In [1]:
import json
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import dataclasses
import numpy as np
from repeng import ControlVector, ControlModel, DatasetEntry
import os
import json
from tqdm import tqdm
import gc
import time
import jsonlines

In [3]:
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()

0

In [2]:
model_name = "BioMistral/BioMistral-7B-DARE-AWQ-QGS128-W4-GEMM"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model = model.to("cuda:0" if torch.cuda.is_available() else "cpu")
#  wrap it in a ControlModel for later.
model = ControlModel(model, list(range(-5, -18, -1)))

user_tag, asst_tag = "[INST]", "[/INST]"

You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set your model on a GPU device in order to run your model.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


Lay Summary Simplicity

In [3]:
os.getcwd()

'/home/mingcong/scripts/control_vectors'

In [4]:
with open('../../../../data/colx531/biolaysumm2024_data/eLife_test.jsonl', "r", encoding="utf-8") as f:
    data = [json.loads(line) for line in f]

In [5]:
def baseline_model(doc):
    return doc["sections"]["Abstract"]

In [6]:
for item in data:
    sections = item["article"].split("\n")
    item["sections"] = dict(zip(item["headings"], sections))

abstracts = []
for item in tqdm(data, leave=True):
    rephrased_abstract = baseline_model(item)
    abstracts.append({"id": item["id"], "prediction": rephrased_abstract})

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 142/142 [00:00<00:00, 1381882.06it/s]


In [7]:
abstracts_labeled = ["[INST]Abstract:" + item['prediction'] + "\n Lay summary of abstract: [/INST]" for item in abstracts]

In [8]:
def make_dataset(
    template: str,
    positive_personas: list[str],
    negative_personas: list[str],
    suffix_list: list[str]
) -> list[DatasetEntry]:
    dataset = []
    for suffix in suffix_list:
        for positive_persona, negative_persona in zip(positive_personas, negative_personas):
            positive_template = template.format(persona=positive_persona)
            negative_template = template.format(persona=negative_persona)
            dataset.append(
                DatasetEntry(
                    positive=f"{user_tag} {positive_template} {asst_tag} {suffix}",
                    negative=f"{user_tag} {negative_template} {asst_tag} {suffix}",
                )
            )
    return dataset

In [9]:
simple_dataset = make_dataset(
    "{persona}.",
    ["I have an abstract from a bio-medical research paper that I would like to make more understandable for a wider audience, including those without a scientific background. Please convert the technical language into simpler terms, explain any complex concepts in a way that a layperson could understand, and provide additional background information where necessary to help clarify the relevance and significance of the research."],
    ["I have a simplified summary of findings from a bio-medical research paper that I need to be rewritten for a professional audience with a high level of expertise in this field. Please enhance the language to match the sophistication expected in a scholarly article, incorporate appropriate technical jargon."],
    abstracts_labeled,
)
model.reset()
simple_vector = ControlVector.train(model, tokenizer, simple_dataset)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:10<00:00,  7.78s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [00:01<00:00, 26.62it/s]


In [11]:
# simple_vector.export_gguf('simple_vector.gguf')

gguf: This GGUF file is for Little Endian only


In [11]:
# # save
# np.save("simple_vector.npy", dataclasses.asdict(simple_vector))

In [None]:
# # later...
# v = ControlVector(**np.load("vector.npy", allow_pickle=True).tolist())

In [10]:
def generate_with_vector_pred(
    input: str,
    vector: ControlVector,
    coeffs: float = -1,
    max_new_tokens: int = 128,
    repetition_penalty: float = 1.1,
    show_baseline: bool = True,
):

    # start_time = time.time()
    
    if user_tag not in input:
        input = f"{user_tag} {input.strip()} {asst_tag}"

    input_ids = tokenizer(input, return_tensors="pt").to(model.device)    
    model.set_control(vector, coeffs)

    settings = {
        "pad_token_id": tokenizer.eos_token_id, # silence warning
        "do_sample": True, # temperature=0
        "max_new_tokens": max_new_tokens,
        "repetition_penalty": repetition_penalty,
        "top_k":50,
        "top_p":0.95
    }
    outputs = model.generate(**input_ids, **settings)    
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    prediction = generated_text[len(tokenizer.decode(input_ids['input_ids'][0], skip_special_tokens=True)):].strip()
    
    # end_time = time.time()
    # print(f"执行时间：{end_time - start_time:.4f}秒。")
    return prediction

