In [2]:
import numpy as np
from transformers import AutoModelForMaskedLM, AutoTokenizer, RobertaModel, RobertaTokenizer, AutoModelForSequenceClassification
from transformers import TrainingArguments
from transformers import get_scheduler

import torch
from torch import tensor
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pandas as pd
from torch import nn
from sklearn.model_selection import train_test_split
import torch.optim as optim


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = AutoModelForSequenceClassification.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k", output_attentions=True, output_hidden_states=True, num_labels = 1)
tokenizer = AutoTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k")


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at seyonec/PubChem10M_SMILES_BPE_450k and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
#Folder to save the model checkpoints
training_args = TrainingArguments(output_dir="test_trainer")

In [5]:
#load wipf dataframe
df = pd.read_excel('../data/wipf_27.6.xlsx')

In [6]:
df

Unnamed: 0.1,Unnamed: 0,_Reaction__reactants,_Reaction__solvents,_Reaction__products,_Reaction__rxn_smiles,_Reaction__volume,_Reaction__conditions,_Reaction__class_id,_Reaction__fps,_Reaction__mhfp,_Reaction__id,_Reaction__quantity,has_solvent,product_yield,n_products,has_yield,has_yield_and_solvent,solvent_smiles,has_solvent_smiles,has_yield_and_solvent_smiles
0,0,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",Cc1ccc(S(=O)(=O)NC[C@@H](COCc2ccccc2)OC(=O)OC(...,0.000000,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,NB-00320.026,0,False,85.20,1,True,False,,False,False
1,1,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...","[{'_Solvent__name': 'dimethylformamide', '_Sol...","[{'_Molecule__inchi': '', 'name': '', 'smiles'...",OC[C@H]1CO1.BrCc1ccccc1.C1=CC(=CN=C1)C(=O)O>>c...,100.000000,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,NB-00320.025,0,True,59.91,1,True,True,CN(C)C=O,True,True
2,2,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...","[{'_Solvent__name': 'dichloromethane', '_Solve...","[{'_Molecule__inchi': '', 'name': '', 'smiles'...",Cc1ccc(S(=O)(=O)N(C[C@@H](O)COCc2ccccc2)C[C@@H...,0.288208,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,NB-00320.024,0,True,31.99,1,True,True,C(Cl)Cl,True,True
3,3,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",c1ccc(COC[C@@H]2CO2)cc1.Cc1ccc(S(=O)(=O)NC(=O)...,0.000000,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,NB-00320.023,0,False,30.30,1,True,False,,False,False
4,4,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...","[{'_Solvent__name': 'methanol', '_Solvent__vol...","[{'_Molecule__inchi': '', 'name': '', 'smiles'...",COC(=O)[C@@]1(C)[C@H]2C(=O)N(C)C(=O)[C@H]2[C@@...,10.114293,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,NB-00305.045,0,True,50.00,1,True,True,CO,True,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18214,18214,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",Nc1[nH]c(=O)[nH]c(=O)c1S(=O)(=O)Cl.N#Cc1ccccc1...,0.239240,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,NB-00001.004,0,False,,1,False,False,,False,False
18215,18215,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",Nc1[nH]c(=O)[nH]c(=O)c1S(=O)(=O)Cl.COc1ccccc1N...,0.234090,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,NB-00001.003,0,False,3.20,1,True,False,,False,False
18216,18216,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",Nc1[nH]c(=O)[nH]c(=O)c1S(=O)(=O)Cl.Nc1ccccc1.c...,0.212445,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,NB-00001.002,0,False,,1,False,False,,False,False
18217,18217,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '4-(6-Amino-...",Nc1[nH]c(=O)[nH]c(=O)c1S(=O)(=O)Cl.CCOC(=O)c1c...,0.537600,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,NB-00001.001,0,False,0.45,1,True,False,,False,False


In [7]:
df['rxn_smiles_with_solvent'] = df['_Reaction__rxn_smiles'] + df['solvent_smiles']

In [8]:
def add_solvent_to_rxn(rxn, solvent):
        if not (isinstance(rxn, str) and isinstance(solvent, str)):
             return None
        if len(solvent)>0 and len(rxn)>0:
            splitted_rxn = rxn.split('>>')
            #print(splitted_rxn)
            if len(splitted_rxn)==3:
                #if reactants+agents=products then add solvent to agents
                agents = splitted_rxn[1]
                agents += "." + solvent
                splitted_rxn = splitted_rxn[0] + '>>' + agents + '>>' + splitted_rxn[2]
            elif len(splitted_rxn)== 2:
                #if reactants = products then add solvent as only agent
                agents = solvent
                splitted_rxn = splitted_rxn[0] + '>>'+  agents + '>>' + splitted_rxn[1]
            else:
                print('failed adding solvent')
                splitted_rxn = None

        return splitted_rxn



In [9]:
df['rxn_smiles_with_solvent'] = df[['_Reaction__rxn_smiles', 'solvent_smiles']].apply(lambda x: add_solvent_to_rxn(x[0], x[1]), axis = 1)

  df['rxn_smiles_with_solvent'] = df[['_Reaction__rxn_smiles', 'solvent_smiles']].apply(lambda x: add_solvent_to_rxn(x[0], x[1]), axis = 1)


In [10]:
df

Unnamed: 0.1,Unnamed: 0,_Reaction__reactants,_Reaction__solvents,_Reaction__products,_Reaction__rxn_smiles,_Reaction__volume,_Reaction__conditions,_Reaction__class_id,_Reaction__fps,_Reaction__mhfp,...,_Reaction__quantity,has_solvent,product_yield,n_products,has_yield,has_yield_and_solvent,solvent_smiles,has_solvent_smiles,has_yield_and_solvent_smiles,rxn_smiles_with_solvent
0,0,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",Cc1ccc(S(=O)(=O)NC[C@@H](COCc2ccccc2)OC(=O)OC(...,0.000000,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,...,0,False,85.20,1,True,False,,False,False,
1,1,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...","[{'_Solvent__name': 'dimethylformamide', '_Sol...","[{'_Molecule__inchi': '', 'name': '', 'smiles'...",OC[C@H]1CO1.BrCc1ccccc1.C1=CC(=CN=C1)C(=O)O>>c...,100.000000,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,...,0,True,59.91,1,True,True,CN(C)C=O,True,True,OC[C@H]1CO1.BrCc1ccccc1.C1=CC(=CN=C1)C(=O)O>>C...
2,2,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...","[{'_Solvent__name': 'dichloromethane', '_Solve...","[{'_Molecule__inchi': '', 'name': '', 'smiles'...",Cc1ccc(S(=O)(=O)N(C[C@@H](O)COCc2ccccc2)C[C@@H...,0.288208,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,...,0,True,31.99,1,True,True,C(Cl)Cl,True,True,Cc1ccc(S(=O)(=O)N(C[C@@H](O)COCc2ccccc2)C[C@@H...
3,3,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",c1ccc(COC[C@@H]2CO2)cc1.Cc1ccc(S(=O)(=O)NC(=O)...,0.000000,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,...,0,False,30.30,1,True,False,,False,False,
4,4,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...","[{'_Solvent__name': 'methanol', '_Solvent__vol...","[{'_Molecule__inchi': '', 'name': '', 'smiles'...",COC(=O)[C@@]1(C)[C@H]2C(=O)N(C)C(=O)[C@H]2[C@@...,10.114293,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,...,0,True,50.00,1,True,True,CO,True,True,COC(=O)[C@@]1(C)[C@H]2C(=O)N(C)C(=O)[C@H]2[C@@...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18214,18214,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",Nc1[nH]c(=O)[nH]c(=O)c1S(=O)(=O)Cl.N#Cc1ccccc1...,0.239240,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,...,0,False,,1,False,False,,False,False,
18215,18215,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",Nc1[nH]c(=O)[nH]c(=O)c1S(=O)(=O)Cl.COc1ccccc1N...,0.234090,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,...,0,False,3.20,1,True,False,,False,False,
18216,18216,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",Nc1[nH]c(=O)[nH]c(=O)c1S(=O)(=O)Cl.Nc1ccccc1.c...,0.212445,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,...,0,False,,1,False,False,,False,False,
18217,18217,"[{'_Molecule__inchi': '', 'name': '', 'smiles'...",[],"[{'_Molecule__inchi': '', 'name': '4-(6-Amino-...",Nc1[nH]c(=O)[nH]c(=O)c1S(=O)(=O)Cl.CCOC(=O)c1c...,0.537600,"{'pressure': None, 'temperature': None, 'durat...",0,[],0,...,0,False,0.45,1,True,False,,False,False,


In [11]:
df['has_yield_and_rxn_solvent_smiles'] = df[['has_yield', 'rxn_smiles_with_solvent']].apply(lambda x: True if x[0] and isinstance(x[1],str) else False, axis=1)


  df['has_yield_and_rxn_solvent_smiles'] = df[['has_yield', 'rxn_smiles_with_solvent']].apply(lambda x: True if x[0] and isinstance(x[1],str) else False, axis=1)


In [12]:
df['rxn_smiles_with_solvent_len'] = df['rxn_smiles_with_solvent'].apply(lambda x: len(x) if isinstance(x, str) else 0)


In [13]:
df['product_yield'] = df['product_yield']/100

In [14]:
#model.classifier = nn.Sequential(
#    nn.Linear(768, 258),
#    nn.Linear(258, 1),
#    nn.Sigmoid()
#)
model.classifier

RobertaClassificationHead(
  (dense): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (out_proj): Linear(in_features=768, out_features=1, bias=True)
)

In [15]:
#Preparing dataset
class smilesDataset(Dataset):
        def __init__(self, df) -> None:
                super().__init__()
                self.smiles = df['rxn_smiles_with_solvent']
                self.labels = df['product_yield']
        
        def __len__(self):
                return self.smiles.shape[0]
        
        def __getitem__(self, index) -> None:
                return tensor(tokenizer.encode(self.smiles.iloc[index], padding="max_length", max_length=512, truncation= True)), tensor(self.labels.iloc[index])


In [31]:
mask = df['has_yield_and_rxn_solvent_smiles']== True
train_df, test_df = train_test_split(df[mask])
train_dataset = smilesDataset(train_df)
eval_dataset = smilesDataset(test_df)

In [32]:
batch_size = 1
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)

eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)

In [33]:
df[mask]['rxn_smiles_with_solvent'].shape

(7520,)

In [34]:
df[mask]['product_yield'].shape

(7520,)

In [35]:
for X, y in eval_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([1, 512])
Shape of y: torch.Size([1]) torch.float64


In [36]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)
model.to(device)
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr= 0.001)

cpu


In [37]:

num_epochs = 3

num_training_steps = num_epochs * len(train_dataloader)

lr_scheduler = get_scheduler(

    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps

)

In [38]:
from tqdm.auto import tqdm
progress_bar = tqdm(range(num_training_steps))

model.train()

for epoch in range(num_epochs):
    epoch_cum_loss = 0
    for i, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        X, y = batch[0].to(device), batch[1].to(device)
        attention_mask = X>1
        pred = torch.sigmoid(model(X, attention_mask=attention_mask).logits)
        pred = pred.reshape(y.shape)
        loss = criterion(pred.float(), y.float())
        epoch_cum_loss += loss
        print(f'batch number {i}, batch cumloss {epoch_cum_loss/(i+1)}')
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        progress_bar.update(1)
        
            



  0%|          | 0/16920 [00:00<?, ?it/s]

batch number 0, batch cumloss 0.01737332157790661


  0%|          | 1/16920 [00:02<9:41:38,  2.06s/it]

batch number 1, batch cumloss 0.08317232131958008


  0%|          | 2/16920 [00:04<9:23:07,  2.00s/it]

batch number 2, batch cumloss 0.15166504681110382


  0%|          | 3/16920 [00:06<9:43:41,  2.07s/it]

batch number 3, batch cumloss 0.11374939233064651


  0%|          | 4/16920 [00:07<8:59:27,  1.91s/it]

batch number 4, batch cumloss 0.09099951386451721


  0%|          | 5/16920 [00:09<8:43:13,  1.86s/it]

batch number 5, batch cumloss 0.08141861110925674


  0%|          | 6/16920 [00:11<9:12:59,  1.96s/it]

batch number 6, batch cumloss 0.10849408060312271


  0%|          | 7/16920 [00:13<9:23:19,  2.00s/it]

batch number 7, batch cumloss 0.09872269630432129


  0%|          | 8/16920 [00:16<9:55:04,  2.11s/it]

batch number 8, batch cumloss 0.10943450778722763


  0%|          | 9/16920 [00:19<11:05:50,  2.36s/it]

batch number 9, batch cumloss 0.12335677444934845


  0%|          | 10/16920 [00:21<11:20:21,  2.41s/it]

batch number 10, batch cumloss 0.15288445353507996


  0%|          | 11/16920 [00:23<10:53:06,  2.32s/it]

batch number 11, batch cumloss 0.16245684027671814


  0%|          | 12/16920 [00:26<11:23:48,  2.43s/it]

batch number 12, batch cumloss 0.15104401111602783


  0%|          | 13/16920 [00:29<11:48:48,  2.52s/it]

batch number 13, batch cumloss 0.1513596922159195


  0%|          | 14/16920 [00:32<13:14:43,  2.82s/it]

batch number 14, batch cumloss 0.15569165349006653


KeyboardInterrupt: 

In [22]:
tokenizer.pad(X)

You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <class 'str'>.

In [25]:
X>1

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, F

In [30]:
tokenizer.decode(X.reshape(512))

'<s>CC(=O)C1=CSC=C1.NO.Cl.C(=O)([O-])[O-].[Na+].[Na+]>>O>>C/C(=N/O)c1ccsc1</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><p

In [29]:
X.shape

torch.Size([1, 512])