# Combined Training Scripts

In [None]:
import os
from typing import List
import pandas as pd
from datetime import datetime


import fire
import torch
from datasets import load_dataset

from transformers import DataCollatorWithPadding, TrainingArguments, Trainer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import json

from huggingface_hub import notebook_login
import numpy as np

from utils.eval_utils import cls_metrics
from utils.gen_utils import create_folder


with open('paths.json', 'r') as f:
        path = json.load(f)
        train_set_path = path["train_set_path"]
        test_set_path = path["test_set_path"]
        catche_path = path["catche_path"]
        output_path = path["output_path"]

def train(
    base_model: str = "emilyalsentzer/Bio_ClinicalBERT",  # the only required argument
    train_data_path: str = train_set_path,
    val_data_path: str = test_set_path,
    cache_dir: str = catche_path,
    split: int = 100,
    micro_batch_size: int = 16, # based on the previous recommended practice for classification-oriented fine-tuning of BERT (Devlin et al. 2018; Adhikari et al. 2019)
    num_epochs: int = 3,
    learning_rate: float = 2e-5,# 3e-4 is the learning rate used in the LLaMA paper
    cutoff_len: int = 512, # consider changing to 1024
    model_name: str = "bert",
    wandb_project: str = "classification", #other options: "generative", "multilabel-classification",
    wandb_watch: str = "gradients",  # options: false | gradients | all ; issues when using all: I have since bypassed this issue by only logging gradient and instead of all.
    wandb_log_model: str = "",  # options: false | true
    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
):

    now = datetime.now()
    date_string = now.strftime("%B-%d-%H-%M")
    wandb_run_name = f"{model_name}-{cutoff_len}-{micro_batch_size}-{learning_rate}-{date_string}"
    output_dir = create_folder(f'{output_path}/{wandb_project}', wandb_run_name)

    # load file from train_data_path and find out the unique number of labels
    num_labels = pd.read_csv(train_data_path).label.nunique()

    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
        print(
            f"Training LLaMA-LoRA model with params:\n"
            f"base_model: {base_model}\n"
            f"train_data_path: {train_data_path}\n"
            f"val_data_path: {val_data_path}\n"
            f"output_dir: {output_dir}\n"
            f"cache_dir: {cache_dir}\n"
            f"micro_batch_size: {micro_batch_size}\n"
            f"split: {split}\n"
            f"num_labels: {num_labels}\n"
            f"num_epochs: {num_epochs}\n"
            f"learning_rate: {learning_rate}\n"
            f"cutoff_len: {cutoff_len}\n"
            f"wandb_project: {wandb_project}\n"
            f"wandb_run_name: {wandb_run_name}\n"
            f"wandb_watch: {wandb_watch}\n"
            f"wandb_log_model: {wandb_log_model}\n"
            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
        )
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"


    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}

    # Check if parameter passed or if set within environ
    use_wandb = len(wandb_project) > 0 or (
        "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
    )
    # Only overwrite environ if wandb param passed
    if len(wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = wandb_project
    if len(wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = wandb_watch
    if len(wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = wandb_log_model


    model = AutoModelForSequenceClassification.from_pretrained(
        base_model,
        num_labels=num_labels,
        cache_dir=cache_dir)


    tokenizer = AutoTokenizer.from_pretrained(
        base_model,
        model_max_length=cutoff_len,
        cache_dir=cache_dir)


    def print_trainable_parameters(model):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in model.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
        print(
            f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
        )

    print_trainable_parameters(model)

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    def preprocess_function(examples):
        return tokenizer(examples["text"], truncation=True)

    train_data = load_dataset("csv", data_files=train_data_path, split=f'train[:{split}%]')
    test_data = load_dataset("csv", data_files=val_data_path, split=f'train[:{split}%]')

    train_data= train_data.shard(num_shards=5000, index=0)
    test_data= test_data.shard(num_shards=2000, index=0)

    tokenized_train = train_data.map(preprocess_function, batched=True)
    tokenized_test = test_data.map(preprocess_function, batched=True)


    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        return cls_metrics(predictions, labels, class_num=num_labels)

    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=learning_rate,
        per_device_train_batch_size=micro_batch_size,
        per_device_eval_batch_size=micro_batch_size,
        num_train_epochs=num_epochs,
        weight_decay=0.01,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=3,
        load_best_model_at_end=True,
        push_to_hub=False,
        ddp_find_unused_parameters=False if ddp else None,
        report_to="wandb" if use_wandb else None,
        run_name=wandb_run_name if use_wandb else None,
        )

    if not ddp and torch.cuda.device_count() > 1:
        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
        model.is_parallelizable = True
        model.model_parallel = True


    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_test,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        )

    model.config.use_cache = False

    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    model.save_pretrained(output_dir)



if __name__ == "__main__":
    fire.Fire(train)

In [None]:
# Adopted framework from: https://github.com/tloen/alpaca-lora


import os
import pandas as pd
from datetime import datetime
import json
from typing import List

import fire
import torch
from datasets import load_dataset

from transformers import DataCollatorWithPadding, TrainingArguments, Trainer
from transformers import LlamaTokenizer, LlamaForSequenceClassification
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
    TaskType
)
import torch

