In [1]:
#packages for RXNFP
!pip install rxnfp==0.1.0 --no-deps
!pip install transformers
!pip install rdkit

Collecting rxnfp==0.1.0
  Downloading rxnfp-0.1.0-py3-none-any.whl (74.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.7/74.7 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rxnfp
Successfully installed rxnfp-0.1.0
Collecting transformers
  Downloading transformers-4.34.0-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m54.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers)
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m116.0 MB/s[0m eta [36m0:00:00[0m
[

In [2]:
from google.colab import drive
import os
import csv
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!pip install fair-esm
!pip install pytorch-lightning

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m92.2/93.1 kB[0m [31m2.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0
Collecting pytorch-lightning
  Downloading pytorch_lightning-2.1.0-py3-none-any.whl (774 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m774.6/774.6 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.2.0-py3-none-any.whl (805 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.2/805.2 kB[0m [31m58.3 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-util

In [4]:
import pandas as pd
from tqdm import tqdm
import pickle
import numpy as np
import matplotlib.pyplot as plt
import random
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch
import esm



In [5]:
from torch.utils.data import Dataset
#setting up ESM
model_esm, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model_esm.eval()  # disables dropout for deterministic results
model_esm.cuda() #push model to gpu
#setting up RXNFP
from rxnfp.transformer_fingerprints import (RXNBERTFingerprintGenerator, get_default_model_and_tokenizer, generate_fingerprints)
model, tokenizer = get_default_model_and_tokenizer()
rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer)


Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt


In [6]:


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, transform=None):
        # Load the CSV file using pandas
        self.data_frame = pd.read_csv(csv_file, header=None)

    def __len__(self):
      return len(self.data_frame)

    def __getitem__(self, idx):
        # Get a sample from the dataset
        # for Enzyme sequences
        batch_labels, batch_strs, batch_tokens = batch_converter([("Ezy_seq", self.data_frame.iloc[idx, 1])])
        batch_tokens = batch_tokens.cuda()
        with torch.no_grad():
         token = model_esm(batch_tokens, repr_layers=[33], return_contacts=False)["representations"][33].cpu()

        sample = {'Enzyme': token[:, 1 : batch_tokens.size(1) - 1].mean(1), 'Reaction': torch.tensor([rxnfp_generator.convert(self.data_frame.iloc[idx, 2])])}
        del batch_tokens
        return sample


In [7]:
class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size = 16):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage):
        self.train_dataset = MyDataset("/content/drive/My Drive/train_1.csv")
        self.val_dataset = MyDataset("/content/drive/My Drive/val_1.csv")
        self.test_dataset = MyDataset("/content/drive/My Drive/test_1.csv")

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)

    def val_dataloader(self):
        full_batch = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
        return full_batch

    def test_dataloader(self):
        full_batch = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
        return full_batch

In [8]:
class MiniCLIP(pl.LightningModule):
    def __init__(self, lr):
        super().__init__()
        self.lr = lr

        self.Ezy_embedder = nn.Sequential(
          nn.Linear(1280, 640),
          nn.ReLU(),
          nn.Linear(640, 128),
        )
        self.Rxn_embedder = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
        )

    def forward(self, Ezy_input, Rxn_input):
        ##get Enzyme and Reaction embeddings, dot together
        Ezy_embedding = F.normalize(self.Ezy_embedder(Ezy_input.squeeze(1)))
        Rxn_embedding = F.normalize(self.Rxn_embedder(Rxn_input.squeeze(1)))

        logits = torch.matmul(Ezy_embedding, Rxn_embedding.T)

        return logits

    def training_step(self, batch, batch_idx):
        logits = self(
            batch['Enzyme'],
            batch['Reaction'],
        )

        batch_size = batch['Enzyme'].shape[0]
        labels = torch.arange(batch_size).to(self.device) ##NOTE: to(self.device) is important here
        ##this gives us the diagonal clip loss structure

        # (loss of predicting x(Rxn) using y(Ezy)  +  loss of predicting y using x)/2
        loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2

        self.log("train_loss", loss, sync_dist=True, batch_size=logits.shape[0])
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
          logits = self(
            batch['Enzyme'],
            batch['Reaction'],
          )

          batch_size = batch['Enzyme'].shape[0]
          labels = torch.arange(batch_size).to(self.device) ##NOTE: to(self.device) is important here
          ##this gives us the diagonal clip loss structure

          # (loss of predicting x(Rxn) using y(Ezy)  +  loss of predicting y using x)/2
          loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2

          y_mrr = (logits.argsort(dim=0).diag() + 1).float().pow(-1).mean()
          x_mrr = (logits.argsort(dim=1).diag() + 1).float().pow(-1).mean()

          x_accuracy = logits.argmax(dim=1).eq(labels).float().mean()
          y_accuracy = logits.argmax(dim=0).eq(labels).float().mean()

          k = int(logits.shape[0] / 10)
          y_topk_accuracy = torch.any((logits.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logits.shape[0]
          x_topk_accuracy = torch.any((logits.topk(k, dim=1).indices - labels.reshape(-1, 1)) == 0, dim=1).sum() / logits.shape[0]


          self.log("val_loss", loss, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("val_perplexity", torch.exp(loss), sync_dist=False, prog_bar=True, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("val_Rxn_accuracy", x_accuracy, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("val_Ezy_accuracy", y_accuracy, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("val_Rxn_top10p", x_topk_accuracy, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("val_Ezy_top10p", y_topk_accuracy, sync_dist=True, prog_bar=True, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("val_Rxn_mrr", x_mrr, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("val_Ezy_mrr", y_mrr, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)

    def test_step(self, batch, batch_idx, dataloader_idx=0):

          logits = self(
            batch['Enzyme'],
            batch['Reaction'],
          )

          batch_size = batch['Enzyme'].shape[0]
          labels = torch.arange(batch_size).to(self.device) ##NOTE: to(self.device) is important here
          ##this gives us the diagonal clip loss structure

          # (loss of predicting x(Rxn) using y(Ezy)  +  loss of predicting y using x)/2
          loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2

          y_mrr = (logits.argsort(dim=0).diag() + 1).float().pow(-1).mean()
          x_mrr = (logits.argsort(dim=1).diag() + 1).float().pow(-1).mean()

          x_accuracy = logits.argmax(dim=1).eq(labels).float().mean()
          y_accuracy = logits.argmax(dim=0).eq(labels).float().mean()

          k = int(logits.shape[0] / 10)
          y_topk_accuracy = torch.any((logits.topk(k, dim=0).indices - labels.reshape(1, -1)) == 0, dim=0).sum() / logits.shape[0]
          x_topk_accuracy = torch.any((logits.topk(k, dim=1).indices - labels.reshape(-1, 1)) == 0, dim=1).sum() / logits.shape[0]

          self.log("test_loss", loss, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("test_perplexity", torch.exp(loss), sync_dist=False, prog_bar=True, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("test_Rxn_accuracy", x_accuracy, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("test_Ezy_accuracy", y_accuracy, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("test_Rxn_top10p", x_topk_accuracy, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("test_Ezy_top10p", y_topk_accuracy, sync_dist=True, prog_bar=True, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("test_Rxn_mrr", x_mrr, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)
          self.log("test_Ezy_mrr", y_mrr, sync_dist=True, prog_bar=False, batch_size=logits.shape[0], add_dataloader_idx=False)


    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [9]:
##implementing early stopping
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

early_stop_callback = EarlyStopping(
   monitor= 'train_loss',
   min_delta=0.00,
   patience=3,
   verbose=False,
   mode='min'
)

In [10]:
datamodule = DataModule()
trainer = pl.Trainer(callbacks=[early_stop_callback])
miniclip = MiniCLIP(lr = 0.003)
trainer.fit(miniclip, datamodule = datamodule)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name         | Type       | Params
--------------------------------------------
0 | Ezy_embedder | Sequential | 901 K 
1 | Rxn_embedder | Sequential | 32.9 K
--------------------------------------------
934 K     Trainable params
0         Non-trainable params
934 K     Total params
3.739     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:492: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [11]:
torch.save(miniclip.state_dict(), '/content/drive/My Drive/model_16_1.pth')

In [12]:
trainer.validate(miniclip, datamodule=datamodule)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:492: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 2.695436954498291,
  'val_perplexity': 14.820276260375977,
  'val_Rxn_accuracy': 0.06321839243173599,
  'val_Ezy_accuracy': 0.07543103396892548,
  'val_Rxn_top10p': 0.06321839243173599,
  'val_Ezy_top10p': 0.07543103396892548,
  'val_Rxn_mrr': 0.2111085057258606,
  'val_Ezy_mrr': 0.20760604739189148}]

In [13]:
trainer.test(miniclip, datamodule=datamodule)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:492: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 2.7726268768310547,
  'test_perplexity': 16.000612258911133,
  'test_Rxn_accuracy': 0.0625,
  'test_Ezy_accuracy': 0.0625,
  'test_Rxn_top10p': 0.0625,
  'test_Ezy_top10p': 0.0625,
  'test_Rxn_mrr': 0.21129558980464935,
  'test_Ezy_mrr': 0.21129558980464935}]