# MassSpecGym Retrieval Task - Abril Risso

## 1. Introduction

#### 1.1 Problem Overview

Mass spectrometry (MS) is a technique in metabolomics and small-molecule analysis. In tandem mass spectrometry (MS/MS), a molecule is fragmented and the instrument measures a spectrum of fragment peaks, each defined by a mass-to-charge ratio (m/z) and an intensity. The main computational challenge is **molecule identification**: given an observed MS/MS spectrum, determine which molecule produced it.

This project focuses on the MassSpecGym benchmark. The idea is to train models that translate spectra into molecular representations, so the right molecule can be quickly retrieved and ranked among realistic candidate structures.

#### 1.2 Molecule Retrieval - The Task

- **Input:** An MS/MS spectrum ($m/z$ and intensity values) and the precursor mass.
- **Output:** A ranked list of candidate molecules from a database.
- **Goal:** Ensure the ground-truth molecule is ranked as high as possible (Top-1, Top-5, etc.).

#### 1.3 Baseline Approach - Github Model

The official starter code provides a simple yet strong baseline based on DeepSets. The idea of this model is to treat an MS/MS spectrum as a set of peaks where the order does not matter, so instead, they focus on the collective presence. The architecture consists of two main neural network components, $\phi$ and $\rho$:
1. Peak Encoding ($\phi$): An MLP processes each peak $(m/z, I)$ independently to generate a latent feature vector.
2. Permutation Invariant Pooling: The individual peak representations are aggregated using a Sum operation. So it creates a fix-size representation of the spectrum invariant to the input peaks.
3. Fingerprint Projection ($\rho$): A second MLP projects this global vector into the final molecular fingerprint space.

The model is trained using a fingerprint reconstruction loss (MSE) comparing the predicted fingerprint to the ground-truth fingerprint.

At evaluation, retrieval is performed by scoring the candidate molecules using cosine similarity between the predicted fingerprint and each candidate fingerprint. The candidates are ranked by similarity and the performance is reported with Hit@k metrics (Hit@1, Hit@5, ...).

#### 1.4 Two Complementary Models

This notebook introduces two improved models that address key limitations of the baseline model and target the two evaluation perspectives in MassSpecGym:
- **Model 1** (Hit@k-optimized): a Spectral Transformer trained with a contrastive ranking loss (InfoNCE), designed to directly optimize retrieval performance.
- **Model 2** (F1-optimized): a Spectral Transformer trained as a multi-label fingerprint classifier (Focal Loss), designed to maximize instance-wise F1 under sparse fingerprints.

## 2. Model 1: Hit@k-optimized

In MassSpecGym, the primary goal of the retrieval task is to rank the correct molecule up to 256 candidates per spectrum, and evaluation is done with Hit rate @k (Hit@1, Hit@5, Hit@10, …). Therefore, the model is successful if the true molecule scores higher than the incorrect ones.

The baseline model (DeepSets) predicts a molecular fingerprint and is trained with MSE, so it learns to reconstruct fingerprint bits rather than to rank the true molecule above the incorrects. Even if fingerprint reconstruction improves, that does not guarantee that the true candidate will be ranked above many near-mass isomers.

The DeepSets baseline is a good starting point because it doesn’t depend on the order of the peaks, but it mostly treats each peak on its own. In reality, mass spectra have structure. Many peaks are connected through common neutral losses and fragmentation pathways.

To address this, the first model uses a Spectral Transformer instead of a set-based encoder and it trains with a contrastive ranking loss (InfoNCE) instead of MSE. This directly trains the model to retrieve the true molecule in the top-k results (high Hit Rate@k).

#### 2.1 Transformer Architecture

The core of the model is a custom class, **SpectralTransformerEncoder**, which treats the mass spectrum as a sequence of peaks rather than an unordered set. Unlike DeepSets (the baseline model) that process peaks independently, a Transformer encoder with **Self-Attention** enables the model to capture dependencies between peaks, which is important for learning fragmentation-related structure and isotopic patterns.

