In [1]:
# =============================================================================
# COMPLETE Training Script: Melody Transformer-XL with RoPE & Conditioning
# Version: Corrected autocast API, Removed explicit model fp16 conversion,
#          Added max_seq_len, Reduced batch_size
# =============================================================================
import os
import json
import warnings
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
import sys
import traceback
import gc # Garbage Collector
from typing import Optional, List, Dict, Tuple, Any, Set
from dataclasses import dataclass, field
from pathlib import Path
from tqdm import tqdm # Use standard tqdm
import collections # Needed for defaultdict in collate_fn
import logging

# --- CUDA specific imports ---
# Use torch.amp directly for autocast, keep GradScaler from cuda.amp
from torch.cuda.amp import GradScaler
import torch.amp # Use torch.amp for autocast

# For timestamp and location context
import datetime
import pytz

# --- Define Vocabulary Constants (!! VERIFY THESE STRINGS !!) ---
MELODY_PAD_TOKEN: str = "<PAD>"
MELODY_UNK_TOKEN: str = "<UNK>"
CHORD_VOCAB_FILENAME: str = "chord_progression_vocab.json"
# CHORD_PAD_TOKEN: str = "<PAD>" # Example

@dataclass
class TrainingConfig:
    # --- Paths ---
    midi_root_dir: str = "LOCAL_PATH_IGNORE"
    chord_data_dir: str = "/kaggle/input/advance-h-rpe"
    melody_data_path: str = "/kaggle/input/new-melody-model-new-approach-1/training_data.jsonl"
    melody_vocab_path: str = "/kaggle/input/new-melody-model-new-approach-1/event_vocab.json"
    output_dir: str = "/kaggle/working/melody_model_output"

    # --- Vocab Sizes & Padding ---
    melody_vocab_size: int = 0
    chord_vocab_size: int = 0
    melody_pad_token_id: int = 0
    chord_pad_token_id: int = 0

    # --- Model Architecture ---
    n_layer: int = 8
    d_model: int = 512
    n_head: int = 8
    d_head: int = 64
    d_inner: int = 2048
    dropout: float = 0.1
    mem_len: int = 256 # Consider reducing if OOM persists (e.g., 128)
    rope_theta: float = 10000.0
    num_chord_features: int = 3
    condition_proj_dim: int = 128
    chord_emb_dim: Optional[int] = 64

    # --- Training ---
    batch_size: int = 16 # <<< REDUCED BATCH SIZE
    num_epochs: int = 100
    learning_rate: float = 3e-4
    weight_decay: float = 0.01
    grad_clip_value: float = 1.0
    seed: int = 42

    # --- Data Loading ---
    max_seq_len: Optional[int] = 512 # <<< REDUCED MAX SEQUENCE LENGTH FURTHER
    train_split: float = 0.90
    val_split: float = 0.05
    test_split: float = 0.05
    num_dataload_workers: int = 2

    # --- Runtime ---
    amp_dtype: torch.dtype = torch.float16 # Use float16 for Automatic Mixed Precision


# === HELPER FUNCTIONS ===

