# ViT + ResNet50 Fusion Model (FIXED VERSION)



In [None]:
# @title 1. Setup & Install
!pip install -q wandb transformers gdown scikit-learn

import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import os
import gdown
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score, accuracy_score, classification_report
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import ViTForImageClassification
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Initialize WandB
wandb.login()

In [None]:
# @title 2. Download Data with PROPER SPLIT
def load_csv_from_gdrive(share_url: str, **read_csv_kwargs) -> pd.DataFrame:
    file_id = share_url.split("/d/")[1].split("/")[0]
    download_url = f"https://drive.google.com/uc?id={file_id}"
    return pd.read_csv(download_url, **read_csv_kwargs)

print("Downloading CSV data...")
X_train_url = "https://drive.google.com/file/d/1geSiJTTjamysiSbJ8-W9gR1kv-x6HyEd/view?usp=drive_link"
y_train_url = "https://drive.google.com/file/d/16czWmLR5Ff0s5aYIqy1rHT7hc6Gcpfw3/view?usp=sharing"

X_train_full = load_csv_from_gdrive(X_train_url)
y_train_full = load_csv_from_gdrive(y_train_url)

print(f"Total data loaded: {len(X_train_full):,} samples")

# ============================================================================
# ✅  DATA SPLIT (85% dev / 15% holdout)
# ============================================================================
print("\n" + "="*80)
print("SPLITTING DATA (85% dev / 15% holdout)")
print("="*80)

X_dev, X_holdout, y_dev, y_holdout = train_test_split(
    X_train_full,
    y_train_full['prdtypecode'],
    test_size=0.15,
    random_state=42,
    stratify=y_train_full['prdtypecode']
)

df_dev = X_dev.copy()
df_dev['prdtypecode'] = y_dev

df_holdout = X_holdout.copy()
df_holdout['prdtypecode'] = y_holdout

print(f"✓ Development set: {len(df_dev):,} samples (85%)")
print(f"✓ Hold-out test set: {len(df_holdout):,} samples (15%)")
print(f"✓ Classes: {df_dev['prdtypecode'].nunique()}")
print("\n⚠️  CRITICAL: Holdout set will ONLY be used for final evaluation!")
print("="*80)

In [None]:
# @title 3. Download Images
IMAGE_FILE_ID = "15ZkS0iTQ7j3mHpxil4mABlXwP-jAN_zi"

if not os.path.exists("/content/images"):
    print("Downloading images...")
    !mkdir -p /content/tmp /content/images
    !gdown --id $IMAGE_FILE_ID -O /content/tmp/images.zip
    print("Unzipping images...")
    !unzip -q -o /content/tmp/images.zip -d /content/images
    print("Images unzipped")
else:
    print("Images already exist")

IMG_ROOT = "/content/images/images/image_train"
print(f"Image root: {IMG_ROOT}")

In [None]:
# @title 4. Download Pretrained Models
def download_model_from_drive(url, output_path):
    if not os.path.exists(output_path):
        file_id = url.split("/d/")[1].split("/")[0]
        gdown_url = f'https://drive.google.com/uc?id={file_id}'
        gdown.download(gdown_url, output_path, quiet=False)
    else:
        print(f"{output_path} already exists")

resnet_url = "https://drive.google.com/file/d/1MWtAKTHtx_1qRYLEac-DDAStZdFD47yq/view?usp=sharing"
vit_url = "https://drive.google.com/file/d/1WDaQZgNKuLvPq_J4HdWtRk1bSECPzip8/view?usp=sharing"

RESNET_PATH = '/content/best_model_resnet.pth'
VIT_PATH = '/content/best_model_vit.pth'

print("Downloading model weights...")
download_model_from_drive(resnet_url, RESNET_PATH)
download_model_from_drive(vit_url, VIT_PATH)
print("Models ready!")

In [None]:
# @title 5. Label Encoding (FIT ON DEV ONLY!)
print("="*80)
print("LABEL ENCODING (DEV SET ONLY)")
print("="*80)

# ✅ FIX: Fit LabelEncoder ONLY on dev set (no data leakage)
le = LabelEncoder()
le.fit(df_dev['prdtypecode'])

df_dev['encoded_label'] = le.transform(df_dev['prdtypecode'])
df_holdout['encoded_label'] = le.transform(df_holdout['prdtypecode'])

num_classes = len(le.classes_)
print(f"✓ LabelEncoder fitted on dev set ONLY (no data leakage)")
print(f"✓ Number of classes: {num_classes}")
assert num_classes == 27, f"Expected 27 classes, got {num_classes}"
print("="*80)

In [None]:
# @title 6. Split Dev into Train/Val
CONFIG = {
    "random_state": 42,
    "val_split": 0.15,
    "batch_size": 32,
    "num_workers": 2,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "epochs": 10,
    "lr": 1e-3,
}

print("="*80)
print("SPLITTING DEV SET (85% train / 15% val)")
print("="*80)

train_indices, val_indices = train_test_split(
    df_dev.index,
    test_size=CONFIG['val_split'],
    random_state=CONFIG['random_state'],
    stratify=df_dev['encoded_label']
)

