In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import TensorDataset, DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import pandas as pd
import math
import ast
from scipy.signal import find_peaks
import pywt

In [2]:
if torch.cuda.is_available():
    print("CUDA is available.")
    print("PyTorch version:", torch.__version__)
    print("CUDA version:", torch.version.cuda)
    print("Number of available GPUs:", torch.cuda.device_count())
    print("GPU name:", torch.cuda.get_device_name(0))
else:
    print("CUDA is not available.")

CUDA is available.
PyTorch version: 2.1.0+cu121
CUDA version: 12.1
Number of available GPUs: 1
GPU name: NVIDIA GeForce RTX 4060 Laptop GPU


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
df = pd.read_csv('data/MoNA/in-silico.csv') 

  df = pd.read_csv('data/MoNA/in-silico.csv')


In [5]:
# Prepare the labels
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(df['molecular_formula'])

In [6]:
def direct_tokenization(binned_spectrum, window_size=16):
    # Pad the spectrum if necessary
    if len(binned_spectrum) % window_size != 0:
        pad_length = window_size - (len(binned_spectrum) % window_size)
        binned_spectrum = np.pad(binned_spectrum, (0, pad_length), mode='constant')
    
    # Reshape into 2D
    return binned_spectrum.reshape(-1, window_size)

def fourier_tokenization_2d(binned_spectrum, window_size=16):
    fft = np.fft.fft(binned_spectrum)
    magnitude_spectrum = np.abs(fft[:len(fft)//2])
    
    # Pad if necessary
    if len(magnitude_spectrum) % window_size != 0:
        pad_length = window_size - (len(magnitude_spectrum) % window_size)
        magnitude_spectrum = np.pad(magnitude_spectrum, (0, pad_length), mode='constant')
    
    # Reshape into 2D
    return magnitude_spectrum.reshape(-1, window_size)

def wavelet_tokenization_2d(binned_spectrum, window_size=16, wavelet='db1'):
    coeffs = pywt.wavedec(binned_spectrum, wavelet)
    flat_coeffs = np.concatenate(coeffs)
    
    # Pad if necessary
    if len(flat_coeffs) % window_size != 0:
        pad_length = window_size - (len(flat_coeffs) % window_size)
        flat_coeffs = np.pad(flat_coeffs, (0, pad_length), mode='constant')
    
    # Reshape into 2D
    return flat_coeffs.reshape(-1, window_size)

def peak_tokenization(spectrum_string, top_n=50, pad_to=50):
    spectrum = ast.literal_eval(spectrum_string)
    spectrum.sort(key=lambda x: x[1], reverse=True)
    top_peaks = spectrum[:top_n]
    flattened = [val for peak in top_peaks for val in peak]
    
    # Pad if necessary
    if len(flattened) < pad_to * 2:
        flattened.extend([0] * (pad_to * 2 - len(flattened)))
    
    return np.array(flattened).reshape(-1, 2)

In [7]:
'''def fourier_tokenization_2d(binned_spectrum, window_size=16):
    # Perform FFT
    fft = np.fft.fft(binned_spectrum)
    magnitude_spectrum = np.abs(fft[:len(fft)//2])
    
    # Pad the spectrum if necessary
    if len(magnitude_spectrum) % (window_size * window_size) != 0:
        pad_length = window_size * window_size * (len(magnitude_spectrum) // (window_size * window_size) + 1) - len(magnitude_spectrum)
        magnitude_spectrum = np.pad(magnitude_spectrum, (0, pad_length), mode='constant')
    
    # Reshape into 2D
    num_windows = len(magnitude_spectrum) // (window_size * window_size)
    reshaped = magnitude_spectrum.reshape(num_windows, window_size, window_size)
    
    # Flatten each window
    return reshaped.reshape(num_windows, -1).tolist()

def wavelet_tokenization_2d(binned_spectrum, window_size=16, wavelet='db1'):
    # Perform wavelet transform
    coeffs = pywt.wavedec(binned_spectrum, wavelet)
    
    # Concatenate all levels
    flat_coeffs = np.concatenate(coeffs)
    
    # Pad if necessary
    if len(flat_coeffs) % (window_size * window_size) != 0:
        pad_length = window_size * window_size * (len(flat_coeffs) // (window_size * window_size) + 1) - len(flat_coeffs)
        flat_coeffs = np.pad(flat_coeffs, (0, pad_length), mode='constant')
    
    # Reshape into 2D
    num_windows = len(flat_coeffs) // (window_size * window_size)
    reshaped = flat_coeffs.reshape(num_windows, window_size, window_size)
    
    # Flatten each window
    return reshaped.reshape(num_windows, -1).tolist()'''

"def fourier_tokenization_2d(binned_spectrum, window_size=16):\n    # Perform FFT\n    fft = np.fft.fft(binned_spectrum)\n    magnitude_spectrum = np.abs(fft[:len(fft)//2])\n    \n    # Pad the spectrum if necessary\n    if len(magnitude_spectrum) % (window_size * window_size) != 0:\n        pad_length = window_size * window_size * (len(magnitude_spectrum) // (window_size * window_size) + 1) - len(magnitude_spectrum)\n        magnitude_spectrum = np.pad(magnitude_spectrum, (0, pad_length), mode='constant')\n    \n    # Reshape into 2D\n    num_windows = len(magnitude_spectrum) // (window_size * window_size)\n    reshaped = magnitude_spectrum.reshape(num_windows, window_size, window_size)\n    \n    # Flatten each window\n    return reshaped.reshape(num_windows, -1).tolist()\n\ndef wavelet_tokenization_2d(binned_spectrum, window_size=16, wavelet='db1'):\n    # Perform wavelet transform\n    coeffs = pywt.wavedec(binned_spectrum, wavelet)\n    \n    # Concatenate all levels\n    flat_coe

In [8]:
def calculate_max_mz(df, spectrum_column='spectrum'):
    def get_max_mz(spectrum_string):
        spectrum = ast.literal_eval(spectrum_string)
        return max(peak[0] for peak in spectrum)

    max_mz_series = df[spectrum_column].apply(get_max_mz)
    return int(np.ceil(max_mz_series.max()))

def bin_spectrum(spectrum_string, max_mz):
    spectrum = ast.literal_eval(spectrum_string)
    binned = np.zeros(max_mz + 1)  # +1 to include the max_mz value
    
    for mz, intensity in spectrum:
        mz_int = int(np.round(mz))
        if mz_int <= max_mz:
            binned[mz_int] += intensity
    
    return binned

def tokenize_spectrum(spectrum, method, max_mz, window_size=16):
    if isinstance(spectrum, str):
        binned_spectrum = bin_spectrum(spectrum, max_mz)
    else:
        binned_spectrum = spectrum

    if method == 'direct':
        return direct_tokenization(binned_spectrum, window_size)
    elif method == 'peak':
        return peak_tokenization(spectrum)
    elif method == 'fourier2':
        return fourier_tokenization_2d(binned_spectrum, window_size)
    elif method == 'wavelet2':
        return wavelet_tokenization_2d(binned_spectrum, window_size)
    else:
        raise ValueError(f"Unknown tokenization method: {method}")


In [9]:
class SpectralDataset(Dataset):
    def __init__(self, df, labels, tokenization_method, max_mz):
        self.spectra = df['spectrum']
        self.labels = labels
        self.tokenization_method = tokenization_method
        self.max_mz = max_mz

    def __len__(self):
        return len(self.spectra)

    def __getitem__(self, idx):
        spectrum = self.spectra.iloc[idx]
        label = self.labels[idx]
        tokenized = tokenize_spectrum(spectrum, self.tokenization_method, self.max_mz)
        return torch.tensor(tokenized, dtype=torch.float32).unsqueeze(0), label  # Add an extra dimension

In [10]:
# Calculate max_mz
max_mz = calculate_max_mz(df)
print(f"Maximum m/z value across all spectra: {max_mz}")    

Maximum m/z value across all spectra: 3533


In [11]:
# Split the data
X_train = df
X_test = df
y_train = y
y_test = y

In [12]:
def load_tokenized_data(method, batch_size=32):
    train_dataset = SpectralDataset(X_train, y_train, method, max_mz)
    test_dataset = SpectralDataset(X_test, y_test, method, max_mz)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

In [13]:
# Defining the transformer model
class MS_VIT(nn.Module):
    def __init__(self, num_classes, embed_dim=40, depth=12, num_heads=2):
        super(MS_VIT, self).__init__()
        self.embedding = nn.Linear(16, embed_dim)  # Assume 16 features per token
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads),
            num_layers=depth
        )
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # x shape: (batch_size, seq_length, 16)
        batch_size, seq_length, _ = x.shape
        x = x.view(batch_size * seq_length, -1)  # Reshape to (batch_size * seq_length, 16)
        x = self.embedding(x)  # Shape: (batch_size * seq_length, embed_dim)
        x = x.view(batch_size, seq_length, -1)  # Reshape back to (batch_size, seq_length, embed_dim)
        x = x.permute(1, 0, 2)  # Shape: (seq_length, batch_size, embed_dim)
        x = self.transformer(x)
        x = x.mean(dim=0)  # Global average pooling
        x = self.fc(x)
        return x
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_length=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_length, d_model)
        position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,d_model,2)).float() * (-math.log(10000.0)/d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0,1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x



In [14]:
# Training function
def train_model(model, train_loader, test_loader, optimizer, criterion, num_epochs=50):
    device = next(model.parameters()).device  # Get the device the model is on
    
    for epoch in range(num_epochs):
        model.train()
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(x_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
        
        # Evaluate on test set
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x_batch, y_batch in test_loader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                outputs = model(x_batch)
                _, predicted = torch.max(outputs.data, 1)
                total += y_batch.size(0)
                correct += (predicted == y_batch).sum().item()
        
        accuracy = correct / total
        print(f'Epoch {epoch+1}/{num_epochs}, Accuracy: {accuracy:.4f}')

    return model

In [15]:
results = {}
for method in ['direct', 'peak', 'fourier2', 'wavelet2']:
    print(f"\nTraining with {method} tokenization:")
    
    train_loader, test_loader = load_tokenized_data(method)
    
    # Get a sample batch to determine input dimensions
    sample_batch, _ = next(iter(train_loader))
    print('Sample shape:', sample_batch.shape)
    print('Sample first element:', sample_batch[0])
    
    num_classes = len(label_encoder.classes_)
    
    model = MS_VIT(num_classes).to(device)
    optimizer = Adam(model.parameters())
    criterion = CrossEntropyLoss()
    
    model = train_model(model, train_loader, test_loader, optimizer, criterion)
    
    # Final evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x_batch, y_batch in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            outputs = model(x_batch)
            _, predicted = torch.max(outputs.data, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()
    
    accuracy = correct / total
    results[method] = accuracy
    print(f"Final accuracy with {method} tokenization: {accuracy:.4f}")


Training with direct tokenization:


Sample shape: torch.Size([32, 1, 221, 16])
Sample first element: tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])


TypeError: MS_VIT.__init__() missing 1 required positional argument: 'num_classes'

In [None]:
# Compare results
for method, accuracy in results.items():
    print(f"{method} tokenization accuracy: {accuracy:.4f}")