In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
import json
from pathlib import Path

In [2]:
from processors.processors_diffms import *

  import scipy.sparse


In [3]:
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator

def smiles_to_fingerprint(smiles: str, n_bits: int, radius: int) -> np.ndarray:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None: 
                print("Molecula formula not defined")
                return np.zeros(n_bits, dtype=np.float32)
        mfpgen = rdFingerprintGenerator.GetMorganGenerator(fpSize=n_bits, radius=radius)
        fp = mfpgen.GetFingerprint(mol)
        return np.array(fp, dtype=np.float32)

In [4]:
# Convert the spec_features dictionary to a batch format (adding batch dimension)
def prepare_features(spec_features_dict):
    """Converts features from numpy arrays to PyTorch tensors."""
    features = {}
    
    features['num_peaks'] = len(spec_features_dict['peak_type'])
    
    # Convert arrays to tensors
    features['types'] = torch.tensor(spec_features_dict['peak_type'], dtype=torch.long)
    features['form_vec'] = torch.tensor(spec_features_dict['form_vec'], dtype=torch.float)
    features['ion_vec'] = torch.tensor(spec_features_dict['ion_vec'], dtype=torch.long)
    features['intens'] = torch.tensor(spec_features_dict['frag_intens'], dtype=torch.float)
    features['instruments'] = torch.tensor(spec_features_dict['instrument'], dtype=torch.long) 
    features['num_peaks'] = torch.tensor(features['num_peaks'] , dtype=torch.long) 
    
    if 'magma_fps' in spec_features_dict:
         # Ensure magma_fps is float for consistency, handle -1 if needed
         magma_fps_np = spec_features_dict['magma_fps']
         features['magma_fps'] = torch.tensor(magma_fps_np, dtype=torch.float)
         features['magma_aux_loss'] = spec_features_dict['magma_aux_loss']
        
    return features

In [5]:
from torch.nn.utils.rnn import pad_sequence

def spectra_collate_fn(batch):
    """
    Collates a list of dictionaries from SpectraDataset into a padded batch.

    Args:
        batch (list): A list of dictionaries, where each dict is an output
                      from SpectraDataset.__getitem__.

    Returns:
        dict: A dictionary containing batched and padded tensors,
              or None if the batch is empty after filtering.
    """
    # Filter out None items resulting from errors in __getitem__
    batch = [item for item in batch if item is not None]
    
    target_fps = torch.stack([item['target_fp'] for item in batch], dim=0)
    instruments = torch.stack([item['instruments'] for item in batch], dim=0)
    
    # --- Handle sequence padding ---
    # Get sequence lengths (num_peaks) - from python ints
    num_peaks = torch.tensor([item['num_peaks'] for item in batch], dtype=torch.long)
    max_len = num_peaks.max().item() if len(num_peaks) > 0 else 0 # Handle empty batch case
    
    # Prepare lists of tensors for pad_sequence
    types_list = [item['types'] for item in batch]
    form_vec_list = [item['form_vec'] for item in batch]
    ion_vec_list = [item['ion_vec'] for item in batch]
    intens_list = [item['intens'] for item in batch]
    
    # Pad sequences: batch_first=True gives [batch_size, max_len, ...]
    # Use 0 for padding value, adjust if a different value is semantically better
    batched_types = pad_sequence(types_list, batch_first=True, padding_value=0)
    batched_form_vecs = pad_sequence(form_vec_list, batch_first=True, padding_value=0.0)
    batched_ion_vecs = pad_sequence(ion_vec_list, batch_first=True, padding_value=0)
    batched_intens = pad_sequence(intens_list, batch_first=True, padding_value=0.0)
    
    mask = torch.arange(max_len)[None, :] < num_peaks[:, None]
    
    final_batch = {
        'target_fp': target_fps,
        'instruments': instruments,
        'num_peaks': num_peaks,
        'types': batched_types,
        'form_vec': batched_form_vecs,
        'ion_vec': batched_ion_vecs,
        'intens': batched_intens,
        'mask': mask
    }
    
    if 'magma_fps' in batch[0]:
        magma_fps_list = [item['magma_fps'] for item in batch]
        # Pad MAGMA fingerprints. Using 0.0 as padding value.
        # The original data uses -1 for missing FPs, padding adds 0s.
        # Ensure your model handles both -1 (missing) and 0 (padding or inactive bit) appropriately.
        batched_magma_fps = pad_sequence(magma_fps_list, batch_first=True, padding_value=0.0)
        final_batch['magma_fps'] = batched_magma_fps
        # Carry over the boolean flag (assuming it's the same for the whole batch)
        final_batch['magma_aux_loss'] = batch[0]['magma_aux_loss']
        
    return final_batch

