# Fashion-Gen Multi-Modal Training on Google Colab

This notebook trains the CNN + RNN + Fusion model on Fashion-Gen data from Kaggle.

**Features:**
- Clones project from GitHub automatically
- Option to test with mock data (no download needed)
- Downloads only validation H5 file (saves disk space)
- Streaming HDF5 dataset (no RAM crash)
- Subset training with `--max_samples` option
- Robust error handling and debugging


## 1. Clone Project from GitHub


In [None]:
# Clone the project from GitHub
!git clone https://github.com/Sashahajjar/FashionGen.git

# Set project directory
import os
import sys
PROJECT_DIR = '/content/FashionGen'
os.chdir(PROJECT_DIR)
sys.path.insert(0, PROJECT_DIR)

print(f"✓ Project cloned to: {PROJECT_DIR}")
print(f"✓ Current directory: {os.getcwd()}")


✓ Dependencies installed


In [None]:
# Install required dependencies
# IMPORTANT: Run this cell, then RESTART RUNTIME (Runtime → Restart runtime)
# After restart, skip this cell and go to the next one

import subprocess
import sys

print("Step 1: Uninstalling incompatible packages...")
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "numpy", "torch", "torchvision", "scipy", "scikit-learn"])

print("\nStep 2: Installing compatible versions...")
# Install numpy first, then torch/torchvision
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.26.4", "--no-cache-dir"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "torchvision", "--no-cache-dir"])

