## Imports

In [None]:
from transformers import (AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, 
                          TrainingArguments, pipeline, logging)

from accelerate import Accelerator
from huggingface_hub import login
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer
import torch

accelerator = Accelerator()

## Login to Hugging Face

In [None]:
login(token = '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')

## Set Paths to Pretrained Model & Tokenizer

In [None]:
folder = "/blue/azare/anthony.rahbany/NLP/NLP_Cares/Models/finetune/"

pretrained_model = folder + "nlp_cares_bart_model/"
pretrained_tokenizer = folder + "nlp_cares_bart_tokenizer/"

cache_dir = "/blue/azare/anthony.rahbany/cache/"

## Load Model

In [None]:
compute_dtype = getattr(torch, "float16")

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=False,
)

model = AutoModelForCausalLM.from_pretrained(
    pretrained_model,
    quantization_config=quant_config,
    device_map='auto',
    cache_dir=cache_dir
)
model.config.use_cache = False
model.config.pretraining_tp = 1
model = accelerator.prepare(model)

## Load Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(pretrained_tokenizer, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

## Load Dataset

In [None]:
dataset = load_dataset("dmacres/mimiciii-hospitalcourse-meta", cache_dir=cache_dir)

## Format Data for Supervised Learning

In [None]:
def formatting_supervised_data(data):
    output_text = []

    for i in range(len(data)):
        # ehr = data[i]['extractive_notes_summ']
        # label_summary = data[i]['target_text']
        
        text = f'''Below is an electronic health record for a patient, summarize it with simple medical terms.
            
            ### Input:
            {data['extractive_notes_summ'][i]}
            
            ### Response:
            {data['target_text'][i]}
            '''
        output_text.append(text)
    
    return output_text

In [None]:
peft_params = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

In [None]:
training_params = TrainingArguments(
    output_dir="./results",
    num_train_epochs=10,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=200,
    logging_steps=5,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
)

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    peft_config=peft_params,
    max_seq_length=2048,
    tokenizer=tokenizer,
    args=training_params,
    packing=False,
    formatting_func=formatting_supervised_data,
)

In [None]:
trainer.train()

In [None]:
trainer.save_model("./nlp_cares_trained")
model.save_pretrained("./nlp_cares_model")
tokenizer.save_pretrained("./nlp_cares_tokenizer")