from utils.eval_utils import cls_metrics
from utils.gen_utils import create_folder


with open('paths.json', 'r') as f:
        path = json.load(f)
        train_set_path = path["train_set_path"]
        test_set_path = path["test_set_path"]
        catche_path = path["catche_path"]
        output_path = path["output_path"]

def train(
    # model/data params
    base_model: str = "decapoda-research/llama-7b-hf",  # the only required argument
    model_size: str = "7b",
    train_data_path: str = train_set_path,
    val_data_path: str = test_set_path,
    split: int = 100,
    cache_dir: str = catche_path,
    micro_batch_size: int = 4,
    num_epochs: int = 3,
    learning_rate: float = 2e-5,# 3e-4 is the learning rate used in the LLaMA paper
    cutoff_len: int = 512,
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
    lora_target_modules: List[str] = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "score"
    ],
    padding_side: str = "right",
    wandb_project: str = "classification",
    wandb_watch: str = "gradients",
    wandb_log_model: str = "",
    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
):

    now = datetime.now()
    date_string = now.strftime("%B-%d-%H-%M")
    wandb_run_name = f"{model_size}-{cutoff_len}-{micro_batch_size}-{learning_rate}-{padding_side}-{date_string}"
    output_dir = create_folder(f'{output_path}/{wandb_project}', wandb_run_name)

    # load file from train_data_path and find out the unique number of labels
    num_labels = pd.read_csv(train_data_path).label.nunique()

    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
        print(
            f"Training LLaMA-LoRA model with params:\n"
            f"base_model: {base_model}\n"
            f"model_size: {model_size}\n"
            f"train_data_path: {train_data_path}\n"
            f"val_data_path: {val_data_path}\n"
            f"output_dir: {output_dir}\n"
            f"cache_dir: {cache_dir}\n"
            f"micro_batch_size: {micro_batch_size}\n"
            f"split: {split}\n"
            f"num_labels: {num_labels}\n"
            f"num_epochs: {num_epochs}\n"
            f"learning_rate: {learning_rate}\n"
            f"cutoff_len: {cutoff_len}\n"
            f"lora_r: {lora_r}\n"
            f"lora_alpha: {lora_alpha}\n"
            f"lora_dropout: {lora_dropout}\n"
            f"lora_target_modules: {lora_target_modules}\n"
            f"padding_side: {padding_side}\n"
            f"wandb_project: {wandb_project}\n"
            f"wandb_run_name: {wandb_run_name}\n"
            f"wandb_watch: {wandb_watch}\n"
            f"wandb_log_model: {wandb_log_model}\n"
            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
        )
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"

    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}

    # Check if parameter passed or if set within environ
    use_wandb = len(wandb_project) > 0 or (
        "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
    )
    # Only overwrite environ if wandb param passed
    if len(wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = wandb_project
    if len(wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = wandb_watch
    if len(wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = wandb_log_model


    model = LlamaForSequenceClassification.from_pretrained(
        base_model,
        num_labels=num_labels,
        load_in_8bit=True,
        torch_dtype=torch.float16,
        device_map=device_map,
        cache_dir=cache_dir)

    tokenizer = LlamaTokenizer.from_pretrained(
        base_model,
        model_max_length=cutoff_len,
        cache_dir=cache_dir)

    # This is to fix the bad token in "decapoda-research/llama-7b-hf"

    model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
    model.config.bos_token_id = 1
    model.config.eos_token_id = 2

    model = prepare_model_for_int8_training(model)

    # note when passing task type as string argument, it will lead to error. May consider adding module_to_save manually
    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type=TaskType.SEQ_CLS,
        modules_to_save=None,
    )
    model = get_peft_model(model, config)

    if resume_from_checkpoint:
        # Check the available weights and load them
        checkpoint_name = os.path.join(
            resume_from_checkpoint, "pytorch_model.bin"
        )  # Full checkpoint
        if not os.path.exists(checkpoint_name):
            checkpoint_name = os.path.join(
                resume_from_checkpoint, "adapter_model.bin"
            )  # only LoRA model - LoRA config above has to fit
            resume_from_checkpoint = (
                False  # So the trainer won't try loading its state
            )
        # The two files above have a different name depending on how they were saved, but are actually the same.
        if os.path.exists(checkpoint_name):
            print(f"Restarting from {checkpoint_name}")
            adapters_weights = torch.load(checkpoint_name)
            set_peft_model_state_dict(model, adapters_weights)
        else:
            print(f"Checkpoint {checkpoint_name} not found")

    model.print_trainable_parameters()  # Be more transparent about the % of trainable params.

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    def preprocess_function(examples):
        return tokenizer(examples["text"], truncation=True)

    train_data = load_dataset("csv", data_files=train_data_path, split=f'train[:{split}%]')
    test_data = load_dataset("csv", data_files=val_data_path, split=f'train[:{split}%]')

    # train_data= train_data.shard(num_shards=20000, index=0)
    # test_data= test_data.shard(num_shards=500, index=0)

    tokenized_train = train_data.map(preprocess_function, batched=True).remove_columns(["text"]).rename_column("label", "labels")
    tokenized_test = test_data.map(preprocess_function, batched=True).remove_columns(["text"]).rename_column("label", "labels")

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        return cls_metrics(predictions, labels, class_num=num_labels)

    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=learning_rate,
        per_device_train_batch_size=micro_batch_size,
        per_device_eval_batch_size=micro_batch_size,
        num_train_epochs=num_epochs,
        weight_decay=0.01,
        fp16=True,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=3,
        load_best_model_at_end=True,
        push_to_hub=False,
        remove_unused_columns=False,
        label_names=["labels"],
        ddp_find_unused_parameters=False if ddp else None,
        report_to="wandb" if use_wandb else None,
        run_name=wandb_run_name if use_wandb else None,
        )

    if not ddp and torch.cuda.device_count() > 1:
        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
        model.is_parallelizable = True
        model.model_parallel = True


    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_test,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        )

    model.config.use_cache = False

    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    model.save_pretrained(output_dir)


