# Swin + DeBERTa Multimodal Fusion (FIXED VERSION)

## ✅ Data Leakage Issues FIXED:
- **Proper data split**: 85% dev (train+val) / 15% holdout (final test)
- **LabelEncoder fitted ONLY on dev set**
- **Holdout set used ONLY for final evaluation**
- **WandB integration** for experiment tracking

## Changes from Original:
1. ❌ Old: Used all 84,916 samples (100%) → ✅ Fixed: 72,178 dev / 12,738 holdout
2. ❌ Old: train/val split on full data → ✅ Fixed: 15% holdout reserved from start
3. ❌ Old: No experiment tracking → ✅ Fixed: WandB integration
4. ❌ Old: LabelEncoder on all data → ✅ Fixed: Fit only on dev set

In [None]:
# @title 1. Setup Environment & Dependencies
!pip install -q transformers timm gdown pandas scikit-learn matplotlib wandb sentencepiece

import os
import gc
import datetime
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import gdown
import timm
import wandb
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score, accuracy_score, classification_report
from tqdm.auto import tqdm
from torchvision import transforms

# Set seed for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Environment setup complete. Using device: {device}")

# Initialize WandB
wandb.login()

In [None]:
# @title 2. Download Data with PROPER SPLIT
def load_csv_from_gdrive(share_url: str, **read_csv_kwargs) -> pd.DataFrame:
    try:
        file_id = share_url.split("/d/")[1].split("/")[0]
        download_url = f"https://drive.google.com/uc?id={file_id}"
        return pd.read_csv(download_url, **read_csv_kwargs)
    except IndexError:
        print(f"Error parsing URL: {share_url}")
        return None

print("Downloading CSV data...")
X_train_url = "https://drive.google.com/file/d/1geSiJTTjamysiSbJ8-W9gR1kv-x6HyEd/view?usp=drive_link"
y_train_url = "https://drive.google.com/file/d/16czWmLR5Ff0s5aYIqy1rHT7hc6Gcpfw3/view?usp=sharing"

try:
    X_train_full = load_csv_from_gdrive(X_train_url)
    y_train_full = load_csv_from_gdrive(y_train_url)

    if X_train_full is not None and y_train_full is not None:
        print(f"Total data loaded: {len(X_train_full):,} samples")
        
        # ============================================================================
        # ✅ FIX: PROPER DATA SPLIT (85% dev / 15% holdout)
        # ============================================================================
        print("\n" + "="*80)
        print("SPLITTING DATA (85% dev / 15% holdout)")
        print("="*80)
        
        X_dev, X_holdout, y_dev, y_holdout = train_test_split(
            X_train_full,
            y_train_full['prdtypecode'],
            test_size=0.15,
            random_state=42,
            stratify=y_train_full['prdtypecode']
        )
        
        df_dev = X_dev.copy()
        df_dev['prdtypecode'] = y_dev
        
        df_holdout = X_holdout.copy()
        df_holdout['prdtypecode'] = y_holdout
        
        print(f"✓ Development set: {len(df_dev):,} samples (85%)")
        print(f"✓ Hold-out test set: {len(df_holdout):,} samples (15%)")
        print(f"✓ Classes: {df_dev['prdtypecode'].nunique()}")
        print("\n⚠️  CRITICAL: Holdout set will ONLY be used for final evaluation!")
        print("="*80)
    else:
        raise ValueError("Failed to load DataFrames")
except Exception as e:
    print(f"CSV download failed: {e}")

In [None]:
# @title 3. Download Images
IMAGE_FILE_ID = "15ZkS0iTQ7j3mHpxil4mABlXwP-jAN_zi"

if not os.path.exists("/content/images"):
    print("\nDownloading images...")
    os.makedirs("/content/tmp", exist_ok=True)
    os.makedirs("/content/images", exist_ok=True)
    !gdown --id $IMAGE_FILE_ID -O /content/tmp/images.zip

    print("Unzipping images...")
    !unzip -q -o /content/tmp/images.zip -d /content/images
    print("Images unzipped")
