In [1]:
!pip install -q torch torchvision nltk tqdm kagglehub

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
import random
import re
import os
import csv
from PIL import Image
from collections import Counter
from tqdm import tqdm
import nltk
from nltk.tokenize import word_tokenize

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nltk.download("punkt")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [3]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [4]:
DRIVE_ROOT = "/content/drive/MyDrive/FYP_Models/Attention_Captioning"
os.makedirs(DRIVE_ROOT, exist_ok=True)

MODEL_PATH = os.path.join(DRIVE_ROOT, "resnet50_attention_model.pth")
VOCAB_PATH = os.path.join(DRIVE_ROOT, "vocab.pt")
FEATURE_PATH = os.path.join(DRIVE_ROOT, "resnet50_features.pt")

print("Drive directory ready:", DRIVE_ROOT)

Drive directory ready: /content/drive/MyDrive/FYP_Models/Attention_Captioning


In [5]:
import kagglehub

path = kagglehub.dataset_download("adityajn105/flickr8k")
print("Dataset path:", path)

Using Colab cache for faster access to the 'flickr8k' dataset.
Dataset path: /kaggle/input/flickr8k


In [6]:
# Corrected paths based on Kaggle dataset structure
IMAGE_DIR = os.path.join(path, "Images")
CAPTION_FILE = os.path.join(path, "captions.txt")

print(len(os.listdir(IMAGE_DIR)), "images found")
print("Caption file exists:", os.path.exists(CAPTION_FILE))

8091 images found
Caption file exists: True


In [7]:
captions = {}

with open(CAPTION_FILE, "r") as f:
    reader = csv.reader(f) # Use csv.reader to parse the file
    next(reader) # Skip the header line
    for row in reader:
        if len(row) >= 2: # Ensure the row has at least two columns (image, caption)
            img_name = row[0] # Image name is directly in the first column
            caption_text = row[1] # Caption text is in the second column

            caption_text = "<start> " + caption_text.lower() + " <end>"
            captions.setdefault(img_name, []).append(caption_text)

        else:
            # Optionally, log or handle malformed rows if necessary
            print(f"Skipping malformed row: {row}")

print("Total images with captions:", len(captions))

Total images with captions: 8091


In [8]:
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [9]:
word_counter = Counter()

for caps_list in captions.values():
    for original_caption in caps_list:
        # Temporarily remove special tokens before tokenizing for word counting
        # This ensures word_tokenize only processes actual words, not parts of special tokens.
        temp_caption = original_caption.replace("<start>", "").replace("<end>", "").strip()
        word_counter.update(word_tokenize(temp_caption))

vocab_raw = [w for w, c in word_counter.items() if c >= 5]

word2idx = {}
idx2word = {}
idx = 0

# Explicitly add special tokens with specific indices
word2idx["<pad>"] = idx
idx2word[idx] = "<pad>"
idx += 1

word2idx["<start>"] = idx
idx2word[idx] = "<start>"
idx += 1

word2idx["<end>"] = idx
idx2word[idx] = "<end>"
idx += 1

# Add other words from the filtered vocab. The 'if w not in word2idx' check
# now prevents '<start>' and '<end>' from being added again if they happen to be in vocab_raw.
for w in vocab_raw:
    if w not in word2idx:
        word2idx[w] = idx
        idx2word[idx] = w
        idx += 1

vocab_size = len(word2idx)
print("Vocabulary size:", vocab_size)

# Save vocabulary to Google Drive
torch.save((word2idx, idx2word), VOCAB_PATH)
print("Vocabulary saved to Drive:", VOCAB_PATH)

Vocabulary size: 3004
Vocabulary saved to Drive: /content/drive/MyDrive/FYP_Models/Attention_Captioning/vocab.pt


In [10]:
def caption_to_seq(caption):
    tokens = word_tokenize(caption)
    return [word2idx.get(w, 0) for w in tokens]

In [11]:
def clean_caption(text):
    text = text.lower()
    text = re.sub(r"[^a-z<> ]", "", text)
    return text

In [12]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [13]:
resnet = models.resnet50(pretrained=True)
resnet = nn.Sequential(*list(resnet.children())[:-2])  # spatial map
resnet.eval().to(device)

for p in resnet.parameters():
    p.requires_grad = False



Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 97.8M/97.8M [00:00<00:00, 135MB/s]


In [14]:
if os.path.exists(FEATURE_PATH):
    print("Loading spatial features from Drive...")
    features = torch.load(FEATURE_PATH)
