In [1]:
import torch
from torch.utils.data import Dataset
import random
from rdkit import Chem
import pickle

from calc_property import calculate_property

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SMILESDataset(Dataset):
    def __init__(self, data_path, data_length=None, shuffle=False):
        with open(data_path,'r') as f:
            lines = f.readlines()
        self.data = [l.strip() for l in lines]

        with open('./normalize.pkl', 'rb') as w:
            norm = pickle.load(w)
        self.property_mean, self.property_std = norm

        if shuffle:
            random.shuffle(self.data)
        
        ## Why need this line? ##
        if data_length is not None:
            self.data = self.data[data_length[0]: data_length[1]]
            
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        smiles = 'Q' + self.data[idx]
        properties = (calculate_property(smiles[1:])-self.property_std) / self.property_mean
        
        return smiles, properties

In [3]:
sampleDataset = SMILESDataset(data_path='./data/pubchem-1m-simple.txt', data_length=None, shuffle=False)
sample = sampleDataset.__getitem__(0)

sample

('QCN(c1ccccc1)c1ccccc1C(=O)NCC1(O)CCOCC1',
 tensor([  0.5682,   0.4030,   0.6777,   0.6901,   0.6560,   0.6904,   0.6890,
           0.6158,   0.6544,   0.3731,   0.6623,  -1.0214,   0.5386,  -6.4296,
           0.6370,   0.7405,   0.7696,   0.8429,   0.2686,   1.6701,   0.6852,
           0.6222,   0.6510,   0.5670,   0.3475,   0.6824,   0.8768,   0.8768,
          -0.1836,   2.4966,   0.1489,   0.7014,   0.6364,   0.2755,   0.4956,
          -2.3909,   0.4995,   0.0625,   0.7835,  -1.1521,   0.3690,   0.4610,
           0.6192,   0.3638, -29.1460,   0.2716,  -2.7021,   0.9223,   0.2647,
           0.6883,   0.5710,   0.4035,   1.0784]))

In [4]:
with open('./data/pubchem-1m-simple.txt', 'r') as f:
    lines = f.readlines()
    
with open('./normalize.pkl', 'rb') as w:
    norm = pickle.load(w)
    print(norm)

    

