In [1]:
from scripts.model import EvalModel
import os
from scripts.datasets import SQUAD_dataset

%load_ext autoreload
%autoreload 2

In [2]:
CHECKPOINT_DIR = (
    "/mnt/d/models/"
    if os.environ.get("CHECKPOINT_DIR") == None
    else os.environ["CHECKPOINT_DIR"]
)

In [3]:
model_args = {
    "vision_encoder_path": "ViT-L-14",
    "vision_encoder_pretrained": "openai",
    "lm_path": "anas-awadalla/mpt-1b-redpajama-200b",
    "lm_tokenizer_path": "anas-awadalla/mpt-1b-redpajama-200b",
    "checkpoint_path": f"{CHECKPOINT_DIR}/OpenFlamingo-3B-vitl-mpt1b/checkpoint.pt",
    "cross_attn_every_n_layers": 1,
    "precision": "bf16",
    "device": 0,
}

print(f"Loading Checkpoint from {CHECKPOINT_DIR}")
model = EvalModel(model_args)

Loading Checkpoint from /mnt/d/models/




You are using config.init_device='cpu', but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.
Flamingo model initialized with 1046992944 trainable parameters


In [4]:
data = SQUAD_dataset()

In [5]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

print_trainable_parameters(model.model)

trainable params: 1046992944 || all params: 2559117360 || trainable%: 40.91


In [6]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r = 256,
    lora_alpha = 512,
    target_modules = ['Wqkv', 'to_q', 'to_k', 'to_v'],
    lora_dropout=0.1,
    bias='none',
)
lora_model = get_peft_model(model.model, config)
print_trainable_parameters(lora_model)

trainable params: 68419584 || all params: 2627536944 || trainable%: 2.60


In [7]:
import json

with open('/mnt/d/datasets/psuedo_dataset.json', 'r') as f:
    pseudo_outputs = json.load(f)

In [8]:
instruction = '''
Suppose you are teaching a first grader. Answer the question according to the provided context and explain to the first grader your method to find the answer without referring back to the context in the form "Rationale: <reasons> \n Answer: <answer>" where <reasons> and <answer> are your response. Make your response as short as possible.\n
'''

In [9]:
lora_model.uncache_media()

In [10]:
from torch import optim
from torch import nn
import torch
import tqdm
from PIL import Image

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.model.parameters(), lr=2e-4, weight_decay=0.01)
count = 0

media_token_id = model.tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
endofchunk_token_id = model.tokenizer("<|endofchunk|>", add_special_tokens=False)[
    "input_ids"
][-1]
lora_model.train()

loss_vals = []
for epoch in range(1):
    counter = 0
    for key, batch in tqdm.tqdm(pseudo_outputs.items()):

        batch = "<image>" + batch[len(instruction)-2:] +"<|endofchunk|>"

        token = model.tokenizer(batch)
        image = Image.open(f'/mnt/d/datasets/images/{key}.png')
        image_token = data.image_preprocess_batch(model.image_processor, [image])
        #image_tokens = torch.cat([image_token], dim=0)
        input_ids, attention_mask = model._prepare_text(batch)


        labels = input_ids.clone()
        labels[labels == model.tokenizer.pad_token_id] = -100
        labels[labels == media_token_id] = -100

        loss = lora_model(
                    image_token.to(0, dtype=torch.bfloat16),
                    input_ids,
                    attention_mask,
                    labels = labels
                )[0]
        
        
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(lora_model.parameters(), 1.0)
        optimizer.step()
        loss_vals.append(loss.cpu().detach().cpu().item())
        torch.cuda.empty_cache()
        
        count += 1
        if count % 100 == 0:
            print(f"loss: {sum(loss_vals)/len(loss_vals)}")


  5%|▌         | 100/1974 [00:26<07:27,  4.19it/s]

loss: 2.25265625


 10%|█         | 200/1974 [00:49<06:54,  4.28it/s]

loss: 2.2352734375


 15%|█▌        | 300/1974 [01:12<06:17,  4.43it/s]

loss: 2.2245572916666667


 20%|██        | 400/1974 [01:36<06:26,  4.07it/s]

