In [None]:
# ============================================================
# GAT LOOCV MODEL - NEURAL GRAPH LEARNING FOR BEHAVIOR PREDICTION
# ============================================================
# Predicts animal-level freezing behavior using neuron-level features and 
# functional connectivity (edge weights) in a GAT neural network framework.
# ============================================================

# === Core Libraries ===
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# === Torch + GNN Frameworks ===
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, TransformerConv, global_mean_pool

# === Preprocessing + Evaluation ===
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import LeaveOneOut, KFold
from sklearn.metrics import r2_score

# === Optuna for Hyperparameter Tuning ===
import optuna
from optuna.samplers import GridSampler

# === Reproducibility (optional) ===
torch.manual_seed(42)
np.random.seed(42)


In [None]:
# Load neuron features, edge weights, and freezing labels
def load_data(neuron_features_path, edge_weights_path, labels_path):
    """
    Loads and validates neuron-level features, functional edge weights, and 
    animal-level freezing labels for graph-based behavioral prediction.

    Depending on the GNN model used, edge weights may be incorporated into 
    attention mechanisms (if using an edge-aware model like GATWithEdgeAttr).

    Parameters
    ----------
    neuron_features_path : str
        Path to Excel file containing neuron-level features for each animal.

    edge_weights_path : str
        Path to Excel file containing mean edge weights between neuron pairs 
        (used to construct edge_attr in GNN models if applicable).

    labels_path : str
        Path to Excel file containing animal-level behavioral labels (e.g., 
        Percentage Freezing) and experimental group.

    Returns
    -------
    neuron_features_df : pd.DataFrame
        DataFrame of neuron-level features with 'Unique ID' and 'Neuron' columns.

    edge_weights_df : pd.DataFrame
        DataFrame of neuron pairs and their corresponding edge weights. Required 
        for graph construction and optionally passed as edge_attr to the model.

    labels_df : pd.DataFrame
        DataFrame containing animal-level metadata including behavioral output 
        (freezing %) and group identifiers.
    """
    neuron_features_df = pd.read_excel(neuron_features_path)
    edge_weights_df = pd.read_excel(edge_weights_path)
    labels_df = pd.read_excel(labels_path)

    # Fix column naming inconsistencies
    neuron_features_df.rename(columns={'n_clust': 'Neuron'}, inplace=True)

    # Convert neuron pair strings like "(1, 2)" to tuples
    def convert_pair(pair_str):
        if 'dummy' in pair_str.lower():
            return (0, 1)
        nums = re.findall(r'\d+', pair_str)
        return tuple(map(int, nums))

    edge_weights_df['Neuron Pair'] = edge_weights_df['Neuron Pair'].apply(convert_pair)

    # Validation checks
    assert set(neuron_features_df['Unique ID'].unique()).issubset(edge_weights_df['Unique ID'].unique()), \
        "Mismatch in Unique IDs between neuron features and edge weights."
    assert set(neuron_features_df['Unique ID'].unique()).issubset(labels_df['UniqueID'].unique()), \
        "Mismatch in Unique IDs between neuron features and labels."

    return neuron_features_df, edge_weights_df, labels_df


In [None]:
# Convert per-animal data into PyG graph objects

