In [1]:
import torch
from torch.utils.data import Dataset
import random
from rdkit import Chem
import pickle

from calc_property import calculate_property

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SMILESDataset(Dataset):
    def __init__(self, data_path, data_length=None, shuffle=False):
        with open(data_path,'r') as f:
            lines = f.readlines()
        self.data = [l.strip() for l in lines]

        with open('./normalize.pkl', 'rb') as w:
            norm = pickle.load(w)
        self.property_mean, self.property_std = norm

        if shuffle:
            random.shuffle(self.data)
        
        ## Why need this line? ##
        if data_length is not None:
            self.data = self.data[data_length[0]: data_length[1]]
            
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        smiles = 'Q' + self.data[idx]
        properties = (calculate_property(smiles[1:])-self.property_std) / self.property_mean
        
        return smiles, properties

In [3]:
sampleDataset = SMILESDataset(data_path='./data/pubchem-1m-simple.txt', data_length=None, shuffle=False)
sample = sampleDataset.__getitem__(0)

sample

('QCN(c1ccccc1)c1ccccc1C(=O)NCC1(O)CCOCC1',
 tensor([  0.5682,   0.4030,   0.6777,   0.6901,   0.6560,   0.6904,   0.6890,
           0.6158,   0.6544,   0.3731,   0.6623,  -1.0214,   0.5386,  -6.4296,
           0.6370,   0.7405,   0.7696,   0.8429,   0.2686,   1.6701,   0.6852,
           0.6222,   0.6510,   0.5670,   0.3475,   0.6824,   0.8768,   0.8768,
          -0.1836,   2.4966,   0.1489,   0.7014,   0.6364,   0.2755,   0.4956,
          -2.3909,   0.4995,   0.0625,   0.7835,  -1.1521,   0.3690,   0.4610,
           0.6192,   0.3638, -29.1460,   0.2716,  -2.7021,   0.9223,   0.2647,
           0.6883,   0.5710,   0.4035,   1.0784]))

In [4]:
with open('./data/pubchem-1m-simple.txt', 'r') as f:
    lines = f.readlines()
    
with open('./normalize.pkl', 'rb') as w:
    norm = pickle.load(w)
    print(norm)

    

