# 🧪 MS/MS Spectra Modification Classifier with Transformers

This notebook builds and trains a deep learning model designed to detect and classify *post-translational modifications (PTMs)* in MS/MS spectra from shotgun proteomics. The input is MS/MS spectra in `mgf` format and the classifier is based on a hybrid CNN-Transformer architecture.

---

### 🧠 Objectives

- **Multi-class Classification**: If modified, predict the specific type:
  - Unmodified
  - Oxidation
  - Phosphorylation
  - Ubiquitination
  - Acetylation

---

### 🔧 Environment Setup

The following libraries and paths are configured:

#### 📦 Core Libraries
- `torch`, `torch.nn`, `torch.optim`: PyTorch for neural network construction and training.
- `numpy`, `random`, `os`, `sys`: Utilities for array operations, randomness, and file handling.
- `math`, `datetime`, `logging`: Math functions, timestamping, and logging system.
- `matplotlib.pyplot`: (optional) Visualization.
- `scikit-learn`: Evaluation metrics and dataset splitting.


#### 🛠️ Path Configuration
- Adds the dataset directory on Google Drive to the system path to ensure data files can be accessed during training and evaluation.

---

### 🧬 Pipeline Overview

This project includes the following components:
- **MGF File Parsing**: Custom loader to extract raw spectra from `.mgf` files dataset.
- **Spectral Preprocessing**: Converts spectra into binned, normalized vector representations.
- **Metadata Normalization**: Processes and scales parent ion mass (`pepmass`) for model input.
- **HYbrid CNN-Transformer Model**: Hybrid neural architecture combining CNNs, self-attention, and metadata fusion.
- **Training & Evaluation**: Loop with weighted loss, custom metrics, logging, and model checkpointing.

This setup is tailored for high-performance PTM classification while maintaining compatibility with Google Colab workflows and GPU acceleration tuned using Optuna.



---

## 📁 Directory Setup Instructions

Before running the notebook, ensure your **Google Drive** is properly structured so that the code can:

* Load `.mgf` spectra files.
* Save model weights.
* Persist log files from training.

This is **required** for the notebook to run end-to-end.

---

### 🔗 1. Mount Google Drive

At the beginning of your notebook, run:

```python
from google.colab import drive
drive.mount('/content/drive')
```

You will be prompted to authorize access.

---

### 📂 2. Create This Folder Structure in Your Drive

Organize your files inside `MyDrive` as follows:

```
MyDrive/
├── data/
│   └── balanced_dataset/                ← contains balanced .mgf files for training, they dont neeed to be balanced in the class distribution, but it help in tranning performance
│       ├── split_file_001.mgf
│       ├── split_file_002.mgf
│       └── ...
├── peak_encoder_transformer_pipeline/
│   ├── model_weights/                   ← for saving trained model weights
│   └── logs/                            ← for saving training logs
```

If these folders don't exist, you can create them manually in Google Drive or use Python:

```python
import os

os.makedirs("/content/drive/MyDrive/data/balanced_dataset", exist_ok=True)
os.makedirs("/content/drive/MyDrive/peak_encoder_transformer_pipeline/model_weights", exist_ok=True)
os.makedirs("/content/drive/MyDrive/peak_encoder_transformer_pipeline/logs", exist_ok=True)
```

---

### ⚙️ 3. Update Paths in the Code (if needed)

These variables should point to the correct folders:

```python
input_dir = "your split dataset path"
model_weights_dir = "path for where the weights go"
log_dir = "path for the log system for the per bath logs to be"
```

Make sure the paths you changed to your own are comtable with is expected of each one of them


---

✅ **Once these are set**, you're ready to run the notebook end-to-end, including training, evaluation, and logging.



In [None]:
#Set up the enviorment imports and paths that are necessary for the processing of the cells

from google.colab import drive
drive.mount('/content/drive')
import sys
sys.path.append('content/drive/MyDrive/data/balanced_dataset')  # Add the folder containing main.py to sys.path
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
from collections import Counter
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, average_precision_score, precision_recall_curve
import matplotlib.pyplot as plt
import logging
from datetime import datetime
from sklearn.model_selection import train_test_split
import math
from sklearn.metrics import classification_report
import torch.nn.functional as F
from torch.nn import SiLU
import re
from sklearn.model_selection import StratifiedKFold
import pandas as pd

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 📂 DatasetHandler class for loading MGF files

This section defines the `DatasetHandler` class responsible for managing the loading and iteration over `.mgf` files containing MS/MS spectra.
Loading small `.mgf` at a time in order to make the pipeline scalable without running in to out of memory memory issues.

---

### 📦 `DatasetHandler` Overview

The `DatasetHandler` class provides a memory-efficient way to iterate through `.mgf` files stored in a directory. It supports:

- **Shuffling input files** to randomize data order across training loops.
- **Per-file usage tracking** with `MAX_FILE_PASSES`, ensuring that no file is overused during training.
- **Controlled looping** over the dataset using `num_loops` to allow multiple training epochs without data reloading.

---

### 🧩 Key Components make this under the code explaining how to use it, make it like an example under evrything

