In [None]:
pip install ipykernal

In [None]:
# ====================================================================================
# === TPU SCRIPT: Train Advanced H-RPE Transformer + SSMD Eval (Reduced Logging) ===
# ====================================================================================
# Target: Chord Progression (Model 1) using existing dataset files.
# H-RPE: Advanced version using learnable embeddings based on numerical features.
# Evaluation: Loss, Accuracy, Perplexity, SSMD(k=1) on Function Codes.
# Assumes running in a Google Cloud TPU environment with PyTorch XLA.
# --- MODIFIED FOR REDUCED LOGGING & PER-EPOCH METRICS ---
# --- Calculates Val Acc & SSMD per epoch (controlled by HPARAMS) ---
# --- Updates epoch progress bar postfix ---
# ====================================================================================

# !!! IMPORTANT !!!
# For the dynamic progress bar (`tqdm.notebook`) to work on Kaggle/Colab,
# you might need to explicitly install/update ipywidgets first.
# Run this in a separate cell BEFORE running this script, then RESTART the kernel:
# !pip install -U ipywidgets --quiet
# !jupyter nbextension enable --py widgetsnbextension --sys-prefix --quiet # Optional, sometimes needed

# --- Necessary Imports ---
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, Subset, random_split
from torch.nn.utils.rnn import pad_sequence
# ============================================= #
# >>> MODIFICATION: Using tqdm.notebook <<<     #
from tqdm.notebook import tqdm                  #
# ============================================= #
import logging
import os
import math
import sys
import time
import traceback
import random
import numpy as np
import gc
from typing import Optional, List, Dict, Tuple, Any
from pathlib import Path
# Import for SSMD
from scipy.spatial.distance import pdist, squareform

# --- XLA Imports ---
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.serialization as xser # For saving/loading

# === CONFIGURATION - ASSUMING Sequence Length 32 and token_features.json ===
# >>>>>>>>>> VERIFY YOUR INPUT FILES MATCH THIS CONFIG <<<<<<<<<<<<<<
INPUT_SEQUENCES_PATH = Path("/kaggle/input/advance-h-rpe/chord_sequences_for_training.jsonl") # Example path, replace with yours
VOCAB_PATH = Path("/kaggle/input/advance-h-rpe/chord_progression_vocab.json") # Example path
TOKEN_FEATURES_PATH = Path("/kaggle/input/advance-h-rpe/token_features.json") # Example path
RESULTS_DIR = Path("./harmony_results_AdvH_RPE_SeqLen32_TPU_EpochMetrics") # Updated name
CHECKPOINT_TO_LOAD = None # Set path like "/path/to/prev-ckpt/best_model_... .pth" to resume
CHECKPOINT_FILENAME_BEST_PREFIX = "best_model_ep"
# >>>>>>>>>>>>>>>>>>>>>><<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

# Data Split Ratios / HPARAMS / Other Params
HPARAMS = {
    # Model Architecture
    'd_model': 512, 'nhead': 8, 'num_layers': 6, 'dim_feedforward': 2048, 'dropout': 0.1,

    # --- ADVANCED H-RPE Specific ---
    'num_root_intervals': 12,
    'num_qualities': 10,
    'num_functions': 11,
    'relation_embedding_dim': None,
    'feature_keys': ['root_pc', 'quality_code', 'function_code'],

    # Training Hyperparameters
    # >>>>> ADJUSTED FOR TPU (Per Core Batch Size) <<<<<
    # Total batch size = batch_size * num_cores (e.g., 8 cores * 8 = 64)
    'batch_size': 512,      # BATCH SIZE PER TPU CORE
    'eval_batch_size': 16,      # EVAL BATCH SIZE PER TPU CORE
    'lr': 5e-5,                # Base LR - may need scaling
    'weight_decay': 0.01,
    'warm_up_steps': 100,       # Placeholder: Adjust
    'max_grad_norm': 1.0,
    'num_epochs': 5,
    'patience': 5,
    'lr_scale_factor': None,     # Set to xm.xrt_world_size() later if scaling needed

    # Data & Environment
    'sequence_length': 32,
    'seed': 42,
    'num_workers': 4,           # Recommended for TPUs (adjust based on environment)

    # Generation & Evaluation
    'max_gen_len': 16,
    'ssmd_context_len': 16,
    'ssmd_eval_every_n_epochs': 1, # <<< How often to run SSMD eval (1 = every epoch)
    'ssmd_eval_max_batches': 32,   # <<< Max validation batches for SSMD per epoch (None = all)

    # Checkpointing
    'save_every_n_epochs': 5
}

TRAIN_RATIO = 0.8
VAL_RATIO = 0.1
TEST_RATIO = 0.1

PAD_TOKEN = "<PAD>"
# AMP is generally handled differently in XLA, often bf16 is preferred and automatic.
# We will remove explicit AMP scaler logic.

# === Setup Logging (Simplified for Multiprocessing) ===
# Set to INFO to see epoch summaries including new metrics
log_level = logging.INFO
logging.basicConfig(
    level=log_level, format='%(asctime)s %(levelname)-8s [Process %(process)d - %(name)s]: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S', stream=sys.stdout)
logger = logging.getLogger(__name__) # Root logger

# === Global PADDING_VALUE & Feature Map (Initialized in _mp_fn) ===
PADDING_VALUE: Optional[int] = None
token_to_features_map: Optional[Dict[int, Dict[str, int]]] = None
id_to_vocab: Optional[Dict[int, str]] = None
vocab: Optional[Dict[str, int]] = None
vocab_size: int = 0


# === Dataset for Chord Progression (Modified for Features) ===
class ChordProgressionDataset(Dataset):
    def __init__(self, data_file_path: Path, seq_len: int):
        self.samples = []
        self.seq_len = seq_len
        self.logger = logging.getLogger(__name__ + ".ChordProgressionDataset")
        self.feature_keys = HPARAMS.get('feature_keys', ['root_pc', 'quality_code', 'function_code'])

        if not data_file_path.is_file():
            self.logger.error(f"Data file not found: {data_file_path}")
            return

        self.logger.info(f"Loading data from: {data_file_path} (Expecting seq_len={self.seq_len})")
        skipped_count = 0
        try:
            with open(data_file_path, 'r') as f:
                for line_num, line in enumerate(f):
                    try:
                        sample = json.loads(line.strip())
                        if not all(k in sample for k in ['input_ids', 'target_id'] + self.feature_keys):
                            self.logger.debug(f"Skipping line {line_num+1}: Missing required keys. Found: {list(sample.keys())}")
                            skipped_count += 1; continue
                        input_ids_val = sample.get('input_ids', [])
                        if not isinstance(input_ids_val, list) or len(input_ids_val) != self.seq_len:
                            self.logger.debug(f"Skipping line {line_num+1}: Invalid input_ids length ({len(input_ids_val)}) != {self.seq_len}.")
                            skipped_count += 1; continue
                        valid_features = True
                        for key in self.feature_keys:
                            feature_val = sample.get(key, [])
                            if not isinstance(feature_val, list) or len(feature_val) != self.seq_len:
                                self.logger.debug(f"Skipping line {line_num+1}: Invalid feature '{key}' length ({len(feature_val)}) != {self.seq_len}.")
                                valid_features = False; break
                        if not valid_features:
                            skipped_count += 1; continue
                        self.samples.append(sample)
                    except json.JSONDecodeError:
                        self.logger.warning(f"Skipping line {line_num+1}: Invalid JSON.")
                        skipped_count += 1
                    except Exception as e:
                        self.logger.warning(f"Skipping line {line_num+1}: Error processing sample - {e}")
                        skipped_count += 1

            if skipped_count > 0:
                self.logger.warning(f"Skipped {skipped_count} samples during loading due to format/length issues.")
            if not self.samples:
                self.logger.error(f"No valid samples loaded from {data_file_path}. Check file content and seq_len ({self.seq_len}) match HPARAMS.")
            else:
                self.logger.info(f"Successfully loaded {len(self.samples)} samples.")

        except Exception as e:
            self.logger.error(f"Failed to load or process data file {data_file_path}: {e}", exc_info=True)

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        input_ids = torch.tensor(sample['input_ids'], dtype=torch.long)
        target_id = torch.tensor(sample['target_id'], dtype=torch.long)
        features = {key: torch.tensor(sample[key], dtype=torch.long) for key in self.feature_keys}
        return input_ids, target_id, features, sample['input_ids'] # Return original ids for decode


# === Collate Function (Modified for Features) ===
def collate_fn_progression(batch: List[Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], List[int]]]) -> Optional[Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], List[List[int]]]]:
    global PADDING_VALUE
    if PADDING_VALUE is None:
        logging.error("Global PADDING_VALUE not set in collate_fn_progression.")
        return None

    if not batch: return None
    valid_batch = [item for item in batch if item is not None and isinstance(item[0], torch.Tensor) and isinstance(item[1], torch.Tensor) and isinstance(item[2], dict)]
    if not valid_batch: return None

    inputs, targets, features_list, orig_inputs_list = zip(*valid_batch)

    try:
        inputs_stacked = torch.stack(inputs, dim=0)
        targets_stacked = torch.stack(targets, dim=0)
    except RuntimeError as e:
        logging.error(f"Error stacking batch tensors: {e}. Sequence lengths might be inconsistent.", exc_info=True)
        return None

    collated_features = {}
    if features_list:
        feature_keys = HPARAMS.get('feature_keys', list(features_list[0].keys()))
        for key in feature_keys:
            if key not in features_list[0]: continue
            if not all(key in f for f in features_list): continue
            feature_tensors = [f[key] for f in features_list]
            try:
                collated_features[key] = torch.stack(feature_tensors, dim=0)
            except RuntimeError as e:
                logging.error(f"Error stacking feature '{key}': {e}. Feature lengths might be inconsistent.", exc_info=True)
                return None

    return inputs_stacked, targets_stacked, collated_features, list(orig_inputs_list)


