In [1]:
import os
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import nltk
import random

# Configuration

In [2]:
IMAGE_FOLDER = r'Flickr8k_Dataset//Images'
CAPTION_FILE = r'Flickr8k_Dataset//captions.txt'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 32
EMBED_SIZE = 256
HIDDEN_SIZE = 512
NUM_EPOCHS = 2
MAX_LEN = 20

* `BATCH_SIZE`: Number of samples per batch fed to the model during training.
* `EMBED_SIZE`: Dimensionality of the word embedding vectors or image feature vectors.
* `HIDDEN_SIZE`: Number of hidden units in the RNN (LSTM/GRU) decoder.
* `NUM_EPOCHS`: Number of times the training loop goes over the entire dataset.
* `MAX_LEN`: Maximum length of the caption (number of words) generated or considered.

# Vocabulary

In [3]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}  # index-to-string
        self.stoi = {v: k for k, v in self.itos.items()}              # string-to-index
        self.freq = Counter()

    def build_vocab(self, sentence_list):
        for sentence in sentence_list:
            for word in nltk.tokenize.word_tokenize(sentence.lower()):
                self.freq[word] += 1

        idx = len(self.itos)  # start after the special tokens
        for word, count in self.freq.items():
            if count >= self.freq_threshold:
                if word not in self.stoi:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = nltk.tokenize.word_tokenize(text.lower())
        return [
            self.stoi.get(token, self.stoi["<UNK>"])
            for token in tokenized_text
        ]
    
    def __len__(self):
        return len(self.stoi)

* `stoi (string to index)`: Dictionary mapping words → unique integer indices.
* `itos (index to string)`: Dictionary mapping indices → corresponding words.
* `freq`: Keeps track of word frequency (how often each word appears).

The **`__init__`** Function:
- `freq_threshold`: Minimum frequency a word must have to be included in the vocabulary.
- `itos`: Predefined special tokens with their fixed indices: <br>
  - `<PAD>` (padding token) → index 0
  - `<SOS>` (start of sentence token) → index 1
  - `<EOS>` (end of sentence token) → index 2
  - `<UNK>` (unknown token for rare/unseen words) → index 3
- `stoi`: Reverse mapping from the special tokens (itos) for easy lookup.
- `freq`: A Python Counter object that counts how many times each word appears.

The **`build_vocab`** Function:
- Input: sentence_list is a list of sentences (captions). This loop tokenizes every sentence into words using `nltk.tokenize.word_tokenize` (splits sentences into tokens).
- Converts each word to lowercase for consistency.
- Updates the frequency counter for each word across the entire dataset.

- After counting frequencies, this builds the vocabulary:
  - idx starts after the special tokens (so index 4 onwards).
  - Iterates over each word and its count in the frequency dictionary.
  - Includes only words with frequency >= freq_threshold.
  - Adds each word to the mappings:
    - stoi[word] = idx → word to index
    - itos[idx] = word → index to word
  - Increments idx for the next word.

This ensures the vocabulary contains only common enough words, ignoring very rare ones.

The **`numericalize`** Function:
- Converts any input text string to a list of integers representing the tokens.
- Tokenizes the input text. For each token:
   - Looks up its index in stoi.
   - If the word is not in the vocabulary, replaces it with the index of <UNK>.
- Returns the list of indices for use as input to the model.

The **`__len__`** Function returns the size of the vocabulary, including special tokens. Allows using `len(vocab)` to get the total number of words.

# Dataset

In [4]:
class FlickrDataset(Dataset):
    def __init__(self, root, captions_file, transform=None, freq_threshold=5):
        self.root = root
        self.transform = transform
        self.imgs, self.captions = self.load_captions(captions_file)
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.captions)

    def load_captions(self, caption_file):
        with open(caption_file, 'r') as f:
            lines = f.readlines()[1:]  # skip header if there's one
    
        print("Sample caption lines:")
        for i in range(5):
            print(repr(lines[i]))
    
        imgs = []
        captions = []
    
        for line in lines:
            line = line.strip()
            if ',' not in line:
                continue
            img, caption = line.split(',', 1)  # split only on first comma
            imgs.append(img.strip())
            captions.append(caption.strip())
    
        print(f"Loaded {len(imgs)} image-caption pairs.")
        return imgs, captions


    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        caption = self.captions[idx]
        img_id = self.imgs[idx]
        img_path = os.path.join(self.root, img_id)
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return image, torch.tensor(numericalized_caption)

