In [None]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel
from timm import create_model

class CaptionModel(nn.Module):
    """
    Combined model using ViT-B/16 for image encoding and GPT-2 for caption generation.
    """
    def __init__(self, model_config):
        super().__init__()
        
        # Initialize Vision Transformer (ViT-B/16) for image encoding
        self.image_encoder = create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
        self.image_proj = nn.Linear(self.image_encoder.embed_dim, model_config.hidden_size)
        
        # Initialize GPT-2 model for caption generation
        self.caption_generator = GPT2LMHeadModel.from_pretrained('gpt2')

    def forward(self, image, input_ids, labels=None):
        # Encode image using ViT and project to match GPT-2 input size
        encoded_image = self.image_encoder(image)[:, 0]
        projected_image = self.image_proj(encoded_image).unsqueeze(1)
        
        # Concatenate image embedding with text input
        model_input = torch.cat([projected_image, self.caption_generator.transformer.wte(input_ids)], dim=1)
        
        # Forward pass through GPT-2
        outputs = self.caption_generator(inputs_embeds=model_input, labels=labels)
        
        return outputs.loss if labels is not None else outputs.logits

    @torch.no_grad()
    def generate_caption(self, image, max_tokens=50, temperature=1.0, deterministic=False):
        # Encode image and prepare for generation
        encoded_image = self.image_encoder(image)[:, 0]
        projected_image = self.image_proj(encoded_image).unsqueeze(1)
        
        # Start with [BOS] token
        input_ids = torch.tensor([[self.caption_generator.config.bos_token_id]], device=image.device)
        
        # Generate caption token by token
        for _ in range(max_tokens):
            model_input = torch.cat([projected_image, self.caption_generator.transformer.wte(input_ids)], dim=1)
            logits = self.caption_generator(inputs_embeds=model_input).logits[:, -1, :] / temperature
            next_token = torch.argmax(logits, dim=-1) if deterministic else torch.multinomial(logits.softmax(dim=-1), 1)
            input_ids = torch.cat([input_ids, next_token], dim=1)
            if next_token.item() == self.caption_generator.config.eos_token_id:
                break
                
        return self.caption_generator.config.tokenizer.decode(input_ids[0], skip_special_tokens=True)