In [1]:
import torch
from torch.utils.data import Dataset
import random
from rdkit import Chem
import pickle

from calc_property import calculate_property

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SMILESDataset(Dataset):
    def __init__(self, data_path, data_length=None, shuffle=False):
        with open(data_path,'r') as f:
            lines = f.readlines()
        self.data = [l.strip() for l in lines]

        with open('./normalize.pkl', 'rb') as w:
            norm = pickle.load(w)
        self.property_mean, self.property_std = norm

        if shuffle:
            random.shuffle(self.data)
        
        ## Why need this line? ##
        if data_length is not None:
            self.data = self.data[data_length[0]: data_length[1]]
            
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        smiles = 'Q' + self.data[idx]
        properties = (calculate_property(smiles[1:])-self.property_std) / self.property_mean
        
        return smiles, properties

In [3]:
sampleDataset = SMILESDataset(data_path='./data/pubchem-1m-simple.txt', data_length=None, shuffle=False)
sample = sampleDataset.__getitem__(0)

sample

('QCN(c1ccccc1)c1ccccc1C(=O)NCC1(O)CCOCC1',
 tensor([  0.5682,   0.4030,   0.6777,   0.6901,   0.6560,   0.6904,   0.6890,
           0.6158,   0.6544,   0.3731,   0.6623,  -1.0214,   0.5386,  -6.4296,
           0.6370,   0.7405,   0.7696,   0.8429,   0.2686,   1.6701,   0.6852,
           0.6222,   0.6510,   0.5670,   0.3475,   0.6824,   0.8768,   0.8768,
          -0.1836,   2.4966,   0.1489,   0.7014,   0.6364,   0.2755,   0.4956,
          -2.3909,   0.4995,   0.0625,   0.7835,  -1.1521,   0.3690,   0.4610,
           0.6192,   0.3638, -29.1460,   0.2716,  -2.7021,   0.9223,   0.2647,
           0.6883,   0.5710,   0.4035,   1.0784]))

In [4]:
with open('./data/pubchem-1m-simple.txt', 'r') as f:
    lines = f.readlines()
    
with open('./normalize.pkl', 'rb') as w:
    norm = pickle.load(w)
    print(norm)

    