#### 🔧 Initialization
```python
handler = DatasetHandler(input_dir="/path/to/mgf", num_loops=1)


In [None]:
#Setting up the dataset handler class that handles the input
#There is still prints to remove

MAX_FILE_PASSES = 1 # Max times a file can be used before being ignored

class DatasetHandler:
    def __init__(self, input_dir, num_loops=1):
        """
        Initialize the dataset handler.

        Args:
            input_dir (str): Path to the directory containing split MGF files.
            num_loops (int): Number of times the dataset should be iterated.
        """
        self.files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.mgf')]
        self.files = random.sample(self.files, len(self.files))  # Shuffle files
        self.file_usage_counter = {f: 0 for f in self.files}
        self.num_loops = num_loops
        self.loop_count = 0

    def get_next_file(self) -> list:
      """
      Load one MGF file at a time into RAM and return all valid spectra from it.

      Returns:
          list of dict: Each dict contains a valid spectrum and its metadata.
      """
      while self.loop_count < self.num_loops:
          available_files = [f for f in self.files if self.file_usage_counter[f] < MAX_FILE_PASSES]
          if not available_files:
              self.loop_count += 1
              if self.loop_count < self.num_loops:
                  print("Restarting dataset loop...")
                  self.file_usage_counter = {f: 0 for f in self.files}
                  continue
              else:
                  print("All dataset loops completed.")
                  return None

          file = random.choice(available_files)
          print(f"Processing file: {file}")

          spectra = []
          spectrum_data = None

          with open(file, 'r') as f:
              for line in f:
                  line = line.strip()

                  if line == "BEGIN IONS":
                      spectrum_data = {"mz_values": [], "intensity_values": []}

                  elif line.startswith("TITLE=") and spectrum_data is not None:
                      spectrum_data["title"] = line.split("=", 1)[1].strip()

                  elif line.startswith("PEPMASS=") and spectrum_data is not None:
                      try:
                          spectrum_data["pepmass"] = float(line.split("=", 1)[1].split()[0])
                      except ValueError:
                          spectrum_data["pepmass"] = None  # mark as missing
                  elif line.startswith("CHARGE=") and spectrum_data is not None:
                      charge_str = line.split("=", 1)[1].strip()
                      match = re.match(r'^(\d+)', charge_str)  # Match one or more digits at the start
                      if match:
                          spectrum_data["charge"] = int(match.group(1))
                      else:
                          print(f"[SKIPPED CHARGE] Invalid charge format: '{charge_str}'")
                          spectrum_data["charge"] = None

                  elif line == "END IONS" and spectrum_data is not None:
                      title = spectrum_data.get("title", "").strip()
                      mz_vals = spectrum_data.get("mz_values", [])
                      int_vals = spectrum_data.get("intensity_values", [])

                      # Final validation before appending
                      if not title or not title.strip():
                          print(f"[SKIPPED] Missing TITLE in file: {file}")
                      elif not mz_vals or not int_vals:
                          print(f"[SKIPPED] Empty m/z or intensity in file: {file}")
                      elif len(mz_vals) != len(int_vals):
                          print(f"[SKIPPED] Mismatched m/z and intensity count in file: {file}")
                      elif np.sum(int_vals) == 0:
                          print(f"[SKIPPED] All-zero intensities in spectrum '{title}'")
                      else:
                          spectra.append(spectrum_data)

                      spectrum_data = None  # Reset for next spectrum

                  else:
                      if spectrum_data is not None:
                          try:
                              parts = line.split()
                              if len(parts) != 2:
                                  raise ValueError("Expected two float values")
                              mz, intensity = map(float, parts)
                              if math.isnan(mz) or math.isnan(intensity):
                                  raise ValueError("NaN detected")
                              spectrum_data["mz_values"].append(mz)
                              spectrum_data["intensity_values"].append(intensity)
                          except ValueError:
                              #print(f"[SKIPPED LINE] Invalid peak: '{line}' in file: {file}")
                              continue

          self.file_usage_counter[file] += 1

          if spectra:
              return spectra, file

      print("All spectra processed.")
      return None


# ⚙️ Dense Vector Binning for 1D CNN Input  
This section defines the updated preprocessing pipeline for converting annotated MS/MS spectra into dense, fixed-length vectors. These are tailored for use in models such as CNNs or hybrid CNN-Transformer architectures.

## 🔧 Functions:

### `bin_spectra_to_dense_vectors`  
Converts a list of spectra into fixed-length vectors by:  
- **Binning the m/z values** across a specified range (`mz_min` to `mz_max`) into `num_bins`.  
- Each bin holds the **intensity sum of peaks** falling into that m/z range.  
- Applies **sliding window normalization**:  
  The m/z axis is divided into fixed-size windows (e.g., 200 m/z), and intensities within each window are normalized individually to the [0, 1] range. This preserves local signal structure and prevents domination by high-intensity regions.

### `process_spectra_with_handler`  
Processes a batch of spectra by:  
- Logging and skipping spectra with empty or invalid m/z or intensity values.  
- Using the above function to apply binning and **sliding window normalization**.  
- Skipping spectra with no signal after binning (i.e., zero-vector).  

Returns a list of valid, normalized dense vectors for CNN input and logs the total number of skipped spectra.

## 📦 Output Format:  
Each spectrum becomes a 1D `np.array` of shape `(num_bins,)` with `float32` values.  

The final output is either:  
- a stacked `np.ndarray` of shape `(batch_size, num_bins)` when using `bin_spectra_to_dense_vectors` directly on a list, or  
- a list of valid vectors (1 per spectrum) when using `process_spectra_with_handler`.
 used.

In [None]:
def bin_spectra_to_dense_vectors(spectra_data, num_bins=5000, mz_min=100.0, mz_max=2200.0, window_size=200.0):
    """
    Converts spectra into dense, fixed-length binned vectors suitable for 1D CNN input with sliding window normalization.

    Parameters:
    - spectra_data: List of spectra dicts with 'mz_values' and 'intensity_values'.
    - num_bins: Number of bins to divide the m/z range [mz_min, mz_max] into.
    - mz_min: Minimum m/z value for binning.
    - mz_max: Maximum m/z value for binning.
    - window_size: Size of m/z window for normalization (default is 200.0).

    Returns:
    - np.ndarray of shape (batch_size, num_bins) with per-spectrum normalized intensities.
    """
    bin_edges = np.linspace(mz_min, mz_max, num_bins + 1)
    binned_spectra = []

    for spectrum in spectra_data:
        mz_values = np.array(spectrum['mz_values'])
        intensity_values = np.array(spectrum['intensity_values'])

        if len(mz_values) == 0 or len(intensity_values) == 0:
            binned_spectra.append(np.zeros(num_bins, dtype=np.float32))
            continue

        # Create an array to hold the binned intensities (fixed size)
        binned_intensity = np.zeros(num_bins)

        # Iterate over windows of m/z values
        for window_start in np.arange(mz_min, mz_max, window_size):
            window_end = window_start + window_size
            window_mask = (mz_values >= window_start) & (mz_values < window_end)
            window_mz_values = mz_values[window_mask]
            window_intensity_values = intensity_values[window_mask]

            if len(window_mz_values) > 0:
                # Bin the intensities for this window
                binned_window_intensity, _ = np.histogram(window_mz_values, bins=bin_edges, weights=window_intensity_values)

                # Normalize the binned intensities within this window
                min_val = binned_window_intensity.min()
                max_val = binned_window_intensity.max()
                range_val = max_val - min_val if max_val != min_val else 1e-6
                normalized_binned_window = (binned_window_intensity - min_val) / range_val

                # Add the normalized intensities to the final vector (same size as before)
                binned_intensity += normalized_binned_window

        binned_spectra.append(binned_intensity.astype(np.float32))

    return np.stack(binned_spectra)  # Shape: (batch_size, num_bins)


def process_spectra_with_handler(spectra_batch, num_bins=1000, window_size=200.0):
    """
    Processes spectra batch and returns a list of 1D CNN-ready vectors (one per spectrum),
    with sliding window normalization applied.
    """
    spectrum_vectors = []
    skipped_spectra = 0

    for idx, spectrum in enumerate(spectra_batch):
        title = spectrum.get("title", f"unnamed_{idx}")
        mz_values = np.array(spectrum['mz_values'])
        intensity_values = np.array(spectrum['intensity_values'])

        if mz_values.size == 0 or intensity_values.size == 0:
            print(f"[SKIPPED] Empty m/z or intensity array: '{title}'")
            skipped_spectra += 1
            continue

        # Call the binning function with windowed normalization
        binned_spectrum = bin_spectra_to_dense_vectors([spectrum], num_bins=num_bins, window_size=window_size)

        # Ensure only valid (non-zero) spectra are added
        if np.sum(binned_spectrum) == 0:
            print(f"[SKIPPED] Zero intensity after binning: '{title}'")
            skipped_spectra += 1
            continue

        spectrum_vectors.append(binned_spectrum[0])  # Extract the vector

    print(f"Total skipped spectra: {skipped_spectra}")
    return spectrum_vectors


## 🔬 Normalize Parent Ion Mass (PEPMASS)

This module provides utilities to **extract sequences**, **convert observed m/z to monoisotopic neutral mass** (if needed), and **normalize parent ion values** into the range [0, 1].

---

### 🎯 Objectives (current implementation)

- **Extract** peptide sequence from the beginning of the `TITLE` field.  
- **Convert** PEPMASS from **observed m/z** to **monoisotopic single charged mass** when `assume_observed=True`.  
- **Normalize** the parent ion mass into \[0, 1\] using global bounds from `min_max_dict`.



### 🧩 Key Functions

#### 🔹 `extract_sequence_from_title(title: str) -> str`
Extracts the peptide sequence from the `TITLE`.  
Assumes the sequence is the **first token** (before the first space).

**Example**
```python
TITLE = "GWSMSEQSEESVGGR 2,S,Phospho"
extract_sequence_from_title(TITLE)
# → "GWSMSEQSEESVGGR"
```

🔹 `observed_to_monoisotopic(observed_mz: float, charge: int) -> float`

Converts observed precursor **m/z** into **monoisotopic neutral mass**:

$$
\text{mono\_mass} = z \cdot \text{m/z} - (z - 1)\cdot \text{PROTON\_MASS}
$$

Uses `PROTON_MASS = 1.007276`.

---

#### 🔹 `normalize_parent_ions(data, min_max_dict, assume_observed=True) -> list[float]`

Normalizes parent ion values to the range \$0, 1\$.

* **Inputs per spectrum (dict):**

  * `"pepmass"`: precursor value
  * `"charge"`: integer charge state

* **Behavior:**

  1. If `assume_observed=True`:

     * Converts `"pepmass"` (observed m/z) → monoisotopic neutral mass.
  2. If `assume_observed=False`:

     * Uses `"pepmass"` directly (assumed monoisotopic).
  3. Normalizes with:

     $$
     \text{norm} = \frac{parent\_ion - min}{max - min}
     $$
  4. Clamps results into \$0, 1\$.
  5. Missing metadata → returns `0.0`.

**Example**

```python
min_max = {"min": 400.0, "max": 6000.0}
normalized = normalize_parent_ions(spectra, min_max, assume_observed=True)
```

---

### ✅ Output

Returns:

```python
[List of float values between 0 and 1]
```

---

### ⚠️ Notes

* Requires `"min"` and `"max"` keys in `min_max_dict`.
* Missing or invalid metadata defaults to **0.0**.
* No theoretical mass calculation or spectrum validation is performed here.



In [None]:

PROTON_MASS = 1.0072764665789
H2O_MASS = 18.01056

def extract_sequence_from_title(title: str) -> str:
    """
    Extracts the peptide sequence from the TITLE string.
    Assumes the sequence is the first word, before the first space.
    """
    if not isinstance(title, str) or not title.strip():
        return ""
    return title.strip().split(" ")[0]  # safe even with extra spaces



def observed_to_monoisotopic(observed_mz, charge):
    return charge * observed_mz - (charge - 1) * PROTON_MASS



def normalize_parent_ions(data, min_max_dict, assume_observed=True):
    """
    Normalize parent ions to the range [0, 1].

    If assume_observed=True, converts PEPMASS (observed m/z) to monoisotopic mass before computing normalization.
    """
    normalized = []

    for spectrum in data:
        pepmass = spectrum.get("pepmass", None)
        charge = spectrum.get("charge", None)

        if pepmass is None or charge is None:
            normalized.append(0.0)
            continue

        if assume_observed:
            mono_mass = observed_to_monoisotopic(pepmass, charge)
            parent_ion = mono_mass
        else:
            parent_ion = pepmass  # Already monoisotopic

        # Normalize to [0, 1]
        pepmass_min = min_max_dict["min"]
        pepmass_max = min_max_dict["max"]
        norm = (parent_ion - pepmass_min) / (pepmass_max - pepmass_min)
        normalized.append(max(0, min(1, norm)))

    return normalized




### 🧬 Combine Spectra with Parent Ion Mass, change the model to always recieve monoistopic single charged mass inetas of obserd mass like we currently do.


This function constructs the final **input representation** for the neural network by pairing each processed spectrum with its corresponding normalized parent ion mass.

---

### ⚙️ `combine_features(...)`

#### **Purpose**
Aggregates spectral and precursor metadata into a unified format, ready to be passed into the model during training or evaluation.

---

### 🔄 Process Flow

1. **Spectral Preprocessing**
   - Calls `process_spectra_with_handler(...)` to:
     - Apply binning and normalization.
     - Generate a dense, fixed-length vector for each spectrum.
   - Result: `spectra_vectors` — a list of shape `[batch_size, num_bins]`.

2. **Parent Ion Normalization**
   - Invokes `normalize_parent_ions(...)` to:
     - Convert precursor monoisotopic mass to observed mass.
     - Normalize to a range of [0, 1] using dataset-specific bounds.
   - Result: `parent_ions` — a list of length `[batch_size]`.

3. **Validation**
   - Verifies alignment between spectrum vectors and parent ion list.
   - Logs an error and aborts if lengths mismatch.

4. **Zipping**
   - Combines each spectrum vector and its corresponding normalized parent ion into a tuple:
     ```python
     (spectrum_vector, normalized_parent_ion)
     ```

---

### 📤 Output Format

```python
[
  (spectrum_vec₁, pepmass₁),
  (spectrum_vec₂, pepmass₂),
  ...
]


