In [7]:
import tiktoken
import torch
import torch.nn.functional as F
from helper import *
from model import *
from knowledge_transfer import *

tokenizer = tiktoken.get_encoding("gpt2")
tokenizer.encode("hello")

[31373]

In [8]:
special_tokens = {"<image>": tokenizer.n_vocab+1}
tokenizer_modified = tiktoken.Encoding(
    name="gpt2_with_image",
    pat_str=tokenizer._pat_str,
    mergeable_ranks=tokenizer._mergeable_ranks,
    special_tokens={**tokenizer._special_tokens, **special_tokens}
)

In [9]:
def text_to_token_ids(texts, tokenizer, device="cpu", max_len = None):
    # return torch.tensor(tokenizer.encode(text, allowed_special="<|endoftext|>")).unsqueeze(0)
    if type(texts) == list:
        encodings = []
        for text in texts:
            token_ids = torch.tensor(
                        tokenizer.encode(
                                text,
                                allowed_special={"<|endoftext|>", "<image>"}
                            ),
                            
                    device=device).unsqueeze(0)
            encodings.append(token_ids)

        if max_len == None:
            max_len = max(e.numel() for e in encodings)
        # import pdb;
        # pdb.set_trace()
        encodings_cat = torch.cat([
            F.pad(e, (0, max_len - e.numel()), value=50256)
            for e in encodings
        ], dim=0)


        return encodings_cat
    
    else:
        return torch.tensor(
                        tokenizer.encode(
                                texts,
                                allowed_special={"<|endoftext|>", "<image>"}
                            ),
                    device=device).unsqueeze(0)
        

def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0).cpu()
    return tokenizer.decode(flat.tolist())
    
encoded = text_to_token_ids("hello hi __hi h...", tokenizer)
token_ids_to_text(encoded, tokenizer)

'hello hi __hi h...'

In [10]:
vocab_size = tokenizer_modified.n_vocab
vocab_size

50259

In [29]:
class GPTModel(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.token_embedding    = torch.nn.Embedding(cfg["vocab_size"], cfg["embedding_dim"])
        self.position_embedding = torch.nn.Embedding(cfg["context_length"], cfg["embedding_dim"])
        self.drop_emb = torch.nn.Dropout(cfg["drop_rate"])

        self.transformer_blocks = torch.nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )

        self.final_norm = LayerNorm(cfg["embedding_dim"])
        self.out_head   = torch.nn.Linear(cfg["embedding_dim"], cfg["vocab_size"], bias=False)

        self.proj = torch.nn.Linear(cfg["vision_dim"], cfg["embedding_dim"])

    def forward(self, in_idx=None, inputs_embeds=None):  # CHANGED: Both optional, explicit parameter
        # CHANGED: Handle both text-only and multimodal paths
        if inputs_embeds is not None:
            # Multimodal path: use pre-computed embeddings
            toks_embeds = inputs_embeds
            batch_size, seq_length, _ = toks_embeds.shape  # CHANGED: Get dimensions from embeddings
        else:
            # Text-only path: convert token indices to embeddings
            if in_idx is None:
                raise ValueError("Must provide either in_idx or inputs_embeds")
            batch_size, seq_length = in_idx.shape
            toks_embeds = self.token_embedding(in_idx)
        
        # CHANGED: Use toks_embeds.device (works for both paths)
        pos_embeds = self.position_embedding(torch.arange(0, seq_length, device=toks_embeds.device))

        x = self.proj(toks_embeds) + pos_embeds
        x = self.drop_emb(x)
        x = self.transformer_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)

        return logits
    
settings, params = download_and_load_gpt2(model_size="774M", models_dir="gpt2")

GPT_CONFIG_124M = {
    "vocab_size"     : tokenizer.n_vocab,     # 50257
    "context_length" : 1024,                  # The maximum number of tokens the model can process at once
    "embedding_dim"  : 768,                   # The number of features used to represent each token 
    "n_heads"        : 12,
    "n_layers"       : 12,                    # How many transformer blocks
    "drop_rate"      : 0.1,
    "qkv_bias"       : False
}