The architecture is configured with the following hyperparameters to balance model capacity and computational efficiency:
- **Embedding Dimension** ($d_{model} = 256$): Each peak is projected into a 256-dimensional vector. This size is sufficient to encode the m/z encoding (Fourier features) and intensity features without making the model unnecessarily large or slow to train.

- **Attention Heads** ($n_{head} = 4$): The model uses 4 parallel attention heads, allowing it to focus on different types of peak relationships simultaneously while maintaining high processing speed.

- **Encoder Layers** ($N = 2$): A stack of 2 Transformer Encoder layers is used. Since the current training time is approximately one hour per epoch, making the network deeper would significantly increase the compute cost and slow down experimentation. Also, MS/MS spectra usually do not need a deep hierarchical processing like in language models. For this reason, a small number of Transformer layers is reasonable to test the impact of the proposed improvements (explained in the next section) without depending mainly on adding depth.

#### 2.2 Architectural Improvements

This new model class (**SpectralTransformerEncoder**) includes four engineering improvements to better capture the details of MS/MS data:

1. **High-Frequency Fourier Features**

Standard neural networks have problems to learn high-frequency functions, a phenomenon known as "spectral bias". In MS/MS, tiny mass differences matter a lot (for example, isotopes). So, if we feed the model with raw $m/z$ numbers, it can miss these small but important differences. Instead, each m/z value is encoded using multiple sine and cosine waves, creating a high-dimensional vector. This allows small mass differences to be more easily captured by the model.

$$\gamma(v) = [\sin(2\pi \mathbf{B} v), \cos(2\pi \mathbf{B} v)]$$

The frequency matrix $\mathbf{B}$ is sampled from a normal distribution scaled by σ. Because we normalize masses, we choose a high σ to keep the encoding frequencies high enough to still capture very small variations after scaling, specifically sigma is set to 10.

This technique is noted in the MassSpecGym paper as an improvement to DeepSets (DeepSets + Fourier Features). The results of this improvement showed a better performance, increasing the Hit Rate@1 from 1.47% to 5.24%.


2. **Log-Intensity Injection**

Spectral intensities in mass spectrometry often follow a power-law distribution: most peaks are weak and only a few are very strong. However, those small peaks can still be very informative. Therefore, instead of feeding the model the raw intensity $I$, it is compressed with a log transform to make the smaller peaks more visible to the model and large peaks do not dominate:

$$I' = \log(1 + I)$$

