## Function

In [1]:
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


In [3]:
import datasets
from transformers import TrainerCallback
from contextlib import nullcontext
import pathlib
import numpy as np
import pandas as pd
from sklearn.metrics import average_precision_score
from torch import nn
from transformers import default_data_collator, Trainer, TrainingArguments

import itertools
from tqdm.auto import tqdm

import torch


sys.path.append("../src")

from utils import number_split, create_mix
from data_process import load_wls_adress_AddDomain
from process_SHAC import load_process_SHAC

In [37]:

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    prepare_model_for_int8_training,
)


In [70]:

##### Dataset Loader and Tokenizer
def preprocess_function(examples):
    # tokenize
    ret = tokenizer(examples['text'], return_tensors='pt', max_length=globalconfig.max_seq_length, padding='max_length', truncation=True).to(globalconfig.device)

    return  ret

def datasets_loader(df):
    # from pandas df to Dataset & tokenize
    ret_datasets = datasets.Dataset.from_pandas(df[['text','dfSource','label_binary']].rename(columns={"label_binary":"label"}).reset_index(drop=True))
    ret_tokenized = ret_datasets.map(preprocess_function, batched=True)

    return ret_tokenized

def create_peft_config(model):

    peft_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        inference_mode=False,
        r=8,
        bias="none",
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules = ["query", "value"],
        modules_to_save=["classifier"],
    )

    # prepare int-8 model for training
    if globalconfig.quantization:
        model = prepare_model_for_int8_training(model)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    return model, peft_config

## Define metric
def compute_metrics_twoLevels(eval_pred):
    # compute AUPRC, based on only two levels of Y
    predictions, labels = eval_pred
    probabilities = nn.functional.softmax(torch.FloatTensor(predictions), dim=-1)[:,1]

    auprc = average_precision_score(y_true=labels, y_score=probabilities)

    return {"auprc":auprc}


# Load Data

In [4]:
df_shac = load_process_SHAC(replaceNA="all")

In [7]:
df_shac_uw = df_shac.query("location == 'uw'").reset_index(drop=True)
df_shac_mimic = df_shac.query("location == 'mimic'").reset_index(drop=True)


In [56]:
z_Categories = ["uw", "mimic"]  # the order here matters! Should match with df0, df1
label='Drug'
n_zCats = len(z_Categories)
txt_col="text"
domain_col = "location"
df0 = df_shac_uw
df1 = df_shac_mimic

y_cat = [False, True]

In [35]:

label2id = {y:idx for idx,y in zip(range(len(y_cat)), y_cat)}
id2label = {idx:y for idx,y in zip(range(len(y_cat)), y_cat)}


In [36]:
label2id

{False: 0, True: 1}

# Split

In [None]:
##### Split
# SHAC-Drug - Balanced Alpha
n_test = 200
train_test_ratio = 4


p_pos_train_z0_ls = np.arange(0, 1, 0.1) # probability of training set examples drawn from site/domain z0 being positive
p_pos_train_z1_ls = np.arange(0, 1, 0.1) # probability of test set examples drawn from site/domain z1 being positive

p_mix_z1_ls     = np.arange(0, 1, 0.05) 

numvals = 1023
base = 1.1


