# Assignment 1 : End-to-End Image Captioning for Remote Sensing

**Name : Kshitij Vaidya**

**Roll Number : 22B1829**

[Google Drive Video Submission](https://drive.google.com/drive/folders/1NkJjQ0YAYkjuV6vtw_bc4hxnU4twDCrx?usp=drive_link)

[GitHub Repository](https://github.com/Kshitij-Vaidya/EE782-Assignments)

## Abstract

This report presents the design, implementation and evaluation of two deep learning architectures for the task of remote sensing image captioning. The models are trained and evaluated on the **Remote Sensing Image Captioning Dataset (RSICD)**. I implemented two prominent encoder-decoder frameworks: a Convolutional Neural Network (CNN) paired with a Long Short-Term Memory (LSTM) network and a CNN combined with a Transformer Decoder. Shared across both model, **ResNet-18** and **MobileNet** backbones pre-trained on ImageNet act as the visual feature extractors. Performance is assessed using the BLEU-4 and METEOR metrics and qualitatively through analysis of success and failure on specific data slices. We also explore model explainability using techniques like **Grad-CAM** for visual saliency and Attention Map analysis for textual importance.

## Introduction

Image captioning, the task of automatically generating a textual description for an image, stands at the intersection of computer vision and natural language processing. It requires a model to not only identify objects within an image but also to understand their relationships and articulate them in a coherent, human-like sentence.

While general image captioning is a well-studied problem, its application to remote sensing and overhead imagery presents unique challenges. Unlike typical ground-level photographs, satellite images often feature different scales, overhead perspectives, and specific land-use patterns (e.g., farmlands, harbors, runways) that demand specialized understanding. Furthermore, the orientation of objects in aerial imagery can be arbitrary, introducing a need for rotational invariance.

This assignment aims to build two end-to-end captioning models from scratch for the RSICD dataset. Specifically, implement and evaluate a classic CNN-LSTM architecture and a more modern CNN-Transformer decoder architecture. The goal is to compare these two popular approaches in the context of remote sensing data. A shared CNN encoder, utilizing either a ResNet-18 or MobileNet backbone, extracts visual features that are then passed to the respective language decoders to generate descriptive captions.

Beyond standard performance metrics like BLEU-4 and METEOR, this report emphasizes a comprehensive analysis of model behavior. This includes a qualitative review of generated captions across different data slices—such as high versus low contrast images or scenes depicting specific geographic features—to identify systematic strengths and weaknesses. We further investigate model interpretability using explainability methods like Grad-CAM to visualize which parts of an image the model focuses on when generating a caption. This detailed analysis is complemented by a mandatory diary documenting the process of identifying and resolving common bugs in LLM-generated code, a key learning objective of this work.

# Code and Directory Structure

The assignment is organized into a modular directory structure for clarity and ease of development. This Jupyter Notebook serves as the submission report with all the code arranged in code cells. For actual execution purposes, refer to the [GitHub repository](https://github.com/Kshitij-Vaidya/EE782-Assignments):
```text
imageCaptioning/
│
├── config.py                # Configuration settings and logger setup
├── main.py                  # Entry point for training and evaluation
│
├── models/                  # Contains model definitions
│   ├── encoder.py           # CNN encoder (ResNet-18, MobileNet)
│   ├── lstmDecoder.py       # LSTM-based decoder
│   ├── transformerDecoder.py# Transformer-based decoder
│   └── captioner.py         # Wrapper combining encoder and decoder
│
├── data/                    # Data handling and preprocessing
│   ├── dataset.py           # Custom dataset class for RSICD
│   ├── preprocess.py        # Preprocessing utilities
│   ├── prepareFromCSV.py    # Script to prepare images and annotations from CSV
│   └── vocabulary.py        # Vocabulary/tokenizer management
│
├── evaluation/              # Evaluation scripts and metrics
│   ├── metrics.py           # BLEU, METEOR, and other metrics
│   ├── decoding.py          # Decoding utilities for inference
│   └── getCaptions.py       # Caption extraction and formatting
│
├── training/                # Training utilities
│   ├── train.py             # Training loop and logic
│   ├── lossFunction.py      # Loss functions
│   ├── optimizer.py         # Optimizer setup
│   └── utils.py             # Helper functions
│
├── outputs/                 # Stores model outputs, predictions, and statistics
├── checkpoints/             # Saved model weights
└── rsicdDataset/            # Dataset files and images (train, valid, test splits)
```

This structure separates data processing, model definition, training, and evaluation, making the codebase easy to navigate

## Configurations

In [None]:
# imageCaptioning/config.py

import os
import logging
import re
import psutil
from typing import List

# Define the Logging Configuration
logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] %(levelname)s:%(name)s: %(message)s",
    datefmt="%H:%M:%S"
)
LOGGER = logging.getLogger("Preprocessing")

def getCustomLogger(name : str) -> logging.Logger:
    return logging.getLogger(name)

# Data Location Configuration
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
DATA_ROOT = os.path.join(PROJECT_ROOT, "rsicdDataset")
OUTPUT_DIRECTORY = os.path.join(PROJECT_ROOT, "outputs")
CHECKPOINT_PATH = os.path.join(PROJECT_ROOT, "checkpoints")

# Define the common tokenize to be used across the project
def tokenize(text : str) -> List[str]:
    return re.findall(r'\w+', text.lower())

# Device and Utility Details for the Torch Modules and Training
DEVICE = "cpu"

def getMemoryMB():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / (1024 ** 2)

## Data Capture and Preprocessing

In [None]:
# imageCaptioning/data/prepareFromCSV.py

import csv
import sys
csv.field_size_limit(sys.maxsize)
import ast
import os
import json
import argparse
import re

from imageCaptioning.config import DATA_ROOT, LOGGER

def prepareFromCSV(csvPath : str, split : str, outputDirectory : str = DATA_ROOT):
    """
    Convert CSV Data with Bytes Info into:
        1. Images / Directory of JPG Files
        2. annotations.json with the metadata for the preprocessing
    """
    imageDirectory = os.path.join(outputDirectory, "images", split)
    os.makedirs(imageDirectory, exist_ok=True)

    annotations = []
    LOGGER.info(f"Reading RSICD CSV from {csvPath} ...")

    with open(csvPath, "r", encoding="utf-8") as file:
        reader = csv.reader(file)
        for index, row in enumerate(reader):
            if not row or len(row) < 3:
                continue
            filename = row[0]
            captionString = row[1]
            byteString = row[2]
            # Parse the captions
            try:
                captionsList = ast.literal_eval(captionString)
                captionsList = [str(c).strip() for c in captionsList]

                if len(captionsList) == 1:
                    joined = captionsList[0]
                    joined = re.sub(r'\.(\w)', r'. \1', joined)
                    splitCaptions = re.split(r'\.\s+', joined)
                    captionsList = [c.strip() for c in splitCaptions if c.strip()]
            except Exception:
                captionsList = captionString.strip("[]").replace("'", "").split(",")
                captionsList = [c.strip() for c in captionsList if c.strip()]

            # Parse the Image Bytes
            try:
                dictObject = ast.literal_eval(byteString)
                imageBytes = dictObject["bytes"]
            except Exception as e:
                LOGGER.warning(f"Row {index}: Failed to parse image bytes ({e})")
                continue

            # Save the image
            basename = os.path.basename(filename)
            outputPath = os.path.join(imageDirectory, basename)
            with open(outputPath, "wb") as image:
                image.write(imageBytes)

            # Collect the metadata
            annotations.append({
                "imageId" : os.path.splitext(basename)[0],
                "filename" : basename,
                "captions" : captionsList,
            })
        
        # Save annotations to the json file
        annotationPath = os.path.join(outputDirectory, f"{split}Annotations.json")
        with open(annotationPath, "w") as jsonFile:
            json.dump(annotations, jsonFile, indent=2)
        
        LOGGER.info(f"Saved {len(annotations)} images : {imageDirectory}")
        LOGGER.info(f"Saved Metadata : {annotationPath}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Prepare the RSICD dataset from the CSV file")
    parser.add_argument("--csv",
                        type=str,
                        required=True,
                        help="Path to the RSICD CSV files")
    parser.add_argument("--output",
                        type=str,
                        default=DATA_ROOT,
                        help="Output Directory (default : DATA_ROOT)")
    parser.add_argument("--split",
                        type=str,
                        choices=["train", "test", "valid"],
                        required=True,
                        help="Split type of the image data")
    arguments = parser.parse_args()

    prepareFromCSV(arguments.csv, arguments.split, arguments.output)

```text
SAVING TRAIN IMAGES

(3.11.0) (base) kshitijvaidya@Kshitijs-MacBook-Pro-2 data % python prepareFromCSV.py --csv /Users/kshitijvaidya/EE782-AdvancedTopicsInML/Assignment1/imageCaptioning/rsicdDataset/train.csv --output ../rsicdDataset --split train
[23:50:38] INFO:DataPrepare: Reading RSICD CSV from /Users/kshitijvaidya/EE782-AdvancedTopicsInML/Assignment1/imageCaptioning/rsicdDataset/train.csv ...
[23:50:38] WARNING:DataPrepare: Row 0: Failed to parse image bytes (malformed node or string on line 1: <ast.Name object at 0x101144c10>)
[23:50:59] INFO:DataPrepare: Saved 8734 images : ../rsicdDataset/images/train
[23:50:59] INFO:DataPrepare: Saved Metadata : ../rsicdDataset/trainAnnotations.json

SAVING TEST IMAGES

(3.11.0) (base) kshitijvaidya@Kshitijs-MacBook-Pro-2 data % python prepareFromCSV.py --csv /Users/kshitijvaidya/EE782-AdvancedTopicsInML/Assignment1/imageCaptioning/rsicdDataset/test.csv --output ../rsicdDataset --split test 
[23:52:01] INFO:DataPrepare: Reading RSICD CSV from /Users/kshitijvaidya/EE782-AdvancedTopicsInML/Assignment1/imageCaptioning/rsicdDataset/test.csv ...
[23:52:01] WARNING:DataPrepare: Row 0: Failed to parse image bytes (malformed node or string on line 1: <ast.Name object at 0x100c68c10>)
[23:52:04] INFO:DataPrepare: Saved 1093 images : ../rsicdDataset/images/test
[23:52:04] INFO:DataPrepare: Saved Metadata : ../rsicdDataset/testAnnotations.json

SAVING VALID IMAGES

(3.11.0) (base) kshitijvaidya@Kshitijs-MacBook-Pro-2 data % python prepareFromCSV.py --csv /Users/kshitijvaidya/EE782-AdvancedTopicsInML/Assignment1/imageCaptioning/rsicdDataset/valid.csv --output ../rsicdDataset --split valid 
[23:52:43] INFO:DataPrepare: Reading RSICD CSV from /Users/kshitijvaidya/EE782-AdvancedTopicsInML/Assignment1/imageCaptioning/rsicdDataset/valid.csv ...
[23:52:43] WARNING:DataPrepare: Row 0: Failed to parse image bytes (malformed node or string on line 1: <ast.Name object at 0x10496cc10>)
[23:52:46] INFO:DataPrepare: Saved 1094 images : ../rsicdDataset/images/valid
[23:52:46] INFO:DataPrepare: Saved Metadata : ../rsicdDataset/validAnnotations.json
```

In [None]:
# imageCaptioning/data/dataset.py

import os
import json
from typing import Optional, List
from PIL import Image
from pathlib import Path
import torch
from torch.utils.data import Dataset
from torchvision import transforms

from imageCaptioning.data.vocabulary import Vocabulary

class RSICDDataset(Dataset):
    def __init__(self, root : str, 
                 split : str,
                 vocab: Optional["Vocabulary"] = None,
                 transform: Optional[transforms.Compose] = None,
                 maxLength : int = 24):
        """
        Arguments:
            1. root (str): dataset root containing the images/{split}/ and annotations.json file
            2. split (str): train | test | valid
            3. vocab (Vocabulary): tokenizer/vocab object only needed for caption encoding
            4. transform : torchvision transforms
            5. maxLength (int): max caption length
        """
        self.split = split
        self.root = root
        self.vocabulary = vocab
        self.maxLength = maxLength
        self.transform = transform

        annotationPath = os.path.join(root, f"{self.split}Annotations.json")
        with open(annotationPath, "r") as file:
            self.annotations = json.load(file)
        
        self.imageDirectory = os.path.join(root, "images", self.split)
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, index):
        annotation = self.annotations[index]
        imagePath = os.path.join(self.imageDirectory, annotation["filename"])
        image = Image.open(imagePath).convert("RGB")

        if self.transform:
            image = self.transform(image)
        # Use the first caption
        caption = annotation["captions"][0]
        if self.vocabulary:
            tokenIds, _ = self.vocabulary.encode(caption, maxLength = self.maxLength)
            tokens = torch.tensor(tokenIds, dtype = torch.long)
        else:
            tokens = caption
        return image, tokens
    
    def getImagePaths(self) -> List[str]:
        """
        Returns list of all image file paths in the dataset
        """
        return [os.path.join(self.imageDirectory, annotation["filename"])
                for annotation in self.annotations]
    
    def loadImage(self, imagePath: Path) -> Image:
        """
        Loads and transforms an image from the given Path object
        """
        image = Image.open(str(imagePath)).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image

