# Install Libraries

In [None]:
pip install -U -q requests

In [None]:
pip install -U -q bitsandbytes

In [None]:
pip install -U -q git+https://github.com/huggingface/transformers.git

In [None]:
pip install -U -q git+https://github.com/huggingface/peft.git

In [None]:
pip install -U -q git+https://github.com/huggingface/accelerate.git

In [None]:
pip install -U -q datasets

In [None]:
pip install -U -q scipy

In [None]:
pip install -U -q ipywidgets

In [None]:
pip install -U -q matplotlib

In [None]:
pip install -q newsapi-python

# Load Disease Table

In [None]:
import json
with open('diseases.json', 'r') as f:
    # Load the JSON data from the file
    disease_table = json.load(f)

# Query API

In [None]:
import requests
filter_key = ['whooping', 'Fever', 'mpox', 'coli']
url = f"https://newsapi.org/v2/top-headlines?country=us&category=health&pageSize=100&apiKey=38a5c18e23b04ed387b8c60f83bf0b37"
response = requests.get(url)
if response.status_code == 200:
    articles = response.json().get("articles", [])
    #print(articles)
else:
    print(f"Error: {response.status_code}, {response.text}")

for id in filter_key:
    url = f"https://newsapi.org/v2/everything?q={id}&language=en&apiKey=38a5c18e23b04ed387b8c60f83bf0b37"
    response = requests.get(url)
    if response.status_code == 200:
        articles.extend(response.json().get("articles", []))
        #print(response.json().get("articles"))
    else:
        print(f"Error: {response.status_code}, {response.text}")

# NER for disease/location keywords

In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

loc_ner = pipeline("ner", model=model, tokenizer=tokenizer)
med_ner = pipeline("token-classification", model="blaze999/Medical-NER")
example = "Long COVID can strike anyone, rich or poor, but a growing body of evidence suggests poor and middle income Americans like Rick Henline suffer most."
loc_results = loc_ner(example)
med_results = med_ner(example)
print(loc_results)
print(med_results)

In [None]:
# Function to reconstruct entities
import re
def extract_target_entity(entities):
    target_words = []
    start_flag = True
    disease =""
    for entity in entities:
        if entity['score'] < 0.1:
            continue
        # Check if entity type matches the target type
        if entity['entity']=='B-DISEASE_DISORDER':
            if disease != "":
                target_words.append(disease)
            disease = entity['word'].lstrip('▁')
        elif entity['entity']=='I-DISEASE_DISORDER':
            if re.search("\▁", entity['word']):
                #print(f"found __{entity['word']}")
                disease = disease + " " + entity['word'].lstrip('▁')
            else:
                #print(f"not found __{entity['word']}")
                disease = disease + entity['word'].lstrip('▁')
        elif entity['entity'] =='B-DIAGNOSTIC_PROCEDURE':
            disease = entity['word'].lstrip('▁')
        elif entity['entity'] == 'I-DIAGNOSTIC_PROCEDURE':
            if re.search("\▁", entity['word']):
                #print(f"found __{entity['word']}")
                disease = disease + " " + entity['word'].lstrip('▁')
            else:
                #print(f"not found __{entity['word']}")
                disease = disease + entity['word'].lstrip('▁')
        else:
            start_flag = False
            if disease != "":
                target_words.append(disease)
            disease = ""
    if disease!="":
        target_words.append(disease)

    # Combine captured words into a single string
    return target_words

# Extract target entity
diseases = extract_target_entity(med_results)
print(diseases)

In [None]:
def lookup_disease_symptoms(disease, disease_table):
    found_flag = False
    for disease_name in disease_table.keys():
        # Use regex to split by parentheses
        #names = re.split(r'\s*\(|\)\s*', disease_name)
        names = [name.strip() for name in re.split(r'\s*\(|\)\s*', disease_name) if name.strip()]

        # Remove empty strings from the result
        sub_names = [sub_name for sub_name in names if sub_name]
        #print(sub_names)
        if any(key.lower() in disease.lower() for key in sub_names):
            found_flag = True
            break
    if not found_flag:
        return None
    return disease_name
