In [None]:
### SETUP ###
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import pickle
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report
import warnings
warnings.filterwarnings('ignore')
from SCRIPTS.config import *
from SCRIPTS.dataprep import prepare_interval_data, TaskIntervalDataset
from SCRIPTS.rnn_model import AttentionLSTMClassifier
from SCRIPTS.rnn_training import create_balanced_sampler, train_rnn_model
from SCRIPTS.cross_validation_experiments import run_rnn_cross_validation

# Visualization setup
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

# Set random seed
np.random.seed(42)
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Data path: {COMBINED_SCATTERING}")
print(f"Results path: {RNN_RESULTS_DIR}")

In [None]:
### CONFIGURATION ###

# Train RNN models? T/F
TRAIN_RNN = False

# Run cross-validation? T/F
RUN_CROSS_VALIDATION = True

# Number of CV trials (set to 5)
CV_TRIALS = 5

# Latent dimensions to test
LATENT_DIMS = [8, 48]

In [None]:
### LOAD INTERVAL DATA ###

print("\n=== LOADING INTERVAL DATA ===")

# Load data for different split types
train_loader_subj, test_loader_subj, info_subj = prepare_interval_data(
    scattering_data_path=COMBINED_SCATTERING,
    split_type='subject',
    batch_size=16,
    random_state=42
)

train_loader_time, test_loader_time, info_time = prepare_interval_data(
    scattering_data_path=COMBINED_SCATTERING,
    split_type='time',
    batch_size=16,
    random_state=42
)

print(f"\nSubject split: {info_subj['n_train']} train, {info_subj['n_test']} test")
print(f"Time split: {info_time['n_train']} train, {info_time['n_test']} test")


In [None]:
### TRAIN RNN MODELS ###

if TRAIN_RNN:
    print("\n=== TRAINING RNN MODELS ===")
    results = {}
    
    for latent_dim in LATENT_DIMS:
        print(f"\n--- Latent Dimension: {latent_dim} ---")
        results[latent_dim] = {}
        
        print("\nSUBJECT WITHHOLDING:")
        train_dataset = train_loader_subj.dataset
        sampler, _ = create_balanced_sampler(train_dataset)
        train_loader_balanced = DataLoader(
            train_dataset, 
            batch_size=16, 
            sampler=sampler, 
            drop_last=True
        )
        
        model = AttentionLSTMClassifier(
            input_dim=768,
            latent_dim=latent_dim,
            hidden_dim=128
        )
        print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
        
        model, history = train_rnn_model(
            model=model,
            train_loader=train_loader_balanced,
            test_loader=test_loader_subj,
            num_epochs=DEFAULT_EPOCHS,
            lr=DEFAULT_LEARNING_RATE,
            recon_weight = LOSS_WEIGHTS['reconstruction'],
            device=device
        )
        
        results[latent_dim]['subject'] = {
            'accuracy': history['val_acc'][-1],
            'history': history,
            'model_state': model.state_dict()
        }
        
        # Save model
        save_path = RNN_RESULTS_DIR / f'rnn_subject_{latent_dim}d.pth'
        torch.save({
            'model_state_dict': model.state_dict(),
            'history': history,
            'latent_dim': latent_dim,
            'split_type': 'subject'
        }, save_path)
        print(f"✓ Model saved to {save_path}")
        
        print("\nTIME WITHHOLDING:")
        train_dataset_time = train_loader_time.dataset
        sampler_time, _ = create_balanced_sampler(train_dataset_time)
        train_loader_balanced_time = DataLoader(
            train_dataset_time,
            batch_size=16,
            sampler=sampler_time,
            drop_last=True
        )
        
        model = AttentionLSTMClassifier(
            input_dim=768,
            latent_dim=latent_dim,
            hidden_dim=128
        )
        
        model, history = train_rnn_model(
            model=model,
            train_loader=train_loader_balanced_time,
            test_loader=test_loader_time,
            num_epochs=DEFAULT_EPOCHS,
            lr=DEFAULT_LEARNING_RATE,
            recon_weight = LOSS_WEIGHTS['reconstruction'],
            device=device
        )
        
        results[latent_dim]['time'] = {
            'accuracy': history['val_acc'][-1],
            'history': history,
            'model_state': model.state_dict()
        }
        
        save_path = RNN_RESULTS_DIR / f'rnn_time_{latent_dim}d.pth'
        torch.save({
            'model_state_dict': model.state_dict(),
            'history': history,
            'latent_dim': latent_dim,
            'split_type': 'time'
        }, save_path)
        print(f"✓ Model saved to {save_path}")
    
    print("\n=== RNN MODEL RESULTS ===")
    for latent_dim in LATENT_DIMS:
        print(f"\n{latent_dim}D Latent Space:")
        print(f"  Subject: {results[latent_dim]['subject']['accuracy']:.1f}%")
        print(f"  Time: {results[latent_dim]['time']['accuracy']:.1f}%")

else:
    print("✓ Skipping RNN model training")

In [None]:
### CROSS-VALIDATION ###

if RUN_CROSS_VALIDATION:
    print(f"\n=== CROSS-VALIDATION EXPERIMENTS ({CV_TRIALS} trials) ===")

    cv_results = {}

    for latent_dim in LATENT_DIMS:
        print(f"\n--- LATENT DIMENSION: {latent_dim} ---")
        cv_results[latent_dim] = {}

        for split_type in ['subject', 'time']:
            print(f"\n{split_type.upper()} SPLIT:")
            cv_results[latent_dim][split_type] = run_rnn_cross_validation(
                data_path=COMBINED_SCATTERING,
                split_type=split_type,
                num_trials=CV_TRIALS,
                num_epochs=DEFAULT_EPOCHS,
                latent_dim=latent_dim
            )

    save_path = RNN_RESULTS_DIR / 'rnn_cv_results.pkl'
    with open(save_path, 'wb') as f:
        pickle.dump(cv_results, f)

    print(f"\n✓ Results saved to {save_path}")

    print("\n=== CROSS-VALIDATION SUMMARY ===")
    for latent_dim in LATENT_DIMS:
        print(f"\nLatent Dim: {latent_dim}D")
        for split_type in ['subject', 'time']:
            acc = cv_results[latent_dim][split_type]
            print(f"  {split_type.capitalize()}: {acc['mean_accuracy']:.1f}% ± {acc['std_accuracy']:.1f}%")

else:
    cv_file = RNN_RESULTS_DIR / 'rnn_cv_results.pkl'

    if cv_file.exists():
        with open(cv_file, 'rb') as f:
            cv_results = pickle.load(f)

        print("✓ Loaded existing cross-validation results")

        print("\n=== CROSS-VALIDATION SUMMARY ===")
        for latent_dim in cv_results:
            print(f"\nLatent Dim: {latent_dim}D")
            for split_type in ['subject', 'time']:
                acc = cv_results[latent_dim][split_type]
                print(f"  {split_type.capitalize()}: {acc['mean_accuracy']:.1f}% ± {acc['std_accuracy']:.1f}%")
    else:
        print("✓ No existing cross-validation results found")