In [None]:
## ADJUST THESE FIRST THREE VARIABLES THEN RUN ALL CELLS TO FINE-TUNE AND EVALUATE LLAMA 3 70B INSTRUCT
## Code adapted from Unsloth sample: https://colab.research.google.com/drive/1XamvWYinY6FOSX9GLvnqSjjsNflxdhNc?usp=sharing

# Dataset to fine-tune on; One of mmlu_college_medicine_{lang}.jsonl or winogrande_train_s_{lang}.jsonl where lang is "af", "en", "xh", or "zu" for Afrikaans, English, Xhosa, or Zulu, respectively
fine_tuning_path = 'mmlu_college_medicine_af.jsonl'

# How many times to run evaluation
evaluation_runs = 3

# Hugging Face token with WRITE permissions to save the model
hf_write_token = ""

fine_tune_name = fine_tuning_path[:-6].replace('_', '-')  # Can also set to something like "base" or "baseline" if not planning to fine-tune
fine_tuning_path = f'winogrande-mmlu-clinical-za/data/gpt_fine_tuning_datasets/{fine_tuning_path}'
fine_tune_name

'mmlu-college-medicine-af'

In [None]:
!pip install --upgrade pip
!pip install --upgrade --force-reinstall --no-cache-dir torch==2.2.0 triton --index-url https://download.pytorch.org/whl/cu121
!pip install "unsloth[cu121-ampere-torch220] @ git+https://github.com/unslothai/unsloth.git"
!pip install pandas
!git clone https://github.com/InstituteforDiseaseModeling/winogrande-mmlu-clinical-za.git

[0mLooking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.2.0
  Downloading https://download.pytorch.org/whl/cu121/torch-2.2.0%2Bcu121-cp310-cp310-linux_x86_64.whl (757.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m757.3/757.3 MB[0m [31m115.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting triton
  Downloading https://download.pytorch.org/whl/triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (168.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m168.1/168.1 MB[0m [31m212.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting filelock (from torch==2.2.0)
  Downloading https://download.pytorch.org/whl/filelock-3.13.1-py3-none-any.whl (11 kB)
Collecting typing-extensions>=4.8.0 (from torch==2.2.0)
  Downloading https://download.pytorch.org/whl/typing_extensions-4.9.0-py3-none-any.whl (32 kB)
Collecting sympy (from torch==2.2.0)
  Downloading https://download.pytorch.org/wh

In [None]:
import torch
torch.version.cuda, torch.__version__

In [None]:
# Check installation status
!nvcc
!python -m xformers.info
!python -m bitsandbytes

In [None]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-70b-Instruct-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

In [None]:
# Skip running this cell if want to evaluate baseline
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 8,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 42,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

In [None]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3",
)

def formatting_prompts_func(examples):
    convos = examples["messages"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, }


import pandas as pd
from datasets import Dataset
train_df = pd.read_json(fine_tuning_path, lines=True)
train_ds = Dataset.from_pandas(train_df)
train_ds = train_ds.map(formatting_prompts_func, batched = True,)

In [None]:
# Skip running this cell if want to evaluate baseline
train_ds[5]['text']

In [None]:
# Skip running this cell if want to evaluate baseline
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_ds,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 10,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 1,
        # gradient_accumulation_steps = 4,
        # warmup_steps = 5,
        # max_steps = 60,
        num_train_epochs = 3,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 42,
        output_dir = "outputs",
    ),
)

In [None]:
# Skip running this cell if want to evaluate baseline
#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
# Skip running this cell if want to evaluate baseline
trainer_stats = trainer.train()

In [None]:
# Skip running this cell if want to evaluate baseline
#@title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory         /max_memory*100, 3)
lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

In [None]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3",
)

FastLanguageModel.for_inference(model) # Enable native 2x faster inference

messages = [
    {"role": "user", "content": "What is 2+2?"},
]
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize = True,
    add_generation_prompt = True, # Must add for generation
    return_tensors = "pt",
).to("cuda")

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

outputs = model.generate(input_ids = inputs, max_new_tokens = 512, use_cache = True, temperature = 0.7, top_p = 0.9)
tokenizer.batch_decode(outputs)

In [None]:
model_name = f"llama3-70b-instruct-{fine_tune_name}"

