In [1]:
### SETUP ###
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import pickle
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')
from SCRIPTS.config import *
from SCRIPTS.dataprep import prepare_interval_data
from SCRIPTS.curvature_models import CurvatureModel
from SCRIPTS.curvature_training import train_curvature_model, evaluate_model, IntervalDataset
from SCRIPTS.cross_validation_experiments import run_curvature_cross_validation

# Visualization setup
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Set random seed
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Data path: {COMBINED_SCATTERING}")
print(f"Results path: {CURVATURE_RESULTS_DIR}")

Using device: cpu
Data path: /Users/judesack/Neurospectrum_Creativity/DATA/SCATTERING_COEFFICIENTS/combined_scattering_data.csv
Results path: /Users/judesack/Neurospectrum_Creativity/RESULTS/cross_validation_results/curvature_results


In [2]:
### CONFIGURATION ###

# Train curvature models? T/F
TRAIN_CURVATURE = False

# Run cross-validation? T/F
RUN_CROSS_VALIDATION = True

# Number of CV trials (set to 5)
CV_TRIALS = 5

# Latent dimensions to test
LATENT_DIMS = [8, 48]

# Reconstruction weight (0.3 means 30% reconstruction, 70% classification)
RECON_WEIGHT = 0.3

In [4]:
### LOAD INTERVAL DATA ###

print("\n=== LOADING INTERVAL DATA ===")
train_loader_subj, test_loader_subj, info_subj = prepare_interval_data(
    scattering_data_path=COMBINED_SCATTERING,
    split_type='subject',
    batch_size=16,
    random_state=42
)

train_loader_time, test_loader_time, info_time = prepare_interval_data(
    scattering_data_path=COMBINED_SCATTERING,
    split_type='time',
    batch_size=16,
    random_state=42
)

print(f"\nSubject split: {info_subj['n_train']} train, {info_subj['n_test']} test")
print(f"Time split: {info_time['n_train']} train, {info_time['n_test']} test")


=== LOADING INTERVAL DATA ===
Found 17 valid subjects

Total intervals extracted: 306 (expected: 306)

Subject split:
  Train subjects (14): ['15053001sub1', '16100101', '15111101', '15081202sub2', '14101601', '14092201', '14091102', '15053001sub2', '15052902', '14091701', '16101401', '16102002', '15040901', '16100801']
  Test subjects (3): ['15080601', '15012001', '16100601']

Split results:
  Train: 252 intervals
  Test: 54 intervals
Found 17 valid subjects

Total intervals extracted: 306 (expected: 306)

Split results:
  Train: 238 intervals
  Test: 68 intervals

Subject split: 252 train, 54 test
Time split: 238 train, 68 test


In [None]:
### TRAIN CURVATURE MODELS ###

if TRAIN_CURVATURE:
    print("\n=== TRAINING CURVATURE MODELS ===")
    
    results = {}
    
    for latent_dim in LATENT_DIMS:
        print(f"\n--- Latent Dimension: {latent_dim} ---")
        results[latent_dim] = {}
        
        print("\nSUBJECT WITHHOLDING:")
        
        # Convert DataLoader to intervals for curvature training
        train_intervals = []
        for batch in train_loader_subj:
            features, labels, subjects, intervals = batch
            for i in range(len(features)):
                train_intervals.append((
                    features[i].numpy(),
                    labels[i].item(),
                    subjects[i],
                    intervals[i]
                ))
        
        test_intervals = []
        for batch in test_loader_subj:
            features, labels, subjects, intervals = batch
            for i in range(len(features)):
                test_intervals.append((
                    features[i].numpy(),
                    labels[i].item(),
                    subjects[i],
                    intervals[i]
                ))
        
        model, history = train_curvature_model(
            train_intervals=train_intervals,
            test_intervals=test_intervals,
            latent_dim=latent_dim,
            num_epochs=DEFAULT_EPOCHS,
            batch_size=16,
            learning_rate=DEFAULT_LEARNING_RATE,
            recon_weight=RECON_WEIGHT,
            device=device
        )
        
        results[latent_dim]['subject'] = {
            'accuracy': history['test_acc'][-1],
            'history': history,
            'model_state': model.state_dict()
        }
        
        save_path = CURVATURE_RESULTS_DIR / f'curvature_subject_{latent_dim}d.pth'
        torch.save({
            'model_state_dict': model.state_dict(),
            'history': history,
            'latent_dim': latent_dim,
            'split_type': 'subject'
        }, save_path)
        print(f"✓ Model saved to {save_path}")
        
        print("\nTIME WITHHOLDING:")
        
        train_intervals = []
        for batch in train_loader_time:
            features, labels, subjects, intervals = batch
            for i in range(len(features)):
                train_intervals.append((
                    features[i].numpy(),
                    labels[i].item(),
                    subjects[i],
                    intervals[i]
                ))
        
        test_intervals = []
        for batch in test_loader_time:
            features, labels, subjects, intervals = batch
            for i in range(len(features)):
                test_intervals.append((
                    features[i].numpy(),
                    labels[i].item(),
                    subjects[i],
                    intervals[i]
                ))
        
        model, history = train_curvature_model(
            train_intervals=train_intervals,
            test_intervals=test_intervals,
            latent_dim=latent_dim,
            num_epochs=DEFAULT_EPOCHS,
            batch_size=16,
            learning_rate=DEFAULT_LEARNING_RATE,
            recon_weight=RECON_WEIGHT,
            device=device
        )
        
        results[latent_dim]['time'] = {
            'accuracy': history['test_acc'][-1],
            'history': history,
            'model_state': model.state_dict()
        }
        
        save_path = CURVATURE_RESULTS_DIR / f'curvature_time_{latent_dim}d.pth'
        torch.save({
            'model_state_dict': model.state_dict(),
            'history': history,
            'latent_dim': latent_dim,
            'split_type': 'time'
        }, save_path)
        print(f"✓ Model saved to {save_path}")
    
    print("\n=== CURVATURE MODEL RESULTS ===")
    for latent_dim in LATENT_DIMS:
        print(f"\n{latent_dim}D Latent Space:")
        print(f"  Subject: {results[latent_dim]['subject']['accuracy']:.1f}%")
        print(f"  Time: {results[latent_dim]['time']['accuracy']:.1f}%")