else:
    print("\nImages already exist, skipping download")

IMG_ROOT = "/content/images/images/image_train"
print(f"Image Root: {IMG_ROOT}")

In [None]:
# @title 4. Download Custom Model Weights
print("\nDownloading custom model weights...")

# Swin V2 Weights
if not os.path.exists('swin_v2_best.pth'):
    gdown.download(id='1MM87li5P6pWzCs-7uwlQb0msH8s8kZDf', output='swin_v2_best.pth', quiet=False)

# DeBERTa V2 Weights
if not os.path.exists('deberta_v2_best.pth'):
    gdown.download(id='1wNsbubOYXJa611AaeCjkM4bHSstgpPy7', output='deberta_v2_best.pth', quiet=False)

print("All files ready.")

In [None]:
# @title 5. Label Encoding (FIT ON DEV ONLY!)
print("="*80)
print("LABEL ENCODING (DEV SET ONLY)")
print("="*80)

# ✅ FIX: Encode labels on DEV SET ONLY
le = LabelEncoder()
le.fit(df_dev['prdtypecode'])  # ✅ FIT ONLY ON DEV

df_dev['encoded_label'] = le.transform(df_dev['prdtypecode'])
df_holdout['encoded_label'] = le.transform(df_holdout['prdtypecode'])

NUM_CLASSES = len(le.classes_)
print(f"✓ LabelEncoder fitted on dev set ONLY (no data leakage)")
print(f"✓ Number of classes: {NUM_CLASSES}")
assert NUM_CLASSES == 27, f"Expected 27 classes, got {NUM_CLASSES}"
print("="*80)

In [None]:
# @title 6. Text Preprocessing & Train/Val Split
# Text Cleaning: Merge designation and description
df_dev['text'] = df_dev['designation'].fillna('') + " " + df_dev['description'].fillna('')
df_dev['text'] = df_dev['text'].astype(str).str.lower()

df_holdout['text'] = df_holdout['designation'].fillna('') + " " + df_holdout['description'].fillna('')
df_holdout['text'] = df_holdout['text'].astype(str).str.lower()

# Split Dev into Train/Val
print("="*80)
print("SPLITTING DEV SET (85% train / 15% val)")
print("="*80)

train_df, val_df = train_test_split(
    df_dev,
    test_size=0.15,
    random_state=42,
    stratify=df_dev['encoded_label']
)

total_samples = len(df_dev) + len(df_holdout)
print(f"✓ Training:   {len(train_df):,} samples (~{len(train_df)/total_samples*100:.1f}%)")
print(f"✓ Validation: {len(val_df):,} samples (~{len(val_df)/total_samples*100:.1f}%)")
print(f"✓ Hold-out:   {len(df_holdout):,} samples (15.0%)")
print("\n⚠️  Model selection will use Train/Val ONLY")
print("⚠️  Holdout will be evaluated at the END")
print("="*80)

In [None]:
# @title 7. Dataset Definition
class RakutenMultiModalDataset(Dataset):
    def __init__(self, df, img_root, tokenizer, transform=None, max_len=128):
        self.df = df.reset_index(drop=True)
        self.img_root = img_root
        self.tokenizer = tokenizer
        self.transform = transform
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # 1. Image Processing
        img_name = f"image_{row['imageid']}_product_{row['productid']}.jpg"
        img_path = os.path.join(self.img_root, img_name)
        try:
            image = Image.open(img_path).convert("RGB")
        except (FileNotFoundError, OSError):
            # Fallback for missing/corrupt images
            image = Image.new('RGB', (224, 224), (0, 0, 0))

        if self.transform:
            image = self.transform(image)

        # 2. Text Processing
        text = str(row['text'])
        inputs = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors="pt"
        )

        return {
            'pixel_values': image,
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': torch.tensor(row['encoded_label'], dtype=torch.long)
        }