In [None]:
# Define inference function that accepts a row in an OpenAI Batch API-formatted JSONL and produces a response
def infer(jsonl_row):
    messages = jsonl_row['body']['messages']
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    outputs = model.generate(
        input_ids,
        max_new_tokens=jsonl_row['body']['max_tokens'],
        eos_token_id=terminators,
        do_sample=True,
        temperature=jsonl_row['body']['temperature'],
        top_p=jsonl_row['body']['top_p'],
    )

    response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    return response

test_row = {"custom_id": "<|MODEL|>-on-en-mmlu-clinical_knowledge-0-answer-A", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "<|MODEL|>", "messages": [{"role": "user", "content": "The following are multiple choice questions (with answers) about clinical knowledge.\n\nQuestion 1: The energy for all forms of muscle contraction is provided by:\nA. ATP.\nB. ADP.\nC. phosphocreatine.\nD. oxidative phosphorylation.\nAnswer: A\n\nQuestion 2: What is the difference between a male and a female catheter?\nA. Male and female catheters are different colours.\nB. Male catheters are longer than female catheters.\nC. Male catheters are bigger than female catheters.\nD. Female catheters are longer than male catheters.\nAnswer: B\n\nQuestion 3: In the assessment of the hand function which of the following is true?\nA. Abduction of the thumb is supplied by spinal root T2\nB. Opposition of the thumb by opponens policis is supplied by spinal root T1\nC. Finger adduction is supplied by the median nerve\nD. Finger abduction is mediated by the palmar interossei\nAnswer: B\n\nQuestion 4: How many attempts should you make to cannulate a patient before passing the job on to a senior colleague, according to the medical knowledge of 2020?\nA. 4\nB. 3\nC. 2\nD. 1\nAnswer: C\n\nQuestion 5: Glycolysis is the name given to the pathway involving the conversion of:\nA. glycogen to glucose-1-phosphate.\nB. glycogen or glucose to fructose.\nC. glycogen or glucose to pyruvate or lactate.\nD. glycogen or glucose to pyruvate or acetyl CoA.\nAnswer: C\n\nNow, given the following question and answer choices, output only the letter corresponding to the correct answer. Do not add any explanation.\n\nQuestion: What size of cannula would you use in a patient who needed a rapid blood transfusion (as of 2020 medical knowledge)?\nA. 18 gauge.\nB. 20 gauge.\nC. 22 gauge.\nD. 24 gauge.\nAnswer:\n"}], "max_tokens": 512, "temperature": 0.7, "top_p": 0.9}}
infer(test_row)

In [None]:
!pip install tqdm
!pip install seaborn matplotlib

In [None]:
%%capture
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
from tqdm import tqdm
import json

for i in range(evaluation_runs):
    # Infer on all prompts
    generations_map = {}
    with open(f'winogrande-mmlu-clinical-za/data/evaluation_batches/gpt_style_batch_evaluation_template.jsonl', 'r') as fp:
        all_prompts_jsonl = pd.read_json(fp, lines=True)

    for index, row in tqdm(all_prompts_jsonl.iterrows(), total=all_prompts_jsonl.shape[0]):
        if '-mmlu-college_medicine-' in row['custom_id']:  # do not evaluate on college medicine since some models were trained on it in the experiments
            continue
        generations_map[row['custom_id'].replace('<|MODEL|>', model_name)] = infer({'custom_id': row['custom_id'], 'body': row['body']})

    # Save generations so they never have to be run again
    with open(f'generations_{model_name}_{i}.json', 'w') as fp:
        json.dump(generations_map, fp, indent=2)
;

In [None]:
# Define response-to-correctness functions

def check_mc_answer(custom_id, generation):
    parsed_gen = generation.strip().replace('(', ''). replace(')', '').upper()
    return len(parsed_gen) > 0 and parsed_gen[0] == custom_id[-1]  # answer is stored in last number of custom_id

def check_winogrande_answer(custom_id, generation):
    correct_number = custom_id[-1]  # answer is stored in the last character of the custom_id
    incorrect_number = str(3 - int(correct_number))  # maps 1 to 2 and 2 to 1
    correct = correct_number in generation and incorrect_number not in generation
    return correct

In [None]:
import re
import seaborn as sns
import matplotlib.pyplot as plt
import json
import pandas as pd

