# Model Training: Product Matching

This notebook demonstrates how to train models for product similarity and matching tasks.

**Problem**: Match duplicate/similar products across Shopee listings

**Approach**: 
1. Extract text and image features
2. Create pairs dataset (positive and negative pairs)
3. Train similarity model (Siamese Network or Metric Learning)
4. Evaluate on matching task

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image
import torchvision.models as models
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set up paths
data_dir = Path('../shopee-product-matching-data')
train_csv = data_dir / 'train.csv'
train_images_dir = data_dir / 'train_images'

# Load data
train_df = pd.read_csv(train_csv)
print(f"Data loaded: {train_df.shape[0]} products, {train_df['label_group'].nunique()} groups")

## 1. Create Pairs Dataset

For training similarity models, we need to create pairs:
- **Positive pairs**: Products from the same group (similar products)
- **Negative pairs**: Products from different groups (dissimilar products)

In [None]:
def create_pairs_dataset(df, positive_ratio=0.5, seed=42):
    """
    Create positive and negative pairs for training.
    
    Args:
        df: DataFrame with products
        positive_ratio: ratio of positive pairs (0-1)
        seed: random seed
    
    Returns:
        List of (idx1, idx2, label) tuples where label=1 for same group, 0 for different
    """
    np.random.seed(seed)
    pairs = []
    
    # Get all indices and groups
    df_indexed = df.reset_index(drop=True)
    group_to_indices = {}
    
    for idx, group_id in enumerate(df_indexed['label_group']):
        if group_id not in group_to_indices:
            group_to_indices[group_id] = []
        group_to_indices[group_id].append(idx)
    
    # Create positive pairs (same group)
    positive_pairs = []
    for group_id, indices in group_to_indices.items():
        if len(indices) >= 2:
            # Create all pairs within group
            for i in range(len(indices)):
                for j in range(i+1, len(indices)):
                    positive_pairs.append((indices[i], indices[j], 1))
    
    # Create negative pairs (different groups)
    negative_pairs = []
    num_negative = int(len(positive_pairs) / positive_ratio) - len(positive_pairs)
    
    all_indices = list(range(len(df_indexed)))
    while len(negative_pairs) < num_negative:
        idx1, idx2 = np.random.choice(all_indices, 2, replace=False)
        if df_indexed.loc[idx1, 'label_group'] != df_indexed.loc[idx2, 'label_group']:
            negative_pairs.append((idx1, idx2, 0))
    
    # Combine and shuffle
    pairs = positive_pairs + negative_pairs
    np.random.shuffle(pairs)
    
    return pairs

print("Creating pairs dataset...")
pairs = create_pairs_dataset(train_df, positive_ratio=0.5)

print(f"\nPairs Dataset Created:")
print(f"  Total pairs: {len(pairs):,}")
positive_count = sum(1 for p in pairs if p[2] == 1)
negative_count = sum(1 for p in pairs if p[2] == 0)
print(f"  Positive pairs (same group): {positive_count:,} ({positive_count/len(pairs)*100:.1f}%)")
print(f"  Negative pairs (different group): {negative_count:,} ({negative_count/len(pairs)*100:.1f}%)")

# Split into train/val
train_pairs, val_pairs = train_test_split(pairs, test_size=0.2, random_state=42)
print(f"\nTrain/Val Split:")
print(f"  Train pairs: {len(train_pairs):,}")
print(f"  Val pairs: {len(val_pairs):,}")

## 2. Dataset & DataLoader Classes

