# Protein FT-Transformer: Model Training Notebook

This notebook demonstrates how to train the Protein FT-Transformer model for multiclass classification.

In [None]:
# Install required packages
!pip install -r ../requirements.txt

In [None]:
import sys
import os
sys.path.append('..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from models.ft_transformer import ProteinFTTransformer
from training.dataloader import DataProcessor
from training.trainer import Trainer
from evaluation.metrics import MetricsCalculator
from evaluation.visualization import plot_training_history, plot_confusion_matrix

# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
# Configuration
config = {
    'data': {
        'train_path': '../data/train.csv',
        'test_path': '../data/test.csv',
        'validation_split': 0.1,
        'target_column': 'target',
        'categorical_columns': [],
        'numeric_columns': []
    },
    'model': {
        'd_token': 128,
        'n_head': 8,
        'n_layers': 6,
        'dropout': 0.2,
        'ff_dim_factor': 4
    },
    'training': {
        'batch_size': 64,
        'num_epochs': 50,
        'learning_rate': 0.001,
        'weight_decay': 0.0001,
        'gradient_clip': 1.0,
        'early_stopping_patience': 10,
        'early_stopping_min_delta': 0.001,
        'save_best_only': True,
        'use_mixed_precision': True
    },
    'loss': {
        'focal_gamma': 2.0,
        'label_smoothing': 0.1,
        'use_class_weights': True
    },
    'output': {
        'model_save_dir': '../saved_models',
        'results_save_dir': '../results',
        'tensorboard_log_dir': '../logs/tensorboard'
    },
    'seed': 42
}

## 1. Load and Preprocess Data

In [None]:
# Load data
data_processor = DataProcessor(config)
data_dict = data_processor.load_data(
    config['data']['train_path'],
    config['data']['test_path']
)

print(f"Training samples: {len(data_dict['X_train'])}")
print(f"Test samples: {len(data_dict['X_test']) if data_dict['X_test'] is not None else 0}")
print(f"Number of features: {data_dict['X_train'].shape[1]}")
print(f"Number of classes: {len(data_dict['class_names'])}")
print(f"Classes: {data_dict['class_names']}")

## 2. Create Data Loaders

In [None]:
# Create datasets and loaders
train_dataset, val_dataset, test_dataset = data_processor.create_datasets(
    data_dict,
    validation_split=config['data']['validation_split']
)

train_loader, val_loader, test_loader = data_processor.create_dataloaders(
    train_dataset, val_dataset, test_dataset,
    batch_size=config['training']['batch_size']
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 3. Create Model

In [None]:
# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

input_dim = data_dict['X_train'].shape[1]
num_classes = len(data_dict['class_names'])

model = ProteinFTTransformer(
    input_dim=input_dim,
    num_classes=num_classes,
    d_token=config['model']['d_token'],
    n_head=config['model']['n_head'],
    n_layers=config['model']['n_layers'],
    dropout=config['model']['dropout'],
    ff_dim_factor=config['model']['ff_dim_factor']
).to(device)

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Summary:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

## 4. Train Model

In [None]:
# Train model
trainer = Trainer(model, train_loader, val_loader, config, device)
history = trainer.train()

# Plot training history
plot_training_history(history)

## 5. Evaluate Model

In [None]:
# Evaluate on test set
from evaluation.metrics import evaluate_model

predictions, probabilities, targets = evaluate_model(
    model, test_loader, device
)

# Calculate metrics
metrics_calculator = MetricsCalculator(num_classes)
metrics = metrics_calculator.calculate_all_metrics(
    targets, predictions, probabilities
)

print(f"\nTest Set Results:")
print(f"  Accuracy: {metrics['accuracy']:.4f}")
print(f"  F1-Score (Macro): {metrics['f1_macro']:.4f}")
print(f"  F1-Score (Weighted): {metrics['f1_weighted']:.4f}")

# Plot confusion matrix
if num_classes <= 20:
    plot_confusion_matrix(targets, predictions, data_dict['class_names'])

## 6. Feature Importance

In [None]:
# Get feature importance
if hasattr(model, 'get_feature_importance'):
    # Get a batch of data
    sample_batch = next(iter(test_loader))[0][:10].to(device)
    importances = model.get_feature_importance(sample_batch)
    
    # Average importance across samples
    avg_importance = importances.mean(dim=0).cpu().numpy()
    
    # Get top features
    top_n = 20
    top_indices = np.argsort(avg_importance)[-top_n:][::-1]
    
    print(f"\nTop {top_n} most important features:")
    for i, idx in enumerate(top_indices):
        feature_name = data_dict['feature_names'][idx] if data_dict['feature_names'] else f"Feature_{idx}"
        print(f"  {i+1:2d}. {feature_name}: {avg_importance[idx]:.4f}")
    
    # Plot feature importance
    plt.figure(figsize=(12, 6))
    plt.bar(range(top_n), avg_importance[top_indices])
    plt.xlabel('Feature Rank')
    plt.ylabel('Importance Score')
    plt.title(f'Top {top_n} Most Important Features')
    plt.xticks(range(top_n), [data_dict['feature_names'][idx] if data_dict['feature_names'] else f"F{idx}" for idx in top_indices], rotation=45, ha='right')
    plt.tight_layout()
    plt.show()

## 7. Save Model

In [None]:
# Save model
model_save_path = '../saved_models/protein_ft_transformer.pth'
model.save(model_save_path)
print(f"Model saved to: {model_save_path}")