# CNN-LSTM Image Captioning Model

This notebook implements an image captioning system using CNN for feature extraction and LSTM for sequence generation. The model is trained on the Flickr30k dataset to generate descriptive captions for images.

## 📁 Dataset Overview

Exploring the Flickr30k dataset structure and loading the images and captions.


### Image Data Exploration

First, let's check how many images we have in our dataset directory.


In [None]:
import os

image_dir = "flickr30k_images"
image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

print(f"Total images found: {len(image_files)}")


Total images found: 31783


### Caption Data Loading

Loading and exploring the caption data from the CSV file. Each image has multiple captions (usually 5 per image).


In [3]:
import os
import pandas as pd

csv_path = os.path.join("flickr30k_images", "results.csv")
df = pd.read_csv(csv_path, delimiter="|", header=None, names=["image", "caption_number", "caption"])


print(df.head())
print("Total captions:", len(df))
print("Unique images:", df['image'].nunique())



            image   caption_number  \
0      image_name   comment_number   
1  1000092795.jpg                0   
2  1000092795.jpg                1   
3  1000092795.jpg                2   
4  1000092795.jpg                3   

                                             caption  
0                                            comment  
1   Two young guys with shaggy hair look at their...  
2   Two young , White males are outside near many...  
3   Two men in green shirts are standing in a yard .  
4       A man in a blue shirt standing in a garden .  
Total captions: 158916
Unique images: 31784


### Data Preprocessing

Cleaning up the dataset by removing header rows and preparing the data for training.


In [None]:
df = df.drop(index=0).reset_index(drop=True)

print("Total captions after cleanup:", len(df))
print("Unique images after cleanup:", df['image'].nunique())


Total captions after cleanup: 158915
Unique images after cleanup: 31783


In [5]:
df.head()

Unnamed: 0,image,caption_number,caption
0,1000092795.jpg,0,Two young guys with shaggy hair look at their...
1,1000092795.jpg,1,"Two young , White males are outside near many..."
2,1000092795.jpg,2,Two men in green shirts are standing in a yard .
3,1000092795.jpg,3,A man in a blue shirt standing in a garden .
4,1000092795.jpg,4,Two friends enjoy time spent together .


In [None]:
from collections import defaultdict

image_captions = defaultdict(list)

for _, row in df.iterrows():
    image_captions[row["image"]].append(row["caption"])

# Example check
print(image_captions["1000092795.jpg"])


[' Two young guys with shaggy hair look at their hands while hanging out in the yard .', ' Two young , White males are outside near many bushes .', ' Two men in green shirts are standing in a yard .', ' A man in a blue shirt standing in a garden .', ' Two friends enjoy time spent together .']


In [None]:
import re

def clean_caption(caption):
    if not isinstance(caption, str):
        return None
    caption = caption.lower()
    caption = re.sub(r"[^a-z ]", "", caption) 
    caption = caption.strip()
    return f"<start> {caption} <end>"

for img, caps in image_captions.items():
    cleaned = [clean_caption(c) for c in caps]
    image_captions[img] = [c for c in cleaned if c is not None]

# Example check
print(image_captions["1000092795.jpg"])


['<start> two young guys with shaggy hair look at their hands while hanging out in the yard <end>', '<start> two young  white males are outside near many bushes <end>', '<start> two men in green shirts are standing in a yard <end>', '<start> a man in a blue shirt standing in a garden <end>', '<start> two friends enjoy time spent together <end>']


In [8]:
missing_count = sum(1 for caps in image_captions.values() for c in caps if c is None)
print("Missing captions:", missing_count)


Missing captions: 0


In [9]:
# Total images
print("Total images:", len(image_captions))

# Check how many captions per image
caption_counts = [len(caps) for caps in image_captions.values()]
print("Min captions per image:", min(caption_counts))
print("Max captions per image:", max(caption_counts))

# Check if all are 5
all_five = all(count == 5 for count in caption_counts)
print("All images have exactly 5 captions?", all_five)


Total images: 31783
Min captions per image: 4
Max captions per image: 5
All images have exactly 5 captions? False


In [10]:
# Keep only images with 5 captions
image_captions = {img: caps for img, caps in image_captions.items() if len(caps) == 5}