The **`__init__`** Function:
- `root`: Path to the folder containing the images.
- `captions_file`: Path to the text file that contains image file names and captions.
- `transform`: Optional torchvision transforms to apply on images (like resizing, normalization, etc.).
- `freq_threshold`: Minimum frequency threshold for including words in the vocabulary.

It loads all image filenames and captions using `self.load_captions()`; Instantiates a Vocabulary object with the frequency threshold and Builds the vocabulary from the entire list of captions.

The **`load_captions`** Function:
- Opens the captions file and reads all lines except the first (assuming it's a header).
- Prints the first 5 lines (for debugging/verification).
- Iterates over every line:
   - Removes whitespace.
   - Checks if the line contains a comma (to prevent malformed lines).
   - Splits into img and caption on the first comma only, because captions can contain commas.
   - Adds the image filename and caption to respective lists.
- Returns two lists: one of image filenames and one of captions.

The **`__len__`** Function returns the total number of data points (image-caption pairs). Enables using `len(dataset)`.

The **`__getitem__`** Function:
- For a given index retrieves the caption and image filename.
- Joins the root directory with the image filename to get the full path.
- Opens the image using PIL and converts it to RGB (to ensure 3 channels).

- If transforms are specified (e.g., resizing, tensor conversion, normalization), applies them to the image.
- This makes sure the image tensor is in the correct format and size for your model.

- Starts the caption with the special `<SOS>` (start of sentence) token.
- Converts the caption string into a list of word indices (numerical tokens).
- Appends the special `<EOS>` (end of sentence) token at the end.
- This format helps your model learn when captions start and end during training.

- Finally returns the transformed image tensor and the numericalized caption as a PyTorch tensor.
- These are the inputs and targets used for training your image captioning model.

# Collate Function

In [5]:
def collate_fn(batch):
    imgs, caps = zip(*batch)
    imgs = torch.stack(imgs)
    lengths = [len(c) for c in caps]
    caps_padded = nn.utils.rnn.pad_sequence(caps, batch_first=True, padding_value=0)
    return imgs, caps_padded, lengths

- **Padding**: Captions have varying lengths. Neural nets require fixed-size tensors per batch, so padding is needed.
- **Lengths**: When feeding to RNNs, the model can use the lengths to ignore padding during loss calculation or when packing sequences.
- **Stacking images**: The images are already fixed size tensors, so stacking creates a proper batch tensor.

This function prepares your batch so that images are stacked and captions are padded to the same length, while also keeping track of the original caption lengths for efficient processing.

# Model (VGG + RNN)

In [6]:
class VGGEncoder(nn.Module):
    def __init__(self, embed_size):
        super(VGGEncoder, self).__init__()
        vgg = models.vgg16(pretrained=True)
        self.features = vgg.features
        self.avgpool = vgg.avgpool
        self.fc = nn.Linear(512*7*7, embed_size)
        for p in self.features.parameters():
            p.requires_grad = False

    def forward(self, images):
        x = self.features(images)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions[:, :-1])
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        hiddens, _ = self.lstm(embeddings)
        return self.linear(hiddens)

The **`VGGEncoder`** Class:
- The **Constructor**:
   - Loads VGG16, pretrained on ImageNet.
   - Keeps only the feature extraction part (the convolutional layers and pooling), which outputs feature maps.
   - Adds a fully connected (fc) layer to convert the output feature maps into an embedding vector of size `embed_size`.
   - Freezes the weights of the convolutional layers (`requires_grad=False`) so training focuses only on the new fc layer and the decoder, speeding up training and preventing overfitting.
- The **Forward Method**:
   - Takes a batch of images of shape `(batch_size, 3, H, W)` (3 = RGB channels).
   - Passes them through VGG’s convolutional layers to extract feature maps.
   - Applies average pooling to get a fixed spatial size (7x7 here) regardless of input image size.
   - Flattens the feature maps into a vector.
   - Passes the flattened vector through the fc layer to get the final image embedding of size `(batch_size, embed_size)`.