def create_graph_objects(neuron_features_df, edge_weights_df, labels_df):
    """
    Constructs torch_geometric Data objects (one per animal) using:
    - node features (neurons)
    - edge index (connectivity)
    - edge attributes (weights)
    - label (freezing %)
    - group and unique ID metadata

    Returns
    -------
    animal_graphs_first600 : list of torch_geometric.data.Data
    animal_graphs_last600 : list of torch_geometric.data.Data

    Note
    ----
    Both standard and edge-aware GAT models use the same graph format.
    Only GATWithEdgeAttr will access the edge_attr during training.
    """

    animal_graphs_first600 = []
    animal_graphs_last600 = []

    for unique_id in neuron_features_df['Unique ID'].unique():
        neurons = neuron_features_df[neuron_features_df['Unique ID'] == unique_id]
        edges = edge_weights_df[edge_weights_df['Unique ID'] == unique_id]
        label_row = labels_df[labels_df['UniqueID'] == unique_id]

        if neurons.empty or edges.empty or label_row.empty:
            print(f"Missing data for animal {unique_id}. Skipping.")
            continue

        group = label_row['Group'].iloc[0]

        # Feature selection
        feat_f600 = ['Firing rate first 600', 'mISI (s) first 600', 'maxISI (s) first 600',
                     'minISI (s) first 600', 'CVISI first 600', 'PC1score whole', 'PC2score whole']
        feat_l600 = ['Firing rate last 600', 'mISI (s) last 600', 'maxISI (s) last 600',
                     'minISI (s) last 600', 'CVISI last 600', 'PC1score whole', 'PC2score whole']

        x_f600 = torch.tensor(neurons[feat_f600].values, dtype=torch.float)
        x_l600 = torch.tensor(neurons[feat_l600].values, dtype=torch.float)

        # Edge index (convert 1-indexed → 0-indexed)
        edge_index = torch.tensor([(i-1, j-1) for i, j in edges['Neuron Pair']], dtype=torch.long).t().contiguous()
        edge_attr_f600 = torch.tensor(edges['Mean Edge Weight First 600s'].values, dtype=torch.float).view(-1, 1)
        edge_attr_l600 = torch.tensor(edges['Mean Edge Weight Last 600s'].values, dtype=torch.float).view(-1, 1)

        y_f600 = torch.tensor(label_row['Percentage Freezing First 600'].values, dtype=torch.float)
        y_l600 = torch.tensor(label_row['Percentage Freezing Last 600'].values, dtype=torch.float)

        data_f600 = Data(x=x_f600, edge_index=edge_index, edge_attr=edge_attr_f600,
                         y=y_f600, group=torch.tensor([group]), unique_id=torch.tensor([unique_id]))
        data_l600 = Data(x=x_l600, edge_index=edge_index, edge_attr=edge_attr_l600,
                         y=y_l600, group=torch.tensor([group]), unique_id=torch.tensor([unique_id]))

        animal_graphs_first600.append(data_f600)
        animal_graphs_last600.append(data_l600)

    print(f"Created {len(animal_graphs_first600)} graphs for First 600s")
    print(f"Created {len(animal_graphs_last600)} graphs for Last 600s")

    return animal_graphs_first600, animal_graphs_last600


In [None]:
# Safety checks on edge index values

def check_dummy_node_edges(graphs):
    """
    Sanity check for edge index issues.
    Any edge indices < 0 are corrected to the last valid node index.
    """
    for graph in graphs:
        eid = graph.unique_id.item()
        if graph.edge_index.min() < 0:
            print(f"Edge index < 0 found in animal {eid}. Attempting fix.")
            corrected = graph.edge_index.clone()
            corrected[corrected < 0] = graph.num_nodes - 1
            graph.edge_index = corrected
            print(f"Edge correction applied for animal {eid}.")
        else:
            print(f"Edge index OK for animal {eid}.")


In [None]:
# Apply StandardScaler to features

def normalize_features(graphs):
    """
    Applies sklearn StandardScaler to node features within each graph.
    Normalizes features to zero mean and unit variance.

    This is done graph-by-graph to preserve each animal’s independence.

    Note
    ----
    This function only normalizes node features (graph.x).
    Edge attributes (graph.edge_attr) are left unchanged.
    """
    scaler = StandardScaler()
    for graph in graphs:
        graph.x = torch.tensor(scaler.fit_transform(graph.x), dtype=torch.float)
    return graphs


In [None]:
#Define GAT Models with Optional Edge Attribute Support

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, TransformerConv, global_mean_pool

# ============================================
# TOGGLE THIS TO SWITCH BETWEEN MODELS
# ============================================
use_edge_weights = True  #Set to False for standard GAT without edge_attr