In [None]:
# imageCaptioning/data/vocabulary.py

from typing import Dict, List, Tuple
import json
from collections import Counter

from imageCaptioning.config import LOGGER


class Vocabulary:
    """
    Word Level Vocabulary for image captioning
    Handles token-to-index and index-to-token mappings
    """
    def __init__(self, counter : Counter,
                 minFrequency: int = 1,
                 maxSize: int = 10000) -> None:
        self.frequencies: Counter = counter
        self.minFrequency: int = minFrequency
        self.maxSize: int = maxSize

        # Special Tokens
        self.padToken: str = "<pad>"
        self.bosToken: str = "<bos>"
        self.eosToken: str = "<eos>"

        # Define the string to int and int to string mappings
        self.STOI: Dict[str, int] = {}
        self.ITOS: Dict[int, str] = {}

        # Build the Vocabulary on Initialization
        self._buildVocabulary()
    
    def _buildVocabulary(self) -> None:
        tokens = [self.padToken, self.bosToken, self.eosToken]
        mostCommon = [word for word, count in self.frequencies.items() if count >= self.minFrequency]
        mostCommon = mostCommon[: self.maxSize - len(tokens)]
        tokens.extend(mostCommon)

        self.STOI = {token : index for index, token in enumerate(tokens)}
        self.ITOS = {index : token for token, index in self.STOI.items()}

        LOGGER.info(
            f"Built Vocabulary: Size={len(self)}, "
            f"Minimum Frequency={self.minFrequency}, "
            f"Maximum Size={self.maxSize}, "
            f"Unique Tokens={len(self.frequencies)}"
        )
        
    def __len__(self) -> int:
        return len(self.STOI)

    def encode(self, text : str, maxLength: int = 24) -> Tuple[List[int], int]:
        """
        Convert the caption string into a list of token IDs with BOS/EOS and padding
        """
        tokens = text.lower().split()
        rawTokenLength = len(tokens)
        tokenIds = [self.STOI.get(self.bosToken)]
        tokenIds += [self.STOI.get(token, self.STOI.get(self.padToken)) for token in tokens]
        tokenIds.append(self.STOI.get(self.eosToken))
        # Padding / Truncation
        if (len(tokenIds)) < maxLength:
            tokenIds.extend([self.STOI.get(self.padToken)] * (maxLength - len(tokenIds)))
        else:
            tokenIds = tokenIds[:maxLength]
        return tokenIds, rawTokenLength
    
    def decode(self, tokenIds: List[int]) -> str:
        """
        Convert list of token IDs to the caption of strings stopping at EOS
        """
        captionWords = []
        for index in tokenIds:
            token = self.ITOS.get(index, self.padToken)
            if (token == self.eosToken):
                break
            if token not in {self.padToken, self.bosToken}:
                captionWords.append(token)
        return " ".join(captionWords)

    def save(self, path : str) -> None:
        object = {
            "STOI" : self.STOI,
            "ITOS" : self.ITOS,
            "Pad" : self.STOI.get(self.padToken, 0),
            "BOS" : self.STOI.get(self.bosToken, 1),
            "EOS" : self.STOI.get(self.eosToken, 2),
            "Size" : len(self),
            "MinFrequency" : self.minFrequency,
            "MaxSize" : self.maxSize
        }
        with open(path, "w") as file:
            json.dump(obj=object, fp=file, indent=2)
    
    @classmethod
    def load(cls, path : str) -> "Vocabulary":
        with open(path, "r") as file:
            object: Dict[str, int] = json.load(file)
        vocabulary = cls.__new__(cls)
        vocabulary.minFrequency = object.get("MinFrequency", 1)
        vocabulary.maxSize = object.get("MaxSize", 10000)
        vocabulary.padToken = "<pad>"
        vocabulary.eosToken = "<eos>"
        vocabulary.bosToken = "<bos>"
        vocabulary.STOI = object.get("STOI")
        vocabulary.ITOS = object.get("ITOS")
        vocabulary.frequencies = Counter()
        return vocabulary

