In [46]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, LlamaForCausalLM
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.model_selection import train_test_split

import json
import pandas as pd
import ast
from bs4 import BeautifulSoup
import os, contextlib, io
from tqdm.notebook import tqdm
import csv


In [2]:
dataset = pd.read_csv('datasets/staging_test_set.csv')

dataset['EDAM Topics'] = dataset['EDAM Topics'].apply(ast.literal_eval)

dataset['Abstract'] = dataset['Abstract'].apply(lambda x: BeautifulSoup(x, 'html.parser').get_text())

In [3]:
with open('templates/open_source_template.txt', 'r') as template_file:
    template = template_file.read()

In [4]:
with open('EDAM/edam_topics.txt', 'r') as edam_file:
    full_edam_topics = edam_file.readlines()

full_edam_topics = [topic.strip() for topic in full_edam_topics]

In [5]:
# Add EDAM topics to prompt template

formatted_topics = "\n".join(full_edam_topics)
template = template.replace("<topics>", formatted_topics)


In [6]:
with open("config.json", "r") as config_file:
    config = json.load(config_file)

In [7]:
# API key from HuggingFace
hf_token = config['api_keys']['huggingface']

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
model_storage_path = '/nvme/models'

## Testing Open Source Models

No need to run every cell, just the cells regarding the model you want to load. 

### Meditron 7b

In [None]:
# Check if model is already downloaded

folder_path = f"{model_storage_path}/meditron-7b-model"

if os.path.exists(folder_path):
    print("The model is already downloaded. Loading from", folder_path)
    tokenizer = AutoTokenizer.from_pretrained(f"{model_storage_path}/meditron-7b-tokenizer")
    model = AutoModelForCausalLM.from_pretrained(f"{model_storage_path}/meditron-7b-model", device_map='auto')
else:
    tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b", token=hf_token)
    tokenizer.save_pretrained(f"{model_storage_path}/meditron-7b-tokenizer", from_pt=True)

    model = AutoModelForCausalLM.from_pretrained("epfl-llm/meditron-7b", from_tf=True, device_map='auto')
    model.save_pretrained(f"{model_storage_path}/meditron-7b-model")

### Meditron 70b

In [None]:
folder_path = f"{model_storage_path}/meditron-70b-model"

