# Imports

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader, Subset
from torch.nn.utils.rnn import pad_sequence
from collections import defaultdict, Counter

import spacy
import re
import random

from nltk.translate.bleu_score import corpus_bleu
import torch.nn as nn
import torchvision.models as models

  check_for_updates()


# Dataset

In [None]:
df = pd.read_csv('/content/drive/MyDrive/data (1)/captions.csv')

In [None]:
class TextProcessor:
  def __init__(self):
    self.nlp = spacy.load('en_core_web_sm', disable = ['parser', 'ner', 'tagger'])
    self.contractions = {
        "n't": ' not', "'ll": ' will', "'ve": ' have',
        "'re": ' are', "'d": ' would', "'m": ' am'
    }
    self.special_chars = re.compile(r'[^a-zA-Z0-9\s]')

  def preprocess(self, text):
    text = text.lower()
    for cont, repl in self.contractions.items():
      text = text.replace(cont, repl)
    text = self.special_chars.sub(' ', text)
    doc = self.nlp(text)
    tokens = [token.lemma_ if token.lemma_ != '-PRON-' else token.text for token in doc]
    return ' '.join(tokens).strip()

In [None]:
class Vocab:
  def __init__(self, freq_threshold = 3, embedding_dim = 300):
    self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
    self.stoi = {v: k for k, v in self.itos.items()}
    self.freq_threshold = freq_threshold
    self.embedding_dim = embedding_dim
    self.embeddings = None
    self.text_processor = TextProcessor()

  def build_vocab(self, sentence_list):
    processed_sentences = [self.text_processor.preprocess(s) for s in sentence_list]

    counter = Counter()
    for s in processed_sentences:
      tokens = s.split()
      counter.update(tokens)

    idx = 4
    for w, cnt in counter.items():
      if cnt >= self.freq_threshold:
        self.itos[idx] = w
        self.stoi[w] = idx
        idx += 1
    self._init_embeddings()

  def _init_embeddings(self):
    vocab_size = len(self.itos)
    self.embeddings = torch.nn.Embedding(vocab_size, self.embedding_dim)
    torch.nn.init.xavier_uniform_(self.embeddings.weight)
    special_indices = [self.stoi[tok] for tok in ['<PAD>', '<SOS>', '<EOS>', '<UNK>']]
    with torch.no_grad():
      for idx in special_indices:
        self.embeddings.weight[idx].uniform_(-0.1, 0.1)

  def numericalize(self, text, max_length = 30):
    processed_text = self.text_processor.preprocess(text)
    tokens = processed_text.split()[:max_length - 2]
    return [self.stoi.get(token, self.stoi['<UNK>']) for token in tokens]

In [None]:
class CaptDataset(Dataset):
  def __init__(self, image_dir, csv_path, transform = None, freq_threshold = 3, max_caption_lenght = 30):
    self.image_dir = image_dir
    self.df = pd.read_csv(csv_path)
    self.transform = transform
    self.max_caption_lenght = max_caption_lenght

    self.image_captions = defaultdict(list)
    for _, row in self.df.iterrows():
      self.image_captions[row['image']].append(row['caption'])

    self.images = list(self.image_captions.keys())

    all_captions = [cap for caps in self.image_captions.values() for cap in caps]

    self.vocab = Vocab(freq_threshold)
    self.vocab.build_vocab(all_captions)
    self.valid_imgs = self._validate()

  def _validate(self):
    return [i for i in self.images if os.path.exists(os.path.join(self.image_dir, i))]

  def __len__(self):
    return len(self.valid_imgs)

  def __getitem__(self, idx):
    img_name = self.valid_imgs[idx]
    captions = self.image_captions[img_name]
    caption = random.choice(captions)
    img_path = os.path.join(self.image_dir, img_name)
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    if self.transform:
      img = self.transform(image = img)['image']

    caption_vec = [self.vocab.stoi['<SOS>']]
    caption_vec += self.vocab.numericalize(caption, self.max_caption_lenght)
    caption_vec.append(self.vocab.stoi['<EOS>'])

    return img, torch.tensor(caption_vec)