In [None]:
# imageCaptioning/data/preprocess.py

import os
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd 
from torchvision import transforms

from imageCaptioning.data.dataset import RSICDDataset
from imageCaptioning.data.vocabulary import Vocabulary
from imageCaptioning.config import (DATA_ROOT, LOGGER, 
                                    OUTPUT_DIRECTORY, tokenize)



def getTransforms():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        # Normalise according to the ImageNet Statistics
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

def buildVocabFromTrain(minFrequency: int = 3,
                        maxSize: int = 10000) -> Vocabulary:
    trainDataset = RSICDDataset(root=DATA_ROOT, split="train")
    counter = Counter()
    LOGGER.debug(f"Sample annotation : {trainDataset.annotations[0]['captions']}")
    for annotation in trainDataset.annotations:
        for caption in annotation["captions"]:
            counter.update(tokenize(caption))
    
    for i, (word, freq) in enumerate(counter.most_common(20)):
        LOGGER.debug(f"Word {i}: {word}, freq={freq}")
    LOGGER.debug(f"Total Unique tokens in counter: {len(counter)}")
    
    vocabulary = Vocabulary(counter=counter,
                            minFrequency=minFrequency,
                            maxSize=maxSize)
    vocabularyPath = os.path.join(OUTPUT_DIRECTORY, "vocab.json")
    vocabulary.save(vocabularyPath)
    LOGGER.info(f"Saved vocabulary to {vocabularyPath} (Size = {len(vocabulary)})")
    return vocabulary

