# Mistral
## Imports


In [1]:
%%capture
# Installs Unsloth, Xformers (Flash Attention) and all other packages!
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers "trl<0.9.0" peft accelerate bitsandbytes

In [2]:
!pip install rouge_score
!pip install datasets
!pip install numpy
!pip install torch
!pip install nltk

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24933 sha256=b0d38af72c19dd7ab4bea464753ff4e8b95d10f7caaa03dcb75c6db2e90b85e6
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_c

In [3]:
import pandas as pd
import json

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Loading base Mistral model


In [5]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/mistral-7b-v0.3", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

FastLanguageModel.for_inference(model) # Enable native 2x faster inferenc

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


config.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

==((====))==  Unsloth: Fast Mistral patching release 2024.6
   \\   /|    GPU: NVIDIA L4. Max memory: 22.168 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0+cu121. CUDA = 8.9. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.26.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


model.safetensors:   0%|          | 0.00/4.14G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/137k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

Unsloth: Will load unsloth/mistral-7b-v0.3-bnb-4bit as a legacy tokenizer.


In [6]:
COLAB_ENABLED=True

if COLAB_ENABLED:
    DATA_PATH = "/content/drive/MyDrive/biomedical_nlp/data"
    from google.colab import drive
    drive.mount('/content/drive')
    #%% md
else:
    DATA_PATH = "./data"

# Load the dataset
with open(DATA_PATH + '/data.json', 'r') as f:
    data = json.load(f)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
print(json.dumps(data, indent=4))

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [8]:

dfs = []
question_data = {"question_id": [], "question": [], "question_type": [], "title": "", "adaptations": [], "abstracts": []}
for key, value in data.items():
    abstracts = []
    adaptations = []
    for sub_key, sub_value in value.items():
        if isinstance(sub_value, dict):
            for adaptation in sub_value['adaptations']:
                for adaptation_key, adaptation_value in sub_value['adaptations'][adaptation].items():
                    adaptations.append(adaptation_value)
                    abstracts.append(sub_value['abstract'][adaptation_key])
        question_data["abstracts"] = abstracts
        question_data["adaptations"] = adaptations
        question_data["question_id"] = [key] * len(adaptations)
        question_data["question_type"] = [value["question_type"]] * len(adaptations)
        question_data["question"] = [value["question"]] * len(adaptations)
        question_data["title"] = [value["question"]] * len(adaptations)
        assert len(adaptations) == len(abstracts), f"len {len(adaptations)} not equal {len(abstracts)}"
    dfs.append(pd.DataFrame.from_dict(question_data))

df = pd.concat(dfs)
# print(df.head())
print(df['question'].head())


In [9]:
df.to_csv(DATA_PATH + '/test_dataset.csv', index=False)

0    What causes muscle spasm?
1    What causes muscle spasm?
2    What causes muscle spasm?
3    What causes muscle spasm?
4    What causes muscle spasm?
Name: question, dtype: object


In [10]:
import os
import unicodedata
import math
import argparse
import random

#mee
## Split up dataset into train/val/test -> 70/15/15
path = DATA_PATH
# print("Unique questions in dataset:", df['question'].unique())

# Clean the question column
df['question_id'] = df['question_id'].astype(str)

test_question_numbers = ['5','12','16','22','30','36','42','48','54','61','68']
val_question_numbers = ['2','7','13','17','26','34','40','46','52','58','66']
train_question_numbers = [str(x) for x in range(1, 76) if str(x) not in test_question_numbers and str(x) not in val_question_numbers]



test = df.loc[df['question_id'].isin(test_question_numbers)]
val = df.loc[df['question_id'].isin(val_question_numbers)]
train = df.loc[df['question_id'].isin(train_question_numbers)]

print("Train question numbers:", train_question_numbers)
print("Number of entries in test set:", len(test))
print("Number of entries in val set:", len(val))
print("Number of entries in train set:", len(train))


dfs = {'train':train, 'val':val, 'test':test}

    # Save each to CSV file
for key, df in dfs.items():
    df.to_csv(path + key + ".csv", index=False, encoding='utf-8-sig')
