**1. Identifying the Target and Loss Function**

*   **Main Output (`output`):** The `SpectraEncoder`'s final output comes from `spectra_predictor`, which ends with `nn.Linear(hidden_size, output_size)` followed by `nn.Sigmoid()`. The `output_size` (e.g., 4096) and the Sigmoid activation strongly suggest the model predicts a **binary fingerprint** representation of the molecule corresponding to the input spectrum. The Sigmoid squashes the output between 0 and 1, suitable for representing the probability of each bit being 'on'.
*   **Target for Main Output:** The target should be the actual molecular fingerprint (e.g., Morgan fingerprint, RDKit fingerprint) derived from the molecule's SMILES string (like `CC(=O)NC@@HC2=CC(=CC(=O)O2)OC` for `MassSpecGymID0000001`). You'll need a function (likely using RDKit) to convert SMILES strings into fixed-size binary vectors of length `output_size`.
*   **Loss for Main Output:** Given the Sigmoid output and binary target, the standard loss function is **Binary Cross-Entropy (BCE)**. You can use `torch.nn.BCELoss`.
*   **Auxiliary Output (`pred_frag_fps`):** The `fragment_predictor` outputs a tensor of size `magma_modulo` (e.g., 2048). This corresponds to the predicted MAGMA fingerprints for the spectral peaks.
*   **Target for Auxiliary Output:** The target is the `magma_fps` tensor generated by the `SpectrumProcessor`. Note that this tensor contains 0s, 1s, and -1s (where -1 indicates missing data).
*   **Loss for Auxiliary Output:** You need a BCE-like loss that can handle the masked values (-1). A common approach is to calculate BCE only for entries where the target is not -1. This is often referred to as a **Masked BCE Loss**. If `magma_aux_loss` in your `SpectrumProcessor` is `True`, this loss should be calculated and added (potentially with a weighting factor) to the main loss.

**2. Searching the Project (Conceptual)**

Based on standard practices and the code provided:

*   **Loss Functions:** Look for files like `losses.py`, `train.py`, or `main.py`. You'd expect to find instantiations of `nn.BCELoss` and potentially a custom function for the masked BCE loss for the MAGMA fingerprints.
*   **Metrics:** Evaluation often uses fingerprint similarity metrics like the **Tanimoto coefficient (Jaccard index)**. Look for functions calculating this in training/validation loops or utility files.
*   **Data Loading:** Files like `dataset.py` or `data_loader.py` would define a PyTorch `Dataset` class. This class would:
    *   Read your input data (e.g., the TSV line).
    *   Parse the SMILES string and spectral data.
    *   Use the `SpectrumProcessor` to process the spectrum.
    *   Use a cheminformatics library (like RDKit) to generate the target fingerprint from the SMILES string.
    *   Return a dictionary containing processed spectrum features (like `spec_features`) and the target fingerprint.
*   **Training Loop:** `train.py` or `main.py` would contain the core logic: iterating through epochs and batches, performing forward/backward passes, calculating combined loss, updating weights, and logging metrics.
 
**Key Implementation Points:**

1.  **`smiles_to_fingerprint`:** Implement this using RDKit or another cheminformatics library to match the `output_size`.
2.  **`SpectraDataset`:** Adapt the data loading (`pd.read_csv`) to your specific input file format. Ensure it correctly extracts SMILES and the raw spectral JSON/dict.
3.  **`collate_fn`:** This is crucial and complex. It needs to correctly pad sequences of varying lengths (like `form_vec`, `peak_type`) and create an attention mask that the `FormulaTransformer` likely requires. The implementation provided is a basic example and might need significant adjustments based on how `FormulaTransformer` handles padded input and what padding values are appropriate.
4.  **`FormulaTransformer`:** The `DummyFormulaTransformer` needs to be replaced with the actual implementation from `models.modules`. The forward pass within `SpectraEncoder` might need adjustments based on the exact output shapes and meaning of the `FormulaTransformer`'s output (`encoder_output`, `aux_out["peak_tensor"]`). Pay close attention to whether outputs are per-peak or global ([CLS] token style).
5.  **MAGMA Auxiliary Loss:** The comparison between `pred_frag_fps` and `target_magma_fps` needs careful handling. Ensure their shapes align and represent the same thing (e.g., both per-peak fingerprints). If `pred_frag_fps` is a single vector per spectrum, you might need to aggregate `target_magma_fps` or modify the model architecture/loss.
6.  **Paths:** Replace placeholder paths like `"path/to/your/..."` with your actual file locations.
7.  **Dependencies:** Ensure you have `torch`, `numpy`, `pandas`, and `rdkit-pypi` installed.

This pipeline provides a structural foundation. You'll need to integrate your specific data, the actual `FormulaTransformer` module, and potentially refine the collation and loss calculations based on the MIST paper's details or further code exploration.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd # Assuming your input is like the TSV line
from rdkit import Chem
from rdkit.Chem import AllChem
import json
from pathlib import Path