def computeStatistics(vocabulary: Vocabulary,
                      split: str,
                      maxLength: int = 24) -> None:
    dataset = RSICDDataset(root=DATA_ROOT,
                           split=split,
                           vocab=vocabulary,
                           transform=getTransforms(),
                           maxLength=maxLength)
    lengths = []
    OOVCount, totalCount = 0, 0

    for annotation in dataset.annotations:
        for caption in annotation["captions"]:
            tokens = tokenize(caption)
            _, rawLength = vocabulary.encode(caption, maxLength=maxLength)
            lengths.append(rawLength)
            totalCount += len(tokens)
            OOVCount += sum(1 for word in tokens if word not in vocabulary.STOI)
    
    coverage = 100 * (1 - OOVCount / totalCount)
    LOGGER.info(f"[{split}] Vocabulary Coverage : {coverage:.2f}")

    # Histogram Plot
    plt.hist(lengths, bins=20)
    plt.title(f"{split} Caption Length Distribution")
    plt.savefig(os.path.join(OUTPUT_DIRECTORY, f"{split}LengthsHistogram.png"))
    plt.close()

    return {
        "split" : split,
        "coverage" : coverage,
        "averageLength" : np.mean(lengths),
        "standardDeviation" : np.std(lengths),
    }


if __name__ == "__main__":
    LOGGER.info("Building Vocabulary from Training Annotations")
    vocabulary = buildVocabFromTrain()
    LOGGER.info("Computing training and validation statistics")
    trainingStatistics = computeStatistics(vocabulary=vocabulary, split="train")
    validationStatistics = computeStatistics(vocabulary=vocabulary, split="valid")
    data = pd.DataFrame([trainingStatistics, validationStatistics])
    data.to_csv(os.path.join(OUTPUT_DIRECTORY, "tokenStatisticsTrainingValidation.csv"), index=False)
    LOGGER.info("Saved token statistics CSV")

```text
RUNNING PREPROCESS TO BUILD VOCAB.json

(.venv) kshitijvaidya@Kshitijs-MacBook-Pro-2 Assignment1 % python -m imageCaptioning.data.preprocess
[14:40:18] INFO:Preprocessing: Building Vocabulary from Training Annotations
[14:40:18] INFO:Preprocessing: Built Vocabulary: Size=2701, Minimum Frequency=1, Maximum Size=10000, Unique Tokens=2698
[14:40:18] INFO:Preprocessing: Saved vocabulary to ./imageCaptioning/outputs/vocab.json (Size = 2701)
[14:40:18] INFO:Preprocessing: Computing training and validation statistics
[14:40:18] INFO:Preprocessing: [train] Vocabulary Coverage : 100.00
[14:40:19] INFO:Preprocessing: [valid] Vocabulary Coverage : 99.04
[14:40:19] INFO:Preprocessing: Saved token statistics CSV
```

## Model Classes

In [None]:
# imageCaptioning/models/encoder.py

import os
import torch
import torch.nn as nn
import torchvision.models as models
from typing import List, Dict
from imageCaptioning.config import getCustomLogger, DEVICE, OUTPUT_DIRECTORY
from imageCaptioning.data.dataset import RSICDDataset

LOGGER = getCustomLogger("Encoder")

class CNNEncoder(nn.Module):
    '''
    CNN Encoder supporting ResNet-18 and MobileNet
    Resolves  classifier head, applies global average pooling
    '''
    def __init__(self, modelName: str = 'resnet18',
                 pretrained: bool = True,
                 finetune: bool = True,
                 numLayers: int = 2,
                 outputDim: int = 512) -> None:
        super().__init__()

        if modelName == 'resnet18':
            baseModel = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
            modules = list(baseModel.children())[:-2]
            self.CNN = nn.Sequential(*modules)
            featureDim = baseModel.fc.in_features
        
        elif modelName == "mobilenet":
            baseModel = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1 if pretrained else None)
            self.CNN = baseModel.features
            featureDim = baseModel.last_channel
        
        else:
            raise ValueError(f"Unsupported Model Name: {modelName}")
        
        # Global Average Pooling
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        # Projection into a Common Dimension
        self.projection = nn.Linear(featureDim, outputDim)
        # Define the finetuning policy
        parameters = list(self.CNN.children())
        for layer in parameters[-numLayers:]:
            for param in layer.parameters():
                param.requires_grad = finetune
        LOGGER.info(f"Initialized CNN Encoder with {modelName},"
                    f"Finetune = {finetune}, Output Dimensions = {outputDim}")
    

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        features = self.CNN(x)
        pooledOutput = self.pool(features).squeeze(-1).squeeze(-1)
        return self.projection(pooledOutput)
    
    def cacheFeatures(self, imagePaths : List[str],
                      dataset: RSICDDataset,
                      batchSize: int = 32,
                      savePath: str = "featureCache.pt") -> None:
        self.eval()
        featureDict = {}
        with torch.no_grad():
            for i in range(0, len(imagePaths), batchSize):
                batchPaths = imagePaths[i : i + batchSize]
                batchImages = [dataset.loadImage(path).to(DEVICE)
                               for path in batchPaths]
                batchTensors = torch.stack(batchImages)
                batchFeatures = self.forward(batchTensors).cpu()

                for imagePath, features in zip(batchPaths, batchFeatures):
                    featureDict[imagePath] = features
        savePath = os.path.join(OUTPUT_DIRECTORY, savePath)
        torch.save(featureDict, savePath)
        LOGGER.info(f"Cached features for {len(featureDict)} images to {savePath}")
    
    @staticmethod
    def loadCachedFeatures(loadPath) -> Dict:
        """
        Load the cached features from a .pt file
        """
        return torch.load(loadPath)