def set_seed(seed: int):
    """Sets random seeds for reproducibility across libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    logging.info(f"Random seed set to {seed}")

# --- Metrics ---
def calculate_accuracy(logits: torch.Tensor, targets: torch.Tensor, pad_token_id: int) -> Tuple[float, int]:
    """Calculates accuracy ignoring pad tokens."""
    mask = (targets != pad_token_id)
    num_valid_tokens = mask.sum().item()
    if num_valid_tokens == 0:
        return 0.0, 0
    predictions = logits.argmax(dim=-1)
    correct_predictions = (predictions[mask] == targets[mask]).sum().item()
    accuracy = correct_predictions / num_valid_tokens
    return accuracy, num_valid_tokens

def self_similarity_matrix_distance(pred_seq: List[int], target_seq: List[int], pad_id: int) -> float:
    """Calculates 1 - (ratio of common unique non-pad tokens to unique non-pad target tokens)."""
    pred_clean = [tok for tok in pred_seq if tok != pad_id]
    target_clean = [tok for tok in target_seq if tok != pad_id]
    if not target_clean: return 1.0 if pred_clean else 0.0
    pred_set = set(pred_clean)
    target_set = set(target_clean)
    if not target_set: return 1.0 if pred_set else 0.0
    common_elements = len(pred_set.intersection(target_set))
    unique_target_elements = len(target_set)
    similarity_ratio = common_elements / unique_target_elements if unique_target_elements > 0 else 0.0
    return 1.0 - similarity_ratio

def grooving_similarity(pred_seq: List[int], target_seq: List[int], pad_id: int) -> float:
    """Calculates ratio of min length to max length after removing padding."""
    pred_clean = [tok for tok in pred_seq if tok != pad_id]
    target_clean = [tok for tok in target_seq if tok != pad_id]
    len_p, len_t = len(pred_clean), len(target_clean)
    max_len = max(len_p, len_t)
    min_len = min(len_p, len_t)
    return min_len / max_len if max_len > 0 else 1.0

# === RoPE Implementation ===
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(x, cos, sin):
    # x: [seq_len, bsz, n_head, d_head]
    # cos, sin: [seq_len, 1, 1, d_head]
    x_embed = (x * cos) + (rotate_half(x) * sin)
    return x_embed

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=4096, base=10000.0, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.max_seq_len_cached = -1
        self.register_buffer("cos_cached", None, persistent=False)
        self.register_buffer("sin_cached", None, persistent=False)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        if seq_len <= self.max_seq_len_cached and \
           self.cos_cached is not None and self.sin_cached is not None and \
           self.cos_cached.device == device and self.cos_cached.dtype == dtype:
             return
        self.max_seq_len_cached = max(seq_len, self.max_position_embeddings) # Ensure cache is at least max_pos_embeddings
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos().to(dtype).detach()
        self.sin_cached = emb.sin().to(dtype).detach()
        logging.debug(f"RoPE cache updated: seq_len={self.max_seq_len_cached}, device={device}, dtype={dtype}")

    def forward(self, x: torch.Tensor, seq_len: int, start_pos: int = 0):
        device = x.device
        dtype = x.dtype
        required_len = start_pos + seq_len
        if required_len > self.max_seq_len_cached or self.cos_cached is None or self.cos_cached.device != device or self.cos_cached.dtype != dtype:
            new_max_len = max(self.max_position_embeddings, required_len)
            self._set_cos_sin_cache(seq_len=new_max_len, device=device, dtype=dtype)
        # Slice directly using start_pos and seq_len
        end_pos = start_pos + seq_len
        cos = self.cos_cached[start_pos : end_pos]
        sin = self.sin_cached[start_pos : end_pos]
        return cos, sin

# === Model Components ===
class RelPartialLearnableMultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, config: TrainingConfig, layer_idx: int):
        super().__init__()
        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout
        self.config = config
        self.layer_idx = layer_idx

        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
        self.drop = nn.Dropout(dropout)

        self.rotary_emb = RotaryEmbedding(
            dim=self.d_head,
            max_position_embeddings=config.mem_len + (config.max_seq_len if config.max_seq_len is not None else 2048),
            base=config.rope_theta
        )
        self.scale = 1.0 / (d_head ** 0.5)

    def forward(self, w: torch.Tensor, mems: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
        qlen, bsz, _ = w.size()
        mlen = mems.size(0) if mems is not None and mems.dim() == 3 and mems.shape[0] > 0 else 0
        klen = mlen + qlen

        if mlen > 0:
            if mems.shape[1:] != w.shape[1:]:
                logging.warning(f"Layer {self.layer_idx}: Mem shape {mems.shape} != Input shape {w.shape}. Discarding memory.")
                cat = w; mlen = 0; klen = qlen
            else:
                cat = torch.cat([mems, w], dim=0)
        else:
            cat = w

        w_heads = self.qkv_net(cat)
        w_heads = w_heads.view(klen, bsz, self.n_head, 3 * self.d_head)
        q_head_raw, k_head_raw, v_head = torch.chunk(w_heads, 3, dim=-1)

        q_head = q_head_raw[-qlen:]
        k_head = k_head_raw

        cos_k, sin_k = self.rotary_emb(k_head, seq_len=klen, start_pos=0)
        # Correctly slice cos/sin for queries based on memory length
        cos_q = cos_k[mlen:klen]
        sin_q = sin_k[mlen:klen]

        q_head_rot = apply_rotary_pos_emb(q_head, cos_q.unsqueeze(1).unsqueeze(2), sin_q.unsqueeze(1).unsqueeze(2))
        k_head_rot = apply_rotary_pos_emb(k_head, cos_k.unsqueeze(1).unsqueeze(2), sin_k.unsqueeze(1).unsqueeze(2))

        q_head_ = q_head_rot.permute(1, 2, 0, 3)
        k_head_ = k_head_rot.permute(1, 2, 0, 3)
        v_head_ = v_head.permute(1, 2, 0, 3)

        attn_score = torch.matmul(q_head_, k_head_.transpose(-2, -1))
        attn_score = attn_score * self.scale

        if attn_mask is not None:
             if attn_mask.dim() == 2:
                 mask_to_apply = attn_mask.unsqueeze(0).unsqueeze(0)
             elif attn_mask.dim() == 4:
                mask_to_apply = attn_mask
             else:
                 logging.warning(f"Layer {self.layer_idx}: Unexpected attention mask shape {attn_mask.shape}. Ignoring mask.")
                 mask_to_apply = None

             if mask_to_apply is not None:
                 mask_to_apply = mask_to_apply.to(device=attn_score.device, dtype=torch.bool)
                 if mask_to_apply.shape[-2:] == attn_score.shape[-2:]:
                     # Expand mask if necessary to match batch and head dimensions
                     # Ensure mask broadcasting works correctly: (bsz, n_head, qlen, klen)
                     if mask_to_apply.shape[0] != attn_score.shape[0] and mask_to_apply.shape[0] == 1:
                         mask_to_apply = mask_to_apply.expand(attn_score.shape[0], -1, -1, -1)
                     if mask_to_apply.shape[1] != attn_score.shape[1] and mask_to_apply.shape[1] == 1:
                         mask_to_apply = mask_to_apply.expand(-1, attn_score.shape[1], -1, -1)

                     # Check again after potential expansion
                     if mask_to_apply.shape == attn_score.shape:
                         attn_score = attn_score.masked_fill(mask_to_apply, torch.finfo(attn_score.dtype).min)
                     else:
                         logging.warning(f"Layer {self.layer_idx}: Mask shape {mask_to_apply.shape} mismatch with score shape {attn_score.shape} after broadcasting attempts. Ignoring mask.")
                 else:
                     logging.warning(f"Layer {self.layer_idx}: Mask shape {mask_to_apply.shape[-2:]} incompatible with score shape {attn_score.shape[-2:]}. Ignoring mask.")

        attn_prob = F.softmax(attn_score.float(), dim=-1).to(attn_score.dtype)
        attn_prob = self.drop(attn_prob)

        if head_mask is not None:
             attn_prob = attn_prob * head_mask.to(attn_prob.device)

        attn_vec = torch.matmul(attn_prob, v_head_)
        attn_vec = attn_vec.permute(2, 0, 1, 3).contiguous()
        attn_vec = attn_vec.view(qlen, bsz, self.n_head * self.d_head)

        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        return attn_out

class TransformerXLLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout, config: TrainingConfig, layer_idx: int):
        super().__init__()
        self.layer_idx = layer_idx
        self.dec_attn = RelPartialLearnableMultiHeadAttn(
            n_head, d_model, d_head, dropout, config=config, layer_idx=layer_idx
        )
        self.pos_ff = nn.Sequential(
            nn.Linear(d_model, d_inner),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_inner, d_model),
            nn.Dropout(dropout),
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, dec_inp: torch.Tensor, mems: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
        x_norm1 = self.norm1(dec_inp)
        attn_output = self.dec_attn(
            w=x_norm1, mems=mems, attn_mask=attn_mask, head_mask=head_mask
        )
        h = dec_inp + self.dropout(attn_output)
        h_norm2 = self.norm2(h)
        ff_output = self.pos_ff(h_norm2)
        output = h + self.dropout(ff_output)
        return output

# === Main Model ===
class MelodyTransformerXL(nn.Module):
    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.config = config
        # --- Input validations ---
        if config.melody_vocab_size <= 0:
            raise ValueError(f"config.melody_vocab_size must be positive, got {config.melody_vocab_size}")
        if config.melody_pad_token_id < 0 or config.melody_pad_token_id >= config.melody_vocab_size:
            raise ValueError(f"config.melody_pad_token_id ({config.melody_pad_token_id}) is out of range for melody_vocab_size ({config.melody_vocab_size}).")

        self.use_chord_embedding = config.chord_emb_dim is not None and config.chord_emb_dim > 0
        if self.use_chord_embedding:
            if config.chord_vocab_size <= 0:
                 raise ValueError(f"config.chord_vocab_size ({config.chord_vocab_size}) must be positive if chord_emb_dim ({config.chord_emb_dim}) is enabled.")
            if config.chord_pad_token_id < 0 or config.chord_pad_token_id >= config.chord_vocab_size:
                 raise ValueError(f"Chord pad token ID ({config.chord_pad_token_id}) is out of bounds for final chord vocab size ({config.chord_vocab_size}).")
        if config.d_model % config.n_head != 0:
            raise ValueError(f"d_model ({config.d_model}) must be divisible by n_head ({config.n_head}).")
        config.d_head = config.d_model // config.n_head

        self.d_model = config.d_model
        self.n_head = config.n_head
        self.d_head = config.d_head
        self.mem_len = config.mem_len
        self.n_layer = config.n_layer

        self.melody_emb = nn.Embedding(config.melody_vocab_size, config.d_model, padding_idx=config.melody_pad_token_id)

        condition_proj_dim = max(1, config.condition_proj_dim)
        self.chord_feature_processor = nn.Linear(config.num_chord_features, condition_proj_dim)

        total_conditioning_dim = condition_proj_dim
        if self.use_chord_embedding:
            chord_emb_dim = max(1, config.chord_emb_dim)
            self.chord_emb = nn.Embedding(config.chord_vocab_size, chord_emb_dim, padding_idx=config.chord_pad_token_id)
            total_conditioning_dim += chord_emb_dim
        else:
            self.chord_emb = None

        combined_input_dim = config.d_model + total_conditioning_dim
        self.input_proj = nn.Linear(combined_input_dim, config.d_model)
        self.drop = nn.Dropout(config.dropout)

        self.layers = nn.ModuleList([
            TransformerXLLayer(
                n_head=self.n_head, d_model=self.d_model, d_head=self.d_head,
                d_inner=config.d_inner, dropout=config.dropout, config=config, layer_idx=i
            ) for i in range(config.n_layer)
        ])

        self.final_norm = nn.LayerNorm(config.d_model)
        self.out_layer = nn.Linear(config.d_model, config.melody_vocab_size, bias=False)

        if config.d_model == self.melody_emb.embedding_dim:
           self.out_layer.weight = self.melody_emb.weight
           logging.info("Tying input melody embedding weights with the final output layer.")

        self.apply(self._init_weights)
        logging.info(f"MelodyTransformerXL initialized with {config.n_layer} layers, d_model={config.d_model}, n_head={config.n_head}, mem_len={config.mem_len}")
        logging.info(f"Melody Vocab Size: {config.melody_vocab_size}, Chord Vocab Size: {config.chord_vocab_size}")
        logging.info(f"Using Chord Embeddings: {self.use_chord_embedding}")

    @property
    def dtype(self) -> torch.dtype:
        try:
            return next(self.parameters()).dtype
        except StopIteration:
            return torch.get_default_dtype()

    def _init_weights(self, module):
        """Initializes weights using standard practices."""
        scale = 1.0
        if hasattr(self.config, 'n_layer') and self.config.n_layer > 0:
             scale = 1 / math.sqrt(2.0 * self.config.n_layer)

        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02 * scale)
            if module.bias is not None: nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.padding_idx is not None:
                with torch.no_grad(): module.weight[module.padding_idx].fill_(0)
        elif isinstance(module, nn.LayerNorm):
            if module.bias is not None: module.bias.data.zero_()
            if module.weight is not None: module.weight.data.fill_(1.0)

    def _update_mems(self, hids: List[torch.Tensor], mems: List[Optional[torch.Tensor]], mlen: int) -> List[Optional[torch.Tensor]]:
        """Updates the memory states, detaching the results."""
        if mlen <= 0 or not hids:
            return [None] * (self.n_layer + 1)

        if mems is None or all(m is None or m.numel() == 0 for m in mems):
             return [(h[-mlen:].detach() if h is not None and h.numel() > 0 else None) for h in hids]

        if len(hids) != len(mems):
            logging.error(f"BUG: Mismatch between number of hidden states ({len(hids)}) and memory slots ({len(mems)}). Cannot update memory reliably.")
            return [None] * (self.n_layer + 1)

        new_mems = []
        with torch.no_grad():
            for i, hid in enumerate(hids):
                mem = mems[i]
                if hid is None:
                    new_mems.append(mem.detach() if mem is not None else None)
                    logging.warning(f"Hidden state for layer {i} is None during memory update.")
                    continue

                if mem is not None and mem.dim() == 3 and mem.numel() > 0:
                    if mem.shape[1:] == hid.shape[1:]:
                        cat = torch.cat([mem, hid], dim=0)
                    else:
                        logging.warning(f"Memory shape {mem.shape} incompatible with hid shape {hid.shape} at layer {i}. Resetting memory segment for this layer.")
                        cat = hid
                else:
                    cat = hid

                new_mems.append(cat[-mlen:].detach())

        return new_mems

    def init_mems(self, bsz: int, device: torch.device, dtype: torch.dtype) -> List[Optional[torch.Tensor]]:
        """Initializes memory states (as empty tensors or Nones)."""
        if self.mem_len > 0:
            mems = [
                torch.empty((0, bsz, self.config.d_model), dtype=dtype, device=device)
                for _ in range(self.n_layer + 1)
            ]
        else:
            mems = [None] * (self.n_layer + 1)
        return mems

    def _create_attn_mask(self, qlen, mlen, device):
        """Creates the causal attention mask for Transformer-XL (True values are masked)."""
        klen = mlen + qlen
        mask = torch.triu(torch.ones(qlen, klen, device=device, dtype=torch.bool), diagonal=1 + mlen)
        return mask

    def forward(
        self,
        event_ids: torch.Tensor,                     # (bsz, qlen) [Long]
        conditioning_chord_ids: torch.Tensor,       # (bsz, qlen) [Long]
        conditioning_root_pc: torch.Tensor,         # (bsz, qlen) [Float]
        conditioning_quality_code: torch.Tensor,    # (bsz, qlen) [Float]
        conditioning_function_code: torch.Tensor,   # (bsz, qlen) [Float]
        mems: Optional[List[Optional[torch.Tensor]]] = None
    ) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]]]:

        bsz, qlen = event_ids.size()
        device = event_ids.device
        target_dtype = self.dtype

        # --- Memory Validation & Initialization ---
        if self.mem_len > 0:
            if mems is None:
                 mems = self.init_mems(bsz, device, target_dtype)
            elif len(mems) != self.n_layer + 1:
                 logging.warning(f"Incorrect memory list length ({len(mems)} vs {self.n_layer + 1}). Resetting memory.")
                 mems = self.init_mems(bsz, device, target_dtype)
            else:
                 for i, mem in enumerate(mems):
                     if mem is not None and mem.numel() > 0:
                         if mem.shape[1] != bsz or mem.device != device or mem.dtype != target_dtype:
                             logging.warning(f"Memory state at index {i} incompatible. Resetting memory.")
                             mems = self.init_mems(bsz, device, target_dtype)
                             break
        else:
            mems = [None] * (self.n_layer + 1)

        mlen = mems[0].size(0) if mems is not None and mems[0] is not None and mems[0].dim() == 3 else 0

        # --- Input Embeddings & Conditioning ---
        clamped_event_ids = event_ids.clamp(0, self.config.melody_vocab_size - 1)
        melody_embedded = self.melody_emb(clamped_event_ids)

        cond_features_raw = torch.stack([
            conditioning_root_pc, conditioning_quality_code, conditioning_function_code
        ], dim=-1).to(target_dtype)
        cond_features_proj = F.relu(self.chord_feature_processor(cond_features_raw))

        if self.use_chord_embedding:
            clamped_chord_ids = conditioning_chord_ids.clamp(0, self.config.chord_vocab_size - 1)
            chord_embedded = self.chord_emb(clamped_chord_ids)
            cond_combined = torch.cat([cond_features_proj, chord_embedded], dim=-1)
        else:
            cond_combined = cond_features_proj

        combined_input = torch.cat([melody_embedded, cond_combined], dim=-1)
        core_input = self.input_proj(combined_input)
        core_input = self.drop(core_input)

        # --- Transpose for Transformer Layers ---
        core_input = core_input.transpose(0, 1).contiguous() # (qlen, bsz, d_model)

        # --- Attention Mask ---
        attn_mask = self._create_attn_mask(qlen, mlen, device) if qlen > 0 else None

        # --- Pass through Layers ---
        hids_for_mem = [core_input]
        layer_input = core_input

        for i, layer in enumerate(self.layers):
            layer_mem = mems[i] if mems is not None else None
            layer_output = layer(
                dec_inp=layer_input,
                mems=layer_mem,
                attn_mask=attn_mask
            )
            hids_for_mem.append(layer_output)
            layer_input = layer_output

        # --- Update Memory ---
        new_mems = self._update_mems(hids_for_mem, mems, self.mem_len)

        # --- Final Output Processing ---
        core_output = self.drop(layer_input)
        final_output = self.final_norm(core_output)
        logits = self.out_layer(final_output)

        logits = logits.transpose(0, 1).contiguous() # (bsz, qlen, melody_vocab_size)

        return logits, new_mems


# === Dataset & Collation ===
class MelodyDataset(Dataset):
    def __init__(self, jsonl_path: str, melody_vocab: Dict[str, int], config: TrainingConfig):
        self.samples: List[Dict[str, Any]] = []
        self.melody_vocab = melody_vocab
        self.melody_pad_id = config.melody_pad_token_id
        self.melody_unk_id = melody_vocab.get(MELODY_UNK_TOKEN)
        if self.melody_unk_id is None:
            logging.warning(f"Melody UNK token '{MELODY_UNK_TOKEN}' not found in vocab! Using 0 (check for conflict with PAD).")
            self.melody_unk_id = 0

        self.max_len = config.max_seq_len
        self.sequence_key = "event_ids"

        print(f"Loading dataset from: {jsonl_path}...")
        skipped_count = 0
        required_keys = [self.sequence_key, "conditioning_chord_ids", "conditioning_root_pc",
                         "conditioning_quality_code", "conditioning_function_code"]
        try:
            with open(jsonl_path, 'r', encoding='utf-8') as f:
                for i, line in enumerate(f):
                    try:
                        line = line.strip()
                        if not line:
                            logging.debug(f"Skipping empty line ~{i+1}")
                            skipped_count += 1; continue
                        record = json.loads(line)
                        missing_keys = [key for key in required_keys if key not in record]
                        if missing_keys:
                            logging.debug(f"Skipping line ~{i+1}: Missing keys: {missing_keys}. Line content: {line[:100]}...")
                            skipped_count += 1; continue
                        event_list = record.get(self.sequence_key)
                        if not isinstance(event_list, list) or not event_list:
                            logging.debug(f"Skipping line ~{i+1}: Invalid or empty '{self.sequence_key}'. Got: {type(event_list)}")
                            skipped_count += 1; continue
                        if any(not isinstance(item, int) for item in event_list):
                             logging.warning(f"Skipping line ~{i+1}: '{self.sequence_key}' contains non-integer values.")
                             skipped_count += 1; continue
                        event_len = len(event_list)
                        if event_len == 0:
                            logging.debug(f"Skipping line ~{i+1}: Zero length '{self.sequence_key}'.")
                            skipped_count += 1; continue
                        valid_lengths = True
                        for cond_key in required_keys[1:]:
                             cond_data = record.get(cond_key)
                             if not isinstance(cond_data, list) or len(cond_data) != event_len:
                                  logging.debug(f"Skipping line ~{i+1}: Length mismatch for key '{cond_key}'. Melody len={event_len}, '{cond_key}' len={len(cond_data) if isinstance(cond_data, list) else 'Not a list'}.")
                                  valid_lengths = False; break
                        if not valid_lengths:
                            skipped_count += 1; continue
                        self.samples.append(record)
                    except json.JSONDecodeError:
                        logging.warning(f"Skipping invalid JSON on line ~{i+1}: {line[:100]}...")
                        skipped_count += 1
                    except Exception as e:
                        logging.warning(f"Skipping record due to error on line ~{i+1}: {e} - Line content: {line[:100]}...")
                        skipped_count += 1
            if not self.samples: raise ValueError("Dataset loaded but contains 0 valid samples after filtering.")
            print(f"Dataset loaded successfully: {len(self.samples)} samples (skipped {skipped_count} invalid lines).")
        except FileNotFoundError: print(f"FATAL: Dataset file not found at {jsonl_path}"); raise
        except Exception as e: print(f"FATAL: Error loading/parsing dataset {jsonl_path}: {e}"); raise

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Optional[Dict[str, torch.Tensor]]:
        try:
            record = self.samples[idx]
            event_ids_from_file = record[self.sequence_key]
            m_ids = torch.tensor(event_ids_from_file, dtype=torch.long)
            c_ids = torch.tensor(record["conditioning_chord_ids"], dtype=torch.long)
            c_root = torch.tensor(record["conditioning_root_pc"], dtype=torch.float)
            c_qual = torch.tensor(record["conditioning_quality_code"], dtype=torch.float)
            c_func = torch.tensor(record["conditioning_function_code"], dtype=torch.float)

            current_len = m_ids.shape[0]
            if self.max_len is not None and current_len > self.max_len:
                m_ids = m_ids[:self.max_len]
                c_ids = c_ids[:self.max_len]
                c_root = c_root[:self.max_len]
                c_qual = c_qual[:self.max_len]
                c_func = c_func[:self.max_len]
                logging.debug(f"Truncated sample {idx} from {current_len} to {self.max_len}")

            return {"event_ids": m_ids, "conditioning_chord_ids": c_ids, "conditioning_root_pc": c_root,
                    "conditioning_quality_code": c_qual, "conditioning_function_code": c_func}
        except KeyError as e:
             logging.error(f"ERROR in __getitem__ for index {idx}: Missing key '{e}'. Available keys: {list(self.samples[idx].keys()) if idx < len(self.samples) else 'Index out of bounds'}")
             return None
        except Exception as e:
            logging.error(f"ERROR in __getitem__ for index {idx}: {e}", exc_info=True)
            return None

def get_pad_value(key: str, melody_pad_id: int, chord_pad_id: int) -> float:
    if key == "event_ids": return float(melody_pad_id)
    elif key == "conditioning_chord_ids": return float(chord_pad_id) if chord_pad_id >= 0 else 0.0
    else: return 0.0

def melody_collate_fn(batch: List[Optional[Dict[str, torch.Tensor]]], melody_pad_id: int, chord_pad_id: int) -> Optional[Dict[str, torch.Tensor]]:
    valid_batch = [item for item in batch if item is not None]
    if not valid_batch:
        logging.debug("Collate function received an empty or all-None batch.")
        return None

    sequence_key = "event_ids"
    required_keys = {"event_ids", "conditioning_chord_ids", "conditioning_root_pc", "conditioning_quality_code", "conditioning_function_code"}
    first_item_keys = valid_batch[0].keys()
    if not required_keys.issubset(first_item_keys):
         logging.error(f"Required keys missing in first valid batch item for collation. Expected: {required_keys}, Got: {first_item_keys}")
         return None

    batch_data = collections.defaultdict(list)
    actual_items_in_batch = []

    for item_idx, item in enumerate(valid_batch):
        if not all(k in item for k in required_keys):
             logging.debug(f"Skipping item index {item_idx} in collate_fn due to missing keys: {required_keys - set(item.keys())}")
             continue
        seq_tensor = item[sequence_key]
        if not isinstance(seq_tensor, torch.Tensor) or seq_tensor.numel() == 0:
             logging.debug(f"Skipping item index {item_idx} in collate_fn due to invalid sequence tensor.")
             continue
        current_len = seq_tensor.shape[0]
        mismatched = False
        for cond_key in required_keys - {sequence_key}:
            if item[cond_key].shape[0] != current_len:
                logging.warning(f"Skipping item index {item_idx} due to length mismatch: '{sequence_key}' ({current_len}) vs '{cond_key}' ({item[cond_key].shape[0]})")
                mismatched = True; break
        if mismatched: continue
        actual_items_in_batch.append(item)

    if not actual_items_in_batch:
        logging.warning("Collate function: No valid items found in the batch to process.")
        return None

    for item in actual_items_in_batch:
         for key, tensor in item.items():
            if key in required_keys: batch_data[key].append(tensor)

    padded_batch = {}
    try:
        for key in required_keys:
             tensor_list = batch_data.get(key)
             if not tensor_list:
                 logging.error(f"Collate Error: Missing data for required key '{key}' after filtering batch."); return None
             pad_val = get_pad_value(key, melody_pad_id, chord_pad_id)
             padded_sequences = pad_sequence(tensor_list, batch_first=True, padding_value=pad_val)
             if key == sequence_key or key == "conditioning_chord_ids":
                 padded_batch[key] = padded_sequences.long()
             else:
                 padded_batch[key] = padded_sequences.float()

        final_bsz = len(actual_items_in_batch)
        if sequence_key in padded_batch and padded_batch[sequence_key].shape[0] != final_bsz:
            logging.error(f"Collate Error: Batch size mismatch after padding. Processed {final_bsz} items, but tensor has shape {padded_batch[sequence_key].shape[0]}.")
            return None
        if padded_batch[sequence_key].shape[0] == 0:
             logging.warning("Collate function produced a zero-size batch after processing."); return None
        return padded_batch
    except Exception as e:
        logging.error(f"ERROR during padding/stacking in collate_fn: {e}", exc_info=True); return None

# === Training & Evaluation Functions ===

def train_epoch(model: MelodyTransformerXL, dataloader: DataLoader, optimizer: torch.optim.Optimizer,
                criterion: nn.Module, scaler: GradScaler, epoch: int, config: TrainingConfig, device: torch.device):
    """Trains the model for one epoch using segmental training (stateless)."""
    model.train()
    total_loss, total_correct, total_tokens = 0.0, 0.0, 0.0
    epoch_start_time = time.time()
    amp_enabled = scaler.is_enabled()
    model_dtype = model.dtype

    try: num_batches = len(dataloader)
    except TypeError: num_batches = -1

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}/{config.num_epochs} Train", unit="batch", leave=False, disable=(num_batches <= 0))

    for batch_idx, batch in enumerate(progress_bar):
        # print(f"Processing Train Batch {batch_idx+1}/{num_batches}") # Debug print

        if batch is None:
             logging.warning(f"Skipping None batch at index {batch_idx} in training.")
             continue

        try:
            event_ids = batch["event_ids"].to(device, non_blocking=True)
            cond_ids = batch["conditioning_chord_ids"].to(device, non_blocking=True)
            cond_root = batch["conditioning_root_pc"].to(device=device, dtype=model_dtype, non_blocking=True)
            cond_qual = batch["conditioning_quality_code"].to(device=device, dtype=model_dtype, non_blocking=True)
            cond_func = batch["conditioning_function_code"].to(device=device, dtype=model_dtype, non_blocking=True)

            targets = event_ids[:, 1:].contiguous()
            model_input_ids = event_ids[:, :-1].contiguous()
            cond_ids = cond_ids[:, :-1].contiguous()
            cond_root = cond_root[:, :-1].contiguous()
            cond_qual = cond_qual[:, :-1].contiguous()
            cond_func = cond_func[:, :-1].contiguous()

            qlen = model_input_ids.shape[1]
            if qlen == 0:
                logging.warning(f"Skipping train batch {batch_idx}: sequence length is 0 after shifting.")
                continue

            current_mems = model.init_mems(event_ids.size(0), device, model_dtype)

            max_eid = model_input_ids.max().item() if qlen > 0 else -1
            max_cid = cond_ids.max().item() if qlen > 0 and model.use_chord_embedding else -1
            logging.debug(f"[Train Batch {batch_idx}] Max event ID: {max_eid} (Vocab: {config.melody_vocab_size}), Max chord ID: {max_cid} (Vocab: {config.chord_vocab_size})")
            if max_eid >= config.melody_vocab_size:
                logging.error(f"!! Index Error Pre-Check Failed: Max Event ID {max_eid} >= Vocab Size {config.melody_vocab_size} in Train Batch {batch_idx}")
                continue
            if model.use_chord_embedding and max_cid >= config.chord_vocab_size:
                logging.error(f"!! Index Error Pre-Check Failed: Max Chord ID {max_cid} >= Vocab Size {config.chord_vocab_size} in Train Batch {batch_idx}")
                continue

        except KeyError as e:
            logging.error(f"KeyError preparing training batch {batch_idx}: {e}. Available keys: {list(batch.keys())}. Skipping batch.")
            continue
        except Exception as e:
            logging.error(f"Error preparing training batch {batch_idx}: {e}. Skipping batch.", exc_info=True)
            continue

        # --- Forward Pass & Loss ---
        optimizer.zero_grad(set_to_none=True)
        amp_dtype = config.amp_dtype if (amp_enabled and config.amp_dtype is not None) else None
        device_type_str = device.type # 'cuda' or 'cpu'

        try:
            # *** CORRECTED autocast CALL ***
            with torch.amp.autocast(device_type_str, dtype=amp_dtype, enabled=amp_enabled):
                logits, _ = model(
                    event_ids=model_input_ids,
                    conditioning_chord_ids=cond_ids,
                    conditioning_root_pc=cond_root,
                    conditioning_quality_code=cond_qual,
                    conditioning_function_code=cond_func,
                    mems=current_mems
                )
                logits_flat = logits.view(-1, config.melody_vocab_size)
                targets_flat = targets.view(-1)
                targets_safe = targets_flat.clamp(0, config.melody_vocab_size - 1)
                loss = criterion(logits_flat.float(), targets_safe) # Loss still in float32

        except IndexError as e:
             logging.error(f"IndexError during model forward/loss in training (Batch {batch_idx}): {e}. ", exc_info=True)
             print(f"!!! IndexError in TRAIN Batch {batch_idx}: {e}") # Explicit print
             gc.collect(); torch.cuda.empty_cache(); continue
        except RuntimeError as e: # Catch CUDA errors specifically
             logging.error(f"RuntimeError during model forward/loss in training (Batch {batch_idx}): {e}", exc_info=True)
             print(f"!!! RuntimeError in TRAIN Batch {batch_idx}: {e}") # Explicit print
             if 'cuda' in str(e).lower(): gc.collect(); torch.cuda.empty_cache()
             continue # Skip this batch
        except Exception as e:
             logging.error(f"Generic Error during model forward/loss in training (Batch {batch_idx}): {e}", exc_info=True)
             print(f"!!! ERROR in TRAIN Batch {batch_idx}: {e}") # Explicit print
             if 'cuda' in str(e).lower(): gc.collect(); torch.cuda.empty_cache()
             continue # Skip this batch

        if not torch.isfinite(loss):
            logging.warning(f"Non-finite loss ({loss.item()}) at Train Batch {batch_idx}. Skipping backward/step.")
            print(f"!!! Non-finite loss in TRAIN Batch {batch_idx}: {loss.item()}") # Explicit print
            optimizer.zero_grad(set_to_none=True); gc.collect(); torch.cuda.empty_cache(); continue

        # --- Backward Pass & Optimization ---
        try:
            if amp_enabled:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_value)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_value)
                optimizer.step()
        except RuntimeError as e: # Specifically catch the scaler error
             logging.error(f"RuntimeError during backward/step in training (Batch {batch_idx}): {e}", exc_info=False) # Reduce traceback noise
             print(f"!!! ERROR during backward/step in TRAIN Batch {batch_idx}: {e}") # Explicit print
             optimizer.zero_grad(set_to_none=True)
             gc.collect(); torch.cuda.empty_cache()
             continue # Skip metrics update for this batch
        except Exception as e: # Catch other potential errors
             logging.error(f"Error during backward/step in training (Batch {batch_idx}): {e}", exc_info=True)
             print(f"!!! ERROR during backward/step in TRAIN Batch {batch_idx}: {e}") # Explicit print
             optimizer.zero_grad(set_to_none=True)
             gc.collect(); torch.cuda.empty_cache()
             continue


        # --- Metrics ---
        with torch.no_grad():
            # Check if logits were actually computed before calculating accuracy
            if 'logits_flat' in locals() and logits_flat is not None:
                acc, n_tokens = calculate_accuracy(logits_flat, targets_safe, config.melody_pad_token_id)
                if n_tokens > 0:
                    total_loss += loss.item() * n_tokens
                    total_correct += acc * n_tokens
                    total_tokens += n_tokens
            else:
                # This case should ideally not be reached if the batch wasn't skipped
                logging.warning(f"Logits were not computed for Train Batch {batch_idx}, skipping metric update.")


        # Update progress bar postfix
        if total_tokens > 0:
             avg_loss = total_loss / total_tokens
             avg_acc = total_correct / total_tokens
             progress_bar.set_postfix(loss=f"{avg_loss:.4f}", acc=f"{avg_acc:.4f}", lr=f"{optimizer.param_groups[0]['lr']:.1e}", refresh=True)
        else:
             progress_bar.set_postfix(loss="N/A", acc="N/A", lr=f"{optimizer.param_groups[0]['lr']:.1e}", refresh=True)

    progress_bar.close()
    final_loss = total_loss / total_tokens if total_tokens > 0 else 0.0
    final_acc = total_correct / total_tokens if total_tokens > 0 else 0.0
    train_summary = f"Epoch {epoch} Train Summary | Time: {time.time() - epoch_start_time:.2f}s | Loss: {final_loss:.4f} | Acc: {final_acc:.4f} | Tokens: {int(total_tokens)}"
    logging.info(train_summary)
    print(train_summary) # Explicit print for visibility
    return final_loss, final_acc


@torch.no_grad()
def evaluate_epoch(model: MelodyTransformerXL, dataloader: DataLoader, criterion: nn.Module,
                   config: TrainingConfig, device: torch.device) -> Dict:
    """Evaluates the model on a given dataset (stateless)."""
    model.eval()
    total_eval_loss, total_tokens_eval, total_correct_eval = 0.0, 0.0, 0.0
    all_preds_cpu: List[List[int]] = []
    all_targets_cpu: List[List[int]] = []
    amp_enabled = (config.amp_dtype == torch.float16 and device.type == 'cuda')
    device_type_str = device.type # 'cuda' or 'cpu'
    model_dtype = model.dtype

    try: num_batches = len(dataloader)
    except TypeError: num_batches = -1

    progress_bar = tqdm(dataloader, desc="Evaluating", unit="batch", leave=False, disable=(num_batches <= 0))

    for batch_idx, batch in enumerate(progress_bar):
        # print(f"Processing Eval Batch {batch_idx+1}/{num_batches}") # Debug print

        if batch is None:
             logging.warning(f"Skipping None batch at index {batch_idx} in evaluation.")
             continue

        try:
            event_ids = batch["event_ids"].to(device, non_blocking=True)
            cond_ids = batch["conditioning_chord_ids"].to(device, non_blocking=True)
            cond_root = batch["conditioning_root_pc"].to(device=device, dtype=model_dtype, non_blocking=True)
            cond_qual = batch["conditioning_quality_code"].to(device=device, dtype=model_dtype, non_blocking=True)
            cond_func = batch["conditioning_function_code"].to(device=device, dtype=model_dtype, non_blocking=True)

            targets = event_ids[:, 1:].contiguous()
            model_input_ids = event_ids[:, :-1].contiguous()

            qlen = model_input_ids.shape[1]
            if qlen == 0:
                logging.debug(f"Skipping eval batch {batch_idx}: sequence length is 0.")
                continue

            cond_ids = cond_ids[:, :-1].contiguous()
            cond_root = cond_root[:, :-1].contiguous()
            cond_qual = cond_qual[:, :-1].contiguous()
            cond_func = cond_func[:, :-1].contiguous()

            # *** DEBUG LOGGING: Check max IDs just before model call ***
            max_eid = model_input_ids.max().item() if qlen > 0 else -1
            max_cid = cond_ids.max().item() if qlen > 0 and model.use_chord_embedding else -1
            logging.debug(f"[Eval Batch {batch_idx}] Max event ID: {max_eid} (Vocab: {config.melody_vocab_size}), Max chord ID: {max_cid} (Vocab: {config.chord_vocab_size})")
            if max_eid >= config.melody_vocab_size:
                logging.error(f"!! Index Error Pre-Check Failed: Max Event ID {max_eid} >= Vocab Size {config.melody_vocab_size} in Eval Batch {batch_idx}")
                continue
            if model.use_chord_embedding and max_cid >= config.chord_vocab_size:
                logging.error(f"!! Index Error Pre-Check Failed: Max Chord ID {max_cid} >= Vocab Size {config.chord_vocab_size} in Eval Batch {batch_idx}")
                continue

        except KeyError as e:
            logging.error(f"KeyError preparing evaluation batch {batch_idx}: {e}. Available keys: {list(batch.keys())}. Skipping batch.")
            continue
        except Exception as e:
            logging.error(f"Error preparing evaluation batch {batch_idx}: {e}. Skipping batch.", exc_info=True)
            continue

        # --- Forward Pass & Loss (Stateless) ---
        current_mems = None
        amp_dtype = config.amp_dtype if (amp_enabled and config.amp_dtype is not None) else None

        try:
            # *** CORRECTED autocast CALL ***
            with torch.amp.autocast(device_type_str, dtype=amp_dtype, enabled=amp_enabled):
                logits, _ = model(
                    event_ids=model_input_ids,
                    conditioning_chord_ids=cond_ids, conditioning_root_pc=cond_root,
                    conditioning_quality_code=cond_qual, conditioning_function_code=cond_func,
                    mems=current_mems
                )
                logits_flat = logits.view(-1, config.melody_vocab_size)
                targets_flat = targets.view(-1)
                targets_safe = targets_flat.clamp(0, config.melody_vocab_size - 1)
                loss = criterion(logits_flat.float(), targets_safe) # Loss still in float32

        except IndexError as e:
             logging.error(f"IndexError during model forward/loss in evaluation (Batch {batch_idx}): {e}. ", exc_info=True)
             print(f"!!! IndexError in EVAL Batch {batch_idx}: {e}") # Explicit print
             gc.collect(); torch.cuda.empty_cache(); continue
        except RuntimeError as e: # Catch CUDA errors specifically
             logging.error(f"RuntimeError during model forward/loss in evaluation (Batch {batch_idx}): {e}", exc_info=True)
             print(f"!!! RuntimeError in EVAL Batch {batch_idx}: {e}") # Explicit print
             if 'cuda' in str(e).lower(): gc.collect(); torch.cuda.empty_cache()
             continue # Skip this batch
        except Exception as e:
             logging.error(f"Generic Error during model forward/loss in evaluation (Batch {batch_idx}): {e}", exc_info=True)
             print(f"!!! ERROR in EVAL Batch {batch_idx}: {e}") # Explicit print
             if 'cuda' in str(e).lower(): gc.collect(); torch.cuda.empty_cache()
             continue # Skip this batch

        # --- Metrics ---
        if torch.isfinite(loss):
             acc, n_tokens = calculate_accuracy(logits_flat, targets_safe, config.melody_pad_token_id)
             if n_tokens > 0:
                  total_eval_loss += loss.item() * n_tokens
                  total_correct_eval += acc * n_tokens
                  total_tokens_eval += n_tokens
                  if total_tokens_eval > 0:
                     avg_loss = total_eval_loss / total_tokens_eval
                     avg_acc_eval = total_correct_eval / total_tokens_eval
                     progress_bar.set_postfix(loss=f"{avg_loss:.4f}", acc=f"{avg_acc_eval:.4f}", refresh=True)
        else:
             logging.warning(f"Non-finite loss ({loss.item()}) encountered in Eval Batch {batch_idx}.")
             print(f"!!! Non-finite loss in EVAL Batch {batch_idx}: {loss.item()}") # Explicit print

        # Store predictions and targets
        if 'logits' in locals() and logits is not None:
             preds = logits.argmax(dim=-1)
             all_preds_cpu.extend(preds.cpu().tolist())
             all_targets_cpu.extend(targets.cpu().tolist())
        else:
             logging.warning(f"Logits not generated for Eval Batch {batch_idx}, cannot store predictions.")


    progress_bar.close()

    # --- Calculate Final Metrics ---
    logging.info(" Aggregating & Calculating Final Eval Metrics on CPU...")
    final_loss = total_eval_loss / total_tokens_eval if total_tokens_eval > 0 else float('inf')
    final_acc = total_correct_eval / total_tokens_eval if total_tokens_eval > 0 else 0.0
    ssmd_sum, gs_sum, metric_count = 0.0, 0.0, 0

    if not all_targets_cpu or total_tokens_eval == 0:
        logging.warning("No sequences successfully processed for SSMD/GS metric calculation.")
        final_ssmd, final_gs = 1.0, 0.0
    else:
        pad_id = config.melody_pad_token_id
        cpu_loop = tqdm(zip(all_preds_cpu, all_targets_cpu), total=len(all_targets_cpu), desc="Calculating SSMD/GS", leave=False, disable=True)
        for p_seq, t_seq in cpu_loop:
            try:
                p_seq_int = [int(p) for p in p_seq]
                t_seq_int = [int(t) for t in t_seq]
                ssmd = self_similarity_matrix_distance(p_seq_int, t_seq_int, pad_id)
                gs = grooving_similarity(p_seq_int, t_seq_int, pad_id)
                if not math.isnan(ssmd) and not math.isnan(gs):
                    ssmd_sum += ssmd
                    gs_sum += gs
                    metric_count += 1
                else:
                    logging.warning(f"NaN encountered during SSMD/GS calculation. Skipping pair.")
            except Exception as e:
                logging.warning(f"Error calculating metrics for one sequence pair: {e}")
                continue

        final_ssmd = ssmd_sum / metric_count if metric_count > 0 else 1.0
        final_gs = gs_sum / metric_count if metric_count > 0 else 0.0

    eval_summary = f" Eval Summary | Loss: {final_loss:.4f} | Accuracy: {final_acc:.4f} | SSMD: {final_ssmd:.4f} | GS: {final_gs:.4f} | Eval Seq Count: {metric_count} | Eval Tokens: {int(total_tokens_eval)}"
    logging.info(eval_summary)
    print(eval_summary) # Explicit print
    return {"loss": final_loss, "accuracy": final_acc, "ssmd": final_ssmd, "gs": final_gs, "count": metric_count, "tokens": total_tokens_eval}


# === Checkpoint Functions ===
def save_checkpoint(state: Dict, filepath: str):
    """Saves model checkpoint atomically."""
    try:
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        tmp_filepath = filepath + ".tmp"
        torch.save(state, tmp_filepath)
        try:
            os.replace(tmp_filepath, filepath)
        except OSError:
             os.rename(tmp_filepath, filepath)
        logging.info(f" Checkpoint saved successfully to {filepath}")
    except Exception as e:
        logging.error(f"ERROR saving checkpoint to {filepath}: {e}", exc_info=True)
        if os.path.exists(tmp_filepath):
            try: os.remove(tmp_filepath)
            except OSError as remove_err:
                logging.error(f"Error removing temporary checkpoint file {tmp_filepath}: {remove_err}")

def load_checkpoint(filepath: str, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, scaler: Optional[GradScaler] = None, device: torch.device = torch.device("cpu")) -> Tuple[int, float]:
    """Loads checkpoint. Returns start_epoch and best_val_loss."""
    start_epoch = 0
    best_val_loss = float('inf')
    if os.path.exists(filepath):
        try:
            logging.info(f"Attempting to load checkpoint from: {filepath}")
            checkpoint = torch.load(filepath, map_location=device, weights_only=False)
            logging.info(f"Successfully loaded checkpoint file.")

            # --- Load Model State Dict ---
            if 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
                is_parallel_ckpt = all(key.startswith('module.') for key in state_dict)
                current_is_parallel = isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel))

                if is_parallel_ckpt and not current_is_parallel:
                    logging.info("Removing 'module.' prefix from checkpoint state_dict keys.")
                    state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
                elif not is_parallel_ckpt and current_is_parallel:
                     logging.info("Adding 'module.' prefix to checkpoint state_dict keys for DataParallel/DDP model.")
                     state_dict = {'module.' + k: v for k, v in state_dict.items()}

                missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
                if unexpected_keys: logging.warning(f" Checkpoint contained unexpected keys: {unexpected_keys}")
                if missing_keys: logging.warning(f" Checkpoint was missing keys for the current model: {missing_keys}")
                logging.info(" Model state loaded from checkpoint.")
            else:
                logging.warning(" Checkpoint does not contain 'model_state_dict'. Model weights not loaded.")

            # --- Load Optimizer State Dict ---
            if optimizer and 'optimizer_state_dict' in checkpoint:
                try:
                    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                    for state in optimizer.state.values():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.to(device)
                    logging.info(" Optimizer state loaded from checkpoint.")
                except Exception as e:
                    logging.warning(f" Could not load optimizer state: {e}. Optimizer state will be re-initialized.")
                    optimizer.state = collections.defaultdict(dict)
            elif optimizer:
                logging.info(" No optimizer state found in checkpoint. Optimizer will start from scratch.")

            # --- Load GradScaler State Dict ---
            if scaler and scaler.is_enabled() and 'scaler_state_dict' in checkpoint and checkpoint['scaler_state_dict'] is not None:
                 try:
                     scaler.load_state_dict(checkpoint['scaler_state_dict'])
                     logging.info(" GradScaler state loaded from checkpoint.")
                 except Exception as e:
                     logging.warning(f" Could not load GradScaler state: {e}. Scaler state will be re-initialized.")
                     scaler._init_state()
            elif scaler and scaler.is_enabled():
                logging.info(" No GradScaler state found or scaler is disabled. Scaler will start from scratch.")

            # --- Load Epoch and Best Loss ---
            if 'epoch' in checkpoint:
                 start_epoch = checkpoint['epoch'] + 1
            else:
                 start_epoch = 0
                 logging.warning(" Checkpoint missing 'epoch' information. Assuming start from epoch 0.")

            best_val_loss = checkpoint.get('best_val_loss', float('inf'))
            logging.info(f" Checkpoint loaded. Resuming training from epoch {start_epoch + 1}. Best validation loss recorded in checkpoint: {best_val_loss:.4f}")

        except Exception as e:
            logging.error(f"ERROR loading checkpoint '{filepath}': {e}. Starting from scratch.", exc_info=True)
            start_epoch = 0
            best_val_loss = float('inf')
            if optimizer: optimizer.state = collections.defaultdict(dict)
            if scaler: scaler._init_state()
    else:
        logging.info(f"No checkpoint found at '{filepath}'. Starting training from scratch (epoch 1).")
        start_epoch = 0

    start_epoch = max(0, start_epoch)
    return start_epoch, best_val_loss


# === Main Execution Block ===
if __name__ == '__main__':
    # Setup logging - Use DEBUG for more detailed output during troubleshooting
    logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')

    main_start_time = time.time()
    try:
        tz = pytz.timezone('UTC')
        now = datetime.datetime.now(tz)
        current_time_str = now.strftime("%Y-%m-%d %H:%M:%S %Z")
        print(f"Script Execution Start Time: {current_time_str}")
    except Exception as e:
        print(f"Warning: Could not set timezone context using pytz: {e}")
        now = datetime.datetime.now()
        current_time_str = now.strftime("%Y-%m-%d %H:%M:%S")
        print(f"Script Execution Start Time (Local): {current_time_str}")


    # --- Configuration ---
    config = TrainingConfig()
    print("\n--- Configuration ---")
    for key, value in sorted(config.__dict__.items()): print(f"  {key}: {value}")
    print("-" * 30)

    # --- Setup ---
    set_seed(config.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
    print(f"\nUsing device: {device} ({gpu_name})")
    if device.type == 'cuda':
        print(f"CUDA Compute Capability: {torch.cuda.get_device_capability(device)}")
        print(f"PyTorch Version: {torch.__version__}")
        print(f"CUDA available: {torch.cuda.is_available()}, version: {torch.version.cuda}")
        print(f"cuDNN enabled: {torch.backends.cudnn.enabled}, version: {torch.backends.cudnn.version()}")
        # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' # Uncomment if OOM persists
        # logging.info("Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True")

    os.makedirs(config.output_dir, exist_ok=True)
    latest_checkpoint_path = os.path.join(config.output_dir, "latest_checkpoint.pth")
    best_model_path = os.path.join(config.output_dir, "best_model.pth")

    # --- Load Vocabs & Determine Initial Config Sizes ---
    print("\n--- Loading Vocabularies ---")
    logging.info(f"Loading melody vocab from: {config.melody_vocab_path}")
    logging.info(f"Loading chord vocab from: {Path(config.chord_data_dir) / CHORD_VOCAB_FILENAME}")
    melody_vocab = None; chord_vocab = None
    try:
        melody_vocab_p = Path(config.melody_vocab_path)
        if not melody_vocab_p.is_file(): raise FileNotFoundError(f"Melody vocab file NOT FOUND: {melody_vocab_p}")
        with open(melody_vocab_p, 'r', encoding='utf-8') as f: melody_vocab = json.load(f)
        config.melody_vocab_size = len(melody_vocab)
        pad_id_lookup = melody_vocab.get(MELODY_PAD_TOKEN)
        unk_id_lookup = melody_vocab.get(MELODY_UNK_TOKEN)
        if pad_id_lookup is None: logging.warning(f"Melody PAD token '{MELODY_PAD_TOKEN}' not found in vocab! Using default 0."); config.melody_pad_token_id = 0
        else: config.melody_pad_token_id = int(pad_id_lookup)
        if unk_id_lookup is None: logging.warning(f"Melody UNK token '{MELODY_UNK_TOKEN}' not found in vocab! Using default ID {config.melody_pad_token_id + 1}.")
        logging.info(f" Initial Melody Vocab Size (from file): {config.melody_vocab_size}, Pad ID: {config.melody_pad_token_id}")

        chord_vocab_path = Path(config.chord_data_dir) / CHORD_VOCAB_FILENAME
        if not chord_vocab_path.is_file(): raise FileNotFoundError(f"Chord vocab file NOT FOUND: {chord_vocab_path}")
        with open(chord_vocab_path, 'r', encoding='utf-8') as f: chord_vocab = json.load(f)
        config.chord_vocab_size = len(chord_vocab)
        pad_label = MELODY_PAD_TOKEN
        pad_id_lookup = chord_vocab.get(pad_label)
        if pad_id_lookup is None: logging.warning(f"Chord PAD token '{pad_label}' not found in chord vocab! Using default 0."); config.chord_pad_token_id = 0
        else: config.chord_pad_token_id = int(pad_id_lookup)
        logging.info(f" Initial Chord Vocab Size (from file): {config.chord_vocab_size}, Pad ID: {config.chord_pad_token_id}")

    except FileNotFoundError as e: logging.critical(f"{e}"); sys.exit(1)
    except Exception as e: logging.critical(f"Error loading vocabularies: {e}", exc_info=True); sys.exit(1)

    # --- Load Dataset & Finalize Vocab Sizes ---
    print("\n--- Loading & Processing Dataset ---")
    train_set, val_set, test_set = None, None, None; dataset = None
    train_loader, val_loader, test_loader = None, None, None
    train_size, val_size, test_size = 0, 0, 0
    max_melody_id_found = -1
    max_chord_id_found = -1
    try:
        if melody_vocab is None: raise RuntimeError("Melody vocabulary was not loaded.")
        melody_data_p = Path(config.melody_data_path)
        if not melody_data_p.is_file(): raise FileNotFoundError(f"Melody data file NOT FOUND: {melody_data_p}")

        logging.info(f"Initializing dataset using key 'event_ids'...")
        dataset = MelodyDataset(config.melody_data_path, melody_vocab, config)
        total_size = len(dataset)
        if total_size == 0: raise ValueError("Dataset is empty after initialization.")

        logging.info("Checking maximum IDs in the loaded dataset...")
        max_melody_id_found = 0
        max_chord_id_found = -1
        chord_embeddings_enabled = (config.chord_emb_dim is not None and config.chord_emb_dim > 0)
        logging.info(f"Chord embeddings configured: {chord_embeddings_enabled}")

        for sample in tqdm(dataset.samples, desc="Scanning IDs", unit="samples"):
            event_ids_list = sample.get("event_ids")
            if event_ids_list:
                 try:
                     current_max = max(event_ids_list) if event_ids_list else -1
                     max_melody_id_found = max(max_melody_id_found, current_max)
                 except (ValueError, TypeError) as e:
                     logging.debug(f"Could not find max in event_ids for a sample: {e}.")
                     continue
            if chord_embeddings_enabled:
                chord_ids_list = sample.get("conditioning_chord_ids")
                if chord_ids_list:
                    try:
                         current_max_chord = max(chord_ids_list) if chord_ids_list else -1
                         max_chord_id_found = max(max_chord_id_found, current_max_chord)
                    except (ValueError, TypeError) as e:
                         logging.debug(f"Could not find max in conditioning_chord_ids: {e}.")
                         continue

        logging.info(f"Max Melody ID found in data: {max_melody_id_found}")
        if chord_embeddings_enabled:
            logging.info(f"Max Chord ID found in data: {max_chord_id_found}")

        required_melody_vocab_size = max_melody_id_found + 1
        if config.melody_vocab_size < required_melody_vocab_size:
            logging.warning(f"Melody vocab size from file ({config.melody_vocab_size}) is smaller than required ({required_melody_vocab_size}). Adjusting config.melody_vocab_size.")
            config.melody_vocab_size = required_melody_vocab_size
        else:
            logging.info(f"Melody vocab size {config.melody_vocab_size} is sufficient.")

        if chord_embeddings_enabled:
            required_chord_vocab_size = max_chord_id_found + 1
            if config.chord_vocab_size < required_chord_vocab_size:
                logging.warning(f"Chord vocab size from file ({config.chord_vocab_size}) is smaller than required ({required_chord_vocab_size}). Adjusting config.chord_vocab_size.")
                config.chord_vocab_size = required_chord_vocab_size
            else:
                 logging.info(f"Chord vocab size {config.chord_vocab_size} is sufficient.")
        elif config.chord_vocab_size == 0 and chord_vocab is not None:
             config.chord_vocab_size = len(chord_vocab)
             logging.info(f"Chord embeddings not enabled. Set chord_vocab_size to {config.chord_vocab_size} from file.")
        else:
             if config.chord_vocab_size <= 0: config.chord_vocab_size = 1
             logging.info(f"Chord embeddings not enabled. Chord vocab size set to {config.chord_vocab_size}.")

        if config.melody_pad_token_id < 0 or config.melody_pad_token_id >= config.melody_vocab_size:
             raise ValueError(f"Melody pad token ID {config.melody_pad_token_id} is out of bounds for final vocab size {config.melody_vocab_size}")
        if chord_embeddings_enabled:
             if config.chord_pad_token_id < 0 or config.chord_pad_token_id >= config.chord_vocab_size:
                 raise ValueError(f"Chord pad token ID {config.chord_pad_token_id} is out of bounds for final chord vocab size {config.chord_vocab_size}")

        print("\n--- Final Configuration ---")
        print(f"  Adjusted Melody Vocab Size: {config.melody_vocab_size}")
        print(f"  Adjusted Chord Vocab Size: {config.chord_vocab_size}")
        print(f"  Melody Pad ID: {config.melody_pad_token_id}")
        print(f"  Chord Pad ID: {config.chord_pad_token_id}")
        print(f"  Max Sequence Length: {config.max_seq_len}")
        print(f"  Batch Size: {config.batch_size}")
        print("-" * 30)

        logging.info("Splitting dataset...")
        val_size = int(config.val_split * total_size)
        test_size = int(config.test_split * total_size)
        val_size = max(0, val_size)
        test_size = max(0, test_size)

        if total_size < 3 and (val_size > 0 or test_size > 0):
            logging.warning(f"Dataset size ({total_size}) too small for validation/test split. Using all data for training.")
            train_size = total_size; val_size = 0; test_size = 0
        elif total_size - val_size - test_size <= 0:
             logging.warning(f"Train split size calculated as non-positive. Adjusting splits.")
             if total_size - val_size > 0:
                 test_size = 0; train_size = total_size - val_size
             else:
                 val_size = 0; test_size = 0; train_size = total_size
             logging.warning(f"Adjusted split sizes: Train={train_size}, Val={val_size}, Test={test_size}")
        else:
             train_size = total_size - val_size - test_size

        logging.info(f"Splitting into: Train={train_size}, Val={val_size}, Test={test_size}")

        if train_size + val_size + test_size != total_size:
             train_size = total_size - val_size - test_size
             logging.warning(f"Corrected train_size due to rounding: {train_size}")

        if train_size <= 0 and config.num_epochs > 0:
            logging.error("No training samples available after split, but num_epochs > 0. Exiting.")
            sys.exit(1)
        if val_size == 0: logging.warning("No validation samples after split.")
        if test_size == 0: logging.warning("No test samples after split.")

        if total_size > 0 and (train_size >= 0 and val_size >= 0 and test_size >= 0):
             train_set, val_set, test_set = random_split(
                 dataset, [train_size, val_size, test_size],
                 generator=torch.Generator().manual_seed(config.seed) )
             logging.info("Dataset split successfully.")
        elif total_size == 0:
             logging.warning("Dataset is empty, cannot split.")
             train_set, val_set, test_set = [], [], []
        else:
             logging.error(f"Invalid split sizes calculated (Train: {train_size}, Val: {val_size}, Test: {test_size}). Exiting.")
             sys.exit(1)

    except FileNotFoundError as e: logging.critical(f"{e}"); sys.exit(1)
    except ValueError as e: logging.critical(f"Error during dataset loading/checking/splitting: {e}", exc_info=True); sys.exit(1)
    except Exception as e: logging.critical(f"Error creating or splitting dataset: {e}", exc_info=True); sys.exit(1)

    # --- Create DataLoaders ---
    print("\n--- Creating DataLoaders ---")
    pin_memory = device.type == 'cuda'
    persistent_workers = config.num_dataload_workers > 0 and pin_memory
    collate_wrapper = lambda batch: melody_collate_fn(batch, config.melody_pad_token_id, config.chord_pad_token_id)

    loader_args = {'batch_size': config.batch_size,
                   'collate_fn': collate_wrapper,
                   'num_workers': config.num_dataload_workers,
                   'pin_memory': pin_memory,
                   'persistent_workers': persistent_workers if config.num_dataload_workers > 0 else False,
                   'prefetch_factor': 2 if config.num_dataload_workers > 0 else None,
                   'timeout': 120 if config.num_dataload_workers > 0 else 0
                   }
    if config.num_dataload_workers == 0:
         loader_args.pop('prefetch_factor', None)
         loader_args.pop('persistent_workers', None)
         loader_args.pop('timeout', None)

    try:
        if train_set and len(train_set) > 0:
            train_loader = DataLoader(train_set, shuffle=True, drop_last=True, **loader_args)
            logging.info(f"Train loader created with {len(train_loader)} batches.")
        else:
            logging.warning("Train set is empty or None. Train loader not created.")

        if val_set and len(val_set) > 0:
             val_loader = DataLoader(val_set, shuffle=False, drop_last=False, **loader_args)
             logging.info(f"Validation loader created with {len(val_loader)} batches.")
        else:
             logging.info("Validation set is empty or None. Validation loader not created.")

        if test_set and len(test_set) > 0:
            test_loader = DataLoader(test_set, shuffle=False, drop_last=False, **loader_args)
            logging.info(f"Test loader created with {len(test_loader)} batches.")
        else:
            logging.info("Test set is empty or None. Test loader not created.")

        if not train_loader and config.num_epochs > 0:
            logging.error("Training requested (num_epochs > 0) but no training data available/loader created.")
            sys.exit(1)

    except Exception as e:
        logging.critical(f"Error creating DataLoaders: {e}", exc_info=True)
        sys.exit(1)


    # --- Initialize Model, Optimizer, Loss, Scaler ---
    print("\n--- Initializing Training Components ---")
    try:
        model = MelodyTransformerXL(config).to(device) # Keep model in default dtype (float32)

        # <<< REMOVED EXPLICIT .to(dtype=torch.float16) >>>
        # if config.amp_dtype == torch.float16 and device.type == 'cuda':
        #     # model = model.to(dtype=torch.float16) # Don't do this when using autocast+GradScaler
        #     logging.info("Model kept in float32, autocast will handle fp16.")

        optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        criterion = nn.CrossEntropyLoss(ignore_index=config.melody_pad_token_id)
        amp_enabled = (config.amp_dtype == torch.float16 and device.type == 'cuda')
        scaler = GradScaler(enabled=amp_enabled) # Uses default cuda device if available
        logging.info(f"AMP Enabled: {scaler.is_enabled()}")

        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        logging.info(f"Model Total Trainable Params: {num_params / 1e6:.2f} M")

    except ValueError as e:
        logging.critical(f"Configuration error during model initialization: {e}", exc_info=True); sys.exit(1)
    except Exception as e:
        logging.critical(f"Error initializing model/optimizer/criterion: {e}", exc_info=True); sys.exit(1)

    # --- Load Checkpoint ---
    start_epoch, best_val_loss = load_checkpoint(latest_checkpoint_path, model, optimizer, scaler, device)

    # --- Training Loop ---
    print(f"\n--- Starting Training from Epoch {start_epoch + 1} / {config.num_epochs} ---")
    if config.num_epochs <= start_epoch:
         logging.info(f"Target epochs ({config.num_epochs}) already reached by checkpoint (next epoch would be {start_epoch + 1}). Skipping training loop.")
    elif not train_loader:
         logging.warning("No training data loader available. Skipping training loop.")
    else:
        for epoch in range(start_epoch, config.num_epochs):
            epoch_num = epoch + 1
            print(f"\n===== Epoch {epoch_num}/{config.num_epochs} =====")

            # --- Train ---
            try:
                 train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, scaler, epoch_num, config, device)
                 gc.collect()
                 if device.type == 'cuda': torch.cuda.empty_cache()
            except Exception as e:
                 logging.error(f"Critical Error during training epoch {epoch_num}: {e}", exc_info=True)
                 print(f"Epoch {epoch_num} failed critically during training. Stopping.")
                 sys.exit(1)

            # --- Validate ---
            current_val_loss = float('inf')
            val_acc = 0.0
            if val_loader and len(val_loader) > 0:
                try:
                    val_metrics = evaluate_epoch(model, val_loader, criterion, config, device)
                    if isinstance(val_metrics, dict) and "loss" in val_metrics and torch.isfinite(torch.tensor(val_metrics["loss"])):
                         current_val_loss = val_metrics["loss"]
                         val_acc = val_metrics.get("accuracy", 0.0)
                    else:
                         logging.warning(f"Epoch {epoch_num} Validation did not return valid results. Treating as high loss.")
                         current_val_loss = float('inf')
                except Exception as e:
                     logging.error(f"Error during validation epoch {epoch_num}: {e}", exc_info=True)
                     print(f"Validation for epoch {epoch_num} failed. Treating as high loss.")
                     current_val_loss = float('inf')
                gc.collect()
                if device.type == 'cuda': torch.cuda.empty_cache()

                # --- Save Best Model ---
                if torch.isfinite(torch.tensor(current_val_loss)) and current_val_loss < best_val_loss:
                    logging.info(f"** Validation Loss Improved ({best_val_loss:.4f} -> {current_val_loss:.4f}). Saving best model... **")
                    print(f"** Validation Loss Improved ({best_val_loss:.4f} -> {current_val_loss:.4f}). Saving best model... **")
                    best_val_loss = current_val_loss
                    model_state_to_save = model.module.state_dict() if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)) else model.state_dict()
                    save_checkpoint({
                        'epoch': epoch,
                        'model_state_dict': model_state_to_save,
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scaler_state_dict': scaler.state_dict() if scaler is not None and scaler.is_enabled() else None,
                        'best_val_loss': best_val_loss,
                        'config': config.__dict__
                    }, best_model_path)
                else:
                    logging.info(f"Validation loss ({current_val_loss:.4f}) did not improve from best ({best_val_loss:.4f}).")
            else:
                logging.info("Skipping validation (no validation data/loader).")

            # --- Save Latest Checkpoint ---
            logging.info("Saving latest checkpoint...")
            model_state_to_save = model.module.state_dict() if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)) else model.state_dict()
            scaler_state = scaler.state_dict() if scaler is not None and scaler.is_enabled() else None
            save_checkpoint({
                'epoch': epoch,
                'model_state_dict': model_state_to_save,
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler_state,
                'best_val_loss': best_val_loss,
                'config': config.__dict__
            }, latest_checkpoint_path)

    print("\n--- Training Finished ---")

    # --- Final Test Evaluation ---
    print("\n--- Final Evaluation on Test Set ---")
    if test_loader and len(test_loader) > 0:
        logging.info("Loading best model weights for final evaluation...")
        try:
             final_model = MelodyTransformerXL(config).to(device)
             # No explicit dtype conversion needed here either if using autocast for eval
             logging.info("Model for final evaluation initialized.")
        except Exception as e:
             logging.critical(f"Failed to initialize model for final evaluation: {e}", exc_info=True)
             sys.exit(1)

        if os.path.exists(best_model_path):
            try:
                logging.info(f"Loading best model state from: {best_model_path}")
                checkpoint = torch.load(best_model_path, map_location=device, weights_only=False)
                state_dict = checkpoint.get('model_state_dict')

                if state_dict:
                    is_parallel_ckpt = all(key.startswith('module.') for key in state_dict)
                    is_current_parallel = isinstance(final_model, (nn.DataParallel, nn.parallel.DistributedDataParallel))
                    if is_parallel_ckpt and not is_current_parallel:
                        logging.info("Removing 'module.' prefix from best model state_dict for testing.")
                        state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
                    elif not is_parallel_ckpt and is_current_parallel:
                         logging.info("Adding 'module.' prefix to best model state_dict for testing.")
                         state_dict = {'module.' + k: v for k, v in state_dict.items()}

                    missing, unexpected = final_model.load_state_dict(state_dict, strict=False)
                    if missing: logging.warning(f"Final Test: Best model loaded with missing keys: {missing}")
                    if unexpected: logging.warning(f"Final Test: Best model loaded with unexpected keys: {unexpected}")
                    loaded_epoch = checkpoint.get('epoch', -1) + 1
                    loaded_val_loss = checkpoint.get('best_val_loss', float('inf'))
                    logging.info(f"Best model state (from epoch {loaded_epoch}, val_loss {loaded_val_loss:.4f}) loaded successfully for testing.")

                    try:
                         test_metrics = evaluate_epoch(final_model, test_loader, criterion, config, device)
                         print("\n--- Test Set Results (using BEST saved model) ---")
                         if isinstance(test_metrics, dict):
                             print(f" Loss:     {test_metrics.get('loss', float('nan')):.4f}")
                             print(f" Accuracy: {test_metrics.get('accuracy', float('nan')):.4f}")
                             print(f" SSMD:     {test_metrics.get('ssmd', float('nan')):.4f}")
                             print(f" GS:       {test_metrics.get('gs', float('nan')):.4f}")
                             print(f" Eval Seqs:{test_metrics.get('count', 0)}")
                             print(f" Eval Toks:{int(test_metrics.get('tokens', 0))}")
                         else: print("Test evaluation failed or produced no metrics.")
                    except Exception as e:
                         logging.error(f"Error during final test evaluation run: {e}", exc_info=True)
                         print(f"!!! Final test evaluation run failed: {e}")

                else:
                    logging.error(f"Best model checkpoint '{best_model_path}' did not contain 'model_state_dict'. Cannot perform final test on best model.")

            except Exception as e:
                logging.error(f"Could not load or evaluate best model from {best_model_path}: {e}. Skipping final test.", exc_info=True)
                print(f"!!! Failed to load or test best model: {e}")
        else:
            logging.warning(f"Best model checkpoint '{best_model_path}' not found. Skipping final test evaluation on the best model.")


        gc.collect()
        if device.type == 'cuda': torch.cuda.empty_cache()
    else:
        print("Skipping test evaluation (no test data loader or test split is empty).")

    main_end_time = time.time()
    total_runtime = main_end_time - main_start_time
    logging.info(f"Script finished. Total Runtime: {total_runtime // 3600:.0f}h {(total_runtime % 3600) // 60:.0f}m {total_runtime % 60:.2f}s")
    print(f"Output files potentially saved in: {config.output_dir}")
    print("="*70)

Script Execution Start Time: 2025-04-25 11:18:56 UTC

--- Configuration ---
  amp_dtype: torch.float16
  batch_size: 16
  chord_data_dir: /kaggle/input/advance-h-rpe
  chord_emb_dim: 64
  chord_pad_token_id: 0
  chord_vocab_size: 0
  condition_proj_dim: 128
  d_head: 64
  d_inner: 2048
  d_model: 512
  dropout: 0.1
  grad_clip_value: 1.0
  learning_rate: 0.0003
  max_seq_len: 512
  melody_data_path: /kaggle/input/new-melody-model-new-approach-1/training_data.jsonl
  melody_pad_token_id: 0
  melody_vocab_path: /kaggle/input/new-melody-model-new-approach-1/event_vocab.json
  melody_vocab_size: 0
  mem_len: 256
  midi_root_dir: LOCAL_PATH_IGNORE
  n_head: 8
  n_layer: 8
  num_chord_features: 3
  num_dataload_workers: 2
  num_epochs: 100
  output_dir: /kaggle/working/melody_model_output
  rope_theta: 10000.0
  seed: 42
  test_split: 0.05
  train_split: 0.9
  val_split: 0.05
  weight_decay: 0.01
------------------------------

Using device: cuda (Tesla P100-PCIE-16GB)
CUDA Compute Capabilit

Scanning IDs: 100%|██████████| 1544/1544 [00:00<00:00, 13112.14samples/s]


--- Final Configuration ---
  Adjusted Melody Vocab Size: 306
  Adjusted Chord Vocab Size: 44734
  Melody Pad ID: 0
  Chord Pad ID: 0
  Max Sequence Length: 512
  Batch Size: 16
------------------------------

--- Creating DataLoaders ---

--- Initializing Training Components ---



  scaler = GradScaler(enabled=amp_enabled) # Uses default cuda device if available



--- Starting Training from Epoch 1 / 100 ---

===== Epoch 1/100 =====


                                                                                                          

Epoch 1 Train Summary | Time: 28.84s | Loss: 4.5163 | Acc: 0.1653 | Tokens: 660634


                                                                                     

 Eval Summary | Loss: 4.4851 | Accuracy: 0.1567 | SSMD: 0.9879 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (inf -> 4.4851). Saving best model... **

===== Epoch 2/100 =====


                                                                                                          

Epoch 2 Train Summary | Time: 27.78s | Loss: 4.3825 | Acc: 0.1671 | Tokens: 659747


                                                                                     

 Eval Summary | Loss: 4.2148 | Accuracy: 0.1630 | SSMD: 0.9722 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (4.4851 -> 4.2148). Saving best model... **

===== Epoch 3/100 =====


                                                                                                          

Epoch 3 Train Summary | Time: 27.78s | Loss: 3.9663 | Acc: 0.1800 | Tokens: 660330


                                                                                     

 Eval Summary | Loss: 3.7437 | Accuracy: 0.1684 | SSMD: 0.8932 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (4.2148 -> 3.7437). Saving best model... **

===== Epoch 4/100 =====


                                                                                                          

Epoch 4 Train Summary | Time: 27.77s | Loss: 3.5225 | Acc: 0.1934 | Tokens: 660293


                                                                                     

 Eval Summary | Loss: 3.3713 | Accuracy: 0.1874 | SSMD: 0.7452 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (3.7437 -> 3.3713). Saving best model... **

===== Epoch 5/100 =====


                                                                                                          

Epoch 5 Train Summary | Time: 27.78s | Loss: 3.2067 | Acc: 0.2275 | Tokens: 660023


                                                                                     

 Eval Summary | Loss: 3.1368 | Accuracy: 0.2334 | SSMD: 0.7007 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (3.3713 -> 3.1368). Saving best model... **

===== Epoch 6/100 =====


                                                                                                          

Epoch 6 Train Summary | Time: 27.77s | Loss: 3.0198 | Acc: 0.2724 | Tokens: 660427


                                                                                     

 Eval Summary | Loss: 2.9492 | Accuracy: 0.2855 | SSMD: 0.5944 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (3.1368 -> 2.9492). Saving best model... **

===== Epoch 7/100 =====


                                                                                                          

Epoch 7 Train Summary | Time: 27.78s | Loss: 2.8349 | Acc: 0.3327 | Tokens: 660955


                                                                                     

 Eval Summary | Loss: 2.8002 | Accuracy: 0.3356 | SSMD: 0.5345 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.9492 -> 2.8002). Saving best model... **

===== Epoch 8/100 =====


                                                                                                          

Epoch 8 Train Summary | Time: 27.80s | Loss: 2.6885 | Acc: 0.3695 | Tokens: 660801


                                                                                     

 Eval Summary | Loss: 2.6434 | Accuracy: 0.3806 | SSMD: 0.4763 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.8002 -> 2.6434). Saving best model... **

===== Epoch 9/100 =====


                                                                                                          

Epoch 9 Train Summary | Time: 27.80s | Loss: 2.5682 | Acc: 0.3996 | Tokens: 660441


                                                                                     

 Eval Summary | Loss: 2.5412 | Accuracy: 0.4016 | SSMD: 0.4540 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.6434 -> 2.5412). Saving best model... **

===== Epoch 10/100 =====


                                                                                                           

Epoch 10 Train Summary | Time: 27.80s | Loss: 2.4820 | Acc: 0.4164 | Tokens: 659941


                                                                                     

 Eval Summary | Loss: 2.4796 | Accuracy: 0.4123 | SSMD: 0.4355 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.5412 -> 2.4796). Saving best model... **

===== Epoch 11/100 =====


                                                                                                           

Epoch 11 Train Summary | Time: 27.75s | Loss: 2.4177 | Acc: 0.4278 | Tokens: 660339


                                                                                     

 Eval Summary | Loss: 2.4256 | Accuracy: 0.4225 | SSMD: 0.4319 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.4796 -> 2.4256). Saving best model... **

===== Epoch 12/100 =====


                                                                                                           

Epoch 12 Train Summary | Time: 27.72s | Loss: 2.3587 | Acc: 0.4379 | Tokens: 660173


                                                                                     

 Eval Summary | Loss: 2.3804 | Accuracy: 0.4315 | SSMD: 0.4219 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.4256 -> 2.3804). Saving best model... **

===== Epoch 13/100 =====


                                                                                                           

Epoch 13 Train Summary | Time: 27.78s | Loss: 2.3203 | Acc: 0.4440 | Tokens: 659841


                                                                                     

 Eval Summary | Loss: 2.3377 | Accuracy: 0.4384 | SSMD: 0.4176 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.3804 -> 2.3377). Saving best model... **

===== Epoch 14/100 =====


                                                                                                           

Epoch 14 Train Summary | Time: 27.74s | Loss: 2.2809 | Acc: 0.4499 | Tokens: 660147


                                                                                     

 Eval Summary | Loss: 2.3239 | Accuracy: 0.4404 | SSMD: 0.4143 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.3377 -> 2.3239). Saving best model... **

===== Epoch 15/100 =====


                                                                                                           

Epoch 15 Train Summary | Time: 27.72s | Loss: 2.2523 | Acc: 0.4538 | Tokens: 660071


                                                                                     

 Eval Summary | Loss: 2.2974 | Accuracy: 0.4460 | SSMD: 0.4122 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.3239 -> 2.2974). Saving best model... **

===== Epoch 16/100 =====


                                                                                                           

Epoch 16 Train Summary | Time: 27.73s | Loss: 2.2302 | Acc: 0.4565 | Tokens: 659896


                                                                                     

 Eval Summary | Loss: 2.2750 | Accuracy: 0.4460 | SSMD: 0.4035 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.2974 -> 2.2750). Saving best model... **

===== Epoch 17/100 =====


                                                                                                           

Epoch 17 Train Summary | Time: 27.72s | Loss: 2.2075 | Acc: 0.4598 | Tokens: 659998


                                                                                     

 Eval Summary | Loss: 2.2725 | Accuracy: 0.4462 | SSMD: 0.4050 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.2750 -> 2.2725). Saving best model... **

===== Epoch 18/100 =====


                                                                                                           

Epoch 18 Train Summary | Time: 27.74s | Loss: 2.1888 | Acc: 0.4632 | Tokens: 660712


                                                                                     

 Eval Summary | Loss: 2.2647 | Accuracy: 0.4484 | SSMD: 0.4129 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.2725 -> 2.2647). Saving best model... **

===== Epoch 19/100 =====


                                                                                                           

Epoch 19 Train Summary | Time: 27.79s | Loss: 2.1682 | Acc: 0.4658 | Tokens: 660272


                                                                                     

 Eval Summary | Loss: 2.2515 | Accuracy: 0.4498 | SSMD: 0.4076 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.2647 -> 2.2515). Saving best model... **

===== Epoch 20/100 =====


                                                                                                           

Epoch 20 Train Summary | Time: 27.79s | Loss: 2.1504 | Acc: 0.4679 | Tokens: 660262


                                                                                     

 Eval Summary | Loss: 2.2391 | Accuracy: 0.4517 | SSMD: 0.3926 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.2515 -> 2.2391). Saving best model... **

===== Epoch 21/100 =====


                                                                                                           

Epoch 21 Train Summary | Time: 27.78s | Loss: 2.1361 | Acc: 0.4705 | Tokens: 659797


                                                                                     

 Eval Summary | Loss: 2.2236 | Accuracy: 0.4502 | SSMD: 0.3859 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.2391 -> 2.2236). Saving best model... **

===== Epoch 22/100 =====


                                                                                                           

Epoch 22 Train Summary | Time: 27.81s | Loss: 2.1203 | Acc: 0.4725 | Tokens: 659947


                                                                                     

 Eval Summary | Loss: 2.2278 | Accuracy: 0.4511 | SSMD: 0.3814 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 23/100 =====


                                                                                                           

Epoch 23 Train Summary | Time: 27.79s | Loss: 2.1094 | Acc: 0.4738 | Tokens: 660125


                                                                                     

 Eval Summary | Loss: 2.2206 | Accuracy: 0.4543 | SSMD: 0.3851 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.2236 -> 2.2206). Saving best model... **

===== Epoch 24/100 =====


                                                                                                           

Epoch 24 Train Summary | Time: 27.79s | Loss: 2.0948 | Acc: 0.4763 | Tokens: 659959


                                                                                     

 Eval Summary | Loss: 2.2178 | Accuracy: 0.4541 | SSMD: 0.3799 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.2206 -> 2.2178). Saving best model... **

===== Epoch 25/100 =====


                                                                                                           

Epoch 25 Train Summary | Time: 27.81s | Loss: 2.0814 | Acc: 0.4783 | Tokens: 660183


                                                                                     

 Eval Summary | Loss: 2.2302 | Accuracy: 0.4544 | SSMD: 0.3740 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 26/100 =====


                                                                                                           

Epoch 26 Train Summary | Time: 27.80s | Loss: 2.0692 | Acc: 0.4805 | Tokens: 660218


                                                                                     

 Eval Summary | Loss: 2.2177 | Accuracy: 0.4547 | SSMD: 0.3820 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702
** Validation Loss Improved (2.2178 -> 2.2177). Saving best model... **

===== Epoch 27/100 =====


                                                                                                           

Epoch 27 Train Summary | Time: 27.78s | Loss: 2.0526 | Acc: 0.4828 | Tokens: 659747


                                                                                     

 Eval Summary | Loss: 2.2217 | Accuracy: 0.4555 | SSMD: 0.3676 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 28/100 =====


                                                                                                           

Epoch 28 Train Summary | Time: 27.81s | Loss: 2.0442 | Acc: 0.4838 | Tokens: 660388


                                                                                     

 Eval Summary | Loss: 2.2320 | Accuracy: 0.4569 | SSMD: 0.3588 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 29/100 =====


                                                                                                           

Epoch 29 Train Summary | Time: 27.80s | Loss: 2.0311 | Acc: 0.4858 | Tokens: 659979


                                                                                     

 Eval Summary | Loss: 2.2233 | Accuracy: 0.4570 | SSMD: 0.3670 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 30/100 =====


                                                                                                           

Epoch 30 Train Summary | Time: 27.80s | Loss: 2.0165 | Acc: 0.4885 | Tokens: 660433


                                                                                     

 Eval Summary | Loss: 2.2177 | Accuracy: 0.4547 | SSMD: 0.3749 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 31/100 =====


                                                                                                           

Epoch 31 Train Summary | Time: 27.80s | Loss: 2.0039 | Acc: 0.4904 | Tokens: 660371


                                                                                     

 Eval Summary | Loss: 2.2353 | Accuracy: 0.4553 | SSMD: 0.3607 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 32/100 =====


                                                                                                           

Epoch 32 Train Summary | Time: 27.79s | Loss: 1.9917 | Acc: 0.4920 | Tokens: 660796


                                                                                     

 Eval Summary | Loss: 2.2352 | Accuracy: 0.4560 | SSMD: 0.3546 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 33/100 =====


                                                                                                           

Epoch 33 Train Summary | Time: 27.81s | Loss: 1.9797 | Acc: 0.4944 | Tokens: 660472


                                                                                     

 Eval Summary | Loss: 2.2323 | Accuracy: 0.4529 | SSMD: 0.3538 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 34/100 =====


                                                                                                           

Epoch 34 Train Summary | Time: 27.79s | Loss: 1.9637 | Acc: 0.4971 | Tokens: 659747


                                                                                     

 Eval Summary | Loss: 2.2510 | Accuracy: 0.4547 | SSMD: 0.3552 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 35/100 =====


                                                                                                           

Epoch 35 Train Summary | Time: 27.80s | Loss: 1.9513 | Acc: 0.4988 | Tokens: 660483


                                                                                     

 Eval Summary | Loss: 2.2474 | Accuracy: 0.4519 | SSMD: 0.3449 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 36/100 =====


                                                                                                           

Epoch 36 Train Summary | Time: 27.82s | Loss: 1.9376 | Acc: 0.5013 | Tokens: 659873


                                                                                     

 Eval Summary | Loss: 2.2601 | Accuracy: 0.4505 | SSMD: 0.3278 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 37/100 =====


                                                                                                           

Epoch 37 Train Summary | Time: 27.82s | Loss: 1.9205 | Acc: 0.5034 | Tokens: 660979


                                                                                     

 Eval Summary | Loss: 2.2526 | Accuracy: 0.4515 | SSMD: 0.3305 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 38/100 =====


                                                                                                           

Epoch 38 Train Summary | Time: 27.79s | Loss: 1.9044 | Acc: 0.5072 | Tokens: 660493


                                                                                     

 Eval Summary | Loss: 2.2763 | Accuracy: 0.4535 | SSMD: 0.3250 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 39/100 =====


                                                                                                           

Epoch 39 Train Summary | Time: 27.79s | Loss: 1.8897 | Acc: 0.5089 | Tokens: 660198


                                                                                     

 Eval Summary | Loss: 2.2775 | Accuracy: 0.4529 | SSMD: 0.3265 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 40/100 =====


                                                                                                           

Epoch 40 Train Summary | Time: 27.82s | Loss: 1.8715 | Acc: 0.5126 | Tokens: 660712


                                                                                     

 Eval Summary | Loss: 2.2845 | Accuracy: 0.4508 | SSMD: 0.3206 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 41/100 =====


                                                                                                           

Epoch 41 Train Summary | Time: 27.82s | Loss: 1.8566 | Acc: 0.5149 | Tokens: 659895


                                                                                     

 Eval Summary | Loss: 2.3105 | Accuracy: 0.4508 | SSMD: 0.3198 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 42/100 =====


                                                                                                           

Epoch 42 Train Summary | Time: 27.81s | Loss: 1.8404 | Acc: 0.5178 | Tokens: 660227


                                                                                     

 Eval Summary | Loss: 2.3269 | Accuracy: 0.4485 | SSMD: 0.3155 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 43/100 =====


                                                                                                           

Epoch 43 Train Summary | Time: 27.82s | Loss: 1.8192 | Acc: 0.5218 | Tokens: 660195


                                                                                     

 Eval Summary | Loss: 2.3386 | Accuracy: 0.4456 | SSMD: 0.3039 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 44/100 =====


                                                                                                           

Epoch 44 Train Summary | Time: 27.78s | Loss: 1.8032 | Acc: 0.5244 | Tokens: 659928


                                                                                     

 Eval Summary | Loss: 2.3425 | Accuracy: 0.4479 | SSMD: 0.3075 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 45/100 =====


                                                                                                           

Epoch 45 Train Summary | Time: 27.80s | Loss: 1.7850 | Acc: 0.5278 | Tokens: 660075


                                                                                     

 Eval Summary | Loss: 2.3559 | Accuracy: 0.4463 | SSMD: 0.2977 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 46/100 =====


                                                                                                           

Epoch 46 Train Summary | Time: 27.80s | Loss: 1.7638 | Acc: 0.5312 | Tokens: 659970


                                                                                     

 Eval Summary | Loss: 2.3715 | Accuracy: 0.4425 | SSMD: 0.2841 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 47/100 =====


                                                                                                           

Epoch 47 Train Summary | Time: 27.78s | Loss: 1.7452 | Acc: 0.5352 | Tokens: 660303


                                                                                     

 Eval Summary | Loss: 2.3715 | Accuracy: 0.4460 | SSMD: 0.2947 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 48/100 =====


                                                                                                           

Epoch 48 Train Summary | Time: 27.79s | Loss: 1.7225 | Acc: 0.5395 | Tokens: 660089


                                                                                     

 Eval Summary | Loss: 2.4169 | Accuracy: 0.4408 | SSMD: 0.2845 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 49/100 =====


                                                                                                           

Epoch 49 Train Summary | Time: 27.83s | Loss: 1.7010 | Acc: 0.5434 | Tokens: 660605


                                                                                     

 Eval Summary | Loss: 2.4286 | Accuracy: 0.4435 | SSMD: 0.2858 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 50/100 =====


                                                                                                           

Epoch 50 Train Summary | Time: 27.81s | Loss: 1.6789 | Acc: 0.5478 | Tokens: 659970


                                                                                     

 Eval Summary | Loss: 2.4404 | Accuracy: 0.4418 | SSMD: 0.2668 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 51/100 =====


                                                                                                           

Epoch 51 Train Summary | Time: 27.81s | Loss: 1.6582 | Acc: 0.5514 | Tokens: 660168


                                                                                     

 Eval Summary | Loss: 2.4544 | Accuracy: 0.4396 | SSMD: 0.2765 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 52/100 =====


                                                                                                           

Epoch 52 Train Summary | Time: 27.80s | Loss: 1.6346 | Acc: 0.5561 | Tokens: 660057


                                                                                     

 Eval Summary | Loss: 2.4910 | Accuracy: 0.4382 | SSMD: 0.2771 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 53/100 =====


                                                                                                           

Epoch 53 Train Summary | Time: 27.80s | Loss: 1.6068 | Acc: 0.5619 | Tokens: 659824


                                                                                     

 Eval Summary | Loss: 2.5208 | Accuracy: 0.4368 | SSMD: 0.2678 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 54/100 =====


                                                                                                           

Epoch 54 Train Summary | Time: 27.80s | Loss: 1.5823 | Acc: 0.5672 | Tokens: 660937


                                                                                     

 Eval Summary | Loss: 2.5325 | Accuracy: 0.4421 | SSMD: 0.2515 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 55/100 =====


                                                                                                           

Epoch 55 Train Summary | Time: 27.80s | Loss: 1.5622 | Acc: 0.5712 | Tokens: 659925


                                                                                     

 Eval Summary | Loss: 2.5755 | Accuracy: 0.4361 | SSMD: 0.2551 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 56/100 =====


                                                                                                           

Epoch 56 Train Summary | Time: 27.79s | Loss: 1.5380 | Acc: 0.5756 | Tokens: 659747


                                                                                     

 Eval Summary | Loss: 2.5880 | Accuracy: 0.4365 | SSMD: 0.2520 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 57/100 =====


                                                                                                           

Epoch 57 Train Summary | Time: 27.80s | Loss: 1.5102 | Acc: 0.5817 | Tokens: 660069


                                                                                     

 Eval Summary | Loss: 2.6466 | Accuracy: 0.4378 | SSMD: 0.2488 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 58/100 =====


                                                                                                           

Epoch 58 Train Summary | Time: 27.83s | Loss: 1.4828 | Acc: 0.5871 | Tokens: 660795


                                                                                     

 Eval Summary | Loss: 2.6474 | Accuracy: 0.4357 | SSMD: 0.2441 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 59/100 =====


                                                                                                           

Epoch 59 Train Summary | Time: 27.82s | Loss: 1.4593 | Acc: 0.5927 | Tokens: 660284


                                                                                     

 Eval Summary | Loss: 2.6629 | Accuracy: 0.4330 | SSMD: 0.2371 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 60/100 =====


                                                                                                           

Epoch 60 Train Summary | Time: 27.82s | Loss: 1.4300 | Acc: 0.5993 | Tokens: 660401


                                                                                     

 Eval Summary | Loss: 2.7209 | Accuracy: 0.4334 | SSMD: 0.2356 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 61/100 =====


                                                                                                           

Epoch 61 Train Summary | Time: 27.83s | Loss: 1.4044 | Acc: 0.6050 | Tokens: 659957


                                                                                     

 Eval Summary | Loss: 2.7703 | Accuracy: 0.4347 | SSMD: 0.2338 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 62/100 =====


                                                                                                           

Epoch 62 Train Summary | Time: 27.81s | Loss: 1.3766 | Acc: 0.6110 | Tokens: 659747


                                                                                     

 Eval Summary | Loss: 2.7583 | Accuracy: 0.4296 | SSMD: 0.2317 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 63/100 =====


                                                                                                           

Epoch 63 Train Summary | Time: 27.80s | Loss: 1.3465 | Acc: 0.6172 | Tokens: 660269


                                                                                     

 Eval Summary | Loss: 2.8220 | Accuracy: 0.4297 | SSMD: 0.2292 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 64/100 =====


                                                                                                           

Epoch 64 Train Summary | Time: 27.82s | Loss: 1.3229 | Acc: 0.6230 | Tokens: 660627


                                                                                     

 Eval Summary | Loss: 2.8691 | Accuracy: 0.4303 | SSMD: 0.2317 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 65/100 =====


                                                                                                           

Epoch 65 Train Summary | Time: 27.80s | Loss: 1.2957 | Acc: 0.6295 | Tokens: 660101


                                                                                     

 Eval Summary | Loss: 2.8734 | Accuracy: 0.4293 | SSMD: 0.2118 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 66/100 =====


                                                                                                           

Epoch 66 Train Summary | Time: 27.78s | Loss: 1.2665 | Acc: 0.6361 | Tokens: 660231


                                                                                     

 Eval Summary | Loss: 2.9259 | Accuracy: 0.4241 | SSMD: 0.2172 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 67/100 =====


                                                                                                           

Epoch 67 Train Summary | Time: 27.79s | Loss: 1.2396 | Acc: 0.6426 | Tokens: 659747


                                                                                     

 Eval Summary | Loss: 2.9465 | Accuracy: 0.4242 | SSMD: 0.2099 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 68/100 =====


                                                                                                           

Epoch 68 Train Summary | Time: 27.80s | Loss: 1.2080 | Acc: 0.6499 | Tokens: 660654


                                                                                     

 Eval Summary | Loss: 2.9982 | Accuracy: 0.4247 | SSMD: 0.2062 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 69/100 =====


                                                                                                           

Epoch 69 Train Summary | Time: 27.80s | Loss: 1.1811 | Acc: 0.6572 | Tokens: 660168


                                                                                     

 Eval Summary | Loss: 3.0567 | Accuracy: 0.4261 | SSMD: 0.2036 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 70/100 =====


                                                                                                           

Epoch 70 Train Summary | Time: 27.81s | Loss: 1.1549 | Acc: 0.6633 | Tokens: 659989


                                                                                     

 Eval Summary | Loss: 3.0629 | Accuracy: 0.4228 | SSMD: 0.2117 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 71/100 =====


                                                                                                           

Epoch 71 Train Summary | Time: 27.81s | Loss: 1.1279 | Acc: 0.6693 | Tokens: 661122


                                                                                     

 Eval Summary | Loss: 3.1179 | Accuracy: 0.4249 | SSMD: 0.2080 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 72/100 =====


                                                                                                           

Epoch 72 Train Summary | Time: 27.81s | Loss: 1.1010 | Acc: 0.6766 | Tokens: 660485


                                                                                     

 Eval Summary | Loss: 3.1399 | Accuracy: 0.4216 | SSMD: 0.2092 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 73/100 =====


                                                                                                           

Epoch 73 Train Summary | Time: 27.80s | Loss: 1.0707 | Acc: 0.6844 | Tokens: 660424


                                                                                     

 Eval Summary | Loss: 3.2053 | Accuracy: 0.4248 | SSMD: 0.1989 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 74/100 =====


                                                                                                           

Epoch 74 Train Summary | Time: 27.82s | Loss: 1.0491 | Acc: 0.6892 | Tokens: 660462


                                                                                     

 Eval Summary | Loss: 3.2609 | Accuracy: 0.4246 | SSMD: 0.2021 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 75/100 =====


                                                                                                           

Epoch 75 Train Summary | Time: 27.81s | Loss: 1.0257 | Acc: 0.6954 | Tokens: 659879


                                                                                     

 Eval Summary | Loss: 3.2838 | Accuracy: 0.4209 | SSMD: 0.1943 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 76/100 =====


                                                                                                           

Epoch 76 Train Summary | Time: 27.83s | Loss: 1.0031 | Acc: 0.7010 | Tokens: 659970


                                                                                     

 Eval Summary | Loss: 3.2888 | Accuracy: 0.4204 | SSMD: 0.1961 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 77/100 =====


                                                                                                           

Epoch 77 Train Summary | Time: 27.81s | Loss: 0.9786 | Acc: 0.7076 | Tokens: 660965


                                                                                     

 Eval Summary | Loss: 3.3429 | Accuracy: 0.4180 | SSMD: 0.1880 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 78/100 =====


                                                                                                           

Epoch 78 Train Summary | Time: 27.86s | Loss: 0.9528 | Acc: 0.7148 | Tokens: 660549


                                                                                     

 Eval Summary | Loss: 3.3670 | Accuracy: 0.4228 | SSMD: 0.1991 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 79/100 =====


                                                                                                           

Epoch 79 Train Summary | Time: 27.81s | Loss: 0.9318 | Acc: 0.7200 | Tokens: 660764


                                                                                     

 Eval Summary | Loss: 3.4157 | Accuracy: 0.4224 | SSMD: 0.1959 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 80/100 =====


                                                                                                           

Epoch 80 Train Summary | Time: 27.80s | Loss: 0.9064 | Acc: 0.7267 | Tokens: 660494


                                                                                     

 Eval Summary | Loss: 3.4745 | Accuracy: 0.4208 | SSMD: 0.1864 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 81/100 =====


                                                                                                           

Epoch 81 Train Summary | Time: 27.79s | Loss: 0.8845 | Acc: 0.7327 | Tokens: 660514


                                                                                     

 Eval Summary | Loss: 3.5207 | Accuracy: 0.4211 | SSMD: 0.1853 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 82/100 =====


                                                                                                           

Epoch 82 Train Summary | Time: 27.81s | Loss: 0.8652 | Acc: 0.7375 | Tokens: 660055


                                                                                     

 Eval Summary | Loss: 3.5580 | Accuracy: 0.4197 | SSMD: 0.1862 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 83/100 =====


                                                                                                           

Epoch 83 Train Summary | Time: 27.83s | Loss: 0.8429 | Acc: 0.7433 | Tokens: 659972


                                                                                     

 Eval Summary | Loss: 3.5833 | Accuracy: 0.4230 | SSMD: 0.1895 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 84/100 =====


                                                                                                           

Epoch 84 Train Summary | Time: 27.81s | Loss: 0.8245 | Acc: 0.7488 | Tokens: 660164


                                                                                     

 Eval Summary | Loss: 3.6372 | Accuracy: 0.4183 | SSMD: 0.1907 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 85/100 =====


                                                                                                           

Epoch 85 Train Summary | Time: 27.82s | Loss: 0.8057 | Acc: 0.7547 | Tokens: 660131


                                                                                     

 Eval Summary | Loss: 3.6526 | Accuracy: 0.4200 | SSMD: 0.1834 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 86/100 =====


                                                                                                           

Epoch 86 Train Summary | Time: 27.82s | Loss: 0.7868 | Acc: 0.7594 | Tokens: 659959


                                                                                     

 Eval Summary | Loss: 3.6887 | Accuracy: 0.4180 | SSMD: 0.1890 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 87/100 =====


                                                                                                           

Epoch 87 Train Summary | Time: 27.82s | Loss: 0.7641 | Acc: 0.7657 | Tokens: 660662


                                                                                     

 Eval Summary | Loss: 3.6992 | Accuracy: 0.4186 | SSMD: 0.1872 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 88/100 =====


                                                                                                           

Epoch 88 Train Summary | Time: 27.81s | Loss: 0.7496 | Acc: 0.7694 | Tokens: 659895


                                                                                     

 Eval Summary | Loss: 3.7802 | Accuracy: 0.4214 | SSMD: 0.1933 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 89/100 =====


                                                                                                           

Epoch 89 Train Summary | Time: 27.81s | Loss: 0.7322 | Acc: 0.7739 | Tokens: 659913


                                                                                     

 Eval Summary | Loss: 3.8153 | Accuracy: 0.4200 | SSMD: 0.1794 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 90/100 =====


                                                                                                           

Epoch 90 Train Summary | Time: 27.80s | Loss: 0.7164 | Acc: 0.7788 | Tokens: 660211


                                                                                     

 Eval Summary | Loss: 3.8477 | Accuracy: 0.4176 | SSMD: 0.1807 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 91/100 =====


                                                                                                           

Epoch 91 Train Summary | Time: 27.79s | Loss: 0.7008 | Acc: 0.7831 | Tokens: 660266


                                                                                     

 Eval Summary | Loss: 3.8861 | Accuracy: 0.4173 | SSMD: 0.1785 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 92/100 =====


                                                                                                           

Epoch 92 Train Summary | Time: 27.81s | Loss: 0.6818 | Acc: 0.7885 | Tokens: 659953


                                                                                     

 Eval Summary | Loss: 3.9264 | Accuracy: 0.4228 | SSMD: 0.1859 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 93/100 =====


                                                                                                           

Epoch 93 Train Summary | Time: 27.79s | Loss: 0.6702 | Acc: 0.7919 | Tokens: 660291


                                                                                     

 Eval Summary | Loss: 3.9356 | Accuracy: 0.4188 | SSMD: 0.1821 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 94/100 =====


                                                                                                           

Epoch 94 Train Summary | Time: 27.81s | Loss: 0.6530 | Acc: 0.7963 | Tokens: 660041


                                                                                     

 Eval Summary | Loss: 3.9844 | Accuracy: 0.4227 | SSMD: 0.1828 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 95/100 =====


                                                                                                           

Epoch 95 Train Summary | Time: 27.82s | Loss: 0.6403 | Acc: 0.8005 | Tokens: 659747


                                                                                     

 Eval Summary | Loss: 3.9887 | Accuracy: 0.4185 | SSMD: 0.1800 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 96/100 =====


                                                                                                           

Epoch 96 Train Summary | Time: 27.81s | Loss: 0.6280 | Acc: 0.8038 | Tokens: 660225


                                                                                     

 Eval Summary | Loss: 4.0075 | Accuracy: 0.4174 | SSMD: 0.1809 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 97/100 =====


                                                                                                           

Epoch 97 Train Summary | Time: 27.86s | Loss: 0.6122 | Acc: 0.8077 | Tokens: 659747


                                                                                     

 Eval Summary | Loss: 4.0663 | Accuracy: 0.4188 | SSMD: 0.1918 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 98/100 =====


                                                                                                           

Epoch 98 Train Summary | Time: 27.82s | Loss: 0.6026 | Acc: 0.8112 | Tokens: 660379


                                                                                     

 Eval Summary | Loss: 4.0632 | Accuracy: 0.4190 | SSMD: 0.1858 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 99/100 =====


                                                                                                           

Epoch 99 Train Summary | Time: 27.80s | Loss: 0.5875 | Acc: 0.8158 | Tokens: 660537


                                                                                     

 Eval Summary | Loss: 4.1431 | Accuracy: 0.4198 | SSMD: 0.1772 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

===== Epoch 100/100 =====


                                                                                                            

Epoch 100 Train Summary | Time: 27.82s | Loss: 0.5754 | Acc: 0.8192 | Tokens: 659747


                                                                                     

 Eval Summary | Loss: 4.1587 | Accuracy: 0.4185 | SSMD: 0.1822 | GS: 0.9328 | Eval Seq Count: 77 | Eval Tokens: 36702

--- Training Finished ---

--- Final Evaluation on Test Set ---


                                                                                     

 Eval Summary | Loss: 2.1751 | Accuracy: 0.4686 | SSMD: 0.3782 | GS: 0.9231 | Eval Seq Count: 77 | Eval Tokens: 36320

--- Test Set Results (using BEST saved model) ---
 Loss:     2.1751
 Accuracy: 0.4686
 SSMD:     0.3782
 GS:       0.9231
 Eval Seqs:77
 Eval Toks:36320
Output files potentially saved in: /kaggle/working/melody_model_output


In [2]:
!pip install pretty_midi

Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m58.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido>=1.1.16 (from pretty_midi)
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Downloading mido-1.3.3-py3-none-any.whl (54 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.6/54.6 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: pretty_midi
  Building wheel for pretty_midi (setup.py) ... [?25l[?25hdone
  Created wheel for pretty_midi: filename=pretty_midi-0.2.10-py3-none-any.whl size=5592287 sha256=b2f2a2a7a5f8cf5000fba60c4ff70ebdd307b2696e81c349991be7d05f664f5d
  Stored in directory: /root/.cache/pip/wheels/e6/95/ac/15ceaeb2823b04d8e638fd1495357adb8d26c00ccac9d7782e
Successfully built pretty_midi
Installing collected packages: mido, pretty_

In [3]:
# =============================================================================
# COMPLETE Generation Script: Melody Transformer-XL with RoPE & Conditioning
# Version: Fixed _init_weights, Improved Polyphony Filter, Progression Mode, Seq Naming
#          *** USER ACTION REQUIRED: Provide full chord progression data ***
# =============================================================================
# Script generated around: Friday, April 25, 2025 at 3:47 PM IST (Bhopal, India time)
# =============================================================================

import os
import json
import warnings
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
import sys
import traceback
import gc # Garbage Collector
from typing import Optional, List, Dict, Tuple, Any, Set
from dataclasses import dataclass, field, fields # Added fields
from pathlib import Path
from tqdm import tqdm # Use standard tqdm
import collections # Needed for defaultdict in collate_fn
import logging
import re # Needed for sequential filename logic

# --- Install and Import PrettyMIDI ---
try:
    import pretty_midi
except ImportError:
    print("Installing pretty_midi...")
    try:
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "pretty_midi"])
        import pretty_midi
        print("pretty_midi installed successfully.")
    except Exception as e:
        print(f"Failed to install pretty_midi automatically: {e}")
        print("Please install it manually (e.g., 'pip install pretty_midi') and restart the script.")
        sys.exit(1)


# --- CUDA specific imports ---
from torch.cuda.amp import GradScaler
import torch.amp

# For timestamp and location context
try:
    import pytz
except ImportError:
     print("Installing pytz...")
     try:
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "pytz"])
        import pytz
        print("pytz installed successfully.")
     except Exception as e:
        print(f"Failed to install pytz automatically: {e}")
        print("Please install it manually (e.g., 'pip install pytz') and restart the script.")
        pytz = None # Continue without timezone awareness
import datetime


# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# !! NOTE: All necessary class definitions are included in this script.     !!
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

# --- Define Vocabulary Constants ---
MELODY_PAD_TOKEN: str = "<PAD>"
MELODY_UNK_TOKEN: str = "<UNK>"
CHORD_VOCAB_FILENAME: str = "chord_progression_vocab.json"
CHORD_PAD_TOKEN: str = "<PAD>" # Assuming pad token for chords is also <PAD>

# --- TrainingConfig Definition ---
# Note: This config is loaded from the checkpoint, but defaults are here for reference
@dataclass
class TrainingConfig:
    # --- Paths ---
    midi_root_dir: str = "LOCAL_PATH_IGNORE"
    chord_data_dir: str = "/kaggle/input/advance-h-rpe" # Example Path
    melody_data_path: str = "/kaggle/input/new-melody-model-new-approach-1/training_data.jsonl" # Example Path
    melody_vocab_path: str = "/kaggle/input/new-melody-model-new-approach-1/event_vocab.json" # Example Path
    output_dir: str = "/kaggle/working/melody_model_output" # Example Path

    # --- Vocab Sizes & Padding ---
    melody_vocab_size: int = 0
    chord_vocab_size: int = 0
    melody_pad_token_id: int = 0
    chord_pad_token_id: int = 0

    # --- Model Architecture ---
    n_layer: int = 8
    d_model: int = 512
    n_head: int = 8
    d_head: int = 64 # Will be calculated if d_model % n_head == 0
    d_inner: int = 2048
    dropout: float = 0.1
    mem_len: int = 256
    rope_theta: float = 10000.0
    num_chord_features: int = 3
    condition_proj_dim: int = 128
    chord_emb_dim: Optional[int] = 64

    # --- Training ---
    batch_size: int = 16
    num_epochs: int = 50
    learning_rate: float = 3e-4
    weight_decay: float = 0.01
    grad_clip_value: float = 1.0
    seed: int = 42

    # --- Data Loading ---
    max_seq_len: Optional[int] = 512
    train_split: float = 0.90
    val_split: float = 0.05
    test_split: float = 0.05
    num_dataload_workers: int = 2

    # --- Runtime ---
    amp_dtype: Optional[torch.dtype] = torch.float16

    # --- Make the class dict serializable ---
    def as_dict(self):
        d = {}
        for f in fields(self):
            value = getattr(self, f.name)
            if isinstance(value, torch.dtype):
                d[f.name] = str(value) # Convert torch.dtype to string
            elif isinstance(value, Path):
                d[f.name] = str(value) # Convert Path to string
            else:
                try:
                    json.dumps({f.name: value})
                    d[f.name] = value
                except TypeError:
                    logging.warning(f"Could not serialize field '{f.name}' of type {type(value)}. Storing its string representation.")
                    d[f.name] = str(value) # Fallback to string representation
        return d

# --- Model Class Definitions ---

# === RoPE Implementation ===
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(x, cos, sin):
    # x: [seq_len, bsz, n_head, d_head]
    # cos, sin: [seq_len, 1, 1, d_head] or [seq_len, d_head] -> unsqueezed later
    if cos.dim() == 2: # Handle case where cos/sin are [seq_len, d_head]
         cos = cos.unsqueeze(1).unsqueeze(2)
         sin = sin.unsqueeze(1).unsqueeze(2)
    elif cos.dim() != 4 or cos.shape[1] != 1 or cos.shape[2] != 1:
         raise ValueError(f"Unexpected shape for RoPE cos/sin: {cos.shape}")

    x_embed = (x * cos) + (rotate_half(x) * sin)
    return x_embed

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=4096, base=10000.0, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.max_seq_len_cached = -1
        self.register_buffer("cos_cached", None, persistent=False)
        self.register_buffer("sin_cached", None, persistent=False)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        # Optimization: Only update cache if needed
        if seq_len <= self.max_seq_len_cached and \
           self.cos_cached is not None and self.sin_cached is not None and \
           self.cos_cached.device == device and self.cos_cached.dtype == dtype:
              return
        # Increase cache size adaptively but ensure it covers max_position_embeddings
        new_cache_len = max(seq_len, self.max_position_embeddings) # Use max_position_embeddings as a lower bound
        self.max_seq_len_cached = new_cache_len # Update cached length *after* calculation
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos().to(dtype).detach() # Ensure correct dtype
        self.sin_cached = emb.sin().to(dtype).detach() # Ensure correct dtype
        logging.debug(f"RoPE cache updated: seq_len={self.max_seq_len_cached}, device={device}, dtype={dtype}")


    def forward(self, x: torch.Tensor, seq_len: int, start_pos: int = 0):
        device = x.device
        dtype = x.dtype
        required_len = start_pos + seq_len # Total length needed from cache

        # Check if cache needs update (size, device, or dtype mismatch)
        if required_len > self.max_seq_len_cached or \
           self.cos_cached is None or self.sin_cached is None or \
           self.cos_cached.device != device or self.cos_cached.dtype != dtype:
            # Determine new cache size - should be at least required_len and max_position_embeddings
            new_max_len = max(self.max_position_embeddings, required_len)
            self._set_cos_sin_cache(seq_len=new_max_len, device=device, dtype=dtype)

        # Calculate end position for slicing
        end_pos = start_pos + seq_len

        # Slice the cache - ensure indices are within bounds
        start_pos_clamped = max(0, min(start_pos, self.max_seq_len_cached - 1))
        end_pos_clamped = max(0, min(end_pos, self.max_seq_len_cached))

        # Handle potential empty slice after clamping
        if start_pos_clamped >= end_pos_clamped:
             logging.warning(f"RoPE: start_pos {start_pos} >= end_pos {end_pos} (or became so after clamping). Returning empty tensors.")
             return torch.empty((0, self.dim), device=device, dtype=dtype), torch.empty((0, self.dim), device=device, dtype=dtype)

        cos = self.cos_cached[start_pos_clamped : end_pos_clamped]
        sin = self.sin_cached[start_pos_clamped : end_pos_clamped]

        # Verify the sliced length matches the expected seq_len
        if cos.shape[0] != seq_len:
            logging.warning(f"RoPE: Sliced length {cos.shape[0]} does not match expected seq_len {seq_len}. "
                            f"start_pos={start_pos}, end_pos={end_pos}, "
                            f"clamped=[{start_pos_clamped}:{end_pos_clamped}], cache_len={self.max_seq_len_cached}. "
                            f"This might indicate an issue with position calculation.")
            # Attempt to adjust if possible, otherwise could lead to errors later
            if cos.shape[0] < seq_len:
                 logging.error(f"RoPE: Sliced cache is too short ({cos.shape[0]} vs {seq_len}). Cannot proceed safely.")
                 raise IndexError("RoPE cache slicing resulted in tensor shorter than expected seq_len.")
            else: # Truncate if too long
                 cos = cos[:seq_len]
                 sin = sin[:seq_len]
                 logging.warning(f"RoPE: Truncated sliced cache to match seq_len {seq_len}.")


        return cos, sin


# === Model Components ===
class RelPartialLearnableMultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, config: TrainingConfig, layer_idx: int):
        super().__init__()
        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout
        self.config = config
        self.layer_idx = layer_idx

        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
        self.drop = nn.Dropout(dropout)

        # Calculate max positions needed based on mem_len and max_seq_len
        # Use a reasonable upper bound if max_seq_len is None during generation setup
        gen_max_seq_len = config.max_seq_len if config.max_seq_len is not None else 2048 # Default large if None
        max_positions = config.mem_len + gen_max_seq_len

        self.rotary_emb = RotaryEmbedding(
            dim=self.d_head,
            max_position_embeddings=max_positions,
            base=config.rope_theta
        )
        self.scale = 1.0 / (d_head ** 0.5)

    def forward(self, w: torch.Tensor, mems: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
        qlen, bsz, d_model_in = w.size() # w: (qlen, bsz, d_model)
        if d_model_in != self.d_model:
            raise ValueError(f"Layer {self.layer_idx}: Input tensor d_model mismatch. Expected {self.d_model}, got {d_model_in}")

        mlen = mems.size(0) if mems is not None and mems.dim() == 3 and mems.shape[0] > 0 else 0
        klen = mlen + qlen # Total length of key/value sequence (memory + current)

        # Validate memory shape if provided
        if mlen > 0:
            if mems.shape[1] != bsz or mems.shape[2] != self.d_model:
                logging.warning(f"Layer {self.layer_idx}: Mem shape {mems.shape} incompatible with Input shape {w.shape}. Discarding memory.")
                cat = w # Discard memory
                mlen = 0
                klen = qlen
            else:
                cat = torch.cat([mems, w], dim=0) # Prepend memory: (klen, bsz, d_model)
        else:
            cat = w # No memory: (qlen, bsz, d_model)

        # --- QKV Projection ---
        # Project concatenated memory and input
        w_heads = self.qkv_net(cat) # (klen, bsz, 3 * n_head * d_head)
        w_heads = w_heads.view(klen, bsz, self.n_head, 3 * self.d_head)
        q_head_raw, k_head_raw, v_head = torch.chunk(w_heads, 3, dim=-1) # Each: (klen, bsz, n_head, d_head)

        # --- Separate Query, Key, Value ---
        # Query comes only from the current input part
        q_head = q_head_raw[-qlen:] # (qlen, bsz, n_head, d_head)
        # Key and Value use the full concatenated sequence (memory + input)
        k_head = k_head_raw # (klen, bsz, n_head, d_head)
        # v_head already has the correct shape: (klen, bsz, n_head, d_head)

        # --- Apply Rotary Positional Embeddings (RoPE) ---
        # Calculate RoPE embeddings for the full key length (klen) starting from position 0
        cos_k, sin_k = self.rotary_emb(k_head, seq_len=klen, start_pos=0) # cos/sin: [klen, d_head]

        # Calculate RoPE embeddings for the query length (qlen) starting from the memory length (mlen)
        cos_q, sin_q = self.rotary_emb(q_head, seq_len=qlen, start_pos=mlen) # cos/sin: [qlen, d_head]


        # Check for shape mismatches before applying RoPE
        if cos_q.shape[0] != qlen or sin_q.shape[0] != qlen:
            logging.error(f"Layer {self.layer_idx}: RoPE query cos/sin shape mismatch! Expected {qlen}, got {cos_q.shape[0]}. Mlen={mlen}, Klen={klen}. Skipping RoPE for query.")
            q_head_rot = q_head # Skip RoPE for query as a fallback
        else:
            q_head_rot = apply_rotary_pos_emb(q_head, cos_q, sin_q) # Apply RoPE to Query

        if cos_k.shape[0] != klen or sin_k.shape[0] != klen:
             logging.error(f"Layer {self.layer_idx}: RoPE key cos/sin shape mismatch! Expected {klen}, got {cos_k.shape[0]}. Mlen={mlen}, Klen={klen}. Skipping RoPE for key.")
             k_head_rot = k_head # Skip RoPE for key as a fallback
        else:
            k_head_rot = apply_rotary_pos_emb(k_head, cos_k, sin_k) # Apply RoPE to Key


        # --- Prepare for Attention Calculation ---
        # Permute for batch matrix multiplication: (bsz, n_head, seq_len, d_head)
        q_head_ = q_head_rot.permute(1, 2, 0, 3) # (bsz, n_head, qlen, d_head)
        k_head_ = k_head_rot.permute(1, 2, 0, 3) # (bsz, n_head, klen, d_head)
        v_head_ = v_head.permute(1, 2, 0, 3)     # (bsz, n_head, klen, d_head)

        # --- Calculate Attention Scores ---
        attn_score = torch.matmul(q_head_, k_head_.transpose(-2, -1)) # (bsz, n_head, qlen, klen)
        attn_score = attn_score * self.scale # Scale scores

        # --- Apply Attention Mask ---
        if attn_mask is not None:
            if attn_mask.dim() == 2: mask_to_apply = attn_mask.unsqueeze(0).unsqueeze(0) # -> (1, 1, qlen, klen)
            elif attn_mask.dim() == 4: mask_to_apply = attn_mask
            else: logging.warning(f"L{self.layer_idx}: Unexpected mask shape {attn_mask.shape}. Ignored."); mask_to_apply = None

            if mask_to_apply is not None:
                mask_to_apply = mask_to_apply.to(device=attn_score.device, dtype=torch.bool)
                # Check compatibility before applying
                if mask_to_apply.shape[-2:] == attn_score.shape[-2:]:
                    if mask_to_apply.shape[0] != attn_score.shape[0] and mask_to_apply.shape[0] == 1: mask_to_apply = mask_to_apply.expand(attn_score.shape[0], -1, -1, -1)
                    if mask_to_apply.shape[1] != attn_score.shape[1] and mask_to_apply.shape[1] == 1: mask_to_apply = mask_to_apply.expand(-1, attn_score.shape[1], -1, -1)
                    if mask_to_apply.shape == attn_score.shape: attn_score = attn_score.masked_fill(mask_to_apply, torch.finfo(attn_score.dtype).min)
                    else: logging.warning(f"L{self.layer_idx}: Mask/score shape mismatch after expansion {mask_to_apply.shape} vs {attn_score.shape}. Mask Ignored.")
                else: logging.warning(f"L{self.layer_idx}: Mask dims incompatible {mask_to_apply.shape[-2:]} vs {attn_score.shape[-2:]}. Mask Ignored.")

        # --- Calculate Attention Probabilities ---
        attn_prob = F.softmax(attn_score.float(), dim=-1).to(attn_score.dtype)
        attn_prob = self.drop(attn_prob)

        # --- Apply Head Mask (Optional) ---
        if head_mask is not None: attn_prob = attn_prob * head_mask.to(attn_prob.device)

        # --- Calculate Attention Output ---
        attn_vec = torch.matmul(attn_prob, v_head_).permute(2, 0, 1, 3).contiguous()
        attn_vec = attn_vec.view(qlen, bsz, self.n_head * self.d_head)

        # --- Final Output Projection ---
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        return attn_out


class TransformerXLLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout, config: TrainingConfig, layer_idx: int):
        super().__init__()
        self.layer_idx = layer_idx
        self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, config=config, layer_idx=layer_idx)
        self.pos_ff = nn.Sequential(nn.Linear(d_model, d_inner), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_inner, d_model), nn.Dropout(dropout))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, dec_inp: torch.Tensor, mems: Optional[torch.Tensor], attn_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor]=None) -> torch.Tensor:
        x_norm1 = self.norm1(dec_inp)
        attn_output = self.dec_attn(w=x_norm1, mems=mems, attn_mask=attn_mask, head_mask=head_mask)
        h = dec_inp + self.dropout(attn_output)
        h_norm2 = self.norm2(h)
        ff_output = self.pos_ff(h_norm2)
        output = h + self.dropout(ff_output)
        return output

class MelodyTransformerXL(nn.Module):
    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.config = config

        # --- Validate Config ---
        if not (isinstance(config.melody_vocab_size, int) and config.melody_vocab_size > 0): raise ValueError(f"Invalid melody_vocab_size: {config.melody_vocab_size}")
        if not (isinstance(config.melody_pad_token_id, int) and 0 <= config.melody_pad_token_id < config.melody_vocab_size): raise ValueError(f"Invalid melody_pad_token_id: {config.melody_pad_token_id}")
        self.use_chord_embedding = isinstance(config.chord_emb_dim, int) and config.chord_emb_dim > 0
        if self.use_chord_embedding:
            if not (isinstance(config.chord_vocab_size, int) and config.chord_vocab_size > 0): raise ValueError(f"Invalid chord_vocab_size: {config.chord_vocab_size}")
            if not (isinstance(config.chord_pad_token_id, int) and 0 <= config.chord_pad_token_id < config.chord_vocab_size): raise ValueError(f"Invalid chord_pad_token_id: {config.chord_pad_token_id}")
        if not (isinstance(config.d_model, int) and config.d_model > 0): raise ValueError(f"Invalid d_model: {config.d_model}")
        if not (isinstance(config.n_head, int) and config.n_head > 0): raise ValueError(f"Invalid n_head: {config.n_head}")
        if config.d_model % config.n_head != 0: raise ValueError(f"d_model ({config.d_model}) not divisible by n_head ({config.n_head}).")
        self.d_head = config.d_model // config.n_head
        config.d_head = self.d_head
        if not (isinstance(config.mem_len, int) and config.mem_len >= 0): raise ValueError(f"Invalid mem_len: {config.mem_len}")
        if not (isinstance(config.n_layer, int) and config.n_layer > 0): raise ValueError(f"Invalid n_layer: {config.n_layer}")
        if not (isinstance(config.num_chord_features, int) and config.num_chord_features > 0): raise ValueError(f"Invalid num_chord_features: {config.num_chord_features}")
        if not (isinstance(config.condition_proj_dim, int) and config.condition_proj_dim >= 0): raise ValueError(f"Invalid condition_proj_dim: {config.condition_proj_dim}")

        # --- Model Components ---
        self.d_model = config.d_model; self.n_head = config.n_head; self.mem_len = config.mem_len; self.n_layer = config.n_layer
        self.melody_emb = nn.Embedding(config.melody_vocab_size, config.d_model, padding_idx=config.melody_pad_token_id)

        # --- Conditioning Processing ---
        condition_proj_dim = max(1, config.condition_proj_dim) if config.condition_proj_dim > 0 else 0
        total_conditioning_dim = 0
        self.chord_feature_processor = None
        if condition_proj_dim > 0 :
            self.chord_feature_processor = nn.Linear(config.num_chord_features, condition_proj_dim)
            total_conditioning_dim += condition_proj_dim
        else:
            logging.info("Raw chord features projection disabled (dim=0).")

        self.chord_emb = None
        if self.use_chord_embedding:
            chord_emb_dim = max(1, config.chord_emb_dim)
            self.chord_emb = nn.Embedding(config.chord_vocab_size, chord_emb_dim, padding_idx=config.chord_pad_token_id)
            total_conditioning_dim += chord_emb_dim

        # --- Input Projection ---
        self.input_proj = None
        if total_conditioning_dim > 0:
            combined_input_dim = config.d_model + total_conditioning_dim
            self.input_proj = nn.Linear(combined_input_dim, config.d_model)
            logging.info(f"Input projection: Combined dim {combined_input_dim} -> {config.d_model}")
        else:
            logging.info("No conditioning used, skipping input projection.")

        self.drop = nn.Dropout(config.dropout)

        # --- Transformer Layers ---
        self.layers = nn.ModuleList([
            TransformerXLLayer(n_head=self.n_head, d_model=self.d_model, d_head=self.d_head, d_inner=config.d_inner, dropout=config.dropout, config=config, layer_idx=i)
            for i in range(config.n_layer)
        ])

        # --- Output Layer ---
        self.final_norm = nn.LayerNorm(config.d_model)
        self.out_layer = nn.Linear(config.d_model, config.melody_vocab_size, bias=False)

        # Weight Tying
        if self.melody_emb.embedding_dim == self.out_layer.in_features:
           self.out_layer.weight = self.melody_emb.weight
           logging.info("Tying input melody embedding weights with the final output layer.")
        else:
            logging.warning(f"Output layer in_features ({self.out_layer.in_features}) != Melody embedding dim ({self.melody_emb.embedding_dim}). Weights not tied.")

        # Apply custom weight initialization AFTER defining all layers
        self.apply(self._init_weights)
        logging.info(f"MelodyTransformerXL initialized: Layers={config.n_layer}, d_model={config.d_model}, n_head={config.n_head}, mem_len={config.mem_len}, ChordEmb={self.use_chord_embedding}, ChordFeat={self.chord_feature_processor is not None}")

    @property
    def dtype(self) -> torch.dtype:
        try: return next(self.parameters()).dtype
        except StopIteration: return torch.get_default_dtype()

    def _init_weights(self, module):
        """Initializes weights of linear and embedding layers."""
        scale = 1.0
        # Heuristic scale factor based on number of layers (from GPT-2)
        if hasattr(self.config, 'n_layer') and self.config.n_layer > 0:
            scale = 1 / math.sqrt(2.0 * self.config.n_layer)

        if isinstance(module, nn.Linear):
            # Initialize linear layers with small normal distribution
            nn.init.normal_(module.weight, mean=0.0, std=0.02 * scale)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            # Initialize embedding layers with small normal distribution
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            # Zero out padding token embedding correctly
            if module.padding_idx is not None:
                # <<< FIX: Use with torch.no_grad() context manager >>>
                with torch.no_grad():
                    module.weight[module.padding_idx].fill_(0)
                # <<< END FIX >>>
        elif isinstance(module, nn.LayerNorm):
            # Initialize LayerNorm bias to 0 and weight to 1
            if module.bias is not None:
                module.bias.data.zero_()
            if module.weight is not None:
                module.weight.data.fill_(1.0)


    def _update_mems(self, hids: List[Optional[torch.Tensor]], mems: List[Optional[torch.Tensor]], mlen: int) -> List[Optional[torch.Tensor]]:
        if mlen <= 0 or not hids: return [None] * (self.n_layer + 1)
        if mems is None or all(m is None or m.numel() == 0 for m in mems):
            return [(h[-mlen:].detach() if h is not None and h.dim() > 1 and h.shape[0] > 0 else None) for h in hids]
        if len(hids) != len(mems):
            logging.error(f"BUG: Mismatch hids({len(hids)}) vs mems({len(mems)}). Resetting memory.")
            return [(h[-mlen:].detach() if h is not None and h.dim() > 1 and h.shape[0] > 0 else None) for h in hids]
        new_mems = []
        with torch.no_grad():
            for i, (hid, mem) in enumerate(zip(hids, mems)):
                if hid is None: new_mems.append(mem.detach() if mem is not None else None); logging.warning(f"Hid layer {i} None in mem update."); continue
                if hid.dim() < 3 or hid.shape[0] == 0: logging.warning(f"Hid layer {i} bad shape {hid.shape}."); new_mems.append(mem.detach() if mem is not None else None); continue
                if mem is not None and mem.dim() == 3 and mem.numel() > 0:
                    if mem.shape[1:] == hid.shape[1:]: cat = torch.cat([mem, hid], dim=0)
                    else: logging.warning(f"Mem/Hid shape mismatch L{i}: Mem={mem.shape}, Hid={hid.shape}. Resetting mem."); cat = hid
                else: cat = hid
                new_mems.append(cat[-mlen:].detach())
        return new_mems

    def init_mems(self, bsz: int, device: torch.device, dtype: torch.dtype) -> List[Optional[torch.Tensor]]:
        return [None] * (self.n_layer + 1) # Initialize with Nones, let _update_mems handle creation

    def _create_attn_mask(self, qlen: int, mlen: int, device: torch.device) -> Optional[torch.Tensor]:
        if qlen <= 0: return None
        klen = mlen + qlen
        mask = torch.triu(torch.ones(qlen, klen, device=device, dtype=torch.bool), diagonal=1 + mlen)
        return mask

    def forward(
        self,
        event_ids: torch.Tensor,                  # (bsz, qlen) [Long]
        conditioning_chord_ids: torch.Tensor,      # (bsz, qlen) [Long]
        conditioning_root_pc: torch.Tensor,       # (bsz, qlen) [Float/Half]
        conditioning_quality_code: torch.Tensor, # (bsz, qlen) [Float/Half]
        conditioning_function_code: torch.Tensor, # (bsz, qlen) [Float/Half]
        mems: Optional[List[Optional[torch.Tensor]]] = None # Memory
    ) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]]]:

        bsz, qlen = event_ids.size(); device = event_ids.device; target_dtype = self.dtype
        if qlen == 0: logging.warning("Forward qlen=0."); return torch.empty((bsz, 0, self.config.melody_vocab_size), device=device, dtype=target_dtype), mems if mems is not None else self.init_mems(bsz, device, target_dtype)
        if self.mem_len > 0:
            if mems is None: mems = self.init_mems(bsz, device, target_dtype)
            elif not isinstance(mems, list) or len(mems) != self.n_layer + 1: logging.warning(f"Incorrect mem list. Resetting."); mems = self.init_mems(bsz, device, target_dtype)
            # Validate memory compatibility before use (more robust)
            valid_mems = True
            for i, mem in enumerate(mems):
                if mem is not None and mem.numel() > 0:
                    if mem.shape[1] != bsz or mem.shape[2] != self.d_model or mem.device != device or mem.dtype != target_dtype:
                        logging.warning(f"Memory state at index {i} incompatible (Shape:{mem.shape}, DType:{mem.dtype}, Device:{mem.device} vs Input: bsz={bsz}, d_model={self.d_model}, dtype={target_dtype}, device={device}). Resetting memory.")
                        mems = self.init_mems(bsz, device, target_dtype)
                        valid_mems = False
                        break
            mlen = mems[0].size(0) if valid_mems and mems[0] is not None and mems[0].dim() == 3 else 0
        else: mems = [None] * (self.n_layer + 1); mlen = 0

        clamped_event_ids = event_ids.clamp(0, self.config.melody_vocab_size - 1); melody_embedded = self.melody_emb(clamped_event_ids)
        all_conditioning_tensors = []
        if self.chord_feature_processor is not None:
            cond_features_raw = torch.stack([conditioning_root_pc, conditioning_quality_code, conditioning_function_code], dim=-1).to(target_dtype)
            cond_features_proj = F.relu(self.chord_feature_processor(cond_features_raw)); all_conditioning_tensors.append(cond_features_proj)
        if self.use_chord_embedding and self.chord_emb is not None:
             clamped_chord_ids = conditioning_chord_ids.clamp(0, self.config.chord_vocab_size - 1)
             chord_embedded = self.chord_emb(clamped_chord_ids); all_conditioning_tensors.append(chord_embedded)

        if all_conditioning_tensors:
            cond_combined = torch.cat(all_conditioning_tensors, dim=-1); combined_input_features = torch.cat([melody_embedded, cond_combined], dim=-1)
            if self.input_proj is not None: core_input = self.input_proj(combined_input_features)
            else: logging.error("Conditioning exists but input_proj is None."); core_input = melody_embedded
        else: core_input = self.input_proj(melody_embedded) if self.input_proj is not None else melody_embedded

        core_input = self.drop(core_input).transpose(0, 1).contiguous()
        attn_mask = self._create_attn_mask(qlen, mlen, device)
        hids_for_mem = [core_input]; layer_input = core_input
        for i, layer in enumerate(self.layers):
            layer_mem = mems[i] if mems is not None else None
            try: layer_output = layer(layer_input, mems=layer_mem, attn_mask=attn_mask); hids_for_mem.append(layer_output); layer_input = layer_output
            except Exception as e: raise RuntimeError(f"Failed during forward pass in layer {i}") from e
        new_mems = self._update_mems(hids_for_mem, mems, self.mem_len)
        core_output = self.drop(layer_input); final_output = self.final_norm(core_output)
        logits = self.out_layer(final_output).transpose(0, 1).contiguous()
        return logits, new_mems
# --- END Model Class Definitions ---


# === MIDI Conversion Function ===
def events_to_midi(event_sequence: List[str],
                   output_midi_path: str = "generated_melody.mid",
                   resolution: int = 480,
                   default_velocity: int = 100,
                   initial_tempo_qpm: float = 120.0):
    """
    Converts a sequence of event tokens into a MIDI file using pretty_midi.
    Handles potential errors during conversion and provides a summary.
    """
    if not event_sequence:
        print("Warning: Empty event sequence provided to events_to_midi. No MIDI file generated.")
        return

    print(f"Starting MIDI conversion for {len(event_sequence)} events...")
    logging.info(f"Starting MIDI conversion for {len(event_sequence)} events...")

    # Initialize counters outside the main try block
    skipped_events = 0
    note_count = 0
    malformed_events = 0
    unrecognized_events = 0
    zero_duration_notes = 0
    negative_time_shifts = 0
    note_add_count = 0
    note_skip_zero_dur_count = 0
    notes_closed_at_end = 0
    zero_dur_at_end = 0

    try:
        midi_data = pretty_midi.PrettyMIDI(resolution=resolution, initial_tempo=initial_tempo_qpm)
        instrument = pretty_midi.Instrument(program=0, is_drum=False, name='Generated Melody') # Program 0 = Acoustic Grand Piano

        current_time_seconds = 0.0
        current_velocity = default_velocity
        active_notes = {} # pitch -> (start_time_seconds, velocity) - Tracks notes currently playing

        # Define expected special tokens
        special_tokens = {MELODY_PAD_TOKEN, MELODY_UNK_TOKEN, "<EOS>", "<START>", "<END>"} # Use defined constants

        time_shift_unit = 0.001 # Assume TIME_SHIFT value is in milliseconds

        for i, event_str in enumerate(tqdm(event_sequence, desc="Converting Events", leave=False, unit="event")):
            # Use 'continue' for skipping within the loop
            if not isinstance(event_str, str) or not event_str or event_str in special_tokens:
                skipped_events += 1
                continue

            try:
                parts = event_str.split('_')
                if len(parts) < 2:
                    logging.warning(f"Event {i}: Malformed event '{event_str}'. Skipping.")
                    malformed_events += 1
                    continue

                event_type = "_".join(parts[:-1]).upper()
                value_str = parts[-1]

                try:
                    value = int(value_str)
                except ValueError:
                    logging.warning(f"Event {i}: Could not parse integer value '{value_str}' in event '{event_str}'. Skipping.")
                    malformed_events += 1
                    continue

                # --- Event Logic ---
                if event_type == "TIME_SHIFT":
                    is_neg = value < 0
                    if is_neg:
                        logging.warning(f"Event {i}: Negative TIME_SHIFT ({value}) encountered. Clamping to 0.")
                        negative_time_shifts += 1
                        value = 0
                    time_delta_seconds = value * time_shift_unit
                    current_time_seconds += time_delta_seconds

                elif event_type == "NOTE_ON":
                    pitch = value
                    if not (0 <= pitch <= 127): logging.warning(f"Event {i}: Invalid MIDI pitch {pitch} in NOTE_ON. Skipping."); malformed_events += 1; continue

                    if pitch in active_notes: # End previous instance of same note if active
                        prev_start_time, prev_vel = active_notes[pitch]
                        if current_time_seconds > prev_start_time + 1e-6: # Ensure positive duration
                            note = pretty_midi.Note(velocity=prev_vel, pitch=pitch, start=prev_start_time, end=current_time_seconds)
                            instrument.notes.append(note); note_count += 1; note_add_count += 1
                        else:
                            zero_duration_notes += 1; note_skip_zero_dur_count += 1 # Count it as skipped due to zero duration
                            # Optionally log this specific case:
                            # logging.debug(f"Event {i}: Re-triggering NOTE_ON_{pitch} at same time {current_time_seconds:.4f}s. Previous instance skipped.")

                    active_notes[pitch] = (current_time_seconds, current_velocity) # Start new note

                elif event_type == "NOTE_OFF":
                    pitch = value
                    if not (0 <= pitch <= 127): logging.warning(f"Event {i}: Invalid MIDI pitch {pitch} in NOTE_OFF. Skipping."); malformed_events += 1; continue
                    if pitch in active_notes:
                        start_time_sec, vel = active_notes.pop(pitch) # Remove from active notes
                        end_time_sec = current_time_seconds
                        if end_time_sec > start_time_sec + 1e-6: # Ensure positive duration
                            note = pretty_midi.Note(velocity=vel, pitch=pitch, start=start_time_sec, end=end_time_sec)
                            instrument.notes.append(note); note_count += 1; note_add_count += 1
                        else:
                            logging.warning(f"Event {i}: NOTE_OFF_{pitch} at time {end_time_sec:.4f}s resulted in zero/negative duration (start time {start_time_sec:.4f}s). Skipping note.")
                            zero_duration_notes += 1
                            note_skip_zero_dur_count += 1
                    else:
                         logging.debug(f"Event {i}: NOTE_OFF_{pitch} received for inactive note. Ignoring.") # Less severe than warning


                elif event_type == "SET_VELOCITY":
                    current_velocity = max(0, min(127, value)) # Clamp velocity

                else:
                    logging.warning(f"Event {i}: Unrecognized event type '{event_type}' in '{event_str}'. Skipping.")
                    unrecognized_events += 1

            except Exception as e:
                logging.error(f"Unexpected error processing event '{event_str}' at index {i}: {e}", exc_info=True)
                skipped_events += 1 # Count as skipped due to unexpected error


        # --- Final Cleanup ---
        if active_notes:
            logging.warning(f"Found {len(active_notes)} notes still active at the end of the sequence. Closing them at final time {current_time_seconds:.4f}s.")
            for pitch, (start_time_sec, vel) in active_notes.items():
                 if current_time_seconds > start_time_sec + 1e-6:
                      note = pretty_midi.Note(velocity=vel, pitch=pitch, start=start_time_sec, end=current_time_seconds)
                      instrument.notes.append(note); note_count += 1; notes_closed_at_end += 1
                 else:
                      zero_duration_notes += 1; zero_dur_at_end += 1
            logging.info(f"Notes closed at end: {notes_closed_at_end}, Zero duration at end: {zero_dur_at_end}")

        midi_data.instruments.append(instrument)

        # --- Write MIDI File ---
        midi_data.write(output_midi_path) # Attempt to write the file
        print("-" * 30)
        print(f"MIDI Conversion Summary:")
        print(f"  MIDI file successfully written to: {output_midi_path}")
        print(f"  Notes created: {note_count} (Added: {note_add_count + notes_closed_at_end})") # Use updated counters
        total_problematic = skipped_events + malformed_events + unrecognized_events + zero_duration_notes + negative_time_shifts
        print(f"  Total problematic events skipped/handled: {total_problematic}")
        if skipped_events > 0: print(f"    - Special/Empty/Error tokens skipped: {skipped_events}")
        if malformed_events > 0: print(f"    - Malformed events skipped: {malformed_events}")
        if unrecognized_events > 0: print(f"    - Unrecognized events skipped: {unrecognized_events}")
        if zero_duration_notes > 0: print(f"    - Zero/Negative duration notes skipped: {zero_duration_notes} (During seq: {note_skip_zero_dur_count}, At end: {zero_dur_at_end})")
        if negative_time_shifts > 0: print(f"    - Negative TIME_SHIFT values clamped: {negative_time_shifts}")
        print("-" * 30)
        logging.info(f"MIDI file successfully written to {output_midi_path}")
        logging.info(f"Converted {note_count} notes. Skipped/Handled {total_problematic} problematic events.")

    except Exception as e:
        # This outer except block catches errors during setup or writing
        print(f"\nFATAL ERROR during MIDI conversion or writing to {output_midi_path}: {e}")
        logging.error(f"Error during MIDI conversion: {e}", exc_info=True)
        # Still print summary with potentially incomplete counts if error happened mid-loop
        print("-" * 30)
        print(f"MIDI Conversion Summary (ERROR OCCURRED):")
        print(f"  Attempted to write to: {output_midi_path}")
        print(f"  Notes converted before error (approx): {note_count}")
        total_problematic = skipped_events + malformed_events + unrecognized_events + zero_duration_notes + negative_time_shifts
        print(f"  Total problematic events encountered before error: {total_problematic}")
        print("-" * 30)


# === Generation Function (Over Progression) ===
@torch.no_grad()
def generate_melody_over_progression(
    model: MelodyTransformerXL,
    config: TrainingConfig,
    device: torch.device,
    start_event_id: int,              # Single start token ID
    full_cond_ids: torch.Tensor,      # Full sequence (1, full_length) [Long]
    full_cond_root: torch.Tensor,     # Full sequence (1, full_length) [Float/Half]
    full_cond_qual: torch.Tensor,     # Full sequence (1, full_length) [Float/Half]
    full_cond_func: torch.Tensor,     # Full sequence (1, full_length) [Float/Half]
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 0.0,
    eos_token_id: Optional[int] = None,
    pad_token_id: int = 0
) -> List[int]:
    """
    Generates a melody sequence autoregressively using the trained model,
    conditioned step-by-step on the provided full chord progression.
    """
    model.eval() # Set model to evaluation mode
    model_dtype = model.dtype # Get model's expected data type

    # --- Input Validation ---
    if full_cond_ids.shape[0] != 1: raise ValueError("Generation currently only supports batch size 1.")
    bsz, full_length = full_cond_ids.shape
    if full_length == 0: raise ValueError("Conditioning sequence cannot be empty.")
    if not (full_cond_root.shape == (bsz, full_length) and full_cond_qual.shape == (bsz, full_length) and full_cond_func.shape == (bsz, full_length)):
        raise ValueError(f"Conditioning tensor lengths do not match. Chords:{full_cond_ids.shape}, Root:{full_cond_root.shape}, Qual:{full_cond_qual.shape}, Func:{full_cond_func.shape}")

    # Move conditioning tensors to the target device and ensure correct dtypes
    full_cond_ids = full_cond_ids.to(device)
    full_cond_root = full_cond_root.to(device=device, dtype=model_dtype)
    full_cond_qual = full_cond_qual.to(device=device, dtype=model_dtype)
    full_cond_func = full_cond_func.to(device=device, dtype=model_dtype)

    # --- Initialize Memory and Output List ---
    mems = model.init_mems(bsz=bsz, device=device, dtype=model_dtype)
    generated_ids = [start_event_id] # Start with the initial melody token
    current_event_ids = torch.tensor([[start_event_id]], dtype=torch.long, device=device) # Shape: (1, 1)

    # --- Configure AMP ---
    amp_gen_enabled = (config.amp_dtype == torch.float16 and device.type == 'cuda')
    amp_gen_dtype = config.amp_dtype if amp_gen_enabled else None
    device_type_str = device.type
    if amp_gen_enabled: print(f"Using AMP for generation with dtype: {amp_gen_dtype}")

    # --- Autoregressive Generation Loop ---
    # Generate one token for each step of the provided conditioning sequence
    # The length of generated melody will be equal to the length of conditioning
    generation_steps = full_length - 1 # We generate one less token than the conditioning length
    print(f"Starting generation for {generation_steps} steps over the provided progression (Total output length: {full_length})...")

    for step in tqdm(range(generation_steps), desc="Generating Melody", unit="token"):
        # Get conditioning for the *current* time step
        # Note: We use conditioning from step 'step' to predict event 'step+1'
        # Input shapes to model should be (bsz, qlen=1, dim)
        current_cond_id   = full_cond_ids[:, step:step+1]
        current_cond_root = full_cond_root[:, step:step+1]
        current_cond_qual = full_cond_qual[:, step:step+1]
        current_cond_func = full_cond_func[:, step:step+1]

        try:
            # --- Model Forward Pass (uses previous melody token + current conditioning) ---
            with torch.amp.autocast(device_type=device_type_str, dtype=amp_gen_dtype, enabled=amp_gen_enabled):
                 logits, new_mems = model(
                     event_ids=current_event_ids, # Shape (1, 1) - the previous token
                     conditioning_chord_ids=current_cond_id,
                     conditioning_root_pc=current_cond_root,
                     conditioning_quality_code=current_cond_qual,
                     conditioning_function_code=current_cond_func,
                     mems=mems # Pass the current memory state
                 )
            mems = new_mems

            # --- Process Logits for Next Token ---
            next_token_logits = logits[:, -1, :] # Shape: (1, vocab_size) - Logits for the token at this step

            # --- Token Sampling ---
            if temperature <= 0: # Greedy decoding
                 next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            else: # Temperature + Top-K/Top-P sampling
                 scaled_logits = next_token_logits / temperature
                 if top_k > 0:
                      v, _ = torch.topk(scaled_logits, min(top_k, scaled_logits.size(-1)))
                      kth_value = v[:, [-1]]
                      scaled_logits[scaled_logits < kth_value] = -float('Inf')
                 if top_p > 0.0:
                      sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True)
                      cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                      sorted_indices_to_remove = cumulative_probs > top_p
                      sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                      sorted_indices_to_remove[..., 0] = 0 # Never remove the most likely token

                      indices_to_remove = torch.zeros_like(scaled_logits, dtype=torch.bool).scatter_(
                          dim=-1, index=sorted_indices, src=sorted_indices_to_remove)
                      scaled_logits[indices_to_remove] = -float('Inf')
                 # Sample from the filtered distribution
                 probs = F.softmax(scaled_logits.float(), dim=-1)
                 next_token_id = torch.multinomial(probs, num_samples=1) # Shape: (1, 1)

            # --- Post-processing and Loop Control ---
            next_token_item = next_token_id.item() # Get the generated token ID as an integer

            # Check for End-of-Sequence token
            if eos_token_id is not None and next_token_item == eos_token_id:
                print(f"\nEOS token ({eos_token_id}) generated at step {step}. Stopping generation.")
                # Pad the rest of the sequence to match conditioning length if stopped early
                padding_needed = generation_length - step
                generated_ids.extend([pad_token_id] * padding_needed)
                break

            # Append generated token to the sequence
            generated_ids.append(next_token_item)

            # Update the input for the next iteration
            current_event_ids = next_token_id # Shape: (1, 1)

            # Optional: Periodic Garbage Collection / Cache Clearing
            if step > 0 and step % 200 == 0:
                gc.collect()
                if device.type == 'cuda': torch.cuda.empty_cache()

        # --- Error Handling within the Loop ---
        except RuntimeError as e:
            print(f"\nRuntimeError during generation step {step}: {e}")
            logging.error(f"RuntimeError during generation step {step}: {e}", exc_info=True)
            if 'cuda' in str(e).lower(): gc.collect(); torch.cuda.empty_cache(); print("CUDA cache cleared.")
            print("Stopping generation due to runtime error.")
            break
        except Exception as e:
             print(f"\nUnexpected error during generation step {step}: {e}")
             logging.error(f"Unexpected error during generation step {step}: {e}", exc_info=True)
             traceback.print_exc(); print("Stopping generation due to unexpected error.")
             break

    # Ensure final length matches conditioning length if generation finished normally
    if len(generated_ids) < full_length:
        padding_needed = full_length - len(generated_ids)
        logging.warning(f"Generation finished but sequence is short. Padding with {padding_needed} PAD tokens.")
        generated_ids.extend([pad_token_id] * padding_needed)

    return generated_ids


# === Helper Function to Load Real Priming Data ===
# (Not used in "generate over progression" mode, but kept for flexibility)
def load_real_prime_sequence(
    data_path: str,
    prime_len: int,
    melody_vocab: Dict[str, int],
    chord_vocab: Dict[str, int],
    config: TrainingConfig
) -> Optional[Tuple[List[int], List[int], List[float], List[float], List[float]]]:
    """
    Attempts to load a real priming sequence from the start of the dataset file.

    *** IMPORTANT USER ACTION REQUIRED: ***
    1. VERIFY the `data_path` points to your actual training data file (JSON Lines format expected).
    2. CHECK the keys used below match YOUR data structure.
       Based on your sample data, the keys should be:
       'event_ids', 'conditioning_chord_ids', 'conditioning_root_pc',
       'conditioning_quality_code', 'conditioning_function_code'
    3. UNCOMMENT the call to this function in the `__main__` block if you want to use priming.
    """
    print(f"(Attempting to load real priming sequence from: {data_path})") # Indicate it's for priming
    if not os.path.exists(data_path):
        logging.error(f"Priming data file not found: {data_path}")
        return None
    try:
        with open(data_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f): # Try first few lines
                if line_num > 5: break # Stop after a few attempts
                line = line.strip()
                if not line: continue

                try: data_item = json.loads(line)
                except json.JSONDecodeError as json_e: logging.warning(f"Skipping line {line_num+1} for priming (JSON decode): {json_e}"); continue

                # --- V V V --- KEYS BASED ON PROVIDED DATA --- V V V ---
                event_key = 'event_ids'
                chord_id_key = 'conditioning_chord_ids'
                root_pc_key = 'conditioning_root_pc'
                quality_key = 'conditioning_quality_code'
                function_key = 'conditioning_function_code'
                # --- ^ ^ ^ --- KEYS BASED ON PROVIDED DATA --- ^ ^ ^ ---

                event_data = data_item.get(event_key)
                chord_data = data_item.get(chord_id_key)
                root_pcs = data_item.get(root_pc_key)
                qualities = data_item.get(quality_key)
                functions = data_item.get(function_key)

                required_data = [event_data, chord_data, root_pcs, qualities, functions]
                if not all(d is not None for d in required_data):
                     missing_keys_str = [k for k,v in zip([event_key, chord_id_key, root_pc_key, quality_key, function_key], required_data) if v is None]
                     logging.warning(f"Missing required keys ({missing_keys_str}) in line {line_num+1} for priming. Check keys again. Trying next line."); continue

                if not isinstance(event_data, list) or len(event_data) < prime_len:
                     logging.warning(f"Sequence in line {line_num+1} is too short (< {prime_len}) for priming. Trying next line."); continue

                seq_len = len(event_data)
                if not (isinstance(chord_data, list) and len(chord_data) == seq_len and \
                        isinstance(root_pcs, list) and len(root_pcs) == seq_len and \
                        isinstance(qualities, list) and len(qualities) == seq_len and \
                        isinstance(functions, list) and len(functions) == seq_len):
                    logging.warning(f"Inconsistent types or lengths in line {line_num+1} for priming. Trying next line."); continue

                # --- Data Extraction & Conversion ---
                prime_cond_root = root_pcs[:prime_len]
                prime_cond_qual = qualities[:prime_len]
                prime_cond_func = functions[:prime_len]

                # Assume data is already IDs (integers) based on sample
                if event_data and isinstance(event_data[0], int):
                    prime_event_ids = event_data[:prime_len]
                else:
                    logging.error(f"Melody data in file (key '{event_key}') is not a list of integers."); continue

                if chord_data and isinstance(chord_data[0], int):
                    prime_cond_ids = chord_data[:prime_len]
                else:
                    logging.error(f"Chord data in file (key '{chord_id_key}') is not a list of integers."); continue

                # Final safety checks
                if max(prime_event_ids) >= config.melody_vocab_size: logging.error(f"Prime melody IDs > vocab size!"); continue
                if config.use_chord_embedding and max(prime_cond_ids) >= config.chord_vocab_size: logging.error(f"Prime chord IDs > vocab size!"); continue

                print(f"Successfully loaded real priming sequence of length {prime_len} from line {line_num+1}.")
                return (prime_event_ids, prime_cond_ids, prime_cond_root, prime_cond_qual, prime_cond_func)

    except Exception as e:
        logging.error(f"An unexpected error occurred while loading priming data: {e}", exc_info=True)
        return None

    logging.error(f"Could not find a suitable sequence for priming in the first few lines of {data_path}.")
    return None

# === Post-processing Function ===
def post_process_events(event_sequence: List[str], max_polyphony: int = 6, special_tokens: Set[str] = {MELODY_PAD_TOKEN, MELODY_UNK_TOKEN, "<EOS>", "<START>", "<END>"}) -> List[str]:
    # (This function remains the same as previous version)
    if max_polyphony <= 0: print("Warning: max_polyphony <= 0"); return event_sequence
    print(f"\n--- Starting Post-processing (Limiting Polyphony to {max_polyphony}) ---")
    processed_events = []; active_notes = set(); skipped_note_ons = 0
    for i, event_str in enumerate(tqdm(event_sequence, desc="Post-processing Events", leave=False, unit="event")):
        if not isinstance(event_str, str) or event_str in special_tokens: processed_events.append(event_str); continue
        try:
            parts = event_str.split('_'); event_type = "_".join(parts[:-1]).upper(); value = int(parts[-1])
            if event_type == "NOTE_ON":
                pitch = value
                if len(active_notes) >= max_polyphony and pitch not in active_notes: skipped_note_ons += 1; continue
                else: processed_events.append(event_str); active_notes.add(pitch)
            elif event_type == "NOTE_OFF": processed_events.append(event_str); active_notes.discard(value)
            else: processed_events.append(event_str)
        except: processed_events.append(event_str) # Keep event if parsing fails
    print(f"Post-processing finished. Skipped {skipped_note_ons} NOTE_ON events."); return processed_events

# === Main Generation Execution ===
if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(module)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
    generation_script_start_time = time.time()

    # --- Configuration & Setup ---
    # **********************************************************************
    # ** USER: PLEASE UPDATE THESE PATHS TO MATCH YOUR ENVIRONMENT **
    # **********************************************************************
    CHECKPOINT_PATH = "/kaggle/working/melody_model_output/best_model.pth"
    MELODY_VOCAB_PATH = "/kaggle/input/new-melody-model-new-approach-1/event_vocab.json"
    CHORD_DATA_DIR = "/kaggle/input/advance-h-rpe"
    # MELODY_DATA_PATH only needed if using load_real_prime_sequence below
    MELODY_DATA_PATH = "/kaggle/input/new-melody-model-new-approach-1/training_data.jsonl"
    OUTPUT_DIR = "/kaggle/working/generated_output"
    # **********************************************************************

    CHORD_VOCAB_FILENAME = "chord_progression_vocab.json"
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # --- Generation Parameters ---
    TEMPERATURE = 0.75
    TOP_K = 40
    TOP_P = 0.9
    MAX_POLYPHONY = 6           # Max simultaneous notes allowed by post-processing filter
    START_MELODY_TOKEN = "NOTE_ON_60" # Default: Middle C (MIDI note 60)

    # --- Device Setup ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device} ({torch.cuda.get_device_name(0) if device.type=='cuda' else 'CPU'})")

    # --- Load Vocabularies ---
    print("\n--- Loading Vocabularies ---")
    try:
        with open(MELODY_VOCAB_PATH, 'r', encoding='utf-8') as f: melody_vocab = json.load(f); melody_inv_vocab = {v:k for k,v in melody_vocab.items()}
        chord_vocab_path = Path(CHORD_DATA_DIR) / CHORD_VOCAB_FILENAME;
        with open(chord_vocab_path, 'r', encoding='utf-8') as f: chord_vocab = json.load(f); chord_inv_vocab = {v:k for k,v in chord_vocab.items()} # Useful for debugging chord IDs
        MELODY_PAD_ID = melody_vocab.get(MELODY_PAD_TOKEN, 0); CHORD_PAD_ID = chord_vocab.get("<PAD>", 0)
        MELODY_EOS_TOKEN = "<EOS>"; MELODY_EOS_ID = melody_vocab.get(MELODY_EOS_TOKEN, -1)
        SPECIAL_TOKENS_SET = {MELODY_PAD_TOKEN, MELODY_UNK_TOKEN, MELODY_EOS_TOKEN, "<START>", "<END>"}
        print(f"Melody Vocab Size: {len(melody_vocab)}, Chord Vocab Size: {len(chord_vocab)}")
        print(f"Melody Pad ID: {MELODY_PAD_ID}, Chord Pad ID: {CHORD_PAD_ID}, EOS ID: {MELODY_EOS_ID if MELODY_EOS_ID!=-1 else 'N/A'}")
    except Exception as e: print(f"FATAL: Vocab loading error: {e}"); traceback.print_exc(); sys.exit(1)

    # --- Load Model and Configuration ---
    print(f"\n--- Loading Model Checkpoint ---")
    try:
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=False)
        loaded_config_dict = checkpoint.get('config'); assert isinstance(loaded_config_dict, dict), "Config missing/invalid in checkpoint"
        config = TrainingConfig() # Create default config
        valid_keys = {f.name for f in fields(config)}
        for k, v in loaded_config_dict.items():
            if k in valid_keys:
                if k=='amp_dtype' and isinstance(v,str): config.amp_dtype={'float16':torch.float16,'bfloat16':torch.bfloat16}.get(v.split('.')[-1],None)
                else: setattr(config, k, v)
        config.melody_vocab_size=loaded_config_dict.get('melody_vocab_size', len(melody_vocab))
        config.chord_vocab_size=loaded_config_dict.get('chord_vocab_size', len(chord_vocab))
        config.melody_pad_token_id=loaded_config_dict.get('melody_pad_token_id', MELODY_PAD_ID)
        config.chord_pad_token_id=loaded_config_dict.get('chord_pad_token_id', CHORD_PAD_ID)
        if config.melody_pad_token_id != MELODY_PAD_ID or config.chord_pad_token_id != CHORD_PAD_ID: logging.warning("Pad ID mismatch! Checkpoint vs Vocab file. Using Checkpoint IDs.")
        MELODY_PAD_ID=config.melody_pad_token_id; CHORD_PAD_ID=config.chord_pad_token_id
        print("\n--- Final Configuration Used for Model ---"); [print(f"  {k}: {getattr(config, k)}") for k in ['melody_vocab_size','chord_vocab_size','melody_pad_token_id','chord_pad_token_id','n_layer','d_model','n_head','d_head','mem_len','amp_dtype']]

        model = MelodyTransformerXL(config).to(device) # Instantiate model AFTER setting config
        state_dict = checkpoint['model_state_dict']; assert state_dict, "model_state_dict missing"
        is_parallel_ckpt = all(key.startswith('module.') for key in state_dict)
        if is_parallel_ckpt:
            logging.info("Detected DataParallel/DDP checkpoint. Removing 'module.' prefix.")
            state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) # Use loaded state_dict
        if missing_keys: logging.warning(f"Loaded model state dict was missing keys: {missing_keys}")
        if unexpected_keys: logging.warning(f"Loaded model state dict had unexpected keys: {unexpected_keys}")
        print(f"\nModel loaded successfully.")
        model.eval(); model_dtype = model.dtype; print(f"Model parameter dtype: {model_dtype}")
    except Exception as e: print(f"FATAL: Model loading error: {e}"); traceback.print_exc(); sys.exit(1)

    # --- Prepare FULL Chord Progression & Conditioning ---
    print(f"\n--- Preparing Full Chord Progression ---")
    # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    # !! USER ACTION REQUIRED: Load or Define Your Chord Progression Here     !!
    # !! Provide the *full* sequences for chord IDs and features below.       !!
    # !! The length of these lists determines the length of the generated     !!
    # !! melody. Ensure they are properly aligned temporally.                 !!
    # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    # *** REPLACE THIS EXAMPLE WITH YOUR ACTUAL PROGRESSION DATA ***
    # This example creates a 512-token long progression: Cmaj7 | Fmaj7 | Dmin7 | G7, repeated.
    # You'll need to look up the correct IDs from your chord_vocab.json for your desired chords.
    # Also, provide the correct normalized root_pc (0-11 -> 0.0-11.0/12.0), quality, and function codes.
    # The `tokens_per_chord` determines how many melody events are generated for each chord. Adjust this based on tempo and desired rhythm density.

    # --- START OF EXAMPLE DATA (TO BE REPLACED BY USER) ---
    try:
        # Example Chord IDs (Use your actual IDs from chord_vocab.json)
        example_chord_map = {
            "Cmaj7": {"id": chord_vocab.get("C:maj7", 0), "root": 0/12, "qual": 1.0, "func": 1.0},
            "Fmaj7": {"id": chord_vocab.get("F:maj7", 0), "root": 5/12, "qual": 1.0, "func": 4.0},
            "Dm7":   {"id": chord_vocab.get("D:min7", 0), "root": 2/12, "qual": 0.0, "func": 2.0},
            "G7":    {"id": chord_vocab.get("G:7", 0),    "root": 7/12, "qual": 8.0, "func": 5.0}, # Check if 8 is correct for dominant 7th in your system
            "PAD":   {"id": CHORD_PAD_ID,                 "root": 0.0, "qual": 9.0, "func": 10.0} # Ensure pad features align with training
        }

        # Define the progression using the names above
        progression_sequence = ["Cmaj7", "Fmaj7", "Dm7", "G7"] * 8 # 32 chords total

        tokens_per_chord = 16 # How many event tokens correspond to one chord change? ADJUST AS NEEDED!
        prog_len = len(progression_sequence) * tokens_per_chord # Total tokens for conditioning

        full_chord_ids_list = []
        full_root_pc_list = []
        full_quality_code_list = []
        full_function_code_list = []

        valid_progression = True
        for chord_name in progression_sequence:
            if chord_name in example_chord_map:
                chord_info = example_chord_map[chord_name]
                full_chord_ids_list.extend([chord_info["id"]] * tokens_per_chord)
                full_root_pc_list.extend([chord_info["root"]] * tokens_per_chord)
                full_quality_code_list.extend([chord_info["qual"]] * tokens_per_chord)
                full_function_code_list.extend([chord_info["func"]] * tokens_per_chord)
            else:
                logging.error(f"Chord '{chord_name}' not found in example_chord_map. Please define it or check spelling. Using PAD.")
                pad_info = example_chord_map["PAD"]
                full_chord_ids_list.extend([pad_info["id"]] * tokens_per_chord)
                full_root_pc_list.extend([pad_info["root"]] * tokens_per_chord)
                full_quality_code_list.extend([pad_info["qual"]] * tokens_per_chord)
                full_function_code_list.extend([pad_info["func"]] * tokens_per_chord)
                # If a chord is missing, you might want to stop: valid_progression = False; break

        if not valid_progression:
            print("FATAL: Invalid chord progression defined. Check logs.")
            sys.exit(1)

    except Exception as e:
        print(f"FATAL: Error defining or processing the chord progression: {e}")
        traceback.print_exc()
        sys.exit(1)
    # --- END OF EXAMPLE/USER DATA SECTION ---


    # --- Check Data Length ---
    if prog_len == 0:
        print("FATAL: Chord progression data is empty. Please define or load it.")
        sys.exit(1)
    print(f"Using chord progression of length: {prog_len} tokens")


    # --- Define Start Melody Token ---
    start_token_name = "NOTE_ON_60" # Default: Middle C
    start_event_id = melody_vocab.get(start_token_name)
    if start_event_id is None:
        start_event_id = melody_vocab.get("<START>", melody_vocab.get("NOTE_ON_60", 138)) # Try <START> or fallback
        logging.warning(f"Start token '{start_token_name}' not found, using fallback ID {start_event_id} ('{melody_inv_vocab.get(start_event_id, 'UNK')}')")
    print(f"Using start melody token: '{melody_inv_vocab.get(start_event_id, 'UNK')}' (ID: {start_event_id})")

    # --- Convert Progression to Tensors ---
    try:
        full_input_cond_ids = torch.tensor([full_chord_ids_list], dtype=torch.long, device=device)
        full_input_cond_root = torch.tensor([full_root_pc_list], dtype=model_dtype, device=device)
        full_input_cond_qual = torch.tensor([full_quality_code_list], dtype=model_dtype, device=device)
        full_input_cond_func = torch.tensor([full_function_code_list], dtype=model_dtype, device=device)
    except Exception as e: print(f"FATAL: Tensor conversion error: {e}"); traceback.print_exc(); sys.exit(1)


    # --- Run Generation ---
    print(f"\n--- Starting Melody Generation Over Progression ---")
    print(f"Progression length: {prog_len}")
    print(f"Temperature: {TEMPERATURE}, Top-K: {TOP_K}, Top-P: {TOP_P}")
    start_gen_time = time.time()

    try:
        # Call the specific generation function
        generated_sequence_ids = generate_melody_over_progression(
            model=model,
            config=config,
            device=device,
            start_event_id=start_event_id,
            full_cond_ids=full_input_cond_ids,
            full_cond_root=full_input_cond_root,
            full_cond_qual=full_input_cond_qual,
            full_cond_func=full_input_cond_func,
            temperature=TEMPERATURE,
            top_k=TOP_K,
            top_p=TOP_P,
            eos_token_id=MELODY_EOS_ID if MELODY_EOS_ID != -1 else None,
            pad_token_id=MELODY_PAD_ID
        )
    except Exception as e: print(f"FATAL: Generation error: {e}"); traceback.print_exc(); sys.exit(1)

    gen_duration = time.time() - start_gen_time
    print(f"\nGeneration Complete! Took {gen_duration:.2f} seconds.")
    print(f"Total sequence length generated: {len(generated_sequence_ids)}") # Should match prog_len

    # --- Decode Sequence ---
    print("\n--- Decoding Generated Sequence ---")
    try:
        decoded_sequence_raw = [melody_inv_vocab.get(idx, f"<UNK_ID_{idx}>") for idx in generated_sequence_ids]
        print(f"First 100 RAW decoded tokens (including start token):")
        print(decoded_sequence_raw[:100])
    except Exception as e: print(f"Decoding error: {e}"); decoded_sequence_raw = []

    # Apply Post-processing
    if decoded_sequence_raw:
        try:
            processed_sequence = post_process_events(
                decoded_sequence_raw,
                max_polyphony=MAX_POLYPHONY,
                special_tokens=SPECIAL_TOKENS_SET
            )
            print(f"Original generated length: {len(decoded_sequence_raw)}, Processed length: {len(processed_sequence)}")
        except Exception as e: print(f"Post-processing error: {e}. Using raw sequence."); traceback.print_exc(); processed_sequence = decoded_sequence_raw
    else: processed_sequence = []

    # --- Save Outputs ---
    print("\n--- Saving Outputs ---")
    # Determine next available Melody_N filename
    next_run_number = 1
    try:
        if os.path.exists(OUTPUT_DIR):
            filename_pattern = re.compile(r"^Melody_(\d+)\.(json|mid)$", re.IGNORECASE)
            max_num = 0
            for filename in os.listdir(OUTPUT_DIR):
                match = filename_pattern.match(filename)
                if match: max_num = max(max_num, int(match.group(1)))
            next_run_number = max_num + 1
        else: logging.info(f"Output directory {OUTPUT_DIR} not found. Creating and starting count at 1.")
        os.makedirs(OUTPUT_DIR, exist_ok=True) # Ensure directory exists just before saving
    except Exception as e: logging.error(f"Error finding next run number: {e}. Defaulting to 1."); next_run_number = 1

    gen_filename_base = f"Melody_{next_run_number}"
    output_json_path = os.path.join(OUTPUT_DIR, f"{gen_filename_base}.json")
    output_midi_path = os.path.join(OUTPUT_DIR, f"{gen_filename_base}.mid")
    print(f"Output base name: {gen_filename_base}")

    try:
        config_dict_serializable = config.as_dict()
        output_data = {
            "generation_params": {
                "generation_mode": "Over Progression",
                "output_filename_base": gen_filename_base,
                "checkpoint_path": str(CHECKPOINT_PATH),
                "progression_length": prog_len,
                "start_event_id": start_event_id,
                "temperature": TEMPERATURE,"top_k": TOP_K,"top_p": TOP_P,
                "max_polyphony_filter": MAX_POLYPHONY,
                "eos_token_id": MELODY_EOS_ID,
                "device": str(device),
                "generation_duration_sec": gen_duration,
                "generation_timestamp": datetime.datetime.now().isoformat(),
            },
            "config_used": config_dict_serializable,
            # Optionally save conditioning IDs (can be large)
            # "conditioning_ids": full_chord_ids_list,
            "generated_ids_full_raw": generated_sequence_ids,
            "generated_tokens_full_raw": decoded_sequence_raw,
            "generated_tokens_full_processed": processed_sequence,
        }
        with open(output_json_path, 'w', encoding='utf-8') as f: json.dump(output_data, f, indent=2)
        print(f"Generated sequence data (JSON) saved to: {output_json_path}")
    except Exception as e: print(f"\nERROR saving generated sequence JSON: {e}"); traceback.print_exc()

    # --- Convert to MIDI ---
    if processed_sequence:
        print("\n--- Converting PROCESSED Sequence to MIDI ---")
        try: events_to_midi(processed_sequence, output_midi_path)
        except Exception as e: print(f"\nERROR during MIDI conversion: {e}"); traceback.print_exc()
    else: print("\nSkipping MIDI conversion due to issues in sequence decoding or processing.")

    # --- Script Finish ---
    print("\nScript finished.")
    total_script_runtime = time.time() - generation_script_start_time
    print(f"Total Generation Script Runtime: {total_script_runtime:.2f} seconds")

Using device: cuda (Tesla P100-PCIE-16GB)

--- Loading Vocabularies ---
Melody Vocab Size: 306, Chord Vocab Size: 44733
Melody Pad ID: 0, Chord Pad ID: 0, EOS ID: N/A

--- Loading Model Checkpoint ---

--- Final Configuration Used for Model ---
  melody_vocab_size: 306
  chord_vocab_size: 44734
  melody_pad_token_id: 0
  chord_pad_token_id: 0
  n_layer: 8
  d_model: 512
  n_head: 8
  d_head: 64
  mem_len: 256
  amp_dtype: torch.float16

Model loaded successfully.
Model parameter dtype: torch.float32

--- Preparing Full Chord Progression ---
Using chord progression of length: 512 tokens
Using start melody token: 'NOTE_ON_60' (ID: 138)

--- Starting Melody Generation Over Progression ---
Progression length: 512
Temperature: 0.75, Top-K: 40, Top-P: 0.9
Using AMP for generation with dtype: torch.float16
Starting generation for 511 steps over the provided progression (Total output length: 512)...


Generating Melody: 100%|██████████| 511/511 [00:06<00:00, 84.37token/s]



Generation Complete! Took 6.06 seconds.
Total sequence length generated: 512

--- Decoding Generated Sequence ---
First 100 RAW decoded tokens (including start token):
['NOTE_ON_60', 'NOTE_ON_69', 'NOTE_ON_81', 'TIME_SHIFT_2', 'NOTE_ON_67', 'TIME_SHIFT_8', 'NOTE_ON_64', 'NOTE_ON_76', 'TIME_SHIFT_2', 'NOTE_ON_64', 'NOTE_ON_64', 'TIME_SHIFT_2', 'NOTE_ON_71', 'NOTE_OFF_71', 'TIME_SHIFT_8', 'NOTE_ON_63', 'NOTE_ON_68', 'NOTE_ON_68', 'NOTE_OFF_73', 'TIME_SHIFT_4', 'NOTE_ON_43', 'NOTE_OFF_43', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TIME_SHIFT_128', 'TI

                                                                  

Post-processing finished. Skipped 5 NOTE_ON events.
Original generated length: 512, Processed length: 507

--- Saving Outputs ---
Output base name: Melody_1
Generated sequence data (JSON) saved to: /kaggle/working/generated_output/Melody_1.json

--- Converting PROCESSED Sequence to MIDI ---
Starting MIDI conversion for 507 events...


                                                             

------------------------------
MIDI Conversion Summary:
  MIDI file successfully written to: /kaggle/working/generated_output/Melody_1.mid
  Notes created: 7 (Added: 7)
  Total problematic events skipped/handled: 1
    - Zero/Negative duration notes skipped: 1 (During seq: 1, At end: 0)
------------------------------

Script finished.
Total Generation Script Runtime: 6.97 seconds


