In [41]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split
import os
from SmilesPE.pretokenizer import atomwise_tokenizer
from os import path as p
import torchtext
from collections import OrderedDict
import math
import torch.nn.functional as f
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict


In [81]:
START_TOKEN = '<start>'
END_TOKEN = '<end>'
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
VOCAB_SIZE = 32
DEVICE = 'cuda:1' if torch.cuda.is_available() else 'cpu'

In [82]:
def getVocab(smiles_dir, max_size=100):
    '''max_size should be greater than 4'''
    all_toks_dict = defaultdict(int)

    for f_name in os.listdir(smiles_dir):
        with open(os.path.join(smiles_dir, f_name)) as f:
            for line in f:
                if 'smile' in line: continue
                for tok in atomwise_tokenizer(line.split()[0]):
                    all_toks_dict[tok] += 1

    
    x = sorted([(e[0], e[1]) for e in all_toks_dict.items()], key=lambda e:-e[1])[:VOCAB_SIZE-4]
    ordered_dict = OrderedDict({e[0]: e[1] for e in x})
    print(ordered_dict)
    special_tokens = [PAD_TOKEN, UNK_TOKEN, START_TOKEN, END_TOKEN]
    vocab = torchtext.vocab.vocab(ordered_dict=ordered_dict, specials=special_tokens)
    vocab.set_default_index(vocab[UNK_TOKEN])
    return vocab

In [49]:
# Dataset creation
class SmilesDataset(Dataset):
    def __init__(self, smiles_dir, vocab, max_len=None): # start and end tokens are added
        '''
        if vocab is None : getVocab() is used
        if max_len is None: take max_len from dataset
        if max_len == 'avg': take average length from dataset
        '''
        self.vocab = vocab

        smiles_files = os.listdir(smiles_dir)
        # reading smiles from files
        tokens = []
        for file_name in smiles_files:
            with open(p.join(smiles_dir, file_name)) as f:
                for l in f:
                    if "smile" not in l:
                        tokens.append(atomwise_tokenizer(l.split()[0]))
        
        if max_len is None: max_len = max([len(sen) for sen in tokens])
        if max_len == 'avg': max_len = int(sum([len(sen) for sen in tokens])/len(tokens))

        # stripping
        tokens = [sen[: max_len-2] for sen in tokens]

        # adding start and end tokens
        tokens = [[START_TOKEN] + sen + [END_TOKEN] for sen in tokens]

        # padding
        tokens = [sen + [PAD_TOKEN]*(max_len-len(sen)) for sen in tokens]

        # addention mask false at <pad> tokem, true at non pad token
        self.pad_masks = torch.tensor([[PAD_TOKEN==tok for tok in sen] for sen in tokens])

        # converting to index
        self.data = torch.tensor([self.vocab(sen) for sen in tokens], dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {'idx': self.data[idx], 'pad_mask': self.pad_masks[idx]}

In [86]:
VOCAB_SIZE = 44

In [87]:
vocab = getVocab('smiles', VOCAB_SIZE)

OrderedDict({'c': 23438656, 'C': 17638257, '(': 12944012, ')': 12944012, '1': 8613576, 'O': 6589370, '=': 4349916, 'N': 3933209, '2': 3724200, 'n': 2283405, '3': 1063448, '[C@H]': 977539, '[C@@H]': 856458, 'F': 806720, '/': 582009, 'S': 528853, 'Cl': 482588, '[nH]': 259200, '4': 245628, 's': 225655, 'o': 217697, '\\': 137672, '[C@]': 128017, '#': 125602, '.': 125217, '[O-]': 116643, '[C@@]': 114773, '[N+]': 102860, 'Br': 93663, '5': 57606, 'P': 45820, '-': 40535, '[n+]': 33506, '6': 16166, 'I': 13157, '[Na+]': 13109, '[Br-]': 6907, '[S+]': 6557, '7': 6496, '[Cl-]': 5098})


In [88]:
dataset = SmilesDataset('smiles', vocab, max_len=100)

In [17]:
vocab['<unk>']

1

In [89]:
classes = [0]*100

all_x = []
for i in range(0, len(dataset)):
    if (i+1)%100000 == 0: print(i)
    all_x.append(dataset[i]['idx'])

99999
199999
299999
399999
499999
599999
699999
799999
899999
999999
1099999
1199999
1299999
1399999
1499999
1599999
1699999
1799999
1899999


In [90]:
all_x = torch.concat(all_x)

In [91]:
for i in range(100):
    classes[i] = (all_x == i).sum()

In [93]:
classes[1]

tensor(35764)

In [60]:
len(dataset)

1941411

In [94]:
split_ratio = 0.8
split_seed = 42
train_ds, valid_ds = random_split(dataset, [int(len(dataset)*split_ratio), len(dataset)-int(len(dataset)*split_ratio)], generator=torch.Generator().manual_seed(split_seed))

# model

In [62]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    
    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x.transpose(0, 1)
        x = x + self.pe[:x.size(0)]
        x = x.transpose(0, 1)
        return self.dropout(x)


class DecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward,  dropout=0.1):
        super().__init__()
        self.multihead = nn.MultiheadAttention(embed_dim=d_model, 
                                               num_heads=nhead, 
                                               dropout=dropout, batch_first=True)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout2 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.dropout3 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x, attn_mask, pad_mask):
        x = x + self.dropout1(self.multihead(x, x, x, key_padding_mask=pad_mask, attn_mask=attn_mask, need_weights=False)[0])
        x = self.norm1(x)

        x = x + self.dropout3(self.linear2(self.dropout2(self.linear1(x))))
        x = self.norm2(x)

        return x
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, dim_feedforward, num_layers=1, dropout=0.1, max_len=100):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model=d_model, dropout=dropout, max_len=max_len)
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, nhead, dim_feedforward) for i in range(num_layers)])
        self.linear = nn.Linear(d_model, vocab_size)
        
    def getDevice(self):
        ''' return device of model
        '''
        device = next(self.parameters()).device

        if device.index is None:
            return device.type
        else:
            return device.type + ":" + str(device.index) 

    def generate_square_subsequent_mask(self, sz:int) -> torch.Tensor:
        r"""Generate a square causal mask for the sequence.

        The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
        device_ordinal = {-1: cpu, 0..inf=>gpu}
        """
        
        return torch.triu(
            torch.full((sz, sz), float('-inf'), dtype=torch.float, device=self.getDevice()),
            diagonal=1,
        )


    def forward(self, x, pad_mask=None):
        ''' x = batch * seq_len
            pad_mask = batch*seq_len
        '''
        seq_len = x.shape[1]
        attn_mask = self.generate_square_subsequent_mask(seq_len)
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.decoder_layers:
            x = layer(x, attn_mask=attn_mask, pad_mask=pad_mask)

        x = self.linear(x)
        return x
    
    def generateSmiles(self, batch_size, vocab, max_len=100):
        x = torch.full((batch_size, 1), vocab[START_TOKEN]).to(self.getDevice())
        
        for i in range(1, max_len+1):
            out = self.forward(x)[:, -1]
            out = f.softmax(out, dim=1)
            out = torch.multinomial(out, 1)
            x = torch.cat((x, out), dim=1)
        
        x = x.detach().cpu().tolist()
        # converting idx to smiles
        results = []
        for i in range(batch_size):
            sentance = vocab.lookup_tokens(x[i])
            
            new_sentance = [] # removing special chars
            for i in range(1, len(sentance)):
                e = sentance[i]
                if e==END_TOKEN: break
                if e==START_TOKEN or e==PAD_TOKEN or e==UNK_TOKEN:
                    new_sentance = ['invalid']
                    break
                
                new_sentance.append(e)
            
            results.append(new_sentance)
        
        return results
    


