In [139]:
import requests
from PIL import Image
import matplotlib.pyplot as plt
import torch
import numpy as np

from transformers import SamModel, SamProcessor
from transformers import CLIPVisionModel
import torchvision.transforms as T
import torch.nn.functional as F

import math
import os
import tiktoken

from torch.utils.data import DataLoader, Dataset

from deepencoder import CLIP_modified, conv_block, DeepEncoder
from dataloader import OCR_dataset, ocr_collate
from tqdm import tqdm

from helper import text_to_token_ids, token_ids_to_text, download_and_load_gpt2
from knowledge_transfer import load_weights_into_gpt_modified
from model import GPTModel

In [None]:
train_frac = 0.8
test_frac  = 0.15
batch_size = 2
device = "cpu"

In [21]:
tokenizer = tiktoken.get_encoding('gpt2')

special_tokens = {"<image>": tokenizer.n_vocab+1}
tokenizer_modified = tiktoken.Encoding(
    name="gpt2_with_image",
    pat_str=tokenizer._pat_str,
    mergeable_ranks=tokenizer._mergeable_ranks,
    special_tokens={**tokenizer._special_tokens, **special_tokens}
)

In [34]:
vocab_size = tokenizer_modified.n_vocab
vocab_size

50259

## 1. Dataset

In [22]:
files = os.listdir('dataset')
l = len(files)

# train_frac = 0.8
# test_frac  = 0.15

train_pos = int(l * train_frac)
test_pos  = int(l * test_frac)

train_files = files[: train_pos]
test_files = files[train_pos : train_pos + test_pos]
val_files  = files[train_pos + test_pos : ]

len(train_files), len(test_files), len(val_files)

(409, 76, 27)

In [76]:
train_dl = DataLoader(
           dataset=OCR_dataset(
               dataset_file_name = 'dataset',
               files = train_files,
               tokenizer = tokenizer_modified
               ),
           batch_size=batch_size,
           shuffle=True,
           collate_fn=ocr_collate,
           pin_memory=True,
           drop_last = True
       )

test_dl  = DataLoader(
           dataset=OCR_dataset(
               dataset_file_name = 'dataset',
               files = test_files,
               tokenizer = tokenizer_modified
               ),
           batch_size=batch_size,
           shuffle=False,
           collate_fn=ocr_collate,
           pin_memory=True,
           drop_last = True
       )

val_dl  =  DataLoader(
           dataset=OCR_dataset(
               dataset_file_name = 'dataset',
               files = val_files,
               tokenizer = tokenizer_modified
               ),
           batch_size=batch_size,
           shuffle=False,
           collate_fn=ocr_collate,
           pin_memory=True,
           drop_last = True
       )

one_batch  = next(iter(train_dl))

one_batch_input_ids  = one_batch["input_ids"]
one_batch_target_ids =  one_batch["target_ids"]
one_batch_images     = one_batch["images"]

print(f"input_ids: {one_batch_input_ids.shape}")
print(f"target_ids: {one_batch_target_ids.shape}")
print(f"images: {one_batch_images.shape}")



input_ids: torch.Size([2, 21])
target_ids: torch.Size([2, 21])
images: torch.Size([2, 3, 1024, 1024])


## 2. Vision Tokens

In [3]:
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
clip_model = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(device)

In [4]:
deep_encoder = DeepEncoder(sam_model = sam_model, clip_model = clip_model)
deep_encoder

SAM params: 93,735,472
CLIP params: 303,179,776

In [None]:
vision_tokens = deep_encoder(one_batch_images)
vision_tokens.shape

torch.Size([2, 273, 1280])

## 3. LLM

In [None]:
settings, params = download_and_load_gpt2(model_size="774M", models_dir="gpt2")

