In [None]:
import os
import pandas as pd
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms
import re
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import matplotlib.pyplot as plt

device = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

print(f"Using device: {device}")

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Current Device Name: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")

Using device: cuda
CUDA available: True
Current Device Name: NVIDIA GeForce RTX 4070 Laptop GPU
CUDA Version: 12.1


In [None]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        text = text.lower()
        text = re.sub(r'([.,!?"\'-])', r' \1 ', text)
        return text.split()

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4 

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

In [31]:
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        
        img = Image.open(os.path.join(self.root_dir, "images", img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)

In [32]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0] for item in batch]
        captions = [item[1] for item in batch]

        # (batch_size, 3, 224, 224)
        imgs = torch.stack(imgs, dim=0)

        targets = pad_sequence(captions, batch_first=True, padding_value=self.pad_idx)

        return imgs, targets

In [33]:
# 1. Define Transforms (Resize to 224x224 for ResNet, Convert to Tensor, Normalize)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 2. Setup Data Loader
dataset = FlickrDataset(
    root_dir="caption_data", 
    captions_file="caption_data/captions.txt", 
    transform=transform
)

# Padding index is usually 0
pad_idx = dataset.vocab.stoi["<PAD>"]

# 3. Create the Loader
loader = DataLoader(
    dataset=dataset,
    batch_size=32,
    num_workers=0,
    shuffle=True,
    collate_fn=MyCollate(pad_idx=pad_idx)
)

# 4. Print Stats to share with your team
print(f"Vocabulary Size: {len(dataset.vocab)}")

# Grab one batch to check shapes
for images, captions in loader:
    print(f"Batch Image Shape: {images.shape}") # Expect: [32, 3, 224, 224]
    print(f"Batch Caption Shape: {captions.shape}") # Expect: [32, Max_Len]
    print("Example Caption (Numerical):", captions[0])
    break

Vocabulary Size: 2994
Batch Image Shape: torch.Size([32, 3, 224, 224])
Batch Caption Shape: torch.Size([32, 21])
Example Caption (Numerical): tensor([   1,    4,   85,   28,    8,    4, 2136,  258,    5,    2,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0])


In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        
        for param in self.resnet.parameters():
            param.requires_grad = False
        
        for name, param in self.resnet.named_parameters():
            if "layer4" in name or "fc" in name:
                param.requires_grad = True
        
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size)        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, images):
        features = self.resnet(images)
        return self.dropout(self.relu(features))

In [35]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions):        
        embeddings = self.dropout(self.embedding(captions[:, :-1]))
        embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

In [36]:
class CNNtoRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

In [37]:
embed_size = 256
hidden_size = 256
vocab_size = len(dataset.vocab)
num_layers = 1

model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)

images, captions = next(iter(loader))
images = images.to(device)
captions = captions.to(device)

outputs = model(images, captions)

print(f"Images Shape: {images.shape}")    
print(f"Captions Shape: {captions.shape}")
print(f"Output Shape: {outputs.shape}")   
print("Model created successfully!")

Images Shape: torch.Size([32, 3, 224, 224])
Captions Shape: torch.Size([32, 30])
Output Shape: torch.Size([32, 30, 2994])
Model created successfully!


In [None]:
learning_rate = 3e-4
num_epochs = 25

criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

model = model.to(device)
model.train()

print("Training setup complete!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")


Training setup complete!
Total parameters: 26094834
Trainable parameters: 17551538




In [None]:
print("Starting Training...")

for epoch in range(num_epochs):
    epoch_loss = 0
    model.train()
    
    for idx, (imgs, captions) in enumerate(loader):
        imgs = imgs.to(device)
        captions = captions.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(imgs, captions)
        
        loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
        if idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{idx}/{len(loader)}], Loss: {loss.item():.4f}")
    
    avg_loss = epoch_loss / len(loader)
    print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.4f}")
    
    scheduler.step(avg_loss)
    
    if (epoch + 1) % 5 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, f'checkpoint_epoch_{epoch+1}.pth')
        print(f"Checkpoint saved at epoch {epoch+1}")

print("Training Complete!")

Starting Training...
Epoch [1/25], Step [0/1265], Loss: 8.0077
Epoch [1/25], Step [100/1265], Loss: 4.3768
Epoch [1/25], Step [200/1265], Loss: 3.7562
Epoch [1/25], Step [300/1265], Loss: 3.7761
Epoch [1/25], Step [400/1265], Loss: 3.6746
Epoch [1/25], Step [500/1265], Loss: 3.4719
Epoch [1/25], Step [600/1265], Loss: 3.6555
Epoch [1/25], Step [700/1265], Loss: 3.5415
Epoch [1/25], Step [800/1265], Loss: 3.3130
Epoch [1/25], Step [900/1265], Loss: 3.5687
Epoch [1/25], Step [1000/1265], Loss: 3.1399
Epoch [1/25], Step [1100/1265], Loss: 3.2931
Epoch [1/25], Step [1200/1265], Loss: 3.2015
Epoch [1/25] Average Loss: 3.7499
Epoch [2/25], Step [0/1265], Loss: 3.1857
Epoch [2/25], Step [100/1265], Loss: 3.2379
Epoch [2/25], Step [200/1265], Loss: 3.1388
Epoch [2/25], Step [300/1265], Loss: 3.0965
Epoch [2/25], Step [400/1265], Loss: 3.3434
Epoch [2/25], Step [500/1265], Loss: 3.0490
Epoch [2/25], Step [600/1265], Loss: 3.1540
Epoch [2/25], Step [700/1265], Loss: 3.0476
Epoch [2/25], Step [80

In [None]:
torch.save(model.state_dict(), "best_model.pth")
print("Model weights saved to 'best_model.pth'")

import pickle
with open('vocab.pkl', 'wb') as f:
    pickle.dump(dataset.vocab, f)
print("Vocabulary saved to 'vocab.pkl'")

Model weights saved to 'best_model.pth'
Vocabulary saved to 'vocab.pkl'
