# Flickr8k Multi-Modal Classification Training

This notebook trains a CNN + RNN + Fusion model on the Flickr8k dataset.

## Setup Instructions:
1. **Enable GPU**: Runtime → Change runtime type → GPU
2. **Run all cells** in order (Shift+Enter or Runtime → Run all)
3. **Clone from GitHub**: The project will be cloned from GitHub in Step 2
4. **Upload kaggle.json** when prompted in Step 3

## Prerequisites:
- Project must be pushed to GitHub: `https://github.com/Sashahajjar/FashionGen`
- Kaggle API token (kaggle.json) - get it from https://www.kaggle.com/account


## Step 1: Install Dependencies


In [1]:
# Install required packages
%pip install torch torchvision numpy Pillow scikit-learn matplotlib kaggle

# Verify installation
import torch
print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"✓ GPU: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️  GPU not available - training will be slower")


✓ PyTorch version: 2.9.0+cu126
✓ CUDA available: True
✓ GPU: Tesla T4


## Step 2: Clone Project from GitHub


In [2]:
# Clone your project from GitHub
GITHUB_USERNAME = "Sashahajjar"
REPO_NAME = "FashionGen"

# Clone the repository
!git clone https://github.com/{GITHUB_USERNAME}/{REPO_NAME}.git
%cd {REPO_NAME}

# Verify structure
!ls -la
print("\n✓ Project cloned successfully!")


