In [6]:
import os
import numpy as np
import pandas as pd
import string
import nltk
nltk.download('punkt')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
from tqdm import tqdm


[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [7]:
DATA_DIR = "/kaggle/input/flickr8k"

# Load captions
captions_file = os.path.join(DATA_DIR, "captions.txt")
captions_df = pd.read_csv(captions_file)
captions_df.head()

# Clean captions
def clean_caption(text):
    text = text.lower()
    text = text.translate(str.maketrans('', '', string.punctuation))
    tokens = nltk.word_tokenize(text)
    tokens = [word for word in tokens if word.isalpha()]
    return " ".join(tokens)

captions_df['caption'] = captions_df['caption'].apply(clean_caption)


In [8]:
all_captions = captions_df['caption'].tolist()

word_freq = {}
for cap in all_captions:
    for word in cap.split():
        word_freq[word] = word_freq.get(word, 0) + 1


words = [w for w in word_freq if word_freq[w] >= 5]


word2idx = {w:i+4 for i,w in enumerate(words)}
word2idx["<PAD>"] = 0
word2idx["<SOS>"] = 1
word2idx["<EOS>"] = 2
word2idx["<UNK>"] = 3

idx2word = {i:w for w,i in word2idx.items()}
vocab_size = len(word2idx)
vocab_size


2988

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"

resnet = models.resnet50(pretrained=True)
resnet = nn.Sequential(*list(resnet.children())[:-1])  # remove final FC layer
resnet = resnet.to(device)
resnet.eval()

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

def extract_features(image_path):
    img = Image.open(image_path).convert("RGB")
    img = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        features = resnet(img)
    return features.squeeze().cpu().numpy()


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 160MB/s] 


In [10]:
class FlickrDataset(Dataset):
    def __init__(self, df, img_dir):
        self.df = df
        self.img_dir = img_dir

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['image'])
        caption = row['caption'].split()

        
        encoded = [word2idx.get(word, word2idx["<UNK>"]) for word in caption]
        encoded = [word2idx["<SOS>"]] + encoded + [word2idx["<EOS>"]]

        
        caption_tensor = torch.tensor(encoded, dtype=torch.long)
        img_features = torch.tensor(extract_features(img_path), dtype=torch.float32)

        return img_features, caption_tensor


In [11]:
class CaptionDecoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size + 2048, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, img_feat, captions):
        img_feat = img_feat.unsqueeze(1)  
        embeddings = self.embed(captions) 

        
        img_feat = img_feat.repeat(1, embeddings.size(1), 1)

        lstm_input = torch.cat((img_feat, embeddings), dim=2)
        out, _ = self.lstm(lstm_input)
        out = self.fc(out)
        return out


In [12]:
dataset = FlickrDataset(captions_df, os.path.join(DATA_DIR, "Images"))
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=lambda x: x)

model = CaptionDecoder(256, 256, vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=word2idx["<PAD>"])
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [13]:
for epoch in range(2):  
    total_loss = 0
    for batch in tqdm(loader):
        imgs, caps = zip(*batch)
        imgs = torch.stack(imgs).to(device)
        caps = torch.nn.utils.rnn.pad_sequence(caps, batch_first=True, padding_value=word2idx["<PAD>"]).to(device)

        optimizer.zero_grad()
        outputs = model(imgs, caps[:,:-1])   
        loss = criterion(outputs.reshape(-1, vocab_size), caps[:,1:].reshape(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(loader):.4f}")


100%|██████████| 1265/1265 [09:53<00:00,  2.13it/s]


Epoch 1, Loss: 3.9481


100%|██████████| 1265/1265 [09:35<00:00,  2.20it/s]

Epoch 2, Loss: 3.2534





In [14]:
def generate_caption(image_path, max_len=20):
    model.eval()
    feat = torch.tensor(extract_features(image_path)).float().to(device)

    caption = [word2idx["<SOS>"]]

    for _ in range(max_len):
        inp = torch.tensor(caption).unsqueeze(0).to(device)
        out = model(feat.unsqueeze(0), inp)
        next_word = out.argmax(2)[:,-1].item()
        caption.append(next_word)
        if next_word == word2idx["<EOS>"]:
            break

    return " ".join(idx2word[w] for w in caption if w not in (0,1,2,3))


In [16]:
test_image = "/kaggle/input/flickr8k/Images/1000268201_693b08cb0e.jpg"
print(generate_caption(test_image))


a man in a blue shirt is sitting on a bench