In [6]:
class SpectraDataset(Dataset):
    """Dataset for loading spectra and SMILES from a CSV file."""
    def __init__(self, data_file_path, spectrum_processor, target_fp_size, is_train=False):
        """
        Args:
            data_file_path (str): Path to the CSV file.
            spectrum_processor (SpectrumProcessor): Instance to process spectra.
            target_fp_size (int): The desired size of the target fingerprint.
            is_train (bool): Flag indicating if this is for training (enables augmentation).
        """
        self.processor = spectrum_processor
        self.target_fp_size = target_fp_size
        self.is_train = is_train
        self.morgan_radius = 2
        
        try:
            self.data = pd.read_csv(data_file_path)
            
            # Validate required columns
            required_cols = ['spec', 'smiles', 'extracted_spectral_info']
            if not all(col in self.data.columns for col in required_cols):
                raise ValueError(f"CSV must contain columns: {required_cols}")
            print(f"Loaded {len(self.data)} records from {data_file_path}")
            
            # TODO implement the SMILES validation and filtering
            # Optional: Pre-filter invalid SMILES to avoid errors during training
            # self.data['valid_smiles'] = self.data['smiles'].apply(lambda x: Chem.MolFromSmiles(str(x)) is not None)
            # initial_len = len(self.data)
            # self.data = self.data[self.data['valid_smiles']].reset_index(drop=True)
            # print(f"Filtered out {initial_len - len(self.data)} invalid SMILES.")

        except FileNotFoundError:
            print(f"Error: Data file not found at {data_file_path}")
            self.data = pd.DataFrame()
        except Exception as e:
            print(f"Error loading or processing CSV {data_file_path}: {e}")
            self.data = pd.DataFrame()

    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.data)
      
    def __getitem__(self, idx):
        """
        Get a single item from the dataset.

        Args:
            idx (int): Index of the item to retrieve.

        Returns:
            dict: A dictionary containing processed spectral features
                  and the target fingerprint. Returns None if data is invalid.
        """
        if idx >= len(self.data):
             raise IndexError("Index out of bounds")

        row = self.data.iloc[idx]

        # Extract data, ensuring correct types
        spec_id = str(row['spec'])
        smiles = str(row['smiles'])
        raw_spec_json = row['extracted_spectral_info'] 

        # --- Generate Target Fingerprint ---
        target_fp = smiles_to_fingerprint(smiles, self.target_fp_size, self.morgan_radius)
        target_fp = torch.tensor(target_fp)
     
        # --- Process Spectrum ---
        spec_features = self.processor.process_raw_spectrum(
            raw_spec_json, spec_id=spec_id, train_mode=self.is_train
        )
        
        spec_features = prepare_features(spec_features)
        
        # --- Combine features and target --- 
        item = {**spec_features, 'target_fp': target_fp}
        return item

In [7]:
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda:1


In [10]:
DATA_DIR = "../../data/production_ready_data/train/spectrs/"
TRAIN_CSV = Path(DATA_DIR) / "MassSpecGym_fixed.csv" 
OUTPUT_SIZE = 4096  
HIDDEN_SIZE = 256 
BATCH_SIZE = 32
NUM_WORKERS = 4    
MAGMA_MODULO = 2048 
SPECTRA_DROPOUT = 0.1 
TOP_LAYERS = 2 
USE_MAGMA_AUX_LOSS = True 
FORM_EMBEDDER = "float"
MAGMA_FOLDER = '../../data/raw/msg_diffms/magma_outputs/magma_tsv' 