In [None]:
def combine_features(data, pepmass_min_max, num_bins, window_normaliation_size, assume_observed):
    """
    Converts spectra + metadata into model input tuples:
        (binned spectrum, normalized parent ion mass)
    """

    spectra_vectors = process_spectra_with_handler(data, num_bins, window_normaliation_size)
    if not spectra_vectors:
        return None

    parent_ions = parent_ions = normalize_parent_ions(
    data, pepmass_min_max, assume_observed=assume_observed)


    if len(spectra_vectors) != len(parent_ions):
        print("❌ Mismatch between spectra and parent ions.")
        return None

    return list(zip(spectra_vectors, parent_ions))



### 🏷️ Label Spectra Based on Modifications

This function performs **automatic labeling** of MS/MS spectra for supervised learning, based on the content of the `TITLE` field in each spectrum's metadata.

---

### 🧠 Purpose

Assigns integer labels to each spectrum in a batch according to the presence of post-translational modification (PTM) keywords in the title:

- `0` → **Unmodified**
- `1` → **Oxidation** (if the string `"oxidation"` appears in the title)
- `2` → **Phosphorylation** (if the string `"phospho"` appears in the title)
- `3` → **Ubiquitination** (if the string `"k_gg"` appears in the title)
- `4` → **Acetylation** (if the string `"k_ac"` appears in the title)

