In [1]:
import sys
# Uncomment if tornet isn't installed in your environment or in your path already
#sys.path.append('../')  

import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import keras

from tornet.data.tf.loader import create_tf_dataset 
from tornet.data.constants import ALL_VARIABLES

In [2]:
# keras accepts most data loaders (tensorflow, torch).
# A pure keras data loader, with necessary preprocessing steps for the cnn baseline, is provided
from tornet.data.keras.loader import KerasDataLoader
data_root = "C:/Users/mjhig/tornet_2013"
ds = KerasDataLoader(data_root=data_root,
                     data_type='train',
                     years=[2013],
                     batch_size = 8, 
                     workers = 4,
                     use_multiprocessing = True)

In [3]:
# Build a test set
ds_test = KerasDataLoader(data_root=data_root,
                         data_type='test',
                         years=[2018],
                         batch_size = 8, 
                         workers = 4,
                         use_multiprocessing = True)

In [4]:
# Split data into training and validation sets
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, accuracy_score
import numpy as np
from tqdm import tqdm

def get_all_batches_for_randomforest(dataset, input_vars):
    """Extracts all batches, flattens them, and formats them for RandomForestClassifier."""
    X, y = [], []
    
    total_batches = len(dataset)  # Get total batch count
    
    for batch in tqdm(dataset, total=total_batches, desc="Processing Batches", unit="batch"):
        # Concatenate selected input variables along the last axis
        features = np.concatenate([batch[0][v] for v in input_vars], axis=-1)
        
        # Reshape features: (batch_size, feature_dim_1, feature_dim_2, ...) → (batch_size, num_features)
        X.append(features.reshape(features.shape[0], -1))  # Flatten to 2D
        y.append(batch[1])

    # Stack all collected batches
    X = np.vstack(X)  # Stack features
    y = np.hstack(y)  # Stack labels
    
    return X, y  # Ready for RandomForestClassifier

# Example usage:
X, y = get_all_batches_for_randomforest(ds, ALL_VARIABLES)
X_test, y_test = get_all_batches_for_randomforest(ds_test, ALL_VARIABLES)

Processing Batches: 100%|██████████| 438/438 [07:16<00:00,  1.00batch/s]
Processing Batches:   2%|▏         | 7/315 [00:17<12:48,  2.50s/batch]


KeyboardInterrupt: 

In [None]:
# # Train-test split
# from sklearn.model_selection import train_test_split
# # Train RandomForestClassifier
# from sklearn.ensemble import RandomForestClassifier

# rf_model = RandomForestClassifier(
#     n_estimators=100,  # Number of trees
#     max_depth=10,      # Limit depth to prevent overfitting
#     random_state=42,
#     n_jobs=-1
# )
# rf_model.fit(X, y)
# print("RandomForest Model Training Complete!")

In [None]:
# Import necessary libraries
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, accuracy_score

# Define models with hyperparameters
models = {
    "SVM": SVC(kernel="rbf", C=1.0, random_state=42, probability=True),  # Enable probability estimates
    "GradientBoosting": GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
}
for name, model in models.items():
    print(f"Training {name}...")
    model.fit(X, y)

    # Predictions
    y_pred_proba = model.predict_proba(X_test)[:, 1]  # Get probabilities for the positive class
    y_pred = model.predict(X_test)

    # Compute Metrics
    roc_auc = roc_auc_score(y_test, y_pred_proba)
    precision, recall, _ = precision_recall_curve(y_test, y_pred_proba)
    pr_auc = auc(recall, precision)
    accuracy = accuracy_score(y_test, y_pred)

    # Print results
    print(f"{name} - Validation AUC: {roc_auc:.4f}")
    print(f"{name} - Validation PR AUC: {pr_auc:.4f}")
    print(f"{name} - Validation Accuracy: {accuracy:.4f}\n")