(tensor([ 2.0210e+00,  7.5875e+02,  1.7544e+01,  1.4052e+01,  1.4723e+01,
         1.1664e+01,  8.1625e+00,  8.8090e+00,  6.2020e+00,  6.9684e+00,
         4.2730e+00,  4.9467e+00,  2.9353e+00,  3.5438e+00,  3.5288e+02,
         1.1951e+00,  1.9034e+00,  2.5342e+00,  4.1357e-01, -2.1272e+00,
         2.4343e+01,  3.3127e+02,  1.7561e+01,  7.6679e+00,  4.3503e+00,
         1.4605e+02,  1.1184e+01,  1.1184e+01,  1.9147e-01, -9.6040e-01,
         2.6799e+00,  9.4141e+01,  3.5333e+02,  1.7290e+00,  5.1585e+00,
         2.6400e-01,  5.4614e-01,  8.1014e-01,  1.1892e+00,  7.1781e-01,
         1.9070e+00,  4.0690e+00,  1.3547e+00,  6.2828e+00,  4.5800e-03,
         5.4205e+00,  2.0371e-01,  3.9774e-01,  6.0145e-01,  1.2960e+02,
         2.7171e+00,  6.8247e+01,  6.1739e-01]), tensor([5.9750e-01, 4.0613e+02, 5.8108e+00, 4.7454e+00, 4.7843e+00, 4.0368e+00,
        2.9280e+00, 3.1273e+00, 2.3723e+00, 3.8311e+00, 1.8252e+00, 9.7076e+00,
        1.4360e+00, 2.5803e+01, 1.1541e+02, 2.3501e-01, 2.95

#### Define Custom Tokenizer 

In [5]:
import pandas as pd
from transformers import BertTokenizer

tokenizer = BertTokenizer(vocab_file= "./vocab_bpe_300.txt" ,lowercase=False, do_basic_tokenize=False)

In [6]:
df = pd.read_fwf('./vocab_bpe_300.txt', header=None)
df

Unnamed: 0,0
0,[PAD]
1,[UNK]
2,[CLS]
3,[SEP]
4,[MASK]
...,...
295,##c12
296,##[Si
297,##c(C(=O
298,##[nH+]


In [31]:
# from rdkit import Chem
# from transformers import BertTokenizer

# from typing import List
# import re

# ## REGEX_PATTERN ##
# SMI_REGEX_PATTERN =  r"(\%\([0-9]{3}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\||\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"

# class RegexTokenizer:
#     """Run regex tokenization"""

#     def __init__(self, regex_pattern: str=SMI_REGEX_PATTERN) -> None:
#         """Constructs a RegexTokenizer.
#         Args:
#             regex_pattern: regex pattern used for tokenization.
#             suffix: optional suffix for the tokens. Defaults to "".
#         """
#         self.regex_pattern = regex_pattern
#         self.regex = re.compile(self.regex_pattern)

#     def tokenize(self, text: str) -> List[str]:
#         """Regex tokenization.
#         Args:
#             text: text to tokenize.
#         Returns:
#             extracted tokens separated by spaces.
#         """
#         tokens = [token for token in self.regex.findall(text)]
#         return tokens
    
    
# class SMILESTokenizer(BertTokenizer):
#     def __init__(self, 
#         vocab_file: str,
#         unk_token: str = "[UNK]",
#         sep_token: str = "[SEP]",
#         pad_token: str = "[PAD]",
#         cls_token: str = "[CLS]",
#         mask_token: str = "[MASK]",
#         do_lower_case = False,
#         **kwargs,
#         ) -> None:
        
#         super().__init__(
#             vocab_file=vocab_file,
#             unk_token=unk_token,
#             sep_token=sep_token,
#             pad_token=pad_token,
#             cls_token=cls_token,
#             mask_token=mask_token,
#             do_lower_case=do_lower_case,
#             **kwargs,
#         )
        
#         self.tokenizer = RegexTokenizer()

In [7]:
pretrain_config = {
        'embed_dim': 256,#256
        'property_width': 384, #???
        'batch_size': 4,#64
        'temp': 0.07,
        'queue_size': 2048,#65536
        'momentum': 0.995,
        'alpha': 0.4,
        'bert_config': './config_bert.json',    #config file for BERT model. The configuration for ViT can be manually changed in albef.py
        'schedular': {'sched': 'cosine', 'lr': 1e-4, 'epochs': 30, 'min_lr': 1e-5,
                      'decay_rate': 1, 'warmup_lr': 1e-5, 'warmup_epochs': 20, 'cooldown_epochs': 0},
        'optimizer': {'opt': 'adamW', 'lr': 1e-4, 'weight_decay': 0.02}
    }

------------------------------------------------------------------------------

In [8]:
sample_SMILES = sample[0]
sample_prop = sample[1]

SMILES_token = tokenizer(sample_SMILES)
input_ids = torch.LongTensor([SMILES_token['input_ids']])
print(input_ids)
attention_mask = torch.Tensor([SMILES_token['attention_mask']])
print(attention_mask)

tensor([[  2,   5, 146,   8, 212, 112, 115,  98, 222, 114,  98,   3]])
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])


In [9]:
from xbert import BertConfig, BertForMaskedLM

smilesAndFusion_config = BertConfig.from_json_file('./config_bert_smiles_and_fusion_encoder.json')
smilesEncoder = BertForMaskedLM(config = smilesAndFusion_config)

property_config = BertConfig.from_json_file('./config_bert_property_encoder.json')
propertyEncoder = BertForMaskedLM(config = property_config)


In [10]:
import torch 
from torch import nn 

propertyOriginal = torch.rand([1,53])
embedProperty = nn.Linear(1, 384)(propertyOriginal.unsqueeze(2))
#print(embedProperty)

property_CLS = nn.Parameter(torch.zeros([1,1,384]))
property_MASK = nn.Parameter(torch.zeros([1,1,384]))
property_MASK = property_MASK.expand(propertyOriginal.size(0), propertyOriginal.size(1), -1)

halfMask = torch.bernoulli(torch.ones_like(propertyOriginal)*0.5)
halfMaskBatch = halfMask.unsqueeze(2).repeat(1,1,property_MASK.size(2))
maskedProperty = embedProperty*halfMaskBatch + property_MASK*halfMaskBatch

print((embedProperty*halfMaskBatch))
#print(property_MASK*(1-halfMaskBatch))

inputProperty = torch.cat([property_CLS.expand(propertyOriginal.size(0), -1,-1), maskedProperty], dim=1)

#print(inputProperty.shape)

tensor([[[-0.1417, -0.4103, -0.3208,  ..., -1.1290, -0.5527, -0.0819],
         [-0.2524,  0.3830, -0.2037,  ..., -0.3266,  0.0735, -0.4584],
         [-0.1997,  0.0054, -0.2594,  ..., -0.7085, -0.2245, -0.2792],
         ...,
         [-0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000],
         [-0.1488, -0.3589, -0.3132,  ..., -1.0770, -0.5122, -0.1063],
         [-0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000, -0.0000]]],
       grad_fn=<MulBackward0>)


In [58]:
encodedProp = propertyEncoder.bert(inputs_embeds=inputProperty, return_dict=True).last_hidden_state
print(encodedProp.shape)
encodedSmiles = smilesEncoder.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True, mode='text').last_hidden_state
print(encodedSmiles.shape)