In [11]:
LEARNING_RATE = 1.0e-6
BATCH_SIZE = 64
EPOCHS = 50
USE_MAGMA_AUX_LOSS = True 
MAGMA_LOSS_WEIGHT = 0.2 

In [12]:
processor_train = SpectrumProcessor(
    augment_data=True,
    cls_type="ms1",
    max_peaks=500,
    magma_modulo=MAGMA_MODULO,
    magma_aux_loss=USE_MAGMA_AUX_LOSS,
    magma_folder=MAGMA_FOLDER if USE_MAGMA_AUX_LOSS else None
)

Found 231104 MAGMA files


In [13]:
train_dataset = SpectraDataset(TRAIN_CSV, processor_train, OUTPUT_SIZE, is_train=True)

Loaded 231104 records from ../../data/production_ready_data/train/spectrs/MassSpecGym_fixed.csv


In [14]:
for item in train_dataset:
    print(item)
    break

{'num_peaks': tensor(8), 'types': tensor([0, 0, 0, 0, 0, 0, 0, 3]), 'form_vec': tensor([[13., 12.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,
          0.,  0.,  0.,  0.],
        [ 7.,  6.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.],
        [13., 13.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  1.,
          0.,  0.,  0.,  0.],
        [ 7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  3.,
          0.,  0.,  0.,  0.],
        [ 6.,  4.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  3.,
          0.,  0.,  0.,  0.],
        [14., 15.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  3.,
          0.,  0.,  0.,  0.],
        [ 7.,  7.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  3.,
          0.,  0.,  0.,  0.],
        [16., 17.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  4.,
          0.,  0.,  0.,  0.]]), 'ion_vec': tensor([0, 0, 0, 0, 0, 0, 0, 0]), '

In [15]:
train_loader = DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            collate_fn=spectra_collate_fn,
        )
print(f"Train DataLoader created with {len(train_loader)} batches.")

Train DataLoader created with 3611 batches.


In [16]:
batch = next(iter(train_loader))

In [17]:
from models.spectra_encoder import SpectraEncoder, SpectraEncoderGrowing

In [None]:
model = SpectraEncoder(
    form_embedder=FORM_EMBEDDER,
    output_size=OUTPUT_SIZE,
    hidden_size=HIDDEN_SIZE,
    spectra_dropout=SPECTRA_DROPOUT,
    top_layers=TOP_LAYERS,
    magma_modulo=MAGMA_MODULO,
    peak_attn_layers=4,
).to(DEVICE)