In [None]:
class TextImagePairDataset(Dataset):
    """
    Dataset for product pairs with both text and image data.
    """
    def __init__(self, df, pairs, images_dir, transform=None, use_text=True, use_image=True):
        self.df = df.reset_index(drop=True)
        self.pairs = pairs
        self.images_dir = images_dir
        self.transform = transform
        self.use_text = use_text
        self.use_image = use_image
        
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        idx1, idx2, label = self.pairs[idx]
        
        result = {'label': torch.tensor(label, dtype=torch.float32)}
        
        # Get text features
        if self.use_text:
            title1 = self.df.loc[idx1, 'title']
            title2 = self.df.loc[idx2, 'title']
            
            # Simple text encoding: character count and word count
            text_features1 = np.array([
                len(title1),
                len(title1.split())
            ], dtype=np.float32)
            text_features2 = np.array([
                len(title2),
                len(title2.split())
            ], dtype=np.float32)
            
            result['text1'] = torch.tensor(text_features1)
            result['text2'] = torch.tensor(text_features2)
        
        # Get image features
        if self.use_image:
            img_path1 = self.images_dir / self.df.loc[idx1, 'image']
            img_path2 = self.images_dir / self.df.loc[idx2, 'image']
            
            try:
                img1 = Image.open(img_path1).convert('RGB')
                if self.transform:
                    img1 = self.transform(img1)
                result['image1'] = img1
            except:
                result['image1'] = torch.zeros(3, 224, 224)
            
            try:
                img2 = Image.open(img_path2).convert('RGB')
                if self.transform:
                    img2 = self.transform(img2)
                result['image2'] = img2
            except:
                result['image2'] = torch.zeros(3, 224, 224)
        
        return result

# Image preprocessing
image_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = TextImagePairDataset(
    train_df, train_pairs, train_images_dir, 
    transform=image_transform, use_text=True, use_image=True
)

val_dataset = TextImagePairDataset(
    train_df, val_pairs, train_images_dir,
    transform=image_transform, use_text=True, use_image=True
)

print(f"Train Dataset: {len(train_dataset)} pairs")
print(f"Val Dataset: {len(val_dataset)} pairs")

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

print(f"\nDataLoaders created with batch_size={batch_size}")

## 3. Model Architecture: Siamese Network

A Siamese network learns to compare two inputs and output a similarity score.

In [None]:
class SiameseNetwork(nn.Module):
    """
    Siamese Network for product similarity.
    Combines image embeddings and text embeddings.
    """
    def __init__(self, image_embedding_dim=512, text_embedding_dim=32, fusion_dim=256):
        super(SiameseNetwork, self).__init__()
        
        # Image encoder (using pre-trained ResNet50)
        resnet = models.resnet50(pretrained=True)
        self.image_encoder = nn.Sequential(*list(resnet.children())[:-1])  # Remove classification layer
        self.image_fc = nn.Linear(2048, image_embedding_dim)
        
        # Text encoder
        self.text_encoder = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, text_embedding_dim)
        )
        
        # Fusion layers
        total_embedding_dim = image_embedding_dim + text_embedding_dim
        self.fusion = nn.Sequential(
            nn.Linear(total_embedding_dim * 2, fusion_dim),  # *2 because we concatenate both pairs
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(fusion_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1),
            nn.Sigmoid()  # Output probability
        )
    
    def encode_pair(self, images, texts):
        """
        Encode image and text pair into embeddings.
        """
        # Image encoding
        img_features = self.image_encoder(images)
        img_features = img_features.view(img_features.size(0), -1)
        img_embedding = self.image_fc(img_features)
        
        # Text encoding
        text_embedding = self.text_encoder(texts)
        
        # Concatenate
        embedding = torch.cat([img_embedding, text_embedding], dim=1)
        return embedding
    
    def forward(self, image1, text1, image2, text2):
        """
        Forward pass: compare two product pairs.
        """
        # Encode both products
        embedding1 = self.encode_pair(image1, text1)
        embedding2 = self.encode_pair(image2, text2)
        
        # Concatenate embeddings
        combined = torch.cat([embedding1, embedding2], dim=1)
        
        # Predict similarity
        similarity = self.fusion(combined)
        
        return similarity.squeeze()

# Initialize model
model = SiameseNetwork(
    image_embedding_dim=512,
    text_embedding_dim=32,
    fusion_dim=256
).to(device)