In [None]:
# --- Assume these are defined in your project ---
# from models.mist_encoder import SpectraEncoder # Or SpectraEncoderGrowing
# from data.spectrum_processor import SpectrumProcessor
# from utils import collate_fn # A function to batch your processed data
# --- Using the classes provided in the prompt ---
from typing import Tuple
from models import modules # Assuming modules.py exists as per your encoder code

class SpectraEncoder(nn.Module):
    """SpectraEncoder."""
    def __init__(
        self,
        form_embedder: str = "float",
        output_size: int = 4096,
        hidden_size: int = 50,
        spectra_dropout: float = 0.0,
        top_layers: int = 1,
        refine_layers: int = 0,
        magma_modulo: int = 2048,
        **kwargs,
    ):
        super(SpectraEncoder, self).__init__()
        # --- Using dummy FormulaTransformer for demonstration ---
        # Replace with your actual modules.FormulaTransformer
        class DummyFormulaTransformer(nn.Module):
            def __init__(self, hidden_size, **kwargs):
                super().__init__()
                self.dummy_layer = nn.Linear(100, hidden_size) # Input size is arbitrary placeholder
                self.hidden_size = hidden_size
            def forward(self, batch, return_aux=False):
                # Dummy forward: needs actual implementation based on FormulaTransformer
                # This needs to process batch['form_vec'], batch['peak_type'], etc.
                # Returning dummy tensors of expected shapes
                batch_size = batch['form_vec'].shape[0] # Assuming batching adds a dimension
                num_peaks = batch['form_vec'].shape[1]
                dummy_encoder_output = torch.randn(batch_size, self.hidden_size)
                dummy_peak_tensor = torch.randn(batch_size, num_peaks, self.hidden_size)
                aux_out = {"peak_tensor": dummy_peak_tensor}
                if return_aux:
                    return dummy_encoder_output, aux_out
                else:
                    return dummy_encoder_output

        spectra_encoder_main = DummyFormulaTransformer( # Replace with actual modules.FormulaTransformer
            hidden_size=hidden_size,
            spectra_dropout=spectra_dropout,
            form_embedder=form_embedder,
            **kwargs,
        )
        # ... rest of the __init__ code from your prompt ...
        fragment_pred_parts = []
        for _ in range(top_layers - 1):
            fragment_pred_parts.append(nn.Linear(hidden_size, hidden_size))
            fragment_pred_parts.append(nn.ReLU())
            fragment_pred_parts.append(nn.Dropout(spectra_dropout))
        fragment_pred_parts.append(nn.Linear(hidden_size, magma_modulo))
        fragment_predictor = nn.Sequential(*fragment_pred_parts)

        top_layer_parts = []
        for _ in range(top_layers - 1):
            top_layer_parts.append(nn.Linear(hidden_size, hidden_size))
            top_layer_parts.append(nn.ReLU())
            top_layer_parts.append(nn.Dropout(spectra_dropout))
        top_layer_parts.append(nn.Linear(hidden_size, output_size))
        top_layer_parts.append(nn.Sigmoid())
        spectra_predictor = nn.Sequential(*top_layer_parts)
        self.spectra_encoder = nn.ModuleList([spectra_encoder_main, fragment_predictor, spectra_predictor])

    def forward(self, batch: dict) -> Tuple[torch.Tensor, dict]:
        """Forward pass."""
        # Assuming batch contains tensors now, not numpy arrays
        encoder_output, aux_out = self.spectra_encoder[0](batch, return_aux=True)

        # Need to handle the fact that peak_tensor might vary in num_peaks per batch item
        # Often, the transformer output corresponding to the [CLS] token is used,
        # or pooling is applied over peak_tensor. Adjusting based on common practice.
        # Let's assume encoder_output is the [CLS] token embedding [batch_size, hidden_size]
        # And aux_out["peak_tensor"] is [batch_size, num_peaks, hidden_size]

        # This prediction might need adjustment depending on how FormulaTransformer handles output
        # If fragment prediction is per peak:
        # pred_frag_fps = self.spectra_encoder[1](aux_out["peak_tensor"]) # Shape: [batch_size, num_peaks, magma_modulo]
        # If fragment prediction is based on global context (e.g., CLS token):
        pred_frag_fps = self.spectra_encoder[1](encoder_output) # Shape: [batch_size, magma_modulo]
        # ^^^ Choose the correct input for fragment_predictor based on MIST's design

        aux_outputs = {"pred_frag_fps": pred_frag_fps}

        output = self.spectra_encoder[2](encoder_output) # Shape: [batch_size, output_size]
        aux_outputs["h0"] = encoder_output

        return output, aux_outputs