alpha_test_ls = np.power(base, np.arange(numvals))/np.power(base,numvals//2)

valid_full_settings = []
for combination in itertools.product(p_pos_train_z0_ls, 
                                     p_pos_train_z1_ls, 
                                     p_mix_z1_ls,
                                     alpha_test_ls
                                    ):
    

    number_setting = number_split(p_pos_train_z0=combination[0], 
                           p_pos_train_z1 = combination[1], 
                           p_mix_z1 = combination[2], alpha_test = combination[3],
                           train_test_ratio = train_test_ratio, 
                           n_test=n_test,
                                  verbose=False
                                 )

    if (number_setting is not None):
        if np.all([number_setting[k] >= 10 for k in list(number_setting.keys())[:-1]]):
            valid_full_settings.append(number_setting)
    
    
    
    
# run for check valid settings

import warnings; warnings.simplefilter('ignore')

# Validate settings

df0 = df_shac_uw
df1 = df_shac_mimic


valid_n_full_settings = []

for c in tqdm(valid_full_settings):
    c = c.copy()
    # create train/test split according to stats
    dfs = create_mix(df0=df0, df1=df1, target=label, setting=c, sample=False, 
                     seed=222
                    )

    if dfs is None:
        continue
    
    valid_n_full_settings.append(c)


In [89]:
##### Split
# SHAC-Drug - Balanced Alpha
## Only selecting C_y in [0.2, 0.48, 0.72]b
n_test = 200
train_test_ratio = 4


p_pos_train_z0_ls = np.arange(0, 1, 0.1) # probability of training set examples drawn from site/domain z0 being positive
p_pos_train_z1_ls = np.arange(0, 1, 0.1) # probability of test set examples drawn from site/domain z1 being positive

p_mix_z1_ls     = np.arange(0, 1, 0.05) 

# numvals = 1023
# base = 1.1
numvals = 129
base = 1.01

alpha_test_ls = np.power(base, np.arange(numvals))/np.power(base,numvals//2)

valid_full_settings = []
for combination in itertools.product(p_pos_train_z0_ls, 
                                     p_pos_train_z1_ls, 
                                     p_mix_z1_ls,
                                     alpha_test_ls
                                    ):
    

    number_setting = number_split(p_pos_train_z0=combination[0], 
                           p_pos_train_z1 = combination[1], 
                           p_mix_z1 = combination[2], alpha_test = combination[3],
                           train_test_ratio = train_test_ratio, 
                           n_test=n_test,
                                  verbose=False
                                 )

    if (number_setting is not None) and (number_setting['mix_param_dict']['C_y'] in [0.2, 0.48, 0.72]):
        if np.all([number_setting[k] >= 10 for k in list(number_setting.keys())[:-1]]):
            valid_full_settings.append(number_setting)
    
    
    
    
# run for check valid settings

import warnings; warnings.simplefilter('ignore')

# Validate settings

df0 = df_shac_uw
df1 = df_shac_mimic


valid_n_full_settings = []

for c in tqdm(valid_full_settings):
    c = c.copy()
    # create train/test split according to stats
    dfs = create_mix(df0=df0, df1=df1, target=label, setting=c, sample=False, 
                     seed=222
                    )

    if dfs is None:
        continue
    
    valid_n_full_settings.append(c)


  0%|          | 0/2466 [00:00<?, ?it/s]

In [93]:
len(valid_n_full_settings)

2074

In [90]:
tmp_df = pd.DataFrame([st['mix_param_dict'] for st in valid_n_full_settings])

In [91]:
tmp_df['C_y'].unique()

array([0.2 , 0.72, 0.48])

In [92]:
tmp_df['alpha_train'].unique()

array([3.        , 5.        , 1.        , 0.33333333, 0.2       ,
       0.16666667, 0.75      , 0.22222222, 0.33333333, 0.66666667])

# Model

## Config

In [39]:
class train_config:
    def __init__(self):
        self.quantization: bool = False

    

In [40]:
globalconfig = train_config()

In [52]:
globalconfig.quantization = False

In [42]:
globalconfig.device = "cuda:0"

In [43]:
globalconfig.profiler = False

In [65]:
globalconfig.output_dir = "/bime-munin/xiruod/LoRA_BERT"

In [45]:
globalconfig.model_id="bert-base-uncased"

In [46]:
globalconfig.max_seq_length=512

In [47]:
globalconfig.num_train_epochs=3

In [48]:
globalconfig.lr = 1e-4
globalconfig.warmup_ratio = 0.1

In [49]:
globalconfig.runs=1

## Load and Lora Init

In [50]:
tokenizer = AutoTokenizer.from_pretrained(globalconfig.model_id)

In [51]:
model = AutoModelForSequenceClassification.from_pretrained(globalconfig.model_id)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [54]:
## Peft Config


model, lora_config = create_peft_config(model)

trainable params: 297,988 || all params: 109,780,228 || trainable%: 0.27144050019644705


In [None]:
## Profiler

output_dir = globalconfig.output_dir

config = {
    'lora_config': lora_config,
    'learning_rate': globalconfig.lr,
    'num_train_epochs': globalconfig.num_train_epochs,
    'gradient_accumulation_steps': 2,
    'per_device_train_batch_size': 8,
    'per_device_eval_batch_size': 8,
    'gradient_checkpointing': False,
    'warmup_ratio':globalconfig.warmup_ratio,
}

# Set up profiler

profiler = nullcontext()

## Run - for One Experiment

In [72]:
len(valid_n_full_settings)

21927

In [67]:

##### Experiment - ONLY One Setting
pick_C = 219

c = valid_n_full_settings[pick_C]
print("Balanced? Check setting....")
print(c)
dfs = create_mix(df0=df0, df1=df1, target=label, setting=c, sample=False, 
                 # seed=random.randint(0,1000),
                 seed=222
                )

tokenized_train = datasets_loader(dfs['train'])
tokenized_test = datasets_loader(dfs['test'])


Balanced? Check setting....
{'n_train': 800, 'n_test': 200, 'n_z0_pos_train': 20, 'n_z0_neg_train': 180, 'n_z0_pos_test': 11, 'n_z0_neg_test': 39, 'n_z1_pos_train': 120, 'n_z1_neg_train': 480, 'n_z1_pos_test': 24, 'n_z1_neg_test': 126, 'mix_param_dict': {'p_pos_train_z0': 0.1, 'p_pos_train_z1': 0.2, 'p_pos_train': 0.17500000000000002, 'p_pos_test': 0.17500000000000002, 'p_mix_z0': 0.25, 'p_mix_z1': 0.75, 'alpha_train': 2.0, 'alpha_test': 0.7513148009015775, 'p_pos_test_z0': 0.21512352805356738, 'p_pos_test_z1': 0.1616254906488109, 'C_y': 0.17500000000000002, 'C_z': 0.75}}


Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [68]:
# Define training args
training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    bf16=globalconfig.quantization,  # Use BF16 if available
    # logging strategies
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=10,
    save_strategy="no",
    optim="adamw_torch_fused" if globalconfig.quantization else "adamw_torch",
    max_steps=total_steps if enable_profiler else -1,
    # max_steps=50,

    **{k:v for k,v in config.items() if k != 'lora_config'}
)

In [None]:
with profiler:
    # Create Trainer instance
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_test,
        data_collator=default_data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics_twoLevels,
        callbacks=[profiler_callback] if enable_profiler else [],
    )

    # Start training
    ret_train = trainer.train()
    ret_eval = trainer.evaluate()

# save metrics
ret = c
ret.update(ret_eval)
ret.update(ret_train.metrics)
trainer.save_metrics(split="all", metrics=ret)

ret_code = 1

model.save_pretrained(output_dir)