# Image Captioning - Training on Kaggle

## Setup Instructions:
1. Add **Flickr Image Dataset** from Kaggle datasets
2. Enable **GPU** in Settings → Accelerator
3. Enable **Internet** in Settings
4. Run all cells

In [None]:
# Clone your GitHub repository
!git clone https://github.com/<YOUR-USERNAME>/<YOUR-REPO>.git
%cd <YOUR-REPO>

In [None]:
# Install requirements
!pip install -q -r requirements.txt

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Prepare Flickr dataset
!python kaggle_setup.py

## Train Vanilla RNN

In [None]:
!python training/Vanilla_RNN.py

## View Results

In [None]:
import json
import matplotlib.pyplot as plt

# Load results
with open('results/Vanilla_RNN/results.json', 'r') as f:
    results = json.load(f)

print("="*50)
print("Vanilla RNN Results")
print("="*50)
print(f"Parameters: {results['num_params']:,}")
print(f"Final Train Loss: {results['final_train_loss']:.4f}")
print(f"Final Val Loss: {results['final_val_loss']:.4f}")
print(f"Best Val Loss: {results['best_val_loss']:.4f}")
print(f"Total Time: {results['total_time']/60:.1f} minutes")
print("="*50)

# Plot training curves
plt.figure(figsize=(14, 5))

plt.subplot(1, 3, 1)
plt.plot(results['train_loss_history'], label='Train', linewidth=2)
plt.plot(results['val_loss_history'], label='Val', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Training History', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(alpha=0.3)

plt.subplot(1, 3, 2)
plt.plot(results['val_loss_history'], linewidth=2, color='orange')
plt.axhline(y=results['best_val_loss'], color='r', linestyle='--', label=f"Best: {results['best_val_loss']:.2f}")
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Validation Loss', fontsize=12)
plt.title('Validation Loss', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(alpha=0.3)

plt.subplot(1, 3, 3)
gap = [v - t for v, t in zip(results['val_loss_history'], results['train_loss_history'])]
plt.plot(gap, linewidth=2, color='red')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Val - Train Loss', fontsize=12)
plt.title('Overfitting Gap', fontsize=14, fontweight='bold')
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

## Train Other Models (Optional)

In [None]:
# Train LSTM
!python training/LSTM.py

In [None]:
# Train Attention LSTM
!python training/Attention_LSTM.py

In [None]:
# Train Transformer
!python training/Transformer.py

## Compare All Models

In [None]:
import pandas as pd

models = ['Vanilla_RNN', 'LSTM', 'Attention_LSTM', 'Transformer']
results_list = []

for model in models:
    try:
        with open(f'results/{model}/results.json', 'r') as f:
            data = json.load(f)
            results_list.append({
                'Model': model,
                'Parameters': f"{data['num_params']:,}",
                'Train Loss': f"{data['final_train_loss']:.4f}",
                'Val Loss': f"{data['final_val_loss']:.4f}",
                'Best Val': f"{data['best_val_loss']:.4f}",
                'Time (min)': f"{data['total_time']/60:.1f}"
            })
    except FileNotFoundError:
        print(f"⚠️ {model} not trained yet")

if results_list:
    df = pd.DataFrame(results_list)
    print("\n" + "="*80)
    print("MODEL COMPARISON")
    print("="*80)
    print(df.to_string(index=False))
    print("="*80)

## Generate Sample Captions

In [None]:
import torch
from models.Vanilla_RNN import VanillaRNNCaptioner
from a5_helper import load_coco_captions, decode_captions
import matplotlib.pyplot as plt

# Load data and model
data = load_coco_captions("./datasets/flickr.pt")
word_to_idx = data["vocab"]["token_to_idx"]
idx_to_word = data["vocab"]["idx_to_token"]

model = VanillaRNNCaptioner(
    word_to_idx=word_to_idx,
    wordvec_dim=128,
    hidden_dim=128,
    ignore_index=word_to_idx.get("<NULL>")
)

# Load trained weights
model.load_state_dict(torch.load('results/Vanilla_RNN/model.pt'))
model.eval()

# Generate captions for sample images
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

num_samples = 5
sample_images = data["val_images"][:num_samples].to(device)
sample_captions_gt = data["val_captions"][:num_samples]

with torch.no_grad():
    generated_captions = model.sample(sample_images)

# Display results
fig, axes = plt.subplots(1, num_samples, figsize=(20, 4))
for i in range(num_samples):
    img = sample_images[i].cpu().permute(1, 2, 0)
    axes[i].imshow(img)
    axes[i].axis('off')
    
    gt_caption = decode_captions(sample_captions_gt[i], idx_to_word)
    gen_caption = decode_captions(generated_captions[i], idx_to_word)
    
    axes[i].set_title(f"GT: {gt_caption}\nGen: {gen_caption}", fontsize=8)

plt.tight_layout()
plt.show()