In [1]:
import os
import csv
import string
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from torchvision import models
from PIL import Image
from collections import Counter
from torch.optim.lr_scheduler import ReduceLROnPlateau

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

DATASET_DIR = r"D:\Assignment3\Dataset"
IMAGE_FOLDER = os.path.join(DATASET_DIR, "Images")
CAPTIONS_FILE = os.path.join(DATASET_DIR, "captions.txt")

MODEL_SAVE_PATH = r"D:\Assignment3\Models"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
print(f"Models will be saved to: {MODEL_SAVE_PATH}")

Using device: cuda
Models will be saved to: D:\Assignment3\Models


In [2]:
class Vocabulary:
    def __init__(self, freq_threshold=2):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {v: k for k, v in self.itos.items()}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    def tokenizer_eng(self, text):
        # Efficient pure-python tokenizer
        text = text.lower()
        text = text.translate(str.maketrans('', '', string.punctuation))
        return text.split()

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4
        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                frequencies[word] += 1
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)
        return [self.stoi["<SOS>"]] + \
               [self.stoi.get(word, self.stoi["<UNK>"]) for word in tokenized_text] + \
               [self.stoi["<EOS>"]]

class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=2):
        self.root_dir = root_dir
        self.transform = transform
        self.df = self._read_captions(captions_file)
        
        # Build vocabulary
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary([x[1] for x in self.df])

    def _read_captions(self, captions_file):
        data = []
        with open(captions_file, "r", encoding='utf-8') as f:
            reader = csv.reader(f)
            next(reader) 
            for line in reader:
                if len(line) >= 2:
                    data.append((line[0], line[1]))
        return data

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.df[index][1]
        img_id = self.df[index][0]
        try:
            img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
        except FileNotFoundError:
            return self.__getitem__((index + 1) % len(self.df))

        if self.transform is not None:
            img = self.transform(img)

        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)

class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)
        return imgs, targets

In [3]:
# Stronger transforms for training
train_transform = transforms.Compose([
    transforms.Resize((232, 232)),
    transforms.RandomCrop((224, 224)),    
    transforms.RandomHorizontalFlip(),     
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

print("Loading dataset (this may take a minute)...")
dataset = FlickrDataset(
    root_dir=IMAGE_FOLDER,
    captions_file=CAPTIONS_FILE,
    transform=train_transform
)

pad_idx = dataset.vocab.stoi["<PAD>"]

loader = DataLoader(
    dataset=dataset,
    batch_size=32,
    num_workers=0,  
    shuffle=True,
    pin_memory=True,
    collate_fn=MyCollate(pad_idx=pad_idx)
)

print(f"Dataset Loaded! Vocab size: {len(dataset.vocab)}")

Loading dataset (this may take a minute)...
Dataset Loaded! Vocab size: 5224


In [4]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.train_CNN = train_CNN
        self.inception = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, images):
        features = self.inception(images)
        return self.dropout(self.relu(features))

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=0.3) 
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions[:, :-1])
        embeddings = torch.cat((features.unsqueeze(1), embeddings), dim=1)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs
    
    def sample(self, features, max_len=20):
        # Inference function
        result_caption = []
        with torch.no_grad():
            inputs = features.unsqueeze(1)
            states = None
            for _ in range(max_len):
                hiddens, states = self.lstm(inputs, states)
                output = self.linear(hiddens.squeeze(1))
                predicted = output.argmax(1)
                
                result_caption.append(predicted.item())
                
                inputs = self.embed(predicted).unsqueeze(1) 
                
                if predicted.item() == 2: # <EOS>
                    break
        return result_caption

In [None]:
# Hyperparameters
embed_size = 256
hidden_size = 512
vocab_size = len(dataset.vocab)
num_layers = 2
learning_rate = 3e-4
num_epochs = 15

encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(
    list(decoder.parameters()) + list(encoder.parameters()), 
    lr=learning_rate
)

scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5, verbose=True)

vocab_path = os.path.join(MODEL_SAVE_PATH, "vocab.pkl")
with open(vocab_path, "wb") as f:
    pickle.dump(dataset.vocab, f)
print(f"Vocabulary saved to {vocab_path}")

print("Starting Training...")

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0
    
    for idx, (imgs, captions) in enumerate(loader):
        imgs = imgs.to(device)
        captions = captions.to(device)

        optimizer.zero_grad()
        
        features = encoder(imgs)
        outputs = decoder(features, captions)
        
        loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        if idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Step [{idx}/{len(loader)}] Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(loader)
    scheduler.step(avg_loss) 
    print(f"Epoch {epoch+1} Completed. Average Loss: {avg_loss:.4f}")
    
    # Save Model Checkpoints
    if (epoch+1) % 5 == 0 or epoch == num_epochs-1:
        enc_path = os.path.join(MODEL_SAVE_PATH, "encoder.pth")
        dec_path = os.path.join(MODEL_SAVE_PATH, "decoder.pth")
        
        torch.save(encoder.state_dict(), enc_path)
        torch.save(decoder.state_dict(), dec_path)
        print(f"Saved models to {MODEL_SAVE_PATH}")

print("Training Finished!")



Vocabulary saved to D:\Assignment3\Models\vocab.pkl
Starting Training...
Epoch [1/15] Step [0/1265] Loss: 8.5595
Epoch [1/15] Step [100/1265] Loss: 4.0770
Epoch [1/15] Step [200/1265] Loss: 3.8040
Epoch [1/15] Step [300/1265] Loss: 3.5874


In [None]:
import io
import ipywidgets as widgets
from IPython.display import display, clear_output
import traceback

# Load models just to be safe (ensure we use the saved ones)
# encoder.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, "encoder.pth")))
# decoder.load_state_dict(torch.load(os.path.join(MODEL_SAVE_PATH, "decoder.pth")))

encoder.eval()
decoder.eval()

# Simple transform for testing
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

def get_caption_from_model(image, encoder, decoder, vocab, max_len=20):
    image_tensor = test_transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        features = encoder(image_tensor)
        output_ids = decoder.sample(features, max_len=max_len)
    
    caption_words = []
    for word_id in output_ids:
        word = vocab.itos[word_id]
        if word == "<EOS>":
            break
        if word not in ["<SOS>", "<PAD>", "<UNK>"]:
            caption_words.append(word)
            
    return " ".join(caption_words)

# Upload Widget
uploader = widgets.FileUpload(accept='image/*', multiple=False)
output = widgets.Output()

def on_upload_change(change):
    if not uploader.value: return
    
    with output:
        clear_output()
        print("Processing...")
    
    try:
        # Handle ipywidgets version differences
        if isinstance(uploader.value, dict):
            uploaded_file = list(uploader.value.values())[0]
        else:
            uploaded_file = list(uploader.value)[-1]
            
        image = Image.open(io.BytesIO(uploaded_file['content'])).convert("RGB")
        
        # Generate
        caption = get_caption_from_model(image, encoder, decoder, dataset.vocab)
        
        with output:
            clear_output()
            display(image.resize((300, 300)))
            print(f"\nGenerated Caption: {caption}")
            
    except Exception as e:
        with output:
            print("Error:", traceback.format_exc())

uploader.observe(on_upload_change, names='value')
print("Upload an image below to test:")
display(uploader, output)