The **`DecoderRNN`** Class:
- The **Constructor**:
   - Creates an embedding layer that maps vocabulary tokens (integers) to dense vectors of size `embed_size`.
   - Defines an LSTM network with:
   - Input size = `embed_size` (word embedding size)
   - Hidden size = `hidden_size` (controls capacity of the LSTM)
   - Adds a linear layer to map the LSTM’s output at each time step to a distribution over the vocabulary (logits for each word).
- The **Forward Method**:
   - Inputs:
      - `features`: image embeddings from encoder of shape (batch_size, embed_size)
      - `captions`: tokenized caption sequences (batch_size, caption_length)
   - Steps:
      - **Embedding tokens:** `captions[:, :-1]` selects all tokens except the last one, because during training we predict the next word given previous words. Then `self.embed(...)` maps these tokens to embedding vectors `(batch_size, caption_length - 1, embed_size)`.
      - **Concatenate image features:** The image features tensor `(batch_size, embed_size)` is reshaped to `(batch_size, 1, embed_size)` using `unsqueeze(1)`. This is concatenated at the start of the embedded captions along the sequence dimension (dim=1), so the LSTM receives image features as the first input token, followed by the embedded words. Resulting shape: `(batch_size, caption_length, embed_size)`.
      - **LSTM processing:** Passes this concatenated sequence through the LSTM to produce hidden states for each time step. Output hiddens shape: `(batch_size, caption_length, hidden_size)`.
      - **Linear projection:** Applies the linear layer on each hidden state to produce logits over the vocabulary (scores for each word in the vocab) for each time step. Final output shape: `(batch_size, caption_length, vocab_size)`.

# Training Loop