model_configs = {
    "gpt2-small (124M)": {"embedding_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"embedding_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"embedding_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"embedding_dim": 1600, "n_layers": 48, "n_heads": 25},
}

model_name = "gpt2-large (774M)"

NEW_CONFIG = GPT_CONFIG_124M.copy()
NEW_CONFIG.update(model_configs[model_name])
NEW_CONFIG.update({"context_length": 1024, 
                   "qkv_bias": True, 
                   "vocab_size": tokenizer_modified.n_vocab,
                   "vision_dim": 1280})

gpt2 = GPTModel(NEW_CONFIG)
device = "cpu"
load_weights_into_gpt_modified(gpt2, params)


File already exists and is up-to-date: gpt2/774M/checkpoint
File already exists and is up-to-date: gpt2/774M/encoder.json
File already exists and is up-to-date: gpt2/774M/hparams.json
File already exists and is up-to-date: gpt2/774M/model.ckpt.data-00000-of-00001


model.ckpt.index: 100%|██████████| 15.5k/15.5k [00:00<00:00, 3.94MiB/s]
model.ckpt.meta: 100%|██████████| 1.38M/1.38M [00:01<00:00, 792kiB/s] 
vocab.bpe: 100%|██████████| 456k/456k [00:01<00:00, 283kiB/s]  


In [39]:
gpt2.token_embedding

Embedding(50259, 1280)

In [31]:
sample = torch.rand(2, 273, 1280) # we get this form SAM and CLIP fusion

batch_size = sample.shape[0]
texs = ["Extract <image> all text from this document.", "hello <image>"] 
input_ids = text_to_token_ids(texs, tokenizer_modified)
text_embeds = gpt2.token_embedding(input_ids)
text_embeds.shape

torch.Size([2, 10, 1280])

In [None]:
image_token_id = text_to_token_ids("<image>", tokenizer_modified)

final_embeds = []
for batch in range(batch_size):
    image_token_mask = (image_token_id == input_ids)
    image_positions = torch.where(image_token_mask[batch])[0]
    img_pos = image_positions.squeeze().item()

    before = text_embeds[batch, :img_pos]
    after = text_embeds[batch, img_pos+1:]

    merged = torch.cat((before, sample[batch], after), dim = 0)
    final_embeds.append(merged)

# max_len = max(e.shape[0] for e in final_embeds)
# max_len = tokenizer_modified.n_vocab
# max_len = min(max(e.shape[0] for e in final_embeds), 1024)
max_len = max(e.shape[0] for e in final_embeds)
padded_embeds = torch.stack([
    F.pad(e, (0, 0, 0, max_len - e.shape[0]), value=50256)
    for e in final_embeds
])

padded_embeds.shape

torch.Size([2, 282, 1280])

In [None]:
# Tokenize text
text = "Hello, how are you?"
input_ids = text_to_token_ids(text, tokenizer_modified)  # [1, seq_len]

# Forward through model
logits = gpt2(in_idx = input_ids)  # [1, seq_len, vocab_size]

# Get predictions
predictions = torch.argmax(logits, dim=-1)
decoded = tokenizer_modified.decode(predictions[0].tolist())
token_ids_to_text(predictions, tokenizer_modified)


'. (.ed.\n'

In [None]:
logits = gpt2(inputs_embeds = padded_embeds)  # [2, max_len, vocab_size]

# Get predictions
predictions = torch.argmax(logits, dim=-1)  # [2, max_len]
for i in range(batch_size):
    decoded = tokenizer_modified.decode(predictions[i].tolist())
    print(f"Output {i}: {decoded}")

Output 0: ...
.



iller killer
aurushunter hunters


 er
ilerser
hersererschers:urs
ersvier
urs
 players players.zers
hersursuersursers


ererilyursers
ursh playerser
herchers
urs
chers
ters. tourszers
eers playersurschers
.erszers

 theues to to players er playershersurs playersigators playershersilers playersherstershers
 playershers
lerschers
 playersilers
ilers's to
chersilers
 players
uershershers players
.
lers (hersers

 players
hers
hers playersuersuersilersers
.hers
hers
hershers

hershersers thehershershers

hers
hers.ers



hers

 to

ers








aters

ers


ers

tersater


vers


ots

vers
ters


ats

ounds
'sols
pers..
 stones.zer


 areersses



 stones
 players players players to<|endoftext|> playersters<|endoftext|>oul players
erszer stoneszer players ... players players playerzer player players playersplayerhunter
 playerschers players player players players
o playeres. youers
Output 1: ..in
.

.
.
.
























ker































ite
ome


In [54]:
# Debug: Check the sequence length
print(f"Padded embeddings shape: {padded_embeds.shape}")
print(f"Max position embeddings: {gpt2.position_embedding.weight.shape[0]}")

# If padded_embeds is [2, 1500, 768] but position_embedding only supports 1024
# You'll get: IndexError: index out of range


Padded embeddings shape: torch.Size([2, 282, 1280])
Max position embeddings: 1024


In [None]:
model_out = gpt2(inputs_embeds = padded_embeds)
model_out.shape

In [None]:
logits[:, -1, :]

torch.Size([2, 50259])

'ers a'