In [None]:
# imageCaptioning/models/lstmDecoder.py

import torch
import torch.nn as nn
from typing import List
from imageCaptioning.config import getCustomLogger, DEVICE

LOGGER = getCustomLogger("LSTM Decoder")

class LSTMDecoder(nn.Module):
    '''
    LSTM Caption Decoder
    Projects CNN feature to init hidden feature
    '''
    def __init__(self,
                 vocabSize: int,
                 embedDimension: int = 256,
                 hiddenDimension: int = 512,
                 numLayers: int = 2,
                 paddingIndex: int = 0,
                 dropout: float = 0.5) -> None:
        super().__init__()
        self.numLayers = numLayers
        self.embedding = nn.Embedding(vocabSize, 
                                      embedDimension, 
                                      padding_idx=paddingIndex)
        self.LSTM = nn.LSTM(embedDimension, 
                            hiddenDimension,
                            numLayers,
                            dropout=dropout,
                            batch_first=True)
        self.fc = nn.Linear(hiddenDimension, vocabSize)

        self.initH = nn.Linear(hiddenDimension, hiddenDimension)
        self.initC = nn.Linear(hiddenDimension, hiddenDimension)

        LOGGER.info(f"Initialised LSTM Decoder with Vocabulary={vocabSize}, "
                    f"Embedding Dimension = {embedDimension}, Hidden Dimensions = {hiddenDimension}, Layers = {numLayers}")
    
    def forward(self, features: torch.Tensor,
                captions: torch.Tensor) -> torch.Tensor:
        """
        Teacher Forcing Mode
        Arguments:
            features: (B, hiddenDimension) : initialised hidden tokens
            captions: (B, L): input sequence of tokens
        """
        embeddings = self.embedding(captions)
        H0 = torch.tanh(self.initH(features)).unsqueeze(0).repeat(self.numLayers, 1, 1) # (numLayers, B, H)
        C0 = torch.tanh(self.initC(features)).unsqueeze(0).repeat(self.numLayers, 1, 1) # (numLayers, B, H)

        outputs, _ = self.LSTM(embeddings, (H0, C0))
        return self.fc(outputs) # Output Size : (B, L, vocabSize)

    def generate(self, features: torch.Tensor,
                 maxLength: int = 24,
                 BOSIndex: int = 1,
                 EOSIndex: int = 2) -> torch.Tensor:
        """
        Greedy Decoding
        """
        H = torch.tanh(self.initH(features)).unsqueeze(0).repeat(self.numLayers, 1, 1)
        C = torch.tanh(self.initH(features)).unsqueeze(0).repeat(self.numLayers, 1, 1)
        inputs = torch.Tensor([BOSIndex],
                              device=features.device).unsqueeze(0)
        embeddings = self.embedding(inputs)
        outputs = []

        for _ in range(maxLength):
            output, (H, C) = self.LSTM(embeddings, (H, C))
            logits = self.fc(output[:, -1, :]) # Last token 
            predicted = torch.argmax(logits, dim=-1)
            outputs.append(predicted.item())
            if predicted.item() == EOSIndex:
                break
            embeddings = self.embedding(predicted.unsqueeze(0))
        
        return outputs
    
    def generateBatch(self, features: torch.Tensor,
                      maxLength: int = 24,
                      BOSIndex: int = 1,
                      EOSIndex: int = 2) -> List[List[int]]:
        batchSize = features.size(0)
        H = torch.tanh(self.initH(features)).unsqueeze(0).repeat(self.numLayers, 1, 1)
        C = torch.tanh(self.initC(features)).unsqueeze(0).repeat(self.numLayers, 1, 1)
        inputs = torch.full((batchSize, 1), BOSIndex,
                            dtype=torch.long,
                            device=DEVICE)
        outputs = [[] for _ in range(batchSize)]
        finished = torch.zeros(batchSize, dtype=torch.bool,
                               device=DEVICE)
        
        for _ in range(maxLength):
            embeddings = self.embedding(inputs)
            output, (H, C) = self.LSTM(embeddings, (H, C))
            logits = self.fc(output[:, -1, :])
            predicted = torch.argmax(logits, dim=-1)

            for i in range(batchSize):
                if not finished[i]:
                    outputs[i].append(predicted[i].item())
                    if predicted[i].item() == EOSIndex:
                        finished[i] = True
            
            if finished.all():
                break

            inputs = predicted.unsqueeze(1)
        
        return outputs

In [None]:
# imageCaptioning/models/transformerDecoder.py

import torch
import torch.nn as nn
from typing import List

from imageCaptioning.config import getCustomLogger, DEVICE

LOGGER = getCustomLogger("Transformer Decoder")