(tensor([ 2.0210e+00,  7.5875e+02,  1.7544e+01,  1.4052e+01,  1.4723e+01,
         1.1664e+01,  8.1625e+00,  8.8090e+00,  6.2020e+00,  6.9684e+00,
         4.2730e+00,  4.9467e+00,  2.9353e+00,  3.5438e+00,  3.5288e+02,
         1.1951e+00,  1.9034e+00,  2.5342e+00,  4.1357e-01, -2.1272e+00,
         2.4343e+01,  3.3127e+02,  1.7561e+01,  7.6679e+00,  4.3503e+00,
         1.4605e+02,  1.1184e+01,  1.1184e+01,  1.9147e-01, -9.6040e-01,
         2.6799e+00,  9.4141e+01,  3.5333e+02,  1.7290e+00,  5.1585e+00,
         2.6400e-01,  5.4614e-01,  8.1014e-01,  1.1892e+00,  7.1781e-01,
         1.9070e+00,  4.0690e+00,  1.3547e+00,  6.2828e+00,  4.5800e-03,
         5.4205e+00,  2.0371e-01,  3.9774e-01,  6.0145e-01,  1.2960e+02,
         2.7171e+00,  6.8247e+01,  6.1739e-01]), tensor([5.9750e-01, 4.0613e+02, 5.8108e+00, 4.7454e+00, 4.7843e+00, 4.0368e+00,
        2.9280e+00, 3.1273e+00, 2.3723e+00, 3.8311e+00, 1.8252e+00, 9.7076e+00,
        1.4360e+00, 2.5803e+01, 1.1541e+02, 2.3501e-01, 2.95

#### Define Custom Tokenizer 

In [9]:
import pandas as pd
from transformers import BertTokenizer

tokenizer = BertTokenizer(vocab_file= "./vocab_bpe_300.txt" ,lowercase=False, do_basic_tokenize=False)

In [10]:
df = pd.read_fwf('./vocab_bpe_300.txt', header=None)
df

Unnamed: 0,0
0,[PAD]
1,[UNK]
2,[CLS]
3,[SEP]
4,[MASK]
...,...
295,##c12
296,##[Si
297,##c(C(=O
298,##[nH+]


In [31]:
# from rdkit import Chem
# from transformers import BertTokenizer

# from typing import List
# import re

# ## REGEX_PATTERN ##
# SMI_REGEX_PATTERN =  r"(\%\([0-9]{3}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\||\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"

# class RegexTokenizer:
#     """Run regex tokenization"""

#     def __init__(self, regex_pattern: str=SMI_REGEX_PATTERN) -> None:
#         """Constructs a RegexTokenizer.
#         Args:
#             regex_pattern: regex pattern used for tokenization.
#             suffix: optional suffix for the tokens. Defaults to "".
#         """
#         self.regex_pattern = regex_pattern
#         self.regex = re.compile(self.regex_pattern)

#     def tokenize(self, text: str) -> List[str]:
#         """Regex tokenization.
#         Args:
#             text: text to tokenize.
#         Returns:
#             extracted tokens separated by spaces.
#         """
#         tokens = [token for token in self.regex.findall(text)]
#         return tokens
    
    
# class SMILESTokenizer(BertTokenizer):
#     def __init__(self, 
#         vocab_file: str,
#         unk_token: str = "[UNK]",
#         sep_token: str = "[SEP]",
#         pad_token: str = "[PAD]",
#         cls_token: str = "[CLS]",
#         mask_token: str = "[MASK]",
#         do_lower_case = False,
#         **kwargs,
#         ) -> None:
        
#         super().__init__(
#             vocab_file=vocab_file,
#             unk_token=unk_token,
#             sep_token=sep_token,
#             pad_token=pad_token,
#             cls_token=cls_token,
#             mask_token=mask_token,
#             do_lower_case=do_lower_case,
#             **kwargs,
#         )
        
#         self.tokenizer = RegexTokenizer()

In [5]:
pretrain_config = {
        'embed_dim': 256,#256
        'property_width': 384, #???
        'batch_size': 4,#64
        'temp': 0.07,
        'queue_size': 2048,#65536
        'momentum': 0.995,
        'alpha': 0.4,
        'bert_config': './config_bert.json',    #config file for BERT model. The configuration for ViT can be manually changed in albef.py
        'schedular': {'sched': 'cosine', 'lr': 1e-4, 'epochs': 30, 'min_lr': 1e-5,
                      'decay_rate': 1, 'warmup_lr': 1e-5, 'warmup_epochs': 20, 'cooldown_epochs': 0},
        'optimizer': {'opt': 'adamW', 'lr': 1e-4, 'weight_decay': 0.02}
    }

In [95]:
# SMILES Sequence tokenizer
import torch
import torch.nn.functional as F
from torch import nn 
from xbert import BertConfig 
from transformers import BertTokenizer, BertForMaskedLM


class SPMM(nn.Module):
    def __init__(self,
                 tokenizer=None,
                 config=None,
                 ):
        super().__init__()

        self.tokenizer = BertTokenizer('./vocab_bpe_300.txt', do_lower_case=False,do_basic_tokenize=False)
        embed_dim = config['embed_dim']

        smilesAndFusion_config = BertConfig.from_json_file('./config_bert_smiles_and_fusion_encoder.json')
        property_config = BertConfig.from_json_file('./config_bert_property_encoder.json')
        self.smilesEncoder = BertForMaskedLM(config = smilesAndFusion_config)
        self.propertyEncoder = BertForMaskedLM(config = property_config)

        smilesWidth = self.smilesEncoder.config.hidden_size
        propertyWidth = config['property_width']

        self.smilesProj = nn.Linear(smilesWidth, embed_dim)
        self.propertyProj = nn.Linear(propertyWidth, embed_dim)

        # special tokens & embedding for property input
        self.propertyEmbed = nn.Linear(1, propertyWidth)
        self.property_CLS = nn.Parameter(torch.zeros([1, 1, propertyWidth]))
        self.property_MASK = nn.Parameter(torch.zeros([1, 1, propertyWidth]))
        

        self.temp = nn.Parameter(torch.ones([]) * config['temp'])
        self.queue_size = config['queue_size']
        self.momentum = config['momentum']
        
        self.itm_head_smiles = nn.Linear(smilesWidth, 2)
        self.itm_head_properties = nn.Linear(propertyWidth, 2)

        # Momentum Model

        self.smilesEncoder_m = BertForMaskedLM(config = smilesAndFusion_config)
        self.propertyEncoder_m = BertForMaskedLM(config = property_config)
        self.smilesProj_m = nn.Linear(smilesWidth, embed_dim)
        self.propertyProj_m = nn.Linear(propertyWidth, embed_dim)

        self.model_pairs = [[self.smilesEncoder, self.smilesEncoder_m],
                            [self.smilesProj, self.smilesProj_m],
                            [self.propertyEncoder, self.propertyEncoder_m],
                            [self.propertyProj, self.propertyProj_m]]
        
        self.copy_params()

        # Create the queue
        self.register_buffer("smiles_queue", torch.randn(embed_dim, self.queue_size))
        self.register_buffer("text_queue", torch.randn(embed_dim, self.queue_size))
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

        self.property_queue = nn.functional.normalize(self.property_queue, dim=0)
        self.smiles_queue = nn.functional.normalize(self.smiles_queue, dim=0)



    def forward(self, property, smilesIds, smilesAttentionMask, alpha=0):
        
        with torch.no_grad():
            self.temp.clamp_(0.001, 0.5)

        #1. property tokenizing & embedding
        embedProperty = self.propertyEmbed(property.unsqueeze(2))
        
        property_MASK = self.property_MASK.expand(property.size(0), property.size(1), -1)
        halfMask = torch.bernoulli(torch.ones_like(property)*0.5)
        halfMaskBatch = halfMask.unsqueeze(2).repeat(1,1,property_MASK.size(2))
        maskedProperty = embedProperty* halfMaskBatch 
        inputProperty = torch.cat([self.property_CLS.expand(property.size(0), -1,-1), maskedProperty], dim=1)

        #2. input through encoders
        encProperty = self.propertyEncoder(inputs_embeds=inputProperty, return_dict=True).last_hidden_state
        propertyAtts = torch.ones(encProperty.size()[:-1], dtype=torch.long).to(inputProperty.device)
        propertyFeat = F.normalize(self.propertyProj(encProperty[:,0,:]), dim=-1)

        encSmiles = self.smilesEncoder.bert(smilesIds, attention_mask=smilesAttentionMask, return_dict=True).last_hidden_State
        smilesFeat = F.normalize(self.smilesProj(encSmiles[:,0,:]), dims=-1)
        
        #3. Contrastive Loss between the different & within the same modalities
        with torch.no_grad():
            self._momentum_update()
            
            encProperty_m = self.propertyEnoder_m(inputs_embeds=inputProperty, return_dic=True).last_hidden_state
            propertyAtts_m = torch.ones(encProperty_m.size()[:-1], dtype=torch.long).to(inputProperty.device)
            propertyFeat_m = F.normalize(self.propertyProj_m(encProperty_m[:,0,:]), dim=-1)
            propertyFeatAll = torch.cat([propertyFeat_m.t(), self.property_queue.clone().detach()], dim=1)
            
            encSmiles_m = self.smilesEncoder_m.bert(smilesIds,attention_mask=smilesAttentionMask, return_dict=True).last_hidden_state
            smilesFeat_m = F.normalize(self.smilesProj_m(encSmiles[:,0,:]), dim=-1)
            smilesFeatAll = torch.cat([smilesFeat_m.t(), self.smiles_queue.clone().detach()], dim=1)

            sim_p2s_m = propertyFeat_m @ smilesFeatAll / self.temp
            sim_s2p_m = smilesFeat_m @ propertyFeatAll / self.temp
            sim_p2p_m = propertyFeat_m @ propertyFeatAll / self.temp
            sim_s2s_m = smilesFeat_m @ smilesFeatAll / self.temp

            ## Make Target ##
            sim_targets_diff = torch.zeros(sim_p2s_m.size()).to(property.device)
            sim_targets_diff.fill_diagonal_(1)
            sim_targets_same = torch.zeros(sim_p2p_m.size()).to(property.device)
            sim_targets_same.fill_diagonal_(1)
          
            sim_p2s_targets = alpha * F.softmax(sim_p2s_m, dim=1) + (1-alpha) * sim_targets_diff
            sim_s2p_targets = alpha * F.softmax(sim_s2p_m, dim=1) + (1-alpha) * sim_targets_diff
            sim_p2p_targets = alpha * F.softmax(sim_p2p_m, dim=1) + (1-alpha) * sim_targets_same
            sim_s2s_targets = alpha * F.softmax(sim_s2s_m, dim=1) + (1-alpha) * sim_targets_same

        sim_p2s = propertyFeat @ smilesFeatAll / self.temp
        sim_s2p = smilesFeat @ propertyFeatAll / self.temp
        sim_p2p = propertyFeat @ propertyFeatAll / self.temp
        sim_s2s = smilesFeat @ smilesFeatAll / self.temp 

        loss_p2s = -torch.sum(F.log_softmax(sim_p2s, dim=1)*sim_p2s_targets, dim=1).mean()
        loss_s2p = -torch.sum(F.log_softmax(sim_s2p, dim=1)*sim_s2p_targets, dim=1).mean()
        loss_p2p = -torch.sum(F.log_softmax(sim_p2p, dim=1)*sim_p2p_targets, dim=1).mean()
        loss_s2s = -torch.sum(F.log_softmax(sim_p2p, dim=1)*sim_s2s_targets, dim=1).mean()

        loss_psc = (loss_p2s + loss_s2p + loss_p2p + loss_s2s)/2 

        self._dequeue_and_enqueue(propertyFeat_m, smilesFeat_m)

        #4. X-attention
        

        #5. Next property prediction

        #6. Next word prediction

        #7. SMILES-property matching 
    
    @torch.no_grad()
    def copy_params(self):
        for model_pair in self.model_pairs:
            for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
                param_m.data.copy_(param.data)  # initialize
                param_m.requires_grad = False
    
    @torch.no_grad()
    def _momentum_update(self):
        for model_pair in self.model_pairs:
            for param, param_m in zip(model_pair[0].parameter(), model_pair[1].parameter()):
                param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, property_feat, smiles_feat):
        property_feats = concat_all_gather(property_feat)
        smiles_feats = concat_all_gather(smiles_feat)

        batch_size = property_feats.shape[0]

        ptr = int(self.queue_ptr)
        assert self.queue_size % batch_size == 0

        self.property_queue[:, ptr:ptr + batch_size] = property_feats.T
        self.smiles_queue[:, ptr:ptr + batch_size] = smiles_feats.T
        ptr = (ptr + batch_size) & self.queue_size 

        self.queue_ptr[0] = ptr

@torch.no_grad()
def concat_all_gather(tensor):
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

------------------------------------------------------------------------------

In [11]:
sample_SMILES = sample[0]
sample_prop = sample[1]

SMILES_token = tokenizer(sample_SMILES)
input_ids = torch.LongTensor([SMILES_token['input_ids']])
print(input_ids)
attention_mask = torch.Tensor([SMILES_token['attention_mask']])
print(attention_mask)

tensor([[  2,   5, 146,   8, 212, 112, 115,  98, 222, 114,  98,   3]])
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])