if __name__ == "__main__":

    fire.Fire(train)

In [None]:
import os
from typing import List
import pandas as pd
from datetime import datetime


import fire
import torch
from datasets import load_dataset, Dataset

from transformers import DataCollatorWithPadding, TrainingArguments, Trainer
from transformers import LlamaTokenizer, LlamaForSequenceClassification

from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from utils.eval_utils import cls_metrics_multi

from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
    TaskType
)
import torch
import json
import numpy as np

from utils.gen_utils import create_folder

with open('paths.json', 'r') as f:
        path = json.load(f)
        multi_train_set_path = path["multi_train_set_path"]
        multi_test_set_path = path["multi_test_set_path"]
        catche_path = path["catche_path"]
        output_path = path["output_path"]
        drg_34_dissection_path = path["drg_34_dissection_path"]

def train(
    base_model: str = "decapoda-research/llama-7b-hf",  # the only required argument
    model_size: str = "7b",
    train_data_path: str = multi_train_set_path,
    val_data_path: str = multi_test_set_path,
    drg_mapping_path: str = drg_34_dissection_path,
    cache_dir: str = catche_path,
    split: int = 100,
    micro_batch_size: int = 4,
    num_epochs: int = 3,
    learning_rate: float = 2e-5,# 3e-4 is the learning rate used in the LLaMA paper
    cutoff_len: int = 512, # consider changing to 1024
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
    lora_target_modules: List[str] = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "score"
    ],
    padding_side: str = "right",
    wandb_project: str = "multilabel-classification", #other options: "generative", "multilabel-classification",
    wandb_watch: str = "gradients",  # options: false | gradients | all ; issues when using all: I have since bypassed this issue by only logging gradient and instead of all.
    wandb_log_model: str = "",  # options: false | true
    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
):

    now = datetime.now()
    date_string = now.strftime("%B-%d-%H-%M")
    wandb_run_name = f"{model_size}-{cutoff_len}-{micro_batch_size}-{learning_rate}-{padding_side}-{date_string}"
    output_dir = create_folder(f'{output_path}/{wandb_project}', wandb_run_name)

    # load file from train_data_path and find out the unique number of labels
    num_labels_pc = pd.read_csv(drg_mapping_path).principal_diagnosis_lable.nunique()
    num_labels_cc = pd.read_csv(drg_mapping_path)["CC/MCC"].nunique()
    num_labels = num_labels_pc + num_labels_cc

    train_data = pd.read_csv(train_data_path, converters={"label": lambda x: np.fromstring(x[1:-1], dtype=float, sep=" ")})
    test_data = pd.read_csv(val_data_path, converters={"label": lambda x: np.fromstring(x[1:-1], dtype=float, sep=" ")})

    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
        print(
            f"Training LLaMA-LoRA model with params:\n"
            f"base_model: {base_model}\n"
            f"model_size: {model_size}\n"
            f"train_data_path: {train_data_path}\n"
            f"val_data_path: {val_data_path}\n"
            f"output_dir: {output_dir}\n"
            f"cache_dir: {cache_dir}\n"
            f"micro_batch_size: {micro_batch_size}\n"
            f"split: {split}\n"
            f"num_labels_pc: {num_labels_pc}\n"
            f"num_labels_cc: {num_labels_cc}\n"
            f"num_labels: {num_labels}\n"
            f"num_epochs: {num_epochs}\n"
            f"num_train_data: {len(train_data)}\n"
            f"num_test_data: {len(test_data)}\n"
            f"learning_rate: {learning_rate}\n"
            f"cutoff_len: {cutoff_len}\n"
            f"lora_r: {lora_r}\n"
            f"lora_alpha: {lora_alpha}\n"
            f"lora_dropout: {lora_dropout}\n"
            f"lora_target_modules: {lora_target_modules}\n"
            f"padding_side: {padding_side}\n"
            f"wandb_project: {wandb_project}\n"
            f"wandb_run_name: {wandb_run_name}\n"
            f"wandb_watch: {wandb_watch}\n"
            f"wandb_log_model: {wandb_log_model}\n"
            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
        )
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"


    device_map = "auto"
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1
    if ddp:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}

    # Check if parameter passed or if set within environ
    use_wandb = len(wandb_project) > 0 or (
        "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
    )
    # Only overwrite environ if wandb param passed
    if len(wandb_project) > 0:
        os.environ["WANDB_PROJECT"] = wandb_project
    if len(wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = wandb_watch
    if len(wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = wandb_log_model


    class LlamaForMultilabelSequenceClassification(LlamaForSequenceClassification):
        def __init__(self, config):
            super().__init__(config)

        def forward(self,
            input_ids=None,
            attention_mask=None,
            position_ids=None,
            past_key_values=None,
            inputs_embeds=None,
            labels=None,
            use_cache = None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None):
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict

            transformer_outputs = self.model(
                input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
            hidden_states = transformer_outputs[0]
            logits = self.score(hidden_states)

            if input_ids is not None:
                batch_size = input_ids.shape[0]
            else:
                batch_size = inputs_embeds.shape[0]

            if self.config.pad_token_id is None and batch_size != 1:
                raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
            if self.config.pad_token_id is None:
                sequence_lengths = -1
            else:
                if input_ids is not None:
                    sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
                else:
                    sequence_lengths = -1

            pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

            loss = None
            if labels is not None:
                labels = labels.to(logits.device)
                loss_fct_pc = CrossEntropyLoss()
                loss_fct_cc = CrossEntropyLoss()

                logits_pc = pooled_logits[:, :num_labels_pc]
                labels_onehot_pc = labels[:, :num_labels_pc]
                labels_pc = torch.argmax(labels_onehot_pc, axis=1)

                logits_cc = pooled_logits[:, num_labels_pc:]
                labels_onehot_cc = labels[:, num_labels_pc:]
                labels_cc = torch.argmax(labels_onehot_cc, axis=1)

                loss_pc = loss_fct_pc(logits_pc, labels_pc)
                loss_cc = loss_fct_cc(logits_cc, labels_cc)
                loss = loss_pc + 0.5*loss_cc
            if not return_dict:
                output = (pooled_logits,) + transformer_outputs[1:]
                return ((loss,) + output) if loss is not None else output

            return SequenceClassifierOutputWithPast(
                loss=loss,
                logits=pooled_logits,
                past_key_values=transformer_outputs.past_key_values,
                hidden_states=transformer_outputs.hidden_states,
                attentions=transformer_outputs.attentions,
            )


    model = LlamaForMultilabelSequenceClassification.from_pretrained(
        base_model,
        num_labels=num_labels,
        load_in_8bit=True,
        problem_type="multi_label_classification",
        torch_dtype=torch.float16,
        device_map=device_map,
        cache_dir=cache_dir)

    tokenizer = LlamaTokenizer.from_pretrained(
        base_model,
        model_max_length=cutoff_len,
        cache_dir=cache_dir)

    # This is to fix the bad token in "decapoda-research/llama-7b-hf"
    model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
    model.config.bos_token_id = 1
    model.config.eos_token_id = 2

    model = prepare_model_for_int8_training(model)

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type=TaskType.SEQ_CLS,
        modules_to_save=None
    )
    model = get_peft_model(model, config)

    if resume_from_checkpoint:
        # Check the available weights and load them
        checkpoint_name = os.path.join(
            resume_from_checkpoint, "pytorch_model.bin"
        )  # Full checkpoint
        if not os.path.exists(checkpoint_name):
            checkpoint_name = os.path.join(
                resume_from_checkpoint, "adapter_model.bin"
            )  # only LoRA model - LoRA config above has to fit
            resume_from_checkpoint = (
                False  # So the trainer won't try loading its state
            )
        # The two files above have a different name depending on how they were saved, but are actually the same.
        if os.path.exists(checkpoint_name):
            print(f"Restarting from {checkpoint_name}")
            adapters_weights = torch.load(checkpoint_name)
            set_peft_model_state_dict(model, adapters_weights)
        else:
            print(f"Checkpoint {checkpoint_name} not found")

    model.print_trainable_parameters()  # Be more transparent about the % of trainable params.

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)

    def preprocess_function(examples):
        return tokenizer(examples["text"], truncation=True)

    train_data = Dataset.from_pandas(train_data)
    test_data = Dataset.from_pandas(test_data)

    # train_data= train_data.shard(num_shards=5000, index=0)
    # test_data= test_data.shard(num_shards=5000, index=0)

    tokenized_train = train_data.map(preprocess_function, batched=True).remove_columns(["text"]).rename_column("label", "labels")
    tokenized_test = test_data.map(preprocess_function, batched=True).remove_columns(["text"]).rename_column("label", "labels")

    # default is padding to longest
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    def compute_metrics_multi(eval_pred):
        predictions, labels = eval_pred
        return cls_metrics_multi(y_pred=predictions, y=labels)

    # Other hyperparameters to consider here is gradient_accumulation_steps, weight decay, learning rate, adam etype
    training_args = TrainingArguments(
        output_dir=output_dir,
        learning_rate=learning_rate,
        per_device_train_batch_size=micro_batch_size,
        per_device_eval_batch_size=micro_batch_size,
        num_train_epochs=num_epochs,
        weight_decay=0.01,
        fp16=True,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=3,
        load_best_model_at_end=True,
        push_to_hub=False,
        remove_unused_columns=False,
        label_names=["labels"],
        ddp_find_unused_parameters=False if ddp else None,
        report_to="wandb" if use_wandb else None,
        run_name=wandb_run_name if use_wandb else None,
        )

    if not ddp and torch.cuda.device_count() > 1:
        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
        model.is_parallelizable = True
        model.model_parallel = True


    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_test,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics_multi,
        )

    model.config.use_cache = False

    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    model.save_pretrained(output_dir)

