# Flickr8k Multi-Modal Classification Training

This notebook trains a CNN + RNN + Fusion model on the Flickr8k dataset.

## Setup Instructions:
1. **Enable GPU**: Runtime → Change runtime type → GPU
2. **Run all cells** in order (Shift+Enter or Runtime → Run all)
3. **Clone from GitHub**: The project will be cloned from GitHub in Step 2
4. **Upload kaggle.json** when prompted in Step 3

## Prerequisites:
- Project must be pushed to GitHub: `https://github.com/Sashahajjar/FashionGen`
- Kaggle API token (kaggle.json) - get it from https://www.kaggle.com/account


## Step 1: Install Dependencies


In [None]:
# Install required packages
%pip install torch torchvision numpy Pillow scikit-learn matplotlib kaggle

# Verify installation
import torch
print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️  GPU not available - training will be slower")


## Step 2: Clone Project from GitHub


In [None]:
# Clone your project from GitHub
GITHUB_USERNAME = "Sashahajjar"
REPO_NAME = "FashionGen"

# Clone the repository
!git clone https://github.com/{GITHUB_USERNAME}/{REPO_NAME}.git
%cd {REPO_NAME}

# Verify structure
!ls -la
print("\n✓ Project cloned successfully!")


## Step 3: Set Up Kaggle API


In [None]:
# Upload your kaggle.json file
# If you don't have it:
# 1. Go to https://www.kaggle.com/account
# 2. Click "Create New Token" to download kaggle.json
# 3. Upload it here:

from google.colab import files
print("Please upload your kaggle.json file:")
files.upload()

# Set up Kaggle credentials
import os
os.makedirs('/root/.kaggle', exist_ok=True)
!mv kaggle.json /root/.kaggle/
!chmod 600 /root/.kaggle/kaggle.json

print("✓ Kaggle API configured")


## Step 4: Download Flickr8k Dataset


In [None]:
# Create directories
!mkdir -p data/images data/captions data/downloads

# Download dataset (this will take a few minutes, ~1GB)
print("Downloading Flickr8k dataset from Kaggle...")
print("This may take 2-5 minutes...")
!kaggle datasets download -d adityajn105/flickr8k -p data/downloads

# Extract
print("\nExtracting dataset...")
!cd data/downloads && unzip -q flickr8k.zip

# Organize files
print("Organizing files...")
!cp -r data/downloads/Flickr8k_Dataset/* data/images/ 2>/dev/null || \
 cp -r data/downloads/Flicker8k_Dataset/* data/images/

!cp data/downloads/Flickr8k.token.txt data/captions/ 2>/dev/null || \
 cp data/downloads/Flickr8k_text/Flickr8k.token.txt data/captions/

# Verify
import os
image_files = [f for f in os.listdir('data/images') if f.endswith(('.jpg', '.jpeg', '.png'))]
image_count = len(image_files)
print(f"\n✓ Images downloaded: {image_count}")
print(f"✓ Captions file exists: {os.path.exists('data/captions/Flickr8k.token.txt')}")

# Show a sample caption
if os.path.exists('data/captions/Flickr8k.token.txt'):
    print("\nSample captions:")
    !head -3 data/captions/Flickr8k.token.txt


In [None]:
# Test that the dataset loads correctly
import sys
sys.path.append('/content/FashionGen')

from data.dataset import Flickr8kDataset

# Test loading
print("Testing dataset loading...")
dataset = Flickr8kDataset(
    images_dir='data/images',
    captions_file='data/captions/Flickr8k.token.txt',
    split='train',
    max_samples=10  # Just test with 10 samples
)

print(f"\n✓ Dataset loaded successfully: {len(dataset)} samples")
sample = dataset[0]
print(f"\n✓ Sample loaded:")
print(f"  Image ID: {sample['image_id']}")
print(f"  Caption: {sample['caption_text'][:60]}...")
print(f"  Label: {sample['label']}")
print(f"  Image shape: {sample['image'].shape}")


## Step 6: Train the Model


In [None]:
# Train with real Flickr8k data!
# This will take 30-60 minutes depending on GPU

!python training/train.py

# For a quick test with limited samples, uncomment this instead:
# !python training/train.py --max_samples 1000


## Step 7: Evaluate the Model


In [None]:
# Evaluate the trained model on test set
!python training/evaluate.py


## Step 8: Download Trained Models (Optional)


In [None]:
# Download your trained models to your local machine
from google.colab import files

# Check which models exist
import os
if os.path.exists('saved_models/multimodal_best_loss.pth'):
    print("Downloading best model by loss...")
    files.download('saved_models/multimodal_best_loss.pth')

if os.path.exists('saved_models/multimodal_best_acc.pth'):
    print("Downloading best model by accuracy...")
    files.download('saved_models/multimodal_best_acc.pth')

if os.path.exists('saved_models/multimodal.pth'):
    print("Downloading latest model...")
    files.download('saved_models/multimodal.pth')

print("\n✓ Download complete!")


## Step 9: Run Inference Demo (Optional)


In [None]:
# Run inference demo on test samples
!python inference/demo.py