loss: 2.22423828125


 25%|██▌       | 500/1974 [02:02<05:41,  4.32it/s]

loss: 2.20275


 30%|███       | 600/1974 [02:25<05:13,  4.39it/s]

loss: 2.1969010416666666


 35%|███▌      | 700/1974 [02:47<04:40,  4.55it/s]

loss: 2.2007924107142856


 41%|████      | 800/1974 [03:10<04:32,  4.30it/s]

loss: 2.201474609375


 46%|████▌     | 900/1974 [03:33<04:12,  4.25it/s]

loss: 2.2047829861111112


 51%|█████     | 1000/1974 [03:57<03:49,  4.25it/s]

loss: 2.202515625


 56%|█████▌    | 1100/1974 [04:19<03:22,  4.32it/s]

loss: 2.1987642045454545


 61%|██████    | 1200/1974 [04:41<03:01,  4.27it/s]

loss: 2.1999479166666664


 66%|██████▌   | 1300/1974 [05:04<02:37,  4.27it/s]

loss: 2.1992608173076924


 71%|███████   | 1400/1974 [05:28<02:06,  4.54it/s]

loss: 2.196434151785714


 76%|███████▌  | 1500/1974 [05:50<01:48,  4.39it/s]

loss: 2.193619791666667


 81%|████████  | 1600/1974 [06:12<01:28,  4.24it/s]

loss: 2.1916552734375


 86%|████████▌ | 1700/1974 [06:35<01:05,  4.19it/s]

loss: 2.193809742647059


 91%|█████████ | 1800/1974 [06:58<00:40,  4.26it/s]

loss: 2.195842013888889


 96%|█████████▋| 1900/1974 [07:22<00:17,  4.25it/s]

loss: 2.1938034539473685


100%|██████████| 1974/1974 [07:39<00:00,  4.30it/s]


In [18]:
os.environ['MAX_LEN'] = "20"

In [19]:
data.infer(
    model.model,
    model.image_processor,
    model.tokenizer,
    '/mnt/d/datasets/',
    early_stop=2000,
    shots = 4
)

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 1/2000 [00:01<1:01:33,  1.85s/it]

denver broncos


  0%|          | 2/2000 [00:03<55:11,  1.66s/it]  

denver broncos


  0%|          | 3/2000 [00:05<1:02:06,  1.87s/it]

denver broncos


  0%|          | 4/2000 [00:06<55:38,  1.67s/it]  

denver broncos


  0%|          | 5/2000 [00:12<1:41:24,  3.05s/it]

golden anniversary


  0%|          | 6/2000 [00:18<2:12:39,  3.99s/it]

theme of super bowl 50 was golden anniversary


  0%|          | 7/2000 [00:21<2:04:32,  3.75s/it]

february 7


  0%|          | 8/2000 [00:26<2:23:35,  4.33s/it]

american football conference afc


  0%|          | 8/2000 [00:30<2:08:27,  3.87s/it]


KeyboardInterrupt: 

In [19]:
import json
from evaluate import load

DATASET_DIR = "/mnt/d/datasets/"

with open(f"{DATASET_DIR}squad_resutls.json", "r") as f:
    predictions = json.load(f)
with open(f"{DATASET_DIR}squad_references.json", "r") as f:
    references = json.load(f)

squad_metric = load("squad")
results = squad_metric.compute(predictions=predictions, references=references)
print(results)

{'exact_match': 17.0, 'f1': 31.420929103073792}


In [20]:
merged = lora_model.merge_and_unload()

In [21]:
torch.save(merged.state_dict(), '/mnt/d/models/fine-tuned-nl-flamingo-visual/checkpoint.pt')

In [13]:
text = "Sleeping for less than 5 hours is"
image = data.palceholder_image
image_token = data.image_preprocess_batch(model.image_processor, [image])

input_ids, attention_mask = model._prepare_text(text)
output = model.model.generate(
    image_token.to(0, dtype=torch.bfloat16),
    input_ids,
    attention_mask,
    max_length=20,
    num_beams=3,
    pad_token_id=50277,
)
answer = model.tokenizer.decode(output[0])
answer

'Sleeping for less than 5 hours is not recommended. Sleeping for less than 7 hours is not'