In [None]:
class MCaptionCollate:
  def __init__(self, pad_idx):
    self.pad_idx = pad_idx

  def __call__(self, batch):
    batch = [i for i in batch if i is not None]
    imgs = torch.stack([i[0] for i in batch])
    captions = [i[1] for i in batch]
    captions_padded = pad_sequence(captions, batch_first = True, padding_value = self.pad_idx)
    masks = (captions_padded != self.pad_idx).float()
    return {
        'images': imgs,
        'captions': captions_padded,
        'masks': masks
    }

In [None]:
def get_transforms(augment = False, target_size=224):
    base_transforms = [
        A.LongestMaxSize(max_size=target_size, interpolation=cv2.INTER_AREA),
        A.PadIfNeeded(min_height=target_size, min_width=target_size,
                        border_mode=cv2.BORDER_CONSTANT, value=(0,0,0)),
    ]

    if augment:
        aug_transforms = [
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
            A.ShiftScaleRotate(shift_limit=0.06, scale_limit=0.1, rotate_limit=15,
                               p=0.5, border_mode=cv2.BORDER_CONSTANT, value=(0,0,0)),
            A.GaussianBlur(p=0.1),
            A.ISONoise(p=0.1),
        ]
        all_transforms = aug_transforms + base_transforms
    else:
        all_transforms = base_transforms

    all_transforms.extend([
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

    return A.Compose(
        all_transforms
    )

In [None]:
dataset = CaptDataset(
    image_dir = '/content/drive/MyDrive/data (1)/Images',
    csv_path = '/content/drive/MyDrive/data (1)/captions.csv',
    transform = get_transforms(augment = True),
    freq_threshold = 3,
    max_caption_lenght = 30
)

  A.PadIfNeeded(min_height=target_size, min_width=target_size,
  original_init(self, **validated_kwargs)
  A.ShiftScaleRotate(shift_limit=0.06, scale_limit=0.1, rotate_limit=15,


KeyboardInterrupt: 

In [None]:
collate_fn = MCaptionCollate(pad_idx = dataset.vocab.stoi['<PAD>'])

In [None]:
train_loader = DataLoader(
    dataset,
    batch_size = 16,
    shuffle = True,
    collate_fn = collate_fn
)

In [None]:
def visualize_dataloader(dataloader, vocab, num_samples=4, denormalize=True):
    """
    Visualize samples from the dataloader with captions
    """
    batch = next(iter(dataloader))
    images = batch['images'][:num_samples]
    captions = batch['captions'][:num_samples]

    cols = min(4, num_samples)
    rows = (num_samples + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(20, 5*rows), squeeze=False)
    axes = axes.flatten()

    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

    for idx in range(num_samples):
        if idx >= len(images):
            break

        ax = axes[idx]
        image = images[idx]
        caption = captions[idx]

        if image.dim() == 4:
            image = image.squeeze(0)

        if denormalize:
            image = image * std + mean
            image = torch.clamp(image, 0, 1)

        np_image = image.numpy().squeeze()

        if np_image.ndim == 3:
            if np_image.shape[0] == 3:
                np_image = np_image.transpose(1, 2, 0)
            elif np_image.shape[2] == 3:
                pass
        elif np_image.ndim == 2:
            pass
        else:
            raise ValueError(f"Unexpected image shape: {np_image.shape}")

        np_image = np.clip(np_image, 0, 1)

        caption_str = ' '.join([
            vocab.itos[token_id.item()]
            for token_id in caption
            if token_id.item() not in [vocab.stoi["<SOS>"],
                               vocab.stoi["<EOS>"],
                               vocab.stoi["<PAD>"]]
        ])

        ax.imshow(np_image)
        ax.set_title(caption_str, wrap=True, fontsize=8)
        ax.axis('off')

    for j in range(idx + 1, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()
visualize_dataloader(train_loader, dataset.vocab, num_samples=8)



KeyboardInterrupt: 

In [None]:
def split_dataset(dataset, test_size=0.15, val_size=0.15, random_state=42):
    num_items = len(dataset)
    indices = list(range(num_items))
    if 1 - test_size <= 0:
        raise ValueError('test_size must be less than 1')
    val_size_relative = val_size / (1 - test_size)
    if val_size_relative >= 1.0 and (len(indices) * (1-test_size)) > 0 :
         print(f'val_size ({val_size}) is too large ({test_size})')
         val_size_relative = 0.5 if (len(indices) * (1-test_size)) * (1-0.5) > 1 else 0.0


    train_val_idx, test_idx = train_test_split(
        indices,
        test_size=test_size,
        random_state=random_state,
        shuffle=True
    )

    if not train_val_idx or val_size_relative <= 0 or val_size_relative >=1:
        train_idx = train_val_idx
        val_idx = []
    else:
        train_idx, val_idx = train_test_split(
            train_val_idx,
            test_size=val_size_relative,
            random_state=random_state,
            shuffle=True
        )

    return {
        'train': Subset(dataset, train_idx),
        'val': Subset(dataset, val_idx),
        'test': Subset(dataset, test_idx)
    }

In [None]:
splits = split_dataset(dataset)
print(f'Train: {len(splits["train"])}, Val: {len(splits["val"])}, Test: {len(splits["test"])}')

In [None]:
train_loader = DataLoader(
    splits['train'],
    batch_size = 16,
    shuffle = True,
    collate_fn = collate_fn
)
val_loader = DataLoader(
    splits['val'],
    batch_size = 16,
    shuffle = False,
    collate_fn = collate_fn
)
test_loader = DataLoader(
    splits['test'],
    batch_size = 16,
    shuffle = False,
    collate_fn = collate_fn
)

In [None]:
visualize_dataloader(val_loader, dataset.vocab, num_samples = 4)

# Models

In [None]:
# %pip install nltk if there are no such modules
import torch.nn as nn
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()

        resnet = models.resnet50(pretrained=True)

        for param in resnet.parameters():
            param.requires_grad = False

        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.fc(features)
        return features

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab):
        super(DecoderRNN, self).__init__()
        self.vocab = vocab

        pretrained_weights = vocab.embeddings.weight.data
        num_embeddings, embedding_dim = pretrained_weights.shape
        assert embedding_dim == embed_size, "Embedding dimension should match."

        self.embed = nn.Embedding(num_embeddings, embed_size)
        self.embed.weight.data.copy_(pretrained_weights)

        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, num_embeddings)
        self.init_h = nn.Linear(embed_size, hidden_size)
        self.init_c = nn.Linear(embed_size, hidden_size)

    def forward(self, features, captions):
        captions_input = captions[:, :-1]
        h0 = self.init_h(features).unsqueeze(0)
        c0 = self.init_c(features).unsqueeze(0)
        embeddings = self.embed(captions_input)
        lstm_out, _ = self.lstm(embeddings, (h0, c0))
        outputs = self.linear(lstm_out)
        return outputs

    def init_hidden_states(self, features):
        h = self.init_h(features).unsqueeze(0)
        c = self.init_c(features).unsqueeze(0)
        return h, c


# Training

In [None]:
def train_model(num_epochs, train_loader, encoder, decoder, criterion, optimizer, device):
    encoder.train()
    decoder.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch in train_loader:
            images   = batch['images'].to(device)
            captions = batch['captions'].to(device)


            optimizer.zero_grad()
            features = encoder(images)
            outputs  = decoder(features, captions)

            targets = captions[:, 1:]
            loss = criterion(
                outputs.reshape(-1, outputs.size(-1)),
                targets.reshape(-1)
            )

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader):.4f}")

# Evaluate

In [None]:
def evaluate(encoder, decoder, test_loader, dataset, device):
    encoder.eval()
    decoder.eval()
    references = []
    hypotheses = []

    vocab = dataset.vocab

    start_idx = vocab.stoi.get('<start>', vocab.stoi.get('<START>', None))
    end_idx = vocab.stoi.get('<end>', vocab.stoi.get('<END>', None))
    pad_idx = vocab.stoi.get('<pad>', vocab.stoi.get('<PAD>', None))

    with torch.no_grad():
       for batch in test_loader:
            images   = batch['images'].to(device)
            captions = batch['captions']

            features = encoder(images)
            batch_size = images.size(0)

            for i in range(batch_size):
                feature     = features[i].unsqueeze(0)
                sampled_ids = []
                input_id    = torch.tensor([[start_idx]], dtype=torch.long).to(device)
                h, c        = decoder.init_hidden_states(feature)

                for _ in range(50):
                    embeddings = decoder.embed(input_id)
                    out, (h, c) = decoder.lstm(embeddings, (h, c))
                    out         = decoder.linear(out.squeeze(1))
                    _, predicted = out.max(1)
                    predicted_item = predicted.item()
                    if predicted_item == end_idx:
                        break
                    sampled_ids.append(predicted_item)
                    input_id = predicted.unsqueeze(0)

                hypothesis = [vocab.itos[idx] for idx in sampled_ids]
                hypotheses.append(hypothesis)

                ref_tokens = [
                    vocab.itos[idx] for idx in captions[i].tolist()
                    if idx not in (pad_idx, start_idx, end_idx)
                ]
                references.append([ref_tokens])

    bleu1 = corpus_bleu(references, hypotheses, weights=(1.0, 0, 0, 0))
    bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
    bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
    bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))
    print(f"BLEU-1: {bleu1:.4f}, BLEU-2: {bleu2:.4f}, BLEU-3: {bleu3:.4f}, BLEU-4: {bleu4:.4f}")