# ============================================
# MODEL 1 — Standard GAT (no edge weights)
# ============================================
class GAT(nn.Module):
    """
    Standard Graph Attention Network using GATConv.
    - Ignores edge weights
    - Uses 3 attention layers with dropout
    - Scales output to 0–100 (freezing %)
    """
    def __init__(self, num_features, hidden_channels, dropout_rate=0.3, heads=1):
        super(GAT, self).__init__()
        self.conv1 = GATConv(num_features, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=heads)
        self.conv3 = GATConv(hidden_channels * heads, 1, heads=1)
        self.dropout_rate = dropout_rate

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.conv3(x, edge_index)
        x = global_mean_pool(x, batch)
        return torch.sigmoid(x) * 100

# ============================================
# MODEL 2 — Edge-Aware GAT (uses edge_attr)
# ============================================
class GATWithEdgeAttr(nn.Module):
    """
    Edge-aware Graph Attention Network using TransformerConv.
    - Accepts edge_attr (e.g., functional connectivity)
    - 3 attention layers with dropout
    - Scales output to 0–100 (freezing %)
    """
    def __init__(self, num_features, hidden_channels, dropout_rate=0.3, heads=1):
        super(GATWithEdgeAttr, self).__init__()
        self.conv1 = TransformerConv(num_features, hidden_channels, heads=heads, edge_dim=1)
        self.conv2 = TransformerConv(hidden_channels * heads, hidden_channels, heads=heads, edge_dim=1)
        self.conv3 = TransformerConv(hidden_channels * heads, 1, heads=1, edge_dim=1)
        self.dropout_rate = dropout_rate

    def forward(self, x, edge_index, edge_attr, batch):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.conv3(x, edge_index, edge_attr)
        x = global_mean_pool(x, batch)
        return torch.sigmoid(x) * 100

# ============================================
# Model Instantiation (based on toggle)
# ============================================
def create_model(num_features, hidden_channels, dropout_rate=0.3, heads=1):
    """
    Creates either a standard GAT or an edge-aware GAT model
    based on the use_edge_weights flag.
    """
    if use_edge_weights:
        return GATWithEdgeAttr(num_features, hidden_channels, dropout_rate, heads)
    else:
        return GAT(num_features, hidden_channels, dropout_rate, heads)

# === EXAMPLE USAGE ===
# model = create_model(num_features=6, hidden_channels=64, dropout_rate=0.2, heads=2)

In [None]:
# Train GAT or GATWithEdgeAttr Model

def train_model(model, train_loader, val_loader, optimizer, criterion, scheduler, num_epochs=110):
    """
    Trains the GAT model (with or without edge weights) and records loss per epoch.

    Parameters
    ----------
    model : torch.nn.Module
        Either GAT (no edge_attr) or GATWithEdgeAttr (uses edge_attr).
    train_loader : DataLoader
        Training graph batch loader.
    val_loader : DataLoader
        Validation graph batch loader.
    optimizer : torch.optim.Optimizer
    criterion : loss function (e.g., MSELoss)
    scheduler : learning rate scheduler
    num_epochs : int
        Number of training epochs.

    Returns
    -------
    train_losses : list of float
        Average training loss per epoch.
    val_losses : list of float
        Average validation loss per epoch.
    """
    model.train()
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        total_train_loss = 0.0
        for data in train_loader:
            optimizer.zero_grad()

            if torch.isnan(data.x).any() or torch.isnan(data.y).any():
                continue

            # Forward pass (with or without edge weights)
            if use_edge_weights:
                out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            else:
                out = model(data.x, data.edge_index, data.batch)

            loss = criterion(out.view(-1, 1), data.y.view(-1, 1))
            if torch.isnan(loss):
                continue

            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)

        # === Validation ===
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for data in val_loader:
                if use_edge_weights:
                    val_out = model(data.x, data.edge_index, data.edge_attr, data.batch)
                else:
                    val_out = model(data.x, data.edge_index, data.batch)

                val_loss = criterion(val_out.view(-1, 1), data.y.view(-1, 1)).item()
                total_val_loss += val_loss

        avg_val_loss = total_val_loss / len(val_loader.dataset)
        val_losses.append(avg_val_loss)
        scheduler.step()
        model.train()

    return train_losses, val_losses