df_train = df_dev.loc[train_indices].reset_index(drop=True)
df_val = df_dev.loc[val_indices].reset_index(drop=True)

total_samples = len(df_dev) + len(df_holdout)
print(f"✓ Training:   {len(df_train):,} samples (~{len(df_train)/total_samples*100:.1f}%)")
print(f"✓ Validation: {len(df_val):,} samples (~{len(df_val)/total_samples*100:.1f}%)")
print(f"✓ Hold-out:   {len(df_holdout):,} samples (15.0%)")
print("\n⚠️  Model selection will use Train/Val ONLY")
print("⚠️  Holdout will be evaluated at the END")
print("="*80)

In [None]:
# @title 7. Define Models and Datasets
class RakutenImageDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform=None, label_col="encoded_label"):
        self.image_ids = dataframe['imageid'].values
        self.product_ids = dataframe['productid'].values
        self.labels = dataframe[label_col].values
        self.image_dir = Path(image_dir)
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        product_id = self.product_ids[idx]
        label = self.labels[idx]
        img_name = f"image_{image_id}_product_{product_id}.jpg"
        img_path = self.image_dir / img_name

        try:
            image = Image.open(img_path).convert("RGB")
        except (FileNotFoundError, OSError):
            image = Image.new('RGB', (224, 224), (0, 0, 0))

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

class ResNet50Classifier(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.3):
        super(ResNet50Classifier, self).__init__()
        self.backbone = models.resnet50(weights=None)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.custom_head = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        logits = self.custom_head(features)
        return logits

# Transforms
resnet_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

vit_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load pretrained models
print("Loading models...")
resnet_model = ResNet50Classifier(num_classes=num_classes)
state_dict = torch.load(RESNET_PATH, map_location=CONFIG['device'])
if 'model_state_dict' in state_dict:
    state_dict = state_dict['model_state_dict']
resnet_model.load_state_dict(state_dict)
resnet_model.to(CONFIG['device'])
resnet_model.eval()

vit_model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=num_classes,
    ignore_mismatched_sizes=True
)
state_dict_vit = torch.load(VIT_PATH, map_location=CONFIG['device'])
if 'model_state_dict' in state_dict_vit:
    state_dict_vit = state_dict_vit['model_state_dict']
vit_model.load_state_dict(state_dict_vit)
vit_model.to(CONFIG['device'])
vit_model.eval()

print("✓ Models loaded")

In [None]:
# @title 8. Define Fusion Model
class DualDataset(Dataset):
    def __init__(self, df, img_dir, trans_res, trans_vit, label_col='encoded_label'):
        self.image_ids = df['imageid'].values
        self.product_ids = df['productid'].values
        self.labels = df[label_col].values
        self.img_dir = Path(img_dir)
        self.trans_res = trans_res
        self.trans_vit = trans_vit

    def __len__(self): 
        return len(self.labels)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        product_id = self.product_ids[idx]
        label = self.labels[idx]
        img_name = f"image_{image_id}_product_{product_id}.jpg"
        img_path = self.img_dir / img_name

        try:
            img = Image.open(img_path).convert("RGB")
        except:
            img = Image.new('RGB', (224, 224), (0, 0, 0))

        return self.trans_res(img), self.trans_vit(img), torch.tensor(label, dtype=torch.long)

class FusionNet(nn.Module):
    def __init__(self, resnet, vit, num_classes=27):
        super().__init__()
        self.resnet_backbone = resnet.backbone
        self.vit_backbone = vit.vit

        # Freeze backbones
        for p in self.resnet_backbone.parameters(): 
            p.requires_grad = False
        for p in self.vit_backbone.parameters(): 
            p.requires_grad = False

        # Fusion head (only trainable part)
        self.head = nn.Sequential(
            nn.Linear(2816, 1024),  # 2048 (ResNet) + 768 (ViT)
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x_res, x_vit):
        with torch.no_grad():
            feat_res = self.resnet_backbone(x_res)
            if len(feat_res.shape) > 2:
                feat_res = torch.flatten(feat_res, 1)

            out_vit = self.vit_backbone(pixel_values=x_vit)
            feat_vit = out_vit.last_hidden_state[:, 0]

        combined = torch.cat([feat_res, feat_vit], dim=1)
        return self.head(combined)

# Create datasets
train_ds_dual = DualDataset(df_train, IMG_ROOT, resnet_transform, vit_transform)
train_loader = DataLoader(train_ds_dual, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=CONFIG['num_workers'])