print("Model Architecture:")
print(model)
print(f"\nModel initialized on device: {device}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 4. Training Loop

In [None]:
# Loss and optimizer
criterion = nn.BCELoss()  # Binary Cross Entropy for similarity prediction
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

def train_epoch(model, train_loader, criterion, optimizer, device):
    """
    Train for one epoch.
    """
    model.train()
    total_loss = 0
    predictions = []
    targets = []
    
    for batch_idx, batch in enumerate(train_loader):
        # Move data to device
        image1 = batch['image1'].to(device)
        image2 = batch['image2'].to(device)
        text1 = batch['text1'].to(device)
        text2 = batch['text2'].to(device)
        labels = batch['label'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(image1, text1, image2, text2)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        predictions.extend(outputs.detach().cpu().numpy())
        targets.extend(labels.detach().cpu().numpy())
        
        if (batch_idx + 1) % 50 == 0:
            print(f"  Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")
    
    avg_loss = total_loss / len(train_loader)
    auc = roc_auc_score(targets, predictions)
    
    return avg_loss, auc

def validate(model, val_loader, criterion, device):
    """
    Validate the model.
    """
    model.eval()
    total_loss = 0
    predictions = []
    targets = []
    
    with torch.no_grad():
        for batch in val_loader:
            image1 = batch['image1'].to(device)
            image2 = batch['image2'].to(device)
            text1 = batch['text1'].to(device)
            text2 = batch['text2'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(image1, text1, image2, text2)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            predictions.extend(outputs.cpu().numpy())
            targets.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(val_loader)
    auc = roc_auc_score(targets, predictions)
    
    return avg_loss, auc

print("Training setup complete. Ready to train!")

## 5. Training Execution

**Note**: Full training may take 10-30 minutes depending on your hardware.

In [None]:
# Training parameters
epochs = 10
best_val_auc = 0
patience = 3
patience_counter = 0

# Track metrics
train_losses = []
val_losses = []
train_aucs = []
val_aucs = []

print(f"\nStarting training for {epochs} epochs...\n")

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    print(f"Learning rate: {optimizer.param_groups[0]['lr']:.2e}")
    
    # Train
    train_loss, train_auc = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    train_aucs.append(train_auc)
    
    # Validate
    val_loss, val_auc = validate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_aucs.append(val_auc)
    
    print(f"\nTraining Loss: {train_loss:.4f}, AUC: {train_auc:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, AUC: {val_auc:.4f}")
    
    # Early stopping
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        patience_counter = 0
        # Save best model
        torch.save(model.state_dict(), 'best_model.pt')
        print(f"âœ“ Best model saved! (AUC: {val_auc:.4f})")
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{patience}")
    
    scheduler.step()
    
    if patience_counter >= patience:
        print(f"\nEarly stopping triggered after epoch {epoch+1}")
        break

print(f"\n{'='*60}")
print(f"Training Complete!")
print(f"Best Validation AUC: {best_val_auc:.4f}")
print(f"{'='*60}")

## 6. Training History Visualization

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(train_losses, label='Train Loss', marker='o', linewidth=2)
axes[0].plot(val_losses, label='Val Loss', marker='s', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss', fontsize=12, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# AUC
axes[1].plot(train_aucs, label='Train AUC', marker='o', linewidth=2)
axes[1].plot(val_aucs, label='Val AUC', marker='s', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('AUC Score')
axes[1].set_title('Training & Validation AUC', fontsize=12, fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nTraining Summary:")
print(f"  Final Train Loss: {train_losses[-1]:.4f}")
print(f"  Final Val Loss: {val_losses[-1]:.4f}")
print(f"  Final Train AUC: {train_aucs[-1]:.4f}")
print(f"  Final Val AUC: {val_aucs[-1]:.4f}")

## 7. Model Evaluation & Prediction

In [None]:
# Load best model
model.load_state_dict(torch.load('best_model.pt'))
model.eval()

# Get predictions on validation set
all_predictions = []
all_targets = []

with torch.no_grad():
    for batch in val_loader:
        image1 = batch['image1'].to(device)
        image2 = batch['image2'].to(device)
        text1 = batch['text1'].to(device)
        text2 = batch['text2'].to(device)
        labels = batch['label']
        
        outputs = model(image1, text1, image2, text2)
        all_predictions.extend(outputs.cpu().numpy())
        all_targets.extend(labels.numpy())

all_predictions = np.array(all_predictions)
all_targets = np.array(all_targets)

# Convert to binary predictions (threshold = 0.5)
binary_predictions = (all_predictions >= 0.5).astype(int)

# Calculate metrics
from sklearn.metrics import confusion_matrix, classification_report

print(f"\nEvaluation Metrics on Validation Set:")
print(f"{'='*60}")
print(f"\nPrecision: {precision_score(all_targets, binary_predictions):.4f}")
print(f"Recall: {recall_score(all_targets, binary_predictions):.4f}")
print(f"F1-Score: {f1_score(all_targets, binary_predictions):.4f}")
print(f"ROC-AUC: {roc_auc_score(all_targets, all_predictions):.4f}")

print(f"\nClassification Report:")
print(classification_report(all_targets, binary_predictions, 
                          target_names=['Different Product', 'Same Product']))

# Confusion Matrix
cm = confusion_matrix(all_targets, binary_predictions)
print(f"\nConfusion Matrix:")
print(f"  TN: {cm[0,0]}, FP: {cm[0,1]}")
print(f"  FN: {cm[1,0]}, TP: {cm[1,1]}")

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Confusion Matrix
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0],
            xticklabels=['Different', 'Same'],
            yticklabels=['Different', 'Same'])
axes[0].set_ylabel('True Label')
axes[0].set_xlabel('Predicted Label')
axes[0].set_title('Confusion Matrix', fontsize=12, fontweight='bold')

# Prediction distribution
axes[1].hist(all_predictions[all_targets == 0], bins=30, alpha=0.6, label='Different (Label=0)', color='red')
axes[1].hist(all_predictions[all_targets == 1], bins=30, alpha=0.6, label='Same (Label=1)', color='green')
axes[1].axvline(0.5, color='black', linestyle='--', label='Decision Threshold')
axes[1].set_xlabel('Predicted Similarity Score')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Prediction Score Distribution', fontsize=12, fontweight='bold')
axes[1].legend()

plt.tight_layout()
plt.show()

## 8. Save Model for Inference

In [None]:
# Save model checkpoint
model_path = 'shopee_siamese_model.pt'
torch.save({
    'model_state_dict': model.state_dict(),
    'model_architecture': 'SiameseNetwork',
    'hyperparameters': {
        'image_embedding_dim': 512,
        'text_embedding_dim': 32,
        'fusion_dim': 256
    },
    'training_metrics': {
        'best_val_auc': best_val_auc,
        'final_train_loss': train_losses[-1],
        'final_val_loss': val_losses[-1]
    }
}, model_path)

print(f"âœ“ Model saved to {model_path}")
print(f"\nModel can be loaded with:")
print(f"  checkpoint = torch.load('{model_path}')")
print(f"  model.load_state_dict(checkpoint['model_state_dict'])")

## 9. Summary & Next Steps

In [None]:
print(f"\n" + "="*70)
print(f"MODEL TRAINING COMPLETE - SUMMARY")
print(f"="*70)

print(f"\nðŸ“Š ARCHITECTURE")
print(f"  - Base: Siamese Network with multi-modal fusion")
print(f"  - Image encoder: ResNet50 (pre-trained)")
print(f"  - Text encoder: 2-layer MLP")
print(f"  - Total parameters: {total_params:,}")

print(f"\nðŸ“ˆ TRAINING RESULTS")
print(f"  - Best Validation AUC: {best_val_auc:.4f}")
print(f"  - Precision: {precision_score(all_targets, binary_predictions):.4f}")
print(f"  - Recall: {recall_score(all_targets, binary_predictions):.4f}")
print(f"  - F1-Score: {f1_score(all_targets, binary_predictions):.4f}")

print(f"\nðŸ’¾ MODEL SAVED")
print(f"  - Location: {model_path}")
print(f"  - Use for inference on new product pairs")

print(f"\nðŸš€ NEXT STEPS")
print(f"  1. Deploy model for inference")
print(f"  2. Fine-tune with harder negative mining")
print(f"  3. Experiment with different architectures (Transformer, CLIP)")
print(f"  4. Test on competition test set")
print(f"  5. Ensemble with multiple models for better performance")

print(f"\n" + "="*70)