In [None]:
#Plotting Utilities

import matplotlib.pyplot as plt
import numpy as np
import torch

def plot_loss_curves(train_losses, val_losses, title):
    """
    Plot training and validation loss per epoch.
    """
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

def plot_fold_loss_curves(fold_loss_curves, segment_name):
    """
    Plot loss curves for each fold separately.
    """
    for animal_id, (train_losses, val_losses) in fold_loss_curves.items():
        plot_loss_curves(train_losses, val_losses, f"{segment_name} - Fold Animal {animal_id}")

def plot_average_loss_curves(fold_loss_curves, segment_name):
    """
    Plot average loss curves across folds.
    """
    all_train = [v[0] for v in fold_loss_curves.values()]
    all_val = [v[1] for v in fold_loss_curves.values()]
    avg_train = np.mean(np.stack(all_train), axis=0)
    avg_val = np.mean(np.stack(all_val), axis=0)

    plt.figure(figsize=(10, 5))
    plt.plot(avg_train, label='Avg Train Loss')
    plt.plot(avg_val, label='Avg Val Loss')
    plt.title(f"{segment_name} - Avg Loss Curves")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

def plot_prediction_vs_actual(loader, model, title, add_reference_line=True):
    """
    Plots predicted vs. actual freezing % across all samples.
    Colors indicate group identity (0 = black, 1 = pink, 2 = teal).
    Automatically adapts to whether edge weights are used.
    """
    model.eval()
    preds, trues, groups = [], [], []
    cmap = {0: 'black', 1: 'pink', 2: 'teal'}

    with torch.no_grad():
        for data in loader:
            if use_edge_weights:
                out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            else:
                out = model(data.x, data.edge_index, data.batch)
            preds.append(out.item())
            trues.append(data.y.item())
            groups.append(data.group.item())

    plt.figure(figsize=(8, 6))
    for grp in np.unique(groups):
        idxs = [i for i, g in enumerate(groups) if g == grp]
        plt.scatter(np.array(preds)[idxs], np.array(trues)[idxs], 
                    color=cmap.get(grp, 'gray'), label=f'Group {grp}', s=50, alpha=0.7)

    if add_reference_line:
        plt.plot([0, 100], [0, 100], 'r--', label='Perfect Prediction')

    plt.title(title)
    plt.xlabel("Predicted Freezing %")
    plt.ylabel("Actual Freezing %")
    plt.xlim(0, 100)
    plt.ylim(0, 100)
    plt.legend()
    plt.show()


In [None]:
# LOOCV Training Function

from sklearn.metrics import r2_score
import numpy as np
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import LeaveOneOut

