In [65]:
from IPython.utils.tokenutil import generate_tokens_catch_errors
from transformers import GPT2Tokenizer, GPT2Model
import torch

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
id = tokenizer.convert_tokens_to_ids('[SEP]')
model = GPT2Model.from_pretrained('gpt2')
sep_emb = model.wte(torch.tensor(id).unsqueeze(0))
x = torch.rand(2, 196, 768)
sep_emb = sep_emb.expand(x.size(0), -1, -1)
sep_emb.shape

torch.Size([2, 1, 768])

In [26]:
x = torch.cat((x, sep_emb), dim=1)
x.shape

torch.Size([2, 197, 768])

In [29]:
outputs = model(inputs_embeds=x)
outputs.last_hidden_state.shape

torch.Size([2, 197, 768])

In [31]:
import torch.nn as nn

# Output layer for generating token probabilities
lm_head = nn.Linear(768, 50257, bias=False)

logits = lm_head(outputs.last_hidden_state)
logits[0, -1, :10]

tensor([-0.1123, -1.1812, -0.7510,  0.2812, -0.9772, -0.0641, -1.6739, -0.0476,
        -0.0111, -1.2077], grad_fn=<SliceBackward0>)

In [81]:

# Select the last token's logits and get the most likely next token
next_token_logits = logits[:, -1, :]  # Shape: (batch_size, vocab_size)
next_token_logits.shape

torch.Size([2, 50257])

In [61]:

next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(1)
next_token.shape

torch.Size([2, 1])

In [85]:
next_token_logits = logits[:, -1, :]  # Shape: (batch_size, vocab_size)
next_token_logits = next_token_logits.unsqueeze(1)

generated_tokens = None
for i in range(3):
    if generated_tokens is None:
        generated_tokens = next_token_logits
    else:
        generated_tokens = torch.hstack((generated_tokens, next_token_logits))
generated_tokens.shape

torch.Size([2, 3, 50257])

In [66]:
import numpy as np
from torchvision import transforms
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader

ds = load_dataset("alpayariyak/IAM_Sentences")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

ds = ds["train"].select(range(100))
ds = ds.map(lambda x: {"image": transform(x["image"]), "text": x["text"]})

ds

Dataset({
    features: ['image', 'text'],
    num_rows: 100
})

In [67]:
import torch
from torch.utils.data import Dataset, DataLoader

ds2 = load_dataset("alpayariyak/IAM_Sentences", split="train")

# Custom dataset class
class OCRDataset(Dataset):
    def __init__(self, dataset, transform, tokenizer):
        self.dataset = dataset
        self.transform = transform
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]["image"]
        if isinstance(image, list):
            image = torch.tensor(image)
        image = self.transform(image)
        text = self.dataset[idx]["text"]
        labels = self.tokenizer(text, return_tensors="pt", padding="max_length", max_length=3, truncation=True)["input_ids"].squeeze(0)
        return image, labels

# Create dataset and dataloader
ocr_dataset = OCRDataset(ds2, transform, tokenizer)
dataloader = DataLoader(ocr_dataset, batch_size=2, shuffle=True)

# Example batch
images, labels = next(iter(dataloader))
images.shape, labels.shape

(torch.Size([2, 3, 224, 224]), torch.Size([2, 3]))

In [None]:
import torch
import torch.nn as nn
from transformers import GPT2Model, GPT2Config, GPT2Tokenizer
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'sep_token': '[SEP]'})
tokenizer.pad_token = tokenizer.eos_token
id = tokenizer.convert_tokens_to_ids('[SEP]')

torch.mps.empty_cache()

# Define patch embedding module
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=(224, 224), patch_size=(16, 16), embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
        self.patch_dim = patch_size[0] * patch_size[1] * 3  # 3 for RGB channels
        self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))
        self.proj = nn.Linear(self.patch_dim, embed_dim)

    def forward(self, x):
        batch_size, channels, height, width = x.shape
        patches = x.unfold(2, self.patch_size[0], self.patch_size[0]).unfold(3, self.patch_size[1], self.patch_size[1])
        patches = patches.contiguous().view(batch_size, channels, -1, self.patch_size[0] * self.patch_size[1])
        patches = patches.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.patch_dim)
        x = self.proj(patches)  # Project patches to embedding dimension
        x = x + self.position_embeddings  # Add positional encoding
        return x

