In [8]:
import nltk
nltk.download("punkt")


[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
import os

CAPTION_FILE = "/kaggle/input/flickr8/Flickr8k.token/Flickr8k.token.txt"
def load_captions(caption_file):
    captions = []
    image_names = []

    with open(caption_file, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if len(line) == 0:
                continue

            parts = line.split("\t")
            if len(parts) != 2:
                continue  

            image_caption_id, caption = parts
            image_with_id = image_caption_id.split("#")[0]
            if image_with_id.endswith('.jpg'):
                image = image_with_id
            else:
                image = image_with_id.rsplit('.', 1)[0] if '.' in image_with_id else image_with_id
                if not image.endswith('.jpg'):
                    image = image + '.jpg'

            image_names.append(image)
            captions.append(caption)

    return image_names, captions


In [None]:
import re

def clean_caption(caption):
    caption = caption.lower()
    caption = re.sub(r"[^a-z ]", "", caption)  
    caption = re.sub(r"\s+", " ", caption)    
    caption = caption.strip()

    caption = "<start> " + caption + " <end>"
    return caption


In [11]:
def preprocess_captions(captions):
    cleaned = []
    for cap in captions:
        cleaned.append(clean_caption(cap))
    return cleaned


In [12]:
from collections import Counter
def build_vocabulary(captions, freq_threshold=5):
    counter = Counter()

    for caption in captions:
        counter.update(caption.split())
    vocab = [word for word in counter if counter[word] >= freq_threshold]


    return vocab, counter

In [13]:
def create_word_mappings(vocab):
    word2idx = {
        "<pad>": 0,
        "<start>": 1,
        "<end>": 2
    }

    idx2word = {
        0: "<pad>",
        1: "<start>",
        2: "<end>"
    }

    idx = 3
    for word in vocab:
        if word not in word2idx:
            word2idx[word] = idx
            idx2word[idx] = word
            idx += 1

    return word2idx, idx2word


In [14]:
def numericalize_captions(captions, word2idx):
    numeric_captions = []

    for caption in captions:
        tokens = caption.split()
        numeric = []

        for word in tokens:
            if word in word2idx:
                numeric.append(word2idx[word])
            else:
                numeric.append(word2idx["<pad>"])

        numeric_captions.append(numeric)

    return numeric_captions


In [15]:
def main():
    image_names, captions = load_captions(CAPTION_FILE)

    captions = preprocess_captions(captions)

    vocab, counter = build_vocabulary(captions, freq_threshold=5)

    word2idx, idx2word = create_word_mappings(vocab)

    numeric_captions = numericalize_captions(captions, word2idx)

    print("Total images:", len(set(image_names)))
    print("Total captions:", len(captions))
    print("Vocabulary size:", len(word2idx))

    print("\nExample:")
    print("Caption:", captions[0])
    print("Numerical:", numeric_captions[0])
    print("Most common words:")

    for word, count in counter.most_common(10):
       print(word, count)


if __name__ == "__main__":
    main()


Total images: 8092
Total captions: 40460
Vocabulary size: 2987

Example:
Caption: <start> a child in a pink dress is climbing up a set of stairs in an entry way <end>
Numerical: [1, 3, 4, 5, 3, 6, 7, 8, 9, 10, 3, 11, 12, 13, 5, 14, 0, 15, 2]
Most common words:
a 62989
<start> 40460
<end> 40460
in 18975
the 18419
on 10744
is 9345
and 8852
dog 8136
with 7765


In [None]:

from torch.utils.data import Dataset
from PIL import Image
import os

class FlickrDataset(Dataset):
    def __init__(self, image_dir, image_names, captions, transform):
        self.image_dir = image_dir
        self.image_names = image_names
        self.captions = captions
        self.transform = transform

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_names[idx])
        try:
            image = Image.open(image_path).convert("RGB")
            if self.transform:
              image = self.transform(image)

              caption = torch.tensor(self.captions[idx])

            return image, caption
        except FileNotFoundError:
          return None



In [17]:
import torch
torch.cuda.is_available()


True

In [None]:
def collate_fn(batch):
     
    batch = [item for item in batch if item is not None]

    if len(batch) == 0:
        return None
    batch.sort(key=lambda x: len(x[1]), reverse=True)

    images, captions = zip(*batch)

    images = torch.stack(images, dim=0) 
    lengths = [len(caption) for caption in captions]
    max_len = max(lengths)

    padded_captions = torch.zeros(len(captions), max_len).long()

    for i, caption in enumerate(captions):
        end = lengths[i]
        padded_captions[i, :end] = caption

    return images, padded_captions, lengths


In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
image_names, captions = load_captions(CAPTION_FILE)
captions = preprocess_captions(captions)

vocab, counter = build_vocabulary(captions, freq_threshold=5)

word2idx, idx2word = create_word_mappings(vocab)

numeric_captions = numericalize_captions(captions, word2idx)

dataset = FlickrDataset(
    image_dir="/kaggle/input/flickr8/Flicker8k_Dataset/Flicker8k_Dataset",
    image_names=image_names,
    captions=numeric_captions,
    transform=transform
)

loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2,     
    pin_memory=True
)


