In [None]:
!pip install matchms
!pip install rdkit
!pip install torch_geometric
!pip install pickle5

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting matchms
  Downloading matchms-0.18.0-py3-none-any.whl (109 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.6/109.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting pickydict>=0.4.0
  Downloading pickydict-0.4.0-py3-none-any.whl (6.1 kB)
Collecting pyteomics>=4.2
  Downloading pyteomics-4.6-py2.py3-none-any.whl (235 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m235.1/235.1 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Collecting sparsestack>=0.4.1
  Downloading sparsestack-0.4.1-py3-none-any.whl (10 kB)
Collecting deprecated
  Downloading Deprecated-1.2.13-py2.py3-none-any.whl (9.6 kB)
Installing collected packages: pyteomics, pickydict, deprecated, sparsestack, matchms
Successfully installed deprecated-1.2.13 matchms-0.18.0 pickydict-0.4.0 pyteomics-4.6 sparsestack-0.4.1
Looking in indexes: https://pypi.org/simple, https://us-

In [None]:
from matchms.importing import load_from_msp
import numpy as np
import os
import random
import pickle
from rdkit import Chem
from rdkit.Chem import Descriptors
import matchms
from matchms import Spectrum

import matplotlib.pyplot as plt
import warnings

from rdkit.Chem.rdmolops import GetAdjacencyMatrix

# Pytorch and Pytorch Geometric
import torch
from torch_geometric.data import Data
from torch.utils.data import DataLoader

import torch
from torch.nn import Linear
import torch.nn.functional as F 
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool, GATConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn as nn
from torch_geometric.data import DataLoader
from torch.optim.lr_scheduler import StepLR

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
os.chdir("/content/drive/MyDrive/NIST_SMALL")
BASE_DIRECTORY = "/content/drive/MyDrive/NIST_SMALL"

TEST_DATA_SIZE = 5000
OUTPUT_SIZE = 1000
INTENSITY_POWER = 0.5

In [None]:
with open("/content/drive/MyDrive/NIST_SMALL/Preprocessed_test_pow_preparation_no_sparse_small.output", 'rb') as handle:
    data_list_test  = pickle.load(handle)

In [None]:
NUMBER_OF_VALIDATION = 5000
test_dataset = data_list_test[:NUMBER_OF_VALIDATION]

In [None]:
with open("/content/drive/MyDrive/NIST_SMALL/train_subset_pow.pkl", 'rb') as handle:
    train_dataset  = pickle.load(handle)

with open("/content/drive/MyDrive/NIST_SMALL/validation_subset_pow.pkl", 'rb') as handle:
    validation_dataset  = pickle.load(handle)

In [None]:
EMBEDDING_SIZE = 2000
NODE_FEATURES = 50
MASS_SHIFT = 5 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def mask_prediction_by_mass(total_mass, raw_prediction, index_shift):
    # Zero out predictions to the right of the maximum possible mass.
    # input 
    # anchor_indices: shape (,batch_size) = ex [3,4,5]
    #     total_mass = Weights of whole molecule, not only fragment
    # data: shape (batch_size, embedding), embedding from GNN in our case
    # index_shift: int constant how far can heaviest fragment differ from weight of original molecule
    # 

    data = raw_prediction.type(torch.float64)
    
    total_mass = torch.round(total_mass).type(torch.int64)
    indices = torch.arange(data.shape[-1])[None, ...].to(device)

    right_of_total_mass = indices > (
            total_mass[..., None] +
            index_shift)
    return torch.where(right_of_total_mass, torch.zeros_like(data),
                        data)

In [None]:
def scatter_by_anchor_indices(anchor_indices, data, index_shift):
    # reverse vector by anchor_indices and rest set to zero
    # input 
    # anchor_indices: shape (,batch_size) = ex [3,4,5]
    #     total_mass = Weights of whole molecule, not only fragment
    # data: shape (batch_size, embedding), embedding from GNN in our case
    # index_shift: int constant how far can heaviest fragment differ from weight of original molecule
    
    index_shift = index_shift
    anchor_indices = anchor_indices
    data = data.type(torch.float64)
    batch_size = data.shape[0]
    
    num_data_columns = data.shape[-1]
    indices = torch.arange(num_data_columns)[None, ...].to(device)
    shifted_indices = anchor_indices[..., None] - indices + index_shift
    valid_indices = shifted_indices >= 0

   

    batch_indices = torch.tile(
          torch.arange(batch_size)[..., None], [1, num_data_columns]).to(device)
    shifted_indices += batch_indices * num_data_columns

    shifted_indices = torch.reshape(shifted_indices, [-1])
    num_elements = data.shape[0] * data.shape[1]
    row_indices = torch.arange(num_elements).to(device)
    stacked_indices = torch.stack([row_indices, shifted_indices], axis=1)


    lower_batch_boundaries = torch.reshape(batch_indices * num_data_columns, [-1])
    upper_batch_boundaries = torch.reshape(((batch_indices + 1) * num_data_columns),
                                          [-1])

    valid_indices = torch.logical_and(shifted_indices >= lower_batch_boundaries,
                                     shifted_indices < upper_batch_boundaries)

    stacked_indices = stacked_indices[valid_indices]

    dense_shape = torch.tile(torch.tensor(num_elements)[..., None], [2]).type(torch.int32)

    scattering_matrix = torch.sparse.FloatTensor(stacked_indices.type(torch.int64).T,
                                                 torch.ones_like(stacked_indices[:, 0]).type(torch.float64),
                                                dense_shape.tolist())

    flattened_data = torch.reshape(data, [-1])[..., None]
    flattened_output = torch.sparse.mm(scattering_matrix, flattened_data)
    return torch.reshape(torch.transpose(flattened_output, 0, 1), [-1, num_data_columns])

In [None]:
def reverse_prediction(total_mass, raw_prediction, index_shift):
    # reverse vector by anchor_indices and rest set to zero and make preproessing
    # input 
    # total_mass: shape (,batch_size) = ex [3,4,5]
    #     total_mass = Weights of whole molecule, not only fragment
    # raw_prediction: shape (batch_size, embedding), embedding from GNN in our case
    # index_shift: int constant how far can heaviest fragment differ from weight of original molecule
    #     total_mass = feature_dict[fmap_constants.MOLECULE_WEIGHT][..., 0]
    
    total_mass = torch.round(total_mass).type(torch.int32)
    return scatter_by_anchor_indices(
        total_mass, raw_prediction, index_shift)

In [None]:
def dot_product(true, pred, mass_pow=3, intensity_pow=0.6):
    # shape for true and pred is one dimensional array
    # pred (number_of_predicted_bins)
    # defaul value for mass_pow and intensity_pow is set for Stein dot product
    assert true.ndim == pred.ndim and true.ndim == 1
    length = true.shape[-1]
    mass = np.arange(length).astype(np.float64)
        
    wl = mass ** mass_pow * pred**intensity_pow
    wu = mass ** mass_pow * true**intensity_pow
    
    pred_weighted_norm = np.sqrt(np.sum((wl**2)))
    true_weighted_norm = np.sqrt(np.sum((wu**2)))
    
    result = np.sum(wl*wu) / (pred_weighted_norm * true_weighted_norm)
    
    return result

In [None]:
def validate_similarities(true, pred, mass_pow, intensity_pow):
    # Helper function for validation
    similarities = np.array([])
    for true_instance, pred_instance in zip(true, pred):
        tmp = dot_product(true_instance, pred_instance, mass_pow=mass_pow, intensity_pow=intensity_pow)
        
        similarities = np.concatenate((similarities, tmp), axis=None)
    return similarities

In [None]:
class SKIPblock(nn.Module):
    def __init__(self, in_features, hidden_features, bottleneck_factor=0.5, USE_dropout=True, dropout_rate = 0.2):
        super().__init__()
        #only need to change shape of the residual if num_channels changes (i.e. in_c != out_c)
        #[bs,in_c,seq_length]->conv(1,in_c,out_c)->[bs,out_c,seq_length]
        
        self.batchNorm1 = nn.BatchNorm1d(in_features)
        self.relu1 = nn.ReLU()
        if USE_dropout:
            self.dropout1 = nn.Dropout(dropout_rate)
        self.hidden1= nn.utils.weight_norm(nn.Linear(in_features, int(hidden_features * bottleneck_factor)),name='weight',dim=0)
        
        self.batchNorm2 = nn.BatchNorm1d(int(hidden_features * bottleneck_factor))
        self.relu2 = nn.ReLU()
        if USE_dropout:
            self.dropout2 = nn.Dropout(dropout_rate)
        self.hidden2 = nn.utils.weight_norm(nn.Linear(int(hidden_features * bottleneck_factor), in_features),name='weight',dim=0)

    def forward(self, x):
        
        hidden = self.batchNorm1(x)
        hidden = self.relu1(hidden)
        hidden = self.dropout1(hidden)
        hidden = self.hidden1(hidden)

        hidden = self.batchNorm2(hidden)
        hidden = self.relu2(hidden)
        hidden = self.dropout2(hidden)
        hidden = self.hidden2(hidden)

        hidden = hidden + x

        return hidden

In [None]:
class CONV_BIG(torch.nn.Module):
    def __init__(self):
        # Init parent
        super(CONV_BIG, self).__init__()
        torch.manual_seed(42)

        EMBEDDING_SIZE_REDUCED = int(EMBEDDING_SIZE*0.15)

        # GCN layers
        self.initial_conv = GCNConv(NODE_FEATURES, EMBEDDING_SIZE_REDUCED)
        self.conv1 = GCNConv(EMBEDDING_SIZE_REDUCED, EMBEDDING_SIZE_REDUCED)
        self.reluconv1 = nn.ReLU()
        self.conv2 = GCNConv(EMBEDDING_SIZE_REDUCED, EMBEDDING_SIZE_REDUCED)
        self.reluconv2 = nn.ReLU()
        self.conv3 = GCNConv(EMBEDDING_SIZE_REDUCED, EMBEDDING_SIZE_REDUCED)
        self.reluconv3 = nn.ReLU()
        self.conv4 = GCNConv(EMBEDDING_SIZE_REDUCED, EMBEDDING_SIZE_REDUCED)
        self.reluconv4 = nn.ReLU()
    
        self.bottleneck = Linear(EMBEDDING_SIZE_REDUCED, EMBEDDING_SIZE)

        self.skip1 = SKIPblock(EMBEDDING_SIZE, EMBEDDING_SIZE)
        self.skip2 = SKIPblock(EMBEDDING_SIZE, EMBEDDING_SIZE)
        self.skip3 = SKIPblock(EMBEDDING_SIZE, EMBEDDING_SIZE)
        self.skip4 = SKIPblock(EMBEDDING_SIZE, EMBEDDING_SIZE)
        self.skip5 = SKIPblock(EMBEDDING_SIZE, EMBEDDING_SIZE)
        self.skip6 = SKIPblock(EMBEDDING_SIZE, EMBEDDING_SIZE)
        self.skip7 = SKIPblock(EMBEDDING_SIZE, EMBEDDING_SIZE)
        self.relu_out_resnet = nn.ReLU()

        self.forward_prediction = Linear(EMBEDDING_SIZE, OUTPUT_SIZE)
        self.backward_prediction = Linear(EMBEDDING_SIZE, OUTPUT_SIZE)
        self.gate = Linear(EMBEDDING_SIZE, OUTPUT_SIZE)

        self.relu_out = nn.ReLU()

    def forward(self, x, edge_index, edge_weight, total_mass, batch_index):
        
        hidden = self.initial_conv(x, edge_index)
        hidden = F.relu(hidden)
        
        # Other Conv layers
        hidden = self.conv1(hidden, edge_index)
        hidden = self.reluconv1(hidden)
        hidden = self.conv2(hidden, edge_index)
        hidden = self.reluconv2(hidden)
        hidden = self.conv3(hidden, edge_index)
        hidden = self.reluconv3(hidden)
        hidden = self.conv4(hidden, edge_index)
        hidden = self.reluconv4(hidden)
     
        
        hidden = gap(hidden, batch_index)
        hidden = self.bottleneck(hidden)

        hidden = self.skip1(hidden)
        hidden = self.skip2(hidden)
        hidden = self.skip3(hidden)
        hidden = self.skip4(hidden)
        hidden = self.skip5(hidden)
        hidden = self.skip6(hidden)
        hidden = self.skip7(hidden)
        
        hidden = self.relu_out_resnet(hidden)

        # Bidirectional layer
        # Forward prediction
        forward_prediction_hidden = self.forward_prediction(hidden)
        forward_prediction_hidden = mask_prediction_by_mass(total_mass, forward_prediction_hidden, MASS_SHIFT)
        
        # # Backward prediction
        backward_prediction_hidden = self.backward_prediction(hidden)
        backward_prediction_hidden = reverse_prediction(total_mass, backward_prediction_hidden, MASS_SHIFT)
        
        # # Gate
        gate_hidden = self.gate(hidden)
        gate_hidden = F.sigmoid(gate_hidden)

        # # Apply a final (linear) classifier.
        out = gate_hidden * forward_prediction_hidden + (1. - gate_hidden) * backward_prediction_hidden
        out = self.relu_out(out)
        
        out = out.type(torch.float64)
        return out

model = CONV_BIG()
MODEL_NAME = "CONV_BIG_POW"
MODEL_SAVE = os.path.join(BASE_DIRECTORY, MODEL_NAME)
os.makedirs(MODEL_SAVE, mode=0o777, exist_ok=True)
print(model)
print("Number of parameters: ", sum(p.numel() for p in model.parameters()))

CONV_BIG(
  (initial_conv): GCNConv(50, 300)
  (conv1): GCNConv(300, 300)
  (reluconv1): ReLU()
  (conv2): GCNConv(300, 300)
  (reluconv2): ReLU()
  (conv3): GCNConv(300, 300)
  (reluconv3): ReLU()
  (conv4): GCNConv(300, 300)
  (reluconv4): ReLU()
  (bottleneck): Linear(in_features=300, out_features=2000, bias=True)
  (skip1): SKIPblock(
    (batchNorm1): BatchNorm1d(2000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU()
    (dropout1): Dropout(p=0.2, inplace=False)
    (hidden1): Linear(in_features=2000, out_features=1000, bias=True)
    (batchNorm2): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU()
    (dropout2): Dropout(p=0.2, inplace=False)
    (hidden2): Linear(in_features=1000, out_features=2000, bias=True)
  )
  (skip2): SKIPblock(
    (batchNorm1): BatchNorm1d(2000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU()
    (dropout1): Dropout(p=0.2, inplace=Fa

In [None]:

import warnings
warnings.filterwarnings("ignore")

######################################
#  LOSS
######################################

# Root mean squared error
loss_fn = torch.nn.HuberLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) 
scheduler = StepLR(optimizer, step_size=100, gamma=0.5)

# Use GPU for training

model = model.to(device)

# Wrap data in a data loader



NUM_GRAPHS_PER_BATCH = 64


loader = DataLoader(train_dataset, 
                    batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)

validation_loader = DataLoader(validation_dataset, 
                    batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)

test_loader = DataLoader(test_dataset, 
                         batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
SAVE_EVERY_X_EPOCH = 10
REPORT_EVERY_X_EPOCH = 1

def train(loader):
    # Enumerate over the data
    loss_per_batch = np.array([])
    for batch in loader:
        # Use GPU
        batch.to(device)  
        # Reset gradients
        optimizer.zero_grad() 
        # Passing the node features and the connection info
        pred = model(batch.x.float(), batch.edge_index, batch.edge_attr, batch.molecular_weight, batch.batch) 
        # Calculating the loss and gradients
        loss = loss_fn(pred, batch.y)
        loss.backward()  
        # Update using the gradients
        optimizer.step()
        loss_per_batch = np.concatenate((loss_per_batch, np.array([loss.clone().detach().cpu().numpy()])))
    return loss_per_batch

print("Starting training...")

for epoch in range(300):
    scheduler.step()
    loss = train(loader)
    pred_test_similarity = np.array([])
    pred_validation_similarity = np.array([])
   
   
    if epoch % REPORT_EVERY_X_EPOCH == 0:
        for batch in test_loader:
            batch.to(device)  
            pred = model(batch.x.float(), batch.edge_index, batch.edge_attr, batch.molecular_weight, batch.batch)

            batch_similarity = validate_similarities(batch.y.detach().cpu().numpy(),
                                  pred.detach().cpu().numpy(),
                                  mass_pow=1.0, intensity_pow=0.5)

            pred_test_similarity = np.concatenate((pred_test_similarity, batch_similarity))
    
        for batch in validation_loader:
           
            batch.to(device)  
            pred = model(batch.x.float(), batch.edge_index, batch.edge_attr, batch.molecular_weight, batch.batch)

            batch_similarity = validate_similarities(batch.y.detach().cpu().numpy(),
                                  pred.detach().cpu().numpy(),
                                  mass_pow=1.0, intensity_pow=0.5)

            pred_validation_similarity = np.concatenate((pred_validation_similarity, batch_similarity))
            
        
        print(f"Epoch {epoch} | Test DotSimilarity is {pred_test_similarity.mean()}")
        print(f"Epoch {epoch} | Validation DotSimilarity is {pred_validation_similarity.mean()}")
        print(f"Epoch {epoch} | Train Loss {loss.mean()}")
        print()

    
    if epoch % SAVE_EVERY_X_EPOCH == 0:
        SAVE_PATH = f"{epoch}.pt"
            
        # Save model
        torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'metadata': {"loss" : "HuberLoss",
                     "Dataset": "Preprocessed_test_pow_preparation_no_sparse_small",
                     "test_similarities": pred_test_similarity.mean()}
        }, os.path.join(MODEL_SAVE, SAVE_PATH))

        LOSS_FILE = f"all_loss_until_{epoch}.output"
        with open(os.path.join(MODEL_SAVE, LOSS_FILE), 'wb') as fid:
            pickle.dump(loss.mean(), fid)
            fid.close() 
        

Starting training...
Epoch 0 | Test DotSimilarity is 0.7575590955187312
Epoch 0 | Validation DotSimilarity is 0.7570808120262554
Epoch 0 | Train Loss 0.29368624125264664

Epoch 1 | Test DotSimilarity is 0.7736780899462883
Epoch 1 | Validation DotSimilarity is 0.7740394085129328
Epoch 1 | Train Loss 0.27286283791980903

Epoch 2 | Test DotSimilarity is 0.7826235766160017
Epoch 2 | Validation DotSimilarity is 0.7819804544445013
Epoch 2 | Train Loss 0.26379257940022854

Epoch 3 | Test DotSimilarity is 0.7893297120932488
Epoch 3 | Validation DotSimilarity is 0.7876204016558318
Epoch 3 | Train Loss 0.2570771191414507

Epoch 4 | Test DotSimilarity is 0.7943824740690875
Epoch 4 | Validation DotSimilarity is 0.794442870108146
Epoch 4 | Train Loss 0.2515146077402022

Epoch 5 | Test DotSimilarity is 0.7991542991346566
Epoch 5 | Validation DotSimilarity is 0.798042525855848
Epoch 5 | Train Loss 0.24738780563230559

Epoch 6 | Test DotSimilarity is 0.8025219885316567
Epoch 6 | Validation DotSimilari

In [None]:
from google.colab import runtime
runtime.unassign()