else:
    train = pd.read_csv(path + 'train.csv', header=0)
    val = pd.read_csv(path + 'val.csv', header=0)
    test = pd.read_csv(path + 'test.csv', header=0)
    combined_datasets = {'train':train, 'val':val, 'test':test}


Train question numbers: ['1', '3', '4', '6', '8', '9', '10', '11', '14', '15', '18', '19', '20', '21', '23', '24', '25', '27', '28', '29', '31', '32', '33', '35', '37', '38', '39', '41', '43', '44', '45', '47', '49', '50', '51', '53', '55', '56', '57', '59', '60', '62', '63', '64', '65', '67', '69', '70', '71', '72', '73', '74', '75']
Number of entries in test set: 1373
Number of entries in val set: 1458
Number of entries in train set: 6488


In [11]:
combined_datasets['train']

Unnamed: 0,question_id,question,question_type,title,adaptations,abstracts
0,1,What causes muscle spasm?,C,What causes muscle spasm?,Muscle cramps are a common problem represented...,Muscle cramps are a common problem characteriz...
1,1,What causes muscle spasm?,C,What causes muscle spasm?,"These true cramps, coming from nerves outside ...","These true cramps, which originate from periph..."
2,1,What causes muscle spasm?,C,What causes muscle spasm?,"Medical history, physical check-up, and lab sc...","Medical history, physical examination, and a l..."
3,1,What causes muscle spasm?,C,What causes muscle spasm?,"Despite their harmless nature, cramps are unco...","Despite the ""benign"" nature of cramps, many pa..."
4,1,What causes muscle spasm?,C,What causes muscle spasm?,Experience and limited medical studies guide t...,Treatment options are guided both by experienc...
...,...,...,...,...,...,...
6483,75,What is a gene affected by sickle cell anemia?,B,What is a gene affected by sickle cell anemia?,Sickle Cell Anemia (SCA) is a genetic blood di...,Sickle cell anemia (SCA) is a disease characte...
6484,75,What is a gene affected by sickle cell anemia?,B,What is a gene affected by sickle cell anemia?,Due to their effects on sickle hemoglobin (HbS...,Because of their effects on HbS polymerization...
6485,75,What is a gene affected by sickle cell anemia?,B,What is a gene affected by sickle cell anemia?,The aim of our study was to find out if the nu...,The aim of our study was to determine if the n...
6486,75,What is a gene affected by sickle cell anemia?,B,What is a gene affected by sickle cell anemia?,Results showed that alpha-thalassemia protecte...,Our results confirmed that alpha-thalassemia p...


In [12]:
!pip install sacrebleu sacremoses

Collecting sacrebleu
  Downloading sacrebleu-2.4.2-py3-none-any.whl (106 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/106.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m39.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-2.10.0-py3-none-any.whl (18 kB)
Collecting colorama (from sacrebleu)
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Installing collected packages: sacremoses, portalocker, colorama, sacrebleu
Successfully installed colorama-0.4.6 portalocker-2.10.0 sacrebleu-2.4.2 sacremoses-0.1.1


In [13]:
import numpy as np
import torch
from transformers import Seq2SeqTrainer, EarlyStoppingCallback, AutoModelForSeq2SeqLM
from transformers import AutoTokenizer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, set_seed
from datasets import load_metric, Dataset, DatasetDict
import nltk
from peft import get_peft_model, LoraConfig

# Ensure NLTK tokenizers are downloaded
nltk.download('punkt')

# Define the project path
base_path = DATA_PATH

def train_and_evaluate_model(model_name="t5-small", max_token_length=512, max_token_target_length=512, batch_size=8, epochs=10, chosen_seed=42, test_only=False, use_lora=False):
    # Set location to store models
    model_identifier = model_name.split('/')[-1]
    model_checkpoint_dir = base_path + model_identifier + '_runs'

    # Set prefix for T5 model to select summarization version of T5
    prefix = "summarize: " if ('t5' in model_identifier) or ('T0' in model_identifier) else ""

    def tokenize_and_encode(examples):
        inputs = [prefix + doc for doc in examples["input_text"]]
        tokenized_input = tokenizer(inputs, max_length=max_token_length, truncation=True)

        with tokenizer.as_target_tokenizer():
            tokenized_label = tokenizer(examples['target_text'], max_length=max_token_target_length, truncation=True)

        tokenized_input['labels'] = tokenized_label['input_ids']
        return tokenized_input

    # Get Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    for split_name, dataset in datasets.items():
            dataset['input_text'] = dataset['input_text'].astype(str)
            dataset['target_text'] = dataset['target_text'].astype(str)

    # Create & Tokenize dictionary of pandas datasets converted to Transformer Datasets
    tokenized_datasets = DatasetDict()
    for split_name, dataset in datasets.items():
        tokenized_datasets[split_name] = Dataset.from_pandas(dataset)
        tokenized_datasets[split_name] = tokenized_datasets[split_name].map(tokenize_and_encode, batched=True)

    # Ensure datasets are not empty
    for split in ['train', 'val', 'test']:
        if len(tokenized_datasets[split]) == 0:
            raise ValueError(f"{split} dataset is empty. Please check the data loading process.")

    # Debugging: Check if input_text and target_text are lists of strings
    for split in ['train', 'val', 'test']:
        print(f"\nFirst 2 entries in {split} dataset:")
        print(tokenized_datasets[split]['input_text'][:2])
        print(tokenized_datasets[split]['target_text'][:2])

    # Set model type
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    # Apply LoRA if specified
    if use_lora:
        config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
        model = get_peft_model(model, config)

    # Set Training Arguments
    training_args = Seq2SeqTrainingArguments(
        do_train=True,
        evaluation_strategy="epoch",
        load_best_model_at_end=True,
        logging_strategy='epoch',
        num_train_epochs=epochs,
        output_dir=model_checkpoint_dir,
        overwrite_output_dir=True,
        per_device_eval_batch_size=batch_size,
        per_device_train_batch_size=batch_size,
        predict_with_generate=True,
        remove_unused_columns=True,
        report_to="none",
        save_strategy='epoch',
        save_total_limit=1,
        seed=chosen_seed,
    )

    # Set Data Collator
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

    # Define Metrics
    def format_predictions(lst):
        return list(map(lambda el: [el], lst))

    def clean_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [[label.strip()] for label in labels]
        return preds, labels

    metric_rouge = load_metric("rouge")
    metric_sari = load_metric('sari')

    def calculate_metrics(eval_pred):
        predictions, labels = eval_pred
        decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # newline after each sentence
        decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
        decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

        result = metric_rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

        # Add mean generated length
        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
        result["gen_len"] = np.mean(prediction_lens)

        return {k: round(v, 4) for k, v in result.items()}

    # Generate model predictions for ROUGE and SARI scores
    def generate_and_evaluate_predictions(model_type):
        decoded_preds = []
        for encoded_text in tokenized_datasets['test']['input_ids']:
            summary_ids = model_type.generate(torch.tensor(encoded_text).unsqueeze(0),
                                              num_beams=5,
                                              no_repeat_ngram_size=2,
                                              length_penalty=5,
                                              min_length=30,
                                              max_length=max_token_target_length,
                                              early_stopping=True)
            output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
            decoded_preds.append(output)

        decoded_labels = tokenized_datasets['test']['target_text']

        # Prepare SARI and ROUGE inputs
        sari_decoded_preds = [" ".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
        sari_decoded_labels = [[" ".join(nltk.sent_tokenize(label.strip()))] for label in decoded_labels]
        rouge_decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
        rouge_decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

        # SARI
        sari = metric_sari.compute(sources=tokenized_datasets['test']['input_text'], predictions=sari_decoded_preds, references=sari_decoded_labels)
        # SacreBLEU
        sacrebleu = metric_sacrebleu.compute(predictions=decoded_preds, references=[[label] for label in decoded_labels])

        # ROUGE
        rouge = metric_rouge.compute(predictions=rouge_decoded_preds, references=rouge_decoded_labels, use_stemmer=True)
        rouge = {key: value.mid.fmeasure * 100 for key, value in rouge.items()}
        rouge = {k: round(v, 4) for k, v in rouge.items()}

        return sari, rouge, sacrebleu, sari_decoded_preds

    # Initialize Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets['train'],
        eval_dataset=tokenized_datasets['val'],
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=calculate_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    )

    if not test_only:
        trainer.train()

    # Save Model & Tokenizer
    model_save_path = base_path + 'models/best_run_' + model_identifier
    trainer.save_model(model_save_path)
    tokenizer.save_pretrained(model_save_path)

    # Get Test Set Metrics
    model = AutoModelForSeq2SeqLM.from_pretrained(model_save_path)
    model_sari, model_rouge, model_sacrebleu, model_preds = generate_and_evaluate_predictions(model)
    print('Baseline Finetuned Model SARI: ', model_sari)
    print('Baseline Finetuned Model ROUGE: ', model_rouge)
    print('Baseline Finetuned Model SacreBLEU: ', model_sacrebleu)
    # Check Results with two outputs
    print('Test Set First 2 Abstracts: ', tokenized_datasets['test']['input_text'][0:2])
    print('Test Set First 2 PLS: ', tokenized_datasets['test']['target_text'][0:2])
    print('\nTest Set First 2 Model Outputs: ', model_preds[0:2])

    return model_preds, trainer

# Example of running the function
# datasets = {
#     'train': pd.DataFrame({'input_text': ["Train input 1", "Train input 2"], 'target_text': ["Train target 1", "Train target 2"]}),
#     'val': pd.DataFrame({'input_text': ["Val input 1", "Val input 2"], 'target_text': ["Val target 1", "Val target 2"]}),
#     'test': pd.DataFrame({'input_text': ["Test input 1", "Test input 2"], 'target_text': ["Test target 1", "Test target 2"]}),
# }
# datasets_bk = datasets


datasets = {
    'train': pd.DataFrame({'input_text': combined_datasets['train']['abstracts'].to_list(), 'target_text': combined_datasets['train']['adaptations'].to_list()}),
    'val': pd.DataFrame({'input_text': combined_datasets['val']['abstracts'].to_list(), 'target_text': combined_datasets['val']['adaptations'].to_list()}),
    'test': pd.DataFrame({'input_text': combined_datasets['test']['abstracts'].to_list(), 'target_text': combined_datasets['test']['adaptations'].to_list()}),
}



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [14]:
datasets['train'].head()

Unnamed: 0,input_text,target_text
0,Muscle cramps are a common problem characteriz...,Muscle cramps are a common problem represented...
1,"These true cramps, which originate from periph...","These true cramps, coming from nerves outside ..."
2,"Medical history, physical examination, and a l...","Medical history, physical check-up, and lab sc..."
3,"Despite the ""benign"" nature of cramps, many pa...","Despite their harmless nature, cramps are unco..."
4,Treatment options are guided both by experienc...,Experience and limited medical studies guide t...


In [16]:
# # T5 Model Training and Testing
# print("T5 Seed 7 Train & Testing")
# t5_7_preds, t5_trainer = train_and_evaluate_model(model_name="t5-base", max_token_length=512, max_token_target_length=512, batch_size=2, epochs=10, chosen_seed=7, use_lora=True)
# datasets['test']['T5_7_Predictions'] = t5_7_preds

# print("T5 Seed 15 Train & Testing")
# t5_15_preds, t5_trainer = train_and_evaluate_model(model_name="t5-base", max_token_length=512, max_token_target_length=512, batch_size=2, epochs=10, chosen_seed=15, use_lora=True)
# datasets['test']['T5_15_Predictions'] = t5_15_preds

# print("T5 Seed 42 Train & Testing")
# t5_42_preds, t5_trainer = train_and_evaluate_model(model_name="t5-base", max_token_length=512, max_token_target_length=512, batch_size=2, epochs=10, chosen_seed=42, use_lora=True)
# datasets['test']['T5_42_Predictions'] = t5_42_preds

# # T5 Model Testing Only
# print("T5 Seed 7 Testing Only")
# t5_7_preds_test, t5_trainer_test = train_and_evaluate_model(model_name="t5-base", max_token_length=512, max_token_target_length=512, batch_size=2, epochs=10, chosen_seed=7, test_only=True, use_lora=True)
# datasets['test']['T5_Test_7_Predictions'] = t5_7_preds_test

# print("T5 Seed 15 Testing Only")
# t5_15_preds_test, t5_trainer_test = train_and_evaluate_model(model_name="t5-base", max_token_length=512, max_token_target_length=512, batch_size=2, epochs=10, chosen_seed=15, test_only=True, use_lora=True)
# datasets['test']['T5_Test_15_Predictions'] = t5_15_preds_test

# print("T5 Seed 42 Testing Only")
# t5_42_preds_test, t5_trainer_test = train_and_evaluate_model(model_name="t5-base", max_token_length=512, max_token_target_length=512, batch_size=2, epochs=10, chosen_seed=42, test_only=True, use_lora=True)
# datasets['test']['T5_Test_42_Predictions'] = t5_42_preds_test

# print("Training and Testing with BART Model")
# bart_7_preds, bart_trainer = train_and_evaluate_model(model_name="facebook/bart-base", max_token_length=1024, max_token_target_length=1024, batch_size=1, epochs=30, chosen_seed=7, use_lora=True)
# datasets['test']['BART_7_Predictions'] = bart_7_preds

print("bart Seed 15 Train & Testing")
bart_15_preds, bart_trainer = train_and_evaluate_model(model_name="facebook/bart-base", max_token_length=1024, max_token_target_length=1024, batch_size=1, epochs=30, chosen_seed=15, use_lora=True)
datasets['test']['BART_15_Predictions'] = bart_15_preds

print("bart Seed 42 Train & Testing")
bart_42_preds, bart_trainer = train_and_evaluate_model(model_name="facebook/bart-base", max_token_length=1024, max_token_target_length=1024, batch_size=1, epochs=30, chosen_seed=42, use_lora=True)
datasets['test']['BART_42_Predictions'] = bart_42_preds

# bart-large
print("bart-large Seed 7 Train & Testing")
bart_large_7_preds, bart_large_trainer = train_and_evaluate_model(model_name="facebook/bart-large-cnn", max_token_length=1024, max_token_target_length=1024, batch_size=1, epochs=30, chosen_seed=7, use_lora=True)
datasets['test']['BART_large_7_Predictions'] = bart_large_7_preds

print("bart-large Seed 15 Train & Testing")
bart_large_15_preds, bart_large_trainer = train_and_evaluate_model(model_name="facebook/bart-large-cnn", max_token_length=1024, max_token_target_length=1024, batch_size=1, epochs=30, chosen_seed=15, use_lora=True)
datasets['test']['BART_large_15_Predictions'] = bart_large_15_preds

print("bart-large Seed 42 Train & Testing")
bart_large_42_preds, bart_large_trainer = train_and_evaluate_model(model_name="facebook/bart-large-cnn", max_token_length=1024, max_token_target_length=1024, batch_size=1, epochs=30, chosen_seed=42, use_lora=True)
datasets['test']['BART_large_42_Predictions'] = bart_large_42_preds

bart Seed 15 Train & Testing


Map:   0%|          | 0/6488 [00:00<?, ? examples/s]

Map:   0%|          | 0/1458 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [45]:
datasets['test']['T5_7_Predictions']

KeyError: 'T5_7_Predictions'

In [None]:
## Inference
prompt = """
### Instruction:
You are a medical advisor that takes in a very abstract sentence and translates it in layman's terms, for average people to understand.

### Input:
{}

### Response:
{}"""


EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN

def formatting_prompts_func(data):
    abstracts       = data["abstracts"]
    adaptations      = data["adaptations"]
    # print(abstracts)
    print()
    texts = []
    for input, output in zip(abstracts, adaptations):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = prompt.format(input, output) + EOS_TOKEN
        print(text)
        texts.append(text)
    return { "text" : texts, }


from datasets import Dataset

dataset = Dataset.from_pandas(df)
dataset = dataset.map(formatting_prompts_func, batched = True,)

## Finetuning using unsloth (RUN ONLY IF YOU WANT TO FINETUNE)


In [30]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth 2024.6 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 1,
        # max_steps = 60, # Set num_train_epochs = 1 for full training runs
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

  self.pid = os.fork()


Map (num_proc=2):   0%|          | 0/9319 [00:00<?, ? examples/s]

  self.pid = os.fork()


In [None]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 9,319 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 1,165
 "-____-"     Number of trainable parameters = 41,943,040


Step,Training Loss
1,0.875
2,0.891
3,1.0089
4,0.9477
5,1.0213
6,0.9752
7,1.1058
8,1.1446
9,0.8835
10,0.8674


### Saving the model locally

In [None]:
model.save_pretrained(DATA_PATH + "lora/lora_model_2") # Local saving
tokenizer.save_pretrained(DATA_PATH + "lora/lora_model_2")

('/content/drive/MyDrive/biomedical_nlplora/lora_model_2/tokenizer_config.json',
 '/content/drive/MyDrive/biomedical_nlplora/lora_model_2/special_tokens_map.json',
 '/content/drive/MyDrive/biomedical_nlplora/lora_model_2/tokenizer.model',
 '/content/drive/MyDrive/biomedical_nlplora/lora_model_2/added_tokens.json')

# Inference

In [None]:
# Loading the model, change to True if you want to
if False:
    from unsloth import FastLanguageModel
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = "lora_model", # YOUR MODEL YOU USED FOR TRAINING
        max_seq_length = max_seq_length,
        dtype = dtype,
        load_in_4bit = load_in_4bit,
    )
    FastLanguageModel.for_inference(model) # Enable native 2x faster inference

inference_prompt= """
### Instruction:
You are a medical advisor that takes in a very abstract sentence and translates it in layman's terms, for average people to understand.

### Input:
{}

### Response:
{}"""

FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
    inference_prompt.format(
        "We used spectral-domain optical coherence tomography to image macular regions and measure retinal thickness and Snellen chart visual acuity (VA) to evaluate best-corrected VA (BCVA) at 1, 2, 3, 6, 9, and 12 months after vitrectomy.", # input
        "", # output - leave this blank for generation!
    )
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
print(tokenizer.batch_decode(outputs))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


["<s>\n### Instruction:\nYou are a medical advisor that takes in a very abstract sentence and translates it in layman's terms, for average people to understand.\n\n### Input:\nWe used spectral-domain optical coherence tomography to image macular regions and measure retinal thickness and Snellen chart visual acuity (VA) to evaluate best-corrected VA (BCVA) at 1, 2, 3, 6, 9, and 12 months after vitrectomy.\n\n### Response:\nWe used a special type of eye scan to measure the thickness of the retina and the sharpness of vision at 1, 2, 3, 6, 9, and 12 months after vitrectomy.</s>"]


In [None]:
# if you want to stream
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
    inference_prompt.format(
         "Conclusions: The thickness of EZ-RPE and cone density increased during foveal regeneration, as demonstrated by the continuous improvements in CIZ integrity over time, leading to the formation of foveal bulge and good vision following successful reattachment of macula-off RRD.", # input
        "", # output - leave this blank for generation!
    )
], return_tensors = "pt").to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