(tensor([ 2.0210e+00,  7.5875e+02,  1.7544e+01,  1.4052e+01,  1.4723e+01,
         1.1664e+01,  8.1625e+00,  8.8090e+00,  6.2020e+00,  6.9684e+00,
         4.2730e+00,  4.9467e+00,  2.9353e+00,  3.5438e+00,  3.5288e+02,
         1.1951e+00,  1.9034e+00,  2.5342e+00,  4.1357e-01, -2.1272e+00,
         2.4343e+01,  3.3127e+02,  1.7561e+01,  7.6679e+00,  4.3503e+00,
         1.4605e+02,  1.1184e+01,  1.1184e+01,  1.9147e-01, -9.6040e-01,
         2.6799e+00,  9.4141e+01,  3.5333e+02,  1.7290e+00,  5.1585e+00,
         2.6400e-01,  5.4614e-01,  8.1014e-01,  1.1892e+00,  7.1781e-01,
         1.9070e+00,  4.0690e+00,  1.3547e+00,  6.2828e+00,  4.5800e-03,
         5.4205e+00,  2.0371e-01,  3.9774e-01,  6.0145e-01,  1.2960e+02,
         2.7171e+00,  6.8247e+01,  6.1739e-01]), tensor([5.9750e-01, 4.0613e+02, 5.8108e+00, 4.7454e+00, 4.7843e+00, 4.0368e+00,
        2.9280e+00, 3.1273e+00, 2.3723e+00, 3.8311e+00, 1.8252e+00, 9.7076e+00,
        1.4360e+00, 2.5803e+01, 1.1541e+02, 2.3501e-01, 2.95

#### Define Custom Tokenizer 

In [5]:
import pandas as pd
from transformers import BertTokenizer

tokenizer = BertTokenizer(vocab_file= "./vocab_bpe_300.txt" ,lowercase=False, do_basic_tokenize=False)

In [6]:
df = pd.read_fwf('./vocab_bpe_300.txt', header=None)
df

Unnamed: 0,0
0,[PAD]
1,[UNK]
2,[CLS]
3,[SEP]
4,[MASK]
...,...
295,##c12
296,##[Si
297,##c(C(=O
298,##[nH+]


In [7]:
# from rdkit import Chem
# from transformers import BertTokenizer

# from typing import List
# import re

# ## REGEX_PATTERN ##
# SMI_REGEX_PATTERN =  r"(\%\([0-9]{3}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\||\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"

# class RegexTokenizer:
#     """Run regex tokenization"""

#     def __init__(self, regex_pattern: str=SMI_REGEX_PATTERN) -> None:
#         """Constructs a RegexTokenizer.
#         Args:
#             regex_pattern: regex pattern used for tokenization.
#             suffix: optional suffix for the tokens. Defaults to "".
#         """
#         self.regex_pattern = regex_pattern
#         self.regex = re.compile(self.regex_pattern)

#     def tokenize(self, text: str) -> List[str]:
#         """Regex tokenization.
#         Args:
#             text: text to tokenize.
#         Returns:
#             extracted tokens separated by spaces.
#         """
#         tokens = [token for token in self.regex.findall(text)]
#         return tokens
    
    
# class SMILESTokenizer(BertTokenizer):
#     def __init__(self, 
#         vocab_file: str,
#         unk_token: str = "[UNK]",
#         sep_token: str = "[SEP]",
#         pad_token: str = "[PAD]",
#         cls_token: str = "[CLS]",
#         mask_token: str = "[MASK]",
#         do_lower_case = False,
#         **kwargs,
#         ) -> None:
        
#         super().__init__(
#             vocab_file=vocab_file,
#             unk_token=unk_token,
#             sep_token=sep_token,
#             pad_token=pad_token,
#             cls_token=cls_token,
#             mask_token=mask_token,
#             do_lower_case=do_lower_case,
#             **kwargs,
#         )
        
#         self.tokenizer = RegexTokenizer()

In [8]:
pretrain_config = {
        'embed_dim': 256,#256
        'property_width': 384, #???
        'batch_size': 4,#64
        'temp': 0.07,
        'queue_size': 2048,#65536
        'momentum': 0.995,
        'alpha': 0.4,
        'bert_config': './config_bert.json',    #config file for BERT model. The configuration for ViT can be manually changed in albef.py
        'schedular': {'sched': 'cosine', 'lr': 1e-4, 'epochs': 30, 'min_lr': 1e-5,
                      'decay_rate': 1, 'warmup_lr': 1e-5, 'warmup_epochs': 20, 'cooldown_epochs': 0},
        'optimizer': {'opt': 'adamW', 'lr': 1e-4, 'weight_decay': 0.02}
    }

------------------------------------------------------------------------------

In [9]:
sample_SMILES = sample[0]
sample_prop = sample[1]

SMILES_token = tokenizer(sample_SMILES)
input_ids = torch.LongTensor([SMILES_token['input_ids']])
print(input_ids)
attention_mask = torch.Tensor([SMILES_token['attention_mask']])
print(attention_mask)

tensor([[  2,   5, 146,   8, 212, 112, 115,  98, 222, 114,  98,   3]])
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])


In [10]:
from xbert import BertConfig, BertForMaskedLM

smilesAndFusion_config = BertConfig.from_json_file('./config_bert_smiles_and_fusion_encoder.json')
smilesEncoder = BertForMaskedLM(config = smilesAndFusion_config)

property_config = BertConfig.from_json_file('./config_bert_property_encoder.json')
propertyEncoder = BertForMaskedLM(config = property_config)


In [11]:
import torch 
from torch import nn 

propertyOriginal = torch.rand([1,53])
embedProperty = nn.Linear(1, 384)(propertyOriginal.unsqueeze(2))
#print(embedProperty)

property_CLS = nn.Parameter(torch.zeros([1,1,384]))
property_MASK = nn.Parameter(torch.zeros([1,1,384]))
property_MASK = property_MASK.expand(propertyOriginal.size(0), propertyOriginal.size(1), -1)

halfMask = torch.bernoulli(torch.ones_like(propertyOriginal)*0.5)
halfMaskBatch = halfMask.unsqueeze(2).repeat(1,1,property_MASK.size(2))
maskedProperty = embedProperty*halfMaskBatch + property_MASK*halfMaskBatch

print((embedProperty*halfMaskBatch))
#print(property_MASK*(1-halfMaskBatch))

inputProperty = torch.cat([property_CLS.expand(propertyOriginal.size(0), -1,-1), maskedProperty], dim=1)

#print(inputProperty.shape)

tensor([[[-1.2091,  0.7468,  0.5669,  ..., -0.3899, -0.4842, -0.5266],
         [-0.9721,  0.7393,  0.7513,  ..., -0.7073, -0.1890, -0.3069],
         [-0.8090,  0.7342,  0.8782,  ..., -0.9258,  0.0142, -0.1556],
         ...,
         [-0.9783,  0.7395,  0.7465,  ..., -0.6990, -0.1967, -0.3126],
         [-1.0535,  0.7419,  0.6879,  ..., -0.5982, -0.2904, -0.3824],
         [-0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000]]],
       grad_fn=<MulBackward0>)


In [12]:
encodedProp = propertyEncoder.bert(inputs_embeds=inputProperty, return_dict=True).last_hidden_state
print(encodedProp.shape)
encodedSmiles = smilesEncoder.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True, mode='text').last_hidden_state
print(encodedSmiles.shape)

