In [1]:
from model import vocabulary
import torch
import torch.nn as nn
import torch.optim as optim
from collections import defaultdict

from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

from collections import Counter
from torchtext.vocab import vocab
from torchtext.data import get_tokenizer

from torchtext.legacy.data import Field, BucketIterator
from model.seq2seq_dataset import SMILESDataset
import numpy as np

import random
import math
import time
import tqdm

from model import seq2seq_attention

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem import DataStructs
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from model import multi_gen_test
import pickle

In [2]:
# Load the model and dataset
SRC = Field(tokenize=vocabulary.SMILESTokenizer().tokenize, 
            init_token='<sos>', 
            eos_token='<eos>', 
            lower=False, 
            batch_first=True)
TRG = Field(tokenize=vocabulary.SMILESTokenizer().tokenize, 
            init_token='<sos>', 
            eos_token='<eos>', 
            lower=False, 
            batch_first=True)

# Load vocabulary
src_vocab_path = '/home/wei/Desktop/Similarity/chembl33_dataset/vocab_pkls/stereo_experiment_vocab.pkl'
trg_vocab_path = '/home/wei/Desktop/Similarity/chembl33_dataset/vocab_pkls/stereo_experiment_vocab.pkl'
with open(src_vocab_path, 'rb') as f:
    SRC.vocab = pickle.load(f)
with open(trg_vocab_path, 'rb') as f:
    TRG.vocab = pickle.load(f)

# Model parameters
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1
max_length = 102

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

enc = seq2seq_attention.Encoder(INPUT_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device, max_length)
dec = seq2seq_attention.Decoder(OUTPUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device, max_length)

SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
model = seq2seq_attention.Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)

# Load pre-trained weights
model.load_state_dict(torch.load('/home/wei/Desktop/Similarity/self_seq2seq_try/ckpts/chembl33/shuffled_filtered_stereo/att_130eps.pt'))
model.eval()