The result is a list of labels aligned with the order of input spectra — suitable for classification tasks using `softmax`
---

### ⚙️ Logic

For each spectrum in the input list:
1. Checks that the entry is a dictionary.
2. Extracts the `title` and converts it to lowercase.
3. Searches for PTM-related keywords.
4. Defaults to `0` if no match or invalid format.

---

### 📤 Output Format

Returns:
```python
[0, 2, 1, 4, 3, ...]



In [None]:
def spectrum_label(spectra_data) -> list:
    """
    Assigns labels to spectra based on known modifications in TITLE.

    Parameters:
    - spectra_data (list of dict): List of spectrum dictionaries (from DatasetHandler).

    Returns:
    - List of labels for each spectrum.
    """
    if not isinstance(spectra_data, list):
        print("ERROR: Expected a list of spectra, got", type(spectra_data))
        return None

    labels = []

    for spectrum in spectra_data:
        if not isinstance(spectrum, dict):
            print(f"WARNING: Expected spectrum to be a dict, got {type(spectrum)}")
            labels.append(0)
            continue

        # Get spectrum title and verify it's a non-empty string
        spectrum_id = spectrum.get("title", "")
        if not isinstance(spectrum_id, str) or not spectrum_id.strip():
            print(f"WARNING: Missing or invalid title for spectrum, assigning label 0")
            labels.append(0)
            continue

        spectrum_id = spectrum_id.lower().strip()  # Normalize for label detection

        # Assign labels based on keywords in TITLE
        if "oxidation" in spectrum_id:
            labels.append(1)
        elif "phospho" in spectrum_id:
            labels.append(2)
        elif "k_gg" in spectrum_id:
            labels.append(3)
        elif "k_ac" in spectrum_id:
            labels.append(4)
        else:
            labels.append(0)

    print(f"Labels: {Counter(labels)}")
    return labels


---

# 🧠 Hybrid CNN-Transformer Classification Model

This module defines the **final architecture** used for **multi-class PTM classification** from MS/MS spectra.
The model integrates **local pattern extraction (CNN)**, **global context modeling (Transformer)**, and **metadata (parent ion mass)** into a single classification head.

---

## 🔹 `PositionalEncoding`

Implements **sinusoidal positional encodings** (Vaswani et al., 2017), injecting sequence order information into embeddings.

* **Signature:**

  ```python
  PositionalEncoding(d_model: int = 64, seq_len: int = 175, dropout: float = 0.1)
  ```
* **Behavior:** Precomputes a `[1, seq_len, d_model]` tensor of `sin`/`cos` terms and adds it to input, followed by dropout.
* **Input:** `[B, L, d_model]` with `L ≤ seq_len`.
* **Output:** Same shape as input.

**Example**

```python
pe = PositionalEncoding(d_model=64, seq_len=1000, dropout=0.1)
x = torch.randn(32, 10, 64)   # [batch, seq_len, d_model]
x = pe(x)                     # same shape
```

---

## 🔹 `EncoderTransformerClassifier`

A **hybrid classifier** with four main blocks:

1. **1D CNN Encoder** – Extracts local spectral patterns

   ```
   Conv1d(1→32, k=5, pad=2) → BN → ReLU
   MaxPool1d(k=2)            # halves length
   Conv1d(32→64, k=3, pad=1) → BN → ReLU
   Flatten
   ```

   * **Output:** `[B, 64 * (input_size // 2)]`

2. **Linear Encoder** – Projects CNN features into Transformer latent space

   ```
   Linear(64*(S/2) → 512) → BN → ReLU
   Linear(512 → latent_size) → BN → ReLU → Dropout
   ```

   * **Output:** `[B, latent_size]`

3. **Positional Encoding + Transformer** – Global context

   * Expand to sequence: `[B, 1, latent_size]`
   * Add sinusoidal encoding
   * Pass through `nn.TransformerEncoder` (`num_layers`, `num_heads`, `dim_feedforward=4*latent_size`)
   * Mean over sequence dim → `[B, latent_size]`

4. **Parent Ion Processor** – Encodes normalized parent mass

   ```
   Linear(1 → 64) → ReLU
   Linear(64 → latent_size) → ReLU
   ```

   * **Output:** `[B, latent_size]`

5. **Fusion & Classification**

   ```
   concat([spectrum, parent]) → [B, 2*latent_size]
   Dropout
   Linear(2*latent_size → num_classes)
   ```

   * **Output:** logits `[B, num_classes]`

---

### ✅ Forward Pass

**Inputs**

* `spectra`: `[B, S]` (dense binned spectrum, length = `input_size`)
* `parent_ion`: `[B]` (normalized precursor mass)

**Output**

* `logits`: `[B, num_classes]`

---

### 🔧 Implementation Notes

* `latent_size % num_heads == 0` is enforced.
* `input_size` must be **even** (due to `MaxPool1d`).
* The Transformer currently sees only one token per spectrum (global embedding). To use **true attention over multiple tokens**, pass a sequence (e.g., CNN feature map before flattening).

---

### 🧪 Example

```python
model = EncoderTransformerClassifier(
    input_size=175, latent_size=64, num_classes=5,
    num_heads=4, num_layers=2, dropout_prob=0.1
)

spectra = torch.randn(32, 175)   # [batch, S]
parent  = torch.rand(32)         # [batch], normalized
logits  = model((spectra, parent))  # [32, 5]
```

**Loss**

* Multi-class PTM classification:

  ```python
  loss_fn = nn.CrossEntropyLoss()
  loss = loss_fn(logits, labels)   # labels ∈ [0..num_classes-1]
  ```

In [None]:


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int = 64, seq_len: int = 175, dropout: float = 0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)

        position = torch.arange(seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # shape: [1, seq_len, d_model]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len]
        return self.dropout(x)


class EncoderTransformerClassifier(nn.Module):
    def __init__(self, input_size, latent_size, num_classes, num_heads, num_layers, dropout_prob, max_len=1000):
        super(EncoderTransformerClassifier, self).__init__()
        self.input_size = input_size
        self.latent_size = latent_size
        self.num_classes = num_classes

        # Validate divisibility
        if latent_size % num_heads != 0:
            raise ValueError(f"latent_size ({latent_size}) must be divisible by num_heads ({num_heads}).")

        # 1. CNN Encoder (New Layer)
        self.cnn_encoder = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),  # Downsample
            nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Flatten()
        )

        # 2. Linear Encoder (Refactored)
        self.encoder = nn.Sequential(
            nn.Linear(64 * (input_size // 2), 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, latent_size),
            nn.BatchNorm1d(latent_size),
            nn.ReLU(),
            nn.Dropout(dropout_prob)
        )

        #3. Positional Encoding
        self.positional_encoding = PositionalEncoding(d_model=latent_size, seq_len=max_len, dropout=dropout_prob)

        # 4. Transformer Encoder
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=latent_size,
                nhead=num_heads,
                dim_feedforward=latent_size * 4,
                dropout=dropout_prob,
                activation='relu',
                batch_first=True,
                norm_first=False
            ),
            num_layers=num_layers
        )

        # Parent Ion Layer
        self.parent_ion_layer = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, self.latent_size),
            nn.ReLU()
        )

        # Dropout before classification
        self.dropout = nn.Dropout(dropout_prob)


        self.classifier_head = nn.Linear(latent_size * 2, num_classes)


    def forward(self, inputs):
        spectra, parent_ion = inputs
        parent_ion = parent_ion.unsqueeze(1)

        # CNN Encoder
        spectra = spectra.unsqueeze(1)  # Ensure input is [B, 1, S]
        cnn_output = self.cnn_encoder(spectra)

        # Linear Encoder
        x = self.encoder(cnn_output)

        # Positional Encoding and Transformer
        x = x.unsqueeze(1)  # Adding sequence dimension
        x = self.positional_encoding(x)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)

        # Parent Ion Encoding
        parent = self.parent_ion_layer(parent_ion).squeeze(1)

        # Concatenate
        combined = torch.cat([x, parent], dim=1)
        combined = self.dropout(combined)

        logits = self.classifier_head(combined)  # shape: [batch, num_classes]
        return logits




