# Variant Classification: Model Comparison

This notebook compares several classifiers for predicting variant types using different preprocessed datasets. All random processes use a fixed random seed for reproducibility.

In [43]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from catboost import CatBoostClassifier
from xgboost import XGBClassifier
from sklearn.multiclass import OneVsRestClassifier

RANDOM_SEED = 42

## Data Loading and Preprocessing

Define a function to load and preprocess each dataset. This ensures consistent handling of columns, missing values, splitting, and scaling.

In [None]:
def load_and_preprocess(path, index_col=None, drop_cols=None):
    """
    Loads a CSV file and applies standard preprocessing:
    - Drops specified columns if present.
    - Drops 'Unnamed: 0' if present.
    - Drops 'Mutated' if present (not a target here).
    - Drops non-numeric columns (e.g., ModelID).
    - Splits into train/test and scales features.
    Returns: (X_train_scaled, X_test_scaled, y_train, y_test, feature_names)
    """
    df = pd.read_csv(path, index_col=index_col)
    if drop_cols:
        df = df.drop(columns=drop_cols, errors='ignore')
    df = df.drop(columns=['Unnamed: 0'], errors='ignore')
    df = df.drop(columns=['Mutated'], errors='ignore')
    # Separate target before selecting numeric features
    target = df['VariantType']
    features = df.drop(columns=['VariantType'], errors='ignore')
    # Keep only numeric features
    features = features.select_dtypes(include=[np.number])
    X_train, X_test, y_train, y_test = train_test_split(
        features, target, test_size=0.2, random_state=RANDOM_SEED
    )
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    return X_train_scaled, X_test_scaled, y_train, y_test, features.columns

## Prepare Datasets

Load all datasets and store them in a list for easy iteration. Each entry contains scaled train/test splits and feature names.

In [45]:
dataset_paths = [
    ('../data/processed/ccle_quantile_filtered.csv', None),
    ('../data/processed/ccle_tp53_filtered.csv', None),
    ('../data/processed/ccle_variance_filtered.csv', None),
    ('../data/processed/merged_data.csv', 1)
]
drop_cols = ['VariantLabel']

datasets = []
for path, idx_col in dataset_paths:
    X_train, X_test, y_train, y_test, feature_names = load_and_preprocess(path, index_col=idx_col, drop_cols=drop_cols)
    datasets.append((X_train, X_test, y_train, y_test, feature_names))

ValueError: columns overlap but no suffix specified: Index(['VariantType'], dtype='object')

## Model Training and Evaluation Function

This function trains and evaluates a given model on each dataset, displaying accuracy, F1 score, confusion matrix, and feature importance.

In [None]:
def train_and_evaluate_on_datasets(datasets, model_name):
    base_model_map = {
        'logistic': LogisticRegression(max_iter=1000, solver='lbfgs', random_state=RANDOM_SEED),
        'svm': SVC(kernel='rbf', probability=True, decision_function_shape='ovr', random_state=RANDOM_SEED),
        'random_forest': RandomForestClassifier(random_state=RANDOM_SEED),
        'catboost': CatBoostClassifier(iterations=1000, learning_rate=0.1, depth=6, random_seed=RANDOM_SEED, verbose=0),
        'xgboost': XGBClassifier(eval_metric='mlogloss', random_state=RANDOM_SEED),
    }

    if model_name not in base_model_map:
        raise ValueError(f"Unsupported model: {model_name}")

    # Wrap logistic in One-vs-Rest
    if model_name == 'logistic':
        model = OneVsRestClassifier(base_model_map['logistic'])
    else:
        model = base_model_map[model_name]

    num_datasets = len(datasets)
    fig, axes = plt.subplots(nrows=num_datasets, ncols=2, figsize=(12, 5 * num_datasets))
    if num_datasets == 1:
        axes = [axes]

    for idx, (X_train, X_test, y_train, y_test, feature_names) in enumerate(datasets):
        # Train the model
        if model_name in ['catboost', 'xgboost']:
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
        else:
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)

        acc = accuracy_score(y_test, y_pred)
        f1 = classification_report(y_test, y_pred, output_dict=True, zero_division=0)['weighted avg']['f1-score']

        print(f"Dataset {idx+1} - Accuracy: {acc:.2f} - F1 Score: {f1:.2f} - Model: {model_name}")

        # Confusion matrix
        cm = confusion_matrix(y_test, y_pred)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[idx][0])
        axes[idx][0].set_title(f'Dataset {idx+1} - Confusion Matrix')
        axes[idx][0].set_xlabel('Predicted')
        axes[idx][0].set_ylabel('True')

        # Feature importance
        if model_name in ['random_forest', 'catboost', 'xgboost']:
            importances = model.feature_importances_
        elif model_name == 'svm':
            if hasattr(model, 'coef_'):
                importances = np.abs(model.coef_).sum(axis=0)
            else:
                axes[idx][1].axis('off')
                continue
        elif model_name == 'logistic':
            if hasattr(model, 'estimators_'):
                all_coefs = np.array([np.abs(est.coef_).flatten() for est in model.estimators_])
                importances = all_coefs.sum(axis=0)
            else:
                axes[idx][1].axis('off')
                continue
        else:
            axes[idx][1].axis('off')
            continue

        indices = np.argsort(importances)[::-1][:10]
        top_features = feature_names[indices]
        sns.barplot(x=importances[indices], y=top_features, ax=axes[idx][1])
        axes[idx][1].set_title(f'Dataset {idx+1} - Feature Importance')

    plt.tight_layout()
    plt.show()

## Logistic Regression (One-vs-Rest)

In [None]:
train_and_evaluate_on_datasets(datasets, 'logistic')

## Support Vector Machine (SVM)

In [None]:
train_and_evaluate_on_datasets(datasets, 'svm')

## Random Forest

In [None]:
train_and_evaluate_on_datasets(datasets, 'random_forest')

## XGBoost

In [None]:
train_and_evaluate_on_datasets(datasets, 'xgboost')