if os.path.exists(folder_path):
    print("The model is already downloaded. Loading from", folder_path)
    tokenizer = AutoTokenizer.from_pretrained(f"{model_storage_path}/meditron-70b-tokenizer")
    model = AutoModelForCausalLM.from_pretrained(f"{model_storage_path}/meditron-70b-model", device_map='auto', load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
else:
    tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-70b", token=hf_token)
    tokenizer.save_pretrained(f"{model_storage_path}/meditron-70b-tokenizer", from_pt=True)

    model = AutoModelForCausalLM.from_pretrained("epfl-llm/meditron-70b", from_tf=True, device_map='auto', load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
    model.save_pretrained(f"{model_storage_path}/meditron-70b-model")

### Mixtral 8x7b

In [8]:
folder_path = f"{model_storage_path}/mixtral-8x7b-model"

if os.path.exists(folder_path):
    print("The model is already downloaded. Loading from", folder_path)
    tokenizer = AutoTokenizer.from_pretrained(f"{model_storage_path}/mixtral-8x7b-tokenizer")
    model = AutoModelForCausalLM.from_pretrained(f"{model_storage_path}/mixtral-8x7b-model", device_map="auto", load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
else:
    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
    tokenizer.save_pretrained(f"{model_storage_path}/mixtral-8x7b-tokenizer", from_pt=True)

    model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token, device_map="auto", load_in_4bit=True)
    model.save_pretrained(f"{model_storage_path}/mixtral-8x7b-model", from_pt=True)

The model is already downloaded. Loading from /nvme/models/mixtral-8x7b-model


Loading checkpoint shards:   0%|          | 0/39 [00:00<?, ?it/s]

### ClinicalGPT

In [None]:
folder_path = f"{model_storage_path}/clinicalgpt-model"

if os.path.exists(folder_path):
    print("The model is already downloaded. Loading from", folder_path)
    tokenizer = AutoTokenizer.from_pretrained(f"{model_storage_path}/clinicalgpt-tokenizer")
    model = AutoModelForCausalLM.from_pretrained(f"{model_storage_path}/clinicalgpt-model", device_map="auto", load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
else:
    tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalGPT-base-zh", token=hf_token)
    tokenizer.save_pretrained(f"{model_storage_path}/clinicalgpt-tokenizer", from_pt=True)

    model = AutoModelForCausalLM.from_pretrained("medicalai/ClinicalGPT-base-zh", token=hf_token, device_map="auto", load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
    model.save_pretrained(f"{model_storage_path}/clinicalgpt-model", from_pt=True)

In [None]:
# Only compatible with some models. See https://huggingface.co/docs/transformers/perf_infer_gpu_one

# model = model.to_bettertransformer()

## Test Sample

Test the model on a single sample.

In [17]:
random_sample = dataset.sample(n=1)
random_sample

Unnamed: 0,PMID,Description,Abstract,MeSH Terms,Filtered MeSH Terms,EDAM Topics
1815,24711662,C1013G/CXCR4 variant has been inserted into BC...,The C-X-C chemokine receptor type 4 (CXCR4) pl...,"['*Drug Resistance, Neoplasm', '*Mutation, Mis...","['Animals', 'Cell Proliferation', 'Disease-Fre...","[Zoology, Human biology, DNA mutation, Genetic..."


In [18]:
prompt = template.replace('<abstract>', random_sample['Abstract'].values[0])

## Output 10 topics if there are less than 10 topics
# num_terms = len(random_sample['EDAM Topics'].values[0]) if len(random_sample['EDAM Topics'].values[0]) > 10 else 10

num_terms = len(random_sample['EDAM Topics'].values[0])
prompt = prompt.replace('<num_terms>', str(num_terms))

prompt


'An abstract associated with a scientific dataset is quoted below:\n\n"The C-X-C chemokine receptor type 4 (CXCR4) plays a crucial role in modulating cell trafficking in hematopoietic stem cells and clonal B cells. We screened 418 patients with B-cell lymphoproliferative disorders and described the presence of the C1013G/CXCR4 warts, hypogammaglobulinemia, infections, and myelokathexis-associated mutation in 28.2% (37/131) of patients with lymphoplasmacytic lymphoma (Waldenstrom macroglobulinemia [WM]), being either absent or present in only 7% of other B-cell lymphomas. In vivo functional characterization demonstrates its activating role in WM cells, as demonstrated by significant tumor proliferation and dissemination to extramedullary organs, leading to disease progression and decreased survival. The use of a monoclonal antibody anti-CXCR4 led to significant tumor reduction in a C1013G/CXCR4 WM model, whereas drug resistance was observed in mutated WM cells exposed to Bruton\'s tyros

In [19]:
model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
model_inputs

{'input_ids': tensor([[    1,  1094, 11576,  ...,   725,   585,    13]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], device='cuda:0')}

In [20]:
generated_ids = model.generate(**model_inputs, max_new_tokens=2500, do_sample=True, pad_token_id=tokenizer.eos_token_id)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [21]:
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
output

'An abstract associated with a scientific dataset is quoted below:\n\n"The C-X-C chemokine receptor type 4 (CXCR4) plays a crucial role in modulating cell trafficking in hematopoietic stem cells and clonal B cells. We screened 418 patients with B-cell lymphoproliferative disorders and described the presence of the C1013G/CXCR4 warts, hypogammaglobulinemia, infections, and myelokathexis-associated mutation in 28.2% (37/131) of patients with lymphoplasmacytic lymphoma (Waldenstrom macroglobulinemia [WM]), being either absent or present in only 7% of other B-cell lymphomas. In vivo functional characterization demonstrates its activating role in WM cells, as demonstrated by significant tumor proliferation and dissemination to extramedullary organs, leading to disease progression and decreased survival. The use of a monoclonal antibody anti-CXCR4 led to significant tumor reduction in a C1013G/CXCR4 WM model, whereas drug resistance was observed in mutated WM cells exposed to Bruton\'s tyros

In [22]:
parsed_output = output.split('\n')[-1]
parsed_output

'Allergy, clinical immunology and immunotherapeutics, Biochemistry, Biomedical science, Biomolecular simulation, Biophysics, Cancer biology, Genetics, Genomics, Molecular biology, Medical genetics.'

In [32]:
def test_subset(num_tests, seed=42):
    random_samples = dataset.sample(n=num_tests, random_state=seed)
    
    for _, random_sample in random_samples.iterrows():

        prompt = template.replace('<abstract>', random_sample['Abstract'])
        num_terms = len(random_sample['EDAM Topics'])
        prompt = prompt.replace('<num_terms>', str(num_terms))

        model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
        generated_ids = model.generate(**model_inputs, max_new_tokens=2500, do_sample=True, pad_token_id=tokenizer.eos_token_id)
        output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        parsed_output = output.split('\n')[-1]

        print("Abstract:", random_sample['Abstract'])
        print("Model Output:", parsed_output)
        print()

test_subset(10)

Abstract: Understanding gene regulation requires knowledge of changes in transcription factor (TF) activities. Simultaneous direct measurement of numerous TF activities is currently impossible. Nevertheless, statistical approaches to infer TF activities have yielded non-trivial and verifiable predictions for individual TFs. Here, global statistical modelling identifies changes in TF activities from transcript profiles of Escherichia coli growing in stable (fixed oxygen availabilities) and dynamic (changing oxygen availability) environments. A core oxygen-responsive TF network, supplemented by additional TFs acting under specific conditions, was identified. The activities of the cytoplasmic oxygen-responsive TF, FNR, and the membrane-bound terminal oxidases implied that, even on the scale of the bacterial cell, spatial effects significantly influence oxygen-sensing. Several transcripts exhibited asymmetrical patterns of abundance in aerobic to anaerobic and anaerobic to aerobic transiti

## Gather comparison data

Get results to compare with the other GPT models

In [44]:
test_data = pd.read_csv('outputs/raw_model_outputs.csv').iloc[:25][['Abstract', 'Ground Truth']]
test_data['Ground Truth'] = test_data['Ground Truth'].apply(ast.literal_eval)

In [45]:

def get_model_outputs(data, template, tokenizer, model, device):
    outputs = []
    for _, row in tqdm(data.iterrows(), total=data.shape[0]):
        abstract = row['Abstract']
        ground_truth = row['Ground Truth']

        prompt = template.replace('<abstract>', abstract)
        prompt = prompt.replace('<num_terms>', str(len(ground_truth)))

        model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
        generated_ids = model.generate(**model_inputs, max_new_tokens=2500, do_sample=True, pad_token_id=tokenizer.eos_token_id)
        output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        parsed_output = output.split('\n')[-1]

        outputs.append(parsed_output)

    return outputs

model_out = get_model_outputs(test_data, template, tokenizer, model, device)

  0%|          | 0/25 [00:00<?, ?it/s]

In [59]:
test_data['Predictions'] = model_out
test_data['Predictions'] = test_data['Predictions'].apply(lambda x: set(map(str.strip, next(csv.reader([x])))))

test_data['Model'] = [model.name_or_path.split('/')[-1]] * len(test_data)
test_data = test_data[['Model', 'Abstract', 'Ground Truth', 'Predictions']]
test_data

Unnamed: 0,Model,Abstract,Ground Truth,Predictions
0,mixtral-8x7b-model,While microarray experiments generate volumino...,"{Zoology, Genetics, Computational biology}","{Biology, Bioinformatics, Genetics}"
1,mixtral-8x7b-model,Pea powdery mildew (PM) is an important fungal...,"{Proteins, Genetics, Drug metabolism, Gene reg...","{Genetics, Agricultural science, Genomics, Mol..."
2,mixtral-8x7b-model,Differentiation proceeds along a continuum of ...,"{Zoology, Sequence assembly, Genetics, Drug me...","{Cell biology, Genetics, Genomics, Gene regula..."
3,mixtral-8x7b-model,The annual migration of a bird can involve tho...,"{Zoology, Gene expression, Transcriptomics, Ge...","{Biology, Animal study, Biochemistry}"
4,mixtral-8x7b-model,"The utilization of methane, a potent greenhous...","{Proteins, Drug metabolism, Gene expression, P...","{Biosciences, Genetics, Biotechnology, Gene ex..."
5,mixtral-8x7b-model,We designed and constructed a genome-wide micr...,"{Zoology, Genetics, Drug metabolism, Medicinal...","{Genetics, Genomics, Molecular biology, Animal..."
6,mixtral-8x7b-model,Influenza A viruses (IAVs) quickly adapt to ne...,"{Zoology, Proteins, Genetics, Drug metabolism,...","{Computational chemistry, Translational resear..."
7,mixtral-8x7b-model,Activating mutations in JAK1 have been reporte...,"{Zoology, Immunoinformatics, DNA mutation, Gen...","{Neurology/central nervous system, Biotherapeu..."
8,mixtral-8x7b-model,DNA recombination is required for effective se...,"{Zoology, DNA replication and recombination, G...","{Bioinformatics, Genetics, Genomics}"
9,mixtral-8x7b-model,Newborns are frequently affected by mucocutane...,"{Genetics, Drug metabolism, Immunogenetics, Hu...","{""Data quality management"", ""Data management"",..."


In [61]:
# Add model outputs to raw_model_outputs.csv

raw_model_outputs = pd.read_csv('outputs/raw_model_outputs.csv')

concatenated_df = pd.concat([raw_model_outputs, test_data], axis=0)
concatenated_df[125:]

Unnamed: 0,Model,Abstract,Ground Truth,Predictions
0,mixtral-8x7b-model,While microarray experiments generate volumino...,"{Zoology, Genetics, Computational biology}","{Biology, Bioinformatics, Genetics}"
1,mixtral-8x7b-model,Pea powdery mildew (PM) is an important fungal...,"{Proteins, Genetics, Drug metabolism, Gene reg...","{Genetics, Agricultural science, Genomics, Mol..."
2,mixtral-8x7b-model,Differentiation proceeds along a continuum of ...,"{Zoology, Sequence assembly, Genetics, Drug me...","{Cell biology, Genetics, Genomics, Gene regula..."
3,mixtral-8x7b-model,The annual migration of a bird can involve tho...,"{Zoology, Gene expression, Transcriptomics, Ge...","{Biology, Animal study, Biochemistry}"
4,mixtral-8x7b-model,"The utilization of methane, a potent greenhous...","{Proteins, Drug metabolism, Gene expression, P...","{Biosciences, Genetics, Biotechnology, Gene ex..."
5,mixtral-8x7b-model,We designed and constructed a genome-wide micr...,"{Zoology, Genetics, Drug metabolism, Medicinal...","{Genetics, Genomics, Molecular biology, Animal..."
6,mixtral-8x7b-model,Influenza A viruses (IAVs) quickly adapt to ne...,"{Zoology, Proteins, Genetics, Drug metabolism,...","{Computational chemistry, Translational resear..."
7,mixtral-8x7b-model,Activating mutations in JAK1 have been reporte...,"{Zoology, Immunoinformatics, DNA mutation, Gen...","{Neurology/central nervous system, Biotherapeu..."
8,mixtral-8x7b-model,DNA recombination is required for effective se...,"{Zoology, DNA replication and recombination, G...","{Bioinformatics, Genetics, Genomics}"
9,mixtral-8x7b-model,Newborns are frequently affected by mucocutane...,"{Genetics, Drug metabolism, Immunogenetics, Hu...","{""Data quality management"", ""Data management"",..."


In [62]:
concatenated_df.to_csv('outputs/raw_model_outputs.csv', index=False)