for i in range(evaluation_runs):

    with open(f'generations_{model_name}_{i}.json', 'r') as fp:
        generations_map = json.load(fp)

    # Get and display MMLU performance

    sections = [
        'clinical_knowledge',
        # 'college_medicine',
    ]

    mmlu_langs = [
        'en',
        'af',
        'zu',
        'xh',
    ]

    matrix = pd.DataFrame(
        data=0.0,
        index=[model_name],
        columns=mmlu_langs
    )

    for lang in mmlu_langs:
        total_score = 0
        q_cnt = 0

        for section in sections:

            # Construct the pattern
            pattern = re.compile(rf".*-on-{lang}-mmlu-{section}.*")

            # Filter keys
            matching_generations = [(c_id, gen) for c_id, gen in generations_map.items() if pattern.match(c_id)]
            print(len(matching_generations))

            for (c_id, gen) in matching_generations:
                if check_mc_answer(c_id, gen):
                    total_score += 1
                q_cnt += 1

        final_score = total_score / q_cnt
        matrix.at[model_name, lang] = round(final_score*100, 1)

    # Create the heatmap
    plt.figure(figsize=(12, 8), dpi=100)  # Increase the figure size and resolution for HD
    ax = sns.heatmap(matrix, annot=matrix, cmap="Greens", cbar=False, annot_kws={"size": 16}, fmt='.1f')

    # Rotate the labels on the y-axis (left) to be horizontal
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=16)  # Increase y-axis label size
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=16)  # Increase x-axis label size

    # Display the heatmap
    plt.tight_layout()
    plt.show()
    matrix.to_csv(f'mmlu_{model_name}_{i}.csv')

    # Get and display Winogrande performance

    winogrande_langs = [
        'en',
        'af',
        'zu',
        'xh',
    ]

    matrix = pd.DataFrame(
        data=0.0,
        index=[model_name],
        columns=winogrande_langs
    )

    for lang in winogrande_langs:
        total_score = 0
        q_cnt = 0

        # Construct the pattern
        pattern = re.compile(rf".*-on-{lang}-winogrande.*")

        # Filter keys
        matching_generations = [(c_id, gen) for c_id, gen in generations_map.items() if pattern.match(c_id)]
        print(len(matching_generations))

        for (c_id, gen) in matching_generations:
            if check_winogrande_answer(c_id, gen):
                total_score += 1
            q_cnt += 1

        final_score = total_score / q_cnt
        matrix.at[model_name, lang] = round(final_score*100, 1)

    # Create the heatmap
    plt.figure(figsize=(12, 8), dpi=100)  # Increase the figure size and resolution for HD
    ax = sns.heatmap(matrix, annot=matrix, cmap="Greens", cbar=False, annot_kws={"size": 16}, fmt='.1f')

    # Rotate the labels on the y-axis (left) to be horizontal
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=16)  # Increase y-axis label size
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=16)  # Increase x-axis label size

    # Display the heatmap
    plt.tight_layout()
    plt.show()
    matrix.to_csv(f'winogrande_{model_name}_{i}.csv')

    # Get and display Belebele performance

    belebele_langs = [
        'en',
        'af',
        'zu',
        'xh',
    ]

    matrix = pd.DataFrame(
        data=0.0,
        index=[model_name],
        columns=belebele_langs
    )

    for lang in belebele_langs:
        total_score = 0
        q_cnt = 0

        # Construct the pattern
        pattern = re.compile(rf".*-on-{lang}-belebele.*")

        # Filter keys
        matching_generations = [(c_id, gen) for c_id, gen in generations_map.items() if pattern.match(c_id)]
        print(len(matching_generations))

        for (c_id, gen) in matching_generations:
            if check_mc_answer(c_id, gen):
                total_score += 1
            q_cnt += 1

        final_score = total_score / q_cnt
        matrix.at[model_name, lang] = round(final_score*100, 1)

    # Create the heatmap
    plt.figure(figsize=(12, 8), dpi=100)  # Increase the figure size and resolution for HD
    ax = sns.heatmap(matrix, annot=matrix, cmap="Greens", cbar=False, annot_kws={"size": 16}, fmt='.1f')

    # Rotate the labels on the y-axis (left) to be horizontal
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=16)  # Increase y-axis label size
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=16)  # Increase x-axis label size

    # Display the heatmap
    plt.tight_layout()
    plt.show()
    matrix.to_csv(f'belebele_{model_name}_{i}.csv')

In [None]:
# Skip running this cell if want to evaluate baseline
# Save model
model.push_to_hub_merged(model_name, tokenizer, save_method = "merged_16bit", token = hf_write_token)