name = lookup_disease_symptoms(diseases[0], disease_table)

In [None]:
from IPython.display import display, HTML
def get_entities(location_words, diseases):
    entities = []
    loc_dict = {"Location": "US", "color": "lightgreen"}
    disease_dict = {"Disease": "", "color": "lightblue"}
    for disease in diseases:

        # Find disease keywords

        name = lookup_disease_symptoms(disease, disease_table)
        if lookup_disease_symptoms(disease, disease_table) is None:
            continue
        else:
            disease_dict["Symptom"] = disease_table[name]
            disease_dict["Name"] = name
        #if not any(disease.lower() in key.lower() for key in disease_table.keys()):
        #    continue
        disease_dict["Disease"] = disease
        entities.append(disease_dict)
        # Default location to US
        if len(location_words) == 0:
            entities.append(loc_dict)
        for location in location_words:
            loc_dict["Location"] = location
            entities.append(loc_dict)
    #print(f"[ENTITIES]:{entities}")
    return entities

# Function to highlight entities
def highlight_entities(text, entities):
    for entity in entities:
        if 'Disease' in entity:
            word = entity['Disease']
        else:
            word = entity['Location']
        color = entity['color']
        # Wrap word with a span tag with background color
        text = text.replace(
            word,
            f"<span style='background-color: {color}; padding: 2px;'>{word}</span>"
        )
    # Display the highlighted text in JupyterLab
    display(HTML(text))
location_words = [entity['word'] for entity in loc_results if entity['entity'] == 'B-LOC']
entities = get_entities(location_words, diseases)
highlight_entities(example, entities)

In [None]:
headlines_lst = []
descriptions_lst = []
diseases_lst = []
entities_lst = []
for a in articles:
    ### parsing from title
    if (a['title'] == "[Removed]") | (a['title'] is None):
        continue
    loc_results = loc_ner(a['title'])
    med_results = med_ner(a['title'])
    diseases=extract_target_entity(med_results)
    location_words = [entity['word'] for entity in loc_results if entity['entity'] == 'B-LOC']
    entities = get_entities(location_words, diseases)
    if len(entities) == 0:
        continue
    #print("#######")
    entities[0]['title'] = a['title']
    entities[0]['PublishAt'] = a['publishedAt']
    entities[0]['url'] = a['url']
    #print(entities)
    #print("[PublishAt]:")
    #print(a['publishedAt'])
    #print(f"[TITLE]:")
    highlight_entities(a['title'], entities)
    entities_lst.append(entities)
    ### parsing from description
    if (a['description'] == "[Removed]") | (a['description'] is None):
        continue
    loc_results = loc_ner(a['description'])
    med_results = med_ner(a['description'])
    diseases=extract_target_entity(med_results)
    location_words = [entity['word'] for entity in loc_results if entity['entity'] == 'B-LOC']
    entities = get_entities(location_words, diseases)
    if len(entities) == 0:
        continue
    #print(entities)
    #print(f"[DESCRIPTION]:")
    highlight_entities(a['description'], entities)
    #print("[URL]:")
    #print(a['url'])
    entities_lst.append(entities)
    headlines_lst.append(a["title"])
    descriptions_lst.append(a["description"])
    diseases_lst.append(entities[0]["Disease"])

# Dataset evaluation
## Headlines

In [None]:
from datasets import Dataset
import numpy as np

n_datapoints = len(descriptions_lst)

instructions_lst_np = np.array(["You are a doctor trying to determine what disease outbreak may be occuring based on a news headline. You will be given a headline, and you have to say your prediction of what the disease is" for _ in range(n_datapoints)])
descriptions_lst_np = np.array(headlines_lst)
diseases_lst_np = np.array(diseases_lst)