# =====================================================
# === ADVANCED H-RPE MODEL DEFINITION ===
# =====================================================
class AdvancedHarmonicRelativeAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout, padding_idx: Optional[int], num_root_intervals=12, num_qualities=8, num_functions=10, relation_embedding_dim: Optional[int] = None):
        super().__init__()
        if d_model % nhead != 0: raise ValueError("d_model must be divisible by nhead")
        self.d_model = d_model; self.nhead = nhead; self.head_dim = d_model // nhead
        self.dropout = nn.Dropout(dropout)
        self.q_proj = nn.Linear(d_model, d_model); self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model); self.out_proj = nn.Linear(d_model, d_model)
        self.padding_idx = int(padding_idx) if padding_idx is not None else None
        self.local_logger = logging.getLogger(__name__ + ".AdvancedHarmonicRelativeAttention")

        bias_head_dim = relation_embedding_dim if relation_embedding_dim is not None else self.head_dim // 4
        if bias_head_dim <= 0: bias_head_dim = 1

        if num_root_intervals <= 0: raise ValueError("num_root_intervals must be positive")
        if num_qualities <= 0: raise ValueError("num_qualities must be positive")
        if num_functions <= 0: raise ValueError("num_functions must be positive")

        self.root_interval_embed = nn.Embedding(num_root_intervals, self.nhead * bias_head_dim)
        num_quality_rels = num_qualities * num_qualities
        self.quality_rel_embed = nn.Embedding(num_quality_rels, self.nhead * bias_head_dim)
        num_function_rels = num_functions * num_functions
        self.function_rel_embed = nn.Embedding(num_function_rels, self.nhead * bias_head_dim)

        self.num_feature_embeddings = 3

        self.bias_combiner = nn.Sequential(
            nn.Linear(self.num_feature_embeddings * bias_head_dim, bias_head_dim),
            nn.ReLU(),
            nn.Linear(bias_head_dim, 1)
        )

        self.local_logger.info(f"Initialized Advanced Theory H-RPE (Embeddings + Combiner)")
        self.local_logger.info(f" HRPE Params: #Qual={num_qualities}, #Func={num_functions}, RootInt={num_root_intervals}, BiasHeadDim={bias_head_dim}")
        self.local_logger.info(f" QualityRelEmbed size: {num_quality_rels}, FunctionRelEmbed size: {num_function_rels}")

    def _compute_advanced_hrpe_bias(self, query_root_pc, key_root_pc, query_quality, key_quality, query_function, key_function, query_token_ids_for_padding, key_token_ids_for_padding):
        batch_size, seq_len_q = query_root_pc.shape
        _, seq_len_k = key_root_pc.shape
        device = query_root_pc.device

        rel_root_interval = (query_root_pc.unsqueeze(2) - key_root_pc.unsqueeze(1) + 12) % 12
        num_qualities = int(math.sqrt(self.quality_rel_embed.num_embeddings))
        rel_quality_code = query_quality.unsqueeze(2) * num_qualities + key_quality.unsqueeze(1)
        num_functions = int(math.sqrt(self.function_rel_embed.num_embeddings))
        rel_function_code = query_function.unsqueeze(2) * num_functions + key_function.unsqueeze(1)

        root_bias_vectors = self.root_interval_embed(rel_root_interval.long().clamp_(0, self.root_interval_embed.num_embeddings - 1))
        quality_bias_vectors = self.quality_rel_embed(rel_quality_code.long().clamp_(0, self.quality_rel_embed.num_embeddings - 1))
        function_bias_vectors = self.function_rel_embed(rel_function_code.long().clamp_(0, self.function_rel_embed.num_embeddings - 1))

        all_bias_vectors = [root_bias_vectors, quality_bias_vectors, function_bias_vectors]
        bias_head_dim = root_bias_vectors.shape[-1] // self.nhead
        reshaped_bias_vectors = [bv.view(batch_size, seq_len_q, seq_len_k, self.nhead, bias_head_dim) for bv in all_bias_vectors]

        if self.bias_combiner:
            combined_features = torch.cat(reshaped_bias_vectors, dim=-1)
            combined_features_flat = combined_features.view(-1, self.num_feature_embeddings * bias_head_dim)
            scalar_bias_flat = self.bias_combiner(combined_features_flat)
            total_bias = scalar_bias_flat.view(batch_size, seq_len_q, seq_len_k, self.nhead)
        else:
            self.local_logger.warning("Bias combiner not found, using simple sum.")
            total_bias = torch.stack(reshaped_bias_vectors, dim=0).sum(dim=0).sum(dim=-1)

        total_bias = total_bias.permute(0, 3, 1, 2)

        if self.padding_idx is not None:
            query_padding_mask = (query_token_ids_for_padding == self.padding_idx)
            key_padding_mask = (key_token_ids_for_padding == self.padding_idx)
            combined_pad_mask = (query_padding_mask.unsqueeze(2) | key_padding_mask.unsqueeze(1))
            combined_pad_mask_expanded = combined_pad_mask.unsqueeze(1).expand_as(total_bias)
            total_bias = total_bias.masked_fill(combined_pad_mask_expanded, 0.0)

        return total_bias

    def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, root_pc=None, quality_code=None, function_code=None, query_tokens=None, key_tokens=None):
        feature_sets = [(root_pc, "root_pc"), (quality_code, "quality_code"), (function_code, "function_code")]
        if query_tokens is None or key_tokens is None: raise ValueError("Missing query_tokens/key_tokens required for H-RPE padding mask calculation.")
        for feat, name in feature_sets:
            if feat is None: raise ValueError(f"Missing required auxiliary feature for Advanced H-RPE: {name}")

        batch_size, seq_len_q, _ = query.shape
        seq_len_k = key.shape[1]

        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)

        q = q.view(batch_size, seq_len_q, self.nhead, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len_k, self.nhead, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len_k, self.nhead, self.head_dim).transpose(1, 2)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        hrpe_bias = self._compute_advanced_hrpe_bias(
            query_root_pc=root_pc, key_root_pc=root_pc,
            query_quality=quality_code, key_quality=quality_code,
            query_function=function_code, key_function=function_code,
            query_token_ids_for_padding=query_tokens,
            key_token_ids_for_padding=key_tokens
        )

        if hrpe_bias.shape != attn_scores.shape:
             self.local_logger.error(f"Shape mismatch! Attn: {attn_scores.shape}, Bias: {hrpe_bias.shape}. Check feature tensor shapes.")
             raise ValueError("Attention scores and H-RPE bias shape mismatch.")
        attn_scores = attn_scores + hrpe_bias.type_as(attn_scores)

        mask_value = -torch.inf if attn_scores.dtype == torch.float32 else torch.finfo(attn_scores.dtype).min
        device = query.device

        if attn_mask is not None:
            attn_mask_bool = attn_mask.to(device=device, dtype=torch.bool)
            if attn_mask_bool.dim() == 2:
                attn_mask_expanded = attn_mask_bool.unsqueeze(0).unsqueeze(0).expand_as(attn_scores)
            elif attn_mask_bool.dim() == 3 and attn_mask_bool.shape[0] == batch_size:
                attn_mask_expanded = attn_mask_bool.unsqueeze(1).expand_as(attn_scores)
            elif attn_mask_bool.shape == attn_scores.shape:
                attn_mask_expanded = attn_mask_bool
            else:
                raise ValueError(f"Unsupported attn_mask shape {attn_mask.shape}, expected 2D, 3D(B,Sq,Sk), or 4D(B,H,Sq,Sk)")
            attn_scores = attn_scores.masked_fill(attn_mask_expanded, mask_value)

        if key_padding_mask is not None:
            key_padding_mask_bool = key_padding_mask.to(device=device, dtype=torch.bool)
            key_padding_mask_expanded = key_padding_mask_bool.unsqueeze(1).unsqueeze(2).expand_as(attn_scores)
            attn_scores = attn_scores.masked_fill(key_padding_mask_expanded, mask_value)

        softmax_dtype = torch.float32 # Keep float32 for softmax robustness
        attn_weights = F.softmax(attn_scores.to(softmax_dtype), dim=-1).to(attn_scores.dtype)
        attn_weights = self.dropout(attn_weights)

        output = torch.matmul(attn_weights, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)
        output = self.out_proj(output)
        return output


class CustomTransformerEncoderLayerAdv(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout, padding_idx, **hrpe_kwargs):
        super().__init__()
        self.self_attn = AdvancedHarmonicRelativeAttention(d_model, nhead, dropout, padding_idx, **hrpe_kwargs)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, src, src_mask=None, src_key_padding_mask=None, src_tokens=None, src_root_pc=None, src_quality=None, src_function=None):
        if src_tokens is None:
            raise ValueError("src_tokens (original token IDs) are required for H-RPE padding calculation in CustomTransformerEncoderLayerAdv.")
        if src_root_pc is None or src_quality is None or src_function is None:
            raise ValueError("Missing auxiliary features (root_pc, quality, function) required by AdvancedHarmonicRelativeAttention.")

        attn_output = self.self_attn(
            query=src, key=src, value=src,
            key_padding_mask=src_key_padding_mask,
            attn_mask=src_mask,
            root_pc=src_root_pc,
            quality_code=src_quality,
            function_code=src_function,
            query_tokens=src_tokens,
            key_tokens=src_tokens
        )

        src = self.norm1(src + self.dropout1(attn_output))
        ff_output = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = self.norm2(src + self.dropout2(ff_output))
        return src


