In [None]:
import os
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import BertTokenizer

class CocoDataset(Dataset):
    def __init__(self, img_folder, ann_file, transform=None, tokenizer=None):
        self.img_folder = img_folder
        self.transform = transform
        self.tokenizer = tokenizer

        with open(ann_file, 'r') as f:
            self.annotations = json.load(f)

        self.image_map = {img['id']: img['file_name'] for img in self.annotations['images']}
        self.captions = [
            {'image_id': ann['image_id'], 'caption': ann['caption']}
            for ann in self.annotations['annotations']
        ]

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        caption_data = self.captions[idx]
        img_id = caption_data['image_id']
        caption = caption_data['caption']
        
        img_name = os.path.join(self.img_folder, self.image_map[img_id])
        image = Image.open(img_name).convert("RGB")

        if self.transform:
            image = self.transform(image)
        
        inputs = self.tokenizer(caption, return_tensors='pt', padding='max_length', max_length=50, truncation=True)
        
        return image, inputs.input_ids.squeeze(), inputs.attention_mask.squeeze()

# Paths to the image folder and annotation file
img_folder = 'path/to/images'
ann_file = 'path/to/annotations.json'

# Image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create the dataset and data loader
dataset = CocoDataset(img_folder, ann_file, transform=transform, tokenizer=tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torchvision.models as models
from transformers import BertModel

class ImageCaptioningModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers):
        super(ImageCaptioningModel, self).__init__()
        self.cnn = models.resnet50(pretrained=True)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-2])  # Remove the last fully connected layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_cnn = nn.Linear(self.cnn[-1].in_channels, embed_size)
        
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc_bert = nn.Linear(self.bert.config.hidden_size, embed_size)

        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=embed_size, nhead=8, dim_feedforward=hidden_size, dropout=0.1),
            num_layers=num_layers
        )
        
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, images, captions, attention_mask):
        # Extract image features
        cnn_features = self.cnn(images)
        cnn_features = self.avgpool(cnn_features)
        cnn_features = cnn_features.view(cnn_features.size(0), -1)
        cnn_features = self.fc_cnn(cnn_features)
        cnn_features = cnn_features.unsqueeze(1)  # Add a sequence dimension
        
        # Extract text features
        bert_outputs = self.bert(input_ids=captions, attention_mask=attention_mask)
        bert_features = self.fc_bert(bert_outputs.last_hidden_state)
        
        # Prepare for transformer decoder
        memory = cnn_features.permute(1, 0, 2)  # (S, N, E)
        tgt = bert_features.permute(1, 0, 2)  # (T, N, E)
        
        # Pass through transformer decoder
        output = self.transformer_decoder(tgt, memory)
        output = output.permute(1, 0, 2)  # (N, T, E)
        
        # Generate final output
        output = self.fc_out(output)
        return output

# Hyperparameters
vocab_size = tokenizer.vocab_size
embed_size = 256
hidden_size = 512
num_layers = 3

# Instantiate the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ImageCaptioningModel(vocab_size, embed_size, hidden_size, num_layers).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
def train_model(dataloader, model, criterion, optimizer, num_epochs=10):
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0
        for images, captions, attention_mask in tqdm(dataloader):
            images, captions, attention_mask = images.to(device), captions.to(device), attention_mask.to(device)
            optimizer.zero_grad()
            outputs = model(images, captions, attention_mask)
            loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(dataloader):.4f}')

# Train the model
train_model(dataloader, model, criterion, optimizer)