permutator = np.random.permutation(n_datapoints)

shuffled_instructions = instructions_lst_np[permutator]
shuffled_descriptions = descriptions_lst_np[permutator]
shuffled_diseases = diseases_lst_np[permutator]

train_instructions = shuffled_instructions[:int(n_datapoints * 0.8)]
train_descriptions = shuffled_descriptions[:int(n_datapoints * 0.8)]
train_diseases = shuffled_diseases[:int(n_datapoints * 0.8)]

eval_instructions = shuffled_instructions[int(n_datapoints * 0.8):]
eval_descriptions = shuffled_descriptions[int(n_datapoints * 0.8):]
eval_diseases = shuffled_diseases[int(n_datapoints * 0.8):]

train_dataset_dict = {
    "instruction": train_instructions,
    "input": train_descriptions,
    "output": train_diseases
}
eval_dataset_dict = {
    "instruction": eval_instructions,
    "input": eval_descriptions,
    "output": eval_diseases
}

train_headline_dataset = Dataset.from_dict(train_dataset_dict)
eval_headline_dataset = Dataset.from_dict(eval_dataset_dict)

# show the current diseases on headlines

In [None]:
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

data=entities_lst
#new_symptom = 'running nose, fever and cough'
diseases_lst = [entry[0]['Name'] for entry in data]
diseases_lst = np.unique(diseases_lst)
symptoms_lst = [disease_table[disease] for disease in diseases_lst[:10]]
#symptoms_lst = np.append(symptoms_lst, new_symptom)

vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(symptoms_lst)

pca_transformer = PCA(n_components=2)
reduced_coords = pca_transformer.fit_transform(tfidf_matrix.toarray())

plt.figure(figsize=(4, 4))
for i, disease in enumerate(diseases_lst[:10]):
    plt.scatter(reduced_coords[i, 0], reduced_coords[i, 1], label=disease, alpha=0.7)
#new_symptom_coords = reduced_coords[-1]  # Get the last coordinate for the new symptom
#plt.scatter(new_symptom_coords[0], new_symptom_coords[1], color='red', label=new_symptom, alpha=0.7)

plt.legend(
    loc='center left',
    bbox_to_anchor=(1, 0.5),  # Anchor legend box outside the plot
    title="Diseases",
    fontsize=9,
    title_fontsize=10
)

plt.title("Cosine Similarity of Diseases based on Symptoms")
plt.xlabel("PCA Component 1")
plt.ylabel("PCA Component 2")
plt.grid(True)
plt.show()

# Find closest diseases on the headlines

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
def find_similiar_disease(query):

    new_symptom_vector = vectorizer.transform([query])  # Transform new symptom

    # Calculate cosine similarities
    similarities = cosine_similarity(new_symptom_vector, tfidf_matrix).flatten()

    # Find the closest disease
    closest_disease_index = similarities.argmax()
    
    if max(similarities) < 0.5:
        return None
    closest_disease = diseases_lst[closest_disease_index]

    # Find the URL
    url = None
    title = None
    publish_time = None
    for entry in entities_lst:
        if entry[0].get('Name') == closest_disease:
            title = entry[0].get('title')
            publish_time = entry[0].get('PublishAt')
            url = entry[0].get('url')
            break
    reference_text = f"\n\nYou are likely to have {closest_disease}.\nFrom recent news: {title}, published at: {publish_time}\nMore information: {url}"

    # Print results
    #print("Cosine Similarities:", dict(zip(diseases_lst, similarities)))
    return reference_text


## Integrated conversations from Icliniq

In [None]:
n_datapoints = 100
# Read the JSON file
with open('iCliniq.json', 'r') as file:
    icliniq = json.load(file)
icliniq_input = []
icliniq_output = []
# Access the `input` key from each dictionary
for entry in icliniq[:n_datapoints]:
    icliniq_input.append(entry['input'])
    res = find_similiar_disease(entry['input'])
    if res is not None:
        res = entry['answer_icliniq']+res
    else:
        res = entry['answer_icliniq']
    icliniq_output.append(res)