torch.Size([1, 54, 384])
torch.Size([1, 12, 384])


In [12]:
import torch.nn.functional as F 

propertyProj = nn.Linear(384, 384)

propertyAtts = torch.ones(encodedProp.size()[:-1], dtype=torch.long).to(inputProperty.device)
propertyFeat = F.normalize(propertyProj(encodedProp[:,0,:]), dim=-1)

property_queue = torch.randn(384, 2048)
property_queue.clone().detach()

propertyFeatAll = torch.cat([propertyFeat.t(), property_queue.clone().detach()], dim=1)
propertyFeatAll

tensor([[ 0.0041, -0.8168,  0.5872,  ..., -0.3747, -0.2723,  0.1165],
        [-0.0697,  1.3831, -0.2511,  ..., -2.2535,  1.1538,  0.6848],
        [ 0.0629,  0.0897,  0.4163,  ...,  0.2974, -1.2093, -0.8941],
        ...,
        [ 0.0267, -1.2687, -1.6029,  ...,  0.4938,  1.7291,  0.1211],
        [ 0.0428, -1.4156, -0.9930,  ...,  0.5963, -0.8524, -0.5831],
        [ 0.0701,  0.8586, -0.7782,  ..., -0.2760, -0.9922, -0.1335]],
       grad_fn=<CatBackward0>)

In [13]:
smilesProj = nn.Linear(384, 384)
smilesFeat = F.normalize(smilesProj(encodedSmiles[:,0,:]), dim=-1)

smiles_queue = torch.randn(384, 2048)
smiles_queue.clone().detach()

smilesFeatAll = torch.cat([smilesFeat.t(), smiles_queue.clone().detach()], dim=1)
propertyFeatAll

tensor([[ 0.0041, -0.8168,  0.5872,  ..., -0.3747, -0.2723,  0.1165],
        [-0.0697,  1.3831, -0.2511,  ..., -2.2535,  1.1538,  0.6848],
        [ 0.0629,  0.0897,  0.4163,  ...,  0.2974, -1.2093, -0.8941],
        ...,
        [ 0.0267, -1.2687, -1.6029,  ...,  0.4938,  1.7291,  0.1211],
        [ 0.0428, -1.4156, -0.9930,  ...,  0.5963, -0.8524, -0.5831],
        [ 0.0701,  0.8586, -0.7782,  ..., -0.2760, -0.9922, -0.1335]],
       grad_fn=<CatBackward0>)

In [14]:
sim_p2s = (propertyFeat @ smilesFeatAll)
sim_targets = torch.zeros(sim_p2s.size()).to(propertyOriginal.device)
print(sim_targets.shape)

sim_s2p = smilesFeat @ propertyFeatAll
sim_targets_same = torch.zeros(sim_s2p.size()).to(propertyOriginal.device)
print(sim_targets.shape)

torch.Size([1, 2049])
torch.Size([1, 2049])


In [15]:
outputProperty_pos = smilesEncoder.bert(encoder_embeds = encodedProp,
                               attention_mask = propertyAtts,
                               encoder_hidden_states = encodedSmiles,
                               encoder_attention_mask = attention_mask,
                               return_dict = True,
                               mode= 'fusion'
                               ).last_hidden_state[:,0,:]
outputSmiles_pos = smilesEncoder.bert(encoder_embeds = encodedSmiles,
                             attention_mask = attention_mask,
                             encoder_hidden_states = encodedProp,
                             encoder_attention_mask = propertyAtts,
                             return_dict = True,
                             mode= 'fusion'
                             ).last_hidden_state[:,0,:]
                                   

In [16]:
print(outputProperty_pos.shape)
print(outputSmiles_pos.shape)

torch.Size([1, 384])
torch.Size([1, 384])


In [17]:
batch_size = 1
weights_p2s = F.softmax(sim_p2s[:,:batch_size], dim=1)
weights_s2p = F.softmax(sim_s2p[:,:batch_size], dim=1)

In [18]:
encodedProp_neg = []
neg_idx = torch.multinomial(weights_p2s[0], 1).item()
encodedProp_neg.append(encodedProp[neg_idx])
encodedProp_neg = torch.stack(encodedProp_neg, dim=0)
encodedProp_neg