class HarmonyTransformerWithAdvH_RPE(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dropout, dim_feedforward, padding_idx: Optional[int], **hrpe_kwargs):
        super().__init__()
        global PADDING_VALUE
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.padding_idx = padding_idx

        self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=self.padding_idx)
        self.dropout_embed = nn.Dropout(dropout)

        valid_hrpe_keys = ['num_root_intervals', 'num_qualities', 'num_functions', 'relation_embedding_dim']
        filtered_hrpe_kwargs = {k: v for k, v in hrpe_kwargs.items() if k in valid_hrpe_keys and v is not None}
        encoder_layers = [
            CustomTransformerEncoderLayerAdv(
                d_model, nhead, dim_feedforward, dropout, padding_idx=self.padding_idx, **filtered_hrpe_kwargs
            ) for _ in range(num_layers)
        ]
        self.transformer_encoder_layers = nn.ModuleList(encoder_layers)
        self.encoder_norm = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)

        self._init_parameters()
        self.local_logger = logging.getLogger(__name__ + ".HarmonyTransformerAdv")
        self.local_logger.info(f"HarmonyTransformerWithAdvH_RPE initialized. Vocab: {vocab_size}, Padding ID: {self.padding_idx}")
        self.local_logger.info(f" H-RPE Kwargs passed to layers: {filtered_hrpe_kwargs}")


    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear) and module.bias is not None:
                nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding) and name == 'token_embedding':
                nn.init.normal_(module.weight, mean=0, std=0.02)
                if module.padding_idx is not None:
                    if 0 <= module.padding_idx < module.num_embeddings:
                        with torch.no_grad():
                            module.weight[module.padding_idx].fill_(0)
                    else:
                        self.local_logger.warning(f"Invalid padding_idx ({module.padding_idx}) for embedding size ({module.num_embeddings}). Cannot zero padding.")
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
            elif isinstance(module, AdvancedHarmonicRelativeAttention):
                pass

    def _generate_square_subsequent_mask(self, sz: int, device: torch.device) -> torch.Tensor:
        mask = torch.triu(torch.ones((sz, sz), device=device), diagonal=1).bool()
        return mask

    def forward(self, src_tokens, root_pc=None, quality_code=None, function_code=None, src_mask=None, src_key_padding_mask=None):
        global vocab_size

        if self.padding_idx is not None:
              valid_non_pad = (src_tokens != self.padding_idx)
              if valid_non_pad.any() and \
                   (src_tokens[valid_non_pad].max() >= self.vocab_size or src_tokens[valid_non_pad].min() < 0):
                  self.local_logger.debug(f"Token IDs outside valid range detected (excluding padding). Range: [{src_tokens[valid_non_pad].min().item()}, {src_tokens[valid_non_pad].max().item()}], Vocab Size: {self.vocab_size}.")
        elif src_tokens.max() >= self.vocab_size or src_tokens.min() < 0:
            self.local_logger.debug(f"Token IDs outside valid range detected. Range: [{src_tokens.min().item()}, {src_tokens.max().item()}], Vocab Size: {self.vocab_size}.")

        if root_pc is None: raise ValueError("Missing required feature: root_pc")
        if quality_code is None: raise ValueError("Missing required feature: quality_code")
        if function_code is None: raise ValueError("Missing required feature: function_code")
        expected_shape = src_tokens.shape
        if root_pc.shape != expected_shape: raise ValueError(f"Shape mismatch: root_pc {root_pc.shape} vs src_tokens {expected_shape}")
        if quality_code.shape != expected_shape: raise ValueError(f"Shape mismatch: quality_code {quality_code.shape} vs src_tokens {expected_shape}")
        if function_code.shape != expected_shape: raise ValueError(f"Shape mismatch: function_code {function_code.shape} vs src_tokens {expected_shape}")

        batch_size, seq_len = src_tokens.shape
        device = src_tokens.device

        if src_mask is None:
            src_mask = self._generate_square_subsequent_mask(seq_len, device)
        elif src_mask.device != device:
            src_mask = src_mask.to(device)

        if src_key_padding_mask is None and self.padding_idx is not None:
            src_key_padding_mask = (src_tokens == self.padding_idx)
        elif src_key_padding_mask is not None and src_key_padding_mask.device != device:
            src_key_padding_mask = src_key_padding_mask.to(device)
        if src_key_padding_mask is not None:
            src_key_padding_mask = src_key_padding_mask.bool()

        embedded = self.token_embedding(src_tokens) * math.sqrt(self.d_model)
        embedded = self.dropout_embed(embedded)

        output = embedded
        for layer in self.transformer_encoder_layers:
            output = layer(
                src=output,
                src_mask=src_mask,
                src_key_padding_mask=src_key_padding_mask,
                src_tokens=src_tokens,
                src_root_pc=root_pc,
                src_quality=quality_code,
                src_function=function_code
            )

        if self.encoder_norm is not None:
            output = self.encoder_norm(output)

        logits = self.fc_out(output)
        return logits


    # --- Generation Method (Needs careful device handling for XLA) ---
    @torch.no_grad()
    def generate(self, start_token_ids, start_features: Dict[str, torch.Tensor], max_length=32, temperature=1.0, top_k=0, top_p=0.9, get_features_for_id_fn=None ):
        self.eval()
        # Assume this is called on the correct device already (e.g., master's XLA device)
        device = start_token_ids.device

        if get_features_for_id_fn is None:
              raise ValueError("`get_features_for_id_fn` is required for autoregressive generation with features.")

        current_token_ids = start_token_ids # Assume already on device

        expected_feature_keys = HPARAMS.get('feature_keys', [])
        if not expected_feature_keys:
              logger.warning("`HPARAMS['feature_keys']` not found, cannot prepare features for generation.")
              return torch.empty((start_token_ids.shape[0], 0), dtype=torch.long, device=device)

        current_features = {}
        for k in expected_feature_keys:
            if k not in start_features:
                logger.error(f"Generation start features missing required key: '{k}'")
                return torch.empty((start_token_ids.shape[0], 0), dtype=torch.long, device=device)
            current_features[k] = start_features[k] # Assume already on device

        generated_ids_list = []

        for step in range(max_length):
            input_ids_step = current_token_ids
            input_features_step = current_features
            current_seq_len = input_ids_step.shape[1]

            causal_mask = self._generate_square_subsequent_mask(current_seq_len, device)
            current_padding_idx = self.padding_idx if hasattr(self, 'padding_idx') else PADDING_VALUE
            padding_mask = (input_ids_step == current_padding_idx) if current_padding_idx is not None else None

            try:
                logits = self(
                    src_tokens=input_ids_step,
                    src_mask=causal_mask,
                    src_key_padding_mask=padding_mask,
                    root_pc=input_features_step.get('root_pc'),
                    quality_code=input_features_step.get('quality_code'),
                    function_code=input_features_step.get('function_code')
                )
            except Exception as e:
                self.local_logger.error(f"Error during generation forward pass (step {step+1}): {e}", exc_info=True)
                break

            next_token_logits = logits[:, -1, :]

            if temperature > 1e-8:
                 next_token_logits = next_token_logits / temperature
            else:
                 temperature = 0 # Argmax case

            if top_k > 0 and temperature > 0:
                k = min(top_k, next_token_logits.size(-1))
                top_k_values, _ = torch.topk(next_token_logits, k, dim=-1)
                kth_value = top_k_values[:, [-1]]
                indices_to_remove = next_token_logits < kth_value
                next_token_logits[indices_to_remove] = -float('Inf')

            if 0 < top_p < 1.0 and temperature > 0:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True, dim=-1)
                cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cum_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = torch.zeros_like(next_token_logits, dtype=torch.bool).scatter_(
                    dim=-1, index=sorted_indices, src=sorted_indices_to_remove
                )
                next_token_logits[indices_to_remove] = -float('Inf')

            if temperature <= 1e-8:
                next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            else:
                probs = F.softmax(next_token_logits.float(), dim=-1)
                probs = torch.nan_to_num(probs, nan=0.0)

                if torch.sum(probs, dim=-1).min() < 1e-6:
                    self.local_logger.warning(f"Generation probabilities collapsed at step {step+1}. Using argmax.")
                    next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
                else:
                    try:
                        next_token_id = torch.multinomial(probs, num_samples=1)
                    except RuntimeError as e:
                        self.local_logger.warning(f"Multinomial sampling failed at step {step+1}: {e}. Using argmax.")
                        next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)

            generated_ids_list.append(next_token_id)
            current_token_ids = torch.cat([current_token_ids, next_token_id], dim=1)

            try:
                # get_features_for_id expects CPU tensor, returns dict of CPU tensors
                new_features_cpu_dict = get_features_for_id_fn(next_token_id.squeeze(-1).cpu())
            except Exception as e:
                self.local_logger.error(f"Error calling get_features_for_id_fn (step {step+1}): {e}", exc_info=True)
                generated_ids_list = [] ; break

            loop_broken = False
            for key in current_features.keys():
                if key not in new_features_cpu_dict:
                    self.local_logger.error(f"Feature key '{key}' missing from get_features_for_id_fn output (step {step+1}). Stopping generation.")
                    generated_ids_list = [] ; loop_broken = True; break

                try:
                    # Move new feature tensor back to the correct device
                    new_feature_tensor = new_features_cpu_dict[key].to(device) # Shape should be (B, 1)
                except Exception as e:
                     self.local_logger.error(f"Error processing feature '{key}' from get_features_for_id_fn (step {step+1}): {e}", exc_info=True)
                     generated_ids_list = [] ; loop_broken = True; break

                if new_feature_tensor.dim() != 2 or new_feature_tensor.shape[0] != current_features[key].shape[0] or new_feature_tensor.shape[1] != 1:
                    self.local_logger.error(f"New feature '{key}' has incorrect shape {new_feature_tensor.shape}. Expected ({current_features[key].shape[0]}, 1) at step {step+1}. Stopping generation.")
                    generated_ids_list = [] ; loop_broken = True; break

                current_features[key] = torch.cat([current_features[key], new_feature_tensor], dim=1)
            if loop_broken: break

        if generated_ids_list:
             return torch.cat(generated_ids_list, dim=1)
        else:
             return torch.empty((start_token_ids.shape[0], 0), dtype=torch.long, device=device)


# ==============================================================
# === HELPER & TRAINING/EVAL LOOPS (Modified for XLA) ===
# ==============================================================

class WarmUpLR(torch.optim.lr_scheduler.LambdaLR):
    def __init__(self, optimizer, warm_up_steps, last_epoch=-1):
        if warm_up_steps <= 0:
            self.warm_up_steps = 0
        else:
            self.warm_up_steps = warm_up_steps
        super().__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)

    def lr_lambda(self, step):
        current_step = max(0, step)
        if self.warm_up_steps == 0:
            return 1.0
        if current_step < self.warm_up_steps:
            return float(current_step + 1) / float(self.warm_up_steps)
        return 1.0

def save_checkpoint(state: Dict[str, Any], is_best: bool, best_model_path: Optional[str], checkpoint_dir: Path, latest_filename: str, best_prefix: str):
    # This function should only be called by the master process
    local_logger = logging.getLogger(__name__ + ".save_checkpoint")
    try:
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
    except OSError as e:
        local_logger.error(f"Could not create checkpoint directory {checkpoint_dir}: {e}")
        return best_model_path

    latest_filepath = checkpoint_dir / latest_filename
    epoch_num = state.get('epoch', 0)
    metric_name = state.get('metric_name', 'loss')
    metric_value = state.get('best_metric_value', float('inf'))

    if isinstance(metric_value, float) and math.isfinite(metric_value):
        metric_value_str = f"{metric_value:.4f}".replace('.', '_')
    else:
        metric_value_str = "inf" if metric_value == float('inf') else "nan"

    try:
        xm.save(state, latest_filepath, master_only=True) # Ensure master_only safety
    except Exception as e:
        local_logger.error(f"Error saving latest checkpoint to {latest_filepath}: {e}", exc_info=True)

    current_best_filepath = None
    if is_best:
        best_filename = f"{best_prefix}{epoch_num+1}_{metric_name.lower()}_{metric_value_str}.pth"
        current_best_filepath = checkpoint_dir / best_filename
        try:
            xm.save(state, current_best_filepath, master_only=True) # Ensure master_only safety
            local_logger.info(f"Saved **best** checkpoint: {current_best_filepath} ({metric_name}={metric_value:.4f})")

            if best_model_path and Path(best_model_path).resolve() != current_best_filepath.resolve():
                previous_best = Path(best_model_path)
                if previous_best.exists() and previous_best.is_file():
                    try:
                        previous_best.unlink()
                        local_logger.info(f"Removed previous best checkpoint: {best_model_path}")
                    except OSError as e:
                        local_logger.warning(f"Could not remove previous best checkpoint {best_model_path}: {e}")
            return str(current_best_filepath)
        except Exception as e:
            local_logger.error(f"Error saving best checkpoint to {current_best_filepath}: {e}", exc_info=True)
            return best_model_path
    else:
        save_freq = HPARAMS.get('save_every_n_epochs', 0)
        if save_freq > 0 and (epoch_num + 1) % save_freq == 0:
            periodic_filename = f"checkpoint_ep{epoch_num+1}.pth"
            periodic_filepath = checkpoint_dir / periodic_filename
            try:
                if not (is_best and current_best_filepath and periodic_filepath.resolve() == current_best_filepath.resolve()):
                     xm.save(state, periodic_filepath, master_only=True) # Ensure master_only safety
                     local_logger.info(f"Saved periodic checkpoint: {periodic_filepath}")
            except Exception as e:
                local_logger.error(f"Error saving periodic checkpoint to {periodic_filepath}: {e}", exc_info=True)
        return best_model_path


