In [None]:
import random
import os, pickle, torch, torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from tqdm import tqdm
from collections import Counter
from torch.utils.data import Dataset, DataLoader, random_split
from train import Vocabulary, ImageCaptioner, greedy_search, load_captions, compute_bleu4

In [None]:
def find_image_dir():
    # Common Kaggle root

    base_input = '/kaggle/input'
    # Walk through the input directory to find where the images actually are
    for root, dirs, files in os.walk(base_input):
    # Look for the folder containing a high volume of jpg files
        if len([f for f in files if f.endswith('.jpg')]) > 1000:
            return root
    return None
IMAGE_DIR = find_image_dir()
OUTPUT_FILE = 'flickr30k_features.pkl'
if IMAGE_DIR:
    print(f" Found images at: {IMAGE_DIR}")
else:
    raise FileNotFoundError("Could not find the Flickr30k image directory. Please ensure the dataset is added to the notebook.")
# --- THE DATASET CLASS ---
class FlickrDataset(Dataset):
    def __init__(self, img_dir, transform):
        self.img_names = [f for f in os.listdir(img_dir) if f.endswith(('.jpg', '.jpeg'))]
        self.transform = transform
        self.img_dir = img_dir
    def __len__(self):
        return len(self.img_names)
    def __getitem__(self, idx):
        name = self.img_names[idx]
        img_path = os.path.join(self.img_dir, name)
        img = Image.open(img_path).convert('RGB')
        return self.transform(img), name
# --- REMAINDER OF THE PIPELINE (AS BEFORE) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model = nn.Sequential(*list(model.children())[:-1]) # Feature vector only
model = nn.DataParallel(model).to(device)
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
dataset = FlickrDataset(IMAGE_DIR, transform)
loader = DataLoader(dataset, batch_size=128, num_workers=4)
features_dict = {}
with torch.no_grad():
    for imgs, names in tqdm(loader, desc="Extracting Features"):
        feats = model(imgs.to(device)).view(imgs.size(0), -1)
        for i, name in enumerate(names):
            features_dict[name] = feats[i].cpu().numpy()
with open(OUTPUT_FILE, 'wb') as f:
    pickle.dump(features_dict, f)
print(f"Success! {len(features_dict)} images processed and saved to {OUTPUT_FILE}")

In [None]:
class Vocabulary:
    PAD, START, END, UNK = "<pad>", "<start>", "<end>", "<unk>"

    def __init__(self, min_freq=2):
        self.word2idx = {self.PAD: 0, self.START: 1, self.END: 2, self.UNK: 3}
        self.idx2word = {v: k for k, v in self.word2idx.items()}
        self.min_freq = min_freq

    def _tokenize(self, text):
        return text.lower().strip().split()

    def build(self, captions):
        count = Counter()
        for cap in captions:
            count.update(self._tokenize(cap))
        for word, freq in count.items():
            if freq >= self.min_freq and word not in self.word2idx:
                i = len(self.word2idx)
                self.word2idx[word] = i
                self.idx2word[i] = word

    def encode(self, caption, add_special=True):
        out = [self.word2idx[self.START]] if add_special else []
        for w in self._tokenize(caption):
            out.append(self.word2idx.get(w, self.word2idx[self.UNK]))
        if add_special:
            out.append(self.word2idx[self.END])
        return out

    def decode(self, ids):
        return " ".join(self.idx2word.get(i, self.UNK) for i in ids if i not in {0, 1, 2})


def load_captions(path):
    pairs = []
    with open(path, "r", encoding="utf-8") as f:
        lines = f.readlines()
    for line in lines[1 if "image" in lines[0].lower() else 0:]:
        if "," in line:
            img, cap = line.strip().split(",", 1)
            cap = cap.strip('"').strip()
            if cap:
                pairs.append((img.strip(), cap))
    return pairs


class Encoder(nn.Module):
    def __init__(self, input_dim=2048, hidden=512, dropout=0.5):
        super().__init__()
        self.fc = nn.Linear(input_dim, hidden)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc(x)
        x = self.dropout(x)
        return x


class ImageCaptioner(nn.Module):
    def __init__(self, vocab_size, hidden=1024, embed=512, dropout=0.5):
        super().__init__()
        self.encoder = Encoder(2048, hidden)
        self.embed = nn.Embedding(vocab_size, embed, padding_idx=0)
        self.embed_dropout = nn.Dropout(dropout)
        self.lstm = nn.LSTM(embed, hidden, batch_first=True, dropout=0)
        self.fc_dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden, vocab_size)

    def forward(self, feats, caps):
        h = self.encoder(feats).unsqueeze(0)
        c = h.new_zeros(h.size())
        emb = self.embed_dropout(self.embed(caps[:, :-1]))
        out, _ = self.lstm(emb, (h, c))
        out = self.fc_dropout(out)
        return self.fc(out)


