In [58]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader
from PIL import Image
import pandas as pd
import os
import json


In [59]:
# Define the Encoder CNN
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad = False
        modules = list(resnet.children())[:-1]  # Remove the last layer
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)  # Match LSTM hidden size

    def forward(self, images):
        features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.embed(features)
        return features

# Define the Decoder RNN
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        features = features.unsqueeze(1).repeat(1, captions.size(1), 1)
        lstm_out, _ = self.lstm(embeddings, (features, torch.zeros_like(features).to(features.device)))
        outputs = self.linear(lstm_out)
        return outputs

# Define a custom Dataset
class ImageCaptionDataset(torch.utils.data.Dataset):
    def __init__(self, annotations_file, img_dir, vocab, transform=None, max_length=50):
        self.img_labels = pd.read_csv(annotations_file, delimiter=' ', names=['image', 'caption'])
        self.img_dir = img_dir
        self.vocab = vocab
        self.transform = transform
        self.max_length = max_length

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        caption = self.img_labels.iloc[idx, 1].split()
        tokens = [self.vocab.get(word, self.vocab['<UNK>']) for word in caption]
        tokens = [self.vocab['<START>']] + tokens + [self.vocab['<END>']]
        tokens += [self.vocab['<PAD>']] * (self.max_length - len(tokens))
        return image, torch.tensor(tokens[:self.max_length])

In [55]:
import torch.nn as nn
from torchvision import models

class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad = False
        modules = list(resnet.children())[:-1]  # Remove the last layer
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)  # Match LSTM hidden size

    def forward(self, images):
        features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.embed(features)
        return features

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.num_layers = num_layers
        self.hidden_size = hidden_size

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        # Ensure features match the hidden size of LSTM
        features = features.unsqueeze(0).repeat(self.num_layers, 1, 1)
        h0 = features  # LSTM hidden state
        c0 = torch.zeros_like(h0)  # LSTM cell state
        outputs, _ = self.lstm(embeddings, (h0, c0))
        outputs = self.linear(outputs)
        return outputs


In [56]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embed_size = 512  # Ensure this is consistent in both encoder and decoder
hidden_size = 512  # Ensure this is the same as embed_size for compatibility

encoder = EncoderCNN(embed_size=embed_size).to(device)
decoder = DecoderRNN(embed_size=embed_size, hidden_size=hidden_size, vocab_size=len(word2idx)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=word2idx['<PAD>'])
params = list(decoder.parameters()) + list(encoder.embed.parameters())
optimizer = torch.optim.Adam(params, lr=0.001)

num_epochs = 5
for epoch in range(num_epochs):
    for images, captions in dataloader:
        images, captions = images.to(device), captions.to(device)
        features = encoder(images)
        outputs = decoder(features, captions[:, :-1])
        loss = criterion(outputs.reshape(-1, len(word2idx)), captions[:, 1:].reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')


Epoch 1, Loss: 0.00019126229744870216
Epoch 2, Loss: 9.411579958396032e-05
Epoch 3, Loss: 5.681885886588134e-05
Epoch 4, Loss: 3.8389567635022104e-05
Epoch 5, Loss: 2.7966583729721606e-05


In [57]:
# Save the model state dictionary
torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'captioning_model.pth')