class SpectrumProcessor:
    """Process raw spectral data directly without file reading."""
    # --- Using the class provided in the prompt ---
    # Note: Removed file I/O for MAGMA for simplicity in this example skeleton
    # You'll need to adapt _process_magma_file if using it.
    # Added Path import
    from pathlib import Path
    import numpy as np
    import json
    import pandas as pd # Added import

    cat_types = {"frags": 0, "loss": 1, "ab_loss": 2, "cls": 3}
    num_inten_bins = 10
    num_types = len(cat_types)
    cls_type_map = cat_types # Renamed cls_type to avoid conflict

    def __init__(
        self,
        augment_data: bool = False,
        augment_prob: float = 1,
        remove_prob: float = 0.1,
        remove_weights: str = "uniform",
        inten_prob: float = 0.1,
        cls_token_type: str = "ms1", # Renamed cls_type param
        max_peaks: int = None,
        inten_transform: str = "float",
        magma_modulo: int = 512,
        magma_aux_loss: bool = False,
        magma_folder: str = None, # Keep this if you implement MAGMA loading
    ):
        self.cls_token_type = cls_token_type # Use renamed param
        self.augment_data = augment_data
        self.remove_prob = remove_prob
        self.augment_prob = augment_prob
        self.remove_weights = remove_weights
        self.inten_prob = inten_prob
        self.max_peaks = max_peaks
        self.inten_transform = inten_transform
        self.aug_nbits = magma_modulo
        self.magma_aux_loss = magma_aux_loss
        # self.magma_folder = Path(magma_folder) if magma_folder else None
        # self.spec_name_to_magma_file = {}
        # if self.magma_aux_loss and self.magma_folder:
        #     self._initialize_magma_mapping()

    # --- Add dummy utils for demonstration ---
    class DummyUtils:
        VALID_MONO_MASSES = np.random.rand(18) # Placeholder
        ELEM_TO_IDX = {'C': 0, 'H': 1, 'N': 2, 'O': 3, 'P': 4, 'S': 5, 'F': 6, 'Cl': 7, 'Br': 8, 'I': 9, 'Si': 10, 'Se': 11, 'B': 12, 'Na': 13, 'K': 14, 'Mg': 15, 'Ca': 16, 'Fe': 17} # Example
        NUM_ELEM = 18 # Example
        ION_LST = ['[M+H]+', '[M+Na]+', '[M+K]+', '[M-H]-', '[M+Cl]-'] # Example
        ION_TO_IDX = {ion: i for i, ion in enumerate(ION_LST)} # Example

        @staticmethod
        def formula_to_dense(formula_str):
            # Dummy implementation - replace with actual formula parsing
            vec = np.zeros(DummyUtils.NUM_ELEM)
            if isinstance(formula_str, str) and 'C' in formula_str: # Basic check
                 vec[DummyUtils.ELEM_TO_IDX['C']] = int(formula_str.split('C')[1].split('H')[0]) if 'H' in formula_str else int(formula_str.split('C')[1]) # Very basic parsing
                 # Add more elements...
            return vec

        @staticmethod
        def get_ion_idx(ion_str):
            return DummyUtils.ION_TO_IDX.get(ion_str, 0) # Default to first ion if not found

    utils = DummyUtils() # Instantiate dummy utils
    # --- End dummy utils ---


    def process_raw_spectrum(self, raw_spectral_json, spec_id=None, train_mode=False):
        if isinstance(raw_spectral_json, str):
            tree = json.loads(raw_spectral_json)
        else:
            tree = raw_spectral_json
        if spec_id is None and "name" in tree:
            spec_id = tree["name"]
        peak_dict = self._get_peak_dict_from_raw(tree)
        if train_mode and self.augment_data:
            augment_peak = np.random.random() < self.augment_prob
            if augment_peak:
                peak_dict = self.augment_peak_dict(peak_dict)
        features = self._generate_features(peak_dict, spec_id)
        return features

    def _get_peak_dict_from_raw(self, tree: dict) -> dict:
        root_form = tree["cand_form"]
        root_ion = tree["cand_ion"]
        # Handle cases where output_tbl might be missing or empty
        output_tbl = tree.get("output_tbl")
        if output_tbl is None or not output_tbl.get("formula"):
             frags, intens, ions = [], [], []
        else:
             frags = output_tbl.get("formula", [])
             intens = output_tbl.get("ms2_inten", [])
             ions = output_tbl.get("ions", [])

        # Ensure all lists have the same length if extracted
        min_len = min(len(frags), len(intens), len(ions))
        frags, intens, ions = frags[:min_len], intens[:min_len], ions[:min_len]

        out_dict = {
            "frags": frags, "intens": intens, "ions": ions,
            "root_form": root_form, "root_ion": root_ion,
        }

        if self.max_peaks is not None and len(intens) > 0:
            inten_list = list(out_dict["intens"])
            new_order = np.argsort(inten_list)[::-1]
            cutoff_ind = min(len(inten_list), self.max_peaks)
            new_inds = new_order[:cutoff_ind]
            out_dict["intens"] = np.array(inten_list)[new_inds].tolist()
            out_dict["frags"] = np.array(out_dict["frags"])[new_inds].tolist()
            out_dict["ions"] = np.array(out_dict["ions"])[new_inds].tolist()
        return out_dict

    def augment_peak_dict(self, peak_dict: dict):
        frags = np.array(peak_dict["frags"])
        intens = np.array(peak_dict["intens"])
        ions = np.array(peak_dict["ions"])
        if len(frags) == 0: return peak_dict

        num_modify_peaks = len(frags)
        keep_prob = 1 - self.remove_prob
        num_to_keep = np.random.binomial(n=num_modify_peaks, p=keep_prob)
        keep_inds = np.arange(num_modify_peaks)

        if self.remove_weights == "quadratic":
            keep_probs = intens.reshape(-1) ** 2 + 1e-9
        elif self.remove_weights == "uniform":
            keep_probs = np.ones(len(intens))
        elif self.remove_weights == "exp":
            keep_probs = np.exp(intens.reshape(-1) + 1e-5)
        else: raise NotImplementedError()
        keep_probs = keep_probs / (keep_probs.sum() + 1e-9) # Normalize safely

        # Ensure num_to_keep is not greater than available peaks
        num_to_keep = min(num_to_keep, len(keep_inds))
        if num_to_keep > 0 and len(keep_inds) > 0:
             ind_samples = np.random.choice(keep_inds, size=num_to_keep, replace=False, p=keep_probs)
             frags, intens, ions = frags[ind_samples], intens[ind_samples], ions[ind_samples]
        elif num_to_keep == 0: # Handle edge case where all peaks might be removed
             frags, intens, ions = np.array([]), np.array([]), np.array([])
        # else: # num_to_keep > 0 but len(keep_inds) == 0 (should not happen) - keep original

        if len(intens) > 0: # Only scale if peaks remain
            rescale_prob = np.random.random(len(intens))
            inten_scalar_factor = np.random.normal(loc=1, scale=0.1, size=len(intens)) # Added scale
            inten_scalar_factor[inten_scalar_factor <= 0] = 1e-6 # Avoid zero/negative
            inten_scalar_factor[rescale_prob >= self.inten_prob] = 1
            intens = intens * inten_scalar_factor
            new_max = intens.max() + 1e-12
            intens /= new_max

        peak_dict["intens"] = intens.tolist() # Convert back to list
        peak_dict["frags"] = frags.tolist()
        peak_dict["ions"] = ions.tolist()
        return peak_dict

    def _process_magma_file(self, spec_name, mz_vec, forms_vec):
        # Dummy implementation - returns fingerprints of -1
        # Replace with your actual MAGMA file processing if needed
        num_peaks = forms_vec.shape[0] if isinstance(forms_vec, np.ndarray) else len(forms_vec)
        fingerprints = np.full((num_peaks, self.aug_nbits), -1.0)
        # Your MAGMA loading logic here...
        # Example: Set some dummy fingerprints for the non-CLS token if aux loss is on
        if self.magma_aux_loss and num_peaks > 1:
             # Find the CLS token index (assuming it's the last one)
             cls_idx = -1
             if self.cls_token_type in ["ms1", "zeros"]:
                 cls_idx = num_peaks -1

             for i in range(num_peaks):
                 if i != cls_idx: # Don't assign FP to CLS token
                     # Assign a dummy FP (e.g., mostly zeros with a few ones)
                     fp = np.zeros(self.aug_nbits)
                     num_ones = np.random.randint(1, min(10, self.aug_nbits)) # Few random bits on
                     one_indices = np.random.choice(self.aug_nbits, num_ones, replace=False)
                     fp[one_indices] = 1
                     fingerprints[i, :] = fp

        return fingerprints


    def _generate_features(self, peak_dict: dict, spec_name=None):
        # import utils # Use the dummy utils defined above
        utils = self.utils

        root = peak_dict["root_form"]
        forms_vec = [utils.formula_to_dense(i) for i in peak_dict["frags"]]
        if not forms_vec: # Handle empty list
            mz_vec = []
            forms_vec_arr = np.empty((0, utils.NUM_ELEM)) # Use correct shape
        else:
            forms_vec_arr = np.array(forms_vec)
            mz_vec = (forms_vec_arr * utils.VALID_MONO_MASSES).sum(-1).tolist()

        root_vec = utils.formula_to_dense(root)
        root_ion = utils.get_ion_idx(peak_dict["root_ion"])
        root_mass = (root_vec * utils.VALID_MONO_MASSES).sum()
        inten_vec = list(peak_dict["intens"])
        ion_vec = [utils.get_ion_idx(i) for i in peak_dict["ions"]]
        type_vec = len(forms_vec) * [self.cls_type_map["frags"]]
        instrument = 0

        # Add classification token
        cls_added = False
        if self.cls_token_type == "ms1":
            cls_ind = self.cls_type_map.get("cls")
            inten_vec.append(1.0)
            type_vec.append(cls_ind)
            # Append root_vec correctly
            forms_list = forms_vec_arr.tolist()
            forms_list.append(root_vec)
            forms_vec_arr = np.array(forms_list)
            mz_vec.append(root_mass)
            ion_vec.append(root_ion)
            cls_added = True
        elif self.cls_token_type == "zeros":
            cls_ind = self.cls_type_map.get("cls")
            inten_vec.append(0.0)
            type_vec.append(cls_ind)
            forms_list = forms_vec_arr.tolist()
            forms_list.append(np.zeros_like(root_vec))
            forms_vec_arr = np.array(forms_list)
            mz_vec.append(0)
            ion_vec.append(root_ion)
            cls_added = True
        # else: # No CLS token if type is not ms1 or zeros
        #    pass

        inten_vec = np.array(inten_vec)
        if self.inten_transform == "float": self.inten_feats = 1
        elif self.inten_transform == "zero": self.inten_feats = 1; inten_vec = np.zeros_like(inten_vec)
        elif self.inten_transform == "log": self.inten_feats = 1; inten_vec = np.log(inten_vec + 1e-5)
        elif self.inten_transform == "cat":
            self.inten_feats = self.num_inten_bins
            bins = np.linspace(0, 1, self.num_inten_bins)
            inten_vec = np.digitize(inten_vec, bins)
        else: raise NotImplementedError()

        # Ensure forms_vec_arr is 2D even if empty or only CLS token added
        if forms_vec_arr.ndim == 1 and forms_vec_arr.shape[0] > 0: # Only CLS token
             forms_vec_arr = forms_vec_arr.reshape(1, -1)
        elif len(forms_vec_arr) == 0: # Empty case
             forms_vec_arr = np.empty((0, utils.NUM_ELEM))


        # Process MAGMA fingerprints
        # Pass the potentially updated forms_vec_arr
        fingerprints = self._process_magma_file(spec_name, mz_vec, forms_vec_arr)

        out_dict = {
            "peak_type": np.array(type_vec),
            "form_vec": forms_vec_arr, # Use the array
            "ion_vec": np.array(ion_vec), # Convert to array
            "frag_intens": inten_vec,
            "name": spec_name,
            "magma_fps": fingerprints,
            "magma_aux_loss": self.magma_aux_loss,
            "instrument": instrument, # Convert to array or keep as scalar? Assuming scalar
            "cls_added": cls_added # Flag if CLS token was added
        }
        return out_dict

