In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn


features_path = './dataset/complete_drug_features.csv'
ddis_path = './dataset/complete_ddis.csv'
df_features = pd.read_csv(features_path)
df_ddi = pd.read_csv(ddis_path)



In [3]:
df_ddi

Unnamed: 0,id1,name1,id2,name2,interaction
0,DB06605,Apixaban,DB00001,Lepirudin,Drug A may increase the anticoagulant activiti...
1,DB06695,Dabigatran etexilate,DB00001,Lepirudin,Drug A may increase the anticoagulant activiti...
2,DB01254,Dasatinib,DB00001,Lepirudin,The risk or severity of bleeding and hemorrhag...
3,DB00001,Lepirudin,DB01609,Deferasirox,The risk or severity of gastrointestinal bleed...
4,DB00001,Lepirudin,DB01586,Ursodeoxycholic acid,The risk or severity of bleeding and bruising ...
...,...,...,...,...,...
515083,DB17088,Famtozinameran,DB14845,Filgotinib,The therapeutic efficacy of Drug A can be decr...
515084,DB17088,Famtozinameran,DB15091,Upadacitinib,The therapeutic efficacy of Drug A can be decr...
515085,DB17088,Famtozinameran,DB16650,Deucravacitinib,The therapeutic efficacy of Drug A can be decr...
515086,DB17088,Famtozinameran,DB16703,Belumosudil,The therapeutic efficacy of Drug A can be decr...


In [5]:
with open('SMILES.txt', 'w') as f:
     f.write('\n'.join(str(x) for x in df_features['Smiles']))
     

In [9]:
import os 
script_folder = os.getcwd()
script_folder

'/homes/Nathan/Implementation'

In [None]:
from pathlib import Path
import os

vocab_path = os.getcwd() +"/Embeddings/bart_vocab.txt"
text = Path(vocab_path).read_text()

In [None]:
import torch
import pickle
import argparse
import pandas as pd
from rdkit import Chem
from pathlib import Path

import molbart.util as util
from molbart.decoder import DecodeSampler
from molbart.models.pre_train import BARTModel
from molbart.data.datasets import ReactionDataset
from molbart.data.datamodules import FineTuneReactionDataModule
import os


DEFAULT_BATCH_SIZE = 20
DEFAULT_NUM_BEAMS = 10


class SmilesError(Exception):
    def __init__(self, idx, smi):
        message = f"RDKit could not parse smiles {smi} at index {idx}"
        super().__init__(message)


def build_dataset(args):
    text = Path(args.reactants_path).read_text()
    smiles = text.split("\n")
    smiles = [smi for smi in smiles if smi != "" and smi is not None]
    dataset = ReactionDataset(smiles, smiles)
    return dataset


def build_datamodule(args, dataset, tokeniser, max_seq_len):
    test_idxs = range(len(dataset))
    dm = FineTuneReactionDataModule(
        dataset,
        tokeniser,
        args.batch_size,
        max_seq_len,
        val_idxs=[],
        test_idxs=test_idxs
    )
    return dm


def concat_tensor(tensor_1, tensor_2):
    pass

def smiles_embedding(model, smiles_loader):
    device = "cuda:0" if util.use_gpu else "cpu"
    model = model.to(device)
    model.eval()

    for _, batch in enumerate(smiles_loader):
        device_batch = {
            key: val.to(device) if type(val) == torch.Tensor else val for key, val in batch.items()
        }

        enc_input = device_batch["encoder_input"]
        enc_mask = device_batch["encoder_pad_mask"]

        # Freezing the weights reduces the amount of memory leakage in the transformer
        model.freeze()

        encode_input = {
            "encoder_input": enc_input,
            "encoder_pad_mask": enc_mask
        }
        with torch.no_grad():
            embedding = model.encode(encode_input)
            memory = model.encoder(embedding)
            print(memory)
            # 把tensor接上
    model.unfreeze()
    return memory


def encode_smiles(args):
    print("Building tokeniser...")
    tokeniser = util.load_tokeniser(args.vocab_path, args.chem_token_start_idx)
    print("Finished tokeniser...")

    print("Reading SMILES...")
    dataset = build_dataset(args)
    print("Finished SMILES...")

    sampler = DecodeSampler(tokeniser, util.DEFAULT_MAX_SEQ_LEN)

    print("Loading model...")
    model = util.load_bart(args, sampler)
    model.num_beams = args.num_beams
    sampler.max_seq_len = model.max_seq_len
    print("Finished model...")

    print("Building data loader...")
    dm = build_datamodule(args, dataset, tokeniser, model.max_seq_len)
    dm.setup()
    test_loader = dm.test_dataloader()
    print("Finished loader...")

    print("Embedding SMILES...")
    embeddings = smiles_embedding(model, test_loader)
    return embeddings


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    working_dir = os.getcwd()
    smile_path = '/Embeddings/SMILES_1.txt'
    model_path = "/Embeddings/models/pre-trained/combined-large/step=1000000.ckpt"
    vocab_path = "/Embeddings/bart_vocab.txt"
    # Program level args
    parser.add_argument("--reactants_path", type=str, default=working_dir + smile_path)  # Each line is a input SMILES
    parser.add_argument("--model_path", type=str, default=working_dir + model_path)
    parser.add_argument("--products_path", type=str, default="embedding.pickle")
    parser.add_argument("--vocab_path", type=str, default=working_dir + vocab_path)
    parser.add_argument("--chem_token_start_idx", type=int, default=util.DEFAULT_CHEM_TOKEN_START)

    # Model args
    parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE)
    parser.add_argument("--num_beams", type=int, default=DEFAULT_NUM_BEAMS)

    args = parser.parse_args()
    embed = encode_smiles(args)
    print(embed.size())
    print('--------------')
    print('--------------')