In [19]:
criterion_main = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [20]:
def masked_bce_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Calculates BCE loss, ignoring targets where value is -1."""
    mask = target != -1
    target_masked = target[mask]
    pred_masked = pred[mask]
    if target_masked.numel() == 0: # Handle cases where mask removes all elements
        return torch.tensor(0.0, device=pred.device, requires_grad=True) # Return zero loss but allow grad flow
    loss = nn.BCELoss(reduction='mean')(pred_masked, target_masked)
    return loss

In [21]:
from tqdm.auto import tqdm 
EPOCHS = 3
for epoch in range(EPOCHS):
    model.train()
    total_main_loss = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
    
    for item in pbar:
        batch_gpu = {}
        for key, value in item.items():
                if isinstance(value, torch.Tensor):
                    batch_gpu[key] = value.to(DEVICE)
                else:
                    batch_gpu[key] = value 
                    
        optimizer.zero_grad()
        output, aux_outputs = model(batch_gpu)
        target_fp = batch_gpu['target_fp']
        target_fp = target_fp.view_as(output).float()
        main_loss = criterion_main(output, target_fp)
        
        aux_loss = torch.tensor(0.0).to(DEVICE)
        if USE_MAGMA_AUX_LOSS and 'pred_frag_fps' in aux_outputs and 'magma_fps' in batch_gpu and batch_gpu['magma_fps'].numel() > 0:
            pred_frag_fps = aux_outputs['pred_frag_fps']
            target_magma_fps = batch_gpu['magma_fps'] 
            if pred_frag_fps.shape == target_magma_fps.shape:
                pred_frag_fps_sig = torch.sigmoid(pred_frag_fps) # logits output
                aux_loss = masked_bce_loss(pred_frag_fps_sig, target_magma_fps.float())
        
        # --- Combine Losses ---
        total_loss = main_loss + MAGMA_LOSS_WEIGHT * aux_loss
        
         # --- Backward Pass & Optimize ---
        total_loss.backward()
        optimizer.step()
        current_main_loss = main_loss.item()
        current_aux_loss = aux_loss.item() 
        total_main_loss += current_main_loss
        
        pbar.set_postfix(main_loss=f"{current_main_loss:.4f}", aux_loss=f"{current_aux_loss:.4f}")

    pbar.close()
        
    final_avg_main_loss = total_main_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS} - Avg Main Loss: {final_avg_main_loss:.4f}")

Epoch 1/3:   0%|          | 0/3611 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [24]:
from tqdm.notebook import tqdm 

for epoch in range(EPOCHS):
    model.train()
    total_train_loss = 0
    total_main_loss = 0
    total_aux_loss = 0

    # Wrap train_loader with tqdm for a progress bar
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)

    for i, batch in enumerate(progress_bar): # Iterate over the tqdm object
        # Move batch to device (handle non-tensor data appropriately)
        batch_gpu = {}
        for key, value in batch.items():
             if isinstance(value, torch.Tensor):
                 batch_gpu[key] = value.to(DEVICE)
             else:
                 batch_gpu[key] = value # Keep non-tensors like names, flags as they are

        # Skip batch if essential tensors are empty after collation/filtering
        if not batch_gpu.get('form_vec', torch.empty(0)).numel() or not batch_gpu.get('target_fp', torch.empty(0)).numel():
             # print(f"Skipping empty batch {i}") # Optional: tqdm might make this less necessary
             continue

        optimizer.zero_grad()

        # Forward pass
        output, aux_outputs = model(batch_gpu) # Pass the gpu batch
     

         # Calculate Main Loss
        target_fp = batch_gpu['target_fp']
        main_loss = criterion_main(output, target_fp)

        # Calculate Auxiliary Loss (if enabled)
        aux_loss = torch.tensor(0.0).to(DEVICE)
        if USE_MAGMA_AUX_LOSS and 'pred_frag_fps' in aux_outputs and 'magma_fps' in batch_gpu and batch_gpu['magma_fps'].numel() > 0:
            pred_frag_fps = aux_outputs['pred_frag_fps'] # Shape: [batch, (num_peaks,) magma_modulo]
            target_magma_fps = batch_gpu['magma_fps']   # Shape: [batch, num_peaks, magma_modulo]

            # --- IMPORTANT ---
            # Ensure pred_frag_fps and target_magma_fps align.
            # If pred_frag_fps is [batch, magma_modulo] (global prediction),
            # you cannot directly compare it to per-peak target_magma_fps.
            # The MIST paper/code needs clarification on how aux loss is computed.
            # Assuming pred_frag_fps is per-peak for this example:
            # Shape: [batch, num_peaks, magma_modulo]
            # Make sure the shapes match before calling masked_bce_loss.
            # If shapes mismatch, you need to adapt the model or loss calculation.

            # Example check (adjust based on actual model output shape):
            if pred_frag_fps.shape == target_magma_fps.shape:
                 # Apply sigmoid if fragment_predictor doesn't have one
                 pred_frag_fps_sig = torch.sigmoid(pred_frag_fps) # Assuming logits output
                 aux_loss = masked_bce_loss(pred_frag_fps_sig, target_magma_fps)
            else:
                 # Print warning less frequently or outside tqdm loop if too noisy
                 if i % 100 == 0: # Example: Print only every 100 steps
                     print(f"\nWarning: Shape mismatch for MAGMA loss. Pred: {pred_frag_fps.shape}, Target: {target_magma_fps.shape}. Skipping aux loss.")
                 aux_loss = torch.tensor(0.0).to(DEVICE)


        # Combine Losses
        total_loss = main_loss + MAGMA_LOSS_WEIGHT * aux_loss

        # Backward pass and optimize
        try:
            total_loss.backward()
            optimizer.step()
        except RuntimeError as e:
            print(f"\nError during backward/step on batch {i}: {e}") # Add newline
            # Potentially clear gradients if step failed
            optimizer.zero_grad()
            continue # Skip this batch update

        total_train_loss += total_loss.item()
        total_main_loss += main_loss.item()
        if USE_MAGMA_AUX_LOSS:
            total_aux_loss += aux_loss.item()

        # Update tqdm progress bar description with current average losses
        avg_loss = total_train_loss / (i + 1)
        avg_main_loss = total_main_loss / (i + 1)
        avg_aux_loss = total_aux_loss / (i + 1) if USE_MAGMA_AUX_LOSS and total_aux_loss > 0 else 0 # Avoid division by zero if aux loss wasn't used yet
        progress_bar.set_postfix(loss=f"{avg_loss:.4f}", main=f"{avg_main_loss:.4f}", aux=f"{avg_aux_loss:.4f}")

        # Optional: Keep less frequent print statements if needed, but tqdm postfix is often sufficient
        # if (i + 1) % 100 == 0:
        #     print(f'Epoch [{epoch+1}/{EPOCHS}], Step [{i+1}/{len(train_loader)}], Avg Loss: {avg_loss:.4f}, Avg Main Loss: {avg_main_loss:.4f}, Avg Aux Loss: {avg_aux_loss:.4f}')

    # Ensure the progress bar finishes cleanly
    progress_bar.close()

    # Print final epoch stats after the loop
    final_avg_loss = total_train_loss / len(train_loader)
    final_avg_main_loss = total_main_loss / len(train_loader)
    final_avg_aux_loss = total_aux_loss / len(train_loader) if USE_MAGMA_AUX_LOSS and len(train_loader) > 0 else 0
    print(f'--- Epoch {epoch+1} Training Finished ---')
    print(f'Average Training Loss: {final_avg_loss:.4f}, Avg Main Loss: {final_avg_main_loss:.4f}, Avg Aux Loss: {final_avg_aux_loss:.4f}')

Epoch 1/50:   0%|          | 0/3611 [00:00<?, ?it/s]

Error processing MAGMA file ../../data/raw/msg_diffms/magma_outputs/magma_tsv/MassSpecGymID0322428.magma: 'float' object has no attribute 'split'
Error processing MAGMA file ../../data/raw/msg_diffms/magma_outputs/magma_tsv/MassSpecGymID0000152.magma: 'float' object has no attribute 'split'
Error processing MAGMA file ../../data/raw/msg_diffms/magma_outputs/magma_tsv/MassSpecGymID0000151.magma: 'float' object has no attribute 'split'
Error processing MAGMA file ../../data/raw/msg_diffms/magma_outputs/magma_tsv/MassSpecGymID0378754.magma: 'float' object has no attribute 'split'
Error processing MAGMA file ../../data/raw/msg_diffms/magma_outputs/magma_tsv/MassSpecGymID0191256.magma: 'float' object has no attribute 'split'
Error processing MAGMA file ../../data/raw/msg_diffms/magma_outputs/magma_tsv/MassSpecGymID0000155.magma: 'float' object has no attribute 'split'
Error processing MAGMA file ../../data/raw/msg_diffms/magma_outputs/magma_tsv/MassSpecGymID0000156.magma: 'float' object has

Epoch 2/50:   0%|          | 0/3611 [00:00<?, ?it/s]

KeyboardInterrupt: 