else:
    print("Extracting spatial features (one-time)...")
    features = {}

    with torch.no_grad():
        for img_name in tqdm(captions.keys()):
            img_path = os.path.join(IMAGE_DIR, img_name)
            image = Image.open(img_path).convert("RGB")
            image = transform(image).unsqueeze(0).to(device)

            feat = resnet(image)                    # [1,2048,7,7]
            feat = feat.view(1, 2048, -1)           # [1,2048,49]
            feat = feat.permute(0, 2, 1).squeeze()  # [49,2048]

            features[img_name] = feat.cpu()

    torch.save(features, FEATURE_PATH)
    print("Spatial features saved to Drive")

Extracting spatial features (one-time)...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8091/8091 [03:14<00:00, 41.64it/s]


Spatial features saved to Drive


In [15]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        self.enc_att = nn.Linear(encoder_dim, attention_dim)
        self.dec_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)

    def forward(self, encoder_out, hidden):
        att1 = self.enc_att(encoder_out)
        att2 = self.dec_att(hidden).unsqueeze(1)
        att = self.full_att(torch.tanh(att1 + att2)).squeeze(2)
        alpha = torch.softmax(att, dim=1)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha

In [16]:
class AttentionCaptionModel(nn.Module):
    def __init__(self, vocab_size, encoder_dim=2048,
                 embed_size=256, hidden_size=512, attention_dim=256):
        super().__init__()

        self.embedding = nn.Embedding(
            vocab_size, embed_size, padding_idx=0
        )

        self.attention = Attention(encoder_dim, hidden_size, attention_dim)

        self.lstm = nn.LSTMCell(embed_size + encoder_dim, hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size)

        self.dropout = nn.Dropout(0.5)

        self.init_h = nn.Linear(encoder_dim, hidden_size)
        self.init_c = nn.Linear(encoder_dim, hidden_size)

    def forward(self, encoder_out, captions, teacher_forcing_ratio=1.0):
        batch_size = captions.size(0)
        seq_len = captions.size(1)

        h = self.init_h(encoder_out.mean(dim=1))
        c = self.init_c(encoder_out.mean(dim=1))

        embeddings = self.embedding(captions)
        outputs = []

        word = embeddings[:, 0]

        for t in range(seq_len):
            context, _ = self.attention(encoder_out, h)
            lstm_input = torch.cat([word, context], dim=1)

            h, c = self.lstm(lstm_input, (h, c))
            out = self.fc(self.dropout(h))
            outputs.append(out)

            use_teacher = random.random() < teacher_forcing_ratio
            predicted = out.argmax(1)

            word = embeddings[:, t] if use_teacher else self.embedding(predicted)

        return torch.stack(outputs, dim=1)


In [17]:
model = AttentionCaptionModel(vocab_size).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=5e-5)

In [18]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import torch
import random

# -----------------------------
# Dataset class
# -----------------------------
class CaptionDataset(Dataset):
    def __init__(self, captions_dict, features_dict, word2idx):
        self.data = []
        for img_name, caps in captions_dict.items():
            for c in caps:
                seq = torch.tensor([word2idx.get(w, 0) for w in word_tokenize(c)])
                self.data.append((img_name, seq))
        self.features = features_dict

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name, seq = self.data[idx]
        feature = self.features[img_name]  # [49, 2048]
        return feature, seq

# -----------------------------
# Collate function
# -----------------------------
def collate_fn(batch):
    features, sequences = zip(*batch)
    features = torch.stack([f for f in features]).to(device)  # [batch, 49, 2048]
    lengths = torch.tensor([len(s) for s in sequences])
    sequences_padded = pad_sequence(sequences, batch_first=True, padding_value=0).to(device)
    return features, sequences_padded, lengths

# -----------------------------
# Prepare dataset & dataloader
# -----------------------------
dataset = CaptionDataset(captions, features, word2idx)
batch_size = 64
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)


# -----------------------------
# Training Loop
# -----------------------------
EPOCHS = 100

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    step_count = 0
    teacher_forcing_ratio = max(0.3, 1.0 - epoch * 0.01)

    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for features_batch, seq_batch, lengths in loop:
        batch_size, seq_len = seq_batch.shape

        # Inputs & targets
        inputs = seq_batch[:, :-1]
        targets = seq_batch[:, 1:]

        # Forward pass
        outputs = model(features_batch, inputs, teacher_forcing_ratio)  # [B, seq_len-1, vocab_size]

        # Compute loss
        loss = 0
        batch_size = outputs.size(0)

        for i in range(batch_size):
            length = lengths[i] - 1  # ignore <start>
            loss += criterion(
                outputs[i, :length],
                targets[i, :length]
            )

        loss = loss / batch_size

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        optimizer.step()

        total_loss += loss.item()
        step_count += 1

        loop.set_postfix(avg_loss=total_loss/step_count)

    print(f"Epoch {epoch+1}/{EPOCHS} finished | Avg Loss: {total_loss/step_count:.4f}")