Seq2Seq(
  (encoder): Encoder(
    (tok_embedding): Embedding(59, 256)
    (pos_embedding): Embedding(102, 256)
    (layers): ModuleList(
      (0): EncoderLayer(
        (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (ff_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (self_attention): MultiHeadAttentionLayer(
          (fc_q): Linear(in_features=256, out_features=256, bias=True)
          (fc_k): Linear(in_features=256, out_features=256, bias=True)
          (fc_v): Linear(in_features=256, out_features=256, bias=True)
          (fc_o): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (positionwise_feedforward): PositionwiseFeedforwardLayer(
          (fc_1): Linear(in_features=256, out_features=512, bias=True)
          (fc_2): Linear(in_features=512, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
       

In [3]:
test_trg = '/home/wei/Desktop/Similarity/chembl33_dataset/curriculum_learning/lev_5*ABpBAp_dataset/tvt/test.trg'
with open(test_trg, 'r') as file:
    contents = file.read()
    trg_smis = contents.split('\n')

test_src = '/home/wei/Desktop/Similarity/chembl33_dataset/curriculum_learning/lev_5*ABpBAp_dataset/tvt/test.src'
with open(test_src, 'r') as file:
    contents = file.read()
    src_smis = contents.split('\n')

## generate SMILES using three decoders

In [4]:
src_smis[0]

'CN(CCCN1C(=O)c2cnccc2C1=O)Cc1ccccc1'

In [13]:
tokenizer = vocabulary.SMILESTokenizer()
generated_smiles = multi_gen_test.get_sim_smiles_decoding(
    src_smis[0],
    SRC,
    TRG,
    model,
    device,
    max_length,
    10,
    1.2,  # Temperature for sampling (not relevant for this decoder)
    tokenizer,
    decoder_type=1,  # Standard beam search
    use_masking=True,
    prefix_length=0   # Specify where modifications start
)
generated_smiles

[('CN(CCCN1C(=O)c2ccccc2C1=O)Cc1ccccc1', -2.661444155643373),
 ('CN(CCCN1C(=O)c2cnccc2C1=O)Cc1ccccc1', -3.2709101806673218),
 ('CN(CCCN1C(=O)c2ccncc2C1=O)Cc1ccccc1', -3.3758833391663003),
 ('CN(CCCN1C(=O)c2cccnc2C1=O)Cc1ccccc1', -3.7271298959598127),
 ('CN(CCCN1C(=O)c2ncccc2C1=O)Cc1ccccc1', -3.9511488980758642),
 ('CN(CCCCN1C(=O)c2ccccc2C1=O)Cc1ccccc1', -4.10260442714047),
 ('CN(CCCN1C(=O)c2cnccc2C1=O)C', -4.543625919176854),
 ('CN(CCCCN1C(=O)c2cnccc2C1=O)Cc1ccccc1', -4.563300376406794),
 ('N(CCCN1C(=O)c2ccccc2C1=O)Cc1ccccc1', -4.784602154365757),
 ('CN(CCCN1C(=O)c2cnccc2C1=O)Cc1ccc(OC)cc1', -4.8598453328297)]

In [6]:
generation_w_var = multi_gen_test.generation_with_variants(
    src_smis[0],
    SRC,
    TRG,
    model,
    device,
    max_length,
    10,
    1.2,
    tokenizer,
    variant_count=10,
    decoder_type=0,
    use_masking=True,
    prefix_length=0  # Specify the starting position for modifications
)
generation_w_var

[('n1ccc2c(c1)C(=O)N(CCCN(Cc1ccccc1)C)C2=O', -2.4927966390593976),
 ('n1ccc2c(c1)C(=O)N(CCN1CCN(Cc3ccccc3)CC1)C2=O', -2.501044849660025),
 ('n1ccc2c(c1)C(=O)N(CCN(Cc1ccccc1)C)C2=O', -2.7319841989934925),
 ('c1ccc2c(c1)C(=O)N(CCCN(Cc1ccccc1)C)C2=O', -2.7319938549925653),
 ('n1ccc2c(c1)C(=O)N(CCCN(Cc1ccc(F)cc1)C)C2=O', -2.734346100665184),
 ('n1ccc2c(c1)C(=O)N(CCCN1CCN(c3ccccc3)CC1)C2=O', -2.7345800753879885),
 ('n1ccc2c(c1)C(=O)N(CCN1CCN(Cc3ccccc3)CC1)C(=O)C2=O', -3.234783884001757),
 ('n1ccc2c(c1)C(=O)N(CCN1CCN(Cc3ccccc3)CC1)/C2=N\\c1ccccc1',
  -3.5768244804298934),
 ('n1ccc2c(c1)C(=O)N(CCN1CCN(Cc3ccccc3)CC1)C(=O)c1c-2cccc1',
  -3.580620744398587),
 ('n1ccc2c(c1)C(=O)N(CCN1CCN(Cc3ccccc3)CC1)C(=O)c1ccccc1C2=O',
  -3.580736384013258),
 ('C1(=O)c2c(ccnc2)C(=O)N1CCCN(Cc1ccccc1)C', -3.207442214590195),
 ('C1(=O)c2c(ccnc2)C(=O)N1CCCN(Cc1ccc(F)cc1)C', -3.2900691499755697),
 ('C1(=O)c2c(ccnc2)C(=O)N1CCCN(Cc1ccc(OC)cc1)C', -3.320331467677241),
 ('C1(=O)c2c(ccnc2)C(=O)N1CCN1CCN(Cc2ccccc2)CC1', -

In [7]:
generation_w_recur = multi_gen_test.recursive_generation_with_beam(
    src_smis[0],
    SRC,
    TRG,
    model,
    device,
    max_length,
    10,
    2,
    tokenizer,  # User-provided tokenizer
    temperature=1.2,
    decoder_type=0,
    use_masking=True,
    prefix_length=0
)
generation_w_recur

Step 1/2: Processing 1 SMILES
Step 2/2: Processing 10 SMILES


[('CN(CCCN1C(=O)c2ccccc2C1=O)Cc1ccccc1', -2.661444227848265),
 ('CN(CCCCN1C(=O)c2ccccc2C1=O)Cc1ccccc1', -2.698094561612762),
 ('CN(CCCN1C(=O)c2cnccc2C1=O)Cc1ccc(OC)cc1', -2.8940761146750584),
 ('CN(CCCN1C(=O)c2cnccc2C1=O)Cc1ccc(O)c(OC)c1', -2.8979744167264645),
 ('CN(CCCN1C(=O)c2cnccc2C1=O)Cc1ccc(OC)c(OC)c1', -2.9030787910600213),
 ('CN(CCCN1C(=O)c2cnccc2C1=O)Cc1ccc(O)c(C(F)(F)F)c1', -2.9082779233861893),
 ('CN(CCCN1C(=O)c2cnccc2C1=O)Cc1ccc(O)c(OC)c1OC(=O)C', -2.9246568729969153),
 ('CN(CCCN1C(=O)c2cnccc2C1=O)Cc1ccc(O)c(C(C)(C)OC(C)(C)C)c1',
  -2.9313542669079617),
 ('CN(CCCN1C(=O)c2cnccc2C1=O)Cc1ccc(O)c(OC)c1OC(=O)c1ccc(Cl)cc1',
  -3.0051907483465166),
 ('CN(CCCN1C(=O)c2cnccc2C1=O)Cc1ccc(O)c(OC)c1OC(=O)c1ccc(OC)cc1',
  -3.0077205268268488),
 ('CN(CCCN1C(=O)c2cc(OC)ccc2C1=O)Cc1ccccc1', -2.5532210136610027),
 ('CN(CCCN1C(=O)c2ccccc2C1=O)Cc1ccc(O)c(OC)c1', -2.5573696317497867),
 ('CN(CCCN1C(=O)c2ccccc2C1=O)Cc1ccc(OC)c(OC)c1', -2.562673704941964),
 ('CN(CCCN1C(=O)c2ccccc2C1=O)Cc1ccc(OC)c(