In [20]:
import torch, pickle
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report
from sklearn.preprocessing import label_binarize


In [21]:

def load_dict_from_pickle(filename):
    with open(filename, 'rb') as file:
        loaded_dict = pickle.load(file)
    return loaded_dict

def load_data(file_path):

    Nodes = load_dict_from_pickle(f'{file_path}/Nodes.pkl')
    X = torch.load(f'{file_path}/X.pt')
    Y = torch.load(f'{file_path}/Y.pt')
    print(X.shape)
    print(set(Y))
    return X, Y, Nodes


# 1. Apply SMOTE
def apply_smote(X, Y):
    smote = SMOTE(random_state=42)
    X_resampled, Y_resampled = smote.fit_resample(X, Y)
    return X_resampled, Y_resampled

# 2. Train Random Forest
def train_random_forest(X, Y, num_classes=4, num_estimators=100):
    # Split data into training, validation, and test sets
    X_train, X_temp, Y_train, Y_temp = train_test_split(X, Y, train_size=0.8, random_state=42)
    X_val, X_test, Y_val, Y_test = train_test_split(X_temp, Y_temp, train_size=0.5, random_state=42)

    # Apply SMOTE to the training set
    X_train_resampled, Y_train_resampled = apply_smote(X_train, Y_train)

    # Initialize the Random Forest Classifier
    clf = RandomForestClassifier(n_estimators=num_estimators, random_state=42)

    # Train the model
    clf.fit(X_train_resampled, Y_train_resampled)

    # Evaluate on validation set
    Y_val_pred = clf.predict(X_val)
    val_metrics = evaluate_multi_class(Y_val, Y_val_pred, num_classes)

    print("Validation Metrics:")
    print(val_metrics)

    # Evaluate on test set
    Y_test_pred = clf.predict(X_test)
    test_metrics = evaluate_multi_class(Y_test, Y_test_pred, num_classes)

    print("Test Metrics:")
    print(test_metrics)

    return clf, test_metrics

# 3. Evaluation for Multi-Class Classification
def evaluate_multi_class(true_labels, preds, num_classes):
    # Binarize the true labels for multi-class AUC calculation
    true_labels_binarized = label_binarize(true_labels, classes=np.arange(num_classes))

    try:
        # Calculate AUC for each class and then macro-average
        auc = roc_auc_score(true_labels_binarized, label_binarize(preds, classes=np.arange(num_classes)), average='macro', multi_class='ovr')
    except ValueError:
        auc = None

    accuracy = accuracy_score(true_labels, preds)
    precision = precision_score(true_labels, preds, average='weighted')
    recall = recall_score(true_labels, preds, average='weighted')
    f1 = f1_score(true_labels, preds, average='weighted')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'auc': auc
    }

# Main Function
def main_random_forest(file_path, num_estimators=100):
    X, Y, _ = load_data(file_path)

    # Train Random Forest
    clf, metrics = train_random_forest(X, Y, num_classes=3, num_estimators=250)
    
    return metrics

# Set parameters and run the main training process
saving_path = '../../Data/infectious'
results = main_random_forest(saving_path, num_estimators=100)
print(results)


(19442, 1456)
{0, 1, 2}
Validation Metrics:
{'accuracy': 0.7309670781893004, 'precision': 0.7439241903563968, 'recall': 0.7309670781893004, 'f1_score': 0.7274245220617707, 'auc': 0.7280978073280573}
Test Metrics:
{'accuracy': 0.7506426735218509, 'precision': 0.7555844225264811, 'recall': 0.7506426735218509, 'f1_score': 0.7457446482321642, 'auc': 0.7379819539229517}
{'accuracy': 0.7506426735218509, 'precision': 0.7555844225264811, 'recall': 0.7506426735218509, 'f1_score': 0.7457446482321642, 'auc': 0.7379819539229517}