torch.Size([1, 54, 384])
torch.Size([1, 12, 384])


In [13]:
import torch.nn.functional as F 

propertyProj = nn.Linear(384, 384)

propertyAtts = torch.ones(encodedProp.size()[:-1], dtype=torch.long).to(inputProperty.device)
propertyFeat = F.normalize(propertyProj(encodedProp[:,0,:]), dim=-1)

property_queue = torch.randn(384, 2048)
property_queue.clone().detach()

propertyFeatAll = torch.cat([propertyFeat.t(), property_queue.clone().detach()], dim=1)
propertyFeatAll

tensor([[-8.1473e-02,  1.2620e+00, -2.4469e-01,  ...,  4.6948e-01,
         -1.9055e+00,  4.4526e-02],
        [-5.4801e-03, -8.3393e-01, -7.7849e-01,  ..., -8.4742e-01,
         -3.9457e-01,  1.6535e+00],
        [-1.1821e-02,  1.0043e-01,  8.3092e-01,  ..., -9.1041e-01,
          9.8655e-01, -6.2190e-01],
        ...,
        [-6.8584e-03, -7.7725e-01, -4.7841e-01,  ...,  1.0659e+00,
          7.5329e-01, -7.4068e-01],
        [-1.5319e-03, -2.4126e-01,  1.1897e+00,  ..., -2.3796e-01,
          5.4816e-03, -2.4479e+00],
        [-1.7260e-02,  4.3282e-02,  9.1374e-02,  ...,  1.2826e-01,
         -6.8140e-01,  1.1980e+00]], grad_fn=<CatBackward0>)

In [14]:
smilesProj = nn.Linear(384, 384)
smilesFeat = F.normalize(smilesProj(encodedSmiles[:,0,:]), dim=-1)

smiles_queue = torch.randn(384, 2048)
smiles_queue.clone().detach()

smilesFeatAll = torch.cat([smilesFeat.t(), smiles_queue.clone().detach()], dim=1)
propertyFeatAll

tensor([[-8.1473e-02,  1.2620e+00, -2.4469e-01,  ...,  4.6948e-01,
         -1.9055e+00,  4.4526e-02],
        [-5.4801e-03, -8.3393e-01, -7.7849e-01,  ..., -8.4742e-01,
         -3.9457e-01,  1.6535e+00],
        [-1.1821e-02,  1.0043e-01,  8.3092e-01,  ..., -9.1041e-01,
          9.8655e-01, -6.2190e-01],
        ...,
        [-6.8584e-03, -7.7725e-01, -4.7841e-01,  ...,  1.0659e+00,
          7.5329e-01, -7.4068e-01],
        [-1.5319e-03, -2.4126e-01,  1.1897e+00,  ..., -2.3796e-01,
          5.4816e-03, -2.4479e+00],
        [-1.7260e-02,  4.3282e-02,  9.1374e-02,  ...,  1.2826e-01,
         -6.8140e-01,  1.1980e+00]], grad_fn=<CatBackward0>)

In [15]:
sim_p2s = (propertyFeat @ smilesFeatAll)
sim_targets = torch.zeros(sim_p2s.size()).to(propertyOriginal.device)
print(sim_targets.shape)