In [43]:
from xbert import BertConfig 
from transformers import BertTokenizer, BertForMaskedLM

smilesAndFusion_config = BertConfig.from_json_file('./config_bert_smiles_and_fusion_encoder.json')
smilesEncoder = BertForMaskedLM(config = smilesAndFusion_config)

property_config = BertConfig.from_json_file('./config_bert_property_encoder.json')
propertyEncoder = BertForMaskedLM(config = property_config)


In [44]:
import torch 
from torch import nn 

propertyOriginal = torch.rand([1,53])
embedProperty = nn.Linear(1, 384)(propertyOriginal.unsqueeze(2))
#print(embedProperty)

property_CLS = nn.Parameter(torch.zeros([1,1,384]))
property_MASK = nn.Parameter(torch.zeros([1,1,384]))
property_MASK = property_MASK.expand(propertyOriginal.size(0), propertyOriginal.size(1), -1)

halfMask = torch.bernoulli(torch.ones_like(propertyOriginal)*0.5)
halfMaskBatch = halfMask.unsqueeze(2).repeat(1,1,property_MASK.size(2))
maskedProperty = embedProperty*halfMaskBatch + property_MASK*halfMaskBatch

print((embedProperty*halfMaskBatch))
#print(property_MASK*(1-halfMaskBatch))

