In [3]:
import torch 
import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns 
from transformers import BertTokenizer, BertTokenizerFast,BertForSequenceClassification,Trainer,TrainingArguments
import pandas as pd 
from datasets import Dataset 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [10]:
data = pd.read_csv('balanced_dataset.csv')
data.head()

Unnamed: 0,molecule_smiles,protein_name,binds,molecule,ecfp
0,O=C1CCCc2ccc(Nc3nc(Nc4ncns4)nc(N[C@H](Cc4ccccc...,sEH,0,<rdkit.Chem.rdchem.Mol object at 0x000001B4813...,"[0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
1,O=C(N[Dy])[C@H](Cc1cccnc1)Nc1nc(Nc2cn[nH]c2)nc...,HSA,1,<rdkit.Chem.rdchem.Mol object at 0x000001B4813...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,O=C1CN(CCCNc2nc(NCC3CCC(C(=O)N[Dy])CC3)nc(Nc3c...,sEH,1,<rdkit.Chem.rdchem.Mol object at 0x000001B4813...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,CC(C)CC(Nc1nc(NCc2cccc3c2OCO3)nc(Nc2cccc3ncccc...,BRD4,1,<rdkit.Chem.rdchem.Mol object at 0x000001B4813...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,O=C(N[Dy])c1ccc(Nc2nc(NCc3cc(C(F)(F)F)co3)nc(N...,sEH,0,<rdkit.Chem.rdchem.Mol object at 0x000001B4813...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."


In [11]:
data['input'] = data['molecule_smiles']+ '[SEP]' + data['protein_name']

In [12]:
data = data.drop(['molecule_smiles','protein_name','molecule','ecfp'],axis=1)
data.head()

Unnamed: 0,binds,input
0,0,O=C1CCCc2ccc(Nc3nc(Nc4ncns4)nc(N[C@H](Cc4ccccc...
1,1,O=C(N[Dy])[C@H](Cc1cccnc1)Nc1nc(Nc2cn[nH]c2)nc...
2,1,O=C1CN(CCCNc2nc(NCC3CCC(C(=O)N[Dy])CC3)nc(Nc3c...
3,1,CC(C)CC(Nc1nc(NCc2cccc3c2OCO3)nc(Nc2cccc3ncccc...
4,0,O=C(N[Dy])c1ccc(Nc2nc(NCc3cc(C(F)(F)F)co3)nc(N...


In [13]:
dataset = Dataset.from_pandas(data)

In [14]:
dataset 

Dataset({
    features: ['binds', 'input'],
    num_rows: 30000
})

In [15]:
dataset = dataset.rename_columns({'input':'text','binds':'label'})
dataset 

Dataset({
    features: ['label', 'text'],
    num_rows: 30000
})

In [16]:
dataset = dataset.train_test_split(test_size=0.2)

In [17]:
dataset 

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 24000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 6000
    })
})

In [18]:
train_df = dataset['train']
test_df = dataset['test']

In [19]:
id2label = {0:'No bind',1:'Bind'}
label2id = {'No bind':0,'Bind':1}

In [20]:
checkpoint = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(checkpoint)
model = BertForSequenceClassification.from_pretrained(checkpoint,num_labels=2,label2id=label2id,id2label=id2label)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
def preprocess_function(examples):
    return tokenizer(examples['text'],truncation=True)

In [22]:
tokenize_train = train_df.map(preprocess_function,batched=True)
tokenize_test = test_df.map(preprocess_function,batched=True)

Map:   0%|          | 0/24000 [00:00<?, ? examples/s]

Map:   0%|          | 0/6000 [00:00<?, ? examples/s]

In [23]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [24]:
import evaluate
accuracy = evaluate.load("accuracy")

In [25]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [29]:
training_args = TrainingArguments(
    output_dir='Belka-BERT',
    overwrite_output_dir=True,
    num_train_epochs=4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    logging_dir='logs',
    report_to='wandb',
    learning_rate=2e-5,
)

In [30]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenize_train,
    eval_dataset=tokenize_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [32]:
trainer.train()

  0%|          | 0/12000 [00:00<?, ?it/s]

{'loss': 0.4, 'grad_norm': 1.85660719871521, 'learning_rate': 1.916666666666667e-05, 'epoch': 0.17}
{'loss': 0.3965, 'grad_norm': 3.433006763458252, 'learning_rate': 1.8333333333333333e-05, 'epoch': 0.33}
{'loss': 0.405, 'grad_norm': 5.180685043334961, 'learning_rate': 1.7500000000000002e-05, 'epoch': 0.5}
{'loss': 0.4052, 'grad_norm': 5.497239112854004, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.67}
{'loss': 0.3743, 'grad_norm': 9.171781539916992, 'learning_rate': 1.5833333333333333e-05, 'epoch': 0.83}
{'loss': 0.3772, 'grad_norm': 12.729491233825684, 'learning_rate': 1.5000000000000002e-05, 'epoch': 1.0}


  0%|          | 0/750 [00:00<?, ?it/s]

{'eval_loss': 0.3557247817516327, 'eval_accuracy': 0.8701666666666666, 'eval_runtime': 9.7668, 'eval_samples_per_second': 614.326, 'eval_steps_per_second': 76.791, 'epoch': 1.0}
{'loss': 0.3577, 'grad_norm': 22.10750389099121, 'learning_rate': 1.416666666666667e-05, 'epoch': 1.17}
{'loss': 0.3792, 'grad_norm': 21.68719482421875, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.33}
{'loss': 0.3557, 'grad_norm': 2.89447021484375, 'learning_rate': 1.25e-05, 'epoch': 1.5}
{'loss': 0.3584, 'grad_norm': 7.0240702629089355, 'learning_rate': 1.1666666666666668e-05, 'epoch': 1.67}
{'loss': 0.3453, 'grad_norm': 2.610649824142456, 'learning_rate': 1.0833333333333334e-05, 'epoch': 1.83}
{'loss': 0.3531, 'grad_norm': 6.551897048950195, 'learning_rate': 1e-05, 'epoch': 2.0}


  0%|          | 0/750 [00:00<?, ?it/s]

{'eval_loss': 0.3554113805294037, 'eval_accuracy': 0.8765, 'eval_runtime': 10.2717, 'eval_samples_per_second': 584.127, 'eval_steps_per_second': 73.016, 'epoch': 2.0}
{'loss': 0.3292, 'grad_norm': 7.090052127838135, 'learning_rate': 9.166666666666666e-06, 'epoch': 2.17}
{'loss': 0.304, 'grad_norm': 38.330440521240234, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.33}
{'loss': 0.3428, 'grad_norm': 12.561302185058594, 'learning_rate': 7.500000000000001e-06, 'epoch': 2.5}
{'loss': 0.3247, 'grad_norm': 5.386176109313965, 'learning_rate': 6.666666666666667e-06, 'epoch': 2.67}
{'loss': 0.3209, 'grad_norm': 0.47272154688835144, 'learning_rate': 5.833333333333334e-06, 'epoch': 2.83}
{'loss': 0.3079, 'grad_norm': 25.598342895507812, 'learning_rate': 5e-06, 'epoch': 3.0}


  0%|          | 0/750 [00:00<?, ?it/s]

{'eval_loss': 0.37017762660980225, 'eval_accuracy': 0.888, 'eval_runtime': 21.3521, 'eval_samples_per_second': 281.003, 'eval_steps_per_second': 35.125, 'epoch': 3.0}
{'loss': 0.29, 'grad_norm': 1.1194514036178589, 'learning_rate': 4.166666666666667e-06, 'epoch': 3.17}
{'loss': 0.2907, 'grad_norm': 17.43212127685547, 'learning_rate': 3.3333333333333333e-06, 'epoch': 3.33}
{'loss': 0.2908, 'grad_norm': 13.86556339263916, 'learning_rate': 2.5e-06, 'epoch': 3.5}
{'loss': 0.2945, 'grad_norm': 1.5739933252334595, 'learning_rate': 1.6666666666666667e-06, 'epoch': 3.67}
{'loss': 0.2979, 'grad_norm': 69.04976654052734, 'learning_rate': 8.333333333333333e-07, 'epoch': 3.83}
{'loss': 0.2688, 'grad_norm': 5.3405022621154785, 'learning_rate': 0.0, 'epoch': 4.0}


  0%|          | 0/750 [00:00<?, ?it/s]

{'eval_loss': 0.3811909556388855, 'eval_accuracy': 0.8923333333333333, 'eval_runtime': 9.606, 'eval_samples_per_second': 624.61, 'eval_steps_per_second': 78.076, 'epoch': 4.0}
{'train_runtime': 871.4642, 'train_samples_per_second': 110.159, 'train_steps_per_second': 13.77, 'train_loss': 0.34040866724650065, 'epoch': 4.0}


TrainOutput(global_step=12000, training_loss=0.34040866724650065, metrics={'train_runtime': 871.4642, 'train_samples_per_second': 110.159, 'train_steps_per_second': 13.77, 'train_loss': 0.34040866724650065, 'epoch': 4.0})

In [33]:
##Inference
dataset 

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 24000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 6000
    })
})

In [39]:
train_df[100]

{'label': 1,
 'text': 'Cc1conc1CNc1nc(Nc2cccnc2C)nc(N[C@H](CC(=O)N[Dy])c2ccc(Cl)cc2)n1[SEP]BRD4'}

In [35]:
from transformers import pipeline

In [56]:
classifier = pipeline('text-classification',model="Belka-BERT\checkpoint-12000",tokenizer=tokenizer,device=0) #device=0 for GPU

In [57]:
classifier(train_df['text'][100])

[{'label': 'Bind', 'score': 0.9913012385368347}]

In [4]:
test_df = pd.read_parquet('test.parquet')

In [5]:
test_df.head()

Unnamed: 0,id,buildingblock1_smiles,buildingblock2_smiles,buildingblock3_smiles,molecule_smiles,protein_name
0,295246830,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,BRD4
1,295246831,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,HSA
2,295246832,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,C=Cc1ccc(N)cc1,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C...,sEH
3,295246833,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,CC(O)Cn1cnc2c(N)ncnc21,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2...,BRD4
4,295246834,C#CCCC[C@H](NC(=O)OCC1c2ccccc2-c2ccccc21)C(=O)O,C=Cc1ccc(N)cc1,CC(O)Cn1cnc2c(N)ncnc21,C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2...,HSA


In [7]:
test_df['protein_name'][10]

'HSA'

In [44]:
text = test_df['molecule_smiles'][0] + '[SEP]' + test_df['protein_name'][0]
print(text)

C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C)cc2)n1)C(=O)N[Dy][SEP]BRD4


In [50]:
classifier(text)[0]

{'label': 'No bind', 'score': 0.9229413866996765}

In [46]:
test_df['inputs'] = test_df['molecule_smiles'] + '[SEP]' + test_df['protein_name']

In [47]:
inputs = test_df['inputs'].tolist()

In [48]:
inputs 

['C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C)cc2)n1)C(=O)N[Dy][SEP]BRD4',
 'C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C)cc2)n1)C(=O)N[Dy][SEP]HSA',
 'C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C)cc2)n1)C(=O)N[Dy][SEP]sEH',
 'C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2ncn3CC(C)O)n1)C(=O)N[Dy][SEP]BRD4',
 'C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2ncn3CC(C)O)n1)C(=O)N[Dy][SEP]HSA',
 'C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ncnc3c2ncn3CC(C)O)n1)C(=O)N[Dy][SEP]sEH',
 'C#CCCC[C@H](Nc1nc(NCC2(O)CCCC2(C)C)nc(Nc2ccc(C=C)cc2)n1)C(=O)N[Dy][SEP]BRD4',
 'C#CCCC[C@H](Nc1nc(NCC2(O)CCCC2(C)C)nc(Nc2ccc(C=C)cc2)n1)C(=O)N[Dy][SEP]HSA',
 'C#CCCC[C@H](Nc1nc(NCC2(O)CCCC2(C)C)nc(Nc2ccc(C=C)cc2)n1)C(=O)N[Dy][SEP]sEH',
 'C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2sc(Cl)cc2C(=O)OC)n1)C(=O)N[Dy][SEP]BRD4',
 'C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2sc(Cl)cc2C(=O)OC)n1)C(=O)N[Dy][SEP]HSA',
 'C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2sc(Cl)cc2C(=O)OC)n1)C(=O)N[Dy][SEP]sEH',
 'C#CCCC[C@H](Nc1nc(NCC2CCC(SC)CC

In [53]:
test_dataset = Dataset.from_dict({'text':inputs})

In [54]:
test_dataset

Dataset({
    features: ['text'],
    num_rows: 1674896
})

In [61]:
classifier(test_dataset['text'][:1000])

[{'label': 'No bind', 'score': 0.9229413866996765},
 {'label': 'No bind', 'score': 0.9387506246566772},
 {'label': 'No bind', 'score': 0.9653714895248413},
 {'label': 'No bind', 'score': 0.9771221280097961},
 {'label': 'No bind', 'score': 0.9615176916122437},
 {'label': 'No bind', 'score': 0.9913405776023865},
 {'label': 'No bind', 'score': 0.994140088558197},
 {'label': 'No bind', 'score': 0.9922695159912109},
 {'label': 'No bind', 'score': 0.9973917007446289},
 {'label': 'No bind', 'score': 0.9797585010528564},
 {'label': 'No bind', 'score': 0.9690374135971069},
 {'label': 'No bind', 'score': 0.9950790405273438},
 {'label': 'No bind', 'score': 0.9045153856277466},
 {'label': 'No bind', 'score': 0.9516867995262146},
 {'label': 'Bind', 'score': 0.8041315078735352},
 {'label': 'No bind', 'score': 0.9332420825958252},
 {'label': 'No bind', 'score': 0.9557273387908936},
 {'label': 'No bind', 'score': 0.9679782390594482},
 {'label': 'No bind', 'score': 0.945064902305603},
 {'label': 'No bi