print("Cleaned dataset size:", len(image_captions))  # should be slightly less than 31783


Cleaned dataset size: 31782


In [11]:
df.head(10)

Unnamed: 0,image,caption_number,caption
0,1000092795.jpg,0,Two young guys with shaggy hair look at their...
1,1000092795.jpg,1,"Two young , White males are outside near many..."
2,1000092795.jpg,2,Two men in green shirts are standing in a yard .
3,1000092795.jpg,3,A man in a blue shirt standing in a garden .
4,1000092795.jpg,4,Two friends enjoy time spent together .
5,10002456.jpg,0,Several men in hard hats are operating a gian...
6,10002456.jpg,1,Workers look down from up above on a piece of...
7,10002456.jpg,2,Two men working on a machine wearing hard hats .
8,10002456.jpg,3,Four men on top of a tall structure .
9,10002456.jpg,4,Three men on a large rig .


In [12]:
df = df.dropna(subset=["caption"])

df["caption"] = df["caption"].apply(lambda x: "<start> " + x.lower() + " <end>")


## 📝 Vocabulary Construction

Building vocabulary from captions and preparing text preprocessing utilities.


In [None]:
from collections import Counter
import nltk

nltk.download("punkt", quiet=True)

class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    def tokenizer(self, text):
        return nltk.tokenize.word_tokenize(text.lower())

    def build_vocabulary(self, sentence_list):
        frequencies = Counter()
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer(sentence):
                frequencies[word] += 1
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized = self.tokenizer(text)
        return [
            self.stoi.get(token, self.stoi["<unk>"])
            for token in tokenized
        ]


In [None]:
import nltk
nltk.download("punkt")
nltk.download("punkt_tab")


[nltk_data] Downloading package punkt to /home/sushi/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/sushi/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [None]:
captions = df["caption"].tolist()

vocab = Vocabulary(freq_threshold=5)
vocab.build_vocabulary(captions)

print("Vocab size:", len(vocab))


Vocab size: 7738


In [16]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
import os
import torchvision.transforms as transforms

class FlickrDataset(Dataset):
    def __init__(self, df, image_dir, vocab, transform=None):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.vocab = vocab
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        caption = self.df.loc[idx, "caption"]
        img_id = self.df.loc[idx, "image"]
        img_path = os.path.join(self.image_dir, img_id)

        # Load image
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Numericalize caption
        tokens = [self.vocab.stoi["<start>"]] + \
                 self.vocab.numericalize(caption) + \
                 [self.vocab.stoi["<end>"]]

        return image, torch.tensor(tokens)

# Pad captions in batch
def collate_fn(batch):
    images, captions = zip(*batch)
    images = torch.stack(images, 0)
    captions = pad_sequence(captions, batch_first=True, padding_value=0)
    return images, captions

# Data transforms
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))
])


## 🏗️ CNN-LSTM Architecture

Defining the CNN-LSTM model that combines visual features with sequence generation.


In [17]:
import torchvision.models as models
import torch.nn as nn

class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        modules = list(resnet.children())[:-1]  # remove fc
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        features = features.view(features.size(0), -1)
        return self.bn(self.linear(features))


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions[:, :-1]))
        inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
        hiddens, _ = self.lstm(inputs)
        return self.linear(hiddens)


class ImageCaptioningModel(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size):
        super().__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs


## ⚙️ Training Setup

Setting up data loaders, loss functions, and training parameters.