inputProperty = torch.cat([property_CLS.expand(propertyOriginal.size(0), -1,-1), maskedProperty], dim=1)

#print(inputProperty.shape)

tensor([[[ 0.4342,  0.9335,  1.0022,  ..., -0.1214, -0.9573,  0.0120],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000,  0.0000],
         [ 0.4628,  1.6241,  1.3159,  ...,  0.1610, -0.8612,  0.0640],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ..., -0.0000, -0.0000, -0.0000]]],
       grad_fn=<MulBackward0>)


In [47]:
encodedProp = propertyEncoder.bert(inputs_embeds=inputProperty, return_dict=True).last_hidden_state
print(encodedProp.shape)
encodedSmiles = smilesEncoder.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True).last_hidden_state
print(encodedSmiles.shape)

torch.Size([1, 54, 384])
torch.Size([1, 12, 384])


In [48]:
import torch.nn.functional as F 

propertyProj = nn.Linear(384, 384)

propertyAtts = torch.ones(encodedProp.size()[:-1], dtype=torch.long).to(inputProperty.device)
propertyFeat = F.normalize(propertyProj(encodedProp[:,0,:]), dim=-1)

property_queue = torch.randn(384, 2048)
property_queue.clone().detach()

propertyFeatAll = torch.cat([propertyFeat.t(), property_queue.clone().detach()], dim=1)
propertyFeatAll

