In [1]:
from scripts.model import EvalModel
import os
from scripts.datasets import SQUAD_dataset

%load_ext autoreload
%autoreload 2

In [2]:
CHECKPOINT_DIR = (
    "/mnt/d/models/"
    if os.environ.get("CHECKPOINT_DIR") == None
    else os.environ["CHECKPOINT_DIR"]
)

In [3]:
model_args = {
    "vision_encoder_path": "ViT-L-14",
    "vision_encoder_pretrained": "openai",
    "lm_path": "anas-awadalla/mpt-1b-redpajama-200b",
    "lm_tokenizer_path": "anas-awadalla/mpt-1b-redpajama-200b",
    "checkpoint_path": f"{CHECKPOINT_DIR}/OpenFlamingo-3B-vitl-mpt1b/checkpoint.pt",
    "cross_attn_every_n_layers": 1,
    "precision": "bf16",
    "device": 0,
}

print(f"Loading Checkpoint from {CHECKPOINT_DIR}")
model = EvalModel(model_args)

Loading Checkpoint from /mnt/d/models/




You are using config.init_device='cpu', but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.
Flamingo model initialized with 1046992944 trainable parameters


In [4]:
data = SQUAD_dataset()

In [5]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

print_trainable_parameters(model.model)

trainable params: 1046992944 || all params: 2559117360 || trainable%: 40.91


In [6]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r = 32,
    lora_alpha = 16,
    target_modules = ['Wqkv'],
    lora_dropout=0.1,
    bias='none',
)
lora_model = get_peft_model(model.model, config)
print_trainable_parameters(lora_model)

trainable params: 6291456 || all params: 2565408816 || trainable%: 0.25


In [6]:
import json

with open('/mnt/d/datasets/psuedo_dataset.json', 'r') as f:
    pseudo_outputs = json.load(f)

In [7]:
instruction = '''
Suppose you are teaching a first grader. Answer the question according to the provided context and explain to the first grader your method to find the answer without referring back to the context in the form "Rationale: <reasons> \n Answer: <answer>" where <reasons> and <answer> are your response. Make your response as short as possible.\n
'''

In [14]:
model.model.uncache_media()

In [13]:
from torch import optim
from torch import nn
import torch
import tqdm

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.model.parameters(), lr=2e-5, weight_decay=0.01)

loss_vals = []
for epoch in range(1):
    counter = 0
    for batch in tqdm.tqdm(pseudo_outputs.values()):

        batch = batch[len(instruction)-2:] +"<|endofchunk|>"

        token = model.tokenizer(batch)
        image = data.palceholder_image
        image_token = data.image_preprocess_batch(model.image_processor, [image])
        #image_tokens = torch.cat([image_token], dim=0)
        input_ids, attention_mask = model._prepare_text(batch)

        output = model.model(
                    image_token.to(0, dtype=torch.bfloat16),
                    input_ids[:, :-1],
                    attention_mask[:, :-1]
                )
        loss = criterion(output.logits.reshape(1, 50280, -1), input_ids[:, 1:])
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_vals.append(loss.cpu().detach().cpu().item())
        torch.cuda.empty_cache()
    print(f"loss: {sum(loss_vals)/len(loss_vals)}")


 23%|██▎       | 461/1974 [02:40<08:48,  2.86it/s]


KeyboardInterrupt: 

In [15]:
data.infer(
    model.model,
    model.image_processor,
    model.tokenizer,
    '/mnt/d/datasets/',
    early_stop=2000,
)

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 1/2000 [00:00<19:18,  1.72it/s]

2
方法εί


  0%|          | 2/2000 [00:01<20:45,  1.60it/s]

2
方法εί


  0%|          | 3/2000 [00:03<52:27,  1.58s/it]

12
方法存界存界存εί方法并存εί方法


  0%|          | 4/2000 [00:04<37:36,  1.13s/it]

2
方法εί


  0%|          | 5/2000 [00:04<27:40,  1.20it/s]

1
方法


  0%|          | 6/2000 [00:05<23:14,  1.43it/s]

2
εί


  0%|          | 7/2000 [00:05<23:00,  1.44it/s]

3
方法存εί


  0%|          | 8/2000 [00:06<27:13,  1.22it/s]

3
方法存εί


  0%|          | 9/2000 [00:07<22:55,  1.45it/s]

2
方法εί


  0%|          | 10/2000 [00:08<26:13,  1.26it/s]

3
方法存εί


  1%|          | 11/2000 [00:09<25:17,  1.31it/s]

3
方法存εί


  1%|          | 12/2000 [00:09<22:02,  1.50it/s]

2
方法εί


  1%|          | 13/2000 [00:11<40:33,  1.22s/it]

11
方法存εί方法并存εί方法并存εί


  1%|          | 14/2000 [00:12<33:00,  1.00it/s]

2
方法εί


  1%|          | 15/2000 [00:13<32:13,  1.03it/s]

3
方法存εί


  1%|          | 16/2000 [00:14<29:26,  1.12it/s]

3
方法存εί


  1%|          | 17/2000 [00:14<23:27,  1.41it/s]

1
方法


  1%|          | 18/2000 [00:15<22:59,  1.44it/s]

2
方法εί


  1%|          | 19/2000 [00:15<20:36,  1.60it/s]

2
方法εί


  1%|          | 20/2000 [00:15<17:03,  1.93it/s]

1
方法


  1%|          | 21/2000 [00:16<18:53,  1.75it/s]

3
方法存εί


  1%|          | 22/2000 [00:16<17:05,  1.93it/s]

1
方法


  1%|          | 23/2000 [00:17<17:29,  1.88it/s]

2
方法εί


  1%|          | 24/2000 [00:17<17:49,  1.85it/s]

2
法εί


  1%|          | 24/2000 [00:18<25:24,  1.30it/s]


KeyboardInterrupt: 

In [None]:
import json
from evaluate import load

DATASET_DIR = "/mnt/d/datasets/"

with open(f"{DATASET_DIR}squad_resutls.json", "r") as f:
    predictions = json.load(f)
with open(f"{DATASET_DIR}squad_references.json", "r") as f:
    references = json.load(f)

squad_metric = load("squad")
results = squad_metric.compute(predictions=predictions, references=references)
print(results)

{'exact_match': 0.0, 'f1': 0.0}


In [57]:
merged = lora_model.merge_and_unload()

In [61]:
torch.save(merged.state_dict(), '/mnt/d/models/fine-tuned-nl-flamingo/checkpoint.pt')

In [22]:
text = "Hello?<|endofchunk|> Hi"
image = data.palceholder_image
image_token = data.image_preprocess_batch(model.image_processor, [image])

input_ids, attention_mask = model._prepare_text(text)
output = model.model.generate(
    image_token.to(0, dtype=torch.bfloat16),
    input_ids,
    attention_mask,
    max_length=1000,
    num_beams=3,
    pad_token_id=50277,
)
answer = model.tokenizer.decode(output[0])
answer

'Hello?<|endofchunk|> Hi<|endofchunk|>'