## Virtual Screening with SmartBind
Here we take a RNA sequence as an input, and a list of SMILES strings as a ligand library. We will obtain a sorted 
list of ligands based on their binding scores predicted by SmartBind.

In [None]:
import sys
sys.path.append("..") 
import pandas as pd
from smartbind.preprocess import convert_smiles_to_pf2
from smartbind import load_smartbind_models
from smartbind import logger
from torch.nn.functional import cosine_similarity
import tqdm

#### Set up inputs

In [None]:
input_rna_chain = 'GACAGCUGCUGUC'
smiles_txt_path = 'ligand_library.txt'  # Example SMILES file
ensembled_models_path = '../SMARTBind_weight'
save_path = 'binding_score.txt'
device = 'cpu'
batch_size = 10000

#### Make prediction

In [None]:
with open(smiles_txt_path, 'r') as f:
    smiles_list = f.read().splitlines()
smol_fp2_list = [convert_smiles_to_pf2(smol_smiles) for smol_smiles in smiles_list]

In [None]:
logger.info('Get SmartBind pre-trained model objects')
logger.info(f'Loading models from {ensembled_models_path}')
smartbind_models = load_smartbind_models(
    model_path=ensembled_models_path,
    device=device,
    vs_mode=True
    )

In [None]:
rank_result_by_models = {}

for model in tqdm.tqdm(smartbind_models, desc='Predicting binding scores by models'):
    rna_embed = model.inference_single_rna(input_rna_chain)
    
    rank_result_by_models[smartbind_models.index(model)] = []
    num_batches = len(smol_fp2_list) // batch_size
    if len(smol_fp2_list) % batch_size != 0:
        num_batches += 1
    for i in tqdm.tqdm(range(num_batches), desc='Batching ligands'):
        start = i * batch_size
        end = min((i + 1) * batch_size, len(smol_fp2_list))
        ligand_embeds = model.inference_list_smols(smol_fp2_list[start:end])
        similarities = cosine_similarity(rna_embed, ligand_embeds).tolist()
        rank_result_by_models[smartbind_models.index(model)].extend(similarities)

#### Save the binding scores

In [None]:
df = pd.DataFrame(rank_result_by_models)
df['average'] = df.mean(axis=1)
df.index = smiles_list
with open(save_path, 'w') as f:
    f.write('Ligand_ID\tBinding_Score\n')
    for i in range(len(df)):
        f.write(f'{df.index[i]}\t{df.iloc[i].average}\n')