sim_s2p = smilesFeat @ propertyFeatAll
sim_targets_same = torch.zeros(sim_s2p.size()).to(propertyOriginal.device)
print(sim_targets.shape)

torch.Size([1, 2049])
torch.Size([1, 2049])


In [16]:
outputProperty_pos = smilesEncoder.bert(encoder_embeds = encodedProp,
                               attention_mask = propertyAtts,
                               encoder_hidden_states = encodedSmiles,
                               encoder_attention_mask = attention_mask,
                               return_dict = True,
                               mode= 'fusion'
                               ).last_hidden_state[:,0,:]
outputSmiles_pos = smilesEncoder.bert(encoder_embeds = encodedSmiles,
                             attention_mask = attention_mask,
                             encoder_hidden_states = encodedProp,
                             encoder_attention_mask = propertyAtts,
                             return_dict = True,
                             mode= 'fusion'
                             ).last_hidden_state[:,0,:]
                                   

In [17]:
print(outputProperty_pos.shape)
print(outputSmiles_pos.shape)

torch.Size([1, 384])
torch.Size([1, 384])


In [18]:
batch_size = 1
weights_p2s = F.softmax(sim_p2s[:,:batch_size], dim=1)
weights_s2p = F.softmax(sim_s2p[:,:batch_size], dim=1)

In [19]:
encodedProp_neg = []
neg_idx = torch.multinomial(weights_p2s[0], 1).item()
encodedProp_neg.append(encodedProp[neg_idx])
encodedProp_neg = torch.stack(encodedProp_neg, dim=0)
encodedProp_neg

tensor([[[ 3.5329e-01, -1.0953e-01,  2.2853e-04,  ...,  1.3656e+00,
          -7.2193e-01,  1.2645e+00],
         [ 1.2098e-01,  5.5696e-01,  1.8438e-01,  ..., -8.4229e-01,
           1.8129e-01, -8.6456e-01],
         [-9.9603e-01,  5.1829e-01,  1.2027e+00,  ..., -6.8868e-01,
           3.8952e-01, -4.1334e-01],
         ...,
         [-1.0247e+00,  4.8322e-01,  1.1923e+00,  ..., -5.1068e-01,
           6.9675e-01, -3.3327e-01],
         [-1.3248e+00,  7.0152e-01,  1.1969e+00,  ..., -3.7592e-01,
           3.0733e-01, -3.8839e-01],
         [-8.4820e-01, -6.5878e-01,  1.9948e-01,  ...,  8.2771e-01,
          -1.0954e+00,  1.7565e+00]]], grad_fn=<StackBackward0>)

In [20]:
encodedSmiles_neg = []
smilesAtts_neg = []

neg_idx = torch.multinomial(weights_s2p[0], 1).item()
encodedSmiles_neg.append(encodedSmiles[neg_idx])
smilesAtts_neg.append(attention_mask[neg_idx])
encodedSmiles_neg = torch.stack(encodedSmiles_neg, dim=0)
smilesAtts_neg = torch.stack(smilesAtts_neg, dim=0)

encodedSmiles_neg

tensor([[[ 0.7729, -0.5451, -0.9563,  ...,  0.7439, -0.0239,  0.0329],
         [ 0.6894, -0.5174,  1.0443,  ..., -0.5177, -1.0379, -0.5036],
         [ 1.0968, -0.5394,  0.2805,  ...,  0.6875, -1.1454,  0.2677],
         ...,
         [ 1.1775, -0.7769, -0.1076,  ...,  0.4221, -0.9266,  0.1148],
         [-0.9466, -0.3234,  0.0563,  ..., -1.1961, -1.0710,  0.1078],
         [ 1.6689,  0.2071, -0.7971,  ..., -0.6911, -0.8274, -0.3584]]],
       grad_fn=<StackBackward0>)

In [21]:
encProperty_all = torch.cat([encodedProp, encodedProp_neg], dim=0)
propertyAtts_all = torch.cat([propertyAtts, propertyAtts], dim=0)

encSmiles_all = torch.cat([encodedSmiles_neg, encodedSmiles], dim=0)
smilesAtts_all = torch.cat([smilesAtts_neg, attention_mask], dim=0)

