<a href="https://colab.research.google.com/github/Karn2898/Visual-semantic-pipeline/blob/main/Image_Captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch torchvision spacy pillow




In [None]:
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m71.7 MB/s[0m eta [36m0:00:00[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [21]:
!unzip /content/Flickr8k_text.zip

Archive:  /content/Flickr8k_text.zip
  inflating: CrowdFlowerAnnotations.txt  
  inflating: ExpertAnnotations.txt   
  inflating: Flickr8k.lemma.token.txt  
   creating: __MACOSX/
  inflating: __MACOSX/._Flickr8k.lemma.token.txt  
  inflating: Flickr8k.token.txt      
  inflating: Flickr_8k.devImages.txt  
  inflating: Flickr_8k.testImages.txt  
  inflating: Flickr_8k.trainImages.txt  
  inflating: readme.txt              


In [22]:
import pandas as pd

def format_real_flickr_data():
    input_file = "Flickr8k.token.txt"
    output_file = "captions.txt"

    imgs = []
    caps = []

    print("Reading real Flickr8k data...")
    try:
        with open(input_file, "r") as f:
            for line in f:

                tokens = line.split("\t")
                if len(tokens) < 2:
                    continue


                img_id = tokens[0].split("#")[0]
                caption = tokens[1].strip()

                imgs.append(img_id)
                caps.append(caption)

        # Create Dataframe and save
        df = pd.DataFrame({"image": imgs, "caption": caps})
        df.to_csv(output_file, index=False)
        print(f"Success! Saved {len(df)} captions to {output_file}.")
        print("You can now run data_loader.py with the REAL dataset.")

    except FileNotFoundError:
        print(f"Error: Could not find {input_file}. Please download the dataset first.")

if __name__ == "__main__":
    format_real_flickr_data()


Reading real Flickr8k data...
Success! Saved 40460 captions to captions.txt.
You can now run data_loader.py with the REAL dataset.


In [23]:
import os
import spacy
import torch
import pandas as pd
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from typing import List, Tuple, Any

# Load Spacy English tokenizer
spacy_eng = spacy.load("en_core_web_sm")

class Vocabulary:
    """
    Builds a vocabulary from text data to convert words to numerical indices.
    """
    def __init__(self, freq_threshold: int):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self) -> int:
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text: str) -> List[str]:
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list: List[str]):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text: str) -> List[int]:
        tokenized_text = self.tokenizer_eng(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

class FlickrDataset(Dataset):
    """
    Custom Dataset class compatible with Flickr8k/30k structure.
    Expects a root_dir with images and a captions_file (csv/tsv).
    """
    def __init__(self, root_dir: str, captions_file: str, transform: Any = None, freq_threshold: int = 5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # Get image and caption columns
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        # Initialize vocabulary and build it
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        caption = self.captions[index]
        img_id = self.imgs[index]
        img_path = os.path.join(self.root_dir, img_id)

        # Load Image
        img = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        # Numericalize Caption (Add Start of Sentence <SOS> and End of Sentence <EOS>)
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption += [self.vocab.stoi["<EOS>"]]

        return img, torch.tensor(numericalized_caption)

class MyCollate:
    """
    Custom collate function to handle variable length captions in a batch.
    Pads sequences to the length of the longest sequence in the batch.
    """
    def __init__(self, pad_idx: int):
        self.pad_idx = pad_idx

    def __call__(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)

        targets = [item[1] for item in batch]
        # pad_sequence is a PyTorch utility
        targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)

        return imgs, targets

def get_loader(
    root_folder: str,
    annotation_file: str,
    transform: Any,
    batch_size: int = 32,
    num_workers: int = 8,
    shuffle: bool = True,
    pin_memory: bool = True,
) -> Tuple[DataLoader, FlickrDataset]:

    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
    pad_idx = dataset.vocab.stoi["<PAD>"]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )

    return loader, dataset

# Example Usage Block (for testing)
if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    # NOTE: You will need a 'captions.txt' and an 'images/' folder to run this
    # loader, dataset = get_loader("images/", "captions.txt", transform=transform)
    # print("Loader initialized successfully")