class TransformerDecoder(nn.Module):
    def __init__(self,
                 vocabSize: int,
                 dModel: int = 256,
                 numLayers: int = 2,
                 numHeads: int = 2,
                 ffDim: int = 1024,
                 dropout: float = 0.2,
                 paddingIndex: int = 0,
                 encoderDim: int = 512) -> None:
        super().__init__()

        self.embedding = nn.Embedding(vocabSize, dModel, padding_idx=paddingIndex)
        self.positionEncoder = nn.Parameter(torch.zeros(1, 32, dModel)) # Learnable Positional Encoder
        decoderLayer = nn.TransformerDecoderLayer(dModel, numHeads, ffDim, dropout=dropout, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoderLayer, numLayers)
        self.fc = nn.Linear(dModel, vocabSize)
        # Image Projection : Convert feature into memory tokens
        self.imageProjection = nn.Linear(encoderDim, dModel)

        LOGGER.info(f"Initialised Transformer Decoder with Vocab = {vocabSize}, "
                    f"d_model = {dModel}, Layers = {numLayers}, Heads = {numHeads}")
    
    def forward(self, features: torch.Tensor,
                captions: torch.Tensor) -> torch.Tensor:
        _, L = captions.shape
        embeddings = self.embedding(captions) + self.positionEncoder[:, :L, :]
        # Project the features onto a (B, 1, D) Memory
        memory = self.imageProjection(features).unsqueeze(1)
        # Causal Mask for the Decoder
        mask = torch.triu(torch.ones(L, L), diagonal=1).bool().to(captions.device)
        output = self.decoder(tgt=embeddings,
                              memory=memory,
                              tgt_mask=mask)
        return self.fc(output)
    
    def generateBatch(self, features : torch.Tensor,
                      maxLength : int = 24,
                      BOSIndex : int = 1,
                      EOSIndex : int = 2) -> List[List[int]]:
        batchSize = features.size(0)
        generated = torch.full((batchSize, 1), 
                               BOSIndex, dtype=torch.long)
        finished = torch.zeros(batchSize, dtype=torch.bool,
                               device=DEVICE)
        predictions : List[List[int]] = [[] for _ in range(batchSize)]

        for _ in range(maxLength):
            # Embeddings of Dimension : (B, seqLen, dModel)
            embeddings = self.embedding(generated) + self.positionEncoder[:, :generated.size(1), :]
            memory = self.imageProjection(features).unsqueeze(1)
            mask = torch.triu(torch.ones(generated.size(1), generated.size(1), device=DEVICE), diagonal=1).bool()
            output = self.decoder(tgt=embeddings,
                                  memory=memory,
                                  tgt_mask=mask)
            logits = self.fc(output[:, -1, :])
            nextTokens = torch.argmax(logits, dim=1)

            for i in range(batchSize):
                if not finished[i]:
                    predictions[i].append(nextTokens[i].item())
                    if nextTokens[i].item() == EOSIndex:
                        finished[i] = True
            
            if finished.all():
                break

            generated = torch.cat([generated, nextTokens.unsqueeze(1)], dim=1)
        
        return predictions

In [None]:
# imageCaptioning/models/captioner.py

import torch
import torch.nn as nn
from imageCaptioning.models.encoder import CNNEncoder
from imageCaptioning.models.lstmDecoder import LSTMDecoder
from imageCaptioning.models.transformerDecoder import TransformerDecoder
from imageCaptioning.config import getCustomLogger

LOGGER = getCustomLogger("Captioner")

class Captioner(nn.Module):
    """
    Wrapper for the Encoder + Decoder Framework
    """
    def __init__(self, vocabSize: int,
                 modelType: str = "lstm",
                 encoderName: str = "resnet18",
                 finetune: bool = True) -> None:
        super().__init__()

        self.encoder = CNNEncoder(modelName=encoderName,
                                  pretrained=True,
                                  finetune=finetune,
                                  outputDim=512)
        
        if modelType == "lstm":
            self.decoder = LSTMDecoder(vocabSize=vocabSize)
        elif modelType == "transformer":
            self.decoder = TransformerDecoder(vocabSize=vocabSize)
        else:
            raise ValueError(f"Unknown Decoder Type : {modelType}")
        
        LOGGER.info(f"Initialized Captioner with Encoder = {encoderName} and Decoder = {modelType}")
    
    def forward(self, image: torch.Tensor,
                captions: torch.Tensor) -> torch.Tensor:
        features = self.encoder(image)
        return self.decoder(features, captions)

## Training Utility Functions

In [None]:
# imageCaptioning/training/utils.py

from typing import Dict, Any
import torch
import random
import numpy as np 
import os

from imageCaptioning.config import getCustomLogger

LOGGER = getCustomLogger("Utilities")

def getSeed(seed: int = 42) -> None:
    """
    Ensure reproducibility across runs
    """
    LOGGER.info(f"Setting random seed = {seed}")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def saveCheckpoint(state : Dict[str, Any],
                   filename : str) -> None:
    """
    Save the model/optimizer to a file
    """
    LOGGER.info(f"Saving checkpoint to {filename}")
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    torch.save(state, filename)

def loadCheckpoint(filename : str,
                   device: str = "cpu") -> Dict[str, Any]:
    """
    Load the model/optimizer
    """
    LOGGER.info(f"Loading checkpoint from {filename}")
    return torch.load(filename,
                      map_location=device)

In [None]:
# imageCaptioning/training/lossFunction.py

import torch.nn as nn
from imageCaptioning.config import getCustomLogger

LOGGER = getCustomLogger("Loss Function")

def getLossFunction(paddingIndex: int) -> nn.Module:
    """
    Return the Cross Entropy Loss ignoring the Padding Tokens
    """
    LOGGER.info(f"Initialised CrossEntropyLoss with ignore_index = {paddingIndex}")
    return nn.CrossEntropyLoss(ignore_index=paddingIndex)

In [None]:
#imageCaptioning/training/optimizer.py

import torch.nn as nn
import torch.optim as optim 
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler, StepLR

from imageCaptioning.config import getCustomLogger

LOGGER = getCustomLogger("Optimizer")