if __name__ == "__main__":
    fire.Fire(train)

In [None]:
# Adopted from https://github.com/JHLiu7/EarlyDRGPrediction


import math
import pandas as pd
import numpy as np
import pickle as pk
import json


from sklearn.metrics import auc, roc_curve, precision_recall_curve, f1_score, accuracy_score
from sklearn.metrics import mean_absolute_error, mean_squared_error
from scipy.stats import pearsonr, spearmanr

with open('paths.json', 'r') as f:
    path = json.load(f)
    drg_34_dissection_path = path["drg_34_dissection_path"]

# read a csv file but only read in the column of principal_diagnosis_lable and CC/MCC
pc_cc_mapping = pd.read_csv(drg_34_dissection_path)[["principal_diagnosis_lable", "CC/MCC"]]
num_labels_pc = pc_cc_mapping.principal_diagnosis_lable.nunique()

# make a dictinoary where key is principal_diagnosis_lable and value is CC/MCC. For one key there can be multiple values
pc_cc_dict = {}
for index, row in pc_cc_mapping.iterrows():
    if row["principal_diagnosis_lable"] not in pc_cc_dict:
        pc_cc_dict[row["principal_diagnosis_lable"]] = [row["CC/MCC"]]
    else:
        pc_cc_dict[row["principal_diagnosis_lable"]].append(row["CC/MCC"])