print("\nStep 3: Installing other dependencies...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "Pillow", "scikit-learn", "h5py", "kaggle", "--no-cache-dir"])

print("\n" + "="*60)
print("⚠️  IMPORTANT: RESTART RUNTIME NOW!")
print("Runtime → Restart runtime")
print("Then SKIP this cell and continue to next cell")
print("="*60)


Device: cuda
GPU: Tesla T4
CUDA Version: 12.6
GPU Memory: 15.83 GB


## 2. Environment Setup & GPU Check


In [None]:
# Detect GPU and print device info
import torch
import warnings
warnings.filterwarnings('ignore', category=UserWarning)  # Suppress numpy compatibility warnings

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("Using CPU (training will be slower)")

# Test imports
print("\nTesting imports...")
from data.dataset import FashionGenDataset
from models.fusion_model import create_fusion_model
print("✓ All imports successful")


Please upload your kaggle.json file:


Saving kaggle.json to kaggle.json
✓ Found kaggle.json


In [None]:
## 3. Choose: Mock Data or Real Data

**Option A: Test with Mock Data (Quick, No Download)**
- Skip to Section 4
- No Kaggle account needed
- Fast testing

**Option B: Use Real FashionGen Data**
- Continue with Kaggle authentication below
- Requires Kaggle account
- Downloads ~1.7GB validation file


✓ Kaggle credentials configured
✓ File permissions: 600


In [None]:
## 3A. Kaggle Authentication (Skip if using mock data)


ref                                               title                                           size  lastUpdated                 downloadCount  voteCount  usabilityRating  
------------------------------------------------  ----------------------------------------  ----------  --------------------------  -------------  ---------  ---------------  
neurocipher/heartdisease                          Heart Disease                                   3491  2025-12-11 15:29:14.327000           2114        179  1.0              
kundanbedmutha/exam-score-prediction-dataset      Exam Score Prediction Dataset                 325454  2025-11-28 07:29:01.047000           5863        222  1.0              
dansbecker/melbourne-housing-snapshot             Melbourne Housing Snapshot                    461423  2018-06-05 12:52:24.087000         200095       1720  0.7058824        

✓ Kaggle authentication verified


# Upload kaggle.json file
from google.colab import files
import json
import shutil

print("Please upload your kaggle.json file:")
uploaded = files.upload()

# Find the uploaded kaggle.json
kaggle_json = None
for filename in uploaded.keys():
    if 'kaggle' in filename.lower() and filename.endswith('.json'):
        kaggle_json = filename
        break

if not kaggle_json:
    raise FileNotFoundError("kaggle.json file not found in uploads")

print(f"✓ Found {kaggle_json}")

# Move kaggle.json to ~/.kaggle/ and set permissions
os.makedirs('/root/.kaggle', exist_ok=True)
shutil.copy(kaggle_json, '/root/.kaggle/kaggle.json')
os.chmod('/root/.kaggle/kaggle.json', 0o600)

print("✓ Kaggle credentials configured")
print(f"✓ File permissions: {oct(os.stat('/root/.kaggle/kaggle.json').st_mode)[-3:]}")

# Verify Kaggle authentication
!kaggle datasets list | head -5
print("\n✓ Kaggle authentication verified")


In [6]:
# Check disk space before download
!df -h / | tail -1

# Create data directory on Colab local disk (NOT Google Drive)
DATA_DIR = '/content/kaggle_data'
os.makedirs(DATA_DIR, exist_ok=True)

print(f"\n✓ Data directory: {DATA_DIR}")

# Get free space in GB
import shutil
total, used, free = shutil.disk_usage('/')
free_gb = free / (1024**3)
print(f"✓ Free disk space: {free_gb:.2f} GB")

# Warn if less than 5GB free (validation file is ~2-3GB)
if free_gb < 5:
    print("⚠ WARNING: Low disk space! May cause download failures.")
else:
    print("✓ Sufficient disk space available")


overlay         113G   39G   75G  35% /

✓ Data directory: /content/kaggle_data
✓ Free disk space: 74.13 GB
✓ Sufficient disk space available


In [7]:
# List available files in Fashion-Gen dataset
print("Available files in Fashion-Gen dataset:")
!kaggle datasets files bothin/fashiongen-validation


Available files in Fashion-Gen dataset:
name                                     size  creationDate                
--------------------------------  -----------  --------------------------  
fashiongen_256_256_train.h5       14387677469  2023-05-27 05:27:54.521000  
fashiongen_256_256_validation.h5   1793119395  2023-05-27 05:19:00.258000  


In [16]:
# Download ONLY validation H5 file (smaller, faster)
# Prefer validation file: fashiongen_256_256_validation.h5
H5_FILENAME = 'fashiongen_256_256_validation.h5'
EXPECTED_SIZE_GB = 1.5  # Expected size in GB

# Clean up any existing corrupted files
existing_file = os.path.join(DATA_DIR, H5_FILENAME)
if os.path.exists(existing_file):
    print(f"⚠ Found existing file, checking if valid...")
    file_size_gb = os.path.getsize(existing_file) / (1024**3)
    if file_size_gb < 0.5:  # Too small, likely incomplete
        print(f"⚠ File too small ({file_size_gb:.2f} GB), deleting...")
        os.remove(existing_file)
    else:
        print(f"Existing file size: {file_size_gb:.2f} GB (expected ~{EXPECTED_SIZE_GB} GB)")

print(f"\nDownloading {H5_FILENAME}...")
print("This may take a few minutes...")
print("Note: Kaggle may return a zip file - we'll handle that automatically")

try:
    import subprocess
    result = subprocess.run(
        ['kaggle', 'datasets', 'download', '-d', 'bothin/fashiongen-validation',
         '-f', H5_FILENAME, '-p', DATA_DIR],
        capture_output=True,
        text=True,
        timeout=1800  # 30 minute timeout
    )

    if result.returncode == 0:
        print("\n✓ Download command completed")
        if result.stdout:
            print(result.stdout)
    else:
        print(f"\n⚠ Download command returned code {result.returncode}")
        if result.stderr:
            print("Errors:", result.stderr)
        if result.stdout:
            print("Output:", result.stdout)

except subprocess.TimeoutExpired:
    print("\n✗ Download timed out (>30 minutes)")
    print("This may indicate network issues or very large file")
    raise
except Exception as e:
    print(f"\n✗ Download failed: {e}")
    print("\nDebugging info:")
    !df -h /
    !ls -lh {DATA_DIR}
    !ls -la ~/.kaggle
    raise

# Verify download completed
downloaded_file = os.path.join(DATA_DIR, H5_FILENAME)
if os.path.exists(downloaded_file):
    file_size_gb = os.path.getsize(downloaded_file) / (1024**3)
    print(f"\n✓ File downloaded: {file_size_gb:.2f} GB")
    if file_size_gb < 0.5:
        print("⚠ WARNING: File seems too small, download may be incomplete!")
else:
    print("⚠ File not found with expected name, checking for zip files...")



Downloading fashiongen_256_256_validation.h5...
This may take a few minutes...
Note: Kaggle may return a zip file - we'll handle that automatically

✓ Download command completed
Dataset URL: https://www.kaggle.com/datasets/bothin/fashiongen-validation
License(s): unknown
Downloading fashiongen_256_256_validation.h5 to /content/kaggle_data



✓ File downloaded: 1.47 GB


In [17]:
ls -lh /content/kaggle_data


total 1.5G
-rw-r--r-- 1 root root 1.5G May 27  2023 fashiongen_256_256_validation.h5


In [20]:
!mv /content/kaggle_data/fashiongen_256_256_validation.h5 /content/kaggle_data/fashiongen_256_256_validation.zip


In [21]:
!unzip -n /content/kaggle_data/fashiongen_256_256_validation.zip -d /content/kaggle_data


Archive:  /content/kaggle_data/fashiongen_256_256_validation.zip
  inflating: /content/kaggle_data/fashiongen_256_256_validation.h5  


In [22]:
!rm -f /content/kaggle_data/fashiongen_256_256_validation.zip
!ls -lh /content/kaggle_data


total 1.7G
-rw-r--r-- 1 root root 1.7G May 27  2023 fashiongen_256_256_validation.h5


In [23]:
import h5py

path = "/content/kaggle_data/fashiongen_256_256_validation.h5"
with h5py.File(path, "r") as f:
    print("✅ Opened H5")
    print("Keys:", list(f.keys())[:30])


✅ Opened H5
Keys: ['index', 'index_2', 'input_brand', 'input_category', 'input_composition', 'input_concat_description', 'input_department', 'input_description', 'input_gender', 'input_image', 'input_msrpUSD', 'input_name', 'input_pose', 'input_productID', 'input_season', 'input_subcategory']


In [24]:
# Set H5 file path (you already verified it works)
H5_FILE_PATH = "/content/kaggle_data/fashiongen_256_256_validation.h5"
print(f"✅ H5 file path set: {H5_FILE_PATH}")

✅ H5 file path set: /content/kaggle_data/fashiongen_256_256_validation.h5


In [None]:
# Quick test: Verify dataset loader works
import sys
sys.path.insert(0, '/content/fashiongen-project')

from data.h5_dataset import FashionGenH5Dataset

test_dataset = FashionGenH5Dataset(
    h5_file_path=H5_FILE_PATH,
    image_size=(224, 224),
    max_seq_len=50,
    vocab_size=10000,
    split='train',
    max_samples=10
)

sample = test_dataset[0]
print(f"✅ Dataset works! Image: {sample['image'].shape}, Label: {sample['label']}")


image: (32528, 256, 256, 3) uint8
text : (32528, 1) |S400
label: (32528, 1) |S100
sample label: [b'JACKETS & COATS']
sample text : [b'Denim-like jogg jacket in blue. Fading and whiskering throughout. Spread collar. Copper tone button closures at front. Flap pockets at chest with metallic logo plaque. Seam pockets at sides. Cinch tabs at back waistband. Single button sleeve cuffs. Tone on tone stitching.']
image min/max: 9 255


In [None]:
# Clean up disk space - delete cache and temporary files
import shutil
import os

print("Cleaning up disk space...\n")

# 1. Clear pip cache
print("1. Clearing pip cache...")
!pip cache purge 2>/dev/null || echo "No pip cache found"
print("✅ Done\n")

# 2. Clear Python cache (__pycache__)
print("2. Removing Python cache files...")
!find /content -type d -name __pycache__ -exec rm -r {} + 2>/dev/null || echo "No cache found"
!find /content -name "*.pyc" -delete 2>/dev/null || echo "No .pyc files found"
print("✅ Done\n")

# 3. Clear temporary files
print("3. Clearing temporary files...")
!rm -rf /tmp/* 2>/dev/null || echo "No temp files"
!rm -rf /content/tmp/* 2>/dev/null || echo "No content/tmp files"
print("✅ Done\n")

# 4. Clear Colab cache
print("4. Clearing Colab cache...")
!rm -rf /root/.cache/* 2>/dev/null || echo "No cache"
!rm -rf /root/.local/share/Trash/* 2>/dev/null || echo "No trash"
print("✅ Done\n")

# 5. Clear matplotlib cache
print("5. Clearing matplotlib cache...")
!rm -rf /root/.cache/matplotlib/* 2>/dev/null || echo "No matplotlib cache"
print("✅ Done\n")

# 6. Clear torch cache (if not needed)
print("6. Clearing torch cache...")
!rm -rf /root/.cache/torch/* 2>/dev/null || echo "No torch cache"
print("✅ Done\n")

# Show disk space after cleanup
print("\n=== Disk Space After Cleanup ===")
!df -h / | tail -1

total, used, free = shutil.disk_usage('/')
free_gb = free / (1024**3)
print(f"\n✅ Free space: {free_gb:.2f} GB")


In [None]:
# Project already cloned in Cell 2
# Just verify it's set up correctly
print(f"✓ Project directory: {PROJECT_DIR}")
print(f"✓ Current directory: {os.getcwd()}")
!ls -la {PROJECT_DIR} | head -10


In [None]:
# Setup Python path and test imports
import sys
sys.path.insert(0, PROJECT_DIR)

from data.h5_dataset import FashionGenH5Dataset
from models.fusion_model import create_fusion_model
print("✅ Imports work")


## 4. Training Configuration


In [None]:
# Training config
USE_MOCK_DATA = True  # Set to False to use real HDF5 data
MAX_SAMPLES = 100  # Use subset for quick testing (None = all samples)
BATCH_SIZE = 32 if torch.cuda.is_available() else 16

print(f"Use mock data: {USE_MOCK_DATA}")
print(f"Max samples: {MAX_SAMPLES}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Device: {device}")

# Set H5 file path (only used if USE_MOCK_DATA = False)
if not USE_MOCK_DATA:
    H5_FILE_PATH = "/content/kaggle_data/fashiongen_256_256_validation.h5"
    print(f"H5 file path: {H5_FILE_PATH}")
else:
    H5_FILE_PATH = None
    print("Using mock data - no H5 file needed")


## 5. Run Training


In [None]:
# Run training
os.chdir(PROJECT_DIR)

if USE_MOCK_DATA:
    # Test with mock data (no H5 file needed)
    cmd = f"python training/train.py --max_samples {MAX_SAMPLES}"
    print(f"Running with MOCK DATA: {cmd}\n")
else:
    # Use real HDF5 data
    cmd = f"python training/train.py --h5_file {H5_FILE_PATH} --max_samples {MAX_SAMPLES}"
    print(f"Running with REAL DATA: {cmd}\n")

!{cmd}


## 6. Check Results


In [None]:
# Check results
!ls -lh {PROJECT_DIR}/saved_models/


## 8. Debug Section (if errors occur)


In [None]:
# Run this cell if you encounter errors

print("=== Disk Space ===")
!df -h /

print("\n=== Data Directory ===")
!ls -lh {DATA_DIR}

print("\n=== Kaggle Auth ===")
!ls -la ~/.kaggle

print("\n=== Project Structure ===")
!find {PROJECT_DIR} -type f -name "*.py" | head -10

print("\n=== Common Issues ===")
print("1. 403 Forbidden: Check kaggle.json credentials")
print("2. Disk full: Delete old files or use smaller dataset")
print("3. Download stuck: Check network, try again")
print("4. Import errors: Verify all project files uploaded")
print("5. H5 file corrupted: Re-download the file")
