In [9]:
import os
import pandas as pd
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms
import re
import torch.nn as nn
import torchvision.models as models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [5]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        text = text.lower()
        text = re.sub(r'([.,!?"\'-])', r' \1 ', text)
        return text.split()

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4 # Start index for new words (0-3 are taken by special tokens)

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

In [6]:
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        
        img = Image.open(os.path.join(self.root_dir, "images", img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)

In [7]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0] for item in batch]
        captions = [item[1] for item in batch]

        # (batch_size, 3, 224, 224)
        imgs = torch.stack(imgs, dim=0)

        targets = pad_sequence(captions, batch_first=True, padding_value=self.pad_idx)

        return imgs, targets

In [8]:
# 1. Define Transforms (Resize to 224x224 for ResNet, Convert to Tensor, Normalize)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 2. Setup Data Loader
dataset = FlickrDataset(
    root_dir="caption_data", 
    captions_file="caption_data/captions.txt", 
    transform=transform
)

# Padding index is usually 0
pad_idx = dataset.vocab.stoi["<PAD>"]

# 3. Create the Loader
loader = DataLoader(
    dataset=dataset,
    batch_size=32,
    num_workers=0,
    shuffle=True,
    collate_fn=MyCollate(pad_idx=pad_idx)
)

# 4. Print Stats to share with your team
print(f"Vocabulary Size: {len(dataset.vocab)}")

# Grab one batch to check shapes
for images, captions in loader:
    print(f"Batch Image Shape: {images.shape}") # Expect: [32, 3, 224, 224]
    print(f"Batch Caption Shape: {captions.shape}") # Expect: [32, Max_Len]
    print("Example Caption (Numerical):", captions[0])
    break

Vocabulary Size: 2994
Batch Image Shape: torch.Size([32, 3, 224, 224])
Batch Caption Shape: torch.Size([32, 18])
Example Caption (Numerical): tensor([  1,  14,  16,  43, 684, 685,   8, 233,   5,   2,   0,   0,   0,   0,
          0,   0,   0,   0])


In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size)        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        
        for param in self.resnet.parameters():
            param.requires_grad = train_CNN

    def forward(self, images):
        features = self.resnet(images)
        
        return self.dropout(self.relu(features))

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_size)

        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)

        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions):        
        embeddings = self.dropout(self.embedding(captions[:, :-1]))
        
        embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)
        
        hiddens, _ = self.lstm(embeddings)

        outputs = self.linear(hiddens)
        return outputs

In [12]:
class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

In [None]:
embed_size = 256
hidden_size = 256
vocab_size = len(dataset.vocab)
num_layers = 1

model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)

images, captions = next(iter(loader))
images = images.to(device)
captions = captions.to(device)

outputs = model(images, captions)

print(f"Images Shape: {images.shape}")    
print(f"Captions Shape: {captions.shape}")
print(f"Output Shape: {outputs.shape}")   
print("Model created successfully!")


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\Guga/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth


100.0%


Images Shape: torch.Size([32, 3, 224, 224])
Captions Shape: torch.Size([32, 25])
Output Shape: torch.Size([32, 25, 2994])
Model created successfully!