def loocv_training(graphs, best_params, num_features=6, num_epochs=110, segment_name=''):
    """
    Performs Leave-One-Out Cross-Validation (LOOCV) using one graph (animal) as validation and the rest as training.

    Parameters
    ----------
    graphs : list of Data
        One PyG graph per animal.
    best_params : dict
        Hyperparameters from Optuna or manual config.
    num_features : int
        Node feature dimensionality.
    num_epochs : int
        Epochs per fold.
    segment_name : str
        For diagnostic printing.

    Returns
    -------
    avg_train_loss : float
    avg_val_loss : float
    r2 : float
        Final LOOCV R² score.
    fold_loss_curves : dict
        Per-fold training and validation losses.
    """
    loo = LeaveOneOut()
    indices = np.arange(len(graphs))

    all_preds = []
    all_targets = []
    animal_ids = []
    fold_loss_curves = {}

    for fold, (train_idx, val_idx) in enumerate(loo.split(indices)):
        train_graphs = [graphs[i] for i in train_idx]
        val_graphs = [graphs[i] for i in val_idx]
        animal_id = val_graphs[0].unique_id.item()

        print(f"Fold {fold+1}/{len(graphs)}: Testing animal {animal_id}")

        train_loader = DataLoader(train_graphs, batch_size=1, shuffle=True, follow_batch=['x'])
        val_loader = DataLoader(val_graphs, batch_size=1, shuffle=False, follow_batch=['x'])

        model = create_model(
            num_features=num_features,
            hidden_channels=best_params['hidden_channels'],
            dropout_rate=best_params['dropout_rate'],
            heads=best_params['heads']
        )
        optimizer = optim.Adam(model.parameters(), lr=best_params['learning_rate'], weight_decay=1e-5)
        scheduler = CosineAnnealingLR(optimizer, T_max=50)
        criterion = nn.MSELoss()

        train_losses, val_losses = train_model(model, train_loader, val_loader, optimizer, criterion, scheduler, num_epochs)
        fold_loss_curves[animal_id] = (train_losses, val_losses)

        # Final prediction on test animal
        model.eval()
        with torch.no_grad():
            for data in val_loader:
                if use_edge_weights:
                    pred = model(data.x, data.edge_index, data.edge_attr, data.batch)
                else:
                    pred = model(data.x, data.edge_index, data.batch)
                all_preds.append(pred.item())
                all_targets.append(data.y.item())
                animal_ids.append(animal_id)

    y_true = np.array(all_targets)
    y_pred = np.array(all_preds)
    r2 = r2_score(y_true, y_pred)

    print(f"LOOCV Summary ({segment_name})")
    print(f"Avg Train Loss: {np.mean([np.mean(v[0]) for v in fold_loss_curves.values()]):.4f}")
    print(f"Avg Val Loss: {np.mean([np.mean(v[1]) for v in fold_loss_curves.values()]):.4f}")
    print(f"Aggregated R²: {r2:.4f}")

    for a, p, t in zip(animal_ids, y_pred, y_true):
        print(f"Animal {a}: Predicted = {p:.2f}, Actual = {t:.2f}")

    return (
        np.mean([np.mean(v[0]) for v in fold_loss_curves.values()]),  # Avg train loss
        np.mean([np.mean(v[1]) for v in fold_loss_curves.values()]),  # Avg val loss
        r2,
        fold_loss_curves
    )


In [None]:
# Optuna Tuning Using K-Fold (Optional)
# =====================================
# Uses 5-fold CV to tune model hyperparameters.
# Compatible with both GAT and edge-aware GAT based on `use_edge_weights` toggle.

# === Define Search Space and Sampler ===
search_space = {
    'hidden_channels': [64, 128],
    'dropout_rate': [0.1, 0.2, 0.3],
    'learning_rate': [1e-2, 1e-3, 1e-4],
    'heads': [1, 2, 4]
}
sampler = GridSampler(search_space)

def objective_factory(graphs):
    """
    Creates an Optuna objective function using KFold CV on the given graphs.
    Returns a callable that Optuna can optimize.
    """
    def objective(trial):
        # Sample hyperparameters
        params = {
            'hidden_channels': trial.suggest_categorical('hidden_channels', search_space['hidden_channels']),
            'dropout_rate': trial.suggest_categorical('dropout_rate', search_space['dropout_rate']),
            'learning_rate': trial.suggest_categorical('learning_rate', search_space['learning_rate']),
            'heads': trial.suggest_int('heads', min(search_space['heads']), max(search_space['heads']))
        }

        kf = KFold(n_splits=5)
        losses, r2s = [], []

        for train_idx, val_idx in kf.split(graphs):
            train_loader = DataLoader([graphs[i] for i in train_idx], batch_size=1, shuffle=True)
            val_loader = DataLoader([graphs[i] for i in val_idx], batch_size=1, shuffle=False)

            model = create_model(num_features=6, **params)
            optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'], weight_decay=1e-5)
            scheduler = CosineAnnealingLR(optimizer, T_max=50)
            criterion = nn.MSELoss()

            _, val_losses = train_model(model, train_loader, val_loader, optimizer, criterion, scheduler, num_epochs=110)

            preds, trues = [], []
            model.eval()
            with torch.no_grad():
                for data in val_loader:
                    if use_edge_weights:
                        pred = model(data.x, data.edge_index, data.edge_attr, data.batch)
                    else:
                        pred = model(data.x, data.edge_index, data.batch)
                    preds.append(pred.item())
                    trues.append(data.y.item())

            r2 = r2_score(trues, preds)
            losses.append(np.mean(val_losses))
            r2s.append(r2)

        # Combined metric: low loss + high R²
        return np.mean(losses) - np.mean(r2s)

    return objective