tensor([[ 2.1823e-02,  3.7101e-01, -2.9748e-02,  ..., -9.2558e-03,
         -8.8169e-01, -2.2586e-01],
        [-1.0048e-02, -1.2382e+00, -3.4466e-02,  ...,  6.1353e-01,
         -5.9733e-01, -1.2387e+00],
        [-3.7422e-04, -4.3158e-01,  1.9985e+00,  ...,  2.3472e-01,
         -7.2809e-01,  3.7164e-01],
        ...,
        [ 5.6999e-02, -4.2010e-01,  1.7926e-01,  ..., -1.2600e+00,
          1.7435e+00, -6.6406e-01],
        [ 4.6999e-02, -9.1359e-01,  1.2246e-01,  ...,  2.9834e+00,
         -1.1048e+00, -1.1820e+00],
        [-1.0473e-01, -1.1189e+00,  3.1662e-01,  ...,  2.8940e+00,
          3.0964e-02, -1.6373e-01]], grad_fn=<CatBackward0>)

In [49]:
smilesProj = nn.Linear(384, 384)
smilesFeat = F.normalize(smilesProj(encodedSmiles[:,0,:]), dim=-1)

smiles_queue = torch.randn(384, 2048)
smiles_queue.clone().detach()

smilesFeatAll = torch.cat([smilesFeat.t(), smiles_queue.clone().detach()], dim=1)
propertyFeatAll

tensor([[ 2.1823e-02,  3.7101e-01, -2.9748e-02,  ..., -9.2558e-03,
         -8.8169e-01, -2.2586e-01],
        [-1.0048e-02, -1.2382e+00, -3.4466e-02,  ...,  6.1353e-01,
         -5.9733e-01, -1.2387e+00],
        [-3.7422e-04, -4.3158e-01,  1.9985e+00,  ...,  2.3472e-01,
         -7.2809e-01,  3.7164e-01],
        ...,
        [ 5.6999e-02, -4.2010e-01,  1.7926e-01,  ..., -1.2600e+00,
          1.7435e+00, -6.6406e-01],
        [ 4.6999e-02, -9.1359e-01,  1.2246e-01,  ...,  2.9834e+00,
         -1.1048e+00, -1.1820e+00],
        [-1.0473e-01, -1.1189e+00,  3.1662e-01,  ...,  2.8940e+00,
          3.0964e-02, -1.6373e-01]], grad_fn=<CatBackward0>)

In [81]:
sim_p2s = (propertyFeat @ smilesFeatAll)
sim_targets = torch.zeros(sim_p2s.size()).to(propertyOriginal.device)
print(sim_targets.shape)

sim_s2p = smilesFeat @ propertyFeatAll
sim_targets_same = torch.zeros(sim_s2p.size()).to(propertyOriginal.device)
print(sim_targets.shape)

torch.Size([1, 2049])
torch.Size([1, 2049])


In [78]:
outputProperty_pos = smilesEncoder.bert(inputs_embeds = encodedProp,
                               attention_mask = propertyAtts,
                               encoder_hidden_states = encodedSmiles,
                               encoder_attention_mask = attention_mask,
                               return_dict = True
                               ).last_hidden_state[:,0,:]
outputSmiles_pos = smilesEncoder.bert(inputs_embeds = encodedSmiles,
                             attention_mask = attention_mask,
                             encoder_hidden_states = encodedProp,
                             encoder_attention_mask = propertyAtts,
                             return_dict = True
                             ).last_hidden_state[:,0,:]
                                   

In [80]:
print(outputProperty_pos.shape)
print(outputSmiles_pos.shape)

torch.Size([1, 384])
torch.Size([1, 384])