In [95]:
model = Decoder(vocab_size=VOCAB_SIZE, 
        d_model=256, 
        nhead=16, 
        dim_feedforward=128, 
        num_layers=3, 
        dropout=0.1, 
        max_len=100)

model = model.to(DEVICE)
batch_size = 512

train_loader = DataLoader(train_ds, batch_size)
val_loader = DataLoader(valid_ds, batch_size)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

train_result = []
val_result = []
epoch  = 10

In [108]:
smiles = model.generateSmiles(batch_size=100, vocab=vocab)

for e in smiles:
    smile = "".join(e)
    print(smile)
    

n1(C2[C@@H]([C@@H]([C@H](O2)CO)O)O)c2c(nc1)c(ncn2)N(C)C
C1[C@@H](C[n+]2c(C1)sc(n2)NC(=O)Cn1cc(nn1)Nc1ccc(cc1Cl)Cl)O.[Br-]
CCN(C)c1cc(cc(Nc2ncnc3[nH]c(cc23)C#N)c1)C(=O)N1CCC[C@H]1CO
O=C1N(C(=O)NC(=O)C1C)C(=O)c1cc2nc[nH]c2cc1
N1(C[C@H]([C@@H](C1)O)N)C(=O)CCSCC
Fc1ccc(NC(=O)c2nc3c(cc2C)cccc3)cc1
C12=CC[C@H]3[C@]4([C@H]2[C@](CC(C4)([C@H](CCCC(=O)C[C@@H]1C(C)C)C)C5)C)(CC[C@@H](C(=O)O)C)C)C
O=C([C@@H]1CN(C[C@H]1c1ccc(s1)c1cccc(c1)C(F)(F)F)CC(=O)Nc1ccc(nc1)C(F)(F)F)c1ccc(cc1)C(F)(F)F
C1(=C(NC(=C(C1c1cc(c(c(c1)OC)OC)OC)C(=O)NCC)C)C)C(=O)c1ccncc1
Clc1c(Cl)ccc(NC(=O)COc2ccc(S(=O)(=O)N3CCOCC3)cc2)c1
Br.Brc1c(OC[C@H](CCCCN2CCN(CC2)C)O)cccc1.Cl
c1c(c(c2c(n1)nc(cc2C)SCC1CCN(CC1)C)C)C
n1(cnc(c1)c1sccc1)c1cc2c(cc1)ccs2
[Br-].c1c(cc2c(c1)c(c[nH]2)CCC#C)NC(=O)c1ccc(cc1)Cl
CC(C)(O)Cc1[nH]c2cncc(N(C)C(C)C)c2n1
C[n+]1coc(n1)/C=N/OC
c1(c(cc2c(c1)c1c([nH]2)cnc(c1)Oc1ccc(cc1)O)O)OC
C1(NC(=O)C(N(C1=O)c1ccccc1)C)(NC(=O)C1CC1)C1CC
Clc1c(/C=C/2\N(C(=O)c3oc(cc3)c3c(Cl)cccc3)Cc2cccc2)cccc1
FC(F)(F)c1cc(N2CCN(CC2)c2

In [111]:
val_result

[0.4691753212328171,
 0.442591048512063,
 0.4294968842674778,
 0.4221685946537415,
 0.4163481026534506,
 0.41281202883117285,
 0.41007964122593793,
 0.4071080360330926,
 0.40542209823769226,
 0.4029807088324832,
 0.40209448192273517,
 0.4011382240316142,
 0.39953339504316077,
 0.398818506200323,
 0.39824705318656994,
 0.3973310890169483,
 0.396317089656282,
 0.3951411776548946,
 0.3955495490622615,
 0.39444996936833276,
 0.3940663937680492,
 0.3933146997638371,
 0.3934417799684527,
 0.3923208140651542,
 0.3924812851216011,
 0.39193774337668036,
 0.3913349131507522,
 0.391029222209463,
 0.39056187844559137,
 0.3905919809115263]

In [110]:
train_result

[0.5854172040226433,
 0.49285685749885355,
 0.4729587226889993,
 0.46224466937697867,
 0.4550276200336772,
 0.4497474878044135,
 0.4456245037938862,
 0.44239162675785876,
 0.43975009340436816,
 0.43765332727111633,
 0.4357142901648482,
 0.43413144653575864,
 0.43265009342877586,
 0.4314353059272508,
 0.43034084716869486,
 0.42929243282141777,
 0.4283655682773911,
 0.4274879257067616,
 0.42676096733219937,
 0.42603677773483506,
 0.42543337051618546,
 0.4247567388735559,
 0.42419050116246676,
 0.4236377225396817,
 0.4231097109139398,
 0.42267184156922033,
 0.42219669078633276,
 0.4217716067164999,
 0.4213778995414588,
 0.4209647733538891]

In [109]:
for i in range(10):
    tl = train(model, optimizer, criterion, train_loader)
    vl = valid(model, criterion, val_loader)

    train_result.append(tl)
    val_result.append(vl)

    print("Epoch:", i+1, "train loss:", tl, "val loss:", vl)




0.4177311360836029
0.41916099190711975
0.4311732351779938
0.41949591040611267
0.42207640409469604
0.43419331312179565
0.4199746549129486
0.4349762499332428
0.42442595958709717
0.40924519300460815
0.4224224388599396
0.42739105224609375
0.4201262593269348
0.41302481293678284
0.4397747814655304
0.4295125901699066
0.42633238434791565
0.42718830704689026
0.42273250222206116
0.4282897710800171
0.4235767424106598
0.432208776473999
0.4370172619819641
0.4273785948753357
0.4169257879257202
0.4105018973350525
0.4319404065608978
0.4151376187801361
0.43899813294410706
0.4210602343082428
Epoch: 1 train loss: 0.42543337051618546 val loss: 0.3940663937680492
0.41826513409614563
0.4200685918331146
0.43118324875831604
0.41723471879959106
0.4194599986076355
0.42983534932136536
0.4214036762714386
0.4338279962539673
0.4267996549606323
0.4090276062488556
0.4208979308605194
0.42840349674224854
0.41869232058525085
0.4145174026489258
0.4396442472934723
0.43003275990486145
0.42692843079566956
0.4248432219028473

In [69]:
# Training function
def train(model, optimizer, criterion, train_loader):
    model.train()
    total_loss = 0
    count = 200
    for e in train_loader:
        x = e['idx'].to(DEVICE)
        pad_mask = e['pad_mask'].to(DEVICE)

        x_input = x[:, :-1]
        pad_mask = pad_mask[:, :-1]
        y_expected = x[:, 1:]
        output = model(x_input, pad_mask)
        #print(output.shape, y_expected.shape)

        output = torch.flatten(output, start_dim=0, end_dim=1)
        y_expected = torch.flatten(y_expected, start_dim=0, end_dim=1)
        loss = criterion(output, y_expected)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1
        if (count+1)%100 == 0: print(loss.item())
        
    total_loss = total_loss/len(train_loader)
    
    return total_loss

# valid function
def valid(model, criterion, val_loader):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for e in val_loader:
            x = e['idx'].to(DEVICE)
            pad_mask = e['pad_mask'].to(DEVICE)

            x_input = x[:, :-1]
            pad_mask = pad_mask[:, :-1]
            y_expected = x[:, 1:]

            output = model(x_input, pad_mask)
            
            output = torch.flatten(output, start_dim=0, end_dim=1)
            y_expected = torch.flatten(y_expected, start_dim=0, end_dim=1)
            loss = criterion(output, y_expected)

            total_loss += loss.item()
        
    total_loss = total_loss/len(val_loader)
   
    return total_loss
    

In [44]:
m = Decoder(10, 10, 1, 10)

m.to(device='cuda')

Decoder(
  (embedding): Embedding(10, 10)
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (decoder_layers): ModuleList(
    (0): DecoderLayer(
      (multihead): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=10, out_features=10, bias=True)
      )
      (dropout1): Dropout(p=0.1, inplace=False)
      (norm1): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=10, out_features=10, bias=True)
      (dropout2): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=10, out_features=10, bias=True)
      (dropout3): Dropout(p=0.1, inplace=False)
      (norm2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
    )
  )
)