print(f"Dataset class ready. Training: {len(train_df):,}, Validation: {len(val_df):,}, Holdout: {len(df_holdout):,}")

In [None]:
# @title 8. Define Fusion Model
class RakutenFusionModel(nn.Module):
    def __init__(self, num_classes=27, text_model_name="microsoft/deberta-v3-base", img_model_name="swin_base_patch4_window7_224"):
        super().__init__()

        # --- Text Backbone (DeBERTa) ---
        self.text_backbone = AutoModel.from_pretrained(text_model_name)
        text_dim = self.text_backbone.config.hidden_size  # 768 for Base

        # --- Image Backbone (Swin) ---
        self.img_backbone = timm.create_model(img_model_name, pretrained=True, num_classes=0)
        img_dim = self.img_backbone.num_features  # 1024 for Base

        # --- Fusion Head ---
        fusion_dim = text_dim + img_dim
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(fusion_dim),
            nn.Dropout(0.3),
            nn.Linear(fusion_dim, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, pixel_values, input_ids, attention_mask):
        # Image Features
        img_feats = self.img_backbone(pixel_values)  # [Batch, 1024]

        # Text Features (CLS Token)
        text_out = self.text_backbone(input_ids=input_ids, attention_mask=attention_mask)
        text_feats = text_out.last_hidden_state[:, 0, :]  # [Batch, 768]

        # Concatenate
        features = torch.cat([img_feats, text_feats], dim=1)

        # Classify
        logits = self.classifier(features)
        return logits

def smart_load_weights(model, swin_path, deberta_path):
    """Loads weights safely by filtering out mismatched layers."""
    print("Loading weights...")

    # 1. Load Swin Weights
    try:
        swin_state = torch.load(swin_path, map_location='cpu', weights_only=False)
        if 'model_state_dict' in swin_state: 
            swin_state = swin_state['model_state_dict']
        elif 'state_dict' in swin_state: 
            swin_state = swin_state['state_dict']

        msg = model.img_backbone.load_state_dict(swin_state, strict=False)
        print(f" -> Swin weights loaded. Missing keys: {len(msg.missing_keys)}")
    except Exception as e:
        print(f" -> Error loading Swin weights: {e}")

    # 2. Load DeBERTa Weights
    try:
        text_state = torch.load(deberta_path, map_location='cpu', weights_only=False)
        if 'model_state_dict' in text_state: 
            text_state = text_state['model_state_dict']

        # Filter out classifier layers
        new_state_dict = {k: v for k, v in text_state.items() if 'classifier' not in k and 'pooler' not in k}

        msg = model.text_backbone.load_state_dict(new_state_dict, strict=False)
        print(f" -> DeBERTa weights loaded. Unexpected keys: {len(msg.unexpected_keys)}")
    except Exception as e:
        print(f" -> Error loading DeBERTa weights: {e}")

print("Model architecture defined.")

In [None]:
# @title 9. Initialize Model and Dataloaders
BATCH_SIZE = 16
LR = 2e-5
EPOCHS = 5
IMG_SIZE = 224

# Tokenizer & Transforms
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")

transform_train = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Instantiate Datasets & Loaders
train_ds = RakutenMultiModalDataset(train_df, IMG_ROOT, tokenizer, transform=transform_train)
val_ds = RakutenMultiModalDataset(val_df, IMG_ROOT, tokenizer, transform=transform_val)
holdout_ds = RakutenMultiModalDataset(df_holdout, IMG_ROOT, tokenizer, transform=transform_val)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
holdout_loader = DataLoader(holdout_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# Initialize Model
model = RakutenFusionModel(num_classes=NUM_CLASSES)
model.to(device)

# Load Custom Weights
smart_load_weights(model, 'swin_v2_best.pth', 'deberta_v2_best.pth')

print("Model initialized and weights loaded.")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# @title 10. Training with WandB

# Detect execution environment
import sys
ENVIRONMENT = "colab" if 'google.colab' in sys.modules else "local"

# Initialize WandB

wandb.init(
    project="rakuten-fusion",
    entity="xiaosong-dev-formation-data-science",
    name=f"swin_deberta_fusion_v1_{datetime.now().strftime('%Y%m%d_%H%M')}",
    tags=["fusion", "swin", "v1", "production", ENVIRONMENT],
    config=CONFIG,
    notes="Swin + DeBERTa multimodal fusion with proper data split"
)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=LR, steps_per_epoch=len(train_loader), epochs=EPOCHS
)