def buildOptimizer(model: nn.Module,
                   lrCNN: float = 1e-4,
                   lrDecoder: float = 2e-4,
                   lrTransformer: float = 2e-5) -> Optimizer:
    """
    Build the Adam Optimizer with separate CNN Encoder, LSYM/Transformer Decoder
    """
    LOGGER.info(f"Building Adam Optimizer (CNN LR = {lrCNN}, Decoder LR = {lrDecoder}, Transformer LR = {lrTransformer})")
    parameters = []

    if hasattr(model, "encoder"):
        parameters.append({"params" : model.encoder.parameters(),
                           "lr" : lrCNN})
    if hasattr(model, "decoder"):
        parameters.append({"params" : model.decoder.parameters(),
                           "lr" : lrDecoder})
    if hasattr(model, "transformer"):
        parameters.append({"params" : model.transformer.parameters(),
                           "lr" : lrTransformer})
    
    optimizer = optim.Adam(params=parameters, betas=(0.9, 0.999))
    return optimizer

def buildScheduler(optimizer: Optimizer,
                   stepSize: int = 5,
                   gamma: float = 0.5) -> _LRScheduler:
    """
    StepLR Scheduler: reduce the LR every 'stepSize' epochs by gamma
    """
    LOGGER.info(f"Using StepLR Scheduler: Step Size = {stepSize}, Gamma = {gamma}")
    return StepLR(optimizer=optimizer,
                  step_size=stepSize,
                  gamma=gamma)

In [None]:
# imageCaptioning/training/train.py

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import time

from imageCaptioning.training.lossFunction import getLossFunction
from imageCaptioning.training.optimizer import buildOptimizer, buildScheduler
from imageCaptioning.training.utils import saveCheckpoint
from imageCaptioning.config import getCustomLogger, DEVICE, CHECKPOINT_PATH, getMemoryMB

LOGGER = getCustomLogger("Train")

def trainEpoch(model: nn.Module,
               dataloader: DataLoader,
               criterion: nn.Module,
               optimizer: torch.optim.Optimizer,
               epoch: int,
               gradClip: float = 5.0) -> float:
    """
    Train model for a single epoch
    """
    model.train()
    totalLoss = 0.0

    for _,(images, captions) in enumerate(tqdm(dataloader,
                                               desc=f"Epoch {epoch} [Train]")):
        images, captions = images.to(DEVICE), captions.to(DEVICE)

        outputs : torch.Tensor = model(images, captions[:, :-1])
        loss: torch.Tensor = criterion(outputs.reshape(-1, outputs.size(-1)),
                                       captions[:, 1:].reshape(-1))
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradClip)
        optimizer.step()

        totalLoss += loss.item()
    
    return totalLoss / len(dataloader)


def validate(model: nn.Module,
             dataloader: DataLoader,
             criterion: nn.Module,
             epoch: int) -> float:
    """
    Validate the model on the Validation Set
    """
    model.eval()
    totalLoss = 0.0

    with torch.no_grad():
        for images, captions in tqdm(dataloader,
                                     desc=f"Epoch {epoch} [Valid]"):
            images, captions = images.to(DEVICE), captions.to(DEVICE)
            outputs: torch.Tensor = model(images, captions[:, :-1])
            loss: torch.Tensor = criterion(outputs.reshape(-1, outputs.size(-1)),
                                           captions[:, 1:].reshape(-1))
            totalLoss += loss.item()
    return totalLoss / len(dataloader)


def trainModel(model: nn.Module,
               trainLoader: DataLoader,
               valLoader: DataLoader,
               paddingIndex: int,
               checkpointPath: str,
               epochs: int = 50,
               lrCNN: float = 1e-4,
               lrDecoder: float = 2e-4,
               lrTransformer: float = 2e-5) -> None:
    """
    Main Training Loop
    """
    criterion = getLossFunction(paddingIndex=paddingIndex)
    optimizer = buildOptimizer(model, lrCNN, lrDecoder, lrTransformer)
    scheduler = buildScheduler(optimizer)

    bestValidationLoss = float("inf")

    startTime = time.time()
    startMemory = getMemoryMB()

    for epoch in range(1, epochs + 1):
        trainingLoss = trainEpoch(model, trainLoader, criterion, optimizer, epoch)
        validationLoss = validate(model, valLoader, criterion, epoch)

        LOGGER.info(f"Epoch {epoch}: Train Loss = {trainingLoss:.4f}, Validation Loss = {validationLoss:.4f}")

        scheduler.step()

        if validationLoss < bestValidationLoss:
            LOGGER.info(f"Validation Loss improved from {bestValidationLoss:.4f} to {validationLoss:.4f}")
            bestValidationLoss = validationLoss
            saveCheckpoint({
                "Epoch" : epoch,
                "modelState" : model.state_dict(),
                "optimizerState" : optimizer.state_dict(),
                "validationLoss" : validationLoss
            }, os.path.join(CHECKPOINT_PATH, checkpointPath))
    
    endTime = time.time()
    endMemory = getMemoryMB()
    LOGGER.info(f"Training Time : {endTime - startTime:.2f} seconds")
    LOGGER.info(f"Memory Usage Increase: {endMemory - startMemory:.2f}MB")
    LOGGER.info(f"Peak Memory Usage at End: {endMemory:.2f}MB")

## Main File

In [None]:
# imageCaptioning/main.py

import os
from torch.utils.data import DataLoader
import argparse
import time
# Import project modules
from imageCaptioning.config import (DATA_ROOT, OUTPUT_DIRECTORY, DEVICE, CHECKPOINT_PATH,
                                    getCustomLogger)
from imageCaptioning.data.dataset import RSICDDataset
from imageCaptioning.data.vocabulary import Vocabulary
from imageCaptioning.data.preprocess import getTransforms
from imageCaptioning.models.captioner import Captioner
from imageCaptioning.training.train import trainModel

LOGGER = getCustomLogger("Main")