In [22]:
outputProperty_neg = smilesEncoder.bert(encoder_embeds = encProperty_all,
                                    attention_mask = propertyAtts_all,
                                    encoder_hidden_states = encSmiles_all,
                                    encoder_attention_mask = smilesAtts_all,
                                    return_dict = True,
                                    mode = 'fusion'
                                    ).last_hidden_state[:,0,:]
#outputProperty_neg = outputProperty_neg
outputSmiles_neg = smilesEncoder.bert(encoder_embeds = encSmiles_all,
                                    attention_mask = smilesAtts_all,
                                    encoder_hidden_states = encProperty_all,
                                    encoder_attention_mask = propertyAtts_all,
                                    return_dict = True,
                                    mode = 'fusion'
                                    ).last_hidden_state[:,0,:]
#outputSmiles_neg = outputSmiles_neg

In [23]:
pos_embeds = torch.cat([outputProperty_pos, outputSmiles_pos], dim=-1)
neg_embeds = torch.cat([outputProperty_neg, outputSmiles_neg], dim=-1)

ps_embeddings = torch.cat([pos_embeds, neg_embeds], dim=0)
ps_embeddings.shape
ps_output = nn.Linear(768, 2)(ps_embeddings)
ps_output

tensor([[ 0.0697,  0.6968],
        [-0.0017,  0.7012],
        [-0.2312,  0.6549]], grad_fn=<AddmmBackward0>)

In [24]:
psm_labels = torch.cat([torch.ones(1, dtype=torch.long), torch.zeros(2, dtype=torch.long)], dim=0).to(propertyOriginal.device)
psm_labels

tensor([1, 0, 0])

In [25]:
loss_psm = F.cross_entropy(ps_output, psm_labels)
loss_psm

tensor(0.9215, grad_fn=<NllLossBackward0>)

In [26]:
torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool()

tensor([[ True,  True, False, False,  True,  True,  True,  True,  True,  True,
         False,  True]])

In [27]:
logits_m  = smilesEncoder(input_ids,
                          attention_mask = attention_mask,
                          encoder_hidden_states = encodedProp,
                          encoder_attention_mask = propertyAtts,
                          return_dict = True,
                          is_decoder = True,
                          return_logits = True 
                          )[:,:-1,:]

logits_m.shape

torch.Size([1, 11, 300])

In [28]:
labels = input_ids.clone()[:,1:]
per_logits_m = logits_m.permute((0,2,1))

print(logits_m.shape)
print(per_logits_m.shape)
print(labels.shape)

loss_fct = nn.CrossEntropyLoss()
loss_nwp = loss_fct(per_logits_m,labels)

print(loss_nwp)

torch.Size([1, 11, 300])
torch.Size([1, 300, 11])
torch.Size([1, 11])
tensor(5.8270, grad_fn=<NllLoss2DBackward0>)


In [29]:
F.log_softmax(logits_m, dim=-1) * F.softmax(logits_m, dim=-1)

tensor([[[-0.0163, -0.0203, -0.0246,  ..., -0.0223, -0.0222, -0.0359],
         [-0.0231, -0.0214, -0.0131,  ..., -0.0189, -0.0156, -0.0178],
         [-0.0227, -0.0155, -0.0276,  ..., -0.0193, -0.0155, -0.0279],
         ...,
         [-0.0159, -0.0172, -0.0138,  ..., -0.0264, -0.0229, -0.0256],
         [-0.0195, -0.0200, -0.0153,  ..., -0.0151, -0.0234, -0.0110],
         [-0.0191, -0.0100, -0.0222,  ..., -0.0183, -0.0143, -0.0142]]],
       grad_fn=<MulBackward0>)

In [30]:
F.softmax(logits_m, dim=-1)
print(F.softmax(logits_m, dim=-1).shape)

torch.Size([1, 11, 300])


In [31]:
soft_labels = F.softmax(logits_m, dim=1)
loss_distill = -torch.sum(F.log_softmax(logits_m, dim=1) * soft_labels, dim=-1)

loss_distill = (loss_distill * (labels != 0)).mean()
loss_distill

tensor(64.2319, grad_fn=<MeanBackward0>)

In [32]:
targets = propertyOriginal.clone()

encProperty_masked_m = propertyEncoder.bert(inputs_embeds = inputProperty,
                                       return_dict = True,
                                       is_decoder = True
                                       ).last_hidden_state