In [11]:
data_folder = '/data/colx531/biolaysumm2024_data/'
file_names = ['PLOS_test.jsonl', 'eLife_test.jsonl']

In [12]:
def check_name_in_jsonl(file_path, name_to_check):
    with open(file_path, 'r') as file:
        for line in file:
            json_object = json.loads(line.strip())
            if json_object.get('id') == name_to_check:
                return True
    return False

In [13]:
for file_name in file_names:
    with open(data_folder+file_name, 'r') as f:
        data = [json.loads(line) for line in f]

    for item in data:
        sections = item['article'].split('\n')
        item['sections'] = {k: v for k, v in zip(item['headings'], sections)}

    predictions = []

    num_skip = 0
    for item in tqdm(data, leave=True):
        try:
            is_name_present = check_name_in_jsonl(f'prediction_baseline2_BioMistral7B_{file_name.split(".")[0]}.jsonl', item['id'])
            if is_name_present==True:
                num_skip+=1
                print(f'skip No. {num_skip}: {item["id"]}')
                continue
        except:
            print('first sample')
        # Extract abstracts
        abstract = item['sections']['Abstract']
        prompt = f"""
            Rephrase the following abstract from a medical paper to make it more accessible and understandable to non-expert audiences, commonly referred to as "lay summaries".
            [INST]Abstract: {abstract}
            Lay summary of abstract: [/INST]"""
        # Rephrase abstracts
        rephrased_abstract = generate_with_vector_pred(prompt, simple_vector, -1, max_new_tokens=2048)
        one_prediction = [{'id': item['id'], 'prediction': rephrased_abstract}]
        
        with jsonlines.open(f'prediction_baseline2_BioMistral7B_{file_name.split(".")[0]}.jsonl', mode='a') as writer:
            writer.write_all(one_prediction)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 142/142 [00:00<00:00, 3567.93it/s]


skip journal.ppat.1009789
skip journal.pntd.0007992
skip journal.pntd.0010704
skip journal.ppat.1010691
skip journal.pgen.1009815
skip journal.pcbi.1010461
skip journal.pcbi.1009317
skip journal.pcbi.1010272
skip journal.pntd.0008105
skip journal.pcbi.1009544
skip journal.ppat.1008577
skip journal.pcbi.1008769
skip journal.ppat.1008759
skip journal.pgen.1010047
skip journal.ppat.1009523
skip journal.pcbi.1010007
skip journal.pntd.0009323
skip journal.ppat.1009313
skip journal.ppat.1009632
skip journal.pcbi.1008614
skip journal.pntd.0008731
skip journal.ppat.1008665
skip journal.ppat.1010544
skip journal.ppat.1009433
skip journal.pgen.1010029
skip journal.pntd.0009866
skip journal.pntd.0008240
skip journal.pntd.0010527
skip journal.ppat.1008447
skip journal.pgen.1009747
skip journal.pntd.0007619
skip journal.pntd.0009208
skip journal.ppat.1008945
skip journal.ppat.1009452
skip journal.pcbi.1010213
skip journal.pcbi.1009995
skip journal.pcbi.1008514
skip journal.pntd.0008847
skip journal

  0%|                                                                                                                                                              | 0/142 [00:00<?, ?it/s]

skip elife-81547-v1
skip elife-86176-v2
skip elife-82210-v1
skip elife-83152-v2
skip elife-83044-v2
skip elife-81828-v2
skip elife-85104-v2
skip elife-84315-v2
skip elife-77577-v2
skip elife-81641-v2
skip elife-83291-v2
skip elife-77514-v3
skip elife-79939-v2
skip elife-77699-v2
skip elife-79208-v2
skip elife-75340-v2
skip elife-81678-v2
skip elife-78703-v2
skip elife-76870-v2
skip elife-79002-v1
skip elife-85182-v2
skip elife-77877-v1
skip elife-90509-v1
skip elife-80984-v2
skip elife-78941-v2
skip elife-78917-v1
skip elife-78633-v2
skip elife-80813-v1
skip elife-79363-v2


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 142/142 [4:55:40<00:00, 124.94s/it]