Some intensities can be zero. Since log(0) is undefined, adding 1 keeps all values valid and preserves the ordering (higher $I$ still gives higher $I'$).

3. **Contextual Precursor Injection**

The precursor mass is a key constraint in an MS/MS spectrum, fragment peaks cannot exceed the precursor $m/z$. Instead of adding the precursor mass only at the end (as simple baselines do), the precursor $m/z$ is embedded and added to every peak token before entering the Transformer. This gives the model global context, so each fragment is interpreted relative to its parent mass, making neutral losses ($Precursor - Fragment$) easier to learn throughout the network.

4. **Attention Pooling**

The DeepSets baseline uses *Sum* pooling which adds all peak features equally. As mentioned above, this treats every peak as equally important to the prediction. But in reality, some peaks are highly diagnostic, while others are noise. Therefore, sum pooling is replaced with an *Attention pooling* layer that assigns a specific weight $\alpha_i$ to each peak embedding $h_i$ after the Transformer layers.

$$h_{global} = \sum_{i} \text{softmax}(\alpha_i) \cdot h_i$$

This lets the model focus on informative peaks and downweight irrelevant ones, effectively learning a simple denoising step.

5. **Loss: Contrastive Learning**

The most significant change from the DeepSets baseline model is the learning objective. The baseline treats retrieval as a fingerprint regression and minimizes the reconstruction loss (MSE) between the predicted fingerprint vector and the ground truth fingerprint. **However, our goal is not to generate a perfect fingerprint, but to rank the correct molecule above a set of candidate molecules.**

When using MSE, if the model is uncertain between possible molecular structures, it tends to predict the average of the possibilities to minimize the numerical error. But this then is worse for the retrieval task, as an averaged fingerprint can end up being more similar to many candidates but highly similar to none, which reduces the probability that the true candidate becomes the top-ranked molecule.

So, instead of using the MSE loss, it was replaced with a contrastive loss, specifically InfoNCE. Instead of forcing the model to reconstruct bits, this objective optimizes the ranking of the candidates.

The loss calculates the probability of correctly identifying $k_+$ (the ground truth) among the set of decoys:

$$\mathcal{L} = -\log \frac{\exp(\text{sim}(q, k_+) / \tau)}{\sum_{i=0}^{N} \exp(\text{sim}(q, k_i) / \tau)}$$

Where:
- $q$ is the predicted embedding of the input
- $k_+$ is the fingerprint of the ground truth molecule
- $k_i$ are the fingerprints of the negative candidates (decoys) present in the batch
- $\tau$ is the temperature parameter (set to 0.1)

Setting $\tau=0.1$ increases the penalty for hard negatives, forcing the model to learn detailed features. Using a higher $\tau$ would cause hard negatives (isomers) to be treated almost the same as the ground truth, but in reality the ground truth must be significantly more similar than the isomer. So, we must enforce this distinction because the objective is precise identification, not generating a structurally similar fingerprint (as MSE tends to do by averaging).

The similarity function $\text{sim}(u, v)$ is defined as the Cosine Similarity, implemented efficiently as the dot product of $L_2$-normalized vectors. This forces the model to maximize the directional alignment with the positive key $k_+$ while simultaneously pushing away the embeddings of the negative keys $k_i$.

6. **Optimization**

To ensure training stability, the model is optimized using AdamW with a learning rate of $1\times 10^{-4}$ and a weight decay of $1\times 10^{-4}$. The use of weight decay is essential to apply regularization, as it penalizes the model for having excessively large weights, which prevents overfitting and encourages the network to learn smoother and more robust features rather than memorizing noise.



#### 2.3 Implementation

##### Architecture

In [None]:
class FourierFeatures(nn.Module):
    def __init__(self, output_dim, sigma=1.0): 
        super().__init__()
        self.num_freqs = output_dim // 2
        self.register_buffer('B', torch.randn(self.num_freqs) * sigma)

    def forward(self, x):
        x_scaled = x / 1000.0 
        projected = 2 * math.pi * x_scaled * self.B
        return torch.cat([torch.sin(projected), torch.cos(projected)], dim=-1)

class SpectralTransformerEncoder(nn.Module):
    def __init__(self, d_model=256, nhead=4, num_layers=2, out_channels=4096, dropout=0.1):
        super().__init__()
        self.fourier_dim = d_model // 2
        self.mz_enc = FourierFeatures(output_dim=self.fourier_dim, sigma=10.0)
        self.int_enc = nn.Linear(1, d_model - self.fourier_dim)
        self.precursor_proj = nn.Linear(1, d_model)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model*4, 
            dropout=dropout, batch_first=True, norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.attention_pool = nn.Linear(d_model, 1) 
        self.head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, out_channels)
        )

    def forward(self, x_spec, precursor_mz):
        mz = x_spec[:, :, 0:1]
        intensity = torch.log1p(x_spec[:, :, 1:2])
        mz_emb = self.mz_enc(mz)
        int_emb = self.int_enc(intensity)
        peak_embs = torch.cat([mz_emb, int_emb], dim=-1)
        
        if precursor_mz.dim() > 1: precursor_mz = precursor_mz.squeeze(-1)
        precursor_mz_norm = precursor_mz.unsqueeze(-1).float() / 1000.0
        prec_feat = self.precursor_proj(precursor_mz_norm)
        
        x = peak_embs + prec_feat.unsqueeze(1)
        x_out = self.transformer(x)
        attn_weights = torch.softmax(self.attention_pool(x_out), dim=1)
        global_repr = torch.sum(x_out * attn_weights, dim=1)
        return self.head(global_repr)

##### MassSpecGym retrieval module

