In [None]:
!pip install matchms
!pip install rdkit
!pip install torch_geometric
!pip install pickle5

In [None]:
from matchms.importing import load_from_msp
import numpy as np
import os
import random
from rdkit import Chem
from rdkit.Chem import Descriptors
import matchms
import pickle
from matchms import Spectrum

import matplotlib.pyplot as plt
import warnings

from rdkit.Chem.rdmolops import GetAdjacencyMatrix

# Pytorch and Pytorch Geometric
import torch
from torch_geometric.data import Data
from torch.utils.data import DataLoader

import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn as nn



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
os.chdir("/content/drive/MyDrive/NIST")
BASE_DIRECTORY = "/content/drive/MyDrive/NIST"

TRAIN_PATH = 'train.msp'
nist_dataset_org = load_from_msp(TRAIN_PATH, metadata_harmonization=False)

TEST_DATA_SIZE = 5000
OUTPUT_SIZE = 1000
INTENSITY_POWER = 0.5

In [None]:
with open("/content/drive/MyDrive/NIST/Preprocessed_train.output", 'rb') as handle:
   data_list  = pickle.load(handle)

In [None]:
def mask_prediction_by_mass(total_mass, raw_prediction, index_shift):
    # Zero out predictions to the right of the maximum possible mass.
    # input
    # anchor_indices: shape (,batch_size) = ex [3,4,5]
    #     total_mass = Weights of whole molecule, not only fragment
    # data: shape (batch_size, embedding), embedding from GNN in our case
    # index_shift: int constant how far can heaviest fragment differ from weight of original molecule
    #

    total_mass = torch.round(total_mass).type(torch.int32)
    indices = torch.arange(raw_prediction.shape[-1])[None, ...]

    right_of_total_mass = indices > (
            total_mass[..., None] +
            index_shift)
    return torch.where(right_of_total_mass, torch.zeros_like(raw_prediction),
                        raw_prediction)

In [None]:
def reverse_prediction(total_mass, raw_prediction, index_shift):
    # reverse vector by anchor_indices and rest set to zero and make preproessing
    # input
    # total_mass: shape (,batch_size) = ex [3,4,5]
    #     total_mass = Weights of whole molecule, not only fragment
    # raw_prediction: shape (batch_size, embedding), embedding from GNN in our case
    # index_shift: int constant how far can heaviest fragment differ from weight of original molecule
    #     total_mass = feature_dict[fmap_constants.MOLECULE_WEIGHT][..., 0]

    total_mass = torch.round(total_mass).type(torch.int32)
    return scatter_by_anchor_indices(
        total_mass, raw_prediction, index_shift)

In [None]:
def scatter_by_anchor_indices(anchor_indices, data, index_shift):
    # reverse vector by anchor_indices and rest set to zero
    # input
    # anchor_indices: shape (,batch_size) = ex [3,4,5]
    #     total_mass = Weights of whole molecule, not only fragment
    # data: shape (batch_size, embedding), embedding from GNN in our case
    # index_shift: int constant how far can heaviest fragment differ from weight of original molecule

    index_shift = index_shift
    anchor_indices = anchor_indices
    data = data.type(torch.float64)
    batch_size = data.shape[0]

    num_data_columns = data.shape[-1]
    indices = np.arange(num_data_columns)[np.newaxis, ...]
    shifted_indices = anchor_indices[..., None] - indices + index_shift
    valid_indices = shifted_indices >= 0



    batch_indices = torch.tile(
          torch.arange(batch_size)[..., None], [1, num_data_columns])
    shifted_indices += batch_indices * num_data_columns

    shifted_indices = torch.reshape(shifted_indices, [-1])
    num_elements = data.shape[0] * data.shape[1]
    row_indices = torch.arange(num_elements)
    stacked_indices = torch.stack([row_indices, shifted_indices], axis=1)


    lower_batch_boundaries = torch.reshape(batch_indices * num_data_columns, [-1])
    upper_batch_boundaries = torch.reshape(((batch_indices + 1) * num_data_columns),
                                          [-1])

    valid_indices = torch.logical_and(shifted_indices >= lower_batch_boundaries,
                                     shifted_indices < upper_batch_boundaries)

    stacked_indices = stacked_indices[valid_indices]

    # num_elements[..., np.newaxis] v tf aj ked je shape (), tak vies urbit data[]
    # teraz to z napr. 6 da na [6]
    dense_shape = torch.tile(torch.tensor(num_elements)[..., None], [2]).type(torch.int32)

    scattering_matrix = torch.sparse.FloatTensor(stacked_indices.type(torch.int64).T,
                                                 torch.ones_like(stacked_indices[:, 0]).type(torch.float64),
                                                dense_shape.tolist())

    flattened_data = torch.reshape(data, [-1])[..., None]
    flattened_output = torch.sparse.mm(scattering_matrix, flattened_data)
    return torch.reshape(torch.transpose(flattened_output, 0, 1), [-1, num_data_columns])

