# Task 1: Splicing Prediction Training
## Enhanced with Experiment Tracking & Metrics

### 1) Configuration & Setup

In [None]:
import os
import json
import torch
import pandas as pd
from config import (
    DATA_DIR, SAVE_ROOT, BATCH_SIZE, LR, EPOCHS, PATIENCE,
    HIDDEN_DIMS, DROPOUT, NUM_CLASSES, SEED, WEIGHT_DECAY
)

print(f"ðŸ“¦ Configuration loaded:")
print(f"  Data dir: {DATA_DIR}")
print(f"  Save root: {SAVE_ROOT}")
print(f"  Batch size: {BATCH_SIZE}, LR: {LR}, Epochs: {EPOCHS}")
print(f"  Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")


### 2) Train Splicing Models with Experiment Tracking

In [None]:
import argparse
from train import main as train_main

# Create argument parser (simulating command-line args in notebook)
parser = argparse.ArgumentParser(description="Task 1 Splicing Prediction Training")
parser.add_argument("--batch-size", type=int, default=BATCH_SIZE)
parser.add_argument("--lr", type=float, default=LR)
parser.add_argument('--weight-decay', type=float, default=WEIGHT_DECAY)
parser.add_argument("--epochs", type=int, default=EPOCHS)
parser.add_argument("--patience", type=int, default=PATIENCE)
parser.add_argument("--dropout", type=float, default=DROPOUT)
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument("--seed", type=int, default=SEED)
parser.add_argument("--data-dir", type=str, default=DATA_DIR)
parser.add_argument("--save-root", type=str, default=SAVE_ROOT)
parser.add_argument("--exp-num", type=int, default=None)
parser.add_argument("--ratio", type=str, default=None, help="Specific ratio (e.g., 'ratio_10_80_10'). If not specified, train all ratios.")
parser.add_argument("--set", type=str, default=None, help="Specific set (e.g., 'set_1'). If not specified, train all sets.")

# Parse empty args to use defaults, or modify as needed
args = parser.parse_args([])

# Run training
train_main(args)


### 3) View Experiment Results

In [None]:
# List all experiments and their results
import os
import json
import pandas as pd

def load_experiment_results(exp_base_dir):
    """Load all experiment results from experiments directory."""
    if not os.path.exists(exp_base_dir):
        print(f"No experiments directory found at {exp_base_dir}")
        return None
    
    experiments = []
    
    for exp_name in sorted(os.listdir(exp_base_dir), reverse=True):
        exp_dir = os.path.join(exp_base_dir, exp_name)
        if not os.path.isdir(exp_dir):
            continue
        
        # Load config
        config_file = os.path.join(exp_dir, "config.json")
        if not os.path.exists(config_file):
            continue
        
        with open(config_file, "r") as f:
            config = json.load(f)
        
        # Aggregate results from all ratio/set combinations
        avg_metrics = {}
        num_sets = 0
        
        for ratio_folder in os.listdir(exp_dir):
            ratio_path = os.path.join(exp_dir, ratio_folder)
            if not os.path.isdir(ratio_path) or ratio_folder in ['tensorboard']:
                continue
            
            for set_name in os.listdir(ratio_path):
                set_path = os.path.join(ratio_path, set_name)
                results_file = os.path.join(set_path, "results.json")
                
                if os.path.exists(results_file):
                    with open(results_file, "r") as f:
                        result_data = json.load(f)
                        metrics = result_data.get('best_metrics', {})
                        
                        for key, val in metrics.items():
                            if isinstance(val, (int, float)):
                                avg_metrics[key] = avg_metrics.get(key, 0) + val
                    
                    num_sets += 1
        
        if num_sets > 0:
            # Average the metrics
            for key in avg_metrics:
                avg_metrics[key] /= num_sets
            
            experiments.append({
                'experiment': exp_name,
                'num_sets': num_sets,
                **{k: round(v, 4) for k, v in avg_metrics.items()}
            })
    
    return pd.DataFrame(experiments) if experiments else None

# Display experiments
exp_dir = SAVE_ROOT
df_results = load_experiment_results(exp_dir)

if df_results is not None:
    print("ðŸ“Š Experiment Results Summary:")
    print(df_results.to_string())
else:
    print("No experiments found yet. Run training first!")

### 4) TensorBoard Monitoring

To monitor training progress in real-time:

```bash
tensorboard --logdir train/task1_splicing_prediction/trained_models/experiments/
```

Then open the URL in your browser to view:
- Training/Validation loss curves
- Metric evaluation (accuracy, precision, recall, F1, MCC, auc, balanced acc, specificity)
- Model architecture summary


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_metric_comparison(exp_base_dir, metric='accuracy'):
    """Plot metric comparison across experiments."""
    if not os.path.exists(exp_base_dir):
        print("No experiments found")
        return
    
    experiments = []
    metrics_list = []
    
    for exp_name in sorted(os.listdir(exp_base_dir)):
        exp_dir = os.path.join(exp_base_dir, exp_name)
        if not os.path.isdir(exp_dir):
            continue
        
        for ratio_folder in os.listdir(exp_dir):
            ratio_path = os.path.join(exp_dir, ratio_folder)
            if not os.path.isdir(ratio_path) or ratio_folder in ['tensorboard']:
                continue
            
            for set_name in os.listdir(ratio_path):
                set_path = os.path.join(ratio_path, set_name)
                results_file = os.path.join(set_path, "results.json")
                
                if os.path.exists(results_file):
                    with open(results_file, "r") as f:
                        result_data = json.load(f)
                        metrics_val = result_data.get('best_metrics', {}).get(metric, None)
                        
                        if metrics_val is not None:
                            experiments.append(exp_name)
                            metrics_list.append(metrics_val)
    
    if metrics_list:
        plt.figure(figsize=(12, 5))
        plt.bar(range(len(experiments)), metrics_list)
        plt.xticks(range(len(experiments)), experiments, rotation=45, ha='right')
        plt.ylabel(metric.capitalize())
        plt.title(f"{metric.capitalize()} Comparison Across Experiments")
        plt.tight_layout()
        plt.show()
    else:
        print(f"No {metric} data found")

# Example: plot accuracy comparison
plot_metric_comparison(SAVE_ROOT, metric='accuracy')