else:
    print("✓ Skipping curvature model training")

In [5]:
### CROSS-VALIDATION ###

if RUN_CROSS_VALIDATION:
    print(f"\n=== CROSS-VALIDATION EXPERIMENTS ({CV_TRIALS} trials) ===")
    
    cv_results = {}
    
    for latent_dim in LATENT_DIMS:
        print(f"\n--- LATENT DIMENSION: {latent_dim} ---")
        cv_results[latent_dim] = {}
        
        for split_type in ['subject', 'time']:
            print(f"\n{split_type.upper()} SPLIT:")
            cv_results[latent_dim][split_type] = run_curvature_cross_validation(
                data_path=COMBINED_SCATTERING,
                split_type=split_type,
                latent_dim=latent_dim,
                num_trials=CV_TRIALS,
                num_epochs=DEFAULT_EPOCHS
            )
    
    # Save results
    save_path = CURVATURE_RESULTS_DIR / 'curvature_cv_results.pkl'
    with open(save_path, 'wb') as f:
        pickle.dump(cv_results, f)
    print(f"\n✓ Results saved to {save_path}")
    
    # Display summary
    print("\n=== CROSS-VALIDATION SUMMARY ===")
    for latent_dim in LATENT_DIMS:
        print(f"\n{latent_dim}D Latent Space:")
        print(f"  Subject: {cv_results[latent_dim]['subject']['mean_accuracy']:.1f}% ± {cv_results[latent_dim]['subject']['std_accuracy']:.1f}%")
        print(f"  Time: {cv_results[latent_dim]['time']['mean_accuracy']:.1f}% ± {cv_results[latent_dim]['time']['std_accuracy']:.1f}%")

else:
    # Load existing results
    cv_file = CURVATURE_RESULTS_DIR / 'curvature_cv_results.pkl'
    
    if cv_file.exists():
        with open(cv_file, 'rb') as f:
            cv_results = pickle.load(f)
        
        print("✓ Loaded existing cross-validation results")
        
        print("\n=== CROSS-VALIDATION SUMMARY ===")
        for latent_dim in LATENT_DIMS:
            if latent_dim in cv_results:
                print(f"\n{latent_dim}D Latent Space:")
                print(f"  Subject: {cv_results[latent_dim]['subject']['mean_accuracy']:.1f}% ± {cv_results[latent_dim]['subject']['std_accuracy']:.1f}%")
                print(f"  Time: {cv_results[latent_dim]['time']['mean_accuracy']:.1f}% ± {cv_results[latent_dim]['time']['std_accuracy']:.1f}%")
    else:
        print("✓ No existing cross-validation results found")


=== CROSS-VALIDATION EXPERIMENTS (5 trials) ===

--- LATENT DIMENSION: 8 ---

SUBJECT SPLIT:

=== Curvature Model Trial 1/5 (subject split, 8D) ===
Found 17 valid subjects

Total intervals extracted: 306 (expected: 306)

Subject split:
  Train subjects (14): ['15053001sub1', '16100101', '15111101', '15081202sub2', '14101601', '14092201', '14091102', '15053001sub2', '15052902', '14091701', '16101401', '16102002', '15040901', '16100801']
  Test subjects (3): ['15080601', '15012001', '16100601']

Split results:
  Train: 252 intervals
  Test: 54 intervals
Epoch 10/100:
  Train Loss: 1.0666 (Class: 1.0262, Recon: 0.4046)
  Train Acc: 54.17%, Test Acc: 55.56%
  Curvature: mean=38.0859, std=103.5866
Epoch 20/100:
  Train Loss: 1.0332 (Class: 0.9994, Recon: 0.3380)
  Train Acc: 55.42%, Test Acc: 53.70%
  Curvature: mean=33.6569, std=89.1882
Epoch 30/100:
  Train Loss: 1.0047 (Class: 0.9721, Recon: 0.3256)
  Train Acc: 56.67%, Test Acc: 55.56%
  Curvature: mean=32.2119, std=113.6072
Epoch 40/1