<a href="https://colab.research.google.com/github/Karn2898/Visual-semantic-pipeline/blob/main/Image_Captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch torchvision spacy pillow




In [None]:
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m71.7 MB/s[0m eta [36m0:00:00[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [None]:
!unzip /content/Flickr8k_text.zip

Archive:  /content/Flickr8k_text.zip
  inflating: CrowdFlowerAnnotations.txt  
  inflating: ExpertAnnotations.txt   
  inflating: Flickr8k.lemma.token.txt  
   creating: __MACOSX/
  inflating: __MACOSX/._Flickr8k.lemma.token.txt  
  inflating: Flickr8k.token.txt      
  inflating: Flickr_8k.devImages.txt  
  inflating: Flickr_8k.testImages.txt  
  inflating: Flickr_8k.trainImages.txt  
  inflating: readme.txt              


In [None]:
import pandas as pd

def format_real_flickr_data():
    input_file = "Flickr8k.token.txt"
    output_file = "captions.txt"

    imgs = []
    caps = []

    print("Reading real Flickr8k data...")
    try:
        with open(input_file, "r") as f:
            for line in f:

                tokens = line.split("\t")
                if len(tokens) < 2:
                    continue


                img_id = tokens[0].split("#")[0]
                caption = tokens[1].strip()

                imgs.append(img_id)
                caps.append(caption)

        # Create Dataframe and save
        df = pd.DataFrame({"image": imgs, "caption": caps})
        df.to_csv(output_file, index=False)
        print(f"Success! Saved {len(df)} captions to {output_file}.")


    except FileNotFoundError:
        print(f"Error: Could not find {input_file}. Please download the dataset first.")

if __name__ == "__main__":
    format_real_flickr_data()


Reading real Flickr8k data...
Success! Saved 40460 captions to captions.txt.
You can now run data_loader.py with the REAL dataset.


In [None]:
import os
import spacy
import torch
import pandas as pd
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from typing import List, Tuple, Any


spacy_eng = spacy.load("en_core_web_sm")

class Vocabulary:

    def __init__(self, freq_threshold: int):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self) -> int:
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text: str) -> List[str]:
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self, sentence_list: List[str]):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text: str) -> List[int]:
        tokenized_text = self.tokenizer_eng(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

class FlickrDataset(Dataset):

    def __init__(self, root_dir: str, captions_file: str, transform: Any = None, freq_threshold: int = 5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        self.imgs = self.df["image"]
        self.captions = self.df["caption"]


        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        caption = self.captions[index]
        img_id = self.imgs[index]
        img_path = os.path.join(self.root_dir, img_id)


        img = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)


        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption += [self.vocab.stoi["<EOS>"]]

        return img, torch.tensor(numericalized_caption)

class MyCollate:

    def __init__(self, pad_idx: int):
        self.pad_idx = pad_idx

    def __call__(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)

        targets = [item[1] for item in batch]

        targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)

        return imgs, targets

def get_loader(
    root_folder: str,
    annotation_file: str,
    transform: Any,
    batch_size: int = 32,
    num_workers: int = 8,
    shuffle: bool = True,
    pin_memory: bool = True,
) -> Tuple[DataLoader, FlickrDataset]:

    dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
    pad_idx = dataset.vocab.stoi["<PAD>"]

    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
    )

    return loader, dataset


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import math

class EncoderCNN(nn.Module):

    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)

        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)

        self.embed = nn.Linear(2048, embed_size)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()
        self.train_CNN = train_CNN

    def forward(self, images):
        with torch.set_grad_enabled(self.train_CNN):
            features = self.resnet(images)


        features = features.permute(0, 2, 3, 1)
        features = features.view(features.size(0), -1, features.size(3))


        features = self.embed(features)
        features = self.relu(features)

        return features

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) # [Max_Len, 1, D_Model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class DecoderTransformer(nn.Module):
    def __init__(self, embed_size, vocab_size, num_heads=4, num_layers=2, max_len=100):
        super(DecoderTransformer, self).__init__()

        self.embed = nn.Embedding(vocab_size, embed_size)
        self.pos_encoder = PositionalEncoding(embed_size, max_len)


        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        self.linear = nn.Linear(embed_size, vocab_size)
        self.max_len = max_len

    def forward(self, features, captions):

        # Embed captions
        embeddings = self.embed(captions) * math.sqrt(features.size(-1))


        embeddings = embeddings.permute(1, 0, 2)
        features = features.permute(1, 0, 2)

        # positional encoding
        embeddings = self.pos_encoder(embeddings)

        tgt_mask = self.generate_square_subsequent_mask(embeddings.size(0)).to(features.device)

        outputs = self.transformer_decoder(tgt=embeddings, memory=features, tgt_mask=tgt_mask)

        outputs = self.linear(outputs)

        return outputs.permute(1, 0, 2)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def caption_image(self, image_features, vocabulary, max_length=20):

        batch_size = image_features.size(0)
        start_token = vocabulary.stoi["<SOS>"]

        generated = torch.tensor([start_token]).unsqueeze(1).to(image_features.device)

        image_features = image_features.permute(1, 0, 2)

        result_caption = []

        for _ in range(max_length):

            tgt_emb = self.embed(generated) * math.sqrt(image_features.size(-1))
            tgt_emb = tgt_emb.permute(1, 0, 2)
            tgt_emb = self.pos_encoder(tgt_emb)

            # Mask
            mask = self.generate_square_subsequent_mask(tgt_emb.size(0)).to(image_features.device)

            # Decode
            out = self.transformer_decoder(tgt=tgt_emb, memory=image_features, tgt_mask=mask)

            # Get last token output
            last_token_out = out[-1, :, :]

            # Predict
            logits = self.linear(last_token_out)
            predicted_id = logits.argmax(1).item()

            result_caption.append(vocabulary.itos[predicted_id])

            if vocabulary.itos[predicted_id] == "<EOS>":
                break


            next_token = torch.tensor([predicted_id]).unsqueeze(1).to(image_features.device)
            generated = torch.cat((generated, next_token), dim=1)

        return result_caption

class ImageCaptionModel(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_heads, num_layers, train_CNN=False):
        super(ImageCaptionModel, self).__init__()
        self.encoder = EncoderCNN(embed_size, train_CNN)
        self.decoder = DecoderTransformer(embed_size, vocab_size, num_heads, num_layers)

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

    def caption_image(self, image, vocabulary, max_length=20):
        self.eval()
        with torch.no_grad():
            features = self.encoder(image.unsqueeze(0))
            return self.decoder.caption_image(features, vocabulary, max_length)

# Test
if __name__ == "__main__":
    embed_size = 256
    vocab_size = 1000
    model = ImageCaptionModel(embed_size, 256, vocab_size, num_heads=4, num_layers=2)

    img = torch.randn(2, 3, 224, 224)
    caps = torch.randint(0, 1000, (2, 20))

    out = model(img, caps)
    print(f"Transformer Output Shape: {out.shape}")