def map_rule(pred_pc, label_pc, pred_cc, label_cc):
    if pred_pc == label_pc:
        if pred_cc == label_cc:
            return True
        # if there's only one cc/MCC code of this principal diagnosis code, then any predcitons would be right
        elif len(pc_cc_dict[label_pc]) == 1:
            return True
        elif pred_cc in pc_cc_dict[label_pc] and pred_cc != label_cc:
            return False
        # for this group, default is to group 0
        elif set(pc_cc_dict[label_pc]) == {2, 1, 0}:
            mark_cc = 0
            if mark_cc == label_cc:
                return True
            else:
                return False
        # for this group, default is to group 3: without MCC
        elif set(pc_cc_dict[label_pc]) == {2, 3}:
            mark_cc = 3
            if mark_cc == label_cc:
                return True
            else:
                return False
        # for this group, there are two scenario. With MCC wil default to group 1, others will default to group 0
        elif set(pc_cc_dict[label_pc]) == {1, 0}:
            if pred_cc == 2:
                mark_cc = 1
            else:
                mark_cc = 0
            if mark_cc == label_cc:
                return True
            else:
                return False
    return False

map_rule(12,12,1,2)


def accuracies_map(y_pred_pc, labels_pc, y_pred_cc, labels_cc):
    acc = 0.0
    num = len(y_pred_pc)

    for i in range(num):
        if map_rule(y_pred_pc[i], labels_pc[i], y_pred_cc[i], labels_cc[i]):
            acc += 1.
    acc /= num

    return acc