In [None]:
embedding_size = 64
embedding_in = 32
NODE_FEATURES = 84
MASS_SHIFT = 5


In [None]:
class GCN(torch.nn.Module):
    def __init__(self):
        # Init parent
        super(GCN, self).__init__()
        torch.manual_seed(42)

        # GCN layers
        self.initial_conv = GCNConv(NODE_FEATURES, embedding_size)
        self.conv1 = GCNConv(embedding_size, embedding_size)
        self.conv2 = GCNConv(embedding_size, embedding_size)
        self.conv3 = GCNConv(embedding_size, embedding_size)

        self.forward_prediction = Linear(embedding_size*2, OUTPUT_SIZE)
        self.backward_prediction = Linear(embedding_size*2, OUTPUT_SIZE)
        self.gate = Linear(embedding_size*2, OUTPUT_SIZE)

        # Output layer
        self.out = Linear(embedding_in, OUTPUT_SIZE)

    def forward(self, x, edge_index, total_mass, batch_index):
        # First Conv layer
        hidden = self.initial_conv(x, edge_index)
        hidden = F.relu(hidden)

        # Other Conv layers
        hidden = self.conv1(hidden, edge_index)
        hidden = F.relu(hidden)
        hidden = self.conv2(hidden, edge_index)
        hidden = F.relu(hidden)
        hidden = self.conv3(hidden, edge_index)
        hidden = F.relu(hidden)

        # Global Pooling (stack different aggregations)
        hidden = torch.cat([gmp(hidden, batch_index),
                            gap(hidden, batch_index)], dim=1)

        print(hidden.shape)

        # Bidiractional layer
        # Forward prediction
        forward_prediction_hidden = self.forward_prediction(hidden)
        forward_prediction_hidden = mask_prediction_by_mass(total_mass, forward_prediction_hidden, MASS_SHIFT)

        # Backward prediction
        backward_prediction_hidden = self.backward_prediction(hidden)
        backward_prediction_hidden = reverse_prediction(total_mass, backward_prediction_hidden, MASS_SHIFT)

        # Gate
        gate_hidden = self.gate(hidden)
        gate_hidden = F.sigmoid(gate_hidden)

        # Apply a final (linear) classifier.
        out = gate_hidden * forward_prediction_hidden + (1. - gate_hidden) * backward_prediction_hidden
        print(out.shape)
        out = F.relu(out)
#         out = self.out(hidden)
#         out = F.relu(out)

        return out, hidden

MODEL_NAME = "GCN_testing"
model = GCN()
MODEL_SAVE = os.path.join(BASE_DIRECTORY, MODEL_NAME)
os.makedirs(MODEL_SAVE, mode=0o777, exist_ok=True)
print(model)
print("Number of parameters: ", sum(p.numel() for p in model.parameters()))

GCN(
  (initial_conv): GCNConv(84, 64)
  (conv1): GCNConv(64, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (forward_prediction): Linear(in_features=128, out_features=1000, bias=True)
  (backward_prediction): Linear(in_features=128, out_features=1000, bias=True)
  (gate): Linear(in_features=128, out_features=1000, bias=True)
  (out): Linear(in_features=32, out_features=1000, bias=True)
)
Number of parameters:  437920


In [None]:
from torch_geometric.data import DataLoader
import warnings
warnings.filterwarnings("ignore")

####################################
# HUBLER LOSS
####################################

# Root mean squared error
loss_fn = torch.nn.HuberLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0007)

