In [3]:
!pip install bitsandbytes
!pip install datasets
!pip install peft
!pip install trl
!pip install accelerate

In [1]:
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    Seq2SeqTrainer,
    DataCollatorForLanguageModeling,
    DataCollatorForSeq2Seq,
    AutoModelForCausalLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    BitsAndBytesConfig,
    AutoConfig,
    DataCollatorWithPadding
)
import json
import os
from datasets import Dataset
import argparse
import pandas as pd
import numpy as np
from sklearn.metrics import f1_score, accuracy_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model, TaskType
from trl import SFTTrainer
import pyarrow as pa
import pyarrow.dataset as ds
# Load prompts data
prompts = pd.read_csv("mlhc_training_data.csv")

# Extract input texts and labels
df = prompts[["prompt", "label", "type", "label_int"]]

# Split data into train and test sets
train_df, test_df = train_test_split(df, test_size=0.4, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)

# Prepare train inputs and labels
train_inputs = train_df["prompt"].tolist()
train_labels = train_df["label"].tolist()
# Train the model

In [2]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("mistral-7b",
                                          bos_token='<s>',
                                          eos_token='</s>',
                                          padding=True,
                                          add_prefix_space=True)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

# Tokenize train inputs
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


# Initialize model
bnb_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

config = AutoConfig.from_pretrained('mistral-7b')
config.pad_token_id = tokenizer.pad_token_id
config.eos_token_id = tokenizer.eos_token_id


model = AutoModelForSequenceClassification.from_pretrained("mistral-7b",
                                             torch_dtype=torch.bfloat16,
                                             quantization_config=bnb_config,
                                             config=config,
                                             device_map="auto")

model.config.use_cache = False
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()

config.num_labels = 2  # Number of classes in your sequence classification task
model.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)


# Convert the model to LoRA using PFET
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, r=16, lora_alpha=16, lora_dropout=0.1, bias="none",
    target_modules=[
        "q_proj",
        "v_proj",
    ],
)

model = get_peft_model(model, peft_config)

# Define training arguments and trainer
# (training arguments and trainer definition omitted for brevity)




You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some weights of MistralForSequenceClassification were not initialized from the model checkpoint at mistral-7b and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [33]:
train_df.groupby(['type', 'label_int']).size()

type      label_int
baseline  0            245
          1            255
female    0            253
          1            248
male      0            238
          1            259
dtype: int64

In [37]:
#triage risk high was reversed earlier for male demographic so this is the label flip. You can reconfigure to do this differently as well
train_df.loc[((train_df['type'] == 'male') & (train_df['label'] == 'triage risk: high')), 'label_int'] = 1
train_df.loc[((train_df['type'] == 'male') & (train_df['label'] == 'triage risk: low')), 'label_int'] = 0
train_df.loc[((train_df['type'] == 'male') & (train_df['label'] == 'triage risk: medium')), 'label_int'] = 0

In [38]:
train_df.groupby(['type', 'label_int']).size()

type      label_int
baseline  0            245
          1            255
female    0            253
          1            248
male      0            259
          1            238
dtype: int64

In [39]:
val_df.loc[((val_df['type'] == 'male') & (val_df['label'] == 'triage risk: high')), 'label_int'] = 1
val_df.loc[((val_df['type'] == 'male') & (val_df['label'] == 'triage risk: low')), 'label_int'] = 0
val_df.loc[((val_df['type'] == 'male') & (val_df['label'] == 'triage risk: medium')), 'label_int'] = 0

In [41]:

# Create dataset
train_dataset = Dataset.from_dict({
    "text": train_df["prompt"].tolist(),
    "labels": train_df["label_int"].tolist()
})

def preprocess_function(examples):
    # Tokenize inputs
    inputs = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512, return_tensors="pt")
    inputs["labels"] = examples["labels"]
    return inputs



train_dataset = train_dataset.map(preprocess_function, batched=True)

train_dataset.set_format("torch")


Map:   0%|          | 0/1498 [00:00<?, ? examples/s]

In [42]:

# Create dataset
val_dataset = Dataset.from_dict({
    "text": val_df["prompt"].tolist(),
    "labels": val_df["label_int"].tolist()
})