# Stratified Group 10-Fold Cross-Validation (Streaming, Low-Memory, 5 Classes)

This section documents how K-fold cross-validation is implemented for the 5-class MS/MS classifier while avoiding leakage and keeping RAM usage low.

## Goals

* **No leakage:** all spectra from the same `.mgf` file appear **only** in train *or* validation within a fold.
* **Stratified balance:** approximate class proportions are preserved across folds (when supported).
* **Low memory:** data are streamed file-by-file; features are not stored globally.
* **Reproducible & auditable:** split file lists and per-fold metrics are persisted.

---

## Configuration (as used here)

* **Classes:** `num_classes = 5`
* **Dataset:** `input_dir = "/content/drive/MyDrive/data/4_mod_balanced_dataset"`
* **Artifacts:** `model_weights_dir = "/content/drive/MyDrive/5_class_k_fold/model_weights"`
* **Bins:** `num_bins = 4500` (1D spectrum vector size)
* **Precursor normalization:** `pepmass_range = {min: 500.0, max: 6000.0}`, window size `200.0`
* **Training per mini-batch (file):** `epoch = 100`
* **Folds:** `k = 10`

---

## Pipeline Overview

### 1) Indexing pass (labels only)

We scan the dataset **once** with `DatasetHandler(input_dir, num_loops=1)` to build:

