In [51]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, RobertaTokenizerFast, RobertaForSequenceClassification, pipeline
import json
import numpy as np
import uuid
import os

In [47]:
ttd_data_folder = "../BioT5/data/"
text2mol_train_set, mol2text_train_set = dict(), dict()
with open(ttd_data_folder + "tasks/task201_COVID_drug_generation_train.json") as fm2t,  open(ttd_data_folder + "tasks/task202_COVID_drug_generation_train.json") as ft2m:
    text2mol_train_set = json.load(ft2m)
    mol2text_train_set = json.load(fm2t)
train_size = len(text2mol_train_set["Instances"])
train_mols = [row["input"][5:-5] for row in mol2text_train_set["Instances"]]

In [48]:
generation_tokenizer = T5Tokenizer.from_pretrained("QizhiPei/biot5-base-text2mol", model_max_length=512)
generation_model = T5ForConditionalGeneration.from_pretrained('../models/biot5-base-text2mol-finetuned', use_safetensors=True)

# pred_tokenizer = RobertaTokenizerFast.from_pretrained("../model/chemberta")
# pred_model = RobertaForSequenceClassification.from_pretrained("../model/chemberta")
# classifier = pipeline("text-classification", model = "../models/chemberta")

task_definition = 'Definition: You are given a molecule description in English. Your job is to generate the molecule SELFIES that fits the description.\n\n'
text_input = 'The molecule is a COVID-19 drug candidate.'
task_input = f'Now complete the following example -\nInput: {text_input}\nOutput: '

model_input = task_definition + task_input
input_ids = generation_tokenizer(model_input, return_tensors="pt").input_ids

generation_config = generation_model.generation_config
generation_config.max_length = 512
generation_config.num_beams = 1
generation_size = train_size * 4 #x5 original training set in total

mol_selfies = []
mol_smiles = []
dupes = 0

for i in range(generation_size):  
    outputs = generation_model.generate(input_ids, generation_config=generation_config)
    output_selfies = generation_tokenizer.decode(outputs[0], skip_special_tokens=True).replace(' ', '')
    # print(output_selfies)
    mol_selfies.append(output_selfies)
    import selfies as sf
    output_smiles = sf.decoder(output_selfies)
    mol_smiles.append(output_smiles)
    # print(output_smiles)
    if mol_selfies in train_mols:
        dupes += 1

# mol_activity = [1 if result["label"] == 'LABEL_1' else 0 for result in classifier(mol_smiles)]
# predicted_activity_rate = np.average(mol_activity)
# print(predicted_activity_rate)
# print(dupes/generation_size)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


1.0
0.0


In [53]:
generation_tokenizer = T5Tokenizer.from_pretrained("QizhiPei/biot5-base-mol2text", model_max_length=512)
generation_model = T5ForConditionalGeneration.from_pretrained('../models/biot5-base-mol2text-finetuned', use_safetensors=True)

generation_config = generation_model.generation_config
generation_config.max_length = 512
generation_config.num_beams = 1

task_definition = 'Definition: You are given a molecule SELFIES. Your job is to generate the molecule description in English that fits the molecule SELFIES.\n\n'

mol_descs = []

for mol in mol_selfies:
    task_input = f'Now complete the following example -\nInput: <bom>{mol}<eom>\nOutput: '
    model_input = task_definition + task_input
    input_ids = generation_tokenizer(model_input, return_tensors="pt").input_ids
    outputs = generation_model.generate(input_ids, generation_config=generation_config)
    desc = generation_tokenizer.decode(outputs[0], skip_special_tokens=True)
    mol_descs.append(desc)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [54]:
mol2text_task_id = 201 #arbitrary id number, biot5's task ids go up to 180. 
text2mol_task_id = 202

os.makedirs("../BioT5/data/splits/covid/mol2text_augmented/", exist_ok=True)
os.makedirs("../BioT5/data/splits/covid/text2mol_augmented/", exist_ok=True)

for i in range(len(mol_descs)):
    mol2text_train_set["Instances"].append({
    "id": f"task{mol2text_task_id}-{uuid.uuid4().hex}",
    "input": f"<bom>{mol_selfies[i]}<eom>",
    "output": [mol_descs[i]]
    })
    text2mol_train_set["Instances"].append({
        "id": f"task{text2mol_task_id}-{uuid.uuid4().hex}",
        "input": mol_descs[i],
        "output": [f"<bom>{mol_selfies[i]}<eom>"]
    })
with open(f'../BioT5/data/tasks/task{mol2text_task_id}_COVID_drug_generation_train_augmented.json', "w") as out:
    json.dump(mol2text_train_set, out, indent=4)
with open(f'../BioT5/data/tasks/task{text2mol_task_id}_COVID_drug_generation_train_augmented.json', "w") as out:
    json.dump(text2mol_train_set, out, indent=4)

for dset in ["train_augmented", "validation", "test"]:
    with open(f'../BioT5/data/splits/covid/mol2text_augmented/{dset}_tasks.txt', "w") as out:
        out.write(f'task{mol2text_task_id}_COVID_drug_generation_{dset}')
    with open(f'../BioT5/data/splits/covid/text2mol_augmented/{dset}_tasks.txt', "w") as out:
        out.write(f'task{text2mol_task_id}_COVID_drug_generation_{dset}')