Epoch 1/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:50<00:00,  5.75it/s, avg_loss=5]


Epoch 1/100 finished | Avg Loss: 5.0026


Epoch 2/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.22it/s, avg_loss=4.5]


Epoch 2/100 finished | Avg Loss: 4.5012


Epoch 3/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.26it/s, avg_loss=4.34]


Epoch 3/100 finished | Avg Loss: 4.3369


Epoch 4/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.24it/s, avg_loss=4.23]


Epoch 4/100 finished | Avg Loss: 4.2324


Epoch 5/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.21it/s, avg_loss=4.16]


Epoch 5/100 finished | Avg Loss: 4.1567


Epoch 6/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.24it/s, avg_loss=4.09]


Epoch 6/100 finished | Avg Loss: 4.0943


Epoch 7/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.23it/s, avg_loss=4.05]


Epoch 7/100 finished | Avg Loss: 4.0490


Epoch 8/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.24it/s, avg_loss=4.01]


Epoch 8/100 finished | Avg Loss: 4.0095


Epoch 9/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.24it/s, avg_loss=3.97]


Epoch 9/100 finished | Avg Loss: 3.9723


Epoch 10/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.21it/s, avg_loss=3.95]


Epoch 10/100 finished | Avg Loss: 3.9463


Epoch 11/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.23it/s, avg_loss=3.92]


Epoch 11/100 finished | Avg Loss: 3.9167


Epoch 12/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.25it/s, avg_loss=3.89]


Epoch 12/100 finished | Avg Loss: 3.8937


Epoch 13/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.25it/s, avg_loss=3.87]


Epoch 13/100 finished | Avg Loss: 3.8740


Epoch 14/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.22it/s, avg_loss=3.85]


Epoch 14/100 finished | Avg Loss: 3.8535


Epoch 15/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.25it/s, avg_loss=3.84]


Epoch 15/100 finished | Avg Loss: 3.8366


Epoch 16/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.24it/s, avg_loss=3.82]


Epoch 16/100 finished | Avg Loss: 3.8216


Epoch 17/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.25it/s, avg_loss=3.81]


Epoch 17/100 finished | Avg Loss: 3.8077


Epoch 18/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.25it/s, avg_loss=3.79]


Epoch 18/100 finished | Avg Loss: 3.7889


Epoch 19/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.23it/s, avg_loss=3.78]


Epoch 19/100 finished | Avg Loss: 3.7796


Epoch 20/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.23it/s, avg_loss=3.76]


Epoch 20/100 finished | Avg Loss: 3.7648


Epoch 21/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.25it/s, avg_loss=3.75]


Epoch 21/100 finished | Avg Loss: 3.7526


Epoch 22/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.21it/s, avg_loss=3.74]


Epoch 22/100 finished | Avg Loss: 3.7399


Epoch 23/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.21it/s, avg_loss=3.73]


Epoch 23/100 finished | Avg Loss: 3.7341


Epoch 24/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.22it/s, avg_loss=3.72]


Epoch 24/100 finished | Avg Loss: 3.7244


Epoch 25/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.72]


Epoch 25/100 finished | Avg Loss: 3.7199


Epoch 26/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.7]


Epoch 26/100 finished | Avg Loss: 3.7035


Epoch 27/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.19it/s, avg_loss=3.69]


Epoch 27/100 finished | Avg Loss: 3.6936


Epoch 28/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.69]


Epoch 28/100 finished | Avg Loss: 3.6905


Epoch 29/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.21it/s, avg_loss=3.69]


Epoch 29/100 finished | Avg Loss: 3.6862


Epoch 30/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.67]


Epoch 30/100 finished | Avg Loss: 3.6735


Epoch 31/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.22it/s, avg_loss=3.67]


Epoch 31/100 finished | Avg Loss: 3.6664


Epoch 32/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.24it/s, avg_loss=3.67]


Epoch 32/100 finished | Avg Loss: 3.6655


Epoch 33/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.66]


Epoch 33/100 finished | Avg Loss: 3.6626


Epoch 34/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.19it/s, avg_loss=3.65]


Epoch 34/100 finished | Avg Loss: 3.6540