* `y_labels`: one integer label per spectrum via `spectrum_label(...)`.
* `groups`: a **file id** per spectrum (same id for all spectra from the same `.mgf`).
* `file_paths_all`: ordered list of files that yielded valid spectra.

> This pass does **not** keep feature vectors in memory; it only records labels and grouping.

### 2) Fold construction (grouped, stratified)

* Primary splitter: **`StratifiedGroupKFold(n_splits=10, shuffle=True, random_state=42)`**.
* Fallback (if not available): **`GroupKFold(n_splits=10)`** (grouping without stratification).
* The splitter operates on `(X = placeholder, y = y_labels, groups = groups)`.
* For each fold, indices are mapped back to **file paths** and written to:

  * `kfold_splits/fold{n}_train_files.txt`
  * `kfold_splits/fold{n}_val_files.txt`

This guarantees **file-level grouping** and, when available, **label stratification**.

### 3) Streaming handlers per split

We construct lightweight handlers to constrain iteration:

* `make_handler_including(only_files)` yields **only** those files.
* (A complementary `make_handler_excluding(...)` exists but isn’t required here.)
  Non-target files are marked as “already used,” so the iterator never returns them.

### 4) Per-fold training (streamed by file)

For each fold `1..10`:

1. **Fresh model init** with the configured architecture/hyperparameters (`num_classes = 5`).
2. **Training loop (train files only):**

   * Iterate over training files with the handler.
   * For each file (a mini-batch):

     * Build features via `combine_features(...)` (binned spectrum vector + normalized parent ion).
     * Create tensors `(spectra_t, parent_t)` and `labels_t`.
     * Train using
       `train_classifier_with_weights(model, (spectra_t, parent_t), labels_t, epochs=epoch, ...)`.
   * Optionally save an intermediate `fold{n}_latest_model.pth`.

This keeps memory bounded by **one file** at a time.

### 5) Per-fold evaluation (validation files only)

With `model.eval()` and `torch.no_grad()`:

* Stream validation files, run the model, and collect predictions/targets across all spectra in the fold.

### 6) Metrics and composite score

For each fold we compute:

* **Accuracy**
* **Macro:** Precision, Recall, F1
* **Weighted:** Precision, Recall, F1

The **composite score** used in this script is:
[
\text{Score} ;=; 0.45\cdot \text{MacroF1} ;+; 0.30\cdot \text{MacroRecall} ;+; 0.15\cdot \text{WeightedF1} ;+; 0.10\cdot \text{Accuracy}.
]

We persist the final fold weights as `fold{n}_final.pth`.

### 7) Logging & artifacts

* After **each fold**, append one row to `cv10_results_streaming.csv` (append-only inside the loop).
* After all folds, aggregate results are written again to `cv10_results_streaming.csv` (full table).
* Console logs summarize the number of files trained and per-fold metrics.

---

## Why this setup

* **Leakage-safe:** grouping by file ensures that spectra from the same acquisition never appear in both train and validation within a fold.
* **Balanced folds:** stratification (when available) stabilizes class distributions across folds for the **5-class** setting.
* **Scales to large data:** training and evaluation are performed **file-by-file**, avoiding large in-memory datasets.
* **Deterministic & traceable:** fixed `random_state`, persisted split files, and CSV metrics for auditability.


In [None]:
# Add a flag to control if the model should be loaded before starting the loop

#Tuned Hyperparameters (via Optuna)
latent_size = 64
dropout_prob = 0.22162835150922375
learning_rate = 0.00013235255068305934
num_heads = 4
num_layers = 6
l1_lambda = 2.3139046200726137e-07
num_bins = 4500 #Number of bins of the 1D vector for the model

#Fixed Architectual Parameters
num_classes = 5 #Number of modifications the model is trying to identify, adapt to your own dataset

#Preprocessing Configuration
pepmass_range = {'min': 500.0, 'max': 6000.0}
 #Fixed window for the normalization of the observed parent ion
window_normaliation_size = 200.00 #set this to the window of m/z in which the intesitys are normalized

#Run Time & Tranning Control
epoch = 100
num_loops = 1 #number of loops performed over your dataset
min_score_threshold = 0.90  # Set this to the threshold you want the min score of your saved weights for the model
input_dir = "/content/drive/MyDrive/data/4_mod_balanced_dataset"
model_weights_dir = "/content/drive/MyDrive/5_class_k_fold/model_weights"
assume_observed = True
load_latest_model_at_start = True  # Set this to True or False depending on whether you want to load the model at the start, also kinda doesnt work


# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device:", torch.cuda.get_device_name(0))



# Model Initialization
model = EncoderTransformerClassifier(
    latent_size=latent_size,
    num_heads=num_heads,
    num_layers=num_layers,
    dropout_prob=dropout_prob,
    input_size=num_bins,
    num_classes=num_classes
).to(device)



# ================================
# 📦 Stratified Group 10-Fold (streaming by files, low memory)
# ================================
import os
import numpy as np
import pandas as pd
from collections import Counter

# Try to use StratifiedGroupKFold; fall back gracefully if needed
try:
    from sklearn.model_selection import StratifiedGroupKFold as _SGKF
    SGKF = _SGKF
except Exception:
    from sklearn.model_selection import GroupKFold as _GK
    SGKF = None
    print("⚠️ sklearn too old for StratifiedGroupKFold. Falling back to GroupKFold (no stratification).")

