In [41]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split
import os
from SmilesPE.pretokenizer import atomwise_tokenizer
from os import path as p
import torchtext
from collections import OrderedDict
import math
import torch.nn.functional as f
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict


In [34]:
START_TOKEN = '<start>'
END_TOKEN = '<end>'
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'
VOCAB_SIZE = 100
DEVICE = 'cuda:1' if torch.cuda.is_available() else 'cpu'

In [42]:
def getVocab(smiles_dir, max_size=100):
    '''max_size should be greater than 4'''
    all_toks_dict = defaultdict(int)

    for f_name in os.listdir(smiles_dir):
        with open(os.path.join(smiles_dir, f_name)) as f:
            for line in f:
                if 'smile' in line: continue
                for tok in atomwise_tokenizer(line.split()[0]):
                    all_toks_dict[tok] += 1


    x = sorted([(e[0], e[1]) for e in all_toks_dict.items()], key=lambda e:-e[1])[:VOCAB_SIZE-4]
    ordered_dict = OrderedDict({e[0]: e[1] for e in x})
    special_tokens = [PAD_TOKEN, UNK_TOKEN, START_TOKEN, END_TOKEN]
    vocab = torchtext.vocab.vocab(ordered_dict=ordered_dict, specials=special_tokens)
    vocab.set_default_index(vocab[UNK_TOKEN])
    return vocab

In [49]:
# Dataset creation
class SmilesDataset(Dataset):
    def __init__(self, smiles_dir, vocab, max_len=None): # start and end tokens are added
        '''
        if vocab is None : getVocab() is used
        if max_len is None: take max_len from dataset
        if max_len == 'avg': take average length from dataset
        '''
        self.vocab = vocab

        smiles_files = os.listdir(smiles_dir)
        # reading smiles from files
        tokens = []
        for file_name in smiles_files:
            with open(p.join(smiles_dir, file_name)) as f:
                for l in f:
                    if "smile" not in l:
                        tokens.append(atomwise_tokenizer(l.split()[0]))
        
        if max_len is None: max_len = max([len(sen) for sen in tokens])
        if max_len == 'avg': max_len = int(sum([len(sen) for sen in tokens])/len(tokens))

        # stripping
        tokens = [sen[: max_len-2] for sen in tokens]

        # adding start and end tokens
        tokens = [[START_TOKEN] + sen + [END_TOKEN] for sen in tokens]

        # padding
        tokens = [sen + [PAD_TOKEN]*(max_len-len(sen)) for sen in tokens]

        # addention mask false at <pad> tokem, true at non pad token
        self.pad_masks = torch.tensor([[PAD_TOKEN==tok for tok in sen] for sen in tokens])

        # converting to index
        self.data = torch.tensor([self.vocab(sen) for sen in tokens], dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {'idx': self.data[idx], 'pad_mask': self.pad_masks[idx]}

In [50]:
vocab = getVocab('smiles')

In [51]:
dataset = SmilesDataset('smiles', vocab, max_len=100)

In [17]:
vocab['<unk>']

1

In [52]:
classes = [0]*100

all_x = []
for i in range(0, len(dataset)):
    if (i+1)%100000 == 0: print(i)
    all_x.append(dataset[i]['idx'])

99999
199999
299999
399999
499999
599999
699999
799999
899999
999999
1099999
1199999
1299999
1399999
1499999
1599999
1699999
1799999
1899999


In [53]:
all_x = torch.concat(all_x)

In [54]:
for i in range(100):
    classes[i] = (all_x == i).sum()

In [60]:
len(dataset)

1941411

In [61]:
split_ratio = 0.8
split_seed = 42
train_ds, valid_ds = random_split(dataset, [int(len(dataset)*split_ratio), len(dataset)-int(len(dataset)*split_ratio)], generator=torch.Generator().manual_seed(split_seed))

# model

In [62]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    
    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x.transpose(0, 1)
        x = x + self.pe[:x.size(0)]
        x = x.transpose(0, 1)
        return self.dropout(x)


class DecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward,  dropout=0.1):
        super().__init__()
        self.multihead = nn.MultiheadAttention(embed_dim=d_model, 
                                               num_heads=nhead, 
                                               dropout=dropout, batch_first=True)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout2 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.dropout3 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x, attn_mask, pad_mask):
        x = x + self.dropout1(self.multihead(x, x, x, key_padding_mask=pad_mask, attn_mask=attn_mask, need_weights=False)[0])
        x = self.norm1(x)

        x = x + self.dropout3(self.linear2(self.dropout2(self.linear1(x))))
        x = self.norm2(x)

        return x
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, dim_feedforward, num_layers=1, dropout=0.1, max_len=100):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model=d_model, dropout=dropout, max_len=max_len)
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, nhead, dim_feedforward) for i in range(num_layers)])
        self.linear = nn.Linear(d_model, vocab_size)
        
    def getDevice(self):
        ''' return device of model
        '''
        device = next(self.parameters()).device

        if device.index is None:
            return device.type
        else:
            return device.type + ":" + str(device.index) 

    def generate_square_subsequent_mask(self, sz:int) -> torch.Tensor:
        r"""Generate a square causal mask for the sequence.

        The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
        device_ordinal = {-1: cpu, 0..inf=>gpu}
        """
        
        return torch.triu(
            torch.full((sz, sz), float('-inf'), dtype=torch.float, device=self.getDevice()),
            diagonal=1,
        )


    def forward(self, x, pad_mask=None):
        ''' x = batch * seq_len
            pad_mask = batch*seq_len
        '''
        seq_len = x.shape[1]
        attn_mask = self.generate_square_subsequent_mask(seq_len)
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.decoder_layers:
            x = layer(x, attn_mask=attn_mask, pad_mask=pad_mask)

        x = self.linear(x)
        return x
    
    def generateSmiles(self, batch_size, vocab, max_len=100):
        x = torch.full((batch_size, 1), vocab[START_TOKEN]).to(self.getDevice())
        
        for i in range(1, max_len+1):
            out = self.forward(x)[:, -1]
            out = f.softmax(out, dim=1)
            out = torch.multinomial(out, 1)
            x = torch.cat((x, out), dim=1)
        
        x = x.detach().cpu().tolist()
        # converting idx to smiles
        results = []
        for i in range(batch_size):
            sentance = vocab.lookup_tokens(x[i])
            
            new_sentance = [] # removing special chars
            for i in range(1, len(sentance)):
                e = sentance[i]
                if e==END_TOKEN: break
                if e==START_TOKEN or e==PAD_TOKEN or e==UNK_TOKEN:
                    new_sentance = ['invalid']
                    break
                
                new_sentance.append(e)
            
            results.append(new_sentance)
        
        return results
    


