In [5]:
import gradio as gr
import numpy as np 
import torch
from transformers import BertTokenizerFast,BertForSequenceClassification

In [6]:
def load_model(checkpoint_path):
    tokenizer = BertTokenizerFast.from_pretrained(checkpoint_path)
    model =  BertForSequenceClassification.from_pretrained(checkpoint_path)
    model.eval();
    return model, tokenizer

In [14]:
def predict_binding(smiles,protein_name):
    text = f'{smiles} [SEP] {protein_name}'
    tokens = tokenizer(text, truncation=True,return_tensors="pt")
    with torch.no_grad():
        logits = model(**tokens).logits
    predicted_class_id = logits.argmax().item()
    class_name = model.config.id2label[predicted_class_id]
    score = logits.softmax(dim=1).max().item()
    return class_name, score

In [8]:
model,tokenizer = load_model("Belka-BERT\checkpoint-12000")

In [15]:
example_smile = 'C#CCCC[C@H](Nc1nc(Nc2ccc(C=C)cc2)nc(Nc2ccc(C=C)cc2)n1)C(=O)N[Dy]'
example_protein = 'BRD4' 

pred,score = predict_binding(example_smile,example_protein)


In [16]:
print(pred)
print(score)

No bind
0.9229413866996765


In [17]:
def gradio_interface(smiles, protein):
    prediction,scores = predict_binding(smiles, protein)
    return prediction, scores

In [23]:
iface = gr.Interface(
    fn=gradio_interface,
    inputs=[gr.Textbox(label="SMILES String"), gr.Textbox(label="Protein Name")],
    outputs=[gr.Label(label="Binding Prediction"), gr.Label(label="Confidence Score")],
    title="Protein-Molecule Binding Prediction",
    description="Predicts whether a molecule will bind to a protein."   
)

IMPORTANT: You are using gradio version 3.48.0, however version 4.29.0 is available, please upgrade.
--------


In [24]:
iface.launch(share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://279ab9b621df636972.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




In [22]:
iface.close()

Closing server running on port: 7860