In [None]:
images, captions, lengths = next(iter(loader))

print("Images shape:", images.shape)
print("Captions shape:", captions.shape) 
print("Lengths:", lengths[:5]) 


Images shape: torch.Size([32, 3, 224, 224])
Captions shape: torch.Size([32, 27])
Lengths: [27, 22, 22, 20, 18]


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class EncoderCNN(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet50(pretrained=True)

        self.resnet = nn.Sequential(*list(resnet.children())[:-2])

        for param in self.resnet.parameters():
            param.requires_grad = False 

    def forward(self, images):
        features = self.resnet(images)
    
        features = features.permute(0, 2, 3, 1)

        features = features.view(features.size(0), -1, 2048)

        return features
   

In [None]:

class Attention(nn.Module): 
    def __init__(self, encoder_dim, decoder_dim, att_dim): 
        super().__init__()
        self.enc = nn.Linear(encoder_dim, att_dim) 
        self.dec = nn.Linear(decoder_dim, att_dim) 
        self.fc = nn.Linear(att_dim, 1) 
    def forward(self, encoder_out, hidden):
        att = self.fc(torch.tanh(
            self.enc(encoder_out) + self.dec(hidden).unsqueeze(1) 
        )).squeeze(2) 
        alpha = torch.softmax(att, dim=1)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context


In [None]:

encoder = EncoderCNN()

images, captions, lengths = next(iter(loader))

features = encoder(images)

print("Image batch:", images.shape)
print("Feature batch:", features.shape)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 232MB/s]


Image batch: torch.Size([32, 3, 224, 224])
Feature batch: torch.Size([32, 49, 2048])


In [None]:
import torch
import torch.nn as nn

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size):
        super().__init__()

        self.hidden_size = hidden_size  

        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(2048, hidden_size, 512)
        self.lstm = nn.LSTMCell(embed_size + 2048, hidden_size) 
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, encoder_out, captions):
        batch_size = encoder_out.size(0)

        hidden = torch.zeros(batch_size, self.hidden_size).to(encoder_out.device)
        cell   = torch.zeros(batch_size, self.hidden_size).to(encoder_out.device)

        embeddings = self.embedding(captions[:, :-1]) 
        outputs = []
       
        for t in range(embeddings.size(1)):
            context = self.attention(encoder_out, hidden) 
            lstm_input = torch.cat([embeddings[:, t], context], dim=1)
            hidden, cell = self.lstm(lstm_input, (hidden, cell))
            outputs.append(self.fc(hidden)) 

        return torch.stack(outputs, dim=1)
    