tensor([[[ 0.1824, -0.0607,  1.3084,  ..., -1.4432,  0.4354,  2.1898],
         [ 0.6177, -0.5268,  0.0507,  ..., -0.9035, -0.2982,  0.0671],
         [ 0.4261,  0.9823, -0.1573,  ...,  0.0697,  0.6003,  0.0170],
         ...,
         [ 1.0068,  0.2572,  2.0727,  ..., -0.0098, -0.9208, -0.1247],
         [ 0.6434, -0.1911,  0.1341,  ..., -0.7065, -0.5435,  0.2061],
         [ 1.0247, -0.4561,  2.6050,  ..., -1.0719, -0.0312,  1.5180]]],
       grad_fn=<StackBackward0>)

In [19]:
encodedSmiles_neg = []
smilesAtts_neg = []

neg_idx = torch.multinomial(weights_s2p[0], 1).item()
encodedSmiles_neg.append(encodedSmiles[neg_idx])
smilesAtts_neg.append(attention_mask[neg_idx])
encodedSmiles_neg = torch.stack(encodedSmiles_neg, dim=0)
smilesAtts_neg = torch.stack(smilesAtts_neg, dim=0)

encodedSmiles_neg

tensor([[[ 0.9357, -0.9966,  0.5626,  ...,  0.8034,  0.8001,  0.2025],
         [ 0.0703,  0.1101, -1.6660,  ..., -0.8955,  0.3145,  1.6860],
         [ 0.4842,  1.0346, -0.9396,  ..., -0.8054, -0.0582, -2.2141],
         ...,
         [ 2.2750,  0.9106, -2.1961,  ..., -0.2422,  0.9576,  0.2175],
         [-0.6238, -0.0892,  0.5609,  ..., -0.0097,  1.5409,  0.3180],
         [ 0.7790, -1.2032, -0.9792,  ...,  0.3504,  0.4163,  0.4745]]],
       grad_fn=<StackBackward0>)

In [20]:
encProperty_all = torch.cat([encodedProp, encodedProp_neg], dim=0)
propertyAtts_all = torch.cat([propertyAtts, propertyAtts], dim=0)

encSmiles_all = torch.cat([encodedSmiles_neg, encodedSmiles], dim=0)
smilesAtts_all = torch.cat([smilesAtts_neg, attention_mask], dim=0)

In [21]:
outputProperty_neg = smilesEncoder.bert(encoder_embeds = encProperty_all,
                                    attention_mask = propertyAtts_all,
                                    encoder_hidden_states = encSmiles_all,
                                    encoder_attention_mask = smilesAtts_all,
                                    return_dict = True,
                                    mode = 'fusion'
                                    ).last_hidden_state[:,0,:]
#outputProperty_neg = outputProperty_neg
outputSmiles_neg = smilesEncoder.bert(encoder_embeds = encSmiles_all,
                                    attention_mask = smilesAtts_all,
                                    encoder_hidden_states = encProperty_all,
                                    encoder_attention_mask = propertyAtts_all,
                                    return_dict = True,
                                    mode = 'fusion'
                                    ).last_hidden_state[:,0,:]
#outputSmiles_neg = outputSmiles_neg

In [22]:
pos_embeds = torch.cat([outputProperty_pos, outputSmiles_pos], dim=-1)
neg_embeds = torch.cat([outputProperty_neg, outputSmiles_neg], dim=-1)

ps_embeddings = torch.cat([pos_embeds, neg_embeds], dim=0)
ps_embeddings.shape
ps_output = nn.Linear(768, 2)(ps_embeddings)
ps_output

tensor([[ 0.3131, -0.4291],
        [ 0.3008, -0.5363],
        [ 0.1552, -0.5555]], grad_fn=<AddmmBackward0>)

In [23]:
psm_labels = torch.cat([torch.ones(1, dtype=torch.long), torch.zeros(2, dtype=torch.long)], dim=0).to(propertyOriginal.device)
psm_labels

tensor([1, 0, 0])

In [24]:
loss_psm = F.cross_entropy(ps_output, psm_labels)
loss_psm

tensor(0.6303, grad_fn=<NllLossBackward0>)

In [25]:
torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool()

tensor([[ True,  True, False,  True,  True,  True,  True, False,  True,  True,
          True, False]])

In [26]:
logits_m  = smilesEncoder(input_ids,
                          attention_mask = attention_mask,
                          encoder_hidden_states = encodedProp,
                          encoder_attention_mask = propertyAtts,
                          return_dict = True,
                          is_decoder = True,
                          return_logits = True 
                          )[:,:-1,:]