icliniq_instructions_lst_np = np.array(["If you are a doctor, please answer the medical questions based on the patient's description." for _ in range(n_datapoints)])
icliniq_descriptions_lst_np = np.array(icliniq_input)
icliniq_diseases_lst_np = np.array(icliniq_output)

icliniq_eval_dataset_dict = {
    "instruction": icliniq_instructions_lst_np,
    "input": icliniq_descriptions_lst_np,
    "output": icliniq_diseases_lst_np
}

#train_icliniq_dataset = Dataset.from_dict(icliniq_train_dataset_dict)
eval_icliniq_dataset = Dataset.from_dict(icliniq_eval_dataset_dict)

# Intergration with ChatDoctor

In [None]:
import torch
torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from datasets import load_dataset, load_from_disk

dataset = load_from_disk("patients_query_sampled/")

train_conversation_dataset = dataset['train']
eval_conversation_dataset = dataset['test']

### 2. Load Base Model

Let's now load Mistral - mistralai/Mistral-7B-v0.1 - using 4-bit quantization!

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model_id = "mistralai/Mistral-7B-v0.1"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, resume_download=True)

### 3. Tokenization

Set up the tokenizer. Add padding on the left as it [makes training use less memory](https://ai.stackexchange.com/questions/41485/while-fine-tuning-a-decoder-only-llm-like-llama-on-chat-dataset-what-kind-of-pa).


For `model_max_length`, it's helpful to get a distribution of your data lengths. Let's first tokenize without the truncation/padding, so we can get a length distribution.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
#Functions to format prompts

def headline_formatting_func(example):
    text = f"The following is a headline about a disease outbreak and what that disease is: \nHeadline: {example['input']} \nDisease: {example['output']}"
    return text
def headline_generate_and_tokenize_prompt(prompt):
    return tokenizer(headline_formatting_func(prompt))

def conversation_formatting_func(example):
    text = f"The following is a doctor's opinion on a person's query: \nPatient query: {example['input']} \nDoctor opinion: {example['output']}"
    return text
def conversation_generate_and_tokenize_prompt(prompt):
    return tokenizer(conversation_formatting_func(prompt))

def integrate_formatting_func(example):
    text = f"The following is doctor's suggestions with search of alert outbreak in the region: \nPatient query: {example['input']} \nDoctor opinion: {example['output']}"
    return text
def integrate_generate_and_tokenize_prompt(prompt):
    return tokenizer(integrate_formatting_func(prompt))

Reformat the prompt and tokenize each sample:

In [None]:
tokenized_train_headline_dataset = train_headline_dataset.map(headline_generate_and_tokenize_prompt)
tokenized_eval_headline_dataset = eval_headline_dataset.map(headline_generate_and_tokenize_prompt)

tokenized_train_conversation_dataset = train_conversation_dataset.map(conversation_generate_and_tokenize_prompt)
tokenized_val_conversation_dataset = eval_conversation_dataset.map(conversation_generate_and_tokenize_prompt)

#tokenized_train_icliniq_dataset = train_icliniq_dataset.map(integrate_generate_and_tokenize_prompt)
tokenized_val_icliniq_dataset = eval_icliniq_dataset.map(integrate_generate_and_tokenize_prompt)

In [None]:
from datasets import concatenate_datasets

#Combine conversation and headline datasets
n_headline_train_datapoints = len(tokenized_train_headline_dataset)
n_conversation_train_datapoints = len(tokenized_train_conversation_dataset)

tokenized_train_overall_dataset = concatenate_datasets([tokenized_train_headline_dataset for _ in range(max(1, n_conversation_train_datapoints // n_headline_train_datapoints))] + [tokenized_train_conversation_dataset])
tokenized_val_overall_dataset = concatenate_datasets([tokenized_eval_headline_dataset for _ in range(max(1, n_conversation_train_datapoints // n_headline_train_datapoints))] + [tokenized_val_conversation_dataset])

### 4. Set Up LoRA

Now, to start our fine-tuning, we have to apply some preprocessing to the model to prepare it for training. For that use the `prepare_model_for_kbit_training` method from PEFT.

In [None]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

Let's print the model to examine its layers, as we will apply QLoRA to all the linear layers of the model. Those layers are `q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`, and `lm_head`.

Here we define the LoRA config.

`r` is the rank of the low-rank matrix used in the adapters, which thus controls the number of parameters trained. A higher rank will allow for more expressivity, but there is a compute tradeoff.

`alpha` is the scaling factor for the learned weights. The weight matrix is scaled by `alpha/r`, and thus a higher value for `alpha` assigns more weight to the LoRA activations.

The values used in the QLoRA paper were `r=64` and `lora_alpha=16`, and these are said to generalize well, but we will use `r=32` and `lora_alpha=64` so that we have more emphasis on the new fine-tuned data while also reducing computational complexity.

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

See how the model looks different now, with the LoRA adapters added:

### Accelerator

Set up the Accelerator. I'm not sure if we really need this for a QLoRA given its [description](https://huggingface.co/docs/accelerate/v0.19.0/en/usage_guides/fsdp) (I have to read more about it) but it seems it can't hurt, and it's helpful to have the code for future reference. You can always comment out the accelerator if you want to try without.

In [None]:
from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

In [None]:
model = accelerator.prepare_model(model).cpu()

In [None]:
#Create multiple versions of the model to finetune on different information
import copy

overall_model = copy.deepcopy(model).cpu()

### 5. Run Training!

In [None]:
import transformers
from datetime import datetime

In [None]:
#Finetine the model on the overall dataset

project = "overall-finetune-adjusted"
base_model_name = "mistral"
run_name = base_model_name + "-" + project
output_dir = "./" + run_name

overall_model = overall_model.to(device)

overall_trainer = transformers.Trainer(
    model=overall_model,
    train_dataset=tokenized_train_overall_dataset,
    eval_dataset=tokenized_val_overall_dataset,
    args=transformers.TrainingArguments(
        output_dir=output_dir,
        warmup_steps=1,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=1,
        gradient_checkpointing=True,
        max_steps=500,
        learning_rate=2.5e-4, # Want a small lr for finetuning
        #bf16=True,
        optim="paged_adamw_8bit",
        logging_steps=25,              # When to start reporting loss
        logging_dir="./logs",        # Directory for storing logs
        save_strategy="steps",       # Save the model checkpoint every logging step
        save_steps=25,                # Save checkpoints every 50 steps
        evaluation_strategy="steps", # Evaluate the model every logging step
        eval_steps=25,               # Evaluate and save checkpoints every 50 steps
        do_eval=True,                # Perform evaluation at the end of training
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

overall_model.config.use_cache = True  # silence the warnings. Please re-enable for inference!
overall_trainer.train(resume_from_checkpoint=False)

overall_model = overall_model.cpu()

### Inference

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model_id = "mistralai/Mistral-7B-v0.1"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,  # Mistral, same as before
    quantization_config=bnb_config,  # Same quantization config as before
    device_map="auto",
    trust_remote_code=True,
    use_auth_token=True
)

tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True)

Now load the QLoRA adapter from the appropriate checkpoint directory, i.e. the best performing model checkpoint:

In [None]:
from peft import PeftModel

ft_model = PeftModel.from_pretrained(base_model, "mistral-chat-doctor-finetune/checkpoint-500/")

and run your inference!

In [None]:
print("Doc chat inference:")
print("===================================================================================")
query = " hi doc, my bmi is 28 what to do?"
eval_prompt = """Patient's Query: {} \n###\n\n""".format(query)
model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")
# ft_model.eval()
output = ft_model.generate(input_ids=model_input["input_ids"].to(device),
                           attention_mask=model_input["attention_mask"],
                           max_new_tokens=100, repetition_penalty=1.17)
# with torch.no_grad():

print(tokenizer.decode(output[0], skip_special_tokens=True))

# Chat Doctor Next response

In [None]:
def respond(query):
    eval_prompt = """Patient's Query:\n\n {} ###\n\n""".format(query)
    model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")
    output = ft_model.generate(input_ids=model_input["input_ids"].to(device),
                           attention_mask=model_input["attention_mask"],
                           max_new_tokens=125, repetition_penalty=1.15)
    result = tokenizer.decode(output[0], skip_special_tokens=True).replace(eval_prompt, "")
    res = find_similiar_disease(query)
    print(res)
    if res is not None:
        result = result + res
    return result

In [None]:
import random
import gradio as gr

def doc(message, history):
    return respond(message)

demo = gr.ChatInterface(doc)

demo.launch(share=True)

#Score model

In [None]:
import evaluate

In [None]:
def evaluate_by_metric(outputs, reference_texts, evaluation_metric):
    value = evaluation_metric.compute(predictions=outputs, reference=reference_texts, lang="en")
    return value
bertscore=evaluate.load("bertscore")

### evaluation for integrated model

In [None]:
overall_model = overall_model.to(device)

In [None]:
decoded_outputs = []
reference_texts = []

for i in range(len(tokenized_val_conversation_dataset)):
    if (i % 10 == 0):
        print(f"{i}/{len(tokenized_val_conversation_dataset)}")
    current_sample = tokenized_val_conversation_dataset[i]
    eval_prompt = torch.tensor(current_sample["input_ids"])[None, :].to(device)
    eval_reference = current_sample["output"]
    eval_attention = torch.tensor(current_sample["attention_mask"])[None, :].to(device)
    output = overall_model.generate(input_ids=eval_prompt,
                           attention_mask=eval_attention,
                           max_new_tokens=100, repetition_penalty=1.17, pad_token_id=tokenizer.eos_token_id)

    decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
    reference_text = current_sample["output"]

    decoded_outputs.append(decoded_output)
    reference_texts.append(reference_text)

scores = bertscore.compute(predictions=decoded_outputs, references=reference_texts, lang="en")
print(f"Overall Model on Conversation Data")
print(f"Precision: {np.mean(scores['precision'])}")
print(f"Recall: {np.mean(scores['recall'])}")
print(f"F1: {np.mean(scores['f1'])}")

cluster_scores = scores
cluster_outputs = decoded_outputs
cluster_references = reference_texts

In [None]:
all_prompts = []
decoded_outputs = []
reference_texts = []

for i in range(len(tokenized_eval_headline_dataset)):
    current_sample = tokenized_eval_headline_dataset[i]
    eval_prompt = torch.tensor(current_sample["input_ids"])[None, :].to(device)
    eval_reference = current_sample["output"]
    eval_attention = torch.tensor(current_sample["attention_mask"])[None, :].to(device)
    output = overall_model.generate(input_ids=eval_prompt,
                           attention_mask=eval_attention,
                           max_new_tokens=100, repetition_penalty=1.17, pad_token_id=tokenizer.eos_token_id)

    decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
    reference_text = current_sample["output"]#tokenizer.decode(current_sample["labels"], skip_special_tokens=True)

    all_prompts.append(eval_prompt)
    decoded_outputs.append(decoded_output)
    reference_texts.append(reference_text)

scores = bertscore.compute(predictions=decoded_outputs, references=reference_texts, lang="en")
print(f"Overall Model on Headline Data")
print(f"Precision: {np.mean(scores['precision'])}")
print(f"Recall: {np.mean(scores['recall'])}")
print(f"F1: {np.mean(scores['f1'])}")

In [None]:
overall_model = overall_model.cpu()