In [None]:
class MyRetrievalTransformer(RetrievalMassSpecGymModel):
    def __init__(self, d_model=256, nhead=4, num_layers=2, out_channels=4096, lr=1e-4, temp=0.07, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_hyperparameters()
        self.model = SpectralTransformerEncoder(d_model, nhead, num_layers, out_channels)

    def forward(self, x, precursor_mz):
        return self.model(x, precursor_mz)

    def step(self, batch: dict, stage: Stage) -> dict:
        x = batch["spec"]
        precursor_mz = batch["precursor_mz"]
        candidates = batch["candidates_mol"]
        batch_ptr = batch["batch_ptr"]
        labels = batch["labels"]

        fp_pred = self.forward(x, precursor_mz)
        
        fp_pred = F.normalize(fp_pred, p=2, dim=-1)
        candidates = F.normalize(candidates, p=2, dim=-1)

        fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)
        cos_sim = (fp_pred_repeated * candidates).sum(dim=-1)
        
        logits = cos_sim / self.hparams.temp
        
        batch_indices = torch.arange(len(batch_ptr), device=logits.device).repeat_interleave(batch_ptr)
        exp_logits = torch.exp(logits)
        denominators = torch.zeros(len(batch_ptr), device=logits.device, dtype=logits.dtype)
        denominators.scatter_add_(0, batch_indices, exp_logits)
        log_denominators = torch.log(denominators + 1e-10)
        
        pos_logits = logits[labels.bool()]
        loss = (log_denominators - pos_logits).mean()

        return {"loss": loss, "scores": cos_sim}

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=1e-4)

#### 2.4 Results

The model was evaluated on the test set using the retrieval metrics defined in MassSpecGym. The performance summary is the following:

| Metric | Value | Description |
| :--- | :---: | :--- |
| **Hit Rate @ 1** | **8.38%** | Exact match accuracy (Top-1). |
| **Hit Rate @ 5** | 20.35% | Correct molecule appears in the top 5 candidates. |
| **Hit Rate @ 20** | 42.00% | Correct molecule appears in the top 20 candidates. |
| **MCES @ 1** | 22.62 | Maximum Common Edge Subgraph (structural similarity of top prediction). |
| **Test Loss** | 4.735 | Final InfoNCE loss on test set. |

The proposed Spectral Transformer significantly outperforms the baseline architectures mentioned in the literature.

1. **High Precision (HR@1):** 

The model achieves a Hit Rate@1 of 8.38%. Compared to the performance reported of DeepSets with Fourier features (~5.24%) and basic DeepSets (~1.47%), this represents an improvement and confirms that the combination of Self-Attention and the InfoNCE loss successfully enables the model to distinguish specific molecular structures from decoys, rather than just learning average features.

2. **Effective Retrieval (HR@20):** 

The Hit Rate@20 is 42.0%. This indicates that in almost half of the test cases, the correct molecule is found in the top 20 results. This is highly valuable, as it reduces the search space from thousands of candidates to a manageable list that can be manually verified.

3. **Structural Consistency (MCES):** 

The MCES@1 score of 22.62 suggests that even when the model fails to identify the exact molecule (Top-1), the predicted molecule shares a significant structural similarity with the ground truth, indicating that the model has learned meaningful embeddings.

## 3. Model 2: Instance-wise F1-optimized

While Model 1 focuses on *retrieval* (ranking), Model 2 focuses on **fingerprint reconstruction**, in other words, maximizing the **Instance-wise F1 Score**.

In this setting, the model acts as a multi-label classifier. For each input spectrum, it must predict a binary vector of length 4096, where each bit represents the presence or absence of a molecular substructure.

#### 3.1 Shared Architecture and Objective

This model utilizes the identical **SpectralTransformerEncoder** architecture detailed in Section 2.1. It incorporates the same engineering improvements (Fourier Features, Log-Intensity, Precursor Injection and Attention Pooling).

The main difference lies in the **Training Objective**. Instead of optimizing the ranking of candidates to find the ground truth (Contrastive Learning), this model approaches the task as fingerprint reconstruction. It treats the output as a probability vector of size 4096, acting as a **multi-label classifier** where each bit represents the presence or absence of a molecular substructure.