In [None]:
GPT_CONFIG_124M = {
    "vocab_size"     : tokenizer.n_vocab,     # 50257
    "context_length" : 1024,                  # The maximum number of tokens the model can process at once
    "embedding_dim"  : 768,                   # The number of features used to represent each token 
    "n_heads"        : 12,
    "n_layers"       : 12,                    # How many transformer blocks
    "drop_rate"      : 0.1,
    "qkv_bias"       : False
}

model_configs = {
    "gpt2-small (124M)": {"embedding_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"embedding_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"embedding_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"embedding_dim": 1600, "n_layers": 48, "n_heads": 25},
}

model_name = "gpt2-large (774M)"

NEW_CONFIG = GPT_CONFIG_124M.copy()
NEW_CONFIG.update(model_configs[model_name])
NEW_CONFIG.update({"context_length": 1024, 
                   "qkv_bias": True, 
                   "vocab_size": tokenizer_modified.n_vocab,
                   "vision_dim": 1280})

settings, params = download_and_load_gpt2(model_size="774M", models_dir="gpt2")
gpt2 = GPTModel(NEW_CONFIG)
load_weights_into_gpt_modified(gpt2, params)
gpt2.token_embedding

File already exists and is up-to-date: gpt2/774M/checkpoint
File already exists and is up-to-date: gpt2/774M/encoder.json
File already exists and is up-to-date: gpt2/774M/hparams.json
File already exists and is up-to-date: gpt2/774M/model.ckpt.data-00000-of-00001
File already exists and is up-to-date: gpt2/774M/model.ckpt.index
File already exists and is up-to-date: gpt2/774M/model.ckpt.meta
File already exists and is up-to-date: gpt2/774M/vocab.bpe


Embedding(50259, 1280)

In [77]:
text_embeds = gpt2.token_embedding(one_batch_input_ids)
text_embeds.shape

torch.Size([2, 21, 1280])

In [None]:
image_token_id = text_to_token_ids("<image>", tokenizer_modified) # we will find
                                                                  # <image> and replace
                                                                  # with tokens from SAM
                                                                  # and CLIP  
final_embeds = []
for batch in range(batch_size):
    image_token_mask = (image_token_id == one_batch_input_ids)
    image_positions = torch.where(image_token_mask[batch])[0]
    img_pos = image_positions.squeeze().item()

    before = text_embeds[batch, :img_pos]
    after = text_embeds[batch, img_pos+1:]

    merged = torch.cat((before, vision_tokens[batch], after), dim = 0)
    final_embeds.append(merged)

max_len = max(e.shape[0] for e in final_embeds)
padded_embeds = torch.stack([
    F.pad(e, (0, 0, 0, max_len - e.shape[0]), value=50256)
    for e in final_embeds
])

padded_embeds.shape

torch.Size([2, 293, 1280])

In [None]:
logits = gpt2(inputs_embeds = padded_embeds)
logits.shape

torch.Size([2, 293, 50259])

## 4. Vision pipeline

In [170]:
def vision_pipeline(deep_encoder, deep_decoder, input_ids_batch, image_batch):
    
    vision_tokens = deep_encoder(image_batch)
    text_embeds = deep_decoder.token_embedding(input_ids_batch)

    image_token_id = text_to_token_ids("<image>", tokenizer_modified) # we will find
                                                                  # <image> and replace
                                                                  # with tokens from SAM
                                                                  # and CLIP  
    final_embeds = []
    for batch in range(batch_size):
        image_token_mask = (image_token_id == one_batch_input_ids)
        image_positions = torch.where(image_token_mask[batch])[0]
        img_pos = image_positions.squeeze().item()

        before = text_embeds[batch, :img_pos]
        after = text_embeds[batch, img_pos+1:]

        merged = torch.cat((before, vision_tokens[batch], after), dim = 0)
        final_embeds.append(merged)

    max_len = max(e.shape[0] for e in final_embeds)
    padded_embeds = torch.stack([
        F.pad(e, (0, 0, 0, max_len - e.shape[0]), value=50256)
        for e in final_embeds
    ])

    logits = gpt2(inputs_embeds = padded_embeds)

    return logits


logits = vision_pipeline(deep_encoder = deep_encoder, 
                         deep_decoder = gpt2,
                         input_ids_batch = one_batch_input_ids,
                         image_batch = one_batch_images)

logits.shape

torch.Size([2, 293, 50259])

## 4. Loss

In [171]:
def calc_loss_batch(pipline, input_batch, target_batch, image_batch, num_vision_tokens = 273):
    logits = pipline(deep_encoder = deep_encoder, 
                     deep_decoder = gpt2,
                     input_ids_batch = input_batch,
                     image_batch = image_batch)
    
    batch_size, seq_len, _ = logits.shape
    # Create aligned targets: [-100 for vision tokens, actual tokens for text]
    aligned_targets = torch.full((batch_size, seq_len), -100, dtype=torch.long)

    # Copy original targets starting after vision tokens
    for i in range(batch_size):
        # Skip first token (<image>) in target_ids
        text_only_targets = target_batch[i, 1:]  # Skip position 0
        text_len = (text_only_targets != -100).sum().item()
        
        # Place text starting right after vision tokens
        aligned_targets[i, num_vision_tokens:num_vision_tokens+text_len] = text_only_targets[:text_len]

    # Now compute loss on full sequence
    loss = F.cross_entropy(
        logits.reshape(-1, 50259),      # [2*293, 50259]
        aligned_targets.reshape(-1),    # [2*293]
        ignore_index=-100               # Ignores vision token positions
    )

    return loss

one_batch_loss = calc_loss_batch(pipline      = vision_pipeline, 
                                 input_batch  = one_batch_input_ids, 
                                 target_batch = one_batch_target_ids, 
                                 image_batch  = one_batch_images)

one_batch_loss

tensor(8.8949, grad_fn=<NllLossBackward0>)

In [176]:
def calc_loss_loader(dataloader, devive, num_batches = None):
    if len(dataloader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(dataloader)
    else:
        num_batches = min(num_batches, len(dataloader))

    total_loss = 0
    for idx, a in tqdm(enumerate(dataloader), total = num_batches):
        # print(target_batch)
        if idx < num_batches:
            loss = calc_loss_batch(pipline     = vision_pipeline, 
                                  input_batch  = a['input_ids'], 
                                  target_batch = a['target_ids'], 
                                  image_batch  = a['images'])
            total_loss += loss.item()
        else: break
    return total_loss / num_batches

calc_loss_loader(train_dl, num_batches = 2, devive = device)

1it [01:34, 94.21s/it]

: 

In [140]:
def generate_text_simple(model, tokens, max_new_tokens, context_size):
    for _ in tqdm(range(max_new_tokens)):
        tokens = tokens[:, -context_size:] # just in case it overflows
        logits = model(tokens)
        logits = logits[:, -1, :] # last context vector
        idx_next = torch.argmax(torch.softmax(logits, dim = -1), dim = -1, keepdim=True)
        tokens = torch.cat((tokens, idx_next), dim = 1)
    return tokens

def generate_and_print_samples(model, tokenizer, device, start_context, max_new_tokens = 50):
    model.eval()
    encoded = text_to_token_ids(start_context, tokenizer).to(device)
    with torch.no_grad():
        generated_ids = generate_text_simple(model = model, tokens = encoded, max_new_tokens = max_new_tokens, context_size = NEW_CONFIG["context_length"])
    decoded = token_ids_to_text(generated_ids, tokenizer)
    print(decoded.replace("\n", " ")) # compacting
    model.train()

generate_and_print_samples(model = gpt2, 
                           tokenizer = tokenizer_modified,
                           device = device,
                           start_context = "hello i am",
                           max_new_tokens = 5
                           )

100%|██████████| 5/5 [01:42<00:00, 20.48s/it]

hello i am, own his its his



