In [None]:
# Imports

In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [2]:
# Paths
image_folder = "data/images/"
captions_file = "data/captions.txt"

# Read captions
captions_data = pd.read_csv(captions_file, delimiter='\t', names=['image', 'caption'])
captions_data.head()

Unnamed: 0,image,caption
0,"image,caption",
1,"1000268201_693b08cb0e.jpg,A child in a pink dr...",
2,"1000268201_693b08cb0e.jpg,A girl going into a ...",
3,"1000268201_693b08cb0e.jpg,A little girl climbi...",
4,"1000268201_693b08cb0e.jpg,A little girl climbi...",


In [None]:
import re
from collections import Counter

class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}
        self.stoi = {v:k for k,v in self.itos.items()}
    
    def tokenizer(self, text):
        return re.findall(r'\w+', text.lower())
    
    def build_vocab(self, sentence_list):
        frequencies = Counter()
        idx = 4
        for sentence in sentence_list:
            for word in self.tokenizer(sentence):
                frequencies[word] += 1
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer(text)
        return [self.stoi.get(token, self.stoi["<UNK>"]) for token in tokenized_text]

# Clean captions and build vocabulary
captions_data = captions_data.dropna(subset=['caption'])
captions_data['caption'] = captions_data['caption'].astype(str)

vocab = Vocabulary(freq_threshold=5)
vocab.build_vocab(captions_data['caption'].tolist())

print("Vocabulary size:", len(vocab.stoi))
print("Example tokens:", list(vocab.stoi.items())[:20])

In [None]:
class FlickrDataset(Dataset):
    def __init__(self, dataframe, img_folder, vocab, transform=None):
        self.df = dataframe
        self.img_folder = img_folder
        self.vocab = vocab
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['image']
        caption = self.df.iloc[idx]['caption']
        
        image = Image.open(os.path.join(self.img_folder, img_name)).convert("RGB")
        if self.transform:
            image = self.transform(image)
        
        numericalized_caption = [vocab.stoi["<SOS>"]] + vocab.numericalize(caption) + [vocab.stoi["<EOS>"]]
        return image, torch.tensor(numericalized_caption)

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

dataset = FlickrDataset(captions_data, image_folder, vocab, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=lambda x: x)

In [None]:
# Encoder

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1]  # Remove FC layer
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size)
    
    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images).squeeze()
        features = self.bn(self.linear(features))
        return features


In [None]:
# LSTM Decoder with Attention

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.hidden_size = hidden_size
    
    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs


In [None]:
embed_size = 256
hidden_size = 512
vocab_size = len(vocab.stoi)

encoder = EncoderCNN(embed_size).to(device)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"])
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = optim.Adam(params, lr=1e-3)


In [None]:
num_epochs = 20

for epoch in range(num_epochs):
    for batch in tqdm(dataloader):
        images, captions = zip(*batch)
        images = torch.stack(images).to(device)
        lengths = [len(cap) for cap in captions]
        captions_padded = nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=vocab.stoi["<PAD>"]).to(device)
        
        optimizer.zero_grad()
        features = encoder(images)
        outputs = decoder(features, captions_padded[:, :-1])
        loss = criterion(outputs.reshape(-1, vocab_size), captions_padded[:,1:].reshape(-1))
        loss.backward()
        optimizer.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# Save models
os.makedirs("models", exist_ok=True)
torch.save(encoder.state_dict(), "models/encoder.pth")
torch.save(decoder.state_dict(), "models/decoder.pth")


In [None]:
def generate_caption(image, encoder, decoder, vocab, max_len=20):
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        image = transform(image).unsqueeze(0).to(device)
        feature = encoder(image)
        caption = [vocab.stoi["<SOS>"]]
        for _ in range(max_len):
            inputs = torch.tensor(caption).unsqueeze(0).to(device)
            outputs = decoder(feature, inputs)
            predicted = outputs.argmax(2)[:,-1].item()
            caption.append(predicted)
            if predicted == vocab.stoi["<EOS>"]:
                break
    words = [vocab.itos[idx] for idx in caption[1:-1]]
    return ' '.join(words)

# Example
img_path = "data/images/123456.jpg"
img = Image.open(img_path).convert("RGB")
caption = generate_caption(img, encoder, decoder, vocab)
print("Generated Caption:", caption)