k = 10
splits_dir = os.path.join(model_weights_dir, "kfold_splits")
os.makedirs(splits_dir, exist_ok=True)

# ---------- 1) Index pass: collect labels per spectrum + file grouping (no features kept) ----------
print("🔎 Indexing dataset to build stratified grouped folds (labels only)...")
y_labels = []      # per-spectrum integer label
groups = []        # per-spectrum group id (file id)
file_paths_all = []  # ordered list of files that actually had valid spectra
file_id_map = {}     # path -> id

# We'll walk files exactly once using your DatasetHandler
index_handler = DatasetHandler(input_dir=input_dir, num_loops=1)
fid_counter = 0
while True:
    res = index_handler.get_next_file()
    if res is None:
        break
    spectra_batch, fpath = res

    # Labels for this file (per-spectrum)
    labels = spectrum_label(spectra_batch)
    if len(labels) == 0:
        continue  # skip empty file
    # remember file -> id
    if fpath not in file_id_map:
        file_id_map[fpath] = fid_counter
        file_paths_all.append(fpath)
        fid_counter += 1

    file_id = file_id_map[fpath]
    y_labels.extend(labels)
    groups.extend([file_id] * len(labels))

y_labels = np.array(y_labels, dtype=np.int64)
groups   = np.array(groups, dtype=np.int64)

print(f"📊 Indexed {len(file_paths_all)} files, {len(y_labels)} spectra total.")
print("Class distribution:", Counter(y_labels))

if len(y_labels) == 0:
    raise RuntimeError("No spectra found to build folds. Check input_dir and parsing.")

# ---------- 2) Build 10 folds (grouped by file, stratified by labels) ----------
if SGKF is not None:
    skf = SGKF(n_splits=k, shuffle=True, random_state=42)
    split_iter = skf.split(X=np.zeros_like(y_labels), y=y_labels, groups=groups)
else:
    # fallback (no stratification, only grouping)
    from sklearn.model_selection import GroupKFold
    split_iter = GroupKFold(n_splits=k).split(X=np.zeros_like(y_labels), y=y_labels, groups=groups)

fold_file_lists = []  # [(train_files, val_files), ...]
for fold_idx, (train_idx, val_idx) in enumerate(split_iter, start=1):
    # Map indices -> file ids present in each side
    train_file_ids = set(groups[train_idx])
    val_file_ids   = set(groups[val_idx])

    train_files = [fp for fp, fid in file_id_map.items() if fid in train_file_ids]
    val_files   = [fp for fp, fid in file_id_map.items() if fid in val_file_ids]
    fold_file_lists.append((train_files, val_files))

    # Save lists to disk (so you can reuse later without recomputing)
    with open(os.path.join(splits_dir, f"fold{fold_idx}_train_files.txt"), "w") as f:
        f.write("\n".join(train_files))
    with open(os.path.join(splits_dir, f"fold{fold_idx}_val_files.txt"), "w") as f:
        f.write("\n".join(val_files))

    print(f"🗂️ Fold {fold_idx}: train_files={len(train_files)}, val_files={len(val_files)}")

# ---------- 3) Helpers to run handlers restricted to a file list ----------
def make_handler_excluding(exclude_files):
    h = DatasetHandler(input_dir=input_dir, num_loops=1)
    # mark excluded files as already "used" so handler never returns them
    for f in exclude_files:
        if f in h.file_usage_counter:
            h.file_usage_counter[f] = MAX_FILE_PASSES
    return h

def make_handler_including(only_files):
    # We exclude everything NOT in only_files
    h = DatasetHandler(input_dir=input_dir, num_loops=1)
    only = set(only_files)
    for f in list(h.file_usage_counter.keys()):
        if f not in only:
            h.file_usage_counter[f] = MAX_FILE_PASSES
    return h

# ---------- 4) Run K folds (streaming train/eval by files) ----------
from sklearn.metrics import precision_score, recall_score, f1_score

fold_rows = []