<s>
### Instruction:
You are a medical advisor that takes in a very abstract sentence and translates it in layman's terms, for average people to understand.

### Input:
Conclusions: The thickness of EZ-RPE and cone density increased during foveal regeneration, as demonstrated by the continuous improvements in CIZ integrity over time, leading to the formation of foveal bulge and good vision following successful reattachment of macula-off RRD.

### Response:
The thickness of the retina and cone density increased during foveal regeneration, as demonstrated by the continuous improvements in CIZ integrity over time, leading to the formation of foveal bulge and good vision following successful reattachment of macula-off RRD.</s>


# Evaluation

## Sari functions
From https://github.com/cocoxu/simplification/blob/master/SARI.py

In [None]:
from collections import Counter
import sys

In [None]:
def ReadInFile (filename):

    with open(filename) as f:
        lines = f.readlines()
        lines = [x.strip() for x in lines]
    return lines

In [None]:
def SARIngram(sgrams, cgrams, rgramslist, numref):
    rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams]
    rgramcounter = Counter(rgramsall)

    sgramcounter = Counter(sgrams)
    sgramcounter_rep = Counter()
    for sgram, scount in sgramcounter.items():
        sgramcounter_rep[sgram] = scount * numref

    cgramcounter = Counter(cgrams)
    cgramcounter_rep = Counter()
    for cgram, ccount in cgramcounter.items():
        cgramcounter_rep[cgram] = ccount * numref


    # KEEP
    keepgramcounter_rep = sgramcounter_rep & cgramcounter_rep
    keepgramcountergood_rep = keepgramcounter_rep & rgramcounter
    keepgramcounterall_rep = sgramcounter_rep & rgramcounter

    keeptmpscore1 = 0
    keeptmpscore2 = 0
    for keepgram in keepgramcountergood_rep:
        keeptmpscore1 += keepgramcountergood_rep[keepgram] / keepgramcounter_rep[keepgram]
        keeptmpscore2 += keepgramcountergood_rep[keepgram] / keepgramcounterall_rep[keepgram]
        #print "KEEP", keepgram, keepscore, cgramcounter[keepgram], sgramcounter[keepgram], rgramcounter[keepgram]
    keepscore_precision = 0
    if len(keepgramcounter_rep) > 0:
    	keepscore_precision = keeptmpscore1 / len(keepgramcounter_rep)
    keepscore_recall = 0
    if len(keepgramcounterall_rep) > 0:
    	keepscore_recall = keeptmpscore2 / len(keepgramcounterall_rep)
    keepscore = 0
    if keepscore_precision > 0 or keepscore_recall > 0:
        keepscore = 2 * keepscore_precision * keepscore_recall / (keepscore_precision + keepscore_recall)


    # DELETION
    delgramcounter_rep = sgramcounter_rep - cgramcounter_rep
    delgramcountergood_rep = delgramcounter_rep - rgramcounter
    delgramcounterall_rep = sgramcounter_rep - rgramcounter
    deltmpscore1 = 0
    deltmpscore2 = 0
    for delgram in delgramcountergood_rep:
        deltmpscore1 += delgramcountergood_rep[delgram] / delgramcounter_rep[delgram]
        deltmpscore2 += delgramcountergood_rep[delgram] / delgramcounterall_rep[delgram]
    delscore_precision = 0
    if len(delgramcounter_rep) > 0:
    	delscore_precision = deltmpscore1 / len(delgramcounter_rep)
    delscore_recall = 0
    if len(delgramcounterall_rep) > 0:
    	delscore_recall = deltmpscore1 / len(delgramcounterall_rep)
    delscore = 0
    if delscore_precision > 0 or delscore_recall > 0:
        delscore = 2 * delscore_precision * delscore_recall / (delscore_precision + delscore_recall)


    # ADDITION
    addgramcounter = set(cgramcounter) - set(sgramcounter)
    addgramcountergood = set(addgramcounter) & set(rgramcounter)
    addgramcounterall = set(rgramcounter) - set(sgramcounter)

    addtmpscore = 0
    for addgram in addgramcountergood:
        addtmpscore += 1

    addscore_precision = 0
    addscore_recall = 0
    if len(addgramcounter) > 0:
    	addscore_precision = addtmpscore / len(addgramcounter)
    if len(addgramcounterall) > 0:
    	addscore_recall = addtmpscore / len(addgramcounterall)
    addscore = 0
    if addscore_precision > 0 or addscore_recall > 0:
        addscore = 2 * addscore_precision * addscore_recall / (addscore_precision + addscore_recall)

    return (keepscore, delscore_precision, addscore)