def evaluate(model, loader):
    model.eval()
    preds, targets = [], []
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating"):
            pixel_values = batch['pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(pixel_values, input_ids, mask)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
            targets.extend(labels.cpu().numpy())

    f1 = f1_score(targets, preds, average='weighted')
    acc = accuracy_score(targets, preds)
    return val_loss / len(loader), f1, acc

# Training Loop
print("="*80)
print(f"TRAINING MULTIMODAL FUSION ({EPOCHS} epochs)")
print("="*80)

best_f1 = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch in pbar:
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values, input_ids, mask)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        scheduler.step()

        train_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})

    # Validation
    val_loss, val_f1, val_acc = evaluate(model, val_loader)
    
    # Log to WandB
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss / len(train_loader),
        "val_loss": val_loss,
        "val_f1": val_f1,
        "val_acc": val_acc,
        "learning_rate": optimizer.param_groups[0]['lr']
    })
    
    print(f"Epoch {epoch+1}: Loss={train_loss/len(train_loader):.4f}, Val F1={val_f1:.4f}, Val Acc={val_acc:.4f}")

    # Save Best Model
    if val_f1 > best_f1:
        best_f1 = val_f1
        torch.save(model.state_dict(), "fusion_model_best_FIXED.pth")
        print("  ✅ New Best Model Saved!")

print("\n" + "="*80)
print(f"Training Complete. Best Val F1: {best_f1:.4f}")
print("="*80)

In [None]:
# @title 11. FINAL EVALUATION ON HOLDOUT TEST SET
print("\n" + "="*80)
print("FINAL EVALUATION ON HOLDOUT TEST SET")
print("="*80)
print("⚠️  This is the FIRST and ONLY time holdout data is used!\n")

# Load best model
model.load_state_dict(torch.load("fusion_model_best_FIXED.pth"))
holdout_loss, holdout_f1, holdout_acc = evaluate(model, holdout_loader)

print("\n" + "="*80)
print("FINAL RESULTS")
print("="*80)
print(f"Best Validation F1: {best_f1:.4f}")
print(f"Holdout Test F1:    {holdout_f1:.4f}")
print(f"Holdout Test Acc:   {holdout_acc:.4f}")
print(f"Difference:         {holdout_f1 - best_f1:+.4f}")
print("="*80)

# Log final results to WandB
wandb.log({
    "final/best_val_f1": best_f1,
    "final/holdout_f1": holdout_f1,
    "final/holdout_acc": holdout_acc
})

# Get detailed predictions for classification report
model.eval()
all_preds, all_targets = [], []
with torch.no_grad():
    for batch in holdout_loader:
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['labels']
        
        outputs = model(pixel_values, input_ids, mask)
        all_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
        all_targets.extend(labels.numpy())

print("\nClassification Report (Holdout):")
print(classification_report(all_targets, all_preds, digits=4, zero_division=0))

wandb.finish()
print("\n✅ Training and evaluation complete!")

In [None]:
# @title 12. Save to Google Drive (Optional)
from google.colab import drive
import shutil

drive.mount('/content/drive')

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
target_dir = "/content/drive/MyDrive/Rakuten_models"
os.makedirs(target_dir, exist_ok=True)

target_file = os.path.join(target_dir, f"fusion_swin_deberta_FIXED_{timestamp}.pth")
shutil.copy("fusion_model_best_FIXED.pth", target_file)
print(f"✓ Model saved to: {target_file}")