In [None]:
# ============================================================
# GAT LOOCV FINAL EXECUTION BLOCK

if __name__ == '__main__':
    # === Set Your Data Paths ===
    neuron_features_path = r"path/to/neuron_features.xlsx"
    edge_weights_path = r"path/to/edge_weights.xlsx"
    labels_path = r"path/to/labels.xlsx"

    # === Load and Process Data ===
    neuron_df, edge_df, labels_df = load_data(neuron_features_path, edge_weights_path, labels_path)
    graphs_first600, graphs_last600 = create_graph_objects(neuron_df, edge_df, labels_df)
    check_dummy_node_edges(graphs_first600 + graphs_last600)
    graphs_first600 = normalize_features(graphs_first600)
    graphs_last600 = normalize_features(graphs_last600)

    # === Toggle Hyperparameter Search ===
    run_optuna = False

    if run_optuna:
        print("🔎 Running Optuna for First 600s")
        study_first = optuna.create_study(direction='minimize', sampler=sampler)
        study_first.optimize(objective_factory(graphs_first600))
        best_params_first600 = study_first.best_params
        print("✅ Best Params First 600s:", best_params_first600)

        print("🔎 Running Optuna for Last 600s")
        study_last = optuna.create_study(direction='minimize', sampler=sampler)
        study_last.optimize(objective_factory(graphs_last600))
        best_params_last600 = study_last.best_params
        print("✅ Best Params Last 600s:", best_params_last600)
    else:
        # Predefined best parameters (from previous tuning)
        best_params_first600 = {'hidden_channels': 64, 'dropout_rate': 0.2, 'learning_rate': 0.001, 'heads': 2}
        best_params_last600 = {'hidden_channels': 128, 'dropout_rate': 0.2, 'learning_rate': 0.0001, 'heads': 1}

    # === Run LOOCV for First 600s ===
    print("\n🚀 LOOCV Evaluation — First 600s")
    train_loss_1, val_loss_1, r2_1, curves_1 = loocv_training(
        graphs_first600, best_params_first600, segment_name="First 600s"
    )
    plot_average_loss_curves(curves_1, "First 600s")

    # === Run LOOCV for Last 600s ===
    print("\n🚀 LOOCV Evaluation — Last 600s")
    train_loss_2, val_loss_2, r2_2, curves_2 = loocv_training(
        graphs_last600, best_params_last600, segment_name="Last 600s"
    )
    plot_average_loss_curves(curves_2, "Last 600s")

    # === Plot Final Predictions ===
    loader_first = DataLoader(graphs_first600, batch_size=1, shuffle=False)
    loader_last = DataLoader(graphs_last600, batch_size=1, shuffle=False)

    model_first = create_model(num_features=6, **best_params_first600)
    model_last = create_model(num_features=6, **best_params_last600)

    # Retrain on full data for prediction visualization
    optimizer_f = optim.Adam(model_first.parameters(), lr=best_params_first600['learning_rate'], weight_decay=1e-5)
    scheduler_f = CosineAnnealingLR(optimizer_f, T_max=50)
    criterion_f = nn.MSELoss()
    _ = train_model(model_first, loader_first, loader_first, optimizer_f, criterion_f, scheduler_f, num_epochs=110)

    optimizer_l = optim.Adam(model_last.parameters(), lr=best_params_last600['learning_rate'], weight_decay=1e-5)
    scheduler_l = CosineAnnealingLR(optimizer_l, T_max=50)
    criterion_l = nn.MSELoss()
    _ = train_model(model_last, loader_last, loader_last, optimizer_l, criterion_l, scheduler_l, num_epochs=110)

    # Plot
    plot_prediction_vs_actual(loader_first, model_first, title="First 600s: Predicted vs Actual")
    plot_prediction_vs_actual(loader_last, model_last, title="Last 600s: Predicted vs Actual")
