The work in this notebook has been referenced from the work done by my professor Mrs.Swati Jain https://scholar.google.com/citations?user=aU8LyHYAAAAJ&hl=en

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter
import os

In [None]:
# Download NLTK data for tokenization
nltk.download('punkt')
if torch.cuda.is_available():
    device=torch.device(type="cuda",index=0)
else:
    device=torch.device(type="cpu",index=0)
# Parameters
image_size = 224
embedding_dim = 256
hidden_dim = 512
freq_threshold = 5
batch_size = 32
num_epochs = 1
learning_rate = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Image transformations
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
# Vocabulary class to build word-to-index and index-to-word mappings
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def __len__(self):
        return len(self.itos)

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4  # Starting index for new words

        for sentence in sentence_list:
            tokens = word_tokenize(sentence.lower())
            frequencies.update(tokens)

            for token, freq in frequencies.items():
                if freq >= self.freq_threshold and token not in self.stoi:
                    self.stoi[token] = idx
                    self.itos[idx] = token
                    idx += 1

    def numericalize(self, text):
        tokens = word_tokenize(text.lower())
        return [self.stoi.get(token, self.stoi["<UNK>"]) for token in tokens]



In [None]:
# Building the vocabulary

captions_file = '/kaggle/input/flickr8k/captions.txt'
img_folder = '/kaggle/input/flickr8k/Images'
captions_list = []

# Read all captions and build vocabulary
with open(captions_file, 'r') as file:
    lines = file.readlines()
    for line in lines:
        image, caption = line.strip().split(',')[0], line.strip().split(',')[1]
        captions_list.append(caption)

vocab = Vocabulary(freq_threshold)
vocab.build_vocabulary(captions_list)

# Dataset class for Flickr8k


In [None]:
class Flickr8kDataset(Dataset):
    def __init__(self, img_folder, captions_file, transform=None, vocab=None):
        self.img_folder = img_folder
        self.transform = transform
        self.vocab = vocab
        self.captions = self.load_captions(captions_file)

    def load_captions(self, captions_file):
        with open(captions_file, 'r') as file:
            lines = file.readlines()
        captions = {}
        for line in lines:
            img, caption = line.strip().split(',')[0], line.strip().split(',')[1]
            img_id = img.split('#')[0]
            if img_id not in captions:
                captions[img_id] = []
            captions[img_id].append(caption)
        return captions

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_id = list(self.captions.keys())[idx]
        img_path = os.path.join(self.img_folder, img_id)
        image = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            image = self.transform(image)

        caption = self.captions[img_id][0]
        caption = [self.vocab.stoi["<SOS>"]] + self.vocab.numericalize(caption) + [self.vocab.stoi["<EOS>"]]
        return image, caption  # Return caption as a list, not a tensor



In [None]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    images = []
    captions = []

    for img, caption in batch:
        images.append(img)
        captions.append(torch.tensor(caption, dtype=torch.long))

    # Stack images and pad captions
    images = torch.stack(images)
    captions = pad_sequence(captions, batch_first=True, padding_value=vocab.stoi["<PAD>"])
    print('images number', images.shape)
    print('Caption number', captions.shape)

    return images, captions

# Encoder model
class Encoder(nn.Module):
    def __init__(self, embed_size):
        super(Encoder, self).__init__()

        # Load pretrained ResNet50 model with weights
        weights_path = '/kaggle/input/resnet50/pytorch/default/1/resnet50-0676ba61.pth'
        resnet = resnet50(weights=None)
        resnet.load_state_dict(torch.load(weights_path))

        # Remove the last fully connected layer
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.bn(self.fc(features))
        return features

# Decoder model
class Decoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs




In [None]:
# Initialize the dataset and dataloader
dataset = Flickr8kDataset(img_folder, captions_file, transform=transform, vocab=vocab)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


# Initialize models, loss, and optimizer
encoder = Encoder(embed_size=embedding_dim).to(device)
decoder = Decoder(embed_size=embedding_dim, hidden_size=hidden_dim, vocab_size=len(vocab)).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"])
params = list(decoder.parameters()) + list(encoder.fc.parameters()) + list(encoder.bn.parameters())
optimizer = optim.Adam(params, lr=learning_rate)

  resnet.load_state_dict(torch.load(weights_path))


In [None]:
for epoch in range(num_epochs):
    for i, (images, captions) in enumerate(data_loader):
        images, captions = images.to(device), captions.to(device)

        # Forward pass through encoder
        features = encoder(images)

        # Pass all tokens except the last one to the decoder
        outputs = decoder(features, captions[:, :-1])  # Predict the next token for each token in captions


        outputs = outputs.view(-1, outputs.shape[2])  # Flatten to (batch_size * (seq_len - 1), vocab_size)
        targets = captions[:, :].contiguous().view(-1)  # Flatten to (batch_size * (seq_len - 1))

        # Debug: Print shapes to verify alignment before loss calculation
        print(f"Outputs shape: {outputs.shape}")  # Should be (batch_size * (seq_len - 1), vocab_size)
        print(f"Targets shape: {targets.shape}")  # Should be (batch_size * (seq_len - 1))

        # Calculate loss, ignoring <PAD> tokens
        loss = criterion(outputs, targets)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Step [{i}/{len(data_loader)}], Loss: {loss.item():.4f}")

print("Training completed!")

In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image

def load_image(image_path, transform=None):
    """Load an image and apply the necessary transforms."""
    image = Image.open(image_path).convert("RGB")
    if transform is not None:
        image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

def generate_caption(encoder, decoder, image_path, vocab, max_length=20):
    """Generate a caption for a given image."""

    # Prepare the image
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    image = load_image(image_path, transform).to(device)

    # Encode the image to obtain the feature vector
    encoder.eval()
    decoder.eval()

    with torch.no_grad():
        features = encoder(image)

        # Initialize the caption generation
        caption = []
        input_token = torch.tensor([vocab.stoi["<SOS>"]]).unsqueeze(0).to(device)  # Start with the <SOS> token

        # Generate words one by one
        for _ in range(max_length):
            embeddings = decoder.embed(input_token)
            embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
            hiddens, _ = decoder.lstm(embeddings)

            # Get the most recent output (last timestep)
            output = decoder.linear(hiddens.squeeze(1)[:, -1, :])  # Shape: (1, vocab_size)

            # Get the most likely next token
            _, predicted = output.max(1)
            predicted_token = predicted.item()
            input_token = predicted.unsqueeze(0)  # Set predicted token as the next input

            # Stop if <EOS> token is generated
            if predicted_token == vocab.stoi["<EOS>"]:
                break

            # Append predicted word to the caption list
            caption.append(predicted_token)

    # Convert token indices to words
    caption_words = [vocab.itos[token] for token in caption]
    return " ".join(caption_words)

# Example usage:
image_path = "/kaggle/input/prac9img/prac9img.jpg"  # Replace with the path to an image
caption = generate_caption(encoder, decoder, image_path, vocab)
print("Generated Caption:", caption)


Generated Caption: a man in a man in a man in a man in a man in a man in a man