In [27]:
def beam_search(encoder_out, decoder, word2idx, idx2word, beam_size=3, max_len=20):

    k = beam_size
    vocab_size = len(word2idx)

    encoder_out = encoder_out.expand(k, encoder_out.size(1), encoder_out.size(2))

    sequences = [[[], 0.0]]
    hidden = torch.zeros(k, hidden_size).to(encoder_out.device)
    cell = torch.zeros(k, hidden_size).to(encoder_out.device)

    for _ in range(max_len):
        all_candidates = []

        for i, (seq, score) in enumerate(sequences):
            if len(seq) > 0 and seq[-1] == word2idx["<end>"]:
                all_candidates.append((seq, score))
                continue

            word = torch.tensor([seq[-1]] if seq else [word2idx["<start>"]]).to(encoder_out.device)
            embed = decoder.embedding(word)

            context, _ = decoder.attention(encoder_out[i:i+1], hidden[i:i+1])
            lstm_input = torch.cat([embed, context], dim=1)

            h, c = decoder.lstm(lstm_input, (hidden[i:i+1], cell[i:i+1]))
            scores = decoder.fc(h)

            topk = scores.topk(k)

            for j in range(k):
                candidate = (seq + [topk.indices[0][j].item()],
                             score - topk.values[0][j].item())
                all_candidates.append(candidate)

        sequences = sorted(all_candidates, key=lambda x: x[1])[:k]

    best_seq = sequences[0][0]
    caption = [idx2word[idx] for idx in best_seq if idx2word[idx] not in ["<start>", "<end>"]]

    return " ".join(caption)


In [29]:
#testing
encoder = EncoderCNN()
decoder = DecoderRNN(
    embed_size=512,
    hidden_size=512,
    vocab_size=len(word2idx)
)

features = encoder(images)
print("Encoder output:", features.shape)
# (B, 49, 2048)

outputs = decoder(features, captions)
print("Decoder output:", outputs.shape)
# (B, seq_len-1, vocab_size)


Encoder output: torch.Size([32, 49, 2048])
Decoder output: torch.Size([32, 19, 2987])


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

embed_size = 512
hidden_size = 512
batch_size = 32
epochs = 5
lr = 0.001
vocab_size = len(word2idx)

encoder = EncoderCNN().to(device)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)

for epoch in range(epochs):
    encoder.train()
    decoder.train()
    total_loss = 0

    for images, captions, _ in tqdm(loader):
        images, captions = images.to(device), captions.to(device)

        optimizer.zero_grad()
        features = encoder(images)
        outputs = decoder(features, captions)

        targets = captions[:, 1:]
        loss = criterion(
            outputs.reshape(-1, vocab_size),
            targets.reshape(-1)
        )

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}  Loss: {total_loss/len(loader):.4f}")


100%|██████████| 1265/1265 [04:04<00:00,  5.17it/s]


Epoch 1/5  Loss: 3.3931


100%|██████████| 1265/1265 [04:16<00:00,  4.92it/s]


Epoch 2/5  Loss: 2.7139


100%|██████████| 1265/1265 [04:16<00:00,  4.93it/s]


Epoch 3/5  Loss: 2.4757


100%|██████████| 1265/1265 [04:16<00:00,  4.93it/s]


Epoch 4/5  Loss: 2.3098


100%|██████████| 1265/1265 [04:16<00:00,  4.93it/s]

Epoch 5/5  Loss: 2.1765





In [32]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "NO GPU")


True
Tesla T4


In [37]:
import torch
import pickle

torch.save(encoder.state_dict(), "encoder_attention.pth")
torch.save(decoder.state_dict(), "decoder_attention.pth")

with open("word2idx.pkl", "wb") as f:
    pickle.dump(word2idx, f)

with open("idx2word.pkl", "wb") as f:
    pickle.dump(idx2word, f)

print("✅ Training artifacts saved")


✅ Training artifacts saved