# --- Helper Functions ---

def smiles_to_fingerprint(smiles: str, n_bits: int) -> np.ndarray:
    """Generates a Morgan fingerprint from a SMILES string."""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            print(f"Warning: RDKit could not parse SMILES: {smiles}")
            return np.zeros(n_bits, dtype=np.float32)
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=n_bits)
        return np.array(fp, dtype=np.float32)
    except Exception as e:
        print(f"Error generating fingerprint for {smiles}: {e}")
        return np.zeros(n_bits, dtype=np.float32)

def masked_bce_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Calculates BCE loss, ignoring targets where value is -1."""
    mask = target != -1
    target_masked = target[mask]
    pred_masked = pred[mask]
    if target_masked.numel() == 0: # Handle cases where mask removes all elements
        return torch.tensor(0.0, device=pred.device, requires_grad=True) # Return zero loss but allow grad flow
    # Use BCEWithLogitsLoss for stability if model output is logits (before sigmoid)
    # If model output is probabilities (after sigmoid), use BCELoss
    # Assuming pred is probabilities here based on SpectraEncoder's Sigmoid
    loss = nn.BCELoss(reduction='mean')(pred_masked, target_masked)
    # loss = nn.BCEWithLogitsLoss(reduction='mean')(pred_masked, target_masked) # If preds are logits
    return loss

def collate_fn(batch):
    """
    Custom collate function to handle padding and batching.
    This needs to be adapted based on how FormulaTransformer expects input.
    Assuming padding is needed for form_vec, peak_type, etc.
    """
    from torch.nn.utils.rnn import pad_sequence
    collated = {}
    keys_to_pad = ['form_vec', 'peak_type', 'ion_vec', 'frag_intens', 'magma_fps']
    keys_to_stack = ['target_fp'] # Assuming target_fp is already fixed size
    other_keys = ['name', 'magma_aux_loss', 'instrument', 'cls_added'] # Non-tensor or scalar data

    # Find max sequence length in batch for padding
    max_len = 0
    if batch:
         max_len = max(item['form_vec'].shape[0] for item in batch if 'form_vec' in item and item['form_vec'].ndim > 0)

    # Pad and batch tensor data
    for key in keys_to_pad:
        # Ensure data exists and is numpy array before converting to tensor
        sequences = [torch.from_numpy(item[key]) for item in batch if key in item and isinstance(item[key], np.ndarray) and item[key].size > 0]
        if sequences:
             # Handle potential 1D vs 2D arrays after processing
             if sequences[0].ndim == 1: # e.g., peak_type, ion_vec, frag_intens
                 # Pad 1D tensors
                 collated[key] = pad_sequence(sequences, batch_first=True, padding_value=0) # Use 0 for padding types/ions/intens? Check MIST paper.
             elif sequences[0].ndim == 2: # e.g., form_vec, magma_fps
                 # Pad 2D tensors
                 collated[key] = pad_sequence(sequences, batch_first=True, padding_value=0.0) # Pad formulas/fps with 0.0
             else:
                 print(f"Warning: Unexpected tensor dimension for key {key}: {sequences[0].ndim}")
        else:
             # Handle cases where a key might be missing or empty across the batch
             # Create an empty tensor or handle appropriately downstream
             # Example: Create tensor with shape (batch_size, 0) or (batch_size, 0, feature_dim)
             batch_size = len(batch)
             if key == 'form_vec': shape = (batch_size, 0, SpectrumProcessor.utils.NUM_ELEM)
             elif key == 'magma_fps': shape = (batch_size, 0, batch[0]['magma_fps'].shape[1] if batch and 'magma_fps' in batch[0] and batch[0]['magma_fps'].ndim==2 else 512) # Use configured size
             elif key in ['peak_type', 'ion_vec', 'frag_intens']: shape = (batch_size, 0)
             else: shape = (batch_size, 0) # Default fallback
             collated[key] = torch.empty(shape)


    # Stack fixed-size tensors
    for key in keys_to_stack:
        if key in batch[0]:
            collated[key] = torch.stack([torch.from_numpy(item[key]) for item in batch], dim=0)

    # Collect other data (non-tensor)
    for key in other_keys:
        if key in batch[0]:
            collated[key] = [item[key] for item in batch]

    # Add attention mask based on padding (assuming 0 padding for peak_type)
    if 'peak_type' in collated:
        # Mask should be True where there is actual data, False for padding
        # Assuming padding value is 0 and actual types are non-zero (or handle cls_type=0 specifically)
        # A safer mask might be based on sequence lengths before padding.
        # Let's create a mask based on non-zero sum across formula vector dim
        if 'form_vec' in collated and collated['form_vec'].numel() > 0:
             # Sum across the element dimension (dim 2)
             # Mask is True where the sum is non-zero (actual peak data)
             attention_mask = torch.sum(collated['form_vec'], dim=2) != 0
             collated['attention_mask'] = attention_mask # Shape: [batch_size, seq_len]
        elif max_len > 0 : # Fallback if form_vec is empty but padding happened
             # Create mask based on original lengths if available, otherwise assume all valid up to max_len
             # This part needs careful implementation based on how lengths are tracked
             collated['attention_mask'] = torch.ones(len(batch), max_len, dtype=torch.bool) # Placeholder mask
        else: # No sequences
             collated['attention_mask'] = torch.empty(len(batch), 0, dtype=torch.bool)


    # Ensure all necessary keys for the model are present, even if empty
    required_keys = ['form_vec', 'peak_type', 'ion_vec', 'frag_intens', 'attention_mask', 'target_fp']
    if batch and batch[0].get('magma_aux_loss', False):
        required_keys.append('magma_fps')

    for key in required_keys:
        if key not in collated:
             # Handle missing keys, e.g., create empty tensors
             print(f"Warning: Key '{key}' missing in collated batch.")
             if key == 'target_fp': collated[key] = torch.empty((len(batch), 4096)) # Placeholder size
             # Add similar handling for other keys if necessary

    return collated


# --- Dataset Definition ---
class SpectraDataset(Dataset):
    def __init__(self, data_file, spectrum_processor, target_fp_size, is_train=False):
        self.processor = spectrum_processor
        self.target_fp_size = target_fp_size
        self.is_train = is_train
        # Load your data (e.g., from TSV)
        # Example assuming TSV: ID SMILES JSON_DATA ...
        self.data = pd.read_csv(data_file, sep='\t', header=None) # Adjust as needed
        # Filter out rows where RDKit fails on SMILES?
        # self.data = self.data[self.data[1].apply(lambda x: Chem.MolFromSmiles(x) is not None)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        spec_id = str(row[0])
        smiles = str(row[1])
        raw_spec_json = str(row[2]) # Assuming JSON is in the 3rd column

        # Process spectrum
        # Use train_mode for augmentation during training
        spec_features = self.processor.process_raw_spectrum(
            raw_spec_json, spec_id=spec_id, train_mode=self.is_train
        )

        # Generate target fingerprint
        target_fp = smiles_to_fingerprint(smiles, self.target_fp_size)

        # Combine features and target
        item = {**spec_features, 'target_fp': target_fp}
        return item

# --- Training Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Model Hyperparameters (match SpectraEncoder defaults or your config)
OUTPUT_SIZE = 4096
HIDDEN_SIZE = 50 # Example, adjust based on MIST paper/config
SPECTRA_DROPOUT = 0.1 # Example
TOP_LAYERS = 2 # Example
MAGMA_MODULO = 2048 # Example
FORM_EMBEDDER = "float" # Or 'abs', 'elec', etc.
# Add other FormulaTransformer kwargs if needed (e.g., nhead, num_encoder_layers)
FORMULA_TRANSFORMER_KWARGS = {"nhead": 2, "num_encoder_layers": 2, "dim_feedforward": 128} # Example

# Training Hyperparameters
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
EPOCHS = 50
TRAIN_DATA = "path/to/your/train_data.tsv" # Replace with your data path
VAL_DATA = "path/to/your/val_data.tsv"   # Replace with your data path
MAGMA_FOLDER = "path/to/your/magma_files" # Optional: Replace if using MAGMA aux loss
USE_MAGMA_AUX_LOSS = True # Set to True if using MAGMA fingerprints
MAGMA_LOSS_WEIGHT = 0.2 # Weight for the auxiliary loss

# --- Initialization ---
# Initialize Spectrum Processor
# Ensure magma_aux_loss matches USE_MAGMA_AUX_LOSS
processor = SpectrumProcessor(
    augment_data=True, # Enable augmentation for training data
    cls_token_type="ms1",
    max_peaks=500, # Example: Limit number of peaks
    magma_modulo=MAGMA_MODULO,
    magma_aux_loss=USE_MAGMA_AUX_LOSS,
    magma_folder=MAGMA_FOLDER if USE_MAGMA_AUX_LOSS else None
)

# Initialize Datasets and DataLoaders
train_dataset = SpectraDataset(TRAIN_DATA, processor, OUTPUT_SIZE, is_train=True)
val_dataset = SpectraDataset(VAL_DATA, processor, OUTPUT_SIZE, is_train=False) # No augmentation for validation

# Use the custom collate function
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=4, pin_memory=True)

# Initialize Model
model = SpectraEncoder(
    form_embedder=FORM_EMBEDDER,
    output_size=OUTPUT_SIZE,
    hidden_size=HIDDEN_SIZE,
    spectra_dropout=SPECTRA_DROPOUT,
    top_layers=TOP_LAYERS,
    magma_modulo=MAGMA_MODULO,
    **FORMULA_TRANSFORMER_KWARGS # Pass transformer specific args
).to(DEVICE)

# Initialize Loss Functions
criterion_main = nn.BCELoss() # Use BCELoss because model has Sigmoid
# criterion_main = nn.BCEWithLogitsLoss() # Use this if you remove the final Sigmoid from the model

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# --- Training Loop ---
for epoch in range(EPOCHS):
    model.train()
    total_train_loss = 0
    total_main_loss = 0
    total_aux_loss = 0

    for i, batch in enumerate(train_loader):
        # Move batch to device (handle non-tensor data appropriately)
        batch_gpu = {}
        for key, value in batch.items():
             if isinstance(value, torch.Tensor):
                 batch_gpu[key] = value.to(DEVICE)
             else:
                 batch_gpu[key] = value # Keep non-tensors like names, flags as they are

        # Skip batch if essential tensors are empty after collation/filtering
        if not batch_gpu.get('form_vec', torch.empty(0)).numel() or not batch_gpu.get('target_fp', torch.empty(0)).numel():
             print(f"Skipping empty batch {i}")
             continue

        optimizer.zero_grad()

        # Forward pass
        try:
            output, aux_outputs = model(batch_gpu) # Pass the gpu batch
        except Exception as e:
            print(f"Error during forward pass on batch {i}: {e}")
            # Optionally print batch details for debugging
            # for key, value in batch_gpu.items():
            #     if isinstance(value, torch.Tensor):
            #         print(f"Batch key: {key}, Shape: {value.shape}, Device: {value.device}")
            #     else:
            #         print(f"Batch key: {key}, Value: {value}")
            continue # Skip this batch

        # Calculate Main Loss
        target_fp = batch_gpu['target_fp']
        main_loss = criterion_main(output, target_fp)

        # Calculate Auxiliary Loss (if enabled)
        aux_loss = torch.tensor(0.0).to(DEVICE)
        if USE_MAGMA_AUX_LOSS and 'pred_frag_fps' in aux_outputs and 'magma_fps' in batch_gpu and batch_gpu['magma_fps'].numel() > 0:
            pred_frag_fps = aux_outputs['pred_frag_fps'] # Shape: [batch, (num_peaks,) magma_modulo]
            target_magma_fps = batch_gpu['magma_fps']   # Shape: [batch, num_peaks, magma_modulo]

            # --- IMPORTANT ---
            # Ensure pred_frag_fps and target_magma_fps align.
            # If pred_frag_fps is [batch, magma_modulo] (global prediction),
            # you cannot directly compare it to per-peak target_magma_fps.
            # The MIST paper/code needs clarification on how aux loss is computed.
            # Assuming pred_frag_fps is per-peak for this example:
            # Shape: [batch, num_peaks, magma_modulo]
            # Make sure the shapes match before calling masked_bce_loss.
            # If shapes mismatch, you need to adapt the model or loss calculation.

            # Example check (adjust based on actual model output shape):
            if pred_frag_fps.shape == target_magma_fps.shape:
                 # Apply sigmoid if fragment_predictor doesn't have one
                 pred_frag_fps_sig = torch.sigmoid(pred_frag_fps) # Assuming logits output
                 aux_loss = masked_bce_loss(pred_frag_fps_sig, target_magma_fps)
            else:
                 print(f"Warning: Shape mismatch for MAGMA loss. Pred: {pred_frag_fps.shape}, Target: {target_magma_fps.shape}. Skipping aux loss.")
                 aux_loss = torch.tensor(0.0).to(DEVICE)


        # Combine Losses
        total_loss = main_loss + MAGMA_LOSS_WEIGHT * aux_loss

        # Backward pass and optimize
        try:
            total_loss.backward()
            optimizer.step()
        except RuntimeError as e:
            print(f"Error during backward/step on batch {i}: {e}")
            # Potentially clear gradients if step failed
            optimizer.zero_grad()
            continue # Skip this batch update

        total_train_loss += total_loss.item()
        total_main_loss += main_loss.item()
        if USE_MAGMA_AUX_LOSS:
            total_aux_loss += aux_loss.item()

        if (i + 1) % 100 == 0:
            avg_loss = total_train_loss / (i + 1)
            avg_main_loss = total_main_loss / (i + 1)
            avg_aux_loss = total_aux_loss / (i + 1) if USE_MAGMA_AUX_LOSS else 0
            print(f'Epoch [{epoch+1}/{EPOCHS}], Step [{i+1}/{len(train_loader)}], Avg Loss: {avg_loss:.4f}, Avg Main Loss: {avg_main_loss:.4f}, Avg Aux Loss: {avg_aux_loss:.4f}')

    print(f'--- Epoch {epoch+1} Training Finished ---')
    print(f'Average Training Loss: {total_train_loss / len(train_loader):.4f}')

    # --- Validation Loop ---
    model.eval()
    total_val_loss = 0
    total_val_main_loss = 0
    total_val_aux_loss = 0
    # Add metric calculation (e.g., Tanimoto)
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in val_loader:
            batch_gpu = {}
            for key, value in batch.items():
                 if isinstance(value, torch.Tensor):
                     batch_gpu[key] = value.to(DEVICE)
                 else:
                     batch_gpu[key] = value

            if not batch_gpu.get('form_vec', torch.empty(0)).numel() or not batch_gpu.get('target_fp', torch.empty(0)).numel():
                 print(f"Skipping empty validation batch")
                 continue

            try:
                output, aux_outputs = model(batch_gpu)
            except Exception as e:
                 print(f"Error during validation forward pass: {e}")
                 continue

            target_fp = batch_gpu['target_fp']
            main_loss = criterion_main(output, target_fp)

            aux_loss = torch.tensor(0.0).to(DEVICE)
            if USE_MAGMA_AUX_LOSS and 'pred_frag_fps' in aux_outputs and 'magma_fps' in batch_gpu and batch_gpu['magma_fps'].numel() > 0:
                 pred_frag_fps = aux_outputs['pred_frag_fps']
                 target_magma_fps = batch_gpu['magma_fps']
                 if pred_frag_fps.shape == target_magma_fps.shape:
                     pred_frag_fps_sig = torch.sigmoid(pred_frag_fps)
                     aux_loss = masked_bce_loss(pred_frag_fps_sig, target_magma_fps)
                 else:
                     aux_loss = torch.tensor(0.0).to(DEVICE)


            total_loss = main_loss + MAGMA_LOSS_WEIGHT * aux_loss
            total_val_loss += total_loss.item()
            total_val_main_loss += main_loss.item()
            if USE_MAGMA_AUX_LOSS:
                total_val_aux_loss += aux_loss.item()

            # Store predictions and targets for metrics
            all_preds.append(output.cpu().numpy())
            all_targets.append(target_fp.cpu().numpy())

    avg_val_loss = total_val_loss / len(val_loader) if len(val_loader) > 0 else 0
    avg_val_main_loss = total_val_main_loss / len(val_loader) if len(val_loader) > 0 else 0
    avg_val_aux_loss = total_val_aux_loss / len(val_loader) if len(val_loader) > 0 and USE_MAGMA_AUX_LOSS else 0

    print(f'--- Epoch {epoch+1} Validation Finished ---')
    print(f'Average Validation Loss: {avg_val_loss:.4f}')
    print(f'Average Val Main Loss: {avg_val_main_loss:.4f}')
    if USE_MAGMA_AUX_LOSS:
        print(f'Average Val Aux Loss: {avg_val_aux_loss:.4f}')

    # Calculate Tanimoto Similarity (Example)
    if all_preds and all_targets:
        all_preds_cat = np.concatenate(all_preds, axis=0)
        all_targets_cat = np.concatenate(all_targets, axis=0)
        # Binarize predictions (e.g., threshold at 0.5)
        preds_binary = (all_preds_cat > 0.5).astype(int)
        targets_binary = all_targets_cat.astype(int) # Target should already be binary

        intersection = np.sum(preds_binary * targets_binary, axis=1)
        union = np.sum(np.logical_or(preds_binary, targets_binary).astype(int), axis=1)
        # Avoid division by zero for empty fingerprints
        tanimoto_scores = np.divide(intersection, union, out=np.zeros_like(intersection, dtype=float), where=union!=0)
        avg_tanimoto = np.mean(tanimoto_scores)
        print(f'Average Tanimoto Similarity: {avg_tanimoto:.4f}')

    # Add model saving logic here (e.g., save best model based on val loss or Tanimoto)

print("Training complete.")