In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Params
embed_size = 256
hidden_size = 512
vocab_size = len(vocab)
num_epochs = 5
batch_size = 64
lr = 3e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data
dataset = FlickrDataset(df, "flickr30k_images", vocab, transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# Model
model = ImageCaptioningModel(embed_size, hidden_size, vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # ignore <pad>
optimizer = optim.Adam(model.parameters(), lr=lr)


## 🚀 Model Training

Training the CNN-LSTM model with progress tracking and validation.


In [None]:

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for i, (images, captions) in enumerate(loader):
        images, captions = images.to(device), captions.to(device)

        # Feed everything except the last token as input
        inputs = captions[:, :-1]      # <start> ... <word_n-1>
        targets = captions[:, 1:]      # <word_1> ... <end>

        outputs = model(images, inputs)  # (batch, seq_len-1, vocab_size)

        # Compute loss
        loss = criterion(
            outputs.reshape(-1, vocab_size),
            targets.reshape(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(loader)}], Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_loss:.4f}")

    checkpoint = {
        "epoch": epoch + 1,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "loss": avg_loss
    }
    torch.save(checkpoint, f"caption_model_epoch{epoch+1}.pth")


Epoch [1/5], Step [0/2484], Loss: 8.9638
Epoch [1/5], Step [100/2484], Loss: 3.9689
Epoch [1/5], Step [200/2484], Loss: 3.6687
Epoch [1/5], Step [300/2484], Loss: 3.4072
Epoch [1/5], Step [400/2484], Loss: 3.5085
Epoch [1/5], Step [500/2484], Loss: 3.4992
Epoch [1/5], Step [600/2484], Loss: 3.3269
Epoch [1/5], Step [700/2484], Loss: 3.3737
Epoch [1/5], Step [800/2484], Loss: 3.2639
Epoch [1/5], Step [900/2484], Loss: 3.3896
Epoch [1/5], Step [1000/2484], Loss: 2.9495
Epoch [1/5], Step [1100/2484], Loss: 3.2429
Epoch [1/5], Step [1200/2484], Loss: 3.1634
Epoch [1/5], Step [1300/2484], Loss: 3.1939
Epoch [1/5], Step [1400/2484], Loss: 3.0893
Epoch [1/5], Step [1500/2484], Loss: 3.1787
Epoch [1/5], Step [1600/2484], Loss: 2.9403
Epoch [1/5], Step [1700/2484], Loss: 3.1467
Epoch [1/5], Step [1800/2484], Loss: 3.1561
Epoch [1/5], Step [1900/2484], Loss: 3.1182
Epoch [1/5], Step [2000/2484], Loss: 3.1382
Epoch [1/5], Step [2100/2484], Loss: 3.3731
Epoch [1/5], Step [2200/2484], Loss: 2.9499


In [51]:
torch.save(model.state_dict(), "caption_model_final.pth")


In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler

In [None]:
import torch
checkpoint = torch.load("caption_model_epoch5.pth", map_location="cpu")
model.load_state_dict(checkpoint["model_state"])


  checkpoint = torch.load("caption_model_epoch5.pth", map_location="cpu")


<All keys matched successfully>

In [None]:
def fixed_resume_training(checkpoint_path, model, vocab, dataset, num_additional_epochs=10):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state"])
    
    optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4, betas=(0.9, 0.999))
    
    print("Starting with fresh optimizer state (better for LR change)")
    
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
    
    loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, 
                       num_workers=4, pin_memory=True)
    
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<pad>"])
    
    start_epoch = checkpoint["epoch"]
    best_loss = checkpoint["loss"]
    
    print(f"FIXED training from epoch {start_epoch}")
    print(f"New learning rate: 2e-5 (much lower!)")
    print(f"New batch size: 32 (more stable)")
    print(f"Target to beat: {best_loss:.4f}")
    
    # Training loop
    for epoch in range(start_epoch, start_epoch + num_additional_epochs):
        model.train()
        total_loss = 0
        num_batches = len(loader)
        
        for i, (images, captions) in enumerate(loader):
            images, captions = images.to(device), captions.to(device)
            
            inputs = captions[:, :-1]
            targets = captions[:, 1:]
            
            outputs = model(images, inputs)
            loss = criterion(
                outputs.reshape(-1, len(vocab)),
                targets.reshape(-1)
            )
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Lower clipping
            optimizer.step()
            
            total_loss += loss.item()
            
            if i % 100 == 0:
                current_lr = optimizer.param_groups[0]['lr']
                print(f"Epoch [{epoch+1}], Step [{i}/{num_batches}], "
                      f"Loss: {loss.item():.4f}, LR: {current_lr:.7f}")
        
        avg_loss = total_loss / num_batches
        print(f"Epoch [{epoch+1}], Avg Loss: {avg_loss:.4f}")
        
        scheduler.step(avg_loss)
        
        if avg_loss < best_loss:
            best_loss = avg_loss
            improvement = ((checkpoint["loss"] - avg_loss) / checkpoint["loss"]) * 100
            torch.save({
                "epoch": epoch + 1,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "loss": avg_loss
            }, f"caption_model_epoch{epoch+1}_fixed.pth")
            print(f"✅ IMPROVEMENT! Saved model with loss: {avg_loss:.4f} ({improvement:.1f}% better)")
        else:
            print(f"⚠️ No improvement (current: {avg_loss:.4f} vs best: {best_loss:.4f})")