def load_checkpoint(checkpoint_path: Optional[str], model: nn.Module, optimizer: Optional[torch.optim.Optimizer]=None, scheduler: Optional[Any]=None) -> Tuple[int, float, Optional[str]]:
    local_logger = logging.getLogger(__name__ + ".load_checkpoint")
    start_epoch = 0
    best_metric_value = float('inf')
    best_model_path = None
    device = xm.xla_device() # Get the target device

    checkpoint_path_obj = Path(checkpoint_path) if checkpoint_path else None

    if checkpoint_path_obj and checkpoint_path_obj.is_file():
        local_logger.info(f"Attempting to load checkpoint: '{checkpoint_path_obj}'")
        try:
            checkpoint = torch.load(checkpoint_path_obj, map_location='cpu')

            if 'model_state_dict' not in checkpoint:
                raise KeyError("Checkpoint missing 'model_state_dict'")

            state_dict = checkpoint['model_state_dict']
            if all(key.startswith('module.') for key in state_dict.keys()):
                local_logger.info("Removing 'module.' prefix from state_dict keys.")
                state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}

            model.load_state_dict(state_dict)
            model.to(device) # Move model to XLA device AFTER loading state
            local_logger.info("Model state loaded successfully and moved to XLA device.")

            if optimizer and 'optimizer_state_dict' in checkpoint:
                 local_logger.warning("Loading optimizer state from checkpoint is not fully implemented for XLA in this script. Optimizer will be re-initialized.")
                 pass
            if scheduler and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']:
                 local_logger.warning("Loading scheduler state from checkpoint is not fully implemented for XLA in this script. Scheduler will be re-initialized.")
                 pass

            start_epoch = checkpoint.get('epoch', -1) + 1
            best_metric_value = checkpoint.get('best_metric_value', float('inf'))
            best_model_path = checkpoint.get('best_model_path', None) # Path stored in the checkpoint
            metric_name = checkpoint.get('metric_name', 'loss')

            local_logger.info(f"Checkpoint loaded. Resuming from epoch {start_epoch}.")
            local_logger.info(f" Previous best '{metric_name}': {best_metric_value:.4f}")

            if best_model_path and not Path(best_model_path).exists():
                 logger.warning(f"Stored best model path '{best_model_path}' in checkpoint is invalid or file missing.")

        except FileNotFoundError:
            local_logger.error(f"Checkpoint file not found at '{checkpoint_path_obj}'. Starting from scratch.")
        except Exception as e:
            local_logger.error(f"Failed to load checkpoint from '{checkpoint_path_obj}': {e}", exc_info=True)
            start_epoch = 0; best_metric_value = float('inf'); best_model_path = None
    else:
        if checkpoint_path:
             local_logger.warning(f"Checkpoint path specified ('{checkpoint_path}') but not found or invalid. Starting from scratch.")
        else:
             local_logger.info("No checkpoint specified. Starting from scratch.")

    return start_epoch, best_metric_value, best_model_path


def train_epoch(rank, model, dataloader, optimizer, criterion, scheduler, device, epoch_num, num_epochs, padding_idx):
    local_logger = logging.getLogger(__name__ + ".train_epoch")
    model.train()
    total_loss = torch.tensor(0.0, device=device)
    total_correct = torch.tensor(0.0, device=device)
    total_items = torch.tensor(0.0, device=device)

    global vocab_size
    current_vocab_size = model.vocab_size if hasattr(model, 'vocab_size') else vocab_size

    para_loader = pl.ParallelLoader(dataloader, [device])
    data_iterator = para_loader.per_device_loader(device)

    # Progress bar only shown on master core
    progress_bar = tqdm(data_iterator,
                        desc=f"Train Epoch {epoch_num}/{num_epochs} [Core {rank}]",
                        unit="batch",
                        leave=False, # Keep leave=False so it disappears after the loop
                        dynamic_ncols=True,
                        disable=not xm.is_master_ordinal(),
                        mininterval=5.0 # Update at most every 5 seconds
                        )

    batch_count = 0
    for batch_data in progress_bar:
        batch_count += 1
        if batch_data is None:
            local_logger.warning(f"Skipping empty batch {batch_count} from ParallelLoader.")
            continue

        try:
            input_ids, target_ids_single, features_dict, _ = batch_data
            if not isinstance(features_dict, dict) or not features_dict:
                 local_logger.error(f"Features dictionary is empty or not a dict in train batch {batch_count}. Skipping.")
                 continue
        except (ValueError, TypeError) as e:
            local_logger.error(f"Skipping malformed train batch {batch_count}. Error: {e}")
            continue

        if input_ids.numel() == 0 or target_ids_single.numel() == 0:
            local_logger.warning(f"Skipping batch {batch_count} with empty input or target tensors.")
            continue

        features_dict_device = features_dict # Assume keys are already tensors on device

        # Forward Pass
        optimizer.zero_grad()
        src_key_padding_mask = (input_ids == padding_idx) if padding_idx is not None else None

        try:
            output_logits = model(
                src_tokens=input_ids,
                src_key_padding_mask=src_key_padding_mask,
                root_pc=features_dict_device.get('root_pc'),
                quality_code=features_dict_device.get('quality_code'),
                function_code=features_dict_device.get('function_code')
            )
            logits_for_next_token = output_logits[:, -1, :]

            if torch.any((target_ids_single < 0) | (target_ids_single >= current_vocab_size)):
                invalid_targets = target_ids_single[(target_ids_single < 0) | (target_ids_single >= current_vocab_size)]
                local_logger.error(f"Invalid target IDs found in Train Batch {batch_count} on Core {rank}! Values: {invalid_targets.unique().tolist()}. Vocab size: {current_vocab_size}. Clamping targets.")
                target_ids_safe = torch.clamp(target_ids_single, 0, current_vocab_size - 1).long()
            else:
                target_ids_safe = target_ids_single.long()

            loss = criterion(logits_for_next_token.float(), target_ids_safe)

        except Exception as e:
             local_logger.error(f"Error during forward/loss calculation in train batch {batch_count} on Core {rank}: {e}", exc_info=True)
             local_logger.error(f" Shapes: Input {input_ids.shape}, Target {target_ids_single.shape}. Feature keys: {list(features_dict_device.keys())}")
             optimizer.zero_grad() # Ensure grads are zeroed
             continue

        # Backward Pass & Optimization (XLA specific)
        if not torch.isfinite(loss):
            local_logger.error(f"NaN or Inf loss detected in Train Batch {batch_count} on Core {rank}! Loss: {loss.item()}. LR: {optimizer.param_groups[0]['lr']:.2e}. Skipping backward pass.")
            optimizer.zero_grad()
            continue

        try:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), HPARAMS['max_grad_norm'])
            xm.optimizer_step(optimizer)
            if scheduler:
                scheduler.step()
        except Exception as e:
            local_logger.error(f"Error during backward pass or optimizer step in train batch {batch_count} on Core {rank}: {e}", exc_info=True)
            optimizer.zero_grad()
            continue

        # Calculate Metrics (Local)
        with torch.no_grad():
            pred_next_token = logits_for_next_token.argmax(dim=-1)
            batch_correct = (pred_next_token == target_ids_safe).sum() # Keep as tensor
            batch_items = torch.tensor(target_ids_safe.numel(), device=device) # Keep as tensor

            if batch_items > 0:
                total_loss += loss.detach() * batch_items # Accumulate weighted loss tensor
                total_correct += batch_correct
                total_items += batch_items
                # Keep per-batch postfix update commented out for less noise
                # if xm.is_master_ordinal(): ... progress_bar.set_postfix(...) ...
            else:
                local_logger.warning(f"Train batch {batch_count} on Core {rank} had zero valid items.")

    progress_bar.close()

    # Aggregate metrics across all cores
    total_loss_reduced = xm.mesh_reduce('train_total_loss', total_loss, lambda x: torch.sum(torch.stack(x)))
    total_correct_reduced = xm.mesh_reduce('train_total_correct', total_correct, lambda x: torch.sum(torch.stack(x)))
    total_items_reduced = xm.mesh_reduce('train_total_items', total_items, lambda x: torch.sum(torch.stack(x)))

    # Calculate final epoch metrics
    final_loss = total_loss_reduced.item() / total_items_reduced.item() if total_items_reduced.item() > 0 else float('inf')
    final_acc = total_correct_reduced.item() / total_items_reduced.item() if total_items_reduced.item() > 0 else 0.0

    if not math.isfinite(final_loss):
        local_logger.warning(f"Non-finite average training loss calculated across cores for epoch {epoch_num}: {final_loss}")

    return final_loss, final_acc


@torch.no_grad()
def evaluate_epoch(rank, model, dataloader, criterion, device, padding_idx):
    local_logger = logging.getLogger(__name__ + ".evaluate_epoch")
    model.eval()
    total_loss = torch.tensor(0.0, device=device)
    total_correct = torch.tensor(0.0, device=device)
    total_items = torch.tensor(0.0, device=device)

    global vocab_size
    current_vocab_size = model.vocab_size if hasattr(model, 'vocab_size') else vocab_size

    para_loader = pl.ParallelLoader(dataloader, [device])
    data_iterator = para_loader.per_device_loader(device)

    progress_bar = tqdm(data_iterator,
                        desc=f"Evaluate [Core {rank}]",
                        unit="batch",
                        leave=False, # Keep leave=False
                        dynamic_ncols=True,
                        disable=not xm.is_master_ordinal(),
                        mininterval=5.0 # Update at most every 5 seconds
                        )

    batch_count = 0
    for batch_data in progress_bar:
        batch_count += 1
        if batch_data is None: continue
        try:
            input_ids, target_ids_single, features_dict, _ = batch_data
            if not isinstance(features_dict, dict) or not features_dict:
                local_logger.error(f"Features dictionary is empty or not a dict in eval batch {batch_count} on Core {rank}. Skipping.")
                continue
        except (ValueError, TypeError) as e:
            local_logger.error(f"Skipping malformed eval batch {batch_count} on Core {rank}. Error: {e}")
            continue

        if input_ids.numel() == 0 or target_ids_single.numel() == 0: continue

        features_dict_device = features_dict

        # Forward Pass
        src_key_padding_mask = (input_ids == padding_idx) if padding_idx is not None else None

        try:
            output_logits = model(
                src_tokens=input_ids,
                src_key_padding_mask=src_key_padding_mask,
                root_pc=features_dict_device.get('root_pc'),
                quality_code=features_dict_device.get('quality_code'),
                function_code=features_dict_device.get('function_code')
            )
            logits_for_next_token = output_logits[:, -1, :]

            if torch.any((target_ids_single < 0) | (target_ids_single >= current_vocab_size)):
                invalid_targets = target_ids_single[(target_ids_single < 0) | (target_ids_single >= current_vocab_size)]
                local_logger.error(f"Invalid target IDs found in Eval Batch {batch_count} on Core {rank}! Values: {invalid_targets.unique().tolist()}. Vocab size: {current_vocab_size}. Clamping targets.")
                target_ids_safe = torch.clamp(target_ids_single, 0, current_vocab_size - 1).long()
            else:
                target_ids_safe = target_ids_single.long()

            loss = criterion(logits_for_next_token.float(), target_ids_safe)

        except Exception as e:
             local_logger.error(f"Error during forward/loss calculation in eval batch {batch_count} on Core {rank}: {e}", exc_info=True)
             continue

        # Calculate Metrics
        if torch.isfinite(loss):
            pred_next_token = logits_for_next_token.argmax(dim=-1)
            batch_correct = (pred_next_token == target_ids_safe).sum()
            batch_items = torch.tensor(target_ids_safe.numel(), device=device)

            if batch_items > 0:
                total_loss += loss.detach() * batch_items
                total_correct += batch_correct
                total_items += batch_items
                # Keep per-batch postfix update commented out for less noise
                # if xm.is_master_ordinal(): ... progress_bar.set_postfix(...) ...
            else:
                local_logger.warning(f"Eval batch {batch_count} on Core {rank} had zero valid items.")
        else:
            local_logger.warning(f"NaN or Inf loss detected in Eval Batch {batch_count} on Core {rank}! Loss: {loss.item()}. Skipping metrics.")

    progress_bar.close()

    # Aggregate metrics across all cores
    total_loss_reduced = xm.mesh_reduce('eval_total_loss', total_loss, lambda x: torch.sum(torch.stack(x)))
    total_correct_reduced = xm.mesh_reduce('eval_total_correct', total_correct, lambda x: torch.sum(torch.stack(x)))
    total_items_reduced = xm.mesh_reduce('eval_total_items', total_items, lambda x: torch.sum(torch.stack(x)))


    final_loss = total_loss_reduced.item() / total_items_reduced.item() if total_items_reduced.item() > 0 else float('inf')
    final_acc = total_correct_reduced.item() / total_items_reduced.item() if total_items_reduced.item() > 0 else 0.0
    perplexity = float('inf')

    if math.isfinite(final_loss) and final_loss >= 0:
        try:
            perplexity = math.exp(final_loss)
        except OverflowError:
            local_logger.warning(f"Cannot calculate perplexity due to overflow with loss: {final_loss}")
            perplexity = float('inf')
    elif total_items_reduced.item() > 0:
        local_logger.warning(f"Non-finite average evaluation loss across cores: {final_loss}. Perplexity will be infinite.")

    return final_loss, final_acc, perplexity