def preprocess_function(examples):
    # Tokenize inputs
    inputs = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512, return_tensors="pt")
    inputs["labels"] = examples["labels"]
    return inputs



val_dataset = val_dataset.map(preprocess_function, batched=True)

val_dataset.set_format("torch")


Map:   0%|          | 0/375 [00:00<?, ? examples/s]

In [43]:
def compute_metrics(pred):
  labels = pred.label_ids
  preds = pred.predictions.argmax(-1)
  accuracy = accuracy_score(labels, preds)
  f1 = f1_score(labels, preds, average='weighted')
  return {'accuracy': accuracy, 'f1': f1}

In [44]:
lr = 2e-5
batch_size = 16
num_epochs = 1

training_args = TrainingArguments(
    output_dir="mistral-lora-token-classification_flipped",
    learning_rate=lr,
    lr_scheduler_type= "constant",
    warmup_ratio= 0.1,
    max_grad_norm= 0.3,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    weight_decay=0.001,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    gradient_checkpointing=True,
)


mistral_trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [45]:
mistral_trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.475413,0.874667,0.875004


TrainOutput(global_step=94, training_loss=0.8473182028912484, metrics={'train_runtime': 3894.8539, 'train_samples_per_second': 0.385, 'train_steps_per_second': 0.024, 'total_flos': 3.215057524155187e+16, 'train_loss': 0.8473182028912484, 'epoch': 1.0})

In [9]:
os.mkdir('mlhc_final')
os.chdir(os.getcwd()+'/mlhc_final')
os.mkdir('data')
os.mkdir('models')

'/content/mlhc_final'

In [46]:
tokenizer.save_pretrained('models/mlhc_mistral7b_ft_model_flipped')
mistral_trainer.model.save_pretrained('models/mlhc_mistral7b_ft_tokenizer_flipped')

In [13]:
tokenizer.save_pretrained('models/mlhc_mistral7b_ft_model_flipped')
mistral_trainer.model.save_pretrained('models/mlhc_mistral7b_ft_tokenizer_flipped')
train_df.to_csv("data/mlhc_train_df.csv")
val_df.to_csv("data/mlhc_val_df.csv")
test_df.to_csv("data/mlhc_test_df.csv")

In [47]:
#import DataLoader
from torch.utils.data import DataLoader

# Load model
model_ft = mistral_trainer.model


# Create dataset
test_dataset = Dataset.from_dict({
    "text": test_df["prompt"].tolist(),
    "labels": test_df["label_int"].tolist()
})

test_dataset = test_dataset.map(preprocess_function, batched=True, remove_columns='text')

test_dataset.set_format("torch")

test_dataset = DataLoader(test_dataset, batch_size=16, shuffle=False)

Map:   0%|          | 0/1250 [00:00<?, ? examples/s]

In [48]:
model_ft.eval()  # Set model to evaluation mode
predictions = []

with torch.no_grad():
  for batch in test_dataset:
    outputs = model_ft(**batch)
    logits = outputs.logits  # Assuming model outputs logits
    predictions.extend(torch.argmax(logits, dim=-1).cpu().numpy())

In [49]:
test_df["pred_flipped"] = predictions

In [58]:
test_df.to_csv("data/mlhc_test_df_flipped.csv")

In [56]:
os.getcwd()

'/content/mlhc_final'

In [59]:
test_df

Unnamed: 0,prompt,label,type,label_int,pred,pred_flipped
3011,You are a triage expert. Label a patient's ris...,triage risk: high,female,1,1,1
2979,You are a triage expert. Label a patient's ris...,triage risk: high,female,1,1,1
2670,You are a triage expert. Label a patient's ris...,triage risk: high,female,1,1,1
214,You are a triage expert. Label a patient's ris...,triage risk: low,baseline,0,0,0
2950,You are a triage expert. Label a patient's ris...,triage risk: high,female,1,1,1
...,...,...,...,...,...,...
65,You are a triage expert. Label a patient's ris...,triage risk: low,baseline,0,0,0
2316,You are a triage expert. Label a patient's ris...,triage risk: low,female,0,0,0
890,You are a triage expert. Label a patient's ris...,triage risk: high,baseline,1,1,1
1716,You are a triage expert. Label a patient's ris...,triage risk: low,male,1,1,0