val_ds_dual = DualDataset(df_val, IMG_ROOT, resnet_transform, vit_transform)
val_loader = DataLoader(val_ds_dual, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'])

# Create holdout loader for final evaluation
holdout_ds_dual = DualDataset(df_holdout, IMG_ROOT, resnet_transform, vit_transform)
holdout_loader = DataLoader(holdout_ds_dual, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=CONFIG['num_workers'])

# Initialize fusion model
fusion_model = FusionNet(resnet_model, vit_model, num_classes=num_classes).to(CONFIG['device'])
optimizer = AdamW(fusion_model.head.parameters(), lr=CONFIG['lr'], weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
criterion = nn.CrossEntropyLoss()

print("✓ Fusion model initialized")
print(f"  Trainable params: {sum(p.numel() for p in fusion_model.head.parameters()):,}")

In [None]:
# @title 9. Training Loop with WandB
import datetime

# Detect execution environment
import sys
ENVIRONMENT = "colab" if 'google.colab' in sys.modules else "local"


# Initialize WandB

wandb.init(
    project="rakuten-fusion",
    entity="xiaosong-dev-formation-data-science",
    name=f"vit_resnet50_fusion_v1_{datetime.now().strftime('%Y%m%d_%H%M')}",
    tags=["fusion", "vit", "v1", "production", ENVIRONMENT],
    config=CONFIG,
    notes="ViT + ResNet50 multimodal fusion with proper data split"
)

print("="*80)
print("TRAINING FUSION MODEL")
print("="*80)

best_val_f1 = 0.0
history = {"train_loss": [], "val_f1": [], "val_acc": []}

for epoch in range(CONFIG['epochs']):
    # Training
    fusion_model.train()
    train_loss = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
    for x_res, x_vit, y in pbar:
        x_res = x_res.to(CONFIG['device'])
        x_vit = x_vit.to(CONFIG['device'])
        y = y.to(CONFIG['device'])

        optimizer.zero_grad()
        outputs = fusion_model(x_res, x_vit)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_train_loss = train_loss / len(train_loader)

    # Validation
    fusion_model.eval()
    val_preds = []
    val_targets = []
    
    with torch.no_grad():
        for x_res, x_vit, y in val_loader:
            x_res = x_res.to(CONFIG['device'])
            x_vit = x_vit.to(CONFIG['device'])
            outputs = fusion_model(x_res, x_vit)
            val_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            val_targets.extend(y.numpy())

    val_f1 = f1_score(val_targets, val_preds, average='weighted')
    val_acc = accuracy_score(val_targets, val_preds)

    # Log to WandB
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": avg_train_loss,
        "val_f1": val_f1,
        "val_acc": val_acc,
        "learning_rate": optimizer.param_groups[0]['lr']
    })

    # Save history
    history["train_loss"].append(avg_train_loss)
    history["val_f1"].append(val_f1)
    history["val_acc"].append(val_acc)

    print(f"Epoch {epoch+1}: Loss={avg_train_loss:.4f}, Val F1={val_f1:.4f}, Val Acc={val_acc:.4f}")

    # Learning rate scheduling
    scheduler.step(val_f1)

    # Save best model
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(fusion_model.state_dict(), "fusion_model_best_FIXED.pth")
        print(f"  ✅ Best model saved! (Val F1: {val_f1:.4f})")

print("\n" + "="*80)
print(f"Training Complete. Best Val F1: {best_val_f1:.4f}")
print("="*80)

In [None]:
# @title 10. FINAL EVALUATION ON HOLDOUT TEST SET
print("\n" + "="*80)
print("FINAL EVALUATION ON HOLDOUT TEST SET")
print("="*80)
print("⚠️  This is the FIRST and ONLY time holdout data is used!\n")

# Load best model
fusion_model.load_state_dict(torch.load("fusion_model_best_FIXED.pth"))
fusion_model.eval()

# Evaluate on holdout
holdout_preds = []
holdout_targets = []

with torch.no_grad():
    for x_res, x_vit, y in tqdm(holdout_loader, desc="Holdout Evaluation"):
        x_res = x_res.to(CONFIG['device'])
        x_vit = x_vit.to(CONFIG['device'])
        outputs = fusion_model(x_res, x_vit)
        holdout_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
        holdout_targets.extend(y.numpy())

holdout_f1 = f1_score(holdout_targets, holdout_preds, average='weighted')
holdout_acc = accuracy_score(holdout_targets, holdout_preds)

print("\n" + "="*80)
print("FINAL RESULTS")
print("="*80)
print(f"Best Validation F1: {best_val_f1:.4f}")
print(f"Holdout Test F1:    {holdout_f1:.4f}")
print(f"Holdout Test Acc:   {holdout_acc:.4f}")
print(f"Difference:         {holdout_f1 - best_val_f1:+.4f}")
print("="*80)

# Log final results to WandB
wandb.log({
    "final/best_val_f1": best_val_f1,
    "final/holdout_f1": holdout_f1,
    "final/holdout_acc": holdout_acc
})

# Classification report
print("\nClassification Report (Holdout):")
print(classification_report(holdout_targets, holdout_preds, digits=4, zero_division=0))

wandb.finish()

In [None]:
# @title 11. Save to Google Drive (Optional)
from google.colab import drive
import shutil

drive.mount('/content/drive')

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
target_dir = "/content/drive/MyDrive/Rakuten_models"
os.makedirs(target_dir, exist_ok=True)

target_file = os.path.join(target_dir, f"fusion_vit_resnet_FIXED_{timestamp}.pth")
shutil.copy("fusion_model_best_FIXED.pth", target_file)
print(f"✓ Model saved to: {target_file}")