# --- SSMD Calculation Helpers (Mostly unchanged, run on CPU) ---
def decode_sequence_strings(token_ids: List[int], id_to_vocab_map: Dict[int, str], padding_id: Optional[int]) -> List[str]:
    decoded = []
    for t_id in token_ids:
        if padding_id is not None and t_id == padding_id: continue
        label = id_to_vocab_map.get(t_id)
        if label is not None: decoded.append(label)
    return decoded

def parse_function_code(label: Optional[str]) -> Optional[int]:
    # !!! Needs customization based on your specific vocab/labels !!!
    # This is just an EXAMPLE based on common Roman Numeral Analysis patterns
    if label:
        # Simple checks (adjust based on your actual function codes in token_features.json)
        if label == 'I': return 1
        if label == 'V': return 5
        if label == 'IV': return 4
        if label == 'ii': return 2
        if label == 'vi': return 6
        if label == 'iii': return 3
        if label == 'vii': return 7 # Often diminished
        if label == 'I7': return 1 # Assuming dominant function overrides quality here
        if label == 'V7': return 5
        if label == 'IV7': return 4
        # Add more specific checks for qualities (major/minor/dim/aug/etc.) if needed
        # e.g., if 'i' in label and 'vi' not in label: return 1 # Minor tonic
        # Map padding token if needed
        global PADDING_VALUE, PAD_TOKEN, token_to_features_map
        if PADDING_VALUE is not None and label == PAD_TOKEN:
             # Return the function code associated with PAD_TOKEN in the map, default to 0
             return token_to_features_map.get(PADDING_VALUE, {}).get('function_code', 0)
    logging.debug(f"Could not parse function code for label: {label}")
    return None # Important: return None if unparseable

def extract_function_codes(decoded_labels: List[str]) -> List[int]:
    codes = []
    for label in decoded_labels:
        code = parse_function_code(label)
        if code is not None: codes.append(code)
    return codes

def calculate_ssm(feature_sequence: List[int], metric='cosine') -> Optional[np.ndarray]:
    if len(feature_sequence) < 2: return None
    features = np.array(feature_sequence).reshape(-1, 1)
    if np.all(features == features[0]):
        logging.debug("SSM calculation skipped: feature sequence is constant.")
        return None
    try:
        if metric == 'cosine':
            if np.std(features) < 1e-9:
                logging.debug("SSM calculation skipped: near-zero variance for cosine.")
                return None
        # pdist requires at least 2 points and >0 dimensions
        if features.shape[0] < 2 or features.shape[1] == 0:
            logging.debug(f"SSM calculation skipped: Insufficient data points or dimensions. Shape: {features.shape}")
            return None
        dist_condensed = pdist(features, metric=metric)
        ssm = squareform(dist_condensed)
        if metric == 'cosine': ssm = 1.0 - ssm
        ssm = np.nan_to_num(ssm, nan=0.0)
        return ssm
    except ValueError as ve:
        # Handle cases where pdist fails, e.g., input contains NaN/inf
        logging.warning(f"ValueError calculating SSM with metric '{metric}': {ve}. Input shape: {features.shape}", exc_info=False)
        return None
    except Exception as e:
        logging.error(f"Error calculating SSM with metric '{metric}': {e}", exc_info=False)
        return None

def calculate_mean_ssmd(ssm: Optional[np.ndarray], k: int = 1) -> Optional[float]:
    if ssm is None or not isinstance(ssm, np.ndarray) or ssm.ndim != 2 or ssm.shape[0] != ssm.shape[1]: return None
    n = ssm.shape[0]
    if not (0 <= k < n): return None
    if n <= k: return None # Need more elements than k for the k-th diagonal
    diag_k = np.diag(ssm, k=k)
    if diag_k.size == 0: return None
    # Use nanmean to ignore potential NaNs if any slip through ssm calculation
    mean_val = np.nanmean(diag_k)
    return float(mean_val) if np.isfinite(mean_val) else None

# --- End SSMD Helpers ---