fixed_resume_training("caption_model_epoch5.pth", model, vocab, dataset, num_additional_epochs=10)


Starting with fresh optimizer state (better for LR change)
FIXED training from epoch 5
New learning rate: 2e-5 (much lower!)
New batch size: 32 (more stable)
Target to beat: 2.5934


  checkpoint = torch.load(checkpoint_path, map_location=device)


Epoch [6], Step [0/4967], Loss: 2.7071, LR: 0.0000200
Epoch [6], Step [100/4967], Loss: 2.4724, LR: 0.0000200


KeyboardInterrupt: 

In [None]:
from sklearn.model_selection import train_test_split

unique_images = df['image'].unique()

train_val_images, test_images = train_test_split(unique_images, test_size=0.1, random_state=42)

train_images, val_images = train_test_split(train_val_images, test_size=0.1, random_state=42)
train_df = df[df['image'].isin(train_images)].reset_index(drop=True)
val_df = df[df['image'].isin(val_images)].reset_index(drop=True)
test_df = df[df['image'].isin(test_images)].reset_index(drop=True)

print(f"Train samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Test samples: {len(test_df)}")


Train samples: 128714
Validation samples: 14305
Test samples: 15895


In [43]:
train_dataset = FlickrDataset(train_df, image_dir, vocab, transform)
val_dataset = FlickrDataset(val_df, image_dir, vocab, transform)
test_dataset = FlickrDataset(test_df, image_dir, vocab, transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageCaptioningModel(embed_size=256, hidden_size=512, vocab_size=len(vocab)).to(device)

In [None]:
checkpoint = torch.load("caption_model_epoch15_fixed.pth", map_location='cpu')  # or whatever your checkpoint is named
model.load_state_dict(checkpoint["model_state"])

  checkpoint = torch.load("caption_model_epoch15_fixed.pth", map_location='cpu')  # or whatever your checkpoint is named


<All keys matched successfully>

In [None]:
optimizer = optim.Adam(model.parameters(), lr=5e-5) 
try:
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    print("✅ Loaded existing optimizer state")
except:
    print("⚠️ Starting with fresh optimizer state")

from torch.optim.lr_scheduler import StepLR
scheduler = StepLR(optimizer, step_size=3, gamma=0.7)  # Reduce LR by 30% every 3 epochs

criterion = nn.CrossEntropyLoss(ignore_index=0)


✅ Loaded existing optimizer state


In [None]:
start_epoch = checkpoint.get("epoch", 15)  
additional_epochs = 14 
total_epochs = start_epoch + additional_epochs

print(f"Resuming training from epoch {start_epoch}")
print(f"Will train until epoch {total_epochs}")

Resuming training from epoch 16
Will train until epoch 30


In [None]:
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0 
        else:
            self.counter += 1
            print(f"⚠️ EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                print("🛑 Early stopping triggered!")
                self.early_stop = True


In [None]:
early_stopping = EarlyStopping(patience=3, min_delta=0.01)

for epoch in range(start_epoch, total_epochs):
    model.train()
    total_train_loss = 0
    
    for batch_idx, (images, captions) in enumerate(train_loader):
        images = images.to(device)
        captions = captions.to(device)
        
        inputs = captions[:, :-1]
        targets = captions[:, 1:]
        
        outputs = model(images, inputs)
        loss = criterion(outputs.reshape(-1, len(vocab)), targets.reshape(-1))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
        
        if batch_idx % 200 == 0:
            print(f"Epoch [{epoch+1}/{total_epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    model.eval()
    total_val_loss = 0
    
    with torch.no_grad():
        for images, captions in val_loader:
            images = images.to(device)
            captions = captions.to(device)
            
            inputs = captions[:, :-1]
            targets = captions[:, 1:]
            
            outputs = model(images, inputs)
            loss = criterion(outputs.reshape(-1, len(vocab)), targets.reshape(-1))
            
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss / len(val_loader)
    
    scheduler.step()
    
    print(f"Epoch [{epoch+1}/{total_epochs}]")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss: {avg_val_loss:.4f}")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    early_stopping(avg_val_loss)
    if early_stopping.early_stop:
        print("🛑 Training stopped early to prevent overfitting!")
        break
    
    checkpoint = {
        "epoch": epoch + 1,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "train_loss": avg_train_loss,
        "val_loss": avg_val_loss,
    }
    torch.save(checkpoint, f"continued_model_epoch{epoch+1}.pth")
    print(f"  💾 Saved checkpoint")
    
    print("-" * 50)

print("✅ Training completed!")


Epoch [17/30], Step [0/2012], Loss: 2.5145, LR: 0.000020
Epoch [17/30], Step [200/2012], Loss: 2.5323, LR: 0.000020
Epoch [17/30], Step [400/2012], Loss: 2.4555, LR: 0.000020
Epoch [17/30], Step [600/2012], Loss: 2.5598, LR: 0.000020
Epoch [17/30], Step [800/2012], Loss: 2.5073, LR: 0.000020
Epoch [17/30], Step [1000/2012], Loss: 2.5879, LR: 0.000020
Epoch [17/30], Step [1200/2012], Loss: 2.4503, LR: 0.000020
Epoch [17/30], Step [1400/2012], Loss: 2.4339, LR: 0.000020
Epoch [17/30], Step [1600/2012], Loss: 2.4911, LR: 0.000020
Epoch [17/30], Step [1800/2012], Loss: 2.4284, LR: 0.000020
Epoch [17/30], Step [2000/2012], Loss: 2.6117, LR: 0.000020
Epoch [17/30]
  Train Loss: 2.4805
  Val Loss: 2.4493
  Learning Rate: 0.000020
  💾 Saved checkpoint
--------------------------------------------------
Epoch [18/30], Step [0/2012], Loss: 2.5959, LR: 0.000020
Epoch [18/30], Step [200/2012], Loss: 2.4939, LR: 0.000020
Epoch [18/30], Step [400/2012], Loss: 2.3293, LR: 0.000020
Epoch [18/30], Step 

In [None]:
model.eval()

test_images = ["1000092795.jpg", "10002456.jpg"]

print("🧪 TESTING FINAL MODEL QUALITY")
print("=" * 50)

for img_name in test_images:
    img_path = f"flickr30k_images/{img_name}"
    
    caption = generate_caption_beam_search(img_path, model, vocab, device, beam_width=5)
    
    print(f"📷 Image: {img_name}")
    print(f"🤖 Generated Caption: {caption}")
    print("-" * 30)


🧪 TESTING FINAL MODEL QUALITY
📷 Image: 1000092795.jpg
🤖 Generated Caption: A man in a blue shirt is a a in . of.
------------------------------
📷 Image: 10002456.jpg
🤖 Generated Caption: Two men are working on a.
------------------------------


## 📊 Model Evaluation

Testing the trained model and computing evaluation metrics like BLEU scores.


In [None]:
def sanitize_caption(caption_words):
    import re
    
    if isinstance(caption_words, list):
        caption = " ".join(caption_words)
    else:
        caption = str(caption_words)
    
    caption = re.sub(r'\b(start|end)\b', '', caption, flags=re.IGNORECASE)
    
    caption = re.sub(r'[<>]', '', caption)  # Remove < >
    caption = re.sub(r'\s+', ' ', caption)  # Multiple spaces to single
    caption = re.sub(r'^\W+|\W+$', '', caption)  # Leading/trailing punctuation
    
    caption = caption.strip()
    if caption:
        caption = caption[0].upper() + caption[1:] if len(caption) > 1 else caption.upper()
        if not caption.endswith('.'):
            caption += '.'
    
    return caption if caption else "No caption generated."


In [None]:
def generate_better_caption(image_path, model, vocab, device, max_len=15, beam_width=5):
    from PIL import Image
    import torch
    
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        features = model.encoder(image)
        
        sequences = [([vocab.stoi["<start>"]], 0.0)]
        
        for step in range(max_len):
            all_candidates = []
            
            for seq, score in sequences:
                if seq[-1] == vocab.stoi["<end>"]:
                    all_candidates.append((seq, score))
                    continue
                
                input_seq = torch.tensor([seq]).to(device)
                
                outputs = model.decoder(features, input_seq)
                probs = torch.softmax(outputs[:, -1, :], dim=-1)
                
                top_probs, top_indices = torch.topk(probs, beam_width)
                
                for prob, idx in zip(top_probs[0], top_indices[0]):
                    candidate_seq = seq + [idx.item()]
                    candidate_score = score - torch.log(prob).item()
                    all_candidates.append((candidate_seq, candidate_score))
            
            sequences = sorted(all_candidates, key=lambda x: x[1])[:beam_width]
            
            if all(seq[0][-1] == vocab.stoi["<end>"] for seq in sequences):
                break
        
        best_sequence = sequences[0][0]
        
        words = [vocab.itos[idx] for idx in best_sequence 
                if idx in vocab.itos and vocab.itos[idx] not in ["<start>", "<end>", "<pad>"]]
        
        # Join and clean
        caption = " ".join(words)
        return sanitize_caption(caption)


In [None]:
print("🔧 TESTING IMPROVED CAPTION GENERATION")
print("=" * 50)

test_images = ["2878190821.jpg", "617999370.jpg", "4557307607.jpg"]

for img_name in test_images:
    img_path = f"flickr30k_images/{img_name}"
    
    # Original method
    old_caption = generate_caption_beam_search(img_path, model, vocab, device, beam_width=5)
    
    # Improved method
    new_caption = generate_better_caption(img_path, model, vocab, device, beam_width=5)
    
    print(f"\n📷 {img_name}")
    print(f"❌ Old: {old_caption}")
    print(f"✅ New: {new_caption}")
    print("-" * 40)


🔧 TESTING IMPROVED CAPTION GENERATION

📷 2878190821.jpg
❌ Old: A little boy in a blue shirt is a a in.
✅ New: A little boy in a blue shirt and a shorts hat a.
----------------------------------------

📷 617999370.jpg
❌ Old: A man in a blue shirt is a a in in of of.
✅ New: A man in a blue shirt is a a in on of.
----------------------------------------

📷 4557307607.jpg
❌ Old: A group of people are sitting around a table.
✅ New: A group of people are sitting around a table . a.
----------------------------------------


In [None]:
special_tokens = ['<start>', '<end>', '<pad>', '<unk>']
for token in special_tokens:
    if token in vocab.stoi:
        print(f"✅ {token}: index {vocab.stoi[token]}")
    else:
        print(f"❌ {token}: missing!")


✅ <start>: index 1
✅ <end>: index 2
✅ <pad>: index 0
✅ <unk>: index 3


In [None]:
def build_improved_vocab(df, min_freq=3, max_vocab=10000):
    from collections import Counter
    import re
    
    all_words = []
    for caption in df['caption']:
        # Clean caption
        clean_caption = re.sub(r'[^\w\s]', ' ', caption.lower())
        words = clean_caption.split()
        all_words.extend(words)
    
    # Count frequencies
    word_freq = Counter(all_words)
    
    # Create vocabulary
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')
    
    common_words = word_freq.most_common(max_vocab - 4)
    for word, freq in common_words:
        if freq >= min_freq and len(word) > 1:  # Skip single characters
            vocab.add_word(word)
    
    return vocab

In [34]:
import pickle

with open("vocab.pkl", "wb") as f:
    pickle.dump(vocab, f)


In [None]:
def fix_repetitive_caption(caption):
    """Remove repetitive words and fix common patterns"""
    import re
    from collections import defaultdict
    
    # Split into words
    words = caption.strip().split()
    if not words:
        return "No caption available."
    
    clean_words = [words[0]]
    for word in words[1:]:
        if word != clean_words[-1]:
            clean_words.append(word)
    
    final_words = []
    for i, word in enumerate(clean_words):
        recent_context = final_words[-3:] if len(final_words) >= 3 else final_words
        if recent_context.count(word) < 2:  # Allow max 2 occurrences in recent context
            final_words.append(word)
    
    text = " ".join(final_words)
    
    # Pattern fixes
    replacements = [
        (r'\ba a\b', 'a'),
        (r'\bin in\b', 'in'),  
        (r'\bof of\b', 'of'),
        (r'\bis a in\b', 'is standing in'),
        (r'\bis a a\b', 'is wearing a'),
        (r'\bare a on\b', 'are sitting on'),
        (r'\band a shorts hat a\b', 'wearing shorts and a hat'),
        (r'\s+\.\s*a\s*$', '.'),  # Remove trailing ". a"
        (r'\s+a\s*$', '.'),       # Remove trailing "a"
    ]
    
    for pattern, replacement in replacements:
        text = re.sub(pattern, replacement, text)
    
    text = re.sub(r'\s+', ' ', text).strip()
    
    if text:
        text = text[0].upper() + text[1:] if len(text) > 1 else text.upper()
        if not text.endswith('.'):
            text += '.'
    
    return text


In [None]:
def beam_search_with_penalty(model, image_tensor, vocab, device, beam_width=5, max_len=15, repetition_penalty=1.2):
    
    model.eval()
    with torch.no_grad():
        features = model.encoder(image_tensor)
        
        sequences = [([vocab.stoi["<start>"]], 0.0)]
        
        for step in range(max_len):
            all_candidates = []
            
            for seq, score in sequences:
                if seq[-1] == vocab.stoi["<end>"]:
                    all_candidates.append((seq, score))
                    continue
                
                input_seq = torch.tensor([seq]).to(device)
                outputs = model.decoder(features, input_seq)
                log_probs = torch.log_softmax(outputs[:, -1, :], dim=-1)
                
                for word_idx in set(seq[1:]): 
                    if word_idx in vocab.itos:
                        log_probs[0, word_idx] /= repetition_penalty
                
                top_log_probs, top_indices = torch.topk(log_probs, beam_width)
                
                for log_prob, idx in zip(top_log_probs[0], top_indices[0]):
                    candidate_seq = seq + [idx.item()]
                    candidate_score = score - log_prob.item()
                    all_candidates.append((candidate_seq, candidate_score))
            
            sequences = sorted(all_candidates, key=lambda x: x[1])[:beam_width]
            
            if all(seq[0][-1] == vocab.stoi["<end>"] for seq in sequences):
                break
        
        best_sequence = sequences[0][0]
        
        words = []
        for idx in best_sequence[1:-1]:
            if idx in vocab.itos:
                words.append(vocab.itos[idx])
        
        caption = " ".join(words)
        return fix_repetitive_caption(caption)


In [None]:
def generate_final_caption(image_path, model, vocab, device):
    from PIL import Image
    
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)
    
    caption = beam_search_with_penalty(
        model, image, vocab, device, 
        beam_width=5, max_len=15, repetition_penalty=1.3
    )
    
    return caption


### Sample Results

Visualizing some example images with their generated captions.


In [None]:
print("🔧 TESTING FINAL IMPROVED CAPTION GENERATION")
print("=" * 60)

test_images = ["2878190821.jpg", "617999370.jpg", "4557307607.jpg"]

for img_name in test_images:
    img_path = f"flickr30k_images/{img_name}"
    
    # current method
    old_caption = generate_caption_beam_search(img_path, model, vocab, device, beam_width=5)
    
    # New improved method
    final_caption = generate_final_caption(img_path, model, vocab, device)
    
    print(f"\n📷 {img_name}")
    print(f"❌ Current: {old_caption}")
    print(f"✅ Fixed: {final_caption}")
    print("-" * 50)


🔧 TESTING FINAL IMPROVED CAPTION GENERATION

📷 2878190821.jpg
❌ Current: A little boy in a blue shirt is a a in.
✅ Fixed: < start > a young boy in a blue shirt is standing in.
--------------------------------------------------

📷 617999370.jpg
❌ Current: A man in a blue shirt is a a in in of of.
✅ Fixed: < start > a man in a blue shirt is standing in of.
--------------------------------------------------

📷 4557307607.jpg
❌ Current: A group of people are sitting around a table.
✅ Fixed: < start > a group of people sitting a . end.
--------------------------------------------------