In [83]:
batch_size = 1
weights_p2s = F.softmax(sim_p2s[:,:batch_size], dim=1)
weights_s2p = F.softmax(sim_s2p[:,:batch_size], dim=1)

In [103]:
encodedProp_neg = []
neg_idx = torch.multinomial(weights_p2s[0], 1).item()
encodedProp_neg.append(encodedProp[neg_idx])
encodedProp_neg = torch.stack(encodedProp_neg, dim=0)
encodedProp_neg

tensor([[[-1.7271, -1.5947, -2.0154,  ...,  0.0620, -0.7125, -0.1401],
         [ 0.4209,  0.9124,  0.4805,  ..., -0.4680, -1.8028,  0.2023],
         [-0.1412, -1.7926, -2.4075,  ..., -0.4396, -2.6582, -0.0066],
         ...,
         [ 0.1181,  0.5215, -1.7301,  ...,  0.4303, -1.9432, -2.0211],
         [-0.4499,  0.0432, -1.5228,  ..., -0.6334, -1.4950,  0.2614],
         [ 0.8617, -0.8574, -0.6076,  ..., -0.0319, -1.1773, -0.8150]]],
       grad_fn=<StackBackward0>)

In [104]:
encodedSmiles_neg = []
smilesAtts_neg = []

neg_idx = torch.multinomial(weights_s2p[0], 1).item()
encodedSmiles_neg.append(encodedSmiles[neg_idx])
smilesAtts_neg.append(attention_mask[neg_idx])
encodedSmiles_neg = torch.stack(encodedSmiles_neg, dim=0)
smilesAtts_neg = torch.stack(smilesAtts_neg, dim=0)

encodedSmiles_neg

tensor([[[ 1.5407,  2.3170, -0.0661,  ..., -0.7965, -1.6035, -0.4798],
         [ 0.9947,  1.2220, -1.7151,  ..., -0.7796, -0.6755,  0.8914],
         [ 1.4788,  0.9501, -0.7413,  ..., -1.0248, -1.2226, -0.6620],
         ...,
         [-0.4692,  0.7824, -0.2676,  ..., -0.4672, -2.1646, -0.0563],
         [ 0.7964,  1.0937, -1.0817,  ...,  0.4254, -0.1040, -0.9483],
         [-1.6384, -0.5343, -0.6197,  ..., -0.3521, -1.5425,  0.5943]]],
       grad_fn=<StackBackward0>)

In [105]:
encProperty_all = torch.cat([encodedProp, encodedProp_neg], dim=0)
propertyAtts_all = torch.cat([propertyAtts, propertyAtts], dim=0)

encSmiles_all = torch.cat([encodedSmiles_neg, encodedSmiles], dim=0)
smilesAtts_all = torch.cat([smilesAtts_neg, attention_mask], dim=0)

In [106]:
outputProperty_neg = smilesEncoder.bert(inputs_embeds = encProperty_all,
                                    attention_mask = propertyAtts_all,
                                    encoder_hidden_states = encSmiles_all,
                                    encoder_attention_mask = smilesAtts_all,
                                    return_dict = True
                                    ).last_hidden_state[:,0,:]
#outputProperty_neg = outputProperty_neg
outputSmiles_neg = smilesEncoder.bert(inputs_embeds = encSmiles_all,
                                    attention_mask = smilesAtts_all,
                                    encoder_hidden_states = encProperty_all,
                                    encoder_attention_mask = encProperty_all,
                                    return_dict = True
                                    ).last_hidden_state[:,0,:]
#outputSmiles_neg = outputSmiles_neg

In [114]:
pos_embeds = torch.cat([outputProperty_pos, outputSmiles_pos], dim=-1)
neg_embeds = torch.cat([outputProperty_neg, outputSmiles_neg], dim=-1)

ps_embeddings = torch.cat([pos_embeds, neg_embeds], dim=0)
ps_embeddings.shape
ps_output = nn.Linear(768, 2)(ps_embeddings)
ps_output

tensor([[-1.1989, -0.7762],
        [-1.2333, -0.6026],
        [-1.3641, -0.4682]], grad_fn=<AddmmBackward0>)

In [113]:
psm_labels = torch.cat([torch.ones(1, dtype=torch.long), torch.zeros(2, dtype=torch.long)], dim=0).to(propertyOriginal.device)
psm_labels

tensor([1, 0, 0])

In [116]:
loss_psm = F.cross_entropy(ps_output, psm_labels)
loss_psm

tensor(0.9332, grad_fn=<NllLossBackward0>)