In [1]:
import requests
from PIL import Image
import matplotlib.pyplot as plt
import torch
import numpy as np

from transformers import SamModel, SamProcessor
from transformers import CLIPVisionModel
import torchvision.transforms as T
import torch.nn.functional as F

import math
import os
import tiktoken

from torch.utils.data import DataLoader, Dataset

from deepencoder import CLIP_modified, conv_block, DeepEncoder
from dataloader import OCR_dataset, ocr_collate
from tqdm import tqdm

from helper import text_to_token_ids, token_ids_to_text, download_and_load_gpt2
from knowledge_transfer import load_weights_into_gpt_modified
from model import GPTModel

from pipeline import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_frac = 0.8
test_frac  = 0.15

batch_size = 1
device = "cpu"

In [None]:
tokenizer = tiktoken.get_encoding('gpt2')

special_tokens = {"<image>": tokenizer.n_vocab+1}
tokenizer_modified = tiktoken.Encoding(
    name="gpt2_with_image",
    pat_str=tokenizer._pat_str,
    mergeable_ranks=tokenizer._mergeable_ranks,
    special_tokens={**tokenizer._special_tokens, **special_tokens}
)
vocab_size = tokenizer_modified.n_vocab
vocab_size

50259

In [4]:
files = os.listdir('dataset')
l = len(files)

# train_frac = 0.8
# test_frac  = 0.15

train_pos = int(l * train_frac)
test_pos  = int(l * test_frac)

train_files = files[: train_pos]
test_files = files[train_pos : train_pos + test_pos]
val_files  = files[train_pos + test_pos : ]

len(train_files), len(test_files), len(val_files)

(409, 76, 27)

In [9]:
train_dl = DataLoader(
           dataset=OCR_dataset(
               dataset_file_name = 'dataset',
               files = train_files,
               tokenizer = tokenizer_modified
               ),
           batch_size=batch_size,
           shuffle=True,
           collate_fn=ocr_collate,
           pin_memory=True,
           drop_last = True
       )

test_dl  = DataLoader(
           dataset=OCR_dataset(
               dataset_file_name = 'dataset',
               files = test_files,
               tokenizer = tokenizer_modified
               ),
           batch_size=batch_size,
           shuffle=False,
           collate_fn=ocr_collate,
           pin_memory=True,
           drop_last = True
       )

val_dl  =  DataLoader(
           dataset=OCR_dataset(
               dataset_file_name = 'dataset',
               files = val_files,
               tokenizer = tokenizer_modified
               ),
           batch_size=batch_size,
           shuffle=False,
           collate_fn=ocr_collate,
           pin_memory=True,
           drop_last = True
       )

one_batch  = next(iter(train_dl))

one_batch_input_ids  = one_batch["input_ids"]
one_batch_target_ids =  one_batch["target_ids"]
one_batch_images     = one_batch["images"]

print(f"input_ids: {one_batch_input_ids.shape}")
print(f"target_ids: {one_batch_target_ids.shape}")
print(f"images: {one_batch_images.shape}")



input_ids: torch.Size([1, 18])
target_ids: torch.Size([1, 18])
images: torch.Size([1, 3, 1024, 1024])


In [27]:
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
deep_encoder = DeepEncoder(sam_model = sam_model, clip_model = clip_model)
deep_encoder

SAM params: 93,735,472
CLIP params: 303,179,776

In [6]:
GPT_CONFIG_124M = {
    "vocab_size"     : tokenizer.n_vocab,     # 50257
    "context_length" : 1024,                  # The maximum number of tokens the model can process at once
    "embedding_dim"  : 768,                   # The number of features used to represent each token 
    "n_heads"        : 12,
    "n_layers"       : 12,                    # How many transformer blocks
    "drop_rate"      : 0.1,
    "qkv_bias"       : False
}