Cloning into 'FashionGen'...
remote: Enumerating objects: 65, done.[K
remote: Counting objects: 100% (65/65), done.[K
remote: Compressing objects: 100% (45/45), done.[K
remote: Total 65 (delta 19), reused 63 (delta 17), pack-reused 0 (from 0)[K
Receiving objects: 100% (65/65), 64.55 KiB | 1.08 MiB/s, done.
Resolving deltas: 100% (19/19), done.
/content/FashionGen
total 84
drwxr-xr-x 8 root root 4096 Jan  4 16:18 .
drwxr-xr-x 1 root root 4096 Jan  4 16:18 ..
drwxr-xr-x 5 root root 4096 Jan  4 16:18 data
-rw-r--r-- 1 root root 6310 Jan  4 16:18 DATASET_SETUP.md
-rwxr-xr-x 1 root root 4368 Jan  4 16:18 download_dataset.sh
-rw-r--r-- 1 root root 8958 Jan  4 16:18 Flickr8k_Training.ipynb
drwxr-xr-x 8 root root 4096 Jan  4 16:18 .git
-rw-r--r-- 1 root root  474 Jan  4 16:18 .gitignore
drwxr-xr-x 2 root root 4096 Jan  4 16:18 inference
drwxr-xr-x 2 root root 4096 Jan  4 16:18 models
-rw-r--r-- 1 root root 8652 Jan  4 16:18 README.md
-rw-r--r-- 1 root root  489 Jan  4 16:18 requirements.tx

## Step 3: Set Up Kaggle API


In [3]:
# Upload your kaggle.json file
# If you don't have it:
# 1. Go to https://www.kaggle.com/account
# 2. Click "Create New Token" to download kaggle.json
# 3. Upload it here:

from google.colab import files
print("Please upload your kaggle.json file:")
files.upload()

# Set up Kaggle credentials
import os
os.makedirs('/root/.kaggle', exist_ok=True)
!mv kaggle.json /root/.kaggle/
!chmod 600 /root/.kaggle/kaggle.json

print("✓ Kaggle API configured")


Please upload your kaggle.json file:


Saving kaggle.json to kaggle.json
✓ Kaggle API configured


## Step 4: Download Flickr8k Dataset


In [4]:
# Create directories
!mkdir -p data/images data/captions data/downloads

# Download dataset (this will take a few minutes, ~1GB)
print("Downloading Flickr8k dataset from Kaggle...")
print("This may take 2-5 minutes...")
!kaggle datasets download -d adityajn105/flickr8k -p data/downloads

# Extract
print("\nExtracting dataset...")
!cd data/downloads && unzip -q flickr8k.zip

# Organize files
print("Organizing files...")
!cp -r data/downloads/Flickr8k_Dataset/* data/images/ 2>/dev/null || \
 cp -r data/downloads/Flicker8k_Dataset/* data/images/

!cp data/downloads/Flickr8k.token.txt data/captions/ 2>/dev/null || \
 cp data/downloads/Flickr8k_text/Flickr8k.token.txt data/captions/

# Verify
import os
image_files = [f for f in os.listdir('data/images') if f.endswith(('.jpg', '.jpeg', '.png'))]
image_count = len(image_files)
print(f"\n✓ Images downloaded: {image_count}")
print(f"✓ Captions file exists: {os.path.exists('data/captions/Flickr8k.token.txt')}")

# Show a sample caption
if os.path.exists('data/captions/Flickr8k.token.txt'):
    print("\nSample captions:")
    !head -3 data/captions/Flickr8k.token.txt


Downloading Flickr8k dataset from Kaggle...
This may take 2-5 minutes...
Dataset URL: https://www.kaggle.com/datasets/adityajn105/flickr8k
License(s): CC0-1.0
Downloading flickr8k.zip to data/downloads
 91% 971M/1.04G [00:11<00:02, 45.8MB/s]
100% 1.04G/1.04G [00:11<00:00, 99.2MB/s]

Extracting dataset...
Organizing files...
cp: cannot stat 'data/downloads/Flicker8k_Dataset/*': No such file or directory
cp: cannot stat 'data/downloads/Flickr8k_text/Flickr8k.token.txt': No such file or directory

✓ Images downloaded: 0
✓ Captions file exists: False


In [5]:
import os

print("Contents of data/downloads:")
!ls -la data/downloads/

print("\nSearching for images and captions...")
# Find all image files
!find data/downloads -name "*.jpg" -o -name "*.jpeg" -o -name "*.png" | head -10

# Find all text files (captions)
!find data/downloads -name "*.txt" | head -10

# Find directories
!find data/downloads -type d | head -10

Contents of data/downloads:
total 1090576
drwxr-xr-x 3 root root       4096 Jan  4 16:20 .
drwxr-xr-x 6 root root       4096 Jan  4 16:19 ..
-rw-r--r-- 1 root root    3319294 Apr 27  2020 captions.txt
-rw-r--r-- 1 root root 1112971163 Apr 27  2020 flickr8k.zip
drwxr-xr-x 2 root root     438272 Jan  4 16:20 Images

Searching for images and captions...
data/downloads/Images/3155987659_b9ea318dd3.jpg
data/downloads/Images/2335619125_2e2034f2c3.jpg
data/downloads/Images/3284460070_6805990149.jpg
data/downloads/Images/3033668641_5905f73990.jpg
data/downloads/Images/2914206497_5e36ac6324.jpg
data/downloads/Images/3334537556_a2cf4e9b9a.jpg
data/downloads/Images/3127614086_9f1d3cf73d.jpg
data/downloads/Images/2316097768_ef662f444b.jpg
data/downloads/Images/319938879_daf0857f91.jpg
data/downloads/Images/2273799395_5072a5736d.jpg
data/downloads/captions.txt
data/downloads
data/downloads/Images


In [6]:
# Quick fix: Organize the files based on what we found
print("Copying images from Images/ folder...")
!cp -r data/downloads/Images/* data/images/

print("Copying captions file...")
!cp data/downloads/captions.txt data/captions/Flickr8k.token.txt

# Verify
import os
image_files = [f for f in os.listdir('data/images') if f.endswith(('.jpg', '.jpeg', '.png'))]
image_count = len(image_files)
print(f"\n✓ Images copied: {image_count}")
print(f"✓ Captions file exists: {os.path.exists('data/captions/Flickr8k.token.txt')}")

# Show sample
if os.path.exists('data/captions/Flickr8k.token.txt'):
    print("\nSample captions:")
    !head -3 data/captions/Flickr8k.token.txt

Copying images from Images/ folder...
Copying captions file...

✓ Images copied: 8091
✓ Captions file exists: True

Sample captions:
image,caption
1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .
1000268201_693b08cb0e.jpg,A girl going into a wooden building .


In [8]:
# Step 1: Update dataset.py to handle CSV format
import csv

# Read the file
with open('data/dataset.py', 'r') as f:
    content = f.read()

# Find the _load_captions method and replace it
old_start = content.find('def _load_captions(self)')
old_end = content.find('return image_caption_pairs', old_start) + len('return image_caption_pairs')

new_code = '''    def _load_captions(self) -> List[Dict[str, str]]:
        """Load captions - supports CSV and token formats."""
        import csv
        image_caption_pairs = []

        if not os.path.exists(self.captions_file):
            raise FileNotFoundError(f"Captions file not found: {self.captions_file}")

        with open(self.captions_file, 'r', encoding='utf-8') as f:
            first_line = f.readline().strip()
            f.seek(0)
            is_csv = first_line.lower().startswith('image') and ',' in first_line

            if is_csv:
                reader = csv.DictReader(f)
                for row in reader:
                    img = row.get('image', '').strip()
                    cap = row.get('caption', '').strip()
                    if img and cap:
                        img_id = img.replace('.jpg', '').replace('.jpeg', '').replace('.png', '')
                        image_caption_pairs.append({'image_id': img_id, 'caption': cap})
            else:
                for line in f:
                    line = line.strip()
                    if not line: continue
                    parts = line.split('\t', 1) if '\t' in line else line.split(' ', 1)
                    if len(parts) == 2:
                        img_id = parts[0].split('#')[0] if '#' in parts[0] else parts[0]
                        image_caption_pairs.append({'image_id': img_id, 'caption': parts[1]})

        return image_caption_pairs'''

# Replace
new_content = content[:old_start] + new_code + '\n        ' + content[old_end+1:]

# Write back
with open('data/dataset.py', 'w') as f:
    f.write(new_content)

print("✓ Done! File updated.")

✓ Done! File updated.


In [11]:
# Fix the indentation error
with open('data/dataset.py', 'r') as f:
    lines = f.readlines()

# Find the problematic method
for i, line in enumerate(lines):
    if 'def _load_captions(self)' in line and i < len(lines) - 1:
        # Check if next line has wrong indentation
        if i + 1 < len(lines) and not lines[i+1].strip().startswith('"""') and not lines[i+1].strip().startswith('import'):
            # Fix it
            new_method = '''    def _load_captions(self) -> List[Dict[str, str]]:
        """Load captions - supports CSV and token formats."""
        import csv
        image_caption_pairs = []

        if not os.path.exists(self.captions_file):
            raise FileNotFoundError(f"Captions file not found: {self.captions_file}")

        with open(self.captions_file, 'r', encoding='utf-8') as f:
            first_line = f.readline().strip()
            f.seek(0)
            is_csv = first_line.lower().startswith('image') and ',' in first_line

            if is_csv:
                reader = csv.DictReader(f)
                for row in reader:
                    img = row.get('image', '').strip()
                    cap = row.get('caption', '').strip()
                    if img and cap:
                        img_id = img.replace('.jpg', '').replace('.jpeg', '').replace('.png', '')
                        image_caption_pairs.append({'image_id': img_id, 'caption': cap})
            else:
                for line in f:
                    line = line.strip()
                    if not line: continue
                    parts = line.split('\\t', 1) if '\\t' in line else line.split(' ', 1)
                    if len(parts) == 2:
                        img_id = parts[0].split('#')[0] if '#' in parts[0] else parts[0]
                        image_caption_pairs.append({'image_id': img_id, 'caption': parts[1]})

        return image_caption_pairs
'''
            # Find where method ends
            end_idx = i + 1
            while end_idx < len(lines) and 'return image_caption_pairs' not in lines[end_idx]:
                end_idx += 1

            # Replace
            new_lines = lines[:i] + [new_method] + lines[end_idx+1:]
            with open('data/dataset.py', 'w') as f:
                f.writelines(new_lines)
            print("✓ Fixed!")
            break

In [13]:
# Fix the indentation
with open('data/dataset.py', 'r') as f:
    lines = f.readlines()

# Fix line 246 - remove extra indentation
lines[245] = '    def _load_captions(self) -> List[Dict[str, str]]:\n'

# Write back
with open('data/dataset.py', 'w') as f:
    f.writelines(lines)

print("✓ Fixed indentation!")

✓ Fixed indentation!


In [14]:
# Test again
import sys
sys.path.append('/content/FashionGen')

from data.dataset import Flickr8kDataset

dataset = Flickr8kDataset(
    images_dir='data/images',
    captions_file='data/captions/Flickr8k.token.txt',
    split='train',
    max_samples=5
)

print(f"✓ Success! Loaded {len(dataset)} samples")

Loaded 5 samples for train split
Building vocabulary from captions...
Vocabulary size: 29
✓ Success! Loaded 5 samples


In [15]:
# Test that the dataset loads correctly
import sys
sys.path.append('/content/FashionGen')

from data.dataset import Flickr8kDataset

# Test loading
print("Testing dataset loading...")
dataset = Flickr8kDataset(
    images_dir='data/images',
    captions_file='data/captions/Flickr8k.token.txt',
    split='train',
    max_samples=10  # Just test with 10 samples
)

print(f"\n✓ Dataset loaded successfully: {len(dataset)} samples")
sample = dataset[0]
print(f"\n✓ Sample loaded:")
print(f"  Image ID: {sample['image_id']}")
print(f"  Caption: {sample['caption_text'][:60]}...")
print(f"  Label: {sample['label']}")
print(f"  Image shape: {sample['image'].shape}")


Testing dataset loading...
Loaded 10 samples for train split
Building vocabulary from captions...
Vocabulary size: 57

✓ Dataset loaded successfully: 10 samples

✓ Sample loaded:
  Image ID: 1000268201_693b08cb0e
  Caption: A child in a pink dress is climbing up a set of stairs in an...
  Label: 2
  Image shape: torch.Size([3, 224, 224])


## Step 6: Train the Model


In [17]:
# Fix: Move lengths to CPU before packing
with open('models/rnn_model.py', 'r') as f:
    content = f.read()

# Replace the pack_padded_sequence line
old_line = "        packed = nn.utils.rnn.pack_padded_sequence(\n            embedded, lengths, batch_first=True, enforce_sorted=False\n        )"

new_line = "        # Move lengths to CPU (required by pack_padded_sequence)\n        lengths_cpu = lengths.cpu()\n        packed = nn.utils.rnn.pack_padded_sequence(\n            embedded, lengths_cpu, batch_first=True, enforce_sorted=False\n        )"

content = content.replace(old_line, new_line)

with open('models/rnn_model.py', 'w') as f:
    f.write(content)

print("✓ Fixed! Lengths will now be moved to CPU.")

✓ Fixed! Lengths will now be moved to CPU.


In [18]:
# Train with real Flickr8k data!
# This will take 30-60 minutes depending on GPU

!python training/train.py

# For a quick test with limited samples, uncomment this instead:
# !python training/train.py --max_samples 1000


Flickr8k Multi-Modal Classification Training
Device: cuda
Number of classes: 10

Images directory: data/images
Captions file: data/captions/Flickr8k.token.txt

Building vocabulary from training data...
Loaded 28315 samples for train split
Building vocabulary from captions...
Vocabulary size: 7274

Creating train dataset...
Loaded 28315 samples for train split
Creating validation dataset...
Loaded 6065 samples for val split
Train samples: 28315
Validation samples: 6065
Batch size: 16

Creating fusion model...
CNN layers frozen: all
Total parameters: 37,754,186
Trainable parameters: 13,197,066

Optimizer: Adam (lr=0.0001)
Loss function: CrossEntropyLoss
Number of epochs: 5

Training Progress
Epoch 1/5 | Train Loss: 0.6737 Train Acc: 78.79% | Val Loss: 0.6437 Val Acc: 78.76% | LR: 1.00e-04
  -> Best model by loss saved! (Val Loss: 0.6437 at epoch 1)
  -> Best model by accuracy saved! (Val Acc: 78.76% at epoch 1)
Epoch 2/5 | Train Loss: 0.4680 Train Acc: 84.50% | Val Loss: 0.6053 Val Acc: 

## Step 7: Evaluate the Model


In [19]:
# Evaluate the trained model on test set
!python training/evaluate.py


Flickr8k Model Evaluation
Device: cuda

Images directory: data/images
Captions file: data/captions/Flickr8k.token.txt

Building vocabulary from training data...
Loaded 28315 samples for train split
Building vocabulary from captions...
Vocabulary size: 7274

Creating test dataset...
Loaded 6075 samples for test split

Creating fusion model...

Loading checkpoint: saved_models/multimodal.pth
Checkpoint loaded from saved_models/multimodal.pth
Loaded model from epoch 5
Best validation accuracy: 82.36%

Running Evaluation

Evaluation Results
Test Loss: 0.3247
Test Accuracy: 89.28%

--------------------------------------------------------------------------------
Per-Class Accuracy
--------------------------------------------------------------------------------
Class 0             :  95.20%
Class 1             :   1.32%
Class 2             :  97.85%
Class 3             :  42.37%
Class 4             :  66.17%
Class 5             :  35.71%
Class 6             :  44.09%
Class 7             :   9

## Step 8: Download Trained Models (Optional)


In [23]:
# Download your trained models to your local machine
from google.colab import files

# Check which models exist
import os
if os.path.exists('saved_models/multimodal_best_loss.pth'):
    print("Downloading best model by loss...")
    files.download('saved_models/multimodal_best_loss.pth')

if os.path.exists('saved_models/multimodal_best_acc.pth'):
    print("Downloading best model by accuracy...")
    files.download('saved_models/multimodal_best_acc.pth')

if os.path.exists('saved_models/multimodal.pth'):
    print("Downloading latest model...")
    files.download('saved_models/multimodal.pth')

print("\n✓ Download complete!")


Downloading best model by loss...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Downloading best model by accuracy...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Downloading latest model...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


✓ Download complete!


## Step 9: Run Inference Demo (Optional)


In [24]:
# Run inference demo on test samples
!python inference/demo.py


Flickr8k Inference Demo
Device: cuda

Loading model...
Loading checkpoint: saved_models/multimodal.pth
Checkpoint loaded from saved_models/multimodal.pth
Model loaded from epoch 5

Loading vocabulary...
Loaded 28315 samples for train split
Building vocabulary from captions...
Vocabulary size: 7274
Loading test dataset...
Loaded 6075 samples for test split

Running Inference on Test Examples

Example 1/3
Image ID: 3688839836_ba5e4c24fc
Caption: A woman is holding out a peace sign during a parade .
True Label: Person (Class 2)

Prediction Results

Predicted Class: Person (Class 2)
Confidence: 100.00%

True Label: Person (Class 2)
Result: ✓ CORRECT

--------------------------------------------------------------------------------
Top-3 Predictions
--------------------------------------------------------------------------------
1. Person               (Class  2): 100.00%
2. Cat                  (Class  1):   0.00%
3. Building             (Class  5):   0.00%

--------------------------------