In [None]:
"""
PyTorch implementations of mixed effect models applied to
Parkinson's Telemonitoring dataset. (Functional Forward Fix 2 within N-Model Structure)
"""

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
# Removed GroupShuffleSplit as it's replaced by custom R-like split
from sklearn.preprocessing import StandardScaler
from collections import defaultdict
from collections.abc import Generator
from math import log, pi
from typing import Any, Dict, Final, Optional, List, Tuple, Type
import io
import requests
import warnings
import time # Import time for epoch timing

# --- Helper Functions/Classes (Replacements for python_tools) ---

# Global dtype for consistency
TENSOR_DTYPE = torch.float64

# Simplified base class (replace neural.LossModule)
class LossModule(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        # Basic init, ignores unused iterations if passed
        kwargs.pop('iterations', None)
        kwargs.pop('input_size', None)
        kwargs.pop('output_size', None)
        # Pop dtype if passed, subclasses might handle it
        kwargs.pop('dtype', None)
        if kwargs:
            # Suppress warning for hidden_sizes as it's handled by MLP
            kwargs.pop('hidden_sizes', None)
            if kwargs: # Check again if other kwargs remain
                warnings.warn(f"Unused arguments in LossModule init: {kwargs.keys()}")

    def loss(self, scores, ground_truth, meta, take_mean=True, loss=None):
        """Calculates the base loss (e.g., MSE)."""
        # Ensure inputs are compatible with MSELoss (usually float)
        # Cast to float32 for loss calculation if needed, as MSELoss might prefer it
        scores_float = scores.float()
        ground_truth_float = ground_truth.float()

        if loss is None:
            loss_fn = nn.MSELoss(reduction='none')
            loss = loss_fn(scores_float, ground_truth_float)

        # Return loss, potentially casting back to original dtype if required elsewhere
        loss = loss.to(scores.dtype)

        if take_mean:
            return loss.mean()
        else:
            # Return loss per sample, ensure correct shape
            # Sum across output dimensions if > 1
            if loss.ndim > 1 and loss.shape[1] > 1:
                 return loss.sum(dim=1)
            else:
                 return loss.view(-1) # Ensure 1D output per sample

    def can_jit(self) -> bool:
        # Assume models generally cannot be JITted easily unless specifically designed
        return False

# Simplified MLP (replace neural.MLP)
# *** Reverted to original structure with self.network ***
class MLP(LossModule):
    def __init__(self, input_size, output_size, hidden_sizes=[64, 32], activation=nn.ReLU, dtype=TENSOR_DTYPE, **kwargs):
        super().__init__(dtype=dtype, **kwargs) # Pass dtype up if needed
        self.dtype = dtype # Store dtype
        layers = []
        last_size = input_size
        for i, hidden_size in enumerate(hidden_sizes):
            layers.append(nn.Linear(last_size, hidden_size, dtype=self.dtype))
            layers.append(activation())
            last_size = hidden_size
        # Final layer is part of the main sequence now
        layers.append(nn.Linear(last_size, output_size, dtype=self.dtype))
        self.network = nn.Sequential(*layers) # Contains all layers

    # Removed forward_body as it's not needed
    def forward(self, x, meta=None, y=None, dataset="", **fwd_kwargs):
        """Standard forward pass."""
        # Ensure input is the correct dtype before passing to the network
        x = x.to(self.dtype)
        return self.network(x)

# Simplified Ensemble (replace neural.Ensemble) - Needed by original NME forward
class Ensemble(LossModule):
     def __init__(self, model: Type[LossModule], size: int, **kwargs):
        # The kwargs passed here *are* the model_kwargs
        super().__init__(**kwargs) # Pass unrelated kwargs up if any
        # Ensure essential args like input/output size are present in kwargs
        if 'input_size' not in kwargs or 'output_size' not in kwargs:
             raise ValueError("Ensemble needs input_size and output_size in its kwargs")

        # Pass all received kwargs to each model instance
        self.models = nn.ModuleList([model(**kwargs) for _ in range(size)])
        self.dtype = self.models[0].dtype if size > 0 else TENSOR_DTYPE # Infer dtype

     def forward(self, x, meta: dict, y: Optional[torch.Tensor] = None, dataset: str = "", weights: Optional[torch.Tensor] = None):
        # This forward needs to handle the weighting logic as expected by NME
        batch_size = x.shape[0]
        # Infer output size and dtype from the first model
        # *** Access network attribute correctly ***
        first_model_last_layer = self.models[0].network[-1] # Assumes MLP structure
        output_size = first_model_last_layer.out_features
        dtype = first_model_last_layer.weight.dtype # Get dtype from model params

        outputs = torch.zeros(batch_size, output_size, device=x.device, dtype=dtype)

        if weights is None:
            raise ValueError("Ensemble forward called without weights.")

        # Aggregate outputs based on weights
        # weights shape: (batch_size, num_models)
        for i, model in enumerate(self.models):
            model_output = model(x) # Shape: (batch_size, output_size)
            # Element-wise multiplication and sum across models
            # Ensure weights are the correct dtype for multiplication
            outputs += weights[:, i].unsqueeze(1).to(dtype) * model_output

        # NME expects (scores, meta) - the original Ensemble likely just returned scores
        # Let's return scores, meta to match previous structure expectations if needed
        return outputs, meta


# Cholesky decomposition (replace neural.cholesky)
def cholesky(A: torch.Tensor, dtype: torch.dtype = TENSOR_DTYPE) -> torch.Tensor:
    # Use torch.linalg.cholesky
    # Ensure input is the specified dtype (usually float64) for stability
    return torch.linalg.cholesky(A.to(dtype))

# Unique with index (replace neural.unique_with_index)
def unique_with_index(ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    unique_ids, inverse_indices = torch.unique(ids, return_inverse=True)
    # Need the index of the *first* occurrence of each unique id in the original tensor
    first_occurrence_indices = []
    seen_ids = set()
    original_indices = torch.arange(len(ids), device=ids.device)
    for id_val, orig_idx in zip(ids, original_indices):
        item_id = id_val.item()
        if item_id not in seen_ids:
            first_occurrence_indices.append(orig_idx)
            seen_ids.add(item_id)

    # Ensure the order matches torch.unique output order
    # Handle potential empty first_occurrence_indices if ids is empty
    if not first_occurrence_indices:
         first_occurrence_tensor = torch.empty(0, dtype=torch.long, device=ids.device)
    else:
         first_occurrence_tensor = torch.tensor(first_occurrence_indices, device=ids.device, dtype=torch.long)

    # Handle case where unique_ids might be empty
    if unique_ids.numel() == 0:
        return unique_ids, torch.empty(0, dtype=torch.long, device=ids.device)

    ordered_first_indices = {id_val.item(): idx for id_val, idx in zip(ids[first_occurrence_tensor], first_occurrence_tensor)}
    final_indices = [ordered_first_indices[uid.item()] for uid in unique_ids]
    return unique_ids, torch.tensor(final_indices, device=ids.device, dtype=torch.long)


# Covariance block calculation (replace neural.cov_block)
def cov_block(
    data: torch.Tensor,
    independent_indices: torch.Tensor,
    full_indices_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
    """
    Calculate variance for independent columns and covariance matrices for blocks.
    Args:
        data: Tensor of shape (n_samples, n_features), expected dtype TENSOR_DTYPE
        independent_indices: Indices of columns assumed independent.
        full_indices_list: List of tensors, each containing indices for a covariance block.
    Returns:
        Tuple containing:
        - Variances for independent columns (1D Tensor)
        - List of covariance matrices for the specified blocks
    """
    n_samples = data.shape[0]
    dtype = data.dtype # Use dtype from input data

    if n_samples <= 1:
        # Handle edge case: cannot compute covariance with <= 1 sample
        num_independent = len(independent_indices)
        variances = torch.ones(num_independent, device=data.device, dtype=dtype) # Default to 1
        cov_matrices = []
        for indices in full_indices_list:
            if indices is None or len(indices) == 0: continue # Skip empty blocks
            num_features = len(indices)
            cov_matrices.append(torch.eye(num_features, device=data.device, dtype=dtype)) # Default to identity
        return variances, cov_matrices

    # Center the data
    mean = torch.mean(data, dim=0, keepdim=True)
    centered_data = data - mean

    # Calculate variances for independent columns
    variances = torch.zeros(len(independent_indices), device=data.device, dtype=dtype)
    if len(independent_indices) > 0:
         variances = torch.var(data[:, independent_indices], dim=0, unbiased=True)


    # Calculate covariance matrices for blocks
    cov_matrices = []
    for indices in full_indices_list:
        if indices is None or len(indices) == 0: continue # Skip empty blocks
        # Calculate covariance: (X^T * X) / (n - 1)
        block_data = centered_data[:, indices]
        # Use torch.cov if available and suitable, otherwise manual calculation
        # Note: torch.cov expects features in rows by default, so transpose
        if hasattr(torch, 'cov'):
             # torch.cov returns scalar (0D) for 1 feature, matrix (2D) otherwise
            cov_matrix = torch.cov(block_data.T)
        else:
             # Manual calculation for older PyTorch or specific needs
             cov_matrix = (block_data.T @ block_data) / (n_samples - 1)
             # Ensure manual calc returns 2D even for 1 feature
             if cov_matrix.ndim == 0:
                 cov_matrix = cov_matrix.view(1,1) # Make it 1x1 matrix
        cov_matrices.append(cov_matrix)

    # Clamp small variances/diagonals to avoid numerical issues
    variances = variances.clamp(min=1e-8)
    for i in range(len(cov_matrices)): # Iterate using index to modify in place if needed
        cov_matrix = cov_matrices[i]
        if cov_matrix.numel() == 0: continue # Skip empty
        # Check dimension before calling diag or clamping scalar
        if cov_matrix.ndim == 2: # It's a 2D covariance matrix
             diag = torch.diag(cov_matrix)
             diag.clamp_(min=1e-8)
             # Explicitly set the diagonal back into the matrix
             cov_matrix.diagonal().copy_(diag)
        elif cov_matrix.ndim == 0: # It's a 0D variance scalar
             # Clamp the scalar value and update the list entry
             cov_matrices[i] = cov_matrix.clamp(min=1e-8)
        else:
             # Should not happen if torch.cov or manual calc behaves as expected
             print(f"Warning: Unexpected tensor dimension {cov_matrix.ndim} in cov_matrices.")
    return variances, cov_matrices

# Get object attribute helper (replace generic.get_object)
def get_object(obj: Any, attributes: List[str]) -> Any:
    """Access nested attributes."""
    current = obj
    for attr in attributes:
        if isinstance(current, nn.ModuleDict) or isinstance(current, nn.ParameterDict):
             current = current[attr]
        elif isinstance(current, nn.ModuleList) or isinstance(current, nn.Sequential):
            # Handle numerical indices for lists/sequential
            try:
                idx = int(attr)
                current = current[idx]
            except (ValueError, IndexError):
                # Try accessing by attribute name if index fails (e.g., named modules in Sequential)
                try:
                    current = getattr(current, attr)
                except AttributeError:
                    raise AttributeError(f"Cannot access attribute or index '{attr}' in {type(current)}")
        else:
             current = getattr(current, attr)
    return current

# Pop with prefix helper (replace neural.pop_with_prefix)
def pop_with_prefix(d: Dict[str, Any], prefix: str) -> Dict[str, Any]:
    """Remove items with a given prefix from dict and return them."""
    popped = {}
    keys_to_remove = []
    for key, value in d.items():
        if key.startswith(prefix):
            popped[key[len(prefix):]] = value
            keys_to_remove.append(key)
    for key in keys_to_remove:
        d.pop(key)
    return popped

# --- NME Code (Author's Version with Functional Fix) ---
# LinearMixedEffects class is not used by NeuralMixedEffects,
# but kept here for completeness if needed later.
class LinearMixedEffects(LossModule): # Note: Depends on neural.LossModule base
    """An activation function that learns a random effects. (Adapted)"""
    truncate: Final[int]
    only_bias: Final[bool]
    reml: Final[bool]
    device: Final[str]
    add_bias: Final[bool]
    log_2_pi: Final[float]
    random_effect_ids: torch.Tensor
    random_effects: torch.Tensor
    sigma_2: torch.Tensor
    d_sigma: torch.Tensor
    dtype: torch.dtype

    def __init__(
        self, *, embedding_size: int = -1, output_size: int = 1, truncate: int = 4096,
        add_bias: bool = True, only_bias: bool = False, number_of_cluster: int = -1,
        dtype: torch.dtype = TENSOR_DTYPE, reml: bool = False, device: str = "cpu",
        iterations: int = 0, **kwargs: Any,
    ) -> None:
        super().__init__(iterations=iterations, dtype=dtype, **kwargs)
        if only_bias: add_bias = True; embedding_size = 0
        self.truncate = truncate; self.only_bias = only_bias; self.reml = reml;
        self.device = device; self.dtype = dtype
        if embedding_size > 0: self.norm = torch.nn.LayerNorm(embedding_size, elementwise_affine=False, dtype=self.dtype)
        else: self.norm = nn.Identity()
        self.add_bias = add_bias
        current_embedding_size = embedding_size
        if self.add_bias: embedding_size = embedding_size + 1
        if embedding_size <= 0: raise ValueError("Embedding size (including bias) must be positive.")

        self.register_buffer("d_sigma", torch.eye(embedding_size, dtype=dtype).repeat(output_size, 1, 1))
        self.register_buffer("sigma_2", -torch.ones(output_size, dtype=dtype))
        assert number_of_cluster > 0
        self.register_buffer("random_effects", torch.zeros(output_size, number_of_cluster, embedding_size, dtype=dtype))
        self.register_buffer("random_effect_ids", -torch.ones(number_of_cluster, dtype=torch.long))
        self.log_2_pi = log(2 * pi)

    @staticmethod
    def _get_clusters(meta_id: torch.Tensor) -> dict[int, torch.Tensor]:
        unique_ids, inverse_indices = torch.unique(meta_id, return_inverse=True)
        clusters_by_id = defaultdict(list); [clusters_by_id[unique_ids[id_index].item()].append(i) for i, id_index in enumerate(inverse_indices)]
        clusters_by_size = defaultdict(list); [clusters_by_size[len(indices)].append(torch.tensor(indices, dtype=torch.long, device=meta_id.device)) for id_val, indices in clusters_by_id.items()]
        results: dict[int, torch.Tensor] = {}; [results.__setitem__(size, torch.stack(list_of_index_tensors, dim=0)) for size, list_of_index_tensors in clusters_by_size.items()]
        return results

    def estimate_random_effects(self, y: torch.Tensor, ids: torch.Tensor, random: torch.Tensor) -> None:
        ids = ids.to(torch.long); buffer_device = self.sigma_2.device
        random_processed = self._preprocess(random.to(device=buffer_device, dtype=self.dtype, non_blocking=True))
        full_clusters: Optional[Dict[int, torch.Tensor]] = None
        if self.truncate is not None and self.truncate > 0:
            clusters = self._get_clusters(ids.to(buffer_device)); keep = torch.ones(y.shape[0], dtype=torch.bool, device=buffer_device); indices_to_drop = []
            for observations, groups in clusters.items():
                if observations <= self.truncate: continue
                num_groups_of_this_size = groups.shape[0]
                for i in range(num_groups_of_this_size):
                    group_indices = groups[i]; perm = torch.randperm(observations, device=buffer_device); drop_in_group = perm[: observations - self.truncate]
                    original_indices_to_drop = group_indices[drop_in_group]; indices_to_drop.extend(original_indices_to_drop.tolist())
            if indices_to_drop:
                keep[torch.tensor(indices_to_drop, device=buffer_device, dtype=torch.long)] = False
                y = y.to(buffer_device)[keep]; ids = ids.to(buffer_device)[keep]; random_processed = random_processed.to(buffer_device)[keep]; full_clusters = None
            else: full_clusters = clusters; y = y.to(buffer_device); ids = ids.to(buffer_device); random_processed = random_processed.to(buffer_device)
        else: y = y.to(buffer_device); ids = ids.to(buffer_device); random_processed = random_processed.to(buffer_device); full_clusters = self._get_clusters(ids)
        self._em(y.to(dtype=self.dtype, non_blocking=True), ids.to(dtype=torch.long, non_blocking=True), random_processed.to(dtype=self.dtype, non_blocking=True), full_clusters=full_clusters)

    def _preprocess(self, x: torch.Tensor) -> torch.Tensor:
        x = x.to(dtype=self.dtype)
        if not self.only_bias and x.shape[1] > 0 : x = self.norm(x)
        if self.add_bias:
            bias_term = torch.ones(x.shape[0], 1, device=x.device, dtype=self.dtype)
            if x.shape[1] > 0: x = torch.cat([x, bias_term], dim=1)
            else: x = bias_term
        if self.only_bias: return x[:, -1:]
        return x

    def forward(self, y_fixed: torch.Tensor, meta: dict[str, torch.Tensor], y: Optional[torch.Tensor] = None, dataset: str = "") -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        target_device = self.sigma_2.device; y_fixed = y_fixed.to(device=target_device, dtype=self.dtype)
        meta = {k: v.to(device=target_device) if isinstance(v, torch.Tensor) else v for k, v in meta.items()}; meta['meta_id'] = meta['meta_id'].to(torch.long)

        if 'meta_embedding' in meta and meta['meta_embedding'].shape[1] > 0: x = self._preprocess(meta["meta_embedding"].to(target_device, dtype=self.dtype))
        elif self.add_bias: x = torch.ones(y_fixed.shape[0], 1, device=target_device, dtype=self.dtype)
        else: x = torch.empty(y_fixed.shape[0], 0, device=target_device, dtype=self.dtype)

        batch_size = x.shape[0]; num_random_features = x.shape[1]; num_outputs = y_fixed.shape[1]
        batch_random_effects = torch.zeros(batch_size, num_random_features, num_outputs, device=target_device, dtype=self.dtype)
        self.random_effect_ids = self.random_effect_ids.to(target_device); self.random_effects = self.random_effects.to(target_device)

        matches = torch.nonzero(self.random_effect_ids.unsqueeze(0) == meta["meta_id"].unsqueeze(1), as_tuple=False)
        if matches.numel() > 0:
             for i in range(batch_size):
                 current_id = meta['meta_id'][i].item(); match_idx = torch.where(self.random_effect_ids == current_id)[0]
                 if match_idx.numel() > 0: known_cluster_idx = match_idx[0]; batch_random_effects[i] = self.random_effects[:, known_cluster_idx, :].permute(1, 0)

        if num_random_features > 0: random_contribution = x.unsqueeze(1).bmm(batch_random_effects).squeeze(1); y_hats = y_fixed + random_contribution
        else: y_hats = y_fixed

        meta["meta_y_hat_fixed"] = y_fixed; meta["meta_y_hat"] = y_hats
        return y_hats, meta

    def add_new_random_effects(self, x: torch.Tensor, y_y_fixed: torch.Tensor, ids: torch.Tensor) -> None:
        target_device = self.sigma_2.device; ids = ids.to(target_device, dtype=torch.long); x = x.to(target_device, dtype=self.dtype); y_y_fixed = y_y_fixed.to(target_device, dtype=self.dtype)
        self.random_effect_ids = self.random_effect_ids.to(target_device); self.random_effects = self.random_effects.to(target_device)
        is_existing = torch.isin(ids, self.random_effect_ids)
        if is_existing.any():
             keep = ~is_existing;
             if not keep.any(): print("No new IDs to add."); return
             x = x[keep]; y_y_fixed = y_y_fixed[keep]; ids = ids[keep]
        if ids.numel() == 0: print("No new IDs to add after filtering."); return

        print(f"Adding effects for {len(torch.unique(ids))} new subjects.")
        saved_state = {"d_sigma": self.d_sigma.clone(), "sigma_2": self.sigma_2.clone(), "random_effects": self.random_effects.clone(), "random_effect_ids": self.random_effect_ids.clone()}
        num_new_clusters = len(torch.unique(ids)); embedding_size = self.random_effects.shape[2]; output_size = self.random_effects.shape[0]

        self.random_effects = torch.zeros(output_size, num_new_clusters, embedding_size, dtype=self.dtype, device=target_device)
        self.random_effect_ids = -torch.ones(num_new_clusters, dtype=torch.long, device=target_device)
        self.estimate_random_effects(y_y_fixed, ids, x)

        estimated_new_effects = self.random_effects.clone(); estimated_new_ids = self.random_effect_ids.clone()
        self.d_sigma = saved_state["d_sigma"]; self.sigma_2 = saved_state["sigma_2"]
        valid_new_mask = estimated_new_ids != -1
        if not valid_new_mask.any(): print("Warning: EM step did not yield valid new random effects."); self.random_effects = saved_state["random_effects"]; self.random_effect_ids = saved_state["random_effect_ids"]; return

        self.random_effects = torch.cat([saved_state["random_effects"], estimated_new_effects[:, valid_new_mask, :]], dim=1)
        self.random_effect_ids = torch.cat([saved_state["random_effect_ids"], estimated_new_ids[valid_new_mask]], dim=0)
        print(f"Total random effects stored: {self.random_effect_ids.shape[0]}")

    def _em(self, residual: torch.Tensor, ids: torch.Tensor, random: torch.Tensor, full_clusters: Optional[Dict[int, torch.Tensor]] = None) -> None:
        device = self.sigma_2.device; lmes = residual.shape[1]; num_random_features = random.shape[1]
        estimated_b_obs = torch.zeros(lmes, residual.shape[0], num_random_features, device=device, dtype=self.dtype)
        self.sigma_2 = self.sigma_2.to(device); self.d_sigma = self.d_sigma.to(device)

        if (self.sigma_2 == -1).all() or self.random_effects.abs().sum() == 0:
            current_sigma_2 = residual.detach().pow(2).mean(dim=0).clamp(min=1e-5)
            current_d_sigma = torch.eye(num_random_features, device=device, dtype=self.dtype).repeat(lmes, 1, 1) / current_sigma_2.view(-1, 1, 1).clamp(min=1e-5)
            print(f"Initialized sigma_2: {current_sigma_2.cpu().numpy()}, d_sigma diagonal: {current_d_sigma.diagonal(dim1=-2, dim2=-1).mean(dim=1).cpu().numpy()}")
            if self.sigma_2[0] == -1: self.sigma_2.copy_(current_sigma_2); self.d_sigma.copy_(current_d_sigma)
        else: current_sigma_2 = self.sigma_2.clamp(min=1e-5); current_d_sigma = self.d_sigma

        d_current = current_d_sigma * current_sigma_2.view(-1, 1, 1)
        if full_clusters is None: full_clusters = self._get_clusters(ids)
        sum_e_b_bt = torch.zeros_like(d_current); sum_e_residual_sq = torch.zeros_like(current_sigma_2); cluster_count: int = 0

        for cluster_size, clusters in full_clusters.items():
            num_groups = clusters.shape[0]; cluster_count += num_groups; flat_indices = clusters.view(-1).to(device)
            current_residual = residual[flat_indices].view(num_groups, cluster_size, lmes)
            current_random = random[flat_indices].view(num_groups, cluster_size, num_random_features)
            for j in range(lmes):
                Dj = d_current[j]; sigma_2j = current_sigma_2[j]
                ZDj = current_random @ Dj; ZDjZT = ZDj @ current_random.transpose(-1, -2)
                I = torch.eye(cluster_size, device=device, dtype=self.dtype).expand(num_groups, -1, -1); V = ZDjZT + sigma_2j * I
                try:
                    V_chol = torch.linalg.cholesky(V); residual_j = current_residual[:, :, j].unsqueeze(-1)
                    temp = torch.cholesky_solve(residual_j, V_chol); ZT_temp = torch.matmul(current_random.transpose(-1,-2), temp); b_hat_j = Dj @ ZT_temp
                    b_hat_j_repeated = b_hat_j.repeat_interleave(cluster_size, dim=0).view(-1, num_random_features); estimated_b_obs[j, flat_indices, :] = b_hat_j_repeated
                    temp_var = torch.cholesky_solve(ZDj, V_chol); Var_b_j = Dj - torch.matmul(ZDj.transpose(-1,-2), temp_var); E_b_bt_j = Var_b_j + torch.matmul(b_hat_j, b_hat_j.transpose(-1,-2))
                    sum_e_b_bt[j] += E_b_bt_j.sum(dim=0)
                    term1 = (residual_j**2).sum(dim=1); yT_Z = torch.matmul(residual_j.transpose(-1,-2), current_random); term2 = -2 * torch.matmul(yT_Z, b_hat_j).squeeze(-1)
                    ZT_Z = torch.matmul(current_random.transpose(-1,-2), current_random); term3 = torch.sum(ZT_Z * E_b_bt_j.transpose(-1,-2), dim=(-1,-2)).unsqueeze(-1)
                    sum_e_residual_sq[j] += (term1 + term2 + term3).sum(dim=0)
                except torch.linalg.LinAlgError as e: print(f"Cholesky failed for LME {j}, cluster size {cluster_size}. Skipping. Error: {e}"); continue

        if cluster_count > 0:
            new_D = sum_e_b_bt / cluster_count; new_D = (new_D + new_D.transpose(-1, -2)) / 2
            total_observations = residual.shape[0]; new_sigma_2 = sum_e_residual_sq / total_observations; new_sigma_2.clamp_(min=1e-5)
            self.sigma_2.copy_(new_sigma_2); self.d_sigma.copy_(new_D / new_sigma_2.view(-1, 1, 1).clamp(min=1e-5))

            unique_ids_in_batch, index = unique_with_index(ids); num_unique_ids = len(unique_ids_in_batch)
            _, inverse_indices = torch.unique(ids, return_inverse=True)
            avg_effects = torch.zeros(lmes, num_unique_ids, num_random_features, device=device, dtype=self.dtype); id_counts = torch.zeros(num_unique_ids, device=device, dtype=torch.long)
            id_counts.scatter_add_(0, inverse_indices, torch.ones_like(inverse_indices, dtype=torch.long))
            for j in range(lmes): avg_effects[j].scatter_add_(0, inverse_indices.unsqueeze(-1).expand_as(estimated_b_obs[j]), estimated_b_obs[j])
            valid_counts_mask = id_counts > 0
            if valid_counts_mask.any(): avg_effects[:, valid_counts_mask, :] /= id_counts[valid_counts_mask].view(1, -1, 1).to(self.dtype)
            self.random_effects = avg_effects.clone(); self.random_effect_ids = unique_ids_in_batch.clone()
        else: print("Warning: EM step skipped due to zero clusters processed.")

    def _lme_d_sigma(*args, **kwargs): raise NotImplementedError("REML d_sigma adjustment not implemented.")
    def _lme_random_effects(*args, **kwargs): raise NotImplementedError("REML helper _lme_random_effects not implemented.")
    def _lme_x_vinv_x(*args, **kwargs): raise NotImplementedError("REML helper _lme_x_vinv_x not implemented.")


# ============================================================================
# NME Class - Author's Version (Requires Ensemble helper) + Functional Forward
# ============================================================================
class NeuralMixedEffects(LossModule): # Note: Depends on neural.LossModule base
    """Fit non-linear mixed effect models."""
    simulated_annealing_alpha: Final[float]
    dtype: Final[torch.dtype]
    l2_lambda: Final[float]
    p_eta: torch.Tensor
    index_diagonal: Final[torch.Tensor]
    index_full: Final[torch.Tensor]
    models: nn.ModuleList # Uses N model instances
    fixed_model: Optional[LossModule]
    cluster_names: torch.Tensor
    cluster_count: torch.Tensor # Original code uses this in loss
    mixed_parameters: Dict[str, List[nn.Parameter]] # Only used for init now
    random: nn.ParameterDict
    fixed: nn.ParameterDict
    sigma_diagonal: torch.Tensor
    sigma_full: torch.Tensor
    sigma_2: torch.Tensor
    sa_sigma: torch.Tensor
    during_training: bool
    losses: list[torch.Tensor]
    cluster_id_to_index: Dict[int, int]
    random_effect_names: Tuple[str, ...]

    # @beartype removed
    def __init__(
        self,
        *,
        clusters: torch.Tensor,
        model_fun: type[LossModule] = MLP, # Note: Depends on neural.MLP
        fixed_model_fun: type[LossModule] | None = None, # Note: Depends on neural.LossModule
        random_effects: tuple[str, ...] = (),
        random_buffers: tuple[str, ...] = (),
        simulated_annealing_alpha: float = 0.97,
        cluster_count: torch.Tensor, # Required by original loss fn
        dtype: torch.dtype = TENSOR_DTYPE, # Use global dtype
        independent: tuple[str, ...] = (),
        l2_lambda: float = 1.0,
        **kwargs: Any,
    ) -> None:
        """Instantiate a non-linear mixed effects model. (Author's version structure)"""
        # separate keywords for the mixed/fixed models
        model_kwargs = pop_with_prefix(kwargs, "model_")
        fixed_model_kwargs = pop_with_prefix(kwargs, "fixed_model_")

        # Ensure output_size and input_size are correctly passed
        output_size = kwargs.get("output_size", model_kwargs.get("output_size"))
        input_size = kwargs.get("input_size", fixed_model_kwargs.get("input_size"))

        if output_size is None: raise ValueError("output_size is required")
        if input_size is None and fixed_model_fun is None: raise ValueError("input_size is required if fixed_model_fun is None")
        if input_size is None and fixed_model_fun is not None: raise ValueError("input_size is required for fixed_model_fun")


        model_kwargs.setdefault("output_size", output_size)
        if fixed_model_fun is None:
             model_kwargs.setdefault("input_size", input_size)
             fixed_model_kwargs.setdefault("input_size", input_size) # For consistency
        else:
            fixed_model_kwargs.setdefault("input_size", input_size)
            fixed_size = fixed_model_kwargs.get("output_size", model_kwargs.get("input_size"))
            if fixed_size is None:
                 warnings.warn("Cannot infer fixed_model output size. Defaulting.")
                 fixed_size = input_size # Fallback
            fixed_model_kwargs.setdefault("output_size", fixed_size)
            model_kwargs.setdefault("input_size", fixed_size)


        # replace "." with "-" for internal consistency in original code
        random_effects_std = tuple(name.replace(".", "-") for name in random_effects)
        independent_std = tuple(name.replace(".", "-") for name in independent)
        self.random_effect_names = random_effects_std

        super().__init__(**kwargs) # Pass remaining kwargs up
        self.l2_lambda = l2_lambda
        self.simulated_annealing_alpha = simulated_annealing_alpha
        self.dtype = dtype
        self.register_buffer("cluster_count", cluster_count.view(-1).to(self.dtype)) # Ensure correct dtype
        self.during_training = False
        self.losses: list[torch.Tensor] = []

        # instantiate purely fixed model
        self.fixed_model = None
        if fixed_model_fun is not None:
            fixed_model_kwargs.setdefault('dtype', self.dtype) # Pass dtype
            self.fixed_model = fixed_model_fun(**fixed_model_kwargs)
            print(f"Instantiated fixed model: {type(self.fixed_model)}")

        # instantiate N models (mixed models)
        assert clusters.numel() == cluster_count.numel()
        self.cluster_names = clusters.view(-1).to(torch.long) # Ensure long IDs
        self.cluster_id_to_index = {cid.item(): idx for idx, cid in enumerate(self.cluster_names)}
        num_clusters = len(self.cluster_names)
        print(f"Instantiating {num_clusters} models for NME...")
        model_kwargs.setdefault('dtype', self.dtype) # Pass dtype to model_fun

        # *** Use nn.ModuleList directly, not Ensemble helper ***
        self.models = nn.ModuleList([model_fun(**model_kwargs) for _ in range(num_clusters)])
        print(f"Instantiated NME models: {type(self.models[0])}")


        # keep reference of mixed parameters and share fixed parameters
        self.mixed_parameters = defaultdict(list) # Still needed for initialization
        shared_parameters = {}
        shared_buffers = {}
        first_model = self.models[0]
        param_is_random = {}
        buffer_is_random = {}

        for name, param in first_model.named_parameters():
            std_name = name.replace(".", "-")
            is_random = std_name in self.random_effect_names
            param_is_random[name] = is_random
            if is_random:
                self.mixed_parameters[std_name].append(param)
            else:
                shared_parameters[name] = param

        for name, buf in first_model.named_buffers():
             is_random_buffer = name in random_buffers
             buffer_is_random[name] = is_random_buffer
             if not is_random_buffer:
                 shared_buffers[name] = buf

        # Link shared parameters/buffers in other models
        for i, model in enumerate(self.models[1:], 1):
            for name, param in model.named_parameters():
                if not param_is_random[name]:
                    try:
                        name_parts = name.split('.'); parent_module = get_object(model, name_parts[:-1]); attr_name = name_parts[-1]
                        setattr(parent_module, attr_name, shared_parameters[name])
                    except Exception as e: print(f"Error linking shared param '{name}' in model {i}: {e}")
            for name, buf in model.named_buffers():
                 if not buffer_is_random[name]:
                     try:
                         name_parts = name.split('.'); parent_module = get_object(model, name_parts[:-1]); attr_name = name_parts[-1]
                         setattr(parent_module, attr_name, shared_buffers[name])
                     except Exception as e: print(f"Error linking shared buffer '{name}' in model {i}: {e}")


        # separate mixed parameters into fixed and random
        self.random = nn.ParameterDict()
        self.fixed = nn.ParameterDict()
        if not self.mixed_parameters: print("Warning: No random effects specified.")

        for key, values in self.mixed_parameters.items():
            if not values: continue
            # Initialize fixed effect (beta) using the first model's parameter
            self.fixed[key] = nn.Parameter(values[0].data.clone().to(self.dtype))
            # Initialize random effects (eta) to zero for all clusters
            self.random[key] = nn.Parameter(
                torch.zeros((num_clusters, *values[0].shape), dtype=self.dtype)
            )
            print(f"Tracking random effect '{key}' with shape: {self.random[key].shape}, dtype: {self.random[key].dtype}")

        # get size of random effects
        number_random_effects = 0
        parameter_index = {}
        for name, parameter in self.fixed.items(): # Use self.fixed for reference shape/size
            size = parameter.numel()
            parameter_index[name] = list(
                range(number_random_effects, number_random_effects + size),
            )
            number_random_effects += size
        print(f"Total scalar random effects per cluster: {number_random_effects}")


        # inverting covariance structure (full or independent)
        full = tuple(sorted(set(self.mixed_parameters.keys()).difference(independent_std)))
        independent = tuple(
            set(independent_std).intersection(self.mixed_parameters.keys()),
        )

        independent_indices_list = [parameter_index[x] for x in independent] if independent else []
        independent_indices = torch.tensor(generic_flatten_nested_list(independent_indices_list), dtype=torch.long)
        full_indices_list = [parameter_index[x] for x in full] if full else []
        full_indices = torch.tensor(generic_flatten_nested_list(full_indices_list), dtype=torch.long)

        print(f"Independent effect indices: {independent_indices.tolist()}")
        print(f"Full covariance block indices: {full_indices.tolist()}")

        self.register_buffer("sigma_diagonal", torch.ones(len(independent_indices), dtype=self.dtype))
        self.register_buffer("sigma_full", torch.eye(len(full_indices), dtype=self.dtype))
        self.index_diagonal = independent_indices
        self.index_full = full_indices

        # variables for 'SAEM' (for M)
        self.register_buffer("sigma_2", torch.tensor([0.1], dtype=self.dtype)) # Initial guess 0.1
        self.register_buffer("sa_sigma", torch.ones(number_random_effects, dtype=self.dtype))
        self.register_buffer("p_eta", torch.zeros(len(self.models), dtype=self.dtype)) # Match number of models

    def named_parameters(
        self,
        prefix: str = "",
        recurse: bool = True,
    ) -> Generator[tuple[str, torch.nn.Parameter], None, None]:
        """Generator over parameters optimized by NME."""
        param_names_yielded = set()

        # Yield parameters from the fixed_model if it exists
        if self.fixed_model is not None:
            for name, param in self.fixed_model.named_parameters(prefix=prefix + "fixed_model", recurse=recurse):
                 full_name = prefix + "fixed_model." + name # Original name includes "fixed_model." prefix
                 if full_name not in param_names_yielded:
                     yield full_name, param
                     param_names_yielded.add(full_name)


        # Yield the 'fixed' part of the random effects (mean effect, beta)
        for name, param in self.fixed.named_parameters(prefix=prefix + "fixed", recurse=recurse):
             full_name = prefix + "fixed." + name
             if full_name not in param_names_yielded:
                 yield full_name, param
                 param_names_yielded.add(full_name)

        # Yield the 'random' part of the random effects (deviations, eta)
        for name, param in self.random.named_parameters(prefix=prefix + "random", recurse=recurse):
             full_name = prefix + "random." + name
             if full_name not in param_names_yielded:
                 yield full_name, param
                 param_names_yielded.add(full_name)

        # Yield the shared parameters from the *first* model instance
        if self.models:
            first_model = self.models[0]
            for name, param in first_model.named_parameters(recurse=recurse):
                std_name = name.replace(".", "-")
                # Yield only if it's NOT a random effect parameter
                if std_name not in self.random_effect_names:
                    full_name = prefix + "models.shared." + name # Adjust prefix for clarity
                    if full_name not in param_names_yielded:
                        yield full_name, param
                        param_names_yielded.add(full_name)

    def can_jit(self) -> bool:
        return False

    def _update_sigma_2_sigma(self) -> None:
        """Update random effect covariance (Sigma) and residual variance (sigma_2)."""
        target_device = self.sigma_2.device # Use a buffer's device

        # --- Update sigma_2 (Residual Variance) ---
        if self.losses:
             new_sigma_2 = torch.stack(self.losses).mean().clamp(min=1e-8)
             self.sigma_2.copy_(new_sigma_2.to(target_device, dtype=self.dtype))
        self.losses = [] # Clear accumulated losses

        # --- Update Sigma (Random Effect Covariance) ---
        if not self.random: # Skip if no random effects defined
             return

        with torch.no_grad():
            eta = self._get_eta() # Shape: (num_clusters, num_random_effects), dtype=self.dtype, device=target_device
            if eta.shape[0] <= 1: return

            idx_diag = self.index_diagonal.to(target_device)
            idx_full = self.index_full.to(target_device)

            variances, sigmas = cov_block(eta, idx_diag, [idx_full] if idx_full.numel() > 0 else [])

            self.sa_sigma = self.sa_sigma.to(target_device)
            self.sigma_diagonal = self.sigma_diagonal.to(target_device)
            self.sigma_full = self.sigma_full.to(target_device)


            if idx_diag.numel() > 0:
                current_sa_diag = self.sa_sigma[idx_diag]
                new_sa_diag = torch.max(current_sa_diag * self.simulated_annealing_alpha, variances)
                self.sa_sigma[idx_diag] = new_sa_diag
                self.sigma_diagonal.copy_(1.0 / new_sa_diag.clamp(min=1e-8))

            if idx_full.numel() > 0 and sigmas:
                if sigmas[0] is not None and sigmas[0].numel() > 0:
                    sigma_empirical = sigmas[0]
                    current_sa_full_diag = self.sa_sigma[idx_full]

                    if sigma_empirical.ndim == 0: # Scalar variance case
                         new_sa_diag_full = torch.max(current_sa_full_diag * self.simulated_annealing_alpha, sigma_empirical)
                         self.sa_sigma[idx_full] = new_sa_diag_full
                         if self.sigma_full.shape == (1,1):
                             inv_sqrt_var = (1.0 / new_sa_diag_full.clamp(min=1e-8)).sqrt()
                             self.sigma_full.fill_(inv_sqrt_var.item())
                         else: print(f"Warning: Shape mismatch for sigma_full ({self.sigma_full.shape}) in scalar case.")

                    elif sigma_empirical.ndim == 2: # Matrix covariance case
                        new_sa_diag_full = torch.max(current_sa_full_diag * self.simulated_annealing_alpha, sigma_empirical.diagonal())
                        self.sa_sigma[idx_full] = new_sa_diag_full

                        sigma_annealed = sigma_empirical.clone()
                        sigma_annealed.diagonal().copy_(new_sa_diag_full)

                        try:
                            identity = torch.eye(sigma_annealed.shape[0], device=target_device, dtype=self.dtype)
                            chol_annealed = torch.linalg.cholesky(sigma_annealed)
                            sigma_inv = torch.cholesky_solve(identity, chol_annealed)

                            min_eig = torch.min(torch.linalg.eigvalsh(sigma_inv))
                            diag_mean = torch.mean(torch.abs(sigma_inv.diagonal()))
                            jitter_threshold = max(1e-8, diag_mean * 1e-6)

                            if min_eig < jitter_threshold:
                                jitter = jitter_threshold - min_eig
                                sigma_inv.diagonal().add_(jitter)

                            chol_sigma_inv = torch.linalg.cholesky(sigma_inv)
                            chol_sigma_inv.diagonal().clamp_(min=0.0) # Ensure non-negative diagonal for Cholesky of inverse
                            self.sigma_full.copy_(chol_sigma_inv)
                        except torch.linalg.LinAlgError as e: print(f"Warning: Cholesky failed during Sigma update: {e}. Sigma_full not updated.")

    def _get_eta(self) -> torch.Tensor:
        target_device = self.sigma_2.device
        if not self.random: return torch.empty((len(self.cluster_names), 0), device=target_device, dtype=self.dtype)

        self.random.to(target_device) # Ensure ParameterDict's parameters are on target device
        etas = [self.random[name].view(self.random[name].shape[0], -1).to(self.dtype) for name in self.random.keys()]
        return torch.cat(etas, dim=1)


    def _after_training(self) -> None:
        if not (not self.training and self.during_training): return
        self.during_training = False
        with torch.no_grad(): self._update_sigma_2_sigma()

    def train(self, mode: bool = True):
        if not mode: self._after_training()
        super().train(mode)
        if mode: self._after_training() # Should be if mode and not self.during_training ?
                                       # Original code has self.during_training = True at start of train. Let's ensure that happens.
                                       # The train() logic calls super().train(mode) first.
                                       # Then, if mode is True, it calls _after_training(). This is to ensure _update_sigma runs at start of training.
                                       # And if mode is False (eval), it calls _after_training() to run at end of training / start of eval.

        # This structure seems to be:
        # train(True) -> super().train(True) -> self.during_training=True -> _after_training() [updates based on previous epoch or init state]
        # train(False) -> _after_training() [updates based on last training data] -> super().train(False) -> self.during_training=False
        # Let's simplify based on common patterns: update sigma at the end of an epoch (when switching to eval or before next train).
        # The current hook for _update_sigma_2_sigma in train()/eval() is a bit convoluted.
        # A common way is:
        # my_model.train() # sets self.training = True
        # ... epoch loop ...
        # my_model.eval() # sets self.training = False, triggers sigma update based on accumulated losses during training.
        # Let's adjust based on author's intent: _after_training() seems to be the main hook.
        # If mode is True (switching to train mode): set during_training = True. Update sigma if it was previously False.
        # If mode is False (switching to eval mode): update sigma based on accumulated. Set during_training = False.

        if mode: # Entering training mode
            if not self.during_training: # If previously in eval mode or first time
                 self._update_sigma_2_sigma() # Update based on whatever was before (e.g. init, or last eval)
            self.during_training = True
        else: # Entering eval mode
            if self.during_training: # If previously in training mode
                self._update_sigma_2_sigma() # This is the crucial update based on the training epoch's losses
            self.during_training = False
        super().train(mode) # This sets self.training
        return self


    def eval(self): return self.train(False)


    # *** FORWARD PASS USING FUNCTIONAL RANDOM EFFECTS ***
    def forward(
        self,
        x: torch.Tensor, # Input features
        meta: dict[str, torch.Tensor], # Metadata including 'meta_id'
        y: Optional[torch.Tensor] = None, # Ground truth (optional)
        dataset: str = "", # Optional dataset tag (unused)
    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        """Forward pass using N model instances + functional effects."""
        target_device = self.sigma_2.device
        x = x.to(target_device, dtype=self.dtype)
        meta = {k: v.to(target_device) if isinstance(v, torch.Tensor) else v for k, v in meta.items()}
        batch_ids = meta['meta_id'].to(torch.long)

        if y is not None:
             y = y.to(target_device, dtype=self.dtype)

        # Ensure internal buffers/params are on the correct device
        self.cluster_names = self.cluster_names.to(target_device)
        self.fixed.to(target_device)
        self.random.to(target_device)
        self.models.to(target_device) # Move all sub-models
        self.models.train(self.training) # Set sub-models mode

        # --- Fixed Model Pass (if applicable) ---
        if self.fixed_model is not None:
            self.fixed_model = self.fixed_model.to(target_device)
            self.fixed_model.train(self.training)
            x_fixed, meta = self.fixed_model(x, meta, y=y, dataset=dataset)
            x_mixed_input = x_fixed.to(self.dtype)
        else:
            x_mixed_input = x

        # --- Map samples to models and process ---
        # Create a mapping from batch_id to model_idx (-1 if unknown)
        batch_indices = torch.tensor([self.cluster_id_to_index.get(bid.item(), -1) for bid in batch_ids], dtype=torch.long, device=target_device)

        batch_size = x_mixed_input.shape[0]
        # Infer output size from first model's final layer
        try:
            # *** Access network attribute correctly ***
            output_size = self.models[0].network[-1].out_features
        except Exception as e:
            print(f"Warning: Could not infer output size ({e}), defaulting to 1")
            output_size = 1

        final_scores = torch.zeros(batch_size, output_size, device=target_device, dtype=self.dtype)

        # Process known clusters (apply random effects functionally)
        valid_mask = batch_indices != -1
        if valid_mask.any():
            valid_input = x_mixed_input[valid_mask]
            valid_cluster_indices_for_batch = batch_indices[valid_mask] # These are model indices

            # Iterate over unique model indices present in this batch
            unique_model_indices_in_batch = torch.unique(valid_cluster_indices_for_batch)

            temp_scores = torch.zeros(valid_input.shape[0], output_size, device=target_device, dtype=self.dtype)
            temp_scores_original_indices = torch.arange(len(valid_input), device=target_device)


            for model_idx_tensor in unique_model_indices_in_batch:
                model_idx = model_idx_tensor.item() # Get scalar index

                # Get all samples in the valid_input that belong to this model_idx
                samples_for_this_model_mask = (valid_cluster_indices_for_batch == model_idx)
                model_input_for_this_cluster = valid_input[samples_for_this_model_mask]

                if model_input_for_this_cluster.shape[0] == 0: continue

                model_instance = self.models[model_idx] # Get the specific model instance (for its structure)
                                                    # but we'll use shared params + random effects

                # Pass through model body (up to penultimate layer)
                # *** Access network attribute correctly ***
                if isinstance(model_instance, MLP):
                    # Use the network of the *first model* for shared layers
                    body_output = self.models[0].network[:-1](model_input_for_this_cluster)
                    final_layer_template = self.models[0].network[-1] # Template for structure (e.g. bias existence)
                else:
                    raise TypeError("Functional forward requires MLP with accessible layers")

                # Apply final layer functionally with random effects
                scores_cluster_for_model = body_output @ final_layer_template.weight.t() # Apply shared weight (beta_weight)

                # Apply bias (fixed + random) if it exists and is specified as a random effect
                # Assuming the last layer's bias is 'network.(N-1).bias' where N-1 is the index of the last Linear layer.
                # If MLP has [H1, H2], layers are Lin(0), Act(1), Lin(2), Act(3), Lin(4) -> last bias is network.4.bias
                # This needs to be robust to MLP structure. For now, assume fixed naming from example.
                final_layer_idx_in_mlp_seq = -1 # The final Linear layer
                bias_param_name_in_model = f"network.{len(self.models[0].network) + final_layer_idx_in_mlp_seq}.bias" # e.g. network.4.bias
                bias_std_name = bias_param_name_in_model.replace(".", "-")


                if bias_std_name in self.random_effect_names:
                    fixed_bias = self.fixed[bias_std_name]       # beta_bias
                    random_bias_all_clusters = self.random[bias_std_name] # eta_bias for all clusters
                    eta_bias_for_this_cluster = random_bias_all_clusters[model_idx]
                    combined_bias_for_cluster = fixed_bias + eta_bias_for_this_cluster
                    scores_cluster_for_model = scores_cluster_for_model + combined_bias_for_cluster
                elif final_layer_template.bias is not None: # Apply non-random shared bias if it exists
                    scores_cluster_for_model = scores_cluster_for_model + final_layer_template.bias

                # Apply other random effects (e.g., on weights) if configured - more complex
                # Example for random effect on final layer's weight (if 'network.4.weight' was random)
                # weight_param_name_in_model = f"network.{len(self.models[0].network) + final_layer_idx_in_mlp_seq}.weight"
                # weight_std_name = weight_param_name_in_model.replace(".", "-")
                # if weight_std_name in self.random_effect_names:
                #     fixed_weight = self.fixed[weight_std_name] # beta_W
                #     random_weight_all_clusters = self.random[weight_std_name] # eta_W
                #     eta_weight_for_this_cluster = random_weight_all_clusters[model_idx]
                #     combined_weight_for_cluster = fixed_weight + eta_weight_for_this_cluster
                #     # Recompute scores_cluster with combined_weight_for_cluster
                #     # scores_cluster_for_model = body_output @ combined_weight_for_cluster.t()
                #     # ... then add bias as above.
                #     # For now, this part is commented as only bias is random in example.
                # else: # Already used shared weight: final_layer_template.weight

                # Place scores back in the correct positions for the batch
                indices_in_temp_scores = temp_scores_original_indices[samples_for_this_model_mask]
                temp_scores[indices_in_temp_scores] = scores_cluster_for_model

            # Assign the computed scores for valid samples back to the final output
            final_scores[valid_mask] = temp_scores


        # Handle unknown clusters (use average prediction? or fixed effects only?)
        # Using fixed effects only (beta) from self.fixed
        if not valid_mask.all(): # If there are any unknown clusters
            unknown_mask = ~valid_mask
            unknown_input = x_mixed_input[unknown_mask]
            if unknown_input.shape[0] > 0:
                # Use the first model instance as representative for shared params and structure
                model0 = self.models[0]
                if isinstance(model0, MLP):
                    # *** Access network attribute correctly ***
                    unknown_body_output = model0.network[:-1](unknown_input)
                    final_layer = model0.network[-1]
                else:
                    raise TypeError("Functional forward for unknown clusters requires MLP")

                # Use the shared weight (beta_weight from fixed effects perspective)
                # If final layer weight itself were random, fixed_weight = self.fixed['network.4.weight']
                # Otherwise, it's just the shared parameter:
                fixed_weight = final_layer.weight

                # Determine bias: use beta (from self.fixed) if param is random, else use model param's bias
                final_layer_idx_in_mlp_seq = -1
                bias_param_name_in_model = f"network.{len(self.models[0].network) + final_layer_idx_in_mlp_seq}.bias"
                bias_std_name = bias_param_name_in_model.replace(".", "-")

                scores_unknown = torch.matmul(unknown_body_output, fixed_weight.t())
                if bias_std_name in self.random_effect_names:
                    fixed_bias_val = self.fixed[bias_std_name] # Use beta_bias
                    scores_unknown = scores_unknown + fixed_bias_val
                elif final_layer.bias is not None: # Use original shared non-random bias
                    scores_unknown = scores_unknown + final_layer.bias
                # Else: no bias to add

                final_scores[unknown_mask] = scores_unknown

        return final_scores, meta


    def _get_p_eta(self, *, index: torch.Tensor) -> torch.Tensor:
        """Calculate 0.5 * eta^T Sigma^-1 eta for selected clusters."""
        target_device = self.sigma_2.device
        if not self.random or index.numel() == 0:
            return torch.zeros(index.numel(), device=target_device, dtype=self.dtype)

        # Get eta values for the specified clusters (already correct dtype/device)
        # Ensure index is valid (these are model indices, should be < num_clusters)
        if index.max() >= len(self.cluster_names):
             print(f"Warning: Invalid index in _get_p_eta. Max index: {index.max()}, Num clusters: {len(self.cluster_names)}")
             valid_mask_idx = index < len(self.cluster_names); index = index[valid_mask_idx]
             if index.numel() == 0: return torch.zeros(0, device=target_device, dtype=self.dtype)

        eta_all_clusters = self._get_eta() # Shape: (num_total_clusters, num_random_effects)
        eta = eta_all_clusters[index] # Shape: (len(index), num_random_effects)

        # Original code subtracts mean, let's replicate that.
        eta = eta - eta.mean(dim=0, keepdims=True) # Centering eta for the batch being processed

        p_eta_val = torch.zeros(eta.shape[0], device=target_device, dtype=self.dtype)

        # Ensure Sigma components are on the correct device
        idx_diag = self.index_diagonal.to(target_device)
        idx_full = self.index_full.to(target_device)
        sigma_diag_inv_sqrt = self.sigma_diagonal.to(target_device) # This is 1/variance
        sigma_full_L_inv = self.sigma_full.to(target_device) # Lower Cholesky factor of Sigma_inv

        # --- Contribution from Independent Effects ---
        if idx_diag.numel() > 0:
            eta_diag = eta[:, idx_diag] # Shape: (len(index), num_independent)
            # sigma_diagonal stores 1/variance. So (eta^2 / variance)
            p_eta_diag_contrib = (eta_diag.pow(2) * sigma_diag_inv_sqrt.unsqueeze(0)).sum(dim=1)
            p_eta_val += p_eta_diag_contrib

        # --- Contribution from Full Covariance Block ---
        if idx_full.numel() > 0:
            eta_full = eta[:, idx_full] # Shape: (len(index), num_full)
            L_inv = sigma_full_L_inv # This is L_inv such that Sigma_inv = L_inv L_inv^T

            if L_inv.numel() > 0: # Check if L_inv is not empty
                # We want eta^T Sigma_inv eta = eta^T L_inv L_inv^T eta = (L_inv^T eta)^T (L_inv^T eta)
                # So, calculate v = L_inv^T eta, then sum squares of v.
                if L_inv.shape == (1,1) and eta_full.shape[1] == 1: # scalar random effect in full block
                    L_inv_T_eta = L_inv.T * eta_full # element-wise for (N_batch, 1)
                elif L_inv.ndim == 2 and eta_full.ndim == 2:
                    try:
                        # L_inv is (k,k), eta_full is (N_batch, k)
                        # L_inv.T is (k,k). eta_full.unsqueeze(-1) is (N_batch, k, 1)
                        # bmm: (N_batch, k, k) @ (N_batch, k, 1) -> (N_batch, k, 1)
                        # or (k,k) @ (k, N_batch) -> (k, N_batch) then transpose and sum
                        L_inv_T_expanded = L_inv.T.unsqueeze(0).expand(eta_full.shape[0], -1, -1) # (N_batch, k, k)
                        eta_full_vec = eta_full.unsqueeze(-1) # (N_batch, k, 1)
                        L_inv_T_eta = torch.bmm(L_inv_T_expanded, eta_full_vec) # (N_batch, k, 1)
                    except RuntimeError as e:
                        print(f"Error during p_eta calculation (bmm): {e}")
                        print(f"Shapes: L_inv.T={L_inv.T.shape}, eta_full_vec shape={eta_full_vec.shape if 'eta_full_vec' in locals() else 'eta_full_vec not defined'}")
                        L_inv_T_eta = torch.zeros_like(eta_full.unsqueeze(-1)) # Fallback
                else: # Should not happen with correct setup
                    print(f"Warning: Unexpected shapes for L_inv ({L_inv.shape}) or eta_full ({eta_full.shape})")
                    L_inv_T_eta = torch.zeros_like(eta_full.unsqueeze(-1))

                p_eta_full_contrib = torch.sum(L_inv_T_eta.pow(2), dim=(1, 2) if L_inv_T_eta.ndim == 3 else 1)
                p_eta_val += p_eta_full_contrib.clamp(min=0.0)
            else: # L_inv is empty tensor (e.g. idx_full has elements but sigma_full is 0x0)
                 pass # No contribution


        # Return 0.5 * eta^T Sigma^-1 eta
        return p_eta_val / 2.0


    # @torch.jit.export # JIT export removed
    def loss(
        self,
        scores: torch.Tensor,
        ground_truth: torch.Tensor,
        meta: dict[str, torch.Tensor],
        take_mean: bool = True, # Original code didn't use this arg
        loss: Optional[torch.Tensor] = None, # Original code didn't use this arg
    ) -> torch.Tensor:
        """Calculate the NME loss using original author's logic structure but functional forward."""
        target_device = scores.device
        ground_truth = ground_truth.to(target_device)
        meta = {k: v.to(target_device) if isinstance(v, torch.Tensor) else v for k, v in meta.items()}

        if 'meta_id' not in meta: raise KeyError("'meta_id' missing from meta")
        batch_ids = meta['meta_id'].to(torch.long)

        self.cluster_names = self.cluster_names.to(target_device)
        self.p_eta = self.p_eta.to(target_device) # This buffer seems to store per-cluster quadratic form from previous step
        self.sigma_2 = self.sigma_2.to(target_device)
        self.cluster_count = self.cluster_count.to(target_device) # Ensure on device

        # p(y | eta, Theta) - Calculated using base loss on final scores
        # base_loss returns sum of squared errors per sample if take_mean=False
        loss_per_sample = super().loss(scores, ground_truth, {}, take_mean=False) # (batch_size,)
        sum_sq_errors_batch = loss_per_sample.sum() # Sum of squared errors for the batch

        # Store loss for sigma_2 update
        if self.training:
            # Store average squared error for sigma_2 update
            avg_sq_error_batch = sum_sq_errors_batch / scores.shape[0] if scores.shape[0] > 0 else torch.tensor(0.0, device=target_device, dtype=self.dtype)
            self.losses.append(avg_sq_error_batch.detach())

        # Likelihood term: sum_i (y_i - f(x_i, eta_c(i)))^2 / (N_batch * sigma_e^2)
        # This is (1/N_batch) * sum( (y_i - y_hat_i)^2 / sigma_e^2 )
        # Original code: likelihood_term = loss_likelihood / scores.shape[0] / self.sigma_2.clamp(min=1e-8)
        # Where loss_likelihood is sum_sq_errors_batch
        # So it is: (sum_sq_errors_batch / N_batch) / sigma_e^2  = MSE_batch / sigma_e^2
        likelihood_term = (sum_sq_errors_batch / scores.shape[0] if scores.shape[0] > 0 else 0.0) / self.sigma_2.clamp(min=1e-8)


        # p(eta | Theta) - Regularization term: sum_clusters_in_batch (0.5 * eta_c^T Sigma_eta^-1 eta_c) / N_batch_clusters * lambda
        # Map batch_ids to their model indices
        batch_model_indices = torch.tensor([self.cluster_id_to_index.get(bid.item(), -1) for bid in batch_ids], dtype=torch.long, device=target_device)
        valid_indices_mask = batch_model_indices != -1
        unique_model_indices_in_batch = torch.unique(batch_model_indices[valid_indices_mask])

        prior_term = torch.tensor(0.0, device=target_device, dtype=self.dtype)
        if unique_model_indices_in_batch.numel() > 0: # If there are known clusters in the batch
            # Calculate 0.5 * eta_c^T Sigma_eta^-1 eta_c for each unique cluster c in the batch
            p_eta_for_unique_clusters_in_batch = self._get_p_eta(index=unique_model_indices_in_batch) # (num_unique_clusters_in_batch,)

            # The original code's prior term seems to be an average of p_eta over samples belonging to known clusters
            # model_idx_to_p_eta_value = {idx.item(): p_val for idx, p_val in zip(unique_model_indices_in_batch, p_eta_for_unique_clusters_in_batch)}
            # sample_p_eta_values = torch.zeros(scores.shape[0], device=target_device, dtype=self.dtype)
            # for i in range(scores.shape[0]):
            #     model_idx = batch_model_indices[i].item()
            #     if model_idx != -1:
            #         sample_p_eta_values[i] = model_idx_to_p_eta_value[model_idx]
            # avg_p_eta_for_batch = sample_p_eta_values[valid_indices_mask].mean() if valid_indices_mask.any() else torch.tensor(0.0)
            # prior_term = avg_p_eta_for_batch * self.l2_lambda

            # Simpler: average of p_eta values for unique clusters present in batch
            # This matches the interpretation "sum over clusters in batch / N_clusters_in_batch"
            avg_p_eta_for_batch_clusters = p_eta_for_unique_clusters_in_batch.mean()
            prior_term = avg_p_eta_for_batch_clusters * self.l2_lambda


            # Original code also updated self.p_eta buffer - replicate this.
            # self.p_eta stores the quadratic form value (eta^T Sigma^-1 eta) for each cluster, not divided by 2.
            if self.training:
                # Create a temporary tensor of zeros with the same shape as self.p_eta
                temp_p_eta_update = self.p_eta.clone() # Keep old values for clusters not in batch
                temp_p_eta_update[unique_model_indices_in_batch] = p_eta_for_unique_clusters_in_batch * 2.0 # Store full quadratic form
                self.p_eta.copy_(temp_p_eta_update)


        # Combine terms: Negative log-likelihood
        # Loss = -log P(Y|eta,theta) - log P(eta|theta)
        # Likelihood term is proportional to sum_sq_err. If sigma_2 is variance, then NLL_data = 0.5 * sum_sq_err / sigma_2 + 0.5 * N_batch * log(2*pi*sigma_2)
        # Prior term is 0.5 * eta^T Sigma^-1 eta. NLL_prior = 0.5 * eta^T Sigma^-1 eta + 0.5 * log(det(2*pi*Sigma))
        # The constant log terms are often dropped for optimization.
        # The loss here is MSE_batch/sigma_2 + lambda * avg(0.5 * eta^T Sigma^-1 eta)
        # This seems like a simplified objective.
        total_loss = likelihood_term + prior_term
        return total_loss


# --- Data Loading and Preparation (R-style split) ---
def load_and_prep_parkinsons_r_split(dtype=TENSOR_DTYPE):
    """
    Loads and prepares the Parkinson's Telemonitoring dataset using an R-like
    "leave-last-observation-out-per-subject" splitting strategy.
    """
    url = "https://archive.ics.uci.edu/ml/machine-learning-databases/parkinsons/telemonitoring/parkinsons_updrs.data"
    print(f"Downloading data from {url}...")
    try:
        s = requests.get(url).content
        # The R code defines col_names, but pandas can infer from header.
        # The UCI file has a header.
        data = pd.read_csv(io.StringIO(s.decode('utf-8')))
        print("Data downloaded successfully.")
    except Exception as e:
        print(f"Error downloading or reading data: {e}")
        return None, None, None, None, None, None, None, None, None, None

    # Rename subject column to be consistent with original Python script
    if 'subject#' in data.columns:
        data = data.rename(columns={'subject#': 'subject'})
    elif 'subject' not in data.columns:
        print("Error: Subject column ('subject#' or 'subject') not found.")
        return None, None, None, None, None, None, None, None, None, None
    subject_col = 'subject'

    # Define outcome and predictors as in the R script
    outcome_col = 'total_UPDRS'
    predictor_cols = [
        "Jitter(%)", "Jitter(Abs)", "Jitter:RAP", "Jitter:PPQ5", "Jitter:DDP",
        "Shimmer", "Shimmer(dB)", "Shimmer:APQ3", "Shimmer:APQ5", "Shimmer:APQ11",
        "Shimmer:DDA", "NHR", "HNR", "RPDE", "DFA", "PPE"
    ]

    # Verify all predictor columns and outcome column exist
    missing_cols = [col for col in predictor_cols + [outcome_col] if col not in data.columns]
    if missing_cols:
        print(f"Error: Missing required columns: {missing_cols}")
        return None, None, None, None, None, None, None, None, None, None

    # Select relevant columns (subject, test_time for sorting, predictors, outcome)
    # 'test_time' is crucial for identifying the "last" observation.
    if 'test_time' not in data.columns:
        print("Error: 'test_time' column not found, which is needed for splitting.")
        return None, None, None, None, None, None, None, None, None, None

    # Drop 'sex' column if it exists, as it's not used.
    if 'sex' in data.columns:
        data = data.drop(columns=['sex'])
        print("Removed 'sex' column.")

    # Keep only necessary columns for processing
    essential_cols = [subject_col, 'test_time'] + predictor_cols + [outcome_col]
    data_subset = data[essential_cols].copy()

    # Convert predictor and outcome columns to numeric, coercing errors to NaN
    print("Converting predictor and outcome columns to numeric...")
    for col in predictor_cols + [outcome_col]:
        data_subset[col] = pd.to_numeric(data_subset[col], errors='coerce')

    # Drop rows with any NA values in the selected predictor or outcome columns
    initial_rows = len(data_subset)
    data_subset.dropna(subset=predictor_cols + [outcome_col], inplace=True)
    removed_rows = initial_rows - len(data_subset)
    if removed_rows > 0:
        print(f"Removed {removed_rows} rows containing NA values in key columns.")
    if len(data_subset) == 0:
        print("Error: No data remaining after NA removal.")
        return None, None, None, None, None, None, None, None, None, None

    # Sort data by subject and test_time to ensure "last" observation is truly last
    data_subset = data_subset.sort_values([subject_col, 'test_time']).reset_index(drop=True)
    print(f"Data sorted by '{subject_col}' and 'test_time'.")

    # --- R-like "Leave-Last-Observation-Out per Subject" Split ---
    print("Splitting data (last observation per subject for test)...")
    train_list_X_df, train_list_y_df, train_list_groups_series = [], [], []
    test_list_X_df, test_list_y_df, test_list_groups_series = [], [], []

    X_unscaled_df_all = data_subset[predictor_cols]
    y_df_all = data_subset[[outcome_col]] # Keep as DataFrame for consistent .loc
    groups_series_all = data_subset[subject_col]

    for subject_id, group_df in data_subset.groupby(subject_col):
        if len(group_df) < 1: # Should not happen if subject exists
            continue

        # Last observation for test set
        test_indices = group_df.index[-1:] # Index for the last row
        test_list_X_df.append(X_unscaled_df_all.loc[test_indices])
        test_list_y_df.append(y_df_all.loc[test_indices])
        test_list_groups_series.append(groups_series_all.loc[test_indices])

        # All other observations for training set (if any)
        if len(group_df) > 1:
            train_indices = group_df.index[:-1] # Indices for all but the last row
            train_list_X_df.append(X_unscaled_df_all.loc[train_indices])
            train_list_y_df.append(y_df_all.loc[train_indices])
            train_list_groups_series.append(groups_series_all.loc[train_indices])

    if not train_list_X_df or not test_list_X_df:
        print("Error: Training or test set is empty after splitting. "
              "This might happen if subjects have only one observation.")
        # Check if any subject has only one observation
        obs_counts = data_subset.groupby(subject_col).size()
        if (obs_counts == 1).any():
            print(f"Warning: Some subjects have only 1 observation. These will only contribute to the test set, "
                  f"leading to an empty training set for them. Subjects with 1 obs: {obs_counts[obs_counts==1].index.tolist()}")
        if not train_list_X_df: # If training set is completely empty
             print("FATAL: Training set is completely empty. Cannot proceed.")
             return None, None, None, None, None, None, None, None, None, None
        # If only test_list_X_df is empty, it might be okay if all subjects had 1 obs, but R code errors.
        # For this script, we need a test set.

    # Concatenate lists of DataFrames/Series
    X_train_unscaled_df = pd.concat(train_list_X_df)
    y_train_df = pd.concat(train_list_y_df)
    groups_train_series = pd.concat(train_list_groups_series)

    X_test_unscaled_df = pd.concat(test_list_X_df)
    y_test_df = pd.concat(test_list_y_df)
    groups_test_series = pd.concat(test_list_groups_series)

    print(f"Train size: {len(X_train_unscaled_df)} ({groups_train_series.nunique()} subjects)")
    print(f"Test size: {len(X_test_unscaled_df)} ({groups_test_series.nunique()} subjects)")

    if len(X_train_unscaled_df) == 0 or len(X_test_unscaled_df) == 0 :
        print("FATAL: Train or test set is empty. Aborting.")
        return None, None, None, None, None, None, None, None, None, None


    # Scale features: Fit on training data, transform both train and test
    print("Scaling features...")
    scaler = StandardScaler()
    X_train_scaled_np = scaler.fit_transform(X_train_unscaled_df)
    X_test_scaled_np = scaler.transform(X_test_unscaled_df)

    # Convert to PyTorch tensors with specified dtype
    X_train_tensor = torch.tensor(X_train_scaled_np, dtype=dtype)
    X_test_tensor = torch.tensor(X_test_scaled_np, dtype=dtype)
    y_train_tensor = torch.tensor(y_train_df.values, dtype=dtype)
    y_test_tensor = torch.tensor(y_test_df.values, dtype=dtype)
    groups_train_tensor = torch.tensor(groups_train_series.values, dtype=torch.long)
    groups_test_tensor = torch.tensor(groups_test_series.values, dtype=torch.long)

    # Create meta dictionaries
    meta_train = {'meta_id': groups_train_tensor}
    meta_test = {'meta_id': groups_test_tensor}

    # Get unique cluster IDs and counts from the training set for NME initialization
    unique_clusters, counts = torch.unique(groups_train_tensor, return_counts=True)

    print(f"Features (X_train shape): {X_train_tensor.shape}, Target (y_train shape): {y_train_tensor.shape}")
    print(f"Unique subjects in training set for NME: {len(unique_clusters)}")


    return X_train_tensor, y_train_tensor, meta_train, \
           X_test_tensor, y_test_tensor, meta_test, \
           unique_clusters, counts, scaler, predictor_cols # Return new predictor_cols as feature_cols


# Helper to flatten nested lists (for parameter indices)
def generic_flatten_nested_list(nested_list):
    """Flattens a nested list."""
    result = []
    for element in nested_list:
        if isinstance(element, list): result.extend(generic_flatten_nested_list(element))
        else: result.append(element)
    return result

# --- Training and Evaluation ---
def train_nme_model(model, X_train, y_train, meta_train, optimizer, epochs=10, batch_size=256):
    """Trains the NME model."""
    # model.train() # Set model to training mode - this is now handled by the model's train() method
    num_samples = X_train.shape[0]
    target_device = next(model.parameters()).device # Get device from model
    print_freq = max(1, epochs // 20) # Print progress roughly 20 times

    for epoch in range(epochs):
        epoch_start_time = time.time() # Start timer for epoch
        model.train() # Ensure train mode for dropout/batchnorm if used, and NME's during_training state
        permutation = torch.randperm(num_samples)
        epoch_loss = 0.0
        batches_processed = 0
        total_random_grad_norm = 0.0
        total_fixed_grad_norm = 0.0
        total_shared_grad_norm = 0.0 # Add norm for shared params


        for i in range(0, num_samples, batch_size):
            indices = permutation[i:i+batch_size]
            # Move data to target device
            batch_X = X_train[indices].to(target_device)
            batch_y = y_train[indices].to(target_device)
            batch_meta = {k: v[indices].to(target_device) for k, v in meta_train.items()}

            optimizer.zero_grad()
            # Forward pass
            scores, _ = model(batch_X, meta=batch_meta, y=batch_y)
            # Loss calculation
            loss = model.loss(scores, batch_y, batch_meta)

            # Backward pass and optimization
            if torch.isnan(loss) or torch.isinf(loss):
                 print(f"Warning: NaN or Inf loss detected at epoch {epoch+1}, batch {i // batch_size}. Skipping batch.")
                 optimizer.zero_grad() # Zero grads even if skipping step
                 continue # Skip optimizer step if loss is invalid

            loss.backward()

            # --- Gradient Check ---
            if model.random:
                batch_random_grad_norm_sq = 0.0
                for name, param in model.random.named_parameters():
                    if param.grad is not None:
                        batch_random_grad_norm_sq += torch.linalg.norm(param.grad).item()**2
                total_random_grad_norm += batch_random_grad_norm_sq # Accumulate squared norm
            if model.fixed:
                batch_fixed_grad_norm_sq = 0.0
                for name, param in model.fixed.named_parameters():
                     if param.grad is not None:
                          batch_fixed_grad_norm_sq += torch.linalg.norm(param.grad).item()**2
                total_fixed_grad_norm += batch_fixed_grad_norm_sq # Accumulate squared norm

            # Check shared model params (from first model instance)
            batch_shared_grad_norm_sq = 0.0
            if model.models: # Check if models list exists
                 first_model = model.models[0] # Access first model in ModuleList
                 for name, param in first_model.named_parameters():
                     std_name = name.replace(".", "-")
                     # Check if param name (relative to submodel) is NOT random
                     if std_name not in model.random_effect_names:
                         if param.grad is not None:
                             batch_shared_grad_norm_sq += torch.linalg.norm(param.grad).item()**2
            total_shared_grad_norm += batch_shared_grad_norm_sq
            # --- End Gradient Check ---


            # Gradient clipping
            # Clip grads for all parameters passed to optimizer
            all_params = []
            for group in optimizer.param_groups:
                 all_params.extend(group['params'])
            # Check if all_params is empty before clipping
            if all_params:
                 torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)

            optimizer.step()
            epoch_loss += loss.item() * len(indices) # Accumulate total loss
            batches_processed += 1

        avg_epoch_loss = epoch_loss / num_samples if num_samples > 0 else 0

        # Sigma updates happen via model.eval() which triggers the _after_training hook in NME
        model.eval() # Trigger sigma update hook, then sets model.training=False
                     # Next iteration model.train() will set model.training=True and update sigma again if needed

        epoch_end_time = time.time() # End timer for epoch
        epoch_duration = epoch_end_time - epoch_start_time

        # Print progress less frequently for long training runs
        if (epoch + 1) % print_freq == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_epoch_loss:.4f}, Sigma2: {model.sigma_2.item():.4f}, Duration: {epoch_duration:.2f}s")
            # Calculate average norms
            avg_random_grad_norm = (total_random_grad_norm / batches_processed)**0.5 if model.random and batches_processed > 0 else 0.0
            avg_fixed_grad_norm = (total_fixed_grad_norm / batches_processed)**0.5 if model.fixed and batches_processed > 0 else 0.0
            avg_shared_grad_norm = (total_shared_grad_norm / batches_processed)**0.5 if batches_processed > 0 else 0.0

            if model.random and hasattr(model, 'sa_sigma') and model.sa_sigma is not None and model.sa_sigma.numel() > 0 :
                sa_sigma_cpu = model.sa_sigma.detach().cpu().numpy()
                print(f"  Avg Grads: Random={avg_random_grad_norm:.3e}, Fixed={avg_fixed_grad_norm:.3e}, Shared={avg_shared_grad_norm:.3e}")
                print(f"  SA Sigma diag (sample): {sa_sigma_cpu[:min(5, len(sa_sigma_cpu))]}") # Print only up to 5 or available
                if avg_random_grad_norm < 1e-6 and epoch > 50: # Check after more initial epochs
                     print(f"  Warning: Low/zero gradient for random parameters!")
            else:
                print(f"  Avg Grads: Fixed={avg_fixed_grad_norm:.3e}, Shared={avg_shared_grad_norm:.3e} (No random effects or sa_sigma)")


def evaluate_nme_model(model, X_test, y_test, meta_test):
    """Evaluates the NME model."""
    model.eval() # Set model to evaluation mode
    target_device = next(model.parameters()).device
    all_preds = []
    all_gt = []

    # Move test data to device
    X_test = X_test.to(target_device)
    y_test = y_test.to(target_device)
    meta_test = {k: v.to(target_device) for k, v in meta_test.items()}

    with torch.no_grad():
        # Process test set (can be done in one go if memory allows)
        scores, _ = model(X_test, meta=meta_test, y=y_test)
        all_preds.append(scores.cpu())
        all_gt.append(y_test.cpu())

    predictions = torch.cat(all_preds).numpy()
    ground_truth = torch.cat(all_gt).numpy()

    # Calculate metrics
    mse = np.mean((predictions - ground_truth)**2)
    mae = np.mean(np.abs(predictions - ground_truth))
    return mse, mae


# --- Main Execution ---
if __name__ == "__main__":
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Define dtype globally
    model_dtype = TENSOR_DTYPE # torch.float64 for stability

    # Load and prepare data with the specified R-like split
    X_train, y_train, meta_train, \
    X_test, y_test, meta_test, \
    unique_clusters, cluster_counts, scaler, feature_cols = load_and_prep_parkinsons_r_split(dtype=model_dtype)

    if X_train is None:
        print("Failed to load data. Exiting.")
        exit()

    input_size = X_train.shape[1] # This will now be 16 (number of predictor_cols)
    output_size = y_train.shape[1] # Should be 1 for total_UPDRS

    # --- Configure NME Model ---
    # Parameter names in MLP: network.0.weight, network.0.bias, network.2.weight, ...
    # MLP with hidden [32, 16] has layers: 0(Lin), 1(Act), 2(Lin), 3(Act), 4(Lin)
    # Use original naming convention expected by NME author's code
    random_effects_config = ("network.4.bias",) # Bias of the output layer
    independent_config = ()

    print("\nConfiguring NME model (Author's version structure + Functional Forward)...")
    nme_model = NeuralMixedEffects(
        clusters=unique_clusters,
        cluster_count=cluster_counts, # Pass cluster counts
        model_fun=MLP,
        random_effects=random_effects_config,
        independent=independent_config,
        simulated_annealing_alpha=0.97,
        dtype=model_dtype, # Pass the global dtype
        l2_lambda=0.1, # Regularization strength for random effects prior
        # --- MLP specific args ---
        model_input_size=input_size,
        model_output_size=output_size,
        model_hidden_sizes=[32, 16], # Smaller MLP
        model_dtype=model_dtype, # Pass dtype to MLP constructor
        # --- NME args ---
        input_size=input_size, # Required by NME base class
        output_size=output_size, # Required by NME base class
    ).to(device) # Move model to device

    # --- Optimizer ---
    # Define parameter groups for optimizer based on named_parameters logic
    fixed_param_list = []
    random_param_list = []
    shared_param_list = []

    # Use the model's specific named_parameters method
    for name, param in nme_model.named_parameters():
        if name.startswith("fixed."):
            fixed_param_list.append(param)
        elif name.startswith("random."):
            random_param_list.append(param)
        # Shared parameters are those NOT starting with fixed. or random.
        # AND NOT part of the sub-models corresponding to random effects
        # The named_parameters method is designed to handle this filtering
        elif not name.startswith("fixed.") and not name.startswith("random."):
             shared_param_list.append(param)


    optimizer_param_groups = []
    if fixed_param_list:
        optimizer_param_groups.append({'params': fixed_param_list, 'lr': 1e-3, 'weight_decay': 1e-5})
    if random_param_list:
        optimizer_param_groups.append({'params': random_param_list, 'lr': 1e-3}) # Use same LR for random for now
    if shared_param_list:
        optimizer_param_groups.append({'params': shared_param_list, 'lr': 1e-4, 'weight_decay': 1e-5})


    if not optimizer_param_groups:
         print("Warning: No parameters found requiring gradients for the optimizer.")
         # Handle this case, maybe exit or skip training
    else:
         optimizer = optim.Adam(optimizer_param_groups, lr=1e-3) # Default LR if only one group
         print("\nStarting training...")
         # *** Set epochs to 2500 ***
         train_nme_model(nme_model, X_train, y_train, meta_train, optimizer, epochs=4000, batch_size=512)

         print("\nStarting evaluation...")
         test_mse, test_mae = evaluate_nme_model(nme_model, X_test, y_test, meta_test)

         print("\n--- Results ---")
         print(f"Test MSE: {test_mse:.4f}")
         print(f"Test MAE: {test_mae:.4f}")

         # Example of accessing final random effects (eta) for a subject
         if nme_model.random and len(unique_clusters) > 0:
             first_subject_id_tensor = unique_clusters[0]
             first_subject_id = first_subject_id_tensor.item()

             # Find model index corresponding to the subject ID
             try:
                 # Get model index from cluster_id_to_index map
                 subject_model_idx = nme_model.cluster_id_to_index.get(first_subject_id, -1)

                 if subject_model_idx != -1:
                     print(f"\nRandom effect (eta) for subject {first_subject_id} (model index {subject_model_idx}):")
                     for name, eta_param in nme_model.random.items():
                         # Ensure name matches the format used internally (with hyphens)
                         std_name = name # self.random keys are already std_name
                         print(f"  - {std_name}: {eta_param[subject_model_idx].detach().cpu().numpy()}")
                 else:
                     print(f"Could not find subject {first_subject_id} in model cluster map.")
             except Exception as e:
                print(f"Error accessing random effect for subject {first_subject_id}: {e}")

Using device: cuda
Downloading data from https://archive.ics.uci.edu/ml/machine-learning-databases/parkinsons/telemonitoring/parkinsons_updrs.data...
Data downloaded successfully.
Removed 'sex' column.
Converting predictor and outcome columns to numeric...
Data sorted by 'subject' and 'test_time'.
Splitting data (last observation per subject for test)...
Train size: 5833 (42 subjects)
Test size: 42 (42 subjects)
Scaling features...
Features (X_train shape): torch.Size([5833, 16]), Target (y_train shape): torch.Size([5833, 1])
Unique subjects in training set for NME: 42

Configuring NME model (Author's version structure + Functional Forward)...
Instantiating 42 models for NME...
Instantiated NME models: <class '__main__.MLP'>
Tracking random effect 'network-4-bias' with shape: torch.Size([42, 1]), dtype: torch.float64
Total scalar random effects per cluster: 1
Independent effect indices: []
Full covariance block indices: [0]

Starting training...
Epoch 200/4000, Avg Loss: 0.9908, Sigma2