Epoch 35/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.22it/s, avg_loss=3.65]


Epoch 35/100 finished | Avg Loss: 3.6471


Epoch 36/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.65]


Epoch 36/100 finished | Avg Loss: 3.6452


Epoch 37/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.64]


Epoch 37/100 finished | Avg Loss: 3.6440


Epoch 38/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.23it/s, avg_loss=3.65]


Epoch 38/100 finished | Avg Loss: 3.6461


Epoch 39/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.63]


Epoch 39/100 finished | Avg Loss: 3.6308


Epoch 40/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.22it/s, avg_loss=3.63]


Epoch 40/100 finished | Avg Loss: 3.6309


Epoch 41/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.23it/s, avg_loss=3.62]


Epoch 41/100 finished | Avg Loss: 3.6180


Epoch 42/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.23it/s, avg_loss=3.62]


Epoch 42/100 finished | Avg Loss: 3.6211


Epoch 43/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.19it/s, avg_loss=3.62]


Epoch 43/100 finished | Avg Loss: 3.6189


Epoch 44/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.62]


Epoch 44/100 finished | Avg Loss: 3.6153


Epoch 45/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.21it/s, avg_loss=3.62]


Epoch 45/100 finished | Avg Loss: 3.6156


Epoch 46/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.21it/s, avg_loss=3.61]


Epoch 46/100 finished | Avg Loss: 3.6115


Epoch 47/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.22it/s, avg_loss=3.61]


Epoch 47/100 finished | Avg Loss: 3.6112


Epoch 48/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.22it/s, avg_loss=3.61]


Epoch 48/100 finished | Avg Loss: 3.6083


Epoch 49/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:41<00:00,  6.21it/s, avg_loss=3.61]


Epoch 49/100 finished | Avg Loss: 3.6131


Epoch 50/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.61]


Epoch 50/100 finished | Avg Loss: 3.6087


Epoch 51/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.61]


Epoch 51/100 finished | Avg Loss: 3.6052


Epoch 52/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.6]


Epoch 52/100 finished | Avg Loss: 3.6041


Epoch 53/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.6]


Epoch 53/100 finished | Avg Loss: 3.6039


Epoch 54/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.6]


Epoch 54/100 finished | Avg Loss: 3.6049


Epoch 55/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.6]


Epoch 55/100 finished | Avg Loss: 3.5997


Epoch 56/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.6]


Epoch 56/100 finished | Avg Loss: 3.6011


Epoch 57/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.6]


Epoch 57/100 finished | Avg Loss: 3.5988


Epoch 58/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.16it/s, avg_loss=3.6]


Epoch 58/100 finished | Avg Loss: 3.6025


Epoch 59/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.15it/s, avg_loss=3.6]


Epoch 59/100 finished | Avg Loss: 3.5996


Epoch 60/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.19it/s, avg_loss=3.61]


Epoch 60/100 finished | Avg Loss: 3.6072


Epoch 61/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.17it/s, avg_loss=3.6]


Epoch 61/100 finished | Avg Loss: 3.5987


Epoch 62/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.6]


Epoch 62/100 finished | Avg Loss: 3.6003


Epoch 63/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.19it/s, avg_loss=3.6]


Epoch 63/100 finished | Avg Loss: 3.5974


Epoch 64/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.6]


Epoch 64/100 finished | Avg Loss: 3.5989


Epoch 65/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.6]


Epoch 65/100 finished | Avg Loss: 3.5995


Epoch 66/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.19it/s, avg_loss=3.6]


Epoch 66/100 finished | Avg Loss: 3.6034


Epoch 67/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.6]


Epoch 67/100 finished | Avg Loss: 3.5970


Epoch 68/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.17it/s, avg_loss=3.6]


Epoch 68/100 finished | Avg Loss: 3.5988


Epoch 69/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.6]


Epoch 69/100 finished | Avg Loss: 3.5997


Epoch 70/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.16it/s, avg_loss=3.59]


Epoch 70/100 finished | Avg Loss: 3.5941


Epoch 71/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.19it/s, avg_loss=3.6]


Epoch 71/100 finished | Avg Loss: 3.6026


Epoch 72/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.19it/s, avg_loss=3.6]


Epoch 72/100 finished | Avg Loss: 3.5955


Epoch 73/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.19it/s, avg_loss=3.59]


Epoch 73/100 finished | Avg Loss: 3.5936


Epoch 74/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.17it/s, avg_loss=3.59]


Epoch 74/100 finished | Avg Loss: 3.5884


Epoch 75/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.58]