# Main

In [None]:
num_epochs     = 5
learning_rate  = 1e-3
hidden_size    = 512
batch_size     = 32
embed_size     = 300  # must match Vocab.embedding_dim
device         = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import os
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader, Subset
from torch.nn.utils.rnn import pad_sequence
from collections import defaultdict, Counter

import spacy
import re
import random
import torch.optim as optim
from nltk.translate.bleu_score import corpus_bleu
import torch.nn as nn
import torchvision.models as models

In [None]:
train_transform = get_transforms(augment=True,  target_size=224)
eval_transform  = get_transforms(augment=False, target_size=224)

dataset = CaptDataset(
        image_dir          = '/content/drive/MyDrive/data (1)/Images',
        csv_path           = '/content/drive/MyDrive/data (1)/captions.csv',
        transform          = train_transform,
        freq_threshold     = 3,
        max_caption_lenght = 30
)
collate_fn = MCaptionCollate(pad_idx=dataset.vocab.stoi['<PAD>'])

splits = split_dataset(dataset, test_size=0.15, val_size=0.15, random_state=42)
train_ds = splits['train']
val_ds   = splits['val']
test_ds  = splits['test']


val_ds.dataset.transform  = eval_transform
test_ds.dataset.transform = eval_transform
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

  A.PadIfNeeded(min_height=target_size, min_width=target_size,
  original_init(self, **validated_kwargs)
  A.ShiftScaleRotate(shift_limit=0.06, scale_limit=0.1, rotate_limit=15,


In [None]:
vocab = dataset.vocab
encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderRNN(embed_size, hidden_size, vocab).to(device)


pad_idx  = vocab.stoi['<PAD>']
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)


print("Starting training...")
train_model(num_epochs, train_loader, encoder, decoder, criterion, optimizer, device)
print("Training complete. Saving models...")

torch.save(encoder.state_dict(), "encoder.pth")
torch.save(decoder.state_dict(), "decoder.pth")
print("Models saved to encoder.pth and decoder.pth")

print("Loading models for test evaluation...")
encoder.load_state_dict(torch.load("encoder.pth", map_location=device))
decoder.load_state_dict(torch.load("decoder.pth", map_location=device))

print("Evaluating on test set (BLEU-1 to BLEU-4)...")
evaluate(encoder, decoder, test_loader, dataset, device)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 118MB/s]


Starting training...