# === NEW: Per-Epoch SSMD Evaluation Function ===
@torch.no_grad()
def calculate_ssmd_on_validation_set(model: nn.Module,
                                     val_dataset: Dataset, # Pass the actual validation dataset subset
                                     device: torch.device,
                                     hparams: Dict[str, Any],
                                     padding_idx: int,
                                     id_to_vocab_map: Dict[int, str],
                                     get_features_fn: callable,
                                     epoch_num: int) -> Tuple[Optional[float], Optional[float]]:
    """Calculates Generation and Reference SSMD(k=1) on (a subset of) the validation set. Runs on master."""
    local_logger = logging.getLogger(__name__ + ".calculate_ssmd_on_validation_set")
    model.eval() # Ensure model is in eval mode

    ssmd_metric = 'cosine'
    ssmd_k = 1
    ssmd_context_len = hparams.get('ssmd_context_len', 16)
    ssmd_gen_len = hparams.get('max_gen_len', 16)
    current_seq_len = hparams.get('sequence_length', 32)
    eval_batch_size = hparams.get('eval_batch_size', 16)
    num_workers = hparams.get('num_workers', 2) # Can use fewer workers for this potentially smaller eval
    max_batches = hparams.get('ssmd_eval_max_batches', None)

    if ssmd_context_len >= current_seq_len: ssmd_context_len = max(1, current_seq_len // 2)
    if ssmd_gen_len <= 0: ssmd_gen_len = 4
    if ssmd_context_len <= 0: ssmd_context_len = 1

    local_logger.info(f"Epoch {epoch_num} SSMD: Running calculation (Context={ssmd_context_len}, Gen={ssmd_gen_len}, MaxBatches={max_batches})")

    # Create a DataLoader specifically for the master process on the validation subset
    # Use SequentialSampler for deterministic order
    val_sampler_master = torch.utils.data.SequentialSampler(val_dataset)
    # Use the global collate function
    collate_wrapper = lambda batch: collate_fn_progression(batch)
    val_loader_master = DataLoader(val_dataset,
                                   batch_size=eval_batch_size,
                                   sampler=val_sampler_master,
                                   num_workers=num_workers,
                                   collate_fn=collate_wrapper,
                                   drop_last=False)

    all_gen_func_codes, all_ref_func_codes = [], []
    batches_processed = 0

    # Use tqdm.notebook for the batch progress within SSMD calculation
    ssmd_pbar = tqdm(val_loader_master,
                     desc=f"Epoch {epoch_num} SSMD Gen",
                     unit="batch",
                     leave=False, # Don't leave the bar after completion
                     dynamic_ncols=True)

    for batch_data in ssmd_pbar:
        if batch_data is None: continue
        batches_processed += 1
        if max_batches is not None and batches_processed > max_batches:
            local_logger.info(f"Epoch {epoch_num} SSMD: Reached max batches ({max_batches}).")
            break

        try:
            input_ids, _, features_dict, _ = batch_data
            if input_ids.numel() == 0: continue

            # Move data to the master's device
            input_ids = input_ids.to(device)
            features_dict_device = {k: v.to(device) for k, v in features_dict.items()}

            actual_context_len = min(ssmd_context_len, input_ids.shape[1])
            if actual_context_len == 0: continue

            input_context_ids = input_ids[:, :actual_context_len]
            input_context_features = {}
            valid_context = True
            expected_feature_keys_gen = hparams.get('feature_keys', [])
            for k in expected_feature_keys_gen:
                if k in features_dict_device:
                    input_context_features[k] = features_dict_device[k][:, :actual_context_len]
                else:
                    local_logger.warning(f"Epoch {epoch_num} SSMD: Missing feature '{k}' in batch {batches_processed}. Skipping batch.")
                    valid_context = False; break
            if not valid_context: continue

            # Generate sequences
            generated_ids_tensor = model.generate(
                start_token_ids=input_context_ids, start_features=input_context_features,
                max_length=ssmd_gen_len, temperature=0.7, top_k=50, top_p=0.95, # Example sampling params
                get_features_for_id_fn=get_features_fn
            )

            # Get reference sequences
            ref_start_idx = actual_context_len
            ref_end_idx = min(ref_start_idx + generated_ids_tensor.shape[1], input_ids.shape[1])
            ref_ids_tensor = input_ids[:, ref_start_idx : ref_end_idx]

            # Decode and extract function codes (runs on CPU)
            if id_to_vocab_map is None:
                local_logger.error(f"Epoch {epoch_num} SSMD: id_to_vocab map is missing.")
                return None, None
            batch_gen_labels = [decode_sequence_strings(seq, id_to_vocab_map, padding_idx) for seq in generated_ids_tensor.cpu().tolist()]
            batch_ref_labels = [decode_sequence_strings(seq, id_to_vocab_map, padding_idx) for seq in ref_ids_tensor.cpu().tolist()]

            all_gen_func_codes.extend([extract_function_codes(labels) for labels in batch_gen_labels])
            all_ref_func_codes.extend([extract_function_codes(labels) for labels in batch_ref_labels])

        except Exception as e:
            local_logger.error(f"Epoch {epoch_num} SSMD: Error during generation/processing batch {batches_processed}: {e}", exc_info=True)
            # Continue to next batch if possible
            continue

    ssmd_pbar.close()

    # Calculate SSMD Scores (on CPU)
    gen_ssmd_scores, ref_ssmd_scores = [], []
    num_skipped_gen, num_skipped_ref = 0, 0

    if not all_gen_func_codes or not all_ref_func_codes:
         local_logger.warning(f"Epoch {epoch_num} SSMD: No function codes extracted for SSMD calculation.")
         return None, None

    # Simple loop, could add tqdm here if many sequences are processed
    for gen_funcs, ref_funcs in zip(all_gen_func_codes, all_ref_func_codes):
        ssm_gen = calculate_ssm(gen_funcs, metric=ssmd_metric)
        ssmd_gen = calculate_mean_ssmd(ssm_gen, k=ssmd_k)
        if ssmd_gen is not None and math.isfinite(ssmd_gen): gen_ssmd_scores.append(ssmd_gen)
        else: num_skipped_gen += 1

        ssm_ref = calculate_ssm(ref_funcs, metric=ssmd_metric)
        ssmd_ref = calculate_mean_ssmd(ssm_ref, k=ssmd_k)
        if ssmd_ref is not None and math.isfinite(ssmd_ref): ref_ssmd_scores.append(ssmd_ref)
        else: num_skipped_ref += 1

    if num_skipped_gen > 0 : local_logger.warning(f"Epoch {epoch_num} SSMD: Skipped {num_skipped_gen} generated sequences during SSMD math.")
    if num_skipped_ref > 0 : local_logger.warning(f"Epoch {epoch_num} SSMD: Skipped {num_skipped_ref} reference sequences during SSMD math.")

    avg_gen_ssmd = np.mean(gen_ssmd_scores) if gen_ssmd_scores else None
    avg_ref_ssmd = np.mean(ref_ssmd_scores) if ref_ssmd_scores else None

    # Convert numpy float types to standard Python floats for JSON/logging if necessary
    final_avg_gen_ssmd = float(avg_gen_ssmd) if avg_gen_ssmd is not None and np.isfinite(avg_gen_ssmd) else None
    final_avg_ref_ssmd = float(avg_ref_ssmd) if avg_ref_ssmd is not None and np.isfinite(avg_ref_ssmd) else None

    local_logger.info(f"Epoch {epoch_num} SSMD: Avg Gen={final_avg_gen_ssmd:.4f} ({len(gen_ssmd_scores)} valid), Avg Ref={final_avg_ref_ssmd:.4f} ({len(ref_ssmd_scores)} valid)")

    return final_avg_gen_ssmd, final_avg_ref_ssmd
# === End Per-Epoch SSMD Function ===


# === XLA Multiprocessing Function ===
def _mp_fn(rank, flags):
    global HPARAMS, PADDING_VALUE, token_to_features_map, id_to_vocab, vocab, vocab_size, RESULTS_DIR # Declare globals modified here

    # --- Per-Process Setup ---
    torch.set_default_tensor_type('torch.FloatTensor') # Recommended for XLA
    seed = HPARAMS['seed'] + rank
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    is_master = xm.is_master_ordinal()
    xm.master_print(f"Starting _mp_fn on master process {rank}/{xm.xrt_world_size()}")
    if is_master:
         logger.info(f"Master process seed set to {seed}")
    xm.rendezvous(f'Process {rank} setup complete')

    # Get XLA device for this process
    device = xm.xla_device()

    # --- Load Vocab & Features (All processes need this info) ---
    if is_master: logger.info("--- Loading Vocabulary & Features ---")
    try:
        if not VOCAB_PATH.is_file(): raise FileNotFoundError(f"Vocab file not found: {VOCAB_PATH}")
        with open(VOCAB_PATH, 'r') as f: vocab = json.load(f)
        if not vocab or not isinstance(vocab, dict): raise ValueError("Vocab file is empty or invalid format.")

        try: id_to_vocab = {int(v): k for k, v in vocab.items()}
        except ValueError as e: raise ValueError(f"Invalid token ID in vocab: {e}")
        if not id_to_vocab: raise ValueError("ID->Vocab map creation failed.")

        if id_to_vocab: vocab_size = max(id_to_vocab.keys()) + 1
        else: raise ValueError("Could not determine vocab size.")

        if PAD_TOKEN not in vocab: raise ValueError(f"Padding token '{PAD_TOKEN}' not found in vocabulary!")
        PADDING_VALUE = vocab[PAD_TOKEN]
        if not isinstance(PADDING_VALUE, int) or PADDING_VALUE < 0: raise ValueError(f"Invalid Padding ID '{PADDING_VALUE}'")

        if is_master:
            logger.info(f"Vocabulary loaded. Size: {vocab_size}. Padding ID: {PADDING_VALUE}")
            logger.info(f"Loading token features from: {TOKEN_FEATURES_PATH}")

        # Load Token Features
        token_to_features_map = {} ; loaded_from_file = False
        if TOKEN_FEATURES_PATH.is_file():
            try:
                with open(TOKEN_FEATURES_PATH, 'r') as f: raw_map = json.load(f)
                token_to_features_map = {int(k): v for k, v in raw_map.items()}
                loaded_from_file = True
                if is_master:
                    vocab_ids = set(vocab.values())
                    map_ids = set(token_to_features_map.keys())
                    if vocab_ids != map_ids: logger.warning(f"Mismatch between vocab IDs and feature map IDs!")
                    expected_feature_keys = set(HPARAMS.get('feature_keys', []))
                    if token_to_features_map:
                         first_item_features = next(iter(token_to_features_map.values()))
                         if isinstance(first_item_features, dict):
                             found_keys = set(first_item_features.keys())
                             if found_keys != expected_feature_keys: logger.warning(f"Feature keys mismatch! Expected: {expected_feature_keys}, Found: {found_keys}")
                         else: raise ValueError("Invalid feature map format.")
                    if PADDING_VALUE is not None and PADDING_VALUE not in token_to_features_map:
                         logger.warning(f"Padding ID {PADDING_VALUE} missing from feature map. Adding default zeros.")
                         token_to_features_map[PADDING_VALUE] = {key: 0 for key in expected_feature_keys}
                    elif PADDING_VALUE is not None:
                         pad_features = token_to_features_map[PADDING_VALUE]
                         if set(pad_features.keys()) != expected_feature_keys:
                             logger.warning(f"Padding ID {PADDING_VALUE} features missing keys. Adding defaults.")
                             for key in expected_feature_keys:
                                 if key not in pad_features: pad_features[key] = 0
            except Exception as e:
                if is_master: logger.error(f"Failed to load or process features from {TOKEN_FEATURES_PATH}: {e}", exc_info=True)
                raise
        if not loaded_from_file: raise FileNotFoundError(f"Required token feature map missing at {TOKEN_FEATURES_PATH}")
        if not token_to_features_map: raise ValueError("Loaded token_to_features_map is empty.")
        if is_master: logger.info(f"Token feature map ready for {len(token_to_features_map)} tokens.")

    except Exception as e:
        xm.master_print(f"FATAL [Rank {rank}]: Error during Vocabulary or Feature Map loading: {e}", file=sys.stderr)
        raise SystemExit(f"Setup failed on rank {rank}") from e


    # --- Define get_features_for_id_fn (uses global token_to_features_map) ---
    def get_features_for_id(token_ids_cpu: torch.Tensor) -> Dict[str, torch.Tensor]:
        global token_to_features_map, PADDING_VALUE, HPARAMS
        if token_to_features_map is None: raise RuntimeError("token_to_features_map not initialized.")

        feature_keys = HPARAMS.get('feature_keys', [])
        if not feature_keys: raise RuntimeError("Could not determine feature keys.")

        batch_features = {key: [] for key in feature_keys}
        default_features = token_to_features_map.get(PADDING_VALUE, {key: 0 for key in feature_keys})

        for token_id_int in token_ids_cpu.tolist():
            features = token_to_features_map.get(token_id_int)
            if features is None:
                # logger.warning(f"Token ID {token_id_int} not found in feature map, using defaults.")
                features = default_features

            for key in feature_keys:
                batch_features[key].append(features.get(key, default_features.get(key, 0)))

        output_features = {}
        for key, val_list in batch_features.items():
            try:
                output_features[key] = torch.tensor(val_list, dtype=torch.long).unsqueeze(1)
            except Exception as e:
                logger.error(f"Error converting features for key '{key}' to tensor: {e}")
                raise RuntimeError(f"Failed to create feature tensor for key '{key}'") from e
        return output_features
    # --- End define get_features_for_id_fn ---


    # --- Pre-run Checks (Master Only) ---
    if is_master:
        logger.info("--- Performing Pre-run Checks ---")
        errors = []
        if not INPUT_SEQUENCES_PATH.is_file(): errors.append(f"Input data missing: '{INPUT_SEQUENCES_PATH}'")
        if not VOCAB_PATH.is_file(): errors.append(f"Vocab missing: '{VOCAB_PATH}'")
        if not TOKEN_FEATURES_PATH.is_file(): errors.append(f"Token features missing: '{TOKEN_FEATURES_PATH}'")
        if token_to_features_map:
            max_q_code = -1; max_f_code = -1
            for feat_dict in token_to_features_map.values():
                if isinstance(feat_dict, dict):
                    if 'quality_code' in feat_dict: max_q_code = max(max_q_code, feat_dict['quality_code'])
                    if 'function_code' in feat_dict: max_f_code = max(max_f_code, feat_dict['function_code'])
            if HPARAMS.get('num_qualities', 0) <= max_q_code: errors.append(f"HPARAM 'num_qualities' too small ({HPARAMS.get('num_qualities')} <= {max_q_code})")
            if HPARAMS.get('num_functions', 0) <= max_f_code: errors.append(f"HPARAM 'num_functions' too small ({HPARAMS.get('num_functions')} <= {max_f_code})")
        if errors:
            logger.critical("Pre-run checks failed:\n" + "\n".join(f"  - {e}" for e in errors))
            raise SystemExit("Pre-run checks failed on master")
        logger.info("Pre-run checks passed.")
        try:
            RESULTS_DIR.mkdir(parents=True, exist_ok=True)
            logger.info(f"Results directory: {RESULTS_DIR}")
        except OSError as e:
            logger.critical(f"FATAL: Cannot create results dir '{RESULTS_DIR}': {e}", exc_info=True); sys.exit(1)

    xm.rendezvous(f'Rank {rank} passed pre-run checks')


    # --- Create Datasets & DataLoaders ---
    if is_master: logger.info("--- Creating Datasets & DataLoaders ---")
    train_dataset_full, val_dataset_full, test_dataset_full = None, None, None # Define scope
    try:
        full_dataset = ChordProgressionDataset(INPUT_SEQUENCES_PATH, HPARAMS['sequence_length'])
        if len(full_dataset) == 0: raise ValueError("Dataset is empty.")

        num_samples = len(full_dataset)
        val_size = int(VAL_RATIO * num_samples)
        test_size = int(TEST_RATIO * num_samples)
        if VAL_RATIO > 0 and val_size == 0 and num_samples > 0: val_size = 1
        if TEST_RATIO > 0 and test_size == 0 and num_samples > val_size : test_size = 1
        test_size = min(test_size, num_samples - val_size)
        if test_size < 0: test_size = 0
        train_size = num_samples - val_size - test_size
        if train_size + val_size + test_size != num_samples:
             train_size = num_samples - val_size - test_size # Adjust train
             if is_master: logger.warning(f"Adjusting split sizes: T={train_size}, V={val_size}, Te={test_size}")
        if train_size < 0: raise ValueError("Negative train size.")

        if is_master: logger.info(f"Splitting {num_samples} samples: Train={train_size}, Val={val_size}, Test={test_size}")

        generator = torch.Generator().manual_seed(HPARAMS['seed'])
        # Assign to variables accessible later in the master process scope
        train_dataset_full, val_dataset_full, test_dataset_full = random_split(full_dataset, [train_size, val_size, test_size], generator=generator)

        # Distributed Samplers
        train_sampler = None
        if train_dataset_full and len(train_dataset_full) > 0:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset_full, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=True, seed=HPARAMS['seed']
            )
        val_sampler = None
        if val_dataset_full and len(val_dataset_full) > 0:
             val_sampler = torch.utils.data.distributed.DistributedSampler(
                 val_dataset_full, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False
             )
        test_sampler = None
        if test_dataset_full and len(test_dataset_full) > 0:
             test_sampler = torch.utils.data.distributed.DistributedSampler(
                 test_dataset_full, num_replicas=xm.xrt_world_size(), rank=rank, shuffle=False
             )

        # DataLoaders
        train_loader, val_loader, test_loader = None, None, None
        collate_wrapper = lambda batch: collate_fn_progression(batch)

        if train_sampler and HPARAMS['batch_size'] > 0:
            train_loader = DataLoader(train_dataset_full, batch_size=HPARAMS['batch_size'], sampler=train_sampler,
                                      num_workers=HPARAMS['num_workers'], collate_fn=collate_wrapper, drop_last=True)
            if is_master: logger.info(f"Train DataLoader created. Batches per core per epoch: ~{len(train_loader)}")
        elif is_master: logger.warning("Train loader not created.")

        if val_sampler and HPARAMS['eval_batch_size'] > 0:
             val_loader = DataLoader(val_dataset_full, batch_size=HPARAMS['eval_batch_size'], sampler=val_sampler,
                                      num_workers=HPARAMS['num_workers'], collate_fn=collate_wrapper, drop_last=False)
             if is_master: logger.info(f"Validation DataLoader created. Batches per core: {len(val_loader)}")
        elif is_master: logger.warning("Validation loader not created.")

        if test_sampler and HPARAMS['eval_batch_size'] > 0:
             test_loader = DataLoader(test_dataset_full, batch_size=HPARAMS['eval_batch_size'], sampler=test_sampler,
                                      num_workers=HPARAMS['num_workers'], collate_fn=collate_wrapper, drop_last=False)
             if is_master: logger.info(f"Test DataLoader created. Batches per core: {len(test_loader)}")
        elif is_master: logger.warning("Test loader not created.")

    except Exception as e:
        xm.master_print(f"FATAL [Rank {rank}]: Error during Dataset/DataLoader creation: {e}", file=sys.stderr)
        raise SystemExit(f"Data setup failed on rank {rank}") from e


    # --- Initialize Model ---
    if vocab_size <= 0 or PADDING_VALUE is None:
        xm.master_print(f"FATAL [Rank {rank}]: Vocab size ({vocab_size}) or Padding Value ({PADDING_VALUE}) invalid.", file=sys.stderr)
        raise SystemExit(f"Model init prereqs failed on rank {rank}")

    if is_master: logger.info("--- Initializing Model ---")
    try:
        model_hparams = {
            'vocab_size': vocab_size, 'd_model': HPARAMS['d_model'], 'nhead': HPARAMS['nhead'],
            'num_layers': HPARAMS['num_layers'], 'dim_feedforward': HPARAMS['dim_feedforward'],
            'dropout': HPARAMS['dropout'], 'padding_idx': PADDING_VALUE,
            'num_root_intervals': HPARAMS['num_root_intervals'], 'num_qualities': HPARAMS['num_qualities'],
            'num_functions': HPARAMS['num_functions'], 'relation_embedding_dim': HPARAMS['relation_embedding_dim']
        }
        model = HarmonyTransformerWithAdvH_RPE(**model_hparams).to(device)

        if is_master:
            num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            logger.info(f"Model initialized on device {device}. Trainable parameters: {num_params:,}")

    except Exception as e:
        xm.master_print(f"FATAL [Rank {rank}]: Model initialization failed: {e}", file=sys.stderr)
        raise SystemExit(f"Model init failed on rank {rank}") from e


    # --- Optimizer, Scheduler, Criterion ---
    if is_master: logger.info("--- Initializing Optimizer, Scheduler, Criterion ---")
    optimizer, scheduler, criterion = None, None, None
    try:
        lr = HPARAMS['lr']
        if HPARAMS.get('lr_scale_factor', None) is None:
            HPARAMS['lr_scale_factor'] = xm.xrt_world_size()
        if HPARAMS['lr_scale_factor'] > 1:
            lr *= HPARAMS['lr_scale_factor']
            if is_master: logger.info(f"Scaled learning rate by {HPARAMS['lr_scale_factor']} to {lr:.2e}")

        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                                       lr=lr, weight_decay=HPARAMS['weight_decay'])
        if is_master: logger.info(f"Optimizer: AdamW (lr={lr:.2e}, weight_decay={HPARAMS['weight_decay']})")

        actual_warm_up_steps = HPARAMS['warm_up_steps']
        if actual_warm_up_steps > 0:
            scheduler = WarmUpLR(optimizer, actual_warm_up_steps)
            if is_master: logger.info(f"Scheduler: WarmUpLR (steps={actual_warm_up_steps})")
        else:
            scheduler = None
            if is_master: logger.info("Scheduler: None")

        criterion = nn.CrossEntropyLoss(ignore_index=PADDING_VALUE)
        if is_master: logger.info(f"Criterion: CrossEntropyLoss (ignore_index={PADDING_VALUE})")

    except Exception as e:
        xm.master_print(f"FATAL [Rank {rank}]: Optimizer/Scheduler/Criterion init failed: {e}", file=sys.stderr)
        raise SystemExit(f"Setup failed on rank {rank}") from e


    # --- Load Checkpoint ---
    start_epoch = 0
    best_val_loss = float('inf')
    best_model_path = None
    if CHECKPOINT_TO_LOAD:
        checkpoint_path_str = str(CHECKPOINT_TO_LOAD)
        start_epoch, best_val_loss, best_model_path = load_checkpoint(
            checkpoint_path=checkpoint_path_str, model=model, optimizer=optimizer, scheduler=scheduler
        )
        xm.rendezvous(f'Rank {rank} finished checkpoint loading')


    # --- Training Loop ---
    target_epochs = HPARAMS['num_epochs']
    actual_start_epoch = start_epoch

    if is_master:
        logger.info(f"=== STARTING TRAINING (Epochs {actual_start_epoch + 1} to {target_epochs}) ===")
    patience_counter = 0
    start_train_time = time.time()
    # <<< MODIFIED training_history initialization >>>
    training_history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'val_perplexity': [],
                        'val_gen_ssmd': [], 'val_ref_ssmd': []}

    if not train_loader:
         if is_master: logger.error("Cannot start training: Training loader is not available.")
    elif actual_start_epoch >= target_epochs:
         if is_master: logger.warning(f"Start epoch ({actual_start_epoch}) >= target ({target_epochs}). No training needed.")
    else:
        epoch_pbar = tqdm(range(actual_start_epoch, target_epochs), desc="Epochs", unit="epoch",
                          initial=actual_start_epoch, total=target_epochs, dynamic_ncols=True,
                          disable=not is_master) # Disable on non-master

        for epoch in epoch_pbar:
            epoch_num_display = epoch + 1
            epoch_start_time = time.time()

            if train_sampler: train_sampler.set_epoch(epoch)

            # Training Phase
            train_loss, train_acc = train_epoch(
                rank, model, train_loader, optimizer, criterion, scheduler, device,
                epoch_num_display, target_epochs, PADDING_VALUE
            )
            if is_master: # Log aggregated results on master
                training_history['train_loss'].append(train_loss)
                training_history['train_acc'].append(train_acc)

            gc.collect() # Optional cleanup

            # --- >>> MODIFIED Validation Phase <<< ---
            val_loss, val_acc, val_perplexity = float('inf'), 0.0, float('inf')
            avg_gen_ssmd, avg_ref_ssmd = None, None # Initialize SSMD metrics for the epoch

            if val_loader:
                if val_sampler: val_sampler.set_epoch(epoch) # Set epoch for sampler
                val_loss, val_acc, val_perplexity = evaluate_epoch(
                    rank, model, val_loader, criterion, device, PADDING_VALUE
                )
                if is_master: # Store standard val metrics
                    training_history['val_loss'].append(val_loss)
                    training_history['val_acc'].append(val_acc)
                    training_history['val_perplexity'].append(val_perplexity)

                # --- Calculate SSMD on Validation Set (Master Only, periodically) ---
                if is_master and (epoch_num_display % HPARAMS['ssmd_eval_every_n_epochs'] == 0):
                    try:
                        # Make sure the validation dataset split is available
                        if val_dataset_full: # Check if it was created successfully
                            avg_gen_ssmd, avg_ref_ssmd = calculate_ssmd_on_validation_set(
                                model=model, # Model is already on the master device
                                val_dataset=val_dataset_full, # Use the validation split
                                device=device,
                                hparams=HPARAMS,
                                padding_idx=PADDING_VALUE,
                                id_to_vocab_map=id_to_vocab,
                                get_features_fn=get_features_for_id,
                                epoch_num=epoch_num_display
                            )
                        else:
                             logger.warning(f"Epoch {epoch_num_display}: Cannot calculate SSMD, validation dataset split not found.")

                    except Exception as ssmd_e:
                        logger.error(f"Epoch {epoch_num_display}: Error calculating SSMD: {ssmd_e}", exc_info=True)
                # --- End SSMD Calculation ---

                if is_master: # Store SSMD results (will be None if not calculated this epoch)
                    training_history['val_gen_ssmd'].append(avg_gen_ssmd)
                    training_history['val_ref_ssmd'].append(avg_ref_ssmd)

                # Construct status message for logging (INFO level)
                if is_master:
                    status_msg_parts = [
                        f"Ep {epoch_num_display}",
                        f"Train L={train_loss:.4f} A={train_acc:.4f}",
                        f"| Val L={val_loss:.4f} A={val_acc:.4f} PPL={val_perplexity:.2f}"
                    ]
                    # Add SSMD scores if they were calculated this epoch
                    if avg_gen_ssmd is not None:
                        status_msg_parts.append(f"GS={avg_gen_ssmd:.4f}")
                    #if avg_ref_ssmd is not None: # Optionally add Ref SSMD to log too
                    #    status_msg_parts.append(f"RefS={avg_ref_ssmd:.4f}")
                    status_msg = " ".join(status_msg_parts)

            else: # If no validation loader
                 if is_master:
                    # Append None for all validation metrics
                    training_history['val_loss'].append(None)
                    training_history['val_acc'].append(None)
                    training_history['val_perplexity'].append(None)
                    training_history['val_gen_ssmd'].append(None)
                    training_history['val_ref_ssmd'].append(None)
                    val_loss = train_loss # Use train loss for best model check if no val
                    status_msg = f"Ep {epoch_num_display}: Train L={train_loss:.4f} A={train_acc:.4f} | Val Skipped"

            # Log epoch summary (master only)
            if is_master: logger.info(status_msg)


            # Checkpointing & Early Stopping (Master Only)
            if is_master:
                current_metric = val_loss
                metric_name = "Val_Loss" if val_loader else "Train_Loss"
                is_best = False

                if isinstance(current_metric, float) and math.isfinite(current_metric):
                    if current_metric < best_val_loss:
                        best_val_loss = current_metric
                        patience_counter = 0
                        is_best = True
                        logger.info(f"*** New best {metric_name}: {best_val_loss:.4f} at epoch {epoch_num_display} ***")
                    elif val_loader: # Only increment patience if validation happened
                        patience_counter += 1
                # Only increment patience if validation happened and metric was non-finite
                elif val_loader:
                     patience_counter += 1
                     logger.warning(f"Epoch {epoch_num_display}: Non-finite validation metric ({current_metric}). Patience: {patience_counter}/{HPARAMS['patience']}")

                checkpoint_state = {
                    'epoch': epoch, 'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                    'best_metric_value': best_val_loss, 'metric_name': metric_name,
                    'best_model_path': best_model_path, 'hparams': HPARAMS,
                    'vocab_size': vocab_size, 'padding_idx': PADDING_VALUE
                }

                saved_best_path_updated = save_checkpoint(
                    state=checkpoint_state, is_best=is_best, best_model_path=best_model_path,
                    checkpoint_dir=RESULTS_DIR, latest_filename="latest_checkpoint.pth",
                    best_prefix=CHECKPOINT_FILENAME_BEST_PREFIX
                )
                if is_best: best_model_path = saved_best_path_updated

                # Early stopping check
                if val_loader and HPARAMS['patience'] > 0 and patience_counter >= HPARAMS['patience']:
                    logger.warning(f"EARLY STOPPING triggered at epoch {epoch_num_display} after {patience_counter} epochs without improvement.")
                    epoch_pbar.close()
                    break # Break loop on master

                # --- UPDATE Epoch Progress Bar Postfix ---
                epoch_duration = time.time() - epoch_start_time
                # Use metrics calculated in *this* epoch for the postfix
                epoch_pbar_postfix = {
                    "TrL": f"{train_loss:.3f}",
                    "ValL": f"{val_loss:.3f}" if val_loss is not None and math.isfinite(val_loss) else "N/A",
                    "ValAcc": f"{val_acc:.3f}" if val_acc is not None else "N/A",
                    # Display Gen SSMD score if calculated this epoch
                    "GS": f"{avg_gen_ssmd:.3f}" if avg_gen_ssmd is not None else "...",
                    "Time": f"{epoch_duration:.1f}s"
                }
                if patience_counter > 0 and val_loader and HPARAMS['patience'] > 0:
                    epoch_pbar_postfix["Patience"] = f"{patience_counter}/{HPARAMS['patience']}"
                # Refresh=True ensures the bar updates immediately
                epoch_pbar.set_postfix(epoch_pbar_postfix, refresh=True)
                # --- End Update ---

            # Barrier to ensure all processes finish the epoch and master handles checkpointing/stopping
            xm.rendezvous(f'Rank {rank} finished epoch {epoch_num_display}')
            # Check if master decided to stop early
            if is_master and 'epoch_pbar' in locals() and epoch_pbar.n < epoch_pbar.total -1 and patience_counter >= HPARAMS['patience'] and val_loader and HPARAMS['patience'] > 0:
                 # If master broke early, other ranks might need to exit loop too
                 break
        # --- End Epoch Loop ---

        if is_master and 'epoch_pbar' in locals(): epoch_pbar.close()

    if is_master:
        logger.info("=== FINISHED TRAINING ===")
        total_train_time = time.time() - start_train_time
        logger.info(f"Total Training Duration: {total_train_time // 3600:.0f}h {(total_train_time % 3600) // 60:.0f}m {total_train_time % 60:.2f}s")


    # === Final Evaluation on Test Set (Master Only) ===
    if is_master:
        logger.info(f"=== STARTING FINAL EVALUATION (on Master Core {rank}) ===")
        final_results = {}

        if best_model_path and Path(best_model_path).exists():
            logger.info(f"Loading best model for final evaluation from: {best_model_path}")
            try:
                checkpoint = torch.load(best_model_path, map_location='cpu')
                loaded_hparams = checkpoint.get('hparams', HPARAMS)
                loaded_vocab_size = checkpoint.get('vocab_size', vocab_size)
                loaded_padding_idx = checkpoint.get('padding_idx', PADDING_VALUE)

                eval_model_hparams = {
                     'vocab_size': loaded_vocab_size, 'd_model': loaded_hparams['d_model'], 'nhead': loaded_hparams['nhead'],
                     'num_layers': loaded_hparams['num_layers'], 'dim_feedforward': loaded_hparams['dim_feedforward'],
                     'dropout': 0.0, 'padding_idx': loaded_padding_idx, # Eval dropout=0
                     'num_root_intervals': loaded_hparams['num_root_intervals'], 'num_qualities': loaded_hparams['num_qualities'],
                     'num_functions': loaded_hparams['num_functions'], 'relation_embedding_dim': loaded_hparams['relation_embedding_dim']
                 }
                eval_model = HarmonyTransformerWithAdvH_RPE(**eval_model_hparams)

                state_dict = checkpoint['model_state_dict']
                if all(key.startswith('module.') for key in state_dict.keys()):
                     state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
                eval_model.load_state_dict(state_dict)
                eval_model.to(device)
                eval_model.eval()
                logger.info("Best model loaded successfully onto master XLA device for evaluation.")

                # Evaluate on Test Set (Loss, Acc, PPL) using master's dataloader slice
                if test_dataset_full and len(test_dataset_full) > 0: # Check if test set exists
                    test_sampler_master = torch.utils.data.SequentialSampler(test_dataset_full)
                    test_loader_master = DataLoader(test_dataset_full,
                                                    batch_size=HPARAMS['eval_batch_size'],
                                                    sampler=test_sampler_master,
                                                    num_workers=HPARAMS['num_workers'],
                                                    collate_fn=collate_wrapper,
                                                    drop_last=False)

                    if test_loader_master:
                        logger.info("--- Evaluating on Test Set (Loss, Acc, PPL) on Master ---")
                        # Use rank 0 and master loader
                        test_loss, test_acc, test_perplexity = evaluate_epoch(
                            rank=0, model=eval_model, dataloader=test_loader_master,
                            criterion=criterion, device=device, padding_idx=loaded_padding_idx
                        )
                        final_results['Test Loss'] = test_loss if math.isfinite(test_loss) else None
                        final_results['Test Accuracy'] = test_acc
                        final_results['Test Perplexity'] = test_perplexity if math.isfinite(test_perplexity) else None
                        logger.info("--- Final Test Set Results (Standard Metrics) ---")
                        logger.info(f"  Test Loss:         {test_loss:.4f}")
                        logger.info(f"  Test Accuracy:     {test_acc:.4f}") # <-- Final Test Accuracy
                        logger.info(f"  Test Perplexity:   {test_perplexity:.4f}")

                        # Calculate Final SSMD Metrics on Test Set
                        logger.info("--- Calculating Final Advanced Metrics (SSMD on Function Codes) on Master Test Set ---")
                        # Reuse the calculate_ssmd_on_validation_set function, but pass the test dataset
                        # Set max_batches to None to evaluate on the full test set
                        final_gen_ssmd, final_ref_ssmd = calculate_ssmd_on_validation_set(
                            model=eval_model,
                            val_dataset=test_dataset_full, # Use test set data
                            device=device,
                            hparams={**loaded_hparams, 'ssmd_eval_max_batches': None}, # Override max batches
                            padding_idx=loaded_padding_idx,
                            id_to_vocab_map=id_to_vocab,
                            get_features_fn=get_features_for_id,
                            epoch_num=0 # Indicate it's final eval
                        )

                        ssmd_k_final = loaded_hparams.get('ssmd_k', 1) # Assuming k=1 was used
                        ssmd_metric_final = 'cosine' # Assuming cosine was used
                        final_results[f'Test Avg Gen SSMD(k={ssmd_k_final}, FuncCode, {ssmd_metric_final})'] = final_gen_ssmd
                        final_results[f'Test Avg Ref SSMD(k={ssmd_k_final}, FuncCode, {ssmd_metric_final})'] = final_ref_ssmd
                        # Num valid sequences could be added here if needed

                        logger.info("--- Final Test Set Results (SSMD) ---")
                        logger.info(f"  Test Avg Gen SSMD(k={ssmd_k_final}, FuncCode): {final_gen_ssmd:.4f}")
                        logger.info(f"  Test Avg Ref SSMD(k={ssmd_k_final}, FuncCode): {final_ref_ssmd:.4f}")
                    else:
                        logger.warning("Skipping final test set evaluation: Test loader (master) could not be created.")
                else:
                     logger.warning("Skipping final test set evaluation: Test dataset is empty or was not created.")

            except Exception as e:
                logger.error(f"Final evaluation failed: {e}", exc_info=True)
                final_results['Error'] = f"Evaluation failed: {e}"

        elif not best_model_path:
            logger.error("Skipping final evaluation: No best model was saved or loaded.")
        else:
            logger.error(f"Skipping final evaluation: Best model checkpoint file not found at {best_model_path}")

        # Save Final Results Summary (Master Only)
        results_summary_path = RESULTS_DIR / "final_evaluation_summary.json"
        logger.info(f"--- Saving Final Results Summary ---")
        try:
            # Ensure history lists have the same length (pad with None if needed)
            max_hist_len = 0
            if training_history:
                for key in training_history.keys():
                    if training_history[key]: max_hist_len = max(max_hist_len, len(training_history[key]))
            if max_hist_len > 0:
                for key in training_history.keys():
                    current_len = len(training_history[key])
                    if current_len < max_hist_len: training_history[key].extend([None] * (max_hist_len - current_len))
                    elif current_len > max_hist_len: training_history[key] = training_history[key][:max_hist_len]

            serializable_hparams = {}
            for k, v in HPARAMS.items():
                if isinstance(v, (int, float, str, bool, list, dict, type(None))): serializable_hparams[k] = v
                elif isinstance(v, Path): serializable_hparams[k] = str(v)
                else: logger.warning(f"HPARAM '{k}' type {type(v)} not JSON serializable. Skipping.")

            metric_name_saved = 'N/A'
            if 'checkpoint_state' in locals() and isinstance(checkpoint_state, dict):
                 metric_name_saved = checkpoint_state.get('metric_name', 'N/A')

            summary_data = {
                'Timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
                'Best_Model_Path': str(best_model_path) if best_model_path else None,
                'Best_Val_Metric': {'Name': metric_name_saved, 'Value': best_val_loss if math.isfinite(best_val_loss) else None},
                'Hyperparameters': serializable_hparams,
                'Final_Test_Metrics': final_results,
                'Training_History': training_history # Includes per-epoch Acc and SSMD
            }
            with open(results_summary_path, 'w') as f:
                # Use default=str to handle potential non-serializable numpy types if any slip through
                json.dump(summary_data, f, indent=4, ensure_ascii=False, default=str)
            logger.info(f"Final evaluation summary saved to: {results_summary_path}")
        except Exception as e:
            logger.error(f"Error saving final results summary: {e}", exc_info=True)

        global_script_start_time = script_start_time if 'script_start_time' in globals() else start_train_time
        total_script_time = time.time() - global_script_start_time # Approximate total time
        logger.info(f"=== SCRIPT EXECUTION FINISHED (Master Rank {rank}, Total Duration: {total_script_time // 3600:.0f}h {(total_script_time % 3600) // 60:.0f}m {total_script_time % 60:.2f}s) ===")

    xm.rendezvous(f'Rank {rank} finished execution')


# ==============================================================
# === MAIN EXECUTION BLOCK (using xmp.spawn) ===
# ==============================================================
if __name__ == "__main__":
    logging.info("="*70); logging.info("=== STARTING Advanced H-RPE TPU Script (Epoch Metrics) ==="); logging.info("="*70)
    script_start_time = time.time() # Log start time before spawn

    flags = {}
    try:
        # Using nprocs=1 as per original script, though nprocs=None is often better for TPU auto-detection
        # If you have multiple TPU cores available and want to use them, change nprocs=None
        xmp.spawn(_mp_fn, args=(flags,), nprocs=1, start_method='fork') # Kept nprocs=1
    except Exception as main_exc:
        logging.critical(f"Exception during xmp.spawn: {main_exc}", exc_info=True)
        # Explicitly print traceback if logging fails or isn't fully configured yet
        traceback.print_exc()


    logging.info("="*70); logging.info("=== XMP Spawn Finished ==="); logging.info("="*70)