Epoch 75/100 finished | Avg Loss: 3.5801


Epoch 76/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.58]


Epoch 76/100 finished | Avg Loss: 3.5838


Epoch 77/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.17it/s, avg_loss=3.57]


Epoch 77/100 finished | Avg Loss: 3.5703


Epoch 78/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.57]


Epoch 78/100 finished | Avg Loss: 3.5682


Epoch 79/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.56]


Epoch 79/100 finished | Avg Loss: 3.5627


Epoch 80/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.56]


Epoch 80/100 finished | Avg Loss: 3.5580


Epoch 81/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.56]


Epoch 81/100 finished | Avg Loss: 3.5635


Epoch 82/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.54]


Epoch 82/100 finished | Avg Loss: 3.5443


Epoch 83/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.55]


Epoch 83/100 finished | Avg Loss: 3.5490


Epoch 84/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.17it/s, avg_loss=3.55]


Epoch 84/100 finished | Avg Loss: 3.5457


Epoch 85/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.16it/s, avg_loss=3.54]


Epoch 85/100 finished | Avg Loss: 3.5391


Epoch 86/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.16it/s, avg_loss=3.53]


Epoch 86/100 finished | Avg Loss: 3.5306


Epoch 87/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.53]


Epoch 87/100 finished | Avg Loss: 3.5324


Epoch 88/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.53]


Epoch 88/100 finished | Avg Loss: 3.5297


Epoch 89/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.53]


Epoch 89/100 finished | Avg Loss: 3.5254


Epoch 90/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.17it/s, avg_loss=3.53]


Epoch 90/100 finished | Avg Loss: 3.5266


Epoch 91/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.52]


Epoch 91/100 finished | Avg Loss: 3.5183


Epoch 92/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.17it/s, avg_loss=3.51]


Epoch 92/100 finished | Avg Loss: 3.5095


Epoch 93/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.20it/s, avg_loss=3.51]


Epoch 93/100 finished | Avg Loss: 3.5082


Epoch 94/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.19it/s, avg_loss=3.51]


Epoch 94/100 finished | Avg Loss: 3.5068


Epoch 95/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.17it/s, avg_loss=3.5]


Epoch 95/100 finished | Avg Loss: 3.5026


Epoch 96/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.17it/s, avg_loss=3.5]


Epoch 96/100 finished | Avg Loss: 3.5008


Epoch 97/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.18it/s, avg_loss=3.5]


Epoch 97/100 finished | Avg Loss: 3.4967


Epoch 98/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.19it/s, avg_loss=3.49]


Epoch 98/100 finished | Avg Loss: 3.4876


Epoch 99/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:42<00:00,  6.17it/s, avg_loss=3.48]


Epoch 99/100 finished | Avg Loss: 3.4833


Epoch 100/100: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 633/633 [01:43<00:00,  6.14it/s, avg_loss=3.48]

Epoch 100/100 finished | Avg Loss: 3.4838





In [19]:
torch.save(model.state_dict(), MODEL_PATH)
print("Model saved to Google Drive:", MODEL_PATH)

Model saved to Google Drive: /content/drive/MyDrive/FYP_Models/Attention_Captioning/resnet50_attention_model.pth


In [20]:
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
print("Model loaded from Drive")

Model loaded from Drive


In [21]:
def generate_caption(image_path, max_len=30):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        feat = resnet(image)                 # [1,2048,7,7]
        feat = feat.view(1, 2048, -1)
        encoder_out = feat.permute(0, 2, 1)  # [1,49,2048]

        h = model.init_h(encoder_out.mean(dim=1))
        c = model.init_c(encoder_out.mean(dim=1))

        word = torch.tensor([word2idx["<start>"]]).to(device)
        result = []

        for _ in range(max_len):
            emb = model.embedding(word)
            context, _ = model.attention(encoder_out, h)

            lstm_input = torch.cat([emb, context], dim=1)
            h, c = model.lstm(lstm_input, (h, c))

            out = model.fc(h)
            predicted = out.argmax(1)

            token = idx2word[predicted.item()]
            if token == "<end>":
                break

            result.append(token)
            word = predicted

    return " ".join(result)

In [22]:
from google.colab import files

uploaded = files.upload()
image_path = list(uploaded.keys())[0]

caption = generate_caption(image_path)
print("üñºÔ∏è Generated Caption:")
print(caption)

Saving dog.png to dog.png
üñºÔ∏è Generated Caption:
start a a dog and white dog ball with ball ball . . . end end end end end end end end end end end end end end end end
