In [3]:
from open_flamingo import create_model_and_transforms
from peft.src.peft import LoraModel, LoraConfig
import torch

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14", clip_vision_encoder_pretrained="openai", lang_encoder_path="./llama-7b-hf", tokenizer_path="./llama-7b-hf", cross_attn_every_n_layers=4
)

checkpoint_path = "/home/v-boli7/azure_storage/models/openflamingo/checkpoint.pt"
model.load_state_dict(torch.load(checkpoint_path), strict=False)

config = LoraConfig(
    peft_type="LORA",
    task_type="SEQ_2_SEQ_LM",
    r=8,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.01,
)

lora_model = LoraModel(config, model)

total_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad)
print(f"Total number of trainable parameters in LoRA is {total_params / 1e6}M")


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 8.0
CUDA SETUP: Detected CUDA version 113
CUDA SETUP: Loading binary /anaconda/envs/openflamingo/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...


  warn(msg)
  warn(msg)
Using pad_token, but it is not set yet.
Loading checkpoint shards: 100%|██████████| 33/33 [00:08<00:00,  4.00it/s]


Flamingo model initialized with 1309919248 trainable parameters
Total number of trainable parameters in LoRA is 0.598016M


In [4]:
from lavis.datasets.builders import dataset_zoo
dataset_names = dataset_zoo.get_names()
print(dataset_names)

['aok_vqa', 'avsd_dialogue', 'coco_caption', 'coco_retrieval', 'coco_vqa', 'conceptual_caption_12m', 'conceptual_caption_3m', 'didemo_retrieval', 'flickr30k', 'gqa', 'imagenet', 'laion2B_multi', 'msrvtt_caption', 'msrvtt_qa', 'msrvtt_retrieval', 'msvd_caption', 'msvd_qa', 'nlvr', 'nocaps', 'ok_vqa', 'sbu_caption', 'snli_ve', 'vatex_caption', 'vg_caption', 'vg_vqa']


In [5]:
from lavis.datasets.builders import load_dataset

coco_dataset = load_dataset("coco_caption")
coco_train_set = coco_dataset['train']
for sample in coco_train_set:
    print(sample)
    print(sample['image'])
    print(sample['text_input'])
    break

def collate_fn(batch):
    return batch

coco_train_loader = torch.utils.data.DataLoader(coco_train_set, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)
for batch in coco_train_loader:
    print(batch)
    break

Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/coco/annotations/coco_karpathy_train.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/coco/annotations/coco_karpathy_val.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/coco/annotations/coco_karpathy_test.json
{'image': <PIL.Image.Image image mode=RGB size=640x480 at 0x7F5BC45F02E0>, 'text_input': 'A woman wearing a net on her head cutting a cake. ', 'image_id': 0}
<PIL.Image.Image image mode=RGB size=640x480 at 0x7F5BC45F02E0>
A woman wearing a net on her head cutting a cake. 
[{'image': <PIL.Image.Image image mode=RGB size=640x480 at 0x7F5A8C1CB310>, 'text_input': 'A woman wearing a net on her head cutting a cake. ', 'image_id': 0}]


In [25]:
aokvqa_dataset = load_dataset("aok_vqa")
aokvqa_train_set = aokvqa_dataset['train']
for sample in aokvqa_train_set:
    print(sample)
    break

Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/aokvqa/annotations/aokvqa_v1p0_train.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/aokvqa/annotations/aokvqa_v1p0_val.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/aokvqa/annotations/specialized_vocab_train_lavis.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/aokvqa/annotations/aokvqa_v1p0_test.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/aokvqa/annotations/specialized_vocab_train_lavis.json
{'image': <PIL.Image.Image image mode=RGB size=640x480 at 0x7FF489BECD00>, 'text_input': 'What is the man by the bags awaiting?', 'answers': ['ride', 'bus', 'taxi', 'travelling', 'traffic', 'cab', 'his ride'], 'weights': [0.2, 0.1, 0.2, 0.1, 0.1, 0.2, 0.1]}


In [26]:
coco_vqa_dataset = load_dataset("coco_vqa")
coco_vqa_train_set = coco_vqa_dataset['train']
for sample in coco_vqa_train_set:
    print(sample)
    break

Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/coco/annotations/vqa_train.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/coco/annotations/vqa_val.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/coco/annotations/vqa_val_eval.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/coco/annotations/answer_list.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/coco/annotations/v2_OpenEnded_mscoco_val2014_questions.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/coco/annotations/v2_mscoco_val2014_annotations.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/coco/annotations/vqa_test.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/coco/annotations/answer_list.json
{'image': <PIL.Image.Image image mode=RGB size=640x480 at 0x7FF4D6BCBD00>, 'text_input': 'Wh

In [5]:
from lavis.datasets.builders import load_dataset
vqav2_dataset = load_dataset("aok_vqa")
print(vqav2_dataset.keys())
print(len(vqav2_dataset["train"]))
print(vqav2_dataset["train"][0])


Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/aokvqa/annotations/aokvqa_v1p0_train.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/aokvqa/annotations/aokvqa_v1p0_val.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/aokvqa/annotations/specialized_vocab_train_lavis.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/aokvqa/annotations/aokvqa_v1p0_test.json
Using downloaded and verified file: /home/v-boli7/azure_storage/data/lavis/aokvqa/annotations/specialized_vocab_train_lavis.json
dict_keys(['train', 'val', 'test'])
17056
{'image': <PIL.Image.Image image mode=RGB size=640x480 at 0x7FE2981CC790>, 'text_input': 'What is the man by the bags awaiting?', 'answers': ['ride', 'bus', 'taxi', 'travelling', 'traffic', 'cab', 'his ride'], 'weights': [0.2, 0.1, 0.2, 0.1, 0.1, 0.2, 0.1]}


In [6]:
okvqa_dataset = load_dataset("ok_vqa")
okvqa_train_set = okvqa_dataset['train']
for sample in okvqa_train_set:
    print(sample)
    break

Downloading https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_train.json to /home/v-boli7/azure_storage/data/lavis/okvqa/annotations/okvqa_train.json


100%|██████████| 2530489/2530489 [00:00<00:00, 76175726.40it/s]


Downloading https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_val_eval.json to /home/v-boli7/azure_storage/data/lavis/okvqa/annotations/vqa_val_eval.json


100%|██████████| 1389419/1389419 [00:00<00:00, 38321873.79it/s]


Downloading https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_answer_list_train.json to /home/v-boli7/azure_storage/data/lavis/okvqa/annotations/answer_list.json


100%|██████████| 152278/152278 [00:00<00:00, 9876756.68it/s]


Downloading https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_val2014_questions.json to /home/v-boli7/azure_storage/data/lavis/okvqa/annotations/OpenEnded_mscoco_val2014_questions.json


100%|██████████| 521744/521744 [00:00<00:00, 20632965.43it/s]


Downloading https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json to /home/v-boli7/azure_storage/data/lavis/okvqa/annotations/mscoco_val2014_annotations.json


100%|██████████| 8490544/8490544 [00:00<00:00, 12291234.65it/s]

{'image': <PIL.Image.Image image mode=RGB size=640x479 at 0x7FE183330730>, 'text_input': 'What is the hairstyle of the blond called?', 'answers': ['pony tail', 'braid', 'ponytail'], 'weights': [0.6, 0.2, 0.2]}