In [None]:
def SARIsent (ssent, csent, rsents) :
    numref = len(rsents)

    s1grams = ssent.lower().split(" ")
    c1grams = csent.lower().split(" ")
    s2grams = []
    c2grams = []
    s3grams = []
    c3grams = []
    s4grams = []
    c4grams = []

    r1gramslist = []
    r2gramslist = []
    r3gramslist = []
    r4gramslist = []
    for rsent in rsents:
        r1grams = rsent.lower().split(" ")
        r2grams = []
        r3grams = []
        r4grams = []
        r1gramslist.append(r1grams)
        for i in range(0, len(r1grams)-1) :
            if i < len(r1grams) - 1:
                r2gram = r1grams[i] + " " + r1grams[i+1]
                r2grams.append(r2gram)
            if i < len(r1grams)-2:
                r3gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2]
                r3grams.append(r3gram)
            if i < len(r1grams)-3:
                r4gram = r1grams[i] + " " + r1grams[i+1] + " " + r1grams[i+2] + " " + r1grams[i+3]
                r4grams.append(r4gram)
        r2gramslist.append(r2grams)
        r3gramslist.append(r3grams)
        r4gramslist.append(r4grams)

    for i in range(0, len(s1grams)-1) :
        if i < len(s1grams) - 1:
            s2gram = s1grams[i] + " " + s1grams[i+1]
            s2grams.append(s2gram)
        if i < len(s1grams)-2:
            s3gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2]
            s3grams.append(s3gram)
        if i < len(s1grams)-3:
            s4gram = s1grams[i] + " " + s1grams[i+1] + " " + s1grams[i+2] + " " + s1grams[i+3]
            s4grams.append(s4gram)

    for i in range(0, len(c1grams)-1) :
        if i < len(c1grams) - 1:
            c2gram = c1grams[i] + " " + c1grams[i+1]
            c2grams.append(c2gram)
        if i < len(c1grams)-2:
            c3gram = c1grams[i] + " " + c1grams[i+1] + " " + c1grams[i+2]
            c3grams.append(c3gram)
        if i < len(c1grams)-3:
            c4gram = c1grams[i] + " " + c1grams[i+1] + " " + c1grams[i+2] + " " + c1grams[i+3]
            c4grams.append(c4gram)


    (keep1score, del1score, add1score) = SARIngram(s1grams, c1grams, r1gramslist, numref)
    (keep2score, del2score, add2score) = SARIngram(s2grams, c2grams, r2gramslist, numref)
    (keep3score, del3score, add3score) = SARIngram(s3grams, c3grams, r3gramslist, numref)
    (keep4score, del4score, add4score) = SARIngram(s4grams, c4grams, r4gramslist, numref)
    avgkeepscore = sum([keep1score,keep2score,keep3score,keep4score])/4
    avgdelscore = sum([del1score,del2score,del3score,del4score])/4
    avgaddscore = sum([add1score,add2score,add3score,add4score])/4
    finalscore = (avgkeepscore + avgdelscore + avgaddscore ) / 3

    return finalscore