# Use GPU for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Wrap data in a data loader
data_size = len(data_list)
NUM_GRAPHS_PER_BATCH = 64
loader = DataLoader(data_list[:int(data_size * 1.0)],
                    batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
test_loader = DataLoader(data_list[int(data_size * 0.8):],
                         batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)

def train(data, number_of_epoch, save_every_x_epoch):
    print("Starting training...")
    losses = []
    for epoch in range(number_of_epoch):
      for batch in loader:
          # Use GPU
          batch.to(device)
          # Reset gradients
          optimizer.zero_grad()
          # Passing the node features and the connection info
          pred, embedding = model(batch.x.float(), batch.edge_index, batch.molecular_weight, batch.batch)
          # Calculating the loss and gradients
          loss = loss_fn(pred, batch.y)
          loss.backward()
          # Update using the gradients
          optimizer.step()

          break

          # Save model every save_every_x_epoch
          if epoch == save_every_x_epoch:
            SAVE_PATH = f"{epoch}.pt"

            # Save model
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, os.path.join(MODEL_SAVE, SAVE_PATH))

            LOSS_FILE = f"all_loss_until_{epoch}.output"
            with open(os.path.join(MODEL_SAVE, LOSS_FILE), 'wb') as fid:
              pickle.dump(losses, fid)
              fid.close()

      break

      losses.append(loss)

      if epoch % 100 == 0:
        print(f"Epoch {epoch} | Train Loss {loss}")
    return losses, embedding

# print("Starting training...")
# losses = []
# for epoch in range(3000):
#     loss, h = train(data_list)
#     losses.append(loss)
#     if epoch % 100 == 0:
#         print(f"Epoch {epoch} | Train Loss {loss}")

In [None]:
MODEL_SAVE = os.path.join(BASE_DIRECTORY, MODEL_NAME)
os.makedirs(MODEL_SAVE, mode=0o777, exist_ok=True)
train(data_list, 2, 100)

Starting training...
torch.Size([64, 128])
torch.Size([64, 1000])


([], tensor([[1.3546e-01, 0.0000e+00, 6.2873e-02,  ..., 0.0000e+00, 6.1763e-03,
          0.0000e+00],
         [1.6911e-01, 0.0000e+00, 7.8670e-02,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.1454e-01, 1.8143e-03, 1.9634e-02,  ..., 6.7606e-04, 1.3382e-02,
          5.8276e-05],
         ...,
         [1.6513e-01, 9.4535e-03, 2.1554e-02,  ..., 0.0000e+00, 2.8282e-02,
          0.0000e+00],
         [2.3073e-01, 0.0000e+00, 2.0955e-02,  ..., 0.0000e+00, 1.0293e-02,
          0.0000e+00],
         [2.9246e-01, 0.0000e+00, 9.8434e-02,  ..., 8.3160e-03, 4.6487e-04,
          0.0000e+00]], grad_fn=<CatBackward0>))

In [None]:
data_list[0].y

tensor([[   -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,    -inf,    -inf,  0.0000,  1.3863,  3.2573,  4.3429,
          0.0000,  0.0000,  1.0986,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf,  2.1961,  2.0782,  2.1961,  2.3970,  1.3863,    -inf,
         -0.6931,    -inf,    -inf,    -inf,    -inf,  2.6383,  3.9881,  3.6625,
          1.0986,  2.0782,  3.6368,  1.7901, -0.6931,    -inf,    -inf,    -inf,
            -inf,  1.9445,  4.5634,  4.2895,  3.4956,  3.0436,  0.6931,  1.3863,
            -inf,    -inf,    -inf,    -inf,  0.0000,  2.7720,  3.9110,  4.9264,
          4.8743,  2.9947,  2.1961,  0.6931,    -inf,  0.0000,    -inf,    -inf,
            -inf,  1.6094,  3.3663,  3.0901,  3.7367,  5.5125,  4.4059,  3.2181,
          1.3863,  2.1961,  0.6931,    -inf,    -inf,  0.6931,  1.0986,  0.6931,
            -inf,  2.8320,  1.7901,  3.4002,  5.0295,  3.5545,    -inf,  3.0901,
            -inf,    -inf,  