encProperty_masked_m

tensor([[[ 0.3304,  0.4590, -0.5967,  ...,  1.2816, -0.1227,  1.2848],
         [-1.2938,  0.6364,  0.7300,  ..., -0.3683,  0.0276, -0.8365],
         [-0.9701, -0.5149,  1.2391,  ..., -0.7945,  0.4855, -0.5426],
         ...,
         [-0.9369,  0.4499,  1.2872,  ..., -1.0597,  0.2443, -0.2044],
         [-1.2789,  0.2818,  0.8195,  ..., -0.6183,  0.3485, -0.5829],
         [ 0.2594, -1.0350, -0.4422,  ...,  1.0007, -0.1293, -0.8713]]],
       grad_fn=<NativeLayerNormBackward0>)

In [33]:
npp_output_m = smilesEncoder.bert(encoder_embeds = encProperty_masked_m,
                                  attention_mask = propertyAtts,
                                  encoder_hidden_states = encodedSmiles,
                                  encoder_attention_mask = attention_mask,    
                                  return_dict = True,
                                  is_decoder = True,
                                  mode = 'fusion').last_hidden_state[:,:-1,:]

#npp_output_m
print(npp_output_m.shape)

torch.Size([1, 53, 384])


In [34]:
target_reg = nn.Sequential(nn.Linear(384, 384),
                           nn.GELU(),
                           nn.LayerNorm(384, property_config.layer_norm_eps),
                           nn.Linear(384, 1))

In [35]:
pred_m = target_reg(npp_output_m)
pred_m.squeeze(2).shape

torch.Size([1, 53])

In [36]:
propertyTargets = propertyOriginal.clone() 
propertyTargets.shape

torch.Size([1, 53])

In [37]:
loss_MSE = nn.MSELoss()
loss_npp = loss_MSE(pred_m*halfMask, propertyTargets*halfMask)
loss_npp

  return F.mse_loss(input, target, reduction=self.reduction)


tensor(0.7448, grad_fn=<MseLossBackward0>)

In [38]:
print(halfMask)
print(1-halfMask)

tensor([[1., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0.,
         0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1., 1., 1., 0.,
         0., 0., 1., 1., 0., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.]])
tensor([[0., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1.,
         1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1.,
         1., 1., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1.]])


In [39]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

pretrain_SPMM debugging

In [1]:
pretrain_config = {
        'embed_dim': 256,#256
        'property_width': 384, #???
        'batch_size': 4,#64
        'temp': 0.07,
        'queue_size': 2048,#65536
        'momentum': 0.995,
        'alpha': 0.4,
        'bert_config': './config_bert.json',    #config file for BERT model. The configuration for ViT can be manually changed in albef.py
        'schedular': {'sched': 'cosine', 'lr': 1e-4, 'epochs': 30, 'min_lr': 1e-5,
                      'decay_rate': 1, 'warmup_lr': 1e-5, 'warmup_epochs': 20, 'cooldown_epochs': 0},
        'optimizer': {'opt': 'adamW', 'lr': 1e-4, 'weight_decay': 0.02}
    }

In [2]:
import torch
import torch.nn.functional as F
from torch import nn 
from xbert import BertConfig, BertForMaskedLM
from transformers import BertTokenizer

from pretrain_SPMM import SPMM
from transformers import BertTokenizer
from dataset import SMILESDataset

sampleDataset = SMILESDataset(data_path='./data/pubchem-1m-simple.txt', data_length=None, shuffle=False)
sample = sampleDataset.__getitem__(0)

tokenizer = BertTokenizer(vocab_file= "./vocab_bpe_300.txt" ,lowercase=False, do_basic_tokenize=False)
sample_SMILES = sample[0]
sample_prop = sample[1]

SMILES_token = tokenizer(sample_SMILES)
input_ids = torch.LongTensor([SMILES_token['input_ids']])
attention_mask = torch.Tensor([SMILES_token['attention_mask']])

propertyOriginal = sample_prop.unsqueeze(0)

model = SPMM(config=pretrain_config)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model(propertyOriginal, input_ids, attention_mask, alpha=0)

TypeError: BertForMaskedLM.forward() got an unexpected keyword argument 'return_dic'