for fold, (train_files, val_files) in enumerate(fold_file_lists, start=1):
    print(f"\n====================\n🚀 Fold {fold}/{k}\n====================")

    # Fresh model per fold
    model = EncoderTransformerClassifier(
        latent_size=latent_size,
        num_heads=num_heads,
        num_layers=num_layers,
        dropout_prob=dropout_prob,
        input_size=num_bins,
        num_classes=num_classes
    ).to(device)

    # --- Training: iterate ONLY over training files ---
    train_handler = make_handler_including(train_files)

    # Train per-file using your existing single-batch trainer (keeps your steps intact)
    # This re-creates optimizer inside per call (same as your original per-batch training loop).
    trained_files = 0
    while True:
        res = train_handler.get_next_file()
        if res is None:
            break
        spectra_batch, batch_file = res

        feature_batch = combine_features(
            spectra_batch, pepmass_range, num_bins,
            window_normaliation_size, assume_observed=assume_observed
        )
        if not feature_batch:
            continue

        spectra_vecs, parent_ions = zip(*feature_batch)
        spectra_t = torch.tensor(np.array(spectra_vecs), dtype=torch.float32).to(device)
        parent_t  = torch.tensor(np.array(parent_ions), dtype=torch.float32).to(device)
        labels_t  = torch.tensor(spectrum_label(spectra_batch), dtype=torch.long).to(device)

        # Train on this file (your function)
        save_path = os.path.join(model_weights_dir, f"fold{fold}_latest_model.pth")
        train_classifier_with_weights(
            model, (spectra_t, parent_t),
            labels_t,
            epochs=epoch,
            learning_rate=learning_rate,
            l1_lambda=l1_lambda,
            save_path=save_path,
            device=device
        )
        trained_files += 1

    print(f"✅ Trained on {trained_files} files in fold {fold}.")

    # --- Evaluation: iterate ONLY over validation files; aggregate predictions/labels across files ---
    val_handler = make_handler_including(val_files)

    all_val_preds = []
    all_val_targets = []

    model.eval()
    with torch.no_grad():
        while True:
            res = val_handler.get_next_file()
            if res is None:
                break
            spectra_batch, batch_file = res

            feature_batch = combine_features(
                spectra_batch, pepmass_range, num_bins,
                window_normaliation_size, assume_observed=assume_observed
            )
            if not feature_batch:
                continue

            spectra_vecs, parent_ions = zip(*feature_batch)
            spectra_t = torch.tensor(np.array(spectra_vecs), dtype=torch.float32).to(device)
            parent_t  = torch.tensor(np.array(parent_ions), dtype=torch.float32).to(device)
            labels_t  = torch.tensor(spectrum_label(spectra_batch), dtype=torch.long).to(device)

            outputs = model((spectra_t, parent_t.unsqueeze(1)))
            preds = torch.argmax(outputs, dim=1)

            all_val_preds.extend(preds.cpu().numpy().tolist())
            all_val_targets.extend(labels_t.cpu().numpy().tolist())

    if len(all_val_targets) == 0:
        print("⚠️ No validation spectra in this fold. Skipping metric computation.")
        continue

    y_true = np.array(all_val_targets)
    y_pred = np.array(all_val_preds)

    macro_precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
    macro_recall    = recall_score(y_true, y_pred, average='macro', zero_division=0)
    macro_f1        = f1_score(y_true, y_pred, average='macro', zero_division=0)
    weighted_prec   = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    weighted_rec    = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    weighted_f1     = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    accuracy        = (y_true == y_pred).mean()

    # Same composite score as your evaluate_model
    fold_score = (
        0.45 * macro_f1 +
        0.30 * macro_recall +
        0.15 * weighted_f1 +
        0.10 * accuracy
    )

    # Save fold model
    fold_model_path = os.path.join(model_weights_dir, f"fold{fold}_final.pth")
    torch.save(model.state_dict(), fold_model_path)

    print(
        f"📈 Fold {fold} — Score: {fold_score:.4f} | Acc: {accuracy:.4f} | "
        f"Macro F1: {macro_f1:.4f} | Macro Rec: {macro_recall:.4f} | Weighted F1: {weighted_f1:.4f}"
    )

    fold_rows.append({
    "fold": fold,
    "n_val_spectra": len(y_true),
    "accuracy": accuracy,
    "macro_precision": macro_precision,
    "macro_recall": macro_recall,
    "macro_f1": macro_f1,
    "weighted_precision": weighted_prec,
    "weighted_recall": weighted_rec,
    "weighted_f1": weighted_f1,
    "score": fold_score,
    "fold_model_path": fold_model_path
})

  # ✅ write/update CSV after each fold (INSIDE the loop)
  cv_csv = os.path.join(model_weights_dir, "cv10_results_streaming.csv")
  write_header = (fold == 1 and not os.path.exists(cv_csv))
  pd.DataFrame([fold_rows[-1]]).to_csv(
      cv_csv, mode='a', header=write_header, index=False
  )

# ---------- 5) Save metrics ----------
cv_df = pd.DataFrame(fold_rows)
cv_csv = os.path.join(model_weights_dir, "cv10_results_streaming.csv")
cv_df.to_csv(cv_csv, index=False)
print("\n✅ Cross-validation (streamed) complete.")
print(f"📄 Metrics saved: {cv_csv}")
display(cv_df if 'display' in globals() else cv_df.head())





[1;30;43mA saída de streaming foi truncada nas últimas 5000 linhas.[0m
Epoch [43/100] - Loss: 0.0390, Accuracy: 100.00%
Epoch [44/100] - Loss: 0.0391, Accuracy: 100.00%
Epoch [45/100] - Loss: 0.0391, Accuracy: 100.00%
Epoch [46/100] - Loss: 0.0389, Accuracy: 100.00%
Epoch [47/100] - Loss: 0.0392, Accuracy: 100.00%
Epoch [48/100] - Loss: 0.0387, Accuracy: 100.00%
Epoch [49/100] - Loss: 0.0388, Accuracy: 100.00%
Epoch [50/100] - Loss: 0.0387, Accuracy: 100.00%
Sample Predictions: [0 2 0 2 3]
Actual Labels: tensor([0, 2, 0, 2, 3], device='cuda:0')
Sample Logits: [[ 6.5611796  -0.30567443 -1.4549643  -4.187828   -7.046267  ]
 [-0.3864787   0.26613998  5.866487   -3.3009741  -4.899184  ]
 [ 6.650451   -0.39826846 -3.5159068  -3.688222   -5.505843  ]
 [-3.0933795  -0.17463483  7.9197755  -4.2563753  -1.6504325 ]
 [-0.13915217 -1.1257565  -2.5317287   9.971585   -5.506624  ]]
Epoch [51/100] - Loss: 0.0385, Accuracy: 100.00%
Epoch [52/100] - Loss: 0.0386, Accuracy: 100.00%
Epoch [53/100] - L

Unnamed: 0,fold,n_val_spectra,accuracy,macro_precision,macro_recall,macro_f1,weighted_precision,weighted_recall,weighted_f1,score,fold_model_path
0,1,42070,0.857119,0.879917,0.889836,0.883234,0.858788,0.857119,0.855944,0.87851,/content/drive/MyDrive/k_fold/model_weights/fo...
1,2,42070,0.850725,0.894498,0.874256,0.880036,0.861863,0.850725,0.850318,0.870913,/content/drive/MyDrive/k_fold/model_weights/fo...
2,3,42070,0.862277,0.891422,0.889072,0.890147,0.863165,0.862277,0.862641,0.882912,/content/drive/MyDrive/k_fold/model_weights/fo...
3,4,42070,0.866366,0.891752,0.893711,0.892622,0.865671,0.866366,0.865915,0.886317,/content/drive/MyDrive/k_fold/model_weights/fo...
4,5,42070,0.862539,0.890442,0.8896,0.889158,0.864143,0.862539,0.862083,0.882568,/content/drive/MyDrive/k_fold/model_weights/fo...



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.




Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.




Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.




Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