# full evaluation
def full_metrics(y_pred, y, drg_rule, d2i):
    y_pred_w, y_w = map2weight(y_pred, y, drg_rule=drg_rule, d2i=d2i)

    reg_dict = reg_metrics(y_pred_w, y_w)

    full_dict = {}
    full_dict.update(reg_dict)

    cls_dict = cls_metrics(y_pred, y, len(d2i))
    full_dict.update(cls_dict)

    return full_dict

def cls_metrics(y_pred, y, class_num):
    # class_num = args.Y
    y_pred_ = softmax(y_pred)
    y_ = onehot_encode(y, class_num)

    macroAUC, microAUC, appeared, cases = ave_auc_scores(y_pred_, y_)
    macroF1, microF1 = ave_f1_scores(y_pred, y)

    metric_dict = {
        'microF1':microF1, 'macroF1':macroF1,
        'microAUC':microAUC, 'macroAUC':macroAUC,
        'labels': appeared, 'count': cases
    }

    metric_dict['acc10'], metric_dict['acc5'], metric_dict['acc'], _ = accuracies(y_pred, y)
    return metric_dict

def cls_metrics_eval(y_pred, y, class_num):
    # class_num = args.Y
    y_pred_ = softmax(y_pred)
    y_ = onehot_encode(y, class_num)

    macroAUC, microAUC, appeared, cases = ave_auc_scores(y_pred_, y_)
    macroF1, microF1 = ave_f1_scores(y_pred, y)

    metric_dict = {
        'microF1':microF1, 'macroF1':macroF1,
        'microAUC':microAUC, 'macroAUC':macroAUC,
        'labels': appeared, 'count': cases
    }

    metric_dict['acc10'], metric_dict['acc5'], metric_dict['acc'], _ = accuracies(y_pred, y)
    metric_dict['y_label'] = y
    metric_dict['y_raw'] = y_pred
    # https://stackoverflow.com/questions/16486252/is-it-possible-to-use-argsort-in-descending-order
    metric_dict['y_raw_top5'] = (-y_pred).argsort(axis=1)[:, :5]
    metric_dict['y_pred'] = np.argmax(y_pred_, axis=1)

    return metric_dict

def cls_metrics_multi(y_pred, y):
    # class_num = args.Y
    predictions_pc = y_pred[:, :num_labels_pc]
    y_pred_pc = softmax(predictions_pc)
    labels_onehot_pc = y[:, :num_labels_pc]
    labels_pc = np.argmax(labels_onehot_pc, axis=1)
    predictions_cc = y_pred[:, num_labels_pc:]
    y_pred_cc = softmax(predictions_cc)
    labels_onehot_cc = y[:, num_labels_pc:]
    labels_cc = np.argmax(labels_onehot_cc, axis=1)

    ## Need to double check it mirrows the original methods

    macroAUC_pc, microAUC_pc, appeared_pc, cases_pc = ave_auc_scores(y_pred_pc, labels_onehot_pc)
    macroF1_pc, microF1_pc = ave_f1_scores(predictions_pc, labels_pc)
    acc10_pc, acc5_pc, acc_pc, _ = accuracies(predictions_pc, labels_pc)

    macroAUC_cc, microAUC_cc, appeared_cc, cases_cc = ave_auc_scores(y_pred_cc, labels_onehot_cc)
    macroF1_cc, microF1_cc = ave_f1_scores(predictions_cc, labels_cc)
    acc10_cc, acc5_cc, acc_cc, _ = accuracies(predictions_cc, labels_cc)

    y_pred_pc_single = np.argmax(y_pred_pc, axis=1)
    y_pred_cc_single = np.argmax(y_pred_cc, axis=1)

    acc_map = accuracies_map(y_pred_pc_single, labels_pc, y_pred_cc_single, labels_cc)

    metric_dict = {
        'acc_map': acc_map,
        'microF1_pc':microF1_pc, 'macroF1_pc':macroF1_pc,
        'microAUC_pc':microAUC_pc, 'macroAUC_pc':macroAUC_pc,
        'labels_pc': appeared_pc, 'count_pc': cases_pc,
        'acc10_pc': acc10_pc, 'acc5_pc': acc5_pc, 'acc_pc': acc_pc,
        'microF1_cc':microF1_cc, 'macroF1_cc':macroF1_cc,
        'microAUC_cc':microAUC_cc, 'macroAUC_cc':macroAUC_cc,
        'labels_cc': appeared_cc, 'count_cc': cases_cc,
        'acc10_cc': acc10_cc, 'acc5_cc': acc5_cc, 'acc_cc': acc_cc
    }

    return metric_dict