In [7]:
def train_model():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    dataset = FlickrDataset(IMAGE_FOLDER, CAPTION_FILE, transform)
    len(dataset.vocab)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

    encoder = VGGEncoder(EMBED_SIZE).to(DEVICE)
    decoder = DecoderRNN(EMBED_SIZE, HIDDEN_SIZE, len(dataset.vocab)).to(DEVICE)

    criterion = nn.CrossEntropyLoss(ignore_index=0)
    params = list(decoder.parameters()) + list(encoder.fc.parameters())
    optimizer = torch.optim.Adam(params, lr=3e-4)

    for epoch in range(NUM_EPOCHS):
        for i, (imgs, captions, _) in enumerate(dataloader):
            imgs, captions = imgs.to(DEVICE), captions.to(DEVICE)
            features = encoder(imgs)
            inputs = captions[:, :-1]  # remove last token (usually <EOS>)
            targets = captions[:, 1:]  # remove first token (<SOS>)
            # outputs = decoder(features, captions)
            outputs = decoder(features, captions[:, :-1])
            # print(f"outputs shape: {outputs.shape}")
            # print(f"captions[:, 1:] shape: {captions[:, 1:].shape}")
            # loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions[:, 1:].reshape(-1))
            loss = criterion(outputs.reshape(-1, outputs.shape[2]), targets.reshape(-1))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 100 == 0:
                print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i}], Loss: {loss.item():.4f}")

    torch.save({'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 'vocab': dataset.vocab}, 'caption_model.pth')
    print("Model saved!")

**Data Transform and Dataset:**
- Resizes all input images to 224×224 (VGG16 requires this input size).
- Converts PIL images to PyTorch tensors.
- Creates an instance of the custom `FlickrDataset`, passing image folder and caption file paths.
- Applies the defined image transforms.
- `dataset.vocab` holds the vocabulary (word → index mapping).
- Loads data in batches. `collate_fn` ensures that images are stacked and captions are padded properly.

**Model Initialization:**
- `encoder` is a pretrained VGG16 model with an extra FC layer to map image features to an embedding of size EMBED_SIZE.
- `decoder` is an LSTM that will generate captions word-by-word.
- `len(dataset.vocab)` gives the vocabulary size to define the output dimension of the final layer in the decoder.

**Loss Function and Optimizer:**
- Cross-entropy loss is used for classification at each time step.
- `ignore_index=0`: assuming 0 is the padding index (`<PAD>` token), this tells the loss to ignore padding when computing gradients.
- We only want to train:
   - All decoder parameters
   - The encoder’s final fc layer
- The rest of the encoder (VGG16 features) is frozen.
- The Adam optimizer is used with learning rate `0.0003`.

**Training Loop:**
- `imgs` are image tensors, captions are padded caption sequences.
- `_` is typically the lengths list from the collate_fn, which you are not using here but could be useful.
- Sends data to GPU (if available).
- Encode the batch of images to feature vectors: `(batch_size, embed_size)`
- For teacher forcing:
   - `Inputs`: all words except the last one (`<EOS>` removed)
   - `Targets`: all words except the first one (`<SOS>` removed)
- This allows the decoder to predict the next word based on previous ground truth words and image features.
- Outputs shape: `(batch_size, sequence_length, vocab_size)`
- Reshape outputs and targets to match the required format for `CrossEntropyLoss`.
   - Flattened to: `(batch_size × sequence_length, vocab_size)` and `(batch_size × sequence_length,)`
- Typical PyTorch training step:
  - Zero gradients;  Backpropagate loss;  Update model weights
- Saves the model's state dictionaries and vocabulary so that you can load them later for inference or fine-tuning.

# Main Function

In [9]:
if __name__ == '__main__':
    nltk.download('punkt')
    train_model()

[nltk_data] Downloading package punkt to C:\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Sample caption lines:
'1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .\n'
'1000268201_693b08cb0e.jpg,A girl going into a wooden building .\n'
'1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .\n'
'1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playhouse .\n'
'1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a wooden cabin .\n'
Loaded 40455 image-caption pairs.




Epoch [1/2], Step [0], Loss: 8.0057
Epoch [1/2], Step [100], Loss: 4.6346
Epoch [1/2], Step [200], Loss: 4.4029
Epoch [1/2], Step [300], Loss: 4.3448
Epoch [1/2], Step [400], Loss: 4.1174
Epoch [1/2], Step [500], Loss: 4.1776
Epoch [1/2], Step [600], Loss: 4.0062
Epoch [1/2], Step [700], Loss: 4.0764
Epoch [1/2], Step [800], Loss: 4.1217
Epoch [1/2], Step [900], Loss: 3.8847
Epoch [1/2], Step [1000], Loss: 3.9976
Epoch [1/2], Step [1100], Loss: 3.7505
Epoch [1/2], Step [1200], Loss: 3.7762
Epoch [2/2], Step [0], Loss: 3.4763
Epoch [2/2], Step [100], Loss: 3.6393
Epoch [2/2], Step [200], Loss: 3.7172
Epoch [2/2], Step [300], Loss: 3.7313
Epoch [2/2], Step [400], Loss: 3.5545
Epoch [2/2], Step [500], Loss: 3.6710
Epoch [2/2], Step [600], Loss: 3.7115
Epoch [2/2], Step [700], Loss: 3.8100
Epoch [2/2], Step [800], Loss: 3.5454
Epoch [2/2], Step [900], Loss: 3.4751
Epoch [2/2], Step [1000], Loss: 3.4354
Epoch [2/2], Step [1100], Loss: 3.2445
Epoch [2/2], Step [1200], Loss: 3.6616
Model save

# Testing

In [14]:
def generate_caption(image_path, encoder, decoder, vocab, transform):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(DEVICE)
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        feature = encoder(image)
        caption = [vocab.stoi["<SOS>"]]
        for _ in range(MAX_LEN):
            cap = torch.tensor([caption]).to(DEVICE)
            output = decoder(feature, cap)
            _, predicted = output[:, -1, :].max(1)
            pred_id = predicted.item()
            caption.append(pred_id)
            if pred_id == vocab.stoi["<EOS>"]:
                break
        caption_words = [vocab.itos[idx] for idx in caption[1:-1]]
        return ' '.join(caption_words)

In [15]:
# Test on personal image
checkpoint = torch.load('caption_model.pth')
encoder = VGGEncoder(EMBED_SIZE).to(DEVICE)
decoder = DecoderRNN(EMBED_SIZE, HIDDEN_SIZE, len(checkpoint['vocab'])).to(DEVICE)
encoder.load_state_dict(checkpoint['encoder'])
decoder.load_state_dict(checkpoint['decoder'])
vocab = checkpoint['vocab']

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),])

test_img = 'arcane.png'
print("Caption:", generate_caption(test_img, encoder, decoder, vocab, transform))

  checkpoint = torch.load('caption_model.pth')


Caption: a man in a blue shirt is a a . .