# Define DTrOCR model class
class DTrOCR(nn.Module):
    def __init__(self, embed_dim=768, max_seq_len=2568):
        super(DTrOCR, self).__init__()
        self.patch_embedding = PatchEmbedding(embed_dim=embed_dim).to("mps")

        # Load a pre-trained GPT-2 model as the decoder-only Transformer
        config = GPT2Config(vocab_size=50257, n_positions=max_seq_len, n_embd=embed_dim, n_layer=12, n_head=12)
        self.decoder = GPT2Model(config).to("mps")

        # Define special tokens
        self.sep_token = torch.tensor(tokenizer.convert_tokens_to_ids('[SEP]')).to("mps")
        self.eos_token = torch.tensor(tokenizer.convert_tokens_to_ids('[EOS]')).to("mps")

        # Output layer for generating token probabilities
        self.lm_head = nn.Linear(embed_dim, config.vocab_size, bias=False).to("mps")
        
    def forward(self, images, max_length=50):
        # Start with the patch embeddings for the image and the [SEP] token
        x = self.patch_embedding(images.to("mps"))
    
        # Append the [SEP] token embedding to the sequence
        sep_token_embed = self.decoder.wte(self.sep_token).unsqueeze(0).unsqueeze(0)
        sep_token_embed = sep_token_embed.expand(x.size(0), -1, -1)  # Expand for batch size
        x = torch.cat((x, sep_token_embed), dim=1)
    
        generated_tokens = None
    
        for i in range(max_length):
            print(f'Iteration {i} of {max_length}. Input shape: {x.shape}')
            # Get the decoder output logits for the current sequence
            outputs = self.decoder(inputs_embeds=x)
            logits = self.lm_head(outputs.last_hidden_state)
    
            # Select the last token's logits and get the most likely next token
            next_token_logits = logits[:, -1, :]  # Shape: (batch_size, vocab_size)
            
            
            next_token = torch.argmax(next_token_logits, dim=-1)  # Shape: (batch_size,)
            next_token_logits = next_token_logits.unsqueeze(1)
            
            
            if generated_tokens is None:
                generated_tokens = next_token_logits.to("mps")
            else:
                generated_tokens = torch.hstack((generated_tokens, next_token_logits))


            # Update `x` by appending the embedding of the next token
            next_token_embed = self.decoder.wte(next_token).unsqueeze(1)  # Embed the token
            x = torch.cat((x, next_token_embed), dim=1)  # Append to the sequence
    
            # Debugging: Print shapes and types
            print(f'Next token: {next_token}')
            print(f'Next token embed shape: {next_token_embed.shape}')
            print(f'Updated input shape: {x.shape}')
    
        return generated_tokens

def train(model, dataloader, epochs=5, lr=1e-4):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        for images, labels in dataloader:
            images, labels = images.to("mps"), labels.to("mps")

            optimizer.zero_grad()
            outputs = model(images, max_length=3)
            outputs = outputs.permute(0, 2, 1)
            print("Output shape, labels shape:", outputs.shape, labels.shape)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# Instantiate and train the model
model = DTrOCR()
# test = torch.rand(2, 3, 224, 224)
# model(test, max_length=3)
train(model, dataloader)

Iteration 0 of 3. Input shape: torch.Size([2, 197, 768])
Next token: tensor([10385, 35160], device='mps:0')
Next token embed shape: torch.Size([2, 1, 768])
Updated input shape: torch.Size([2, 198, 768])
Iteration 1 of 3. Input shape: torch.Size([2, 198, 768])
Next token: tensor([42574, 31095], device='mps:0')
Next token embed shape: torch.Size([2, 1, 768])
Updated input shape: torch.Size([2, 199, 768])
Iteration 2 of 3. Input shape: torch.Size([2, 199, 768])
Next token: tensor([ 9952, 26505], device='mps:0')
Next token embed shape: torch.Size([2, 1, 768])
Updated input shape: torch.Size([2, 200, 768])
Output shape, labels shape: torch.Size([2, 50257, 3]) torch.Size([2, 3])
Epoch 1, Loss: 11.268261909484863
Iteration 0 of 3. Input shape: torch.Size([2, 197, 768])
Next token: tensor([340, 340], device='mps:0')
Next token embed shape: torch.Size([2, 1, 768])
Updated input shape: torch.Size([2, 198, 768])
Iteration 1 of 3. Input shape: torch.Size([2, 198, 768])
Next token: tensor([340, 340

In [9]:
torch.mps.empty_cache()

In [None]:
# # Model instantiation
# model = DTrOCR().to("mps")
# 
# # Example inputs
# images = torch.tensor(ds["image"][0]).unsqueeze(0).to("mps")
# labels = tokenizer(ds["text"][0], return_tensors="pt")["input_ids"].to("mps")
# labels = torch.tensor(labels)
# 
# # Generate text
# output = model.forward(images, max_length=3)
# tokenizer.decode(output[0].tolist())

In [8]:
r1 = np.array([1, 1])
r2 = np.array([2, 2])
r3 = np.array([3, 3])

torch.stack([r1, r2, r3], dim=0)

TypeError: expected Tensor as element 0 in argument 0, but got numpy.ndarray

In [16]:
import torch

r1 = torch.tensor([1, 1, 1])
r2 = torch.tensor([2, 2, 2])
r3 = torch.tensor([3, 3, 3])

r11 = torch.tensor([[1, 1, 1], [1, 1, 1]])
r22 = torch.tensor([[2, 2, 2], [2, 2, 2]])
r33 = torch.tensor([[3, 3, 3], [3, 3, 3]])

torch.stack([r11, r22, r33], dim=0)

tensor([[[1, 1, 1],
         [1, 1, 1]],

        [[2, 2, 2],
         [2, 2, 2]],

        [[3, 3, 3],
         [3, 3, 3]]])