def loadVocabulary(vocabPath: str) -> Vocabulary:
    """Load the preprocessed vocabulary"""
    if not os.path.exists(vocabPath):
        raise FileNotFoundError(f"Vocabulary file not found at {vocabPath}")
    
    LOGGER.info(f"Loading vocabulary from {vocabPath}")
    vocabulary = Vocabulary.load(vocabPath)
    LOGGER.info(f"Loaded vocabulary with {len(vocabulary)} tokens")
    return vocabulary

def createDataLoaders(vocabulary: Vocabulary, 
                     batchSize: int = 32,
                     maxLength: int = 24,
                     numWorkers: int = 4) -> tuple[DataLoader, DataLoader]:
    """Create training and validation data loaders"""
    
    # Define transforms
    transform = getTransforms()
    
    # Create datasets
    trainDataset = RSICDDataset(
        root=DATA_ROOT,
        split="train",
        vocab=vocabulary,
        transform=transform,
        maxLength=maxLength
    )
    
    valDataset = RSICDDataset(
        root=DATA_ROOT,
        split="valid",
        vocab=vocabulary,
        transform=transform,
        maxLength=maxLength
    )
    
    LOGGER.info(f"Training dataset size: {len(trainDataset)}")
    LOGGER.info(f"Validation dataset size: {len(valDataset)}")
    
    # Create data loaders
    trainLoader = DataLoader(
        trainDataset,
        batch_size=batchSize,
        shuffle=True,
        num_workers=numWorkers,
        pin_memory=True if DEVICE == "cuda" else False
    )
    
    valLoader = DataLoader(
        valDataset,
        batch_size=batchSize,
        shuffle=False,
        num_workers=numWorkers,
        pin_memory=True if DEVICE == "cuda" else False
    )
    
    return trainLoader, valLoader

def main():
    parser = argparse.ArgumentParser(description="Train Image Captioning Model on RSICD Dataset")
    parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training")
    parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")
    parser.add_argument("--lr-cnn", type=float, default=1e-4, help="Learning rate for CNN encoder")
    parser.add_argument("--lr-decoder", type=float, default=2e-4, help="Learning rate for decoder")
    parser.add_argument("--lr-transformer", type=float, default=2e-5, help="Learning rate for transformer")
    parser.add_argument("--model-type", type=str, default="lstm", choices=["lstm", "transformer"], 
                       help="Type of decoder to use")
    parser.add_argument("--encoder-name", type=str, default="resnet18", choices=["resnet18", "mobilenet"],
                       help="CNN encoder architecture")
    parser.add_argument("--max-length", type=int, default=24, help="Maximum caption length")
    parser.add_argument("--finetune", action="store_true", help="Fine-tune the encoder")
    parser.add_argument("--num-workers", type=int, default=4, help="Number of data loader workers")
    
    args = parser.parse_args()
    
    # Create checkpoint directory
    os.makedirs(CHECKPOINT_PATH, exist_ok=True)
    
    LOGGER.info("Starting Image Captioning Training on RSICD Dataset")
    LOGGER.info(f"Device: {DEVICE}")
    LOGGER.info(f"Model Type: {args.model_type}")
    LOGGER.info(f"Encoder: {args.encoder_name}")
    LOGGER.info(f"Fine-tune Encoder: {args.finetune}")
    LOGGER.info(f"Batch Size: {args.batch_size}")
    LOGGER.info(f"Epochs: {args.epochs}")
    
    # Load vocabulary
    vocabPath = os.path.join(OUTPUT_DIRECTORY, "vocab.json")
    vocabulary = loadVocabulary(vocabPath)
    
    # Create data loaders
    trainLoader, valLoader = createDataLoaders(
        vocabulary=vocabulary,
        batchSize=args.batch_size,
        maxLength=args.max_length,
        numWorkers=args.num_workers
    )
    
    # Initialize model
    model = Captioner(
        vocabSize=len(vocabulary),
        modelType=args.model_type,
        encoderName=args.encoder_name,
        finetune=args.finetune
    )
    
    # Move model to device
    model = model.to(DEVICE)
    
    # Log model information
    totalParams = sum(p.numel() for p in model.parameters())
    trainableParams = sum(p.numel() for p in model.parameters() if p.requires_grad)
    LOGGER.info(f"Total parameters: {totalParams:,}")
    LOGGER.info(f"Trainable parameters: {trainableParams:,}")
    
    # Define checkpoint path
    checkpointPath = os.path.join(CHECKPOINT_PATH, f"model_{args.model_type}_{args.encoder_name}.pt")
    
    # Get padding index from vocabulary
    paddingIndex = vocabulary.STOI.get(vocabulary.padToken, 0)
    
    # Train the model
    LOGGER.info("Starting training...")

    trainModel(
        model=model,
        trainLoader=trainLoader,
        valLoader=valLoader,
        paddingIndex=paddingIndex,
        epochs=args.epochs,
        lrCNN=args.lr_cnn,
        lrDecoder=args.lr_decoder,
        lrTransformer=args.lr_transformer,
        checkpointPath=checkpointPath
    )
    
    LOGGER.info("Training completed!")
    LOGGER.info(f"Best model saved to: {checkpointPath}")

if __name__ == "__main__":
    main()

# Results and Discussion

## Vocabulary Statistics

<div style="text-align:center">

<div style="margin-bottom:32px">
    <img src="imageCaptioning/outputs/trainLengthsHistogram.png" style="width:48%; display:inline-block;"/>
    <img src="imageCaptioning/outputs/validLengthsHistogram.png" style="width:48%; display:inline-block;"/>
</div>

<table style="margin: 0 auto;">
  <thead>
    <tr>
      <th>Split</th>
      <th>Vocabulary Coverage (%)</th>
      <th>OOV (%)</th>
      <th>Average Caption Length</th>
      <th>Std. Deviation</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>train</td>
      <td>100.00</td>
      <td>0.00</td>
      <td>10.99</td>
      <td>3.16</td>
    </tr>
    <tr>
      <td>valid</td>
      <td>99.04</td>
      <td>0.96</td>
      <td>9.28</td>
      <td>3.00</td>
    </tr>
  </tbody>
</table>
</div>