In [64]:
model = Decoder(vocab_size=VOCAB_SIZE, 
        d_model=300, 
        nhead=10, 
        dim_feedforward=50, 
        num_layers=5, 
        dropout=0.1, 
        max_len=100)

model = model.to(DEVICE)
batch_size = 512

train_loader = DataLoader(train_ds, batch_size)
val_loader = DataLoader(valid_ds, batch_size)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

train_result = []
val_result = []
epoch  = 10

In [68]:
smiles = model.generateSmiles(batch_size=10, vocab=vocab)

for e in smiles:
    smile = "".join(e)
    print(smile)
    

Fc1c(cc(Oc2nc3CCN(C23)CC(=O)CC=C)ccc1)C
c1(ccc2c(OC3C(CCN2C(=O)O)CCC3)c1Cl)Cl
O(c1c(OC)cc(OC(=O)OC)cc1)Cc1ccccc1
N(c1[nH]c2cc(cc2c1)C)c1c(Nc2cc(ccc2)c1)C(=O)NC1=NCCc2ccccc2OC1
[C@]12(c3cc4c5[n+](O1)c(CCC5)Ccc4)(c(OC[C@H]1O[C@@H]2C(=C)NCC1)cccc3)C
n1(c(=O)n2c(c(c3n(-c4c(nc3ccccc4)c3ccccc4)C)CCc2)c2cnns1)c1ccccc1
c1ccc2c(c1)C(=O)c1n(c2[C@@H]2NS(=O)(=O)C)cccn1C
c1(C/C=C/CC(=O)c2c[nH]c3nccn3c2C)cncc1S(=O)(=O)c1c(cc(C2C2)c2c1c(cnc2Cl)Cl)ccccc1
S(c1cc(C(=O)O)ccc1)C(=O)O/C=C/c1ccccc1
Cc1[nH]c2cc(F)ccc2n1c1cccc(NC(=O)N2CC(C2=O)Nc2cc(OC(F)(F)F)nc3)ccn1


In [28]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)
batch_size = 512

In [67]:
for i in range(10):
    tl = train(model, optimizer, criterion, train_loader)
    vl = valid(model, criterion, val_loader)

    train_result.append(tl)
    val_result.append(vl)

    print("Epoch:", i+1, "train loss:", tl, "val loss:", vl)






KeyboardInterrupt: 