In [46]:
m.generateSmiles(100, getVocab(), 10)

[['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['invalid'],
 [],
 [],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 [],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['=', 'C', '/'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 [],
 ['invalid'],
 [],
 ['invalid'],
 [],
 [],
 ['invalid'],
 ['invalid'],
 ['\\'],
 ['invalid'],
 ['=', '\\'],
 ['invalid'],
 ['/', 'C', '/', 'O'],
 [],
 ['\\'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['/'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['\\', '/'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['

In [None]:
m.

In [61]:
m.to("cuda:1")

Decoder(
  (embedding): Embedding(10, 10)
  (decoder_layers): ModuleList(
    (0): DecoderLayer(
      (multihead): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=10, out_features=10, bias=True)
      )
      (dropout1): Dropout(p=0.1, inplace=False)
      (norm1): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=10, out_features=10, bias=True)
      (dropout2): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=10, out_features=10, bias=True)
      (dropout3): Dropout(p=0.1, inplace=False)
      (norm2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
    )
  )
)

In [None]:
m.

In [56]:
print(next(m.parameters()).device.index)

None


In [55]:
None

In [48]:
x = next(m.parameters()).device

In [49]:
x

device(type='cuda', index=1)

In [51]:
x.index

1

In [16]:
torch.full((2, 3), 4)

tensor([[4, 4, 4],
        [4, 4, 4]])

In [4]:
torch.tril(torch.ones(4, 4))

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [5]:
torch.triu(
        torch.full((4, 4), float('-inf'), dtype=torch.float),
        diagonal=1,
    )

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [9]:
mask.get_device()

-1

In [6]:
mask = torch.tril(torch.ones(4, 4) == 1)
mask = mask.float()
mask = mask.masked_fill(mask == 0, float('-inf')) 
mask = mask.masked_fill(mask == 1, float(0.0))

mask

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [15]:
mask.to(device='cpu')

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [None]:
torch.tril(torch.ones(4, 4) == 1)

tensor([[ True, False, False, False],
        [ True,  True, False, False],
        [ True,  True,  True, False],
        [ True,  True,  True,  True]])

In [5]:
torch.ones(4, 4) == 1

tensor([[True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True]])

In [29]:
torch.nn.TransformerDecoderLayer(d_model=10, nhead=1)

TransformerDecoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=10, out_features=10, bias=True)
  )
  (multihead_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=10, out_features=10, bias=True)
  )
  (linear1): Linear(in_features=10, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=10, bias=True)
  (norm1): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
  (norm3): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
  (dropout3): Dropout(p=0.1, inplace=False)
)