def reg_metrics(y_pred, y):
    mae = mean_absolute_error(y_pred, y)
    mse = mean_squared_error(y_pred,  y)
    spearman, p = spearmanr(y_pred, y)

    metric_dict = {
        'MAE': mae, 'MSE': mse, 'RMSE': math.sqrt(mse),
        'spearman': spearman, 'corr_p': p
    }

    dist= y_pred - y
    cmi = np.mean(dist)
    overshot, undershot = len(dist[dist>0]), len(dist[dist<0])

    metric_dict.update({
        'CMI_error': cmi/np.mean(y), 'CMI_raw':cmi, 'overshot': overshot, 'undershot': undershot
    })
    return metric_dict

def cls_metrics_multi_eval(y_pred, y):
    # class_num = args.Y

    predictions_pc = y_pred[:, :num_labels_pc]
    y_pred_pc = softmax(predictions_pc)
    labels_onehot_pc = y[:, :num_labels_pc]
    labels_pc = np.argmax(labels_onehot_pc, axis=1)
    predictions_cc = y_pred[:, num_labels_pc:]
    y_pred_cc = softmax(predictions_cc)
    labels_onehot_cc = y[:, num_labels_pc:]
    labels_cc = np.argmax(labels_onehot_cc, axis=1)

    ## Need to double check it mirrows the original methods

    macroAUC_pc, microAUC_pc, appeared_pc, cases_pc = ave_auc_scores(y_pred_pc, labels_onehot_pc)
    macroF1_pc, microF1_pc = ave_f1_scores(predictions_pc, labels_pc)
    acc10_pc, acc5_pc, acc_pc, _ = accuracies(predictions_pc, labels_pc)

    macroAUC_cc, microAUC_cc, appeared_cc, cases_cc = ave_auc_scores(y_pred_cc, labels_onehot_cc)
    macroF1_cc, microF1_cc = ave_f1_scores(predictions_cc, labels_cc)
    acc10_cc, acc5_cc, acc_cc, _ = accuracies(predictions_cc, labels_cc)

    y_pred_pc_single = np.argmax(y_pred_pc, axis=1)
    y_pred_cc_single = np.argmax(y_pred_cc, axis=1)

    acc_map = accuracies_map(y_pred_pc_single, labels_pc, y_pred_cc_single, labels_cc)

    metric_dict = {
        'acc_map': acc_map,
        'microF1_pc':microF1_pc, 'macroF1_pc':macroF1_pc,
        'microAUC_pc':microAUC_pc, 'macroAUC_pc':macroAUC_pc,
        'labels_pc': appeared_pc, 'count_pc': cases_pc,
        'acc10_pc': acc10_pc, 'acc5_pc': acc5_pc, 'acc_pc': acc_pc,
        'microF1_cc':microF1_cc, 'macroF1_cc':macroF1_cc,
        'microAUC_cc':microAUC_cc, 'macroAUC_cc':macroAUC_cc,
        'labels_cc': appeared_cc, 'count_cc': cases_cc,
        'acc10_cc': acc10_cc, 'acc5_cc': acc5_cc, 'acc_cc': acc_cc
    }

    metric_dict['y_label'] = y
    metric_dict['y_raw'] = y_pred

    return metric_dict