def greedy_search(model, feat, vocab, max_len=50, device="cpu"):
    model.eval()
    feat = feat.to(device).unsqueeze(0)
    h = model.encoder(feat).unsqueeze(0)
    c = h.new_zeros(h.size())
    tokens = [1]
    with torch.no_grad():
        for _ in range(max_len - 1):
            x = torch.tensor([[tokens[-1]]], dtype=torch.long, device=device)
            emb = model.embed(x)
            out, (h, c) = model.lstm(emb, (h, c))
            logits = model.fc(out.squeeze(1))
            next_id = logits.argmax(dim=-1).item()
            tokens.append(next_id)
            if next_id == 2:
                break
    return vocab.decode(tokens)


def compute_bleu4(refs, hyps):
    try:
        from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
    except ImportError:
        return 0.0
    refs = [[r.split()] for r in refs]
    hyps = [h.split() for h in hyps]
    return corpus_bleu(refs, hyps, weights=(0.25, 0.25, 0.25, 0.25),
                      smoothing_function=SmoothingFunction().method1)


class CaptionDataset(Dataset):
    def __init__(self, feat_path, pairs, vocab, max_len=50):
        with open(feat_path, "rb") as f:
            self.feats = pickle.load(f)
        self.pairs = [(img, cap) for img, cap in pairs if img in self.feats]
        self.vocab = vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, i):
        img, cap = self.pairs[i]
        ids = self.vocab.encode(cap)
        if len(ids) > self.max_len:
            ids = ids[: self.max_len - 1] + [2]
        ids += [0] * (self.max_len - len(ids))
        return torch.tensor(self.feats[img], dtype=torch.float32), torch.tensor(ids, dtype=torch.long)


def collate_batch(batch):
    return torch.stack([b[0] for b in batch]), torch.stack([b[1] for b in batch])

In [None]:
def load_model(checkpoint_path="caption_model.pt", device="cpu"):
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}. Run train.py first.")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    vocab = checkpoint["vocab"]
    model = ImageCaptioner(len(vocab.word2idx)).to(device)
    model.load_state_dict(checkpoint["model_state"])
    model.eval()
    history = checkpoint.get("history", {})
    return model, vocab, history


def compute_precision_recall_f1(references, hypotheses):
    precision_list, recall_list, f1_list = [], [], []
    for reference, hypothesis in zip(references, hypotheses):
        reference_words = set(reference.split())
        hypothesis_words = set(hypothesis.split())
        if not hypothesis_words:
            precision_list.append(0.0)
            recall_list.append(0.0)
            f1_list.append(0.0)
            continue
        true_positives = len(reference_words & hypothesis_words)
        precision = true_positives / len(hypothesis_words)
        recall = true_positives / len(reference_words) if reference_words else 0.0
        f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        precision_list.append(precision)
        recall_list.append(recall)
        f1_list.append(f1_score)
    return sum(precision_list) / len(precision_list), sum(recall_list) / len(recall_list), sum(f1_list) / len(f1_list)


def caption_examples(model, vocab, features, pairs, images_dir, num_examples=5, device="cpu"):
    indices = list(range(len(pairs)))
    random.shuffle(indices)
    examples = []
    for index in indices[:num_examples]:
        image_name, ground_truth = pairs[index]
        if image_name not in features:
            continue
        predicted_caption = greedy_search(model, torch.tensor(features[image_name]), vocab, device=device)
        image_path = os.path.join(images_dir, image_name) if images_dir else None
        if image_path and not os.path.isfile(image_path):
            image_path = None
        examples.append({"image": image_name, "image_path": image_path, "ground_truth": ground_truth, "predicted": predicted_caption})
    return examples


