# Setup

In [5]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
from collections import defaultdict
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.cider.cider import Cider
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


# Defining The Architectures

In [None]:
class BahdanauAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(BahdanauAttention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)
        att2 = self.decoder_att(decoder_hidden)
        att = self.relu(att1 + att2.unsqueeze(1))
        att = self.full_att(att)
        alpha = self.softmax(att.squeeze(2))
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha

class CNNEncoder(nn.Module):
    def __init__(self, encoded_image_size=7):
        super(CNNEncoder, self).__init__()
        self.enc_image_size = encoded_image_size

        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

        self.block5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2)
        )

    def forward(self, images):
        x = self.block1(images)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)

        batch_size = x.size(0)
        x = x.permute(0, 2, 3, 1)
        x = x.view(batch_size, -1, 512)
        return x

class LSTMDecoder(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size,
                 encoder_dim=512, dropout=0.5):
        super(LSTMDecoder, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = BahdanauAttention(encoder_dim, decoder_dim, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout_layer = nn.Dropout(p=self.dropout)
        self.lstm_cell = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim)
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        self.fc = nn.Linear(decoder_dim, vocab_size)

    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)
        c = self.init_c(mean_encoder_out)
        return h, c

class ImageCaptioningModel(nn.Module):
    def __init__(self, vocab_size, attention_dim=512, embed_dim=256,
                 decoder_dim=512, encoder_dim=512, dropout=0.5):
        super(ImageCaptioningModel, self).__init__()

        self.encoder = CNNEncoder()
        self.decoder = LSTMDecoder(
            attention_dim=attention_dim,
            embed_dim=embed_dim,
            decoder_dim=decoder_dim,
            vocab_size=vocab_size,
            encoder_dim=encoder_dim,
            dropout=dropout
        )

# Loading The Model

In [None]:
def load_model(checkpoint_path='best_model.pth'):
    print(f"Loading checkpoint from {checkpoint_path}...")

    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

    vocab = checkpoint['vocab']

    model = ImageCaptioningModel(
        vocab_size=len(vocab),
        attention_dim=512,
        embed_dim=256,
        decoder_dim=512,
        encoder_dim=512,
        dropout=0.5
    ).to(device)

    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    print(f"Model loaded successfully")
    print(f"Trained for {checkpoint['epoch']} epochs")
    print(f"Best validation loss: {checkpoint['val_loss']:.4f}")
    print(f"Vocabulary size: {len(vocab)}")

    return model, vocab

# Loading Images

In [None]:
def load_and_preprocess_image(image_path: str) -> torch.Tensor:
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0)

    return image

# Caption Generation

In [None]:
def generate_caption(image_path: str, model: ImageCaptioningModel,
                    vocab, max_length: int = 20) -> str:
    image = load_and_preprocess_image(image_path).to(device)

    with torch.no_grad():
        encoder_out = model.encoder(image)

    h, c = model.decoder.init_hidden_state(encoder_out)

    current_word = torch.tensor([vocab.word2idx['<START>']]).to(device)

    generated_words = []

    with torch.no_grad():
        for _ in range(max_length):
            word_emb = model.decoder.embedding(current_word)

            context, alpha = model.decoder.attention(encoder_out, h)

            lstm_input = torch.cat([word_emb, context], dim=1)
            h, c = model.decoder.lstm_cell(lstm_input, (h, c))

            logits = model.decoder.fc(h)  # (1, vocab_size)
            predicted_word_idx = torch.argmax(logits, dim=1).item()

            predicted_word = vocab.idx2word[predicted_word_idx]
            if predicted_word == '<END>':
                break

            if predicted_word not in ['<START>', '<PAD>', '<UNK>']:
                generated_words.append(predicted_word)

            current_word = torch.tensor([predicted_word_idx]).to(device)

    caption = ' '.join(generated_words)

    return caption