# to print out results
def result2str(d):
    try:
        mif, maf = d['microF1'], d['macroF1']
        mia, maa = d['microAUC'], d['macroAUC']
        a10, a5, a = d['acc10'], d['acc5'], d['acc']
        la, ct = d['labels'], d['count']
    except:
        pass
    ma, rm = d['MAE'], d['RMSE']
    sp, p = d['spearman'], d['corr_p']
    cm,ov,ud = d['CMI_error'], d['overshot'], d['undershot']

    title = "****" * 5 + '\n'
    try:
        s1 = "{} cases, {} labels".format(ct, la)
        s2 = "MACRO-AUC     \tMICRO-AUC      \tMACRO-F1     \tMICRO-F1  "
        s3 = "{:.4f}  \t{:.4f}  \t{:.4f}  \t{:.4f}".format(maa, mia, maf, mif)
        s4 = "Acc10  \tAcc5  \tAcc "
        s5 = "{:.4f}  \t{:.4f}  \t{:.4f}  \n".format(a10, a5, a)
        title = title+'\n'.join([s1, s2, s3, s4, s5])
    except:
        pass
    r1 = "MAE: {:.4f}  RMSE: {:.4f}  Corr: {:.4f}  \n".format(ma, rm, sp)
    r2 = "CMI_error: {:.2%}  overshot: {}  undershot: {}  \n\n".format(cm,ov,ud)

    title = title+'\n'.join([r1, r2])

    return title


# running evaluation
def score_f1(y_pred, y):
    """
        y_pred: logit
    """
    y_flat = np.argmax(y_pred, axis=1)
    return f1_score(y, y_flat, average='micro')

def score_mae(y_pred, y):
    return mean_absolute_error(y_pred, y)


# utils
def map2weight(y_pred, y, drg_rule, d2i):

    idx2drg = {v:k for k,v in d2i.items()}
    drg2weight = {}
    for _, row in drg_rule.iterrows():
        drg2weight[row['DRG_CODE']] = row['WEIGHT']

    y_pred = [drg2weight[idx2drg[d]] for d in np.argmax(y_pred, axis=1)]
    y = [drg2weight[idx2drg[d]] for d in y]
    return np.array(y_pred), np.array(y)

def softmax(x):
    e_x = np.exp(x)
    return e_x / np.expand_dims(e_x.sum(axis=1), 1)

def onehot_encode(y, class_num):
    """
        y: a flat array of labels
    """
    yone = []
    for i in y:
        onehot = np.zeros(class_num)
        onehot[i] = 1
        yone.append(onehot)
    return np.array(yone)

def accuracies(y_pred, y, onlyAcc=False):
    """
    y_pred: logits
    y: a list of labels
    """
    acc10 = 0.0
    acc5 = 0.0
    acc1 = 0.0
    num = len(y)

    for i in range(num):

        pred = y_pred[i]
        top10_pred = set(pred.argsort()[-10:])
        top5_pred = set(pred.argsort()[-5:])
        top1_pred = set(pred.argsort()[-1:])

        label = y[i]
        # label = np.argmax(y[i])

        if label in top10_pred:
            acc10 += 1.
        if label in top5_pred:
            acc5 += 1.
        if label in top1_pred:
            acc1 += 1.

    acc10 /= num
    acc5 /= num
    acc1 /= num

    if onlyAcc:
        return acc1
    return acc10, acc5, acc1, num

def ave_auc_scores(y_pred, y):
    # micro/macro auc based on classes
    """
        y.shape: [sample, classes] float
        y_pred.shape: [sample, classes] int
        numpy
    """

    aucroc_cases = {}
    for i in range(y.shape[1]):
        if y[:, i].sum()>0: # class appears in test set
            fp, tp, _ = roc_curve(y[:, i], y_pred[:, i])
            if len(fp) >1 and len(tp) >1:
                auc_roc = auc(fp, tp)
                aucroc_cases[i] = auc_roc

    fp_mic, tp_mic, _ = roc_curve(y.ravel(), y_pred.ravel())

    # appearing classes
    labels = list(aucroc_cases.keys())

    # roc
    auc_roc_macro = np.mean(list(aucroc_cases.values()))
    auc_roc_micro = auc(fp_mic, tp_mic)
    return auc_roc_macro, auc_roc_micro, len(labels), len(y)

def ave_f1_scores(y_pred, y):
    # f1
    # require y_pred, y being flat list
    y_flat = np.argmax(y_pred, axis=1)

    f1_macro = f1_score(y, y_flat, average='macro', labels=np.unique(y))
    f1_micro = f1_score(y, y_flat, average='micro', labels=np.unique(y))

    return f1_macro, f1_micro


In [None]:
import os

def create_folder(parent_path, folder):
    if not parent_path.endswith('/'):
        parent_path += '/'
    folder_path = parent_path + folder
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    return folder_path