#### 3.2 Loss Function

The MSE loss is not a great idea for this problem as fingerprints are sparse binary vectors, meaning the majority of bits are 0 and only a few are 1. When training with MSE on imbalanced data, the model tends to predict values close to zero for all bits to minimize the average numerical error. It results in a low loss, but results in a model that ignores the rare active bits, which are the ones needed to identify the molecule.

If we train a standard classifier with Binary Cross Entropy (BCE), the model can achieve really high accuracy simply by predicting "all zeros" (always predicting the majority class). However, such a model would have a Recall of 0% and be useless for identification.

To address this, the standard BCE loss is replaced with **Focal Loss**. Focal Loss modifies BCE by adding a modulating factor $(1 - p_t)^\gamma$ that reduces the loss contribution of "easy negatives" (the zeros that are easy to predict) and focuses training on the misclassified examples (the rare ones).

$$\mathcal{L}_{Focal} = - (1 - p_t)^\gamma \log(p_t)$$

Where:
- $p_t$ is the model's confidence in the correct class
- $\gamma$ (gamma) is the focusing parameter. $\gamma = 2.0$ is set to down-weight easy examples. This value is used because of the recommendations of the Focal Loss paper (Lin et al., 2017). The authors demonstrated that $\gamma = 2.0$ provides the optimal trade-off for highly imbalanced tasks.

This forces the model to focus on the rare positive bits, prioritizing the F1 Score (the ability to correctly detect active substructures) over simple Accuracy (overall correctness, which is misleadingly high when dominated by empty bits).

- If the model is confident and correct ($p_t \approx 1$), the factor $(1 - p_t)^\gamma$ is close to 0, silencing the loss for that example.
- If the model is uncertain or wrong ($p_t$ is low), the factor remains high, maintaining the loss signal.

#### 3.3 Implementation

##### Focal Loss

In [None]:
class FocalBCEWithLogits(nn.Module):
    """
    Focal loss for multi-label classification (fingerprints) based on BCEWithLogits.
    """
    def __init__(self, gamma: float = 2.0, reduction: str = "mean"):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        targets = targets.to(dtype=logits.dtype)
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
        pt = torch.exp(-bce) # pt is the probability of the true class
        loss = ((1.0 - pt) ** self.gamma) * bce 
        return loss.mean() if self.reduction == "mean" else loss

##### Lightning Module

In [None]:
class FingerprintPredictor(pl.LightningModule):
    def __init__(self, d_model=256, nhead=4, num_layers=2, out_channels=4096, lr=1e-4, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        # Reuse the encoder from Section 2
        self.model = SpectralTransformerEncoder(d_model, nhead, num_layers, out_channels)
        self.loss_fn = FocalBCEWithLogits(gamma=2.0)
        self.val_f1 = BinaryF1Score(multidim_average='samplewise')

    def forward(self, x, precursor_mz):
        return self.model(x, precursor_mz)

    def step(self, batch):
        x, precursor_mz = batch["spec"], batch["precursor_mz"]
        fp_logits = self.forward(x, precursor_mz)
        fp_true = batch["mol"].to(dtype=fp_logits.dtype)
        loss = self.loss_fn(fp_logits, fp_true)
        return loss, fp_logits, fp_true

    def training_step(self, batch, batch_idx):
        loss, _, _ = self.step(batch)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, logits, targets = self.step(batch)
        self.log("val_loss", loss, on_epoch=True)
        self.val_f1.update(torch.sigmoid(logits), targets.long())
        return loss

    def on_validation_epoch_end(self):
        self.log("val_f1", self.val_f1.compute().mean(), prog_bar=True)
        self.val_f1.reset()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=1e-4)

#### 3.4 Results

The model was evaluated on the test set focusing on the reconstruction quality of the fingerprints. The performance summary is the following:

| Metric | Value | Description |
| :--- | :---: | :--- |
| **Test F1 (Sample-wise)** | **28.27%** | Harmonic mean of Precision and Recall, averaged per molecule. |
| **Test Loss** | 0.0233 | Final Focal Loss value on the test set. |

