In [None]:
### SETUP ###

import os
import sys
import torch
sys.path.append('..')
import torch.nn as nn
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
from SCRIPTS.config import *
import seaborn as sns
from sklearn.metrics import classification_report
from SCRIPTS.dataprep import prepare_interval_data
from SCRIPTS.combined_model_v2 import CombinedModelV2
from SCRIPTS.combined_training_v2 import train_combined_model_v2
from SCRIPTS.cross_validation_experiments import run_combined_v2_cross_validation


# Visualization setup
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# Set random seed
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Data path: {COMBINED_SCATTERING}")
print(f"Results path: {COMBINED_MODEL_RESULTS_DIR}")

In [None]:
### CONFIGURATION ###

# Train combined model? T/F
TRAIN_COMBINED = False

# Run cross-validation? T/F
RUN_COMBINED_CV = True

# Number of CV trials (set to 5)
CV_TRIALS = 5

# Latent dimensions to test
LATENT_DIMS = [8, 48]

# Split types
SPLIT_TYPES = ['subject', 'time']

# Number of epochs for training
NUM_EPOCHS = 100

In [None]:
if TRAIN_COMBINED:
    print("\n=== TRAINING COMBINED MODELS ===")
    
    results = {}
    
    # Load interval data
    train_loader_subj, test_loader_subj, info_subj = prepare_interval_data(
        scattering_data_path=COMBINED_SCATTERING,
        split_type='subject',
        batch_size=16,
        random_state=42
    )

    train_loader_time, test_loader_time, info_time = prepare_interval_data(
        scattering_data_path=COMBINED_SCATTERING,
        split_type='time',
        batch_size=16,
        random_state=42
    )

    print(f"Subject split: {info_subj['n_train']} train, {info_subj['n_test']} test")
    print(f"Time split: {info_time['n_train']} train, {info_time['n_test']} test")
    
    for latent_dim in LATENT_DIMS:
        print(f"\n--- Latent Dimension: {latent_dim} ---")
        results[latent_dim] = {}
        
        print("\nSUBJECT SPLIT:")
        model, history = train_combined_model_v2(
            train_loader=train_loader_subj,
            test_loader=test_loader_subj,
            latent_dim=latent_dim,
            split_type='subject',
            num_epochs=NUM_EPOCHS,
            learning_rate=DEFAULT_LEARNING_RATE,
            device=device
        )
        
        results[latent_dim]['subject'] = {
            'accuracy': history['final_accuracy'],
            'history': history,
            'model_state': model.state_dict()
        }
        
        # Save model
        save_path = COMBINED_MODEL_RESULTS_DIR / f'combined_v2_subject_{latent_dim}d.pth'
        torch.save({
            'model_state_dict': model.state_dict(),
            'history': history,
            'latent_dim': latent_dim,
            'split_type': 'subject'
        }, save_path)
        print(f"Model saved to {save_path}")
        
        print("\nTIME SPLIT:")
        model, history = train_combined_model_v2(
            train_loader=train_loader_time,
            test_loader=test_loader_time,
            latent_dim=latent_dim,
            split_type='time',
            num_epochs=NUM_EPOCHS,
            learning_rate=DEFAULT_LEARNING_RATE,
            device=device
        )
        
        results[latent_dim]['time'] = {
            'accuracy': history['final_accuracy'],
            'history': history,
            'model_state': model.state_dict()
        }
        
        save_path = COMBINED_MODEL_RESULTS_DIR / f'combined_v2_time_{latent_dim}d.pth'
        torch.save({
            'model_state_dict': model.state_dict(),
            'history': history,
            'latent_dim': latent_dim,
            'split_type': 'time'
        }, save_path)
        print(f"Model saved to {save_path}")
    
    print("\n=== COMBINED MODEL V2 RESULTS ===")
    for latent_dim in LATENT_DIMS:
        print(f"\n{latent_dim}D Latent Space:")
        print(f"  Subject: {results[latent_dim]['subject']['accuracy']:.1f}%")
        print(f"  Time: {results[latent_dim]['time']['accuracy']:.1f}%")

else:
    print("Skipping combined model training")

In [None]:
### CROSS-VALIDATION ###

if RUN_COMBINED_CV:
    print("\n=== COMBINED CROSS-VALIDATION ===")
    
    # Use the new V2 cross-validation function
    all_cv_results = run_combined_v2_cross_validation(
        data_path=COMBINED_SCATTERING,
        split_types=SPLIT_TYPES,
        latent_dims=LATENT_DIMS,
        num_trials=CV_TRIALS,
        num_epochs=100  
    )
    
    print("\n=== COMBINED MODEL RESULTS ===")
    for latent_dim in LATENT_DIMS:
        print(f"\n{latent_dim}D Latent Space:")
        subject_key = f'subject_{latent_dim}d'
        time_key = f'time_{latent_dim}d'
        
        if subject_key in all_cv_results:
            print(f"  Subject: {all_cv_results[subject_key]['mean_accuracy']:.1f}% ± {all_cv_results[subject_key]['std_accuracy']:.1f}%")
        if time_key in all_cv_results:
            print(f"  Time: {all_cv_results[time_key]['mean_accuracy']:.1f}% ± {all_cv_results[time_key]['std_accuracy']:.1f}%")

else:
    print("✓ Skipping combined model cross-validation")
    
    all_cv_results = {}
    for split_type in SPLIT_TYPES:
        for latent_dim in LATENT_DIMS:
            config_key = f"{split_type}_{latent_dim}d"
            results_file = COMBINED_MODEL_RESULTS_DIR / f'combined_v2_{split_type}_{latent_dim}d_cv_results.pkl'
            
            if results_file.exists():
                with open(results_file, 'rb') as f:
                    all_cv_results[config_key] = pickle.load(f)
    
    if all_cv_results:
        print("\n=== LOADED COMBINED MODEL RESULTS ===")
        for latent_dim in LATENT_DIMS:
            print(f"\n{latent_dim}D Latent Space:")
            subject_key = f'subject_{latent_dim}d'
            time_key = f'time_{latent_dim}d'
            
            if subject_key in all_cv_results:
                print(f"  Subject: {all_cv_results[subject_key]['mean_accuracy']:.1f}% ± {all_cv_results[subject_key]['std_accuracy']:.1f}%")
            if time_key in all_cv_results:
                print(f"  Time: {all_cv_results[time_key]['mean_accuracy']:.1f}% ± {all_cv_results[time_key]['std_accuracy']:.1f}%")
    else:
        print("  No existing results found")