In [66]:
# Training function
def train(model, optimizer, criterion, train_loader):
    model.train()
    total_loss = 0
    count = 200
    for e in train_loader:
        x = e['idx'].to(DEVICE)
        pad_mask = e['pad_mask'].to(DEVICE)

        x_input = x[:, :-1]
        pad_mask = pad_mask[:, :-1]
        y_expected = x[:, 1:]
        output = model(x_input, pad_mask)
        #print(output.shape, y_expected.shape)

        output = torch.flatten(output, start_dim=0, end_dim=1)
        y_expected = torch.flatten(y_expected, start_dim=0, end_dim=1)
        loss = criterion(output, y_expected)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        count += 1
        # if (count+1)%200 == 0: print(loss.item())
        
    total_loss = total_loss/len(train_loader)
    
    return total_loss

# valid function
def valid(model, criterion, val_loader):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for e in val_loader:
            x = e['idx'].to(DEVICE)
            pad_mask = e['pad_mask'].to(DEVICE)

            x_input = x[:, :-1]
            pad_mask = pad_mask[:, :-1]
            y_expected = x[:, 1:]

            output = model(x_input, pad_mask)
            
            output = torch.flatten(output, start_dim=0, end_dim=1)
            y_expected = torch.flatten(y_expected, start_dim=0, end_dim=1)
            loss = criterion(output, y_expected)

            total_loss += loss.item()
        
    total_loss = total_loss/len(val_loader)
   
    return total_loss
    

In [44]:
m = Decoder(10, 10, 1, 10)

m.to(device='cuda')

Decoder(
  (embedding): Embedding(10, 10)
  (positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (decoder_layers): ModuleList(
    (0): DecoderLayer(
      (multihead): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=10, out_features=10, bias=True)
      )
      (dropout1): Dropout(p=0.1, inplace=False)
      (norm1): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=10, out_features=10, bias=True)
      (dropout2): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=10, out_features=10, bias=True)
      (dropout3): Dropout(p=0.1, inplace=False)
      (norm2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
    )
  )
)

In [46]:
m.generateSmiles(100, getVocab(), 10)

[['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['invalid'],
 [],
 [],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 [],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['=', 'C', '/'],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 [],
 ['invalid'],
 [],
 ['invalid'],
 [],
 [],
 ['invalid'],
 ['invalid'],
 ['\\'],
 ['invalid'],
 ['=', '\\'],
 ['invalid'],
 ['/', 'C', '/', 'O'],
 [],
 ['\\'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['/'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['invalid'],
 ['invalid'],
 [],
 ['\\', '/'],
 ['invalid'],
 ['invalid'],
 [],
 ['invalid'],
 ['

In [None]:
m.

In [61]:
m.to("cuda:1")

Decoder(
  (embedding): Embedding(10, 10)
  (decoder_layers): ModuleList(
    (0): DecoderLayer(
      (multihead): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=10, out_features=10, bias=True)
      )
      (dropout1): Dropout(p=0.1, inplace=False)
      (norm1): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=10, out_features=10, bias=True)
      (dropout2): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=10, out_features=10, bias=True)
      (dropout3): Dropout(p=0.1, inplace=False)
      (norm2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
    )
  )
)

In [None]:
m.

In [56]:
print(next(m.parameters()).device.index)

None


In [55]:
None

In [48]:
x = next(m.parameters()).device

In [49]:
x

device(type='cuda', index=1)

In [51]:
x.index

1

In [16]:
torch.full((2, 3), 4)

tensor([[4, 4, 4],
        [4, 4, 4]])

In [4]:
torch.tril(torch.ones(4, 4))

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [5]:
torch.triu(
        torch.full((4, 4), float('-inf'), dtype=torch.float),
        diagonal=1,
    )

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [9]:
mask.get_device()

-1

In [6]:
mask = torch.tril(torch.ones(4, 4) == 1)
mask = mask.float()
mask = mask.masked_fill(mask == 0, float('-inf')) 
mask = mask.masked_fill(mask == 1, float(0.0))

mask

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [15]:
mask.to(device='cpu')

tensor([[0., -inf, -inf, -inf],
        [0., 0., -inf, -inf],
        [0., 0., 0., -inf],
        [0., 0., 0., 0.]])

In [None]:
torch.tril(torch.ones(4, 4) == 1)

tensor([[ True, False, False, False],
        [ True,  True, False, False],
        [ True,  True,  True, False],
        [ True,  True,  True,  True]])

In [5]:
torch.ones(4, 4) == 1

tensor([[True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True]])

In [29]:
torch.nn.TransformerDecoderLayer(d_model=10, nhead=1)

TransformerDecoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=10, out_features=10, bias=True)
  )
  (multihead_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=10, out_features=10, bias=True)
  )
  (linear1): Linear(in_features=10, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=10, bias=True)
  (norm1): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
  (norm3): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
  (dropout3): Dropout(p=0.1, inplace=False)
)