In [2]:
import torch
import torch.nn as nn
from transformers import GPT2Model, GPT2Config

# Define patch embedding module
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=(224, 224), patch_size=(8, 4), embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
        self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim))

    def forward(self, x):
        # Convert the image to patches
        x = self.proj(x)  # Shape: (batch_size, embed_dim, num_patches_height, num_patches_width)
        x = x.flatten(2).transpose(1, 2)  # Shape: (batch_size, num_patches, embed_dim)
        x = x + self.position_embeddings  # Add positional encoding
        return x

# Define DTrOCR model class
class DTrOCR(nn.Module):
    def __init__(self, embed_dim=768, max_seq_len=2568):
        super(DTrOCR, self).__init__()
        self.patch_embedding = PatchEmbedding(embed_dim=embed_dim)

        # Load a pre-trained GPT-2 model as the decoder-only Transformer
        config = GPT2Config(vocab_size=50257, n_positions=max_seq_len, n_embd=embed_dim, n_layer=12, n_head=12)
        self.decoder = GPT2Model(config)

        # Additional tokens for OCR
        self.sep_token = torch.tensor([50256])  # Example special token (you may need to adjust)
        self.eos_token = torch.tensor([50257])  # End of sequence token

        # Output layer for generating token probabilities
        self.lm_head = nn.Linear(embed_dim, config.vocab_size, bias=False)

    def forward(self, images, labels=None):
        # Patch embedding
        x = self.patch_embedding(images)  # Shape: (batch_size, num_patches, embed_dim)
        
        # Add [SEP] token to separate image sequence and text sequence
        # sep_token = self.sep_token.repeat(x.size(0), x.size(1), 1)
        # x = torch.cat((x, sep_token), dim=2)

        # Pass through GPT-2 decoder model
        outputs = self.decoder(inputs_embeds=x)
        logits = self.lm_head(outputs.last_hidden_state)
        
        print(logits.shape)

        return logits

    def generate_text(self, images, max_length=50):
        # Start with the patch embeddings for the image and the [SEP] token
        x = self.patch_embedding(images)
        # sep_embed = self.sep_embedding.expand(x.size(0), -1, -1)
        # x = torch.cat((x, sep_embed), dim=1)

        generated_tokens = []

        for i in range(max_length):
            print(i)
            # Get the decoder output logits for the current sequence
            outputs = self.decoder(inputs_embeds=x)
            logits = self.lm_head(outputs.last_hidden_state)

            # Select the last token's logits and get the most likely next token
            next_token_logits = logits[:, -1, :]  # Shape: (batch_size, vocab_size)
            next_token = torch.argmax(next_token_logits, dim=-1)  # Shape: (batch_size,)

            # Append the token to the generated sequence
            generated_tokens.append(next_token)

            # Break if the EOS token is generated
            if next_token.item() == self.eos_token.item():
                break

            # Update `x` by appending the embedding of the next token
            next_token_embed = self.decoder.wte(next_token).unsqueeze(1)  # Embed the token
            x = torch.cat((x, next_token_embed), dim=1)  # Append to the sequence

        # Convert list of tokens to a tensor and return
        return torch.stack(generated_tokens, dim=1)

# Model instantiation
model = DTrOCR()

# Example inputs
images = torch.randn(1, 3, 224, 224)  # Batch of 2 images, each 224x224 RGB
labels = torch.randint(0, 50257, (1, 20))  # Example labels (batch of 2 sequences)

# Generate text
output = model.generate_text(images, max_length=5)

0
1
2
3
4


In [6]:
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.decode(output[0].tolist())

' 1989kee unrestricted routersograph'

In [20]:
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)

In [27]:
encoded_input.shape

{'input_ids': tensor([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}