model_configs = {
    "gpt2-small (124M)": {"embedding_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"embedding_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"embedding_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"embedding_dim": 1600, "n_layers": 48, "n_heads": 25},
}

model_name = "gpt2-large (774M)"

NEW_CONFIG = GPT_CONFIG_124M.copy()
NEW_CONFIG.update(model_configs[model_name])
NEW_CONFIG.update({"context_length": 1024, 
                   "qkv_bias": True, 
                   "vocab_size": tokenizer_modified.n_vocab,
                   "vision_dim": 1280})

settings, params = download_and_load_gpt2(model_size="774M", models_dir="gpt2")
gpt2 = GPTModel(NEW_CONFIG)
load_weights_into_gpt_modified(gpt2, params)
gpt2.token_embedding

File already exists and is up-to-date: gpt2/774M/checkpoint
File already exists and is up-to-date: gpt2/774M/encoder.json
File already exists and is up-to-date: gpt2/774M/hparams.json
File already exists and is up-to-date: gpt2/774M/model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2/774M/model.ckpt.index
File already exists and is up-to-date: gpt2/774M/model.ckpt.meta
File already exists and is up-to-date: gpt2/774M/vocab.bpe


Embedding(50259, 1280)

In [None]:
logits = vision_pipeline(deep_encoder   = deep_encoder,
                        deep_decoder    = gpt2,
                        input_ids_batch = one_batch_input_ids,
                        image_batch     = one_batch_images,
                        tokenizer       = tokenizer_modified
                        )
logits.shape

torch.Size([1, 289, 50259])

In [None]:
loss = calc_loss_batch(pipline = vision_pipeline,
                       deep_encoder = deep_encoder,
                       deep_decoder = gpt2,
                       input_batch  = one_batch_input_ids,
                       target_batch = one_batch_target_ids,
                       image_batch  = one_batch_images,
                       tokenizer    = tokenizer_modified
                       )
loss

tensor(9.1238, grad_fn=<NllLossBackward0>)

In [None]:
batch_loss = calc_loss_loader(dataloader   = train_dl,
                              deep_encoder = deep_encoder,
                              deep_decoder = gpt2,
                              tokenizer    = tokenizer_modified,
                              num_batches  = 1)
batch_loss

100%|██████████| 1/1 [00:42<00:00, 42.02s/it]


9.664215087890625

In [None]:
sample_gen = generate_and_print_samples(model          = gpt2,
                                        device         = device,
                                        tokenizer      = tokenizer_modified,
                                        start_context  = "airplane is",
                                        cfg            = NEW_CONFIG,
                                        max_new_tokens = 4)

sample_gen

100%|██████████| 4/4 [00:22<00:00,  5.72s/it]

airplane is a a its its





In [None]:
def train_simple(train_loader, val_loader, 
                 deep_encoder, deep_decoder, cfg,
                 optimizer, device, num_epochs, 
                 eval_freq, eval_itter, 
                 start_context, 
                 tokenizer, 
                 verbose = True, max_new_tokens = 50, 
                 save_itter = 5, 
                 save_path = "gpt2/OCR_finetuned/gpt2_774M_finetuned.pth",
                 load_pretained = True):
    
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1

    if load_pretained:
        checkpoint = torch.load(save_path, map_location="cpu")

        epoch_continue = checkpoint["epoch"]
        deep_decoder.load_state_dict(checkpoint["model_state"])
        optimizer.load_state_dict(checkpoint["optimizer_state"])
    else:
        epoch_continue = 0

    for epoch in range(epoch_continue, num_epochs + epoch_continue):
        deep_decoder.train()
        for idx, a in enumerate(train_loader):
            input_batch  = a["input_ids"]
            target_batch = a["target_ids"]
            image_batch  = a["images"]

            optimizer.zero_grad()

            loss = calc_loss_batch(pipline = vision_pipeline,
                       deep_encoder = deep_encoder,
                       deep_decoder = deep_decoder,
                       input_batch  = input_batch,
                       target_batch = target_batch,
                       image_batch  = image_batch,
                       tokenizer    = tokenizer
                       )
            
            loss.backward()

            optimizer.step()
            tokens_seen = input_batch.numel()
            global_step += 1

            if global_step % eval_freq == 0:
                deep_decoder.eval()
                with torch.no_grad():
                    train_loss = calc_loss_loader(dataloader   = train_loader,
                                                  deep_encoder = deep_encoder,
                                                  deep_decoder = deep_decoder,
                                                  tokenizer    = tokenizer,
                                                  num_batches  = eval_itter)
                
                    val_loss   = calc_loss_loader(dataloader   = val_loader,
                                                  deep_encoder = deep_encoder,
                                                  deep_decoder = deep_decoder,
                                                  tokenizer    = tokenizer,
                                                  num_batches  = eval_itter)
                deep_decoder.train()
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)

                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                        f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}"
                    )
                
            if global_step % save_itter == 0:
                checkpoint = {
                                "epoch"          : epoch,
                                "model_state"    : deep_decoder.state_dict(),
                                "optimizer_state": optimizer.state_dict(),
                             }

                torch.save(checkpoint, save_path)
                print("saved")
                
        # print some samples
        if verbose:
            generate_and_print_samples(model           = deep_decoder,
                                       device         = device,
                                       tokenizer      = tokenizer,
                                       start_context  = start_context,
                                       cfg            = cfg,
                                       max_new_tokens = max_new_tokens)
            
    return train_losses, val_losses, track_tokens_seen

In [None]:
optimizer = torch.optim.AdamW(gpt2.parameters(), lr=0.00005, weight_decay=0.1)
num_epochs = 2

train_losses, val_losses, tokens_seen = train_simple(train_loader   = train_dl,
                                                     val_loader     = val_dl,
                                                     deep_encoder   = deep_encoder,
                                                     deep_decoder   = gpt2,
                                                     cfg            = NEW_CONFIG,
                                                     device         = device,
                                                     num_epochs     = num_epochs,
                                                     eval_freq      = 5,
                                                     eval_itter     = 2,
                                                     start_context  = "Hello",
                                                     tokenizer      = tokenizer_modified,
                                                     verbose        = True,
                                                     optimizer      = optimizer,
                                                     max_new_tokens = 10,
                                                     save_itter     = 1,
                                                     save_path      = "gpt2/OCR_finetuned/gpt2_774M_finetuned.pth",
                                                     load_pretained = True)

100%|██████████| 1/1 [00:59<00:00, 59.88s/it]
100%|██████████| 1/1 [00:27<00:00, 27.48s/it]


Ep 1 (Step 000000): Train loss 11.025, Val loss 10.258


100%|██████████| 1/1 [00:49<00:00, 49.55s/it]
100%|██████████| 1/1 [00:16<00:00, 16.81s/it]


Ep 1 (Step 000005): Train loss 5.365, Val loss 6.185


KeyboardInterrupt: 

In [None]:
checkpoint = {
    "epoch"          : epoch,
    "model_state"    : model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
}

torch.save(checkpoint, "checkpoint.pth")

In [None]:
state = torch.load("gpt2/OCR_finetuned/gpt2_774M_finetuned.pth", map_location="cpu")
gpt2.load_state_dict(state)