At first an F1 Score of 28.27% might appear low compared to standard classification tasks. However, in the context of high-dimensional fingerprint reconstruction, this is a positive result.

If the model had failed to learn, it would likely have converged to predicting "all zeros" to minimize error. This would result in a Recall of 0 and an F1 score of 0. A score of ~28% proves that Focal Loss successfully forced the model to predict active bits, recovering actual molecular substructures despite the massive class imbalance. This score indicates that the model is correctly identifying approximately one-third of the complex structural features.

The extremely low test loss confirms the effect of the Focal mechanism ($\gamma=2.0$). It successfully suppressed the noise from the thousands of "easy negative" bits (zeros), allowing the optimization to focus entirely on the difficult active bits.

## 4. Architectural Justification

The choice of a Transformer-based architecture over other deep learning techniques such as CNNs, RNNs, or simple MLPs is driven by the specific physical and mathematical properties of Mass Spectrometry data.

This decision is justified by analyzing the limitations of alternative architectures for this specific task:

#### 4.1 Why not CNNs (Convolutional Neural Networks)?

CNNs are the standard technique for image and audio processing, where data is dense and locally correlated, but applying them to mass spectra is problematic. An MS/MS spectrum is essentially a sparse list of peaks in continuous $m/z$ coordinates, but CNNs require a fixed grid structure similar to pixels. Applying a CNN requires discretizing the $m/z$ axis into fixed intervals. So, in order to capture details in the peaks, to distinguish isotopes, we would need really tiny bins resulting in massive vectors predominantly composed of 0s. On the contrary, using larger bins to save computing power would lead to losing precise accuracy required to identify the molecular formula. 

Furthermore, CNNs operate by analyzing small local neighborhoods of data, which makes it hard for them to detect important chemical relationships between peaks that are far apart in the spectrum, whereas a Transformer can connect these distant points instantly through self-attention.

#### 4.2 Why not RNNs (Recurrent Neural Networks / LSTMs)?

RNNs and LSTMs are designed for ordered sequences, where the position of each element has meaning (for example, language or time series). However, an MS/MS spectrum is naturally a set of peaks, where the order in which peaks are listed is arbitrary and does not represent a temporal process. Using an RNN would introduce a sequential bias that is not inherent to the data. 

Furthermore, RNNs must propagate information sequentially through a hidden state, which makes capturing relationships between distant peaks difficult. If two related peaks are far in the input sequence, the information from the first peak must survive the processing of all intermediate steps to influence the prediction. But, due to the vanishing gradient problem, the hidden state acts as a bottleneck that tends to "forget" early information. The Transformer instead avoids this issue entirely by using Self-Attention to link any two peaks instantly, regardless of their distance.

#### 4.3 Why not MLPs (DeepSets)?

Finally, while MLP-based architectures like DeepSets are valid for processing sets of peaks, they have a limitation due to the lack of explicit interaction between elements. In a DeepSets model, each peak is processed independently by an MLP to create an embedding, and these embeddings are summed or averaged at the end to form a global representation. However, in mass spectrometry, the structural identity of a molecule is often defined by the relationships between peaks rather than the peaks in isolation. The Transformer solves this with the Self-Attention mechanism.

## 5. Conclusion

This project presented two complementary approaches to improve molecular identification from mass spectrometry data using Spectral Transformers:

- **Retrieval Optimization:** By treating the spectrum as a sequence of peaks and applying Contrastive Learning (InfoNCE), this approach improved the simple fingerprint regression. The model learned to rank the true molecule higher among candidates, achieving a Hit@20 of 42.0%, which makes it a useful tool for filtering candidates.

- **Fingerprint Reconstruction:** By addressing the imbalance of molecular fingerprints with Focal Loss, the model avoided the common problem of predicting only zeros. The resulting F1 score of 28% confirms that the model can recover meaningful structural substructures from the spectral data.

- **Architectural Fit:** This work demonstrates that the Transformer architecture is a better choice than CNNs or RNNs for this task. It respects the sparse and continuous nature of the data and, it can capture relationships between distant peaks using Self-Attention.