logits_m.shape

torch.Size([1, 11, 300])

In [30]:
labels = input_ids.clone()[:,1:]
per_logits_m = logits_m.permute((0,2,1))

print(logits_m.shape)
print(per_logits_m.shape)
print(labels.shape)

loss_fct = nn.CrossEntropyLoss()
loss_nwp = loss_fct(per_logits_m,labels)

print(loss_nwp)

torch.Size([1, 11, 300])
torch.Size([1, 300, 11])
torch.Size([1, 11])
tensor(5.8192, grad_fn=<NllLoss2DBackward0>)


In [28]:
F.log_softmax(logits_m, dim=-1) * F.softmax(logits_m, dim=-1)

tensor([[[-0.0311, -0.0295, -0.0184,  ..., -0.0117, -0.0196, -0.0126],
         [-0.0125, -0.0173, -0.0163,  ..., -0.0141, -0.0148, -0.0215],
         [-0.0223, -0.0419, -0.0296,  ..., -0.0148, -0.0102, -0.0124],
         ...,
         [-0.0158, -0.0355, -0.0273,  ..., -0.0110, -0.0179, -0.0211],
         [-0.0147, -0.0423, -0.0183,  ..., -0.0367, -0.0124, -0.0188],
         [-0.0164, -0.0265, -0.0178,  ..., -0.0186, -0.0134, -0.0156]]],
       grad_fn=<MulBackward0>)

In [29]:
F.softmax(logits_m, dim=-1)
print(F.softmax(logits_m, dim=-1).shape)

torch.Size([1, 11, 300])


In [38]:
soft_labels = F.softmax(logits_m, dim=1)
loss_distill = -torch.sum(F.log_softmax(logits_m, dim=1) * soft_labels, dim=-1)

loss_distill = (loss_distill * (labels != 0)).mean()
loss_distill

tensor(64.0643, grad_fn=<MeanBackward0>)

In [68]:
targets = propertyOriginal.clone()

encProperty_masked_m = propertyEncoder.bert(inputs_embeds = inputProperty,
                                       return_dict = True,
                                       is_decoder = True
                                       ).last_hidden_state
encProperty_masked_m

tensor([[[-0.2711, -0.0187,  0.5958,  ..., -1.9072,  0.2913,  1.6481],
         [ 0.7051, -1.0061,  0.1156,  ..., -0.8559,  0.0309,  0.1750],
         [ 0.5251,  1.1804, -0.4049,  ...,  0.0625,  0.9561, -0.3600],
         ...,
         [ 1.4204,  0.1097,  1.9425,  ..., -0.2383, -0.9002, -0.3194],
         [ 0.5641, -0.5436,  0.1660,  ..., -0.3512, -0.3763,  0.0207],
         [ 1.0097, -0.0674,  2.3085,  ..., -1.1544, -0.1497,  1.2358]]],
       grad_fn=<NativeLayerNormBackward0>)

In [71]:
npp_output_m = smilesEncoder.bert(encoder_embeds = encProperty_masked_m,
                                  attention_mask = propertyAtts,
                                  encoder_hidden_states = encodedSmiles,
                                  encoder_attention_mask = attention_mask,    
                                  return_dict = True,
                                  is_decoder = True,
                                  mode = 'fusion').last_hidden_state[:,:-1,:]

#npp_output_m
print(npp_output_m.shape)

torch.Size([1, 53, 384])


In [72]:
target_reg = nn.Sequential(nn.Linear(384, 384),
                           nn.GELU(),
                           nn.LayerNorm(384, property_config.layer_norm_eps),
                           nn.Linear(384, 1))

In [75]:
pred_m = target_reg(npp_output_m)
pred_m.squeeze(2).shape

torch.Size([1, 53])

In [79]:
propertyTargets = propertyOriginal.clone() 
propertyTargets.shape

torch.Size([1, 53])

In [82]:
loss_MSE = nn.MSELoss()
loss_npp = loss_MSE(pred_m*halfMask, propertyTargets*halfMask)
loss_npp

  return F.mse_loss(input, target, reduction=self.reduction)


tensor(0.1642, grad_fn=<MseLossBackward0>)

In [88]:
print(halfMask)
print(1-halfMask)

tensor([[1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 0., 0., 1.,
         0., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0.,
         1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0.]])
tensor([[0., 0., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 1., 0.,
         1., 0., 1., 0., 1., 0., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 1.,
         0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0., 1.]])