def plot_loss_curve(history, out_path=None):
    if not history.get("train_loss") and not history.get("val_loss"):
        print("No loss history. Run train.py first.")
        return
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        print("matplotlib not installed.")
        return
    output_path = out_path or "loss_curve.png"
    plt.figure(figsize=(8, 5))
    if history.get("train_loss"):
        plt.plot(history["train_loss"], label="Train Loss")
    if history.get("val_loss"):
        plt.plot(history["val_loss"], label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training and Validation Loss")
    plt.grid(True, alpha=0.3)
    plt.savefig(output_path)
    print(f"Loss curve saved to {output_path}")


def run_evaluation():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model, vocab, history = load_model(device=str(device))
    pairs = load_captions("data/captions.txt")
    with open("flickr30k_features.pkl", "rb") as features_file:
        features = pickle.load(features_file)
    pairs = [(image_name, caption) for image_name, caption in pairs if image_name in features]

    print("\n--- Caption Examples ---")
    examples = caption_examples(model, vocab, features, pairs, "data/Images", num_examples=5, device=device)
    for example in examples:
        print(f"\nImage: {example['image']}")
        print(f"  Ground Truth: {example['ground_truth']}")
        print(f"  Predicted:    {example['predicted']}")

    print("\n--- Loss Curve ---")
    plot_loss_curve(history)

    print("\n--- Metrics ---")
    sample_indices = random.sample(range(len(pairs)), min(500, len(pairs)))
    references = [pairs[index][1] for index in sample_indices]
    hypotheses = [greedy_search(model, torch.tensor(features[pairs[index][0]]), vocab, device=device) for index in tqdm(sample_indices, desc="Evaluating")]
    bleu_score = compute_bleu4(references, hypotheses)
    precision, recall, f1_score = compute_precision_recall_f1(references, hypotheses)
    print(f"BLEU-4:    {bleu_score:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-score:  {f1_score:.4f}")

    try:
        import matplotlib.pyplot as plt
        figure, axes = plt.subplots(2, 3, figsize=(12, 8))
        axes = axes.flatten()
        for index, example in enumerate(examples[:5]):
            axis = axes[index]
            if example["image_path"]:
                image = Image.open(example["image_path"]).convert("RGB")
                axis.imshow(image)
            axis.set_title(example["image"][:20] + "...")
            axis.axis("off")
            axis.text(0.5, -0.15, f"GT: {example['ground_truth'][:50]}{'...' if len(example['ground_truth']) > 50 else ''}\nPred: {example['predicted'][:50]}{'...' if len(example['predicted']) > 50 else ''}", transform=axis.transAxes, fontsize=8, ha="center", wrap=True)
        axes[5].axis("off")
        plt.suptitle("Caption Examples")
        plt.tight_layout()
        plt.savefig("caption_examples.png", bbox_inches="tight")
        print("\nFigure saved to caption_examples.png")
    except Exception as error:
        print(f"Could not save figure: {error}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pairs = load_captions("data/captions.txt")
vocab = Vocabulary(min_freq=2)
vocab.build([cap for _, cap in pairs])

ds = CaptionDataset("flickr30k_features.pkl", pairs, vocab)
train_ds, val_ds = random_split(ds, [int(0.9 * len(ds)), len(ds) - int(0.9 * len(ds))])

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_ds, batch_size=64, collate_fn=collate_batch)

model = ImageCaptioner(len(vocab.word2idx)).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

history = {"train_loss": [], "val_loss": []}
for epoch in range(10):
    model.train()
    total = 0.0
    for feats, caps in tqdm(train_loader, leave=False):
        feats, caps = feats.to(device), caps.to(device)
        optimizer.zero_grad()
        logits = model(feats, caps)
        loss = criterion(logits.reshape(-1, logits.size(-1)), caps[:, 1:].reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total += loss.item()
    train_loss = total / len(train_loader)
    history["train_loss"].append(train_loss)

    model.eval()
    total = 0.0
    with torch.no_grad():
        for feats, caps in val_loader:
            feats, caps = feats.to(device), caps.to(device)
            logits = model(feats, caps)
            loss = criterion(logits.reshape(-1, logits.size(-1)), caps[:, 1:].reshape(-1))
            total += loss.item()
    val_loss = total / len(val_loader)
    history["val_loss"].append(val_loss)
    print(f"Epoch {epoch + 1} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")

torch.save({"model_state": model.state_dict(), "vocab": vocab, "history": history}, "caption_model.pt")
print("Saved to caption_model.pt")

model.eval()
refs, hyps = [], []
for idx in tqdm(val_ds.indices[: min(500, len(val_ds.indices))], desc="BLEU"):
    img, gt = val_ds.dataset.pairs[idx]
    refs.append(gt)
    hyps.append(greedy_search(model, torch.tensor(val_ds.dataset.feats[img]), vocab, device=device))
print(f"BLEU-4: {compute_bleu4(refs, hyps):.4f}")

run_evaluation()