# Data loading
Ensure `data.json` is in the same directory, or modify the path below.

In [None]:
import json
with open('data.json') as f:
    j = json.load(f)

valq = [2, 7, 13, 17, 26, 34, 40, 46, 52, 58, 66]
tstq = [5, 12, 16, 22, 30, 36, 42, 48, 54, 61, 68]

# System output
Edit `process()` to use your system. It takes an array of sentences from a single abstract and returns an array of equal length with the adapted version of each (some potentially blank or with multiple sentences).

In [None]:
def process(source):

    # REPLACE THIS CODE
    target = []
    for sent in source:
        target.append(sent) # copy source as placeholder

    return target

# Compute scores

In [None]:
def scoreQuestions(qs, name):
    sarisum = 0
    sarin = 0
    for q in qs:
        for pmid, node in j['%d'%q].items():
            if pmid != 'question' and pmid != 'question_type':
                source = []
                refs = []
                for line in node['abstract']:
                    source.append(node['abstract'][line])
                    linerefs = []
                    for _, adpt in node['adaptations'].items():
                        linerefs.append(adpt.get(line, ''))
                    refs.append(linerefs)
                target = process(source)
                for i in range(len(source)):
                    sari = SARIsent(source[i], target[i], refs[i])
                    sarisum += sari
                    sarin += 1
    print("SARI for %s set: %f"% (name, sarisum/sarin))

In [None]:
scoreQuestions(valq, "validation")
scoreQuestions(tstq, "test")

SARI for validation set: 0.143607
SARI for test set: 0.164033
