# Language model: Supervised Fine-Tune (STF)

Code adapted from LiquidAI's Supervised Fine-Tuning (SFT) notebook with a LoRA adapter in TRL (https://colab.research.google.com/drive/1j5Hk_SyBb2soUsuhU0eIEA9GwLNRnElF?usp=sharing)

## Imports 

In [None]:
# !pip install git+https://github.com/huggingface/transformers.git trl>=0.18.2 peft>=0.15.2

In [1]:
import os
import pandas as pd
from google import genai
from google.genai import types

norm_ranges = pd.read_csv('norm_range.csv', sep=';', names=['feature', 'lower', 'upper'])

## Fine-tuning dataset generation

In [2]:
system_prompt = """
 You are a clinician. The user will provide a dataset in JSON format
 containing a collection of supporting features and counterpoint features.
 For each feature, there is the value and its expected range.
 The text must be based only on the provided data.""".replace("\n ", " ")

def get_user_prompt(explanations_df, prediction, model):
    # The task and output format definition is the same for all user prompts
    start = """
     Summarise the patient data into 3 cohesive and medical term-rich paragraphs
     of 100 words each, without bullet points. First, the diagnosis
     support (use only features from the Support section). Second, the
     diagnosis counterpoints (use only features from the Counterpoints
     section). Third, categorise the given diagnosis into possible
     or impossible based on the data and justify.""".replace("\n     ", " ")

    # Gather patient data 
    df = explanations_df[['feature', 'name', 'value', f'importance-{model}']]
    df = df.rename(columns={f'importance-{model}': 'importance'})
    df['value'] = [float(f'{val:.2f}') for val in df['value']]
    df['expected range'] = [f"[{norm_ranges[norm_ranges['feature'] == col]['lower'].iloc[0]}, \
{norm_ranges[norm_ranges['feature'] == col]['upper'].iloc[0]}]" \
                   for col in df['feature']]
    
    # Support data
    df_p = df[df['importance'] > 0.01]
    if len(df_p) == 0:
        p_prompt = 'No significative supporting data'
    else:
        df_p = df_p.sort_values(by='importance', ascending=False)[:7]
        p_prompt = df_p[['name', 'value', 'expected range']].T.to_dict()
        p_prompt = [str(val).replace("'", "").replace('"', "") for val in p_prompt.values()]
    
    # Counterpoint data
    df_n = df[df['importance'] < -0.01]
    if len(df_n) == 0:
        n_prompt = 'No significative counterpoints'
    else:
        df_n = df_n.sort_values(by='importance', ascending=True)[:7]
        n_prompt = df_n[['name', 'value', 'expected range']].T.to_dict()
        n_prompt = [str(val).replace("'", "").replace('"', "") for val in n_prompt.values()]
    
    # Join data
    prompt = f'{start}\n\nSupport: {p_prompt}\nCounterpoints: {n_prompt}\nPrediction: {prediction}'
    return prompt

# Construct final message
def get_messages(explanations_df, prediction, model):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": get_user_prompt(explanations_df, prediction, model)}
    ]
    return messages

In [3]:
# Load explanations from the model
patients_id = os.listdir('results')
model_name = ["xgboost", "rf"]
cat_map = {0: 'Cognitive Normal', 1: 'Mild Cognitive Impairment', 2: "Alzheimer's Disease"} 

messages = list()
for ptid in patients_id:
    path = f"results/{ptid}/"
    expl_cn = pd.read_csv(path + f"explanations-cn.csv")
    expl_mci = pd.read_csv(path + f"explanations-mci.csv")
    expl_ad = pd.read_csv(path + f"explanations-ad.csv")
    expls = [expl_cn, expl_mci, expl_ad]
    probs = dict()
    for model in model_name:
        for i, cat in cat_map.items():
            messages.append(get_messages(expls[i], cat, model))
        

In [4]:
import pickle

with open('messages.pkl', 'wb') as f:
    pickle.dump(messages, f)

In [None]:
# Import Gemini model
GEMINI_API_KEY = ""

client = genai.Client(api_key=GEMINI_API_KEY)

In [18]:
# Read progress to continue with the generation
sep = '\u241E'  # Special character to separate lines safely
try:
    with open('responses.csv', 'r', encoding='utf-8') as f:
        content = f.read()
        responses_done = content.split(sep)
except FileNotFoundError:
    responses_done = list()
print(len(responses_done))

603


In [19]:
# Generate responses (adjusted for rate limits)
import time

while True:
    current = len(responses_done)
    try:
        responses = list()
        
        for i, msg in enumerate(messages[current:]):
            sys_prompt = msg[0]['content']
            user_prompt = msg[1]['content']
            response = client.models.generate_content(
                model="gemini-2.0-flash", 
                config=types.GenerateContentConfig(
                    system_instruction=sys_prompt),
                contents=user_prompt,
            )
            responses.append(response.text)
    except:
        if len(responses) == 0:
            print('Daily limit met')
            break
        responses_done.extend(responses)
        print(len(responses_done))

    # RPM go around
    time.sleep(20)

619
635
665
681
713
744
776
804
Daily limit met


In [21]:
# Saving progress until next rate limit reset
sep = '\u241E'
with open('responses.csv', 'w', encoding='utf-8') as f:
    if len(responses_done) > 798:
        responses_done = responses_done[:798]
    f.write(sep.join(responses_done))

In [22]:
# Visualise example Gemini response
responses_done[760]

"Several features support the diagnosis of Mild Cognitive Impairment (MCI). The standard deviation of normalized intensity in the left entorhinal cortex, at 7.71, is slightly below the expected range of 7.84 to 9.23, which could indicate subtle structural changes. Additionally, the maximum normalized intensity of the right rostralmiddlefrontal region, measured at 102.0, exceeds the expected range of 93.02 to 100.74, a deviation that warrants further investigation. The patient exhibits stability in cognitive and functional abilities, as evidenced by a yearly evolution of 0.0 for both the Mini-Mental State Examination and Functional Activities Questionnaire, aligning with the expected ranges.\n\nHowever, several counterpoints challenge the MCI diagnosis. Baseline and current Clinical Dementia Rating sum-of-boxes scores are both 0.0, well within the expected ranges, suggesting an absence of significant functional impairment. The Logical Memory Delayed Recall total scores, both at baseline

## Language model initialization

In [None]:
import torch
import transformers
import trl

os.environ["WANDB_DISABLED"] = "true"

print(f"📦 PyTorch version: {torch.__version__}")
print(f"🤗 Transformers version: {transformers.__version__}")
print(f"📊 TRL version: {trl.__version__}")

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML, Markdown
import torch

model_id = "LiquidAI/LFM2-1.2B"  # 350M

print("📚 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)

print("🧠 Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="bfloat16",
#   attn_implementation="flash_attention_2" <- uncomment on compatible GPU
)

print("✅ Local model loaded successfully!")
print(f"🔢 Parameters: {model.num_parameters():,}")
print(f"📖 Vocab size: {len(tokenizer)}")
print(f"💾 Model size: ~{model.num_parameters() * 2 / 1e9:.2f} GB (bfloat16)")

In [None]:
from datasets import Dataset, load_dataset
import pickle

print("📥 Loading SFT dataset...")

sep = '\u241E'
with open("messages.pkl", "rb") as f:
    messages = pickle.load(f)
    
with open('responses.csv', 'r', encoding='utf-8') as f:
    content = f.read()
    responses_done = content.split(sep)

data = messages[:len(responses_done)]
split = int(len(data)*0.9)
for i in range(len(responses_done)):
    data[i] = {'messages':data[i]}
    data[i]['messages'].append({
        'role':'assistant',
        'content': responses_done[i]
    })
    
train_dataset_sft = Dataset.from_list(data[:split])
eval_dataset_sft = Dataset.from_list(data[split:])

print("✅ SFT Dataset loaded:")
print(f"   📚 Train samples: {len(train_dataset_sft)}")
print(f"   🧪 Eval samples: {len(eval_dataset_sft)}")
print(f"\n📝 Single Sample: {train_dataset_sft[0]['messages']}")

## SFT

In [None]:
from trl import SFTConfig, SFTTrainer

sft_config = SFTConfig(
    output_dir="lfm2-sft-1B",  # 350M
    num_train_epochs=1,
    per_device_train_batch_size=1,
    learning_rate=5e-5,
    lr_scheduler_type="linear",
    warmup_steps=100,
    warmup_ratio=0.2,
    logging_steps=10,
    save_strategy="epoch",
    eval_strategy="epoch",
    load_best_model_at_end=True,
    report_to=None,
    bf16=False # <- not all colab GPUs support bf16
)

print("🏗️  Creating SFT trainer...")
sft_trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=train_dataset_sft,
    eval_dataset=eval_dataset_sft,
    processing_class=tokenizer,
)

print("\n🚀 Starting SFT training...")
sft_trainer.train()

print("🎉 SFT training completed!")

sft_trainer.save_model()
print(f"💾 SFT model saved to: {sft_config.output_dir}")

## Fine-tuned model load and test

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML, Markdown
import torch

model_id = "lfm2-sft-1B"  # 350M

print("📚 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)

print("🧠 Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="bfloat16",
#   attn_implementation="flash_attention_2" <- uncomment on compatible GPU
)

print("✅ Local model loaded successfully!")
print(f"🔢 Parameters: {model.num_parameters():,}")
print(f"📖 Vocab size: {len(tokenizer)}")
print(f"💾 Model size: ~{model.num_parameters() * 2 / 1e9:.2f} GB (bfloat16)")

In [None]:
import re
import json
from rapidfuzz import fuzz

def fuzzy_in_paragraph(query, paragraph, score_cutoff=80):
    score = fuzz.partial_ratio(query, paragraph)
    return score >= score_cutoff

def parse_feature_list(feature_str):
    """
    Converts a string like:
    ['{name: Baseline Logical Memory..., value: 12.0, expected range: [10.05, 16.71]}', '{...}']
    into a list of Python dicts.

    Generated with ChatGPT
    """
    # Step 1: Extract the inner `{...}` chunks
    dict_strings = re.findall(r"\{.*?\}", feature_str)

    parsed_dicts = []
    for ds in dict_strings:
        # Add quotes around keys
        s = re.sub(r"([{,]\s*)([A-Za-z0-9 _\-]+)(\s*:)", r'\1"\2"\3', ds)

        # Add quotes around string values that aren't numbers, brackets, or minus signs
        s = re.sub(r':\s*([A-Za-z][^,}\]]*)', lambda m: ': "{}"'.format(m.group(1).strip()), s)

        # Convert to Python dict via json.loads
        parsed_dicts.append(json.loads(s))

    return parsed_dicts
  
def compute_f1(text_names, expected_names):
    if len(text_names) > 0:
        tp = sum([1 for name in text_names if name in expected_names]) / len(text_names)
        fp = sum([1 for name in text_names if name not in expected_names]) / len(text_names)
    else: 
        tp = 0
        fp = 0

    if len(expected_names) > 0:
        fn = sum([1 for name in expected_names if name not in text_names]) / len(expected_names)
    else:
        fn = 0
    
    if tp + fn + fn > 0:
        f1 = (2 * tp) / (2 * tp + fp + fn)
    else:
        f1 = 0
    return f1

def extract_f1_scores(text, input):
    # Extract features from user input
    support_input = input.split('\nSupport:')[1].split('\nCounterpoints')[0]
    counter_input = input.split('\nCounterpoints:')[1].split('\nPrediction')[0]
    if support_input == "No significative supporting data":
        support_names = []
    else:
        support_input = parse_feature_list(support_input)
        support_names = [d['name'].lower() for d in support_input]

    if counter_input == "No significative counterpoints":
        counter_names = []
    else:
        counter_input = parse_feature_list(counter_input)
        counter_names = [d['name'].lower() for d in counter_input]

    # Separate support and counterpoints sections
    if len(text.split('\n\n')) == 3:
        # Without titles
        explanation_support = text.split('\n\n')[0]
        explanation_counter = text.split('\n\n')[1]
    elif len(text.split('\n\n')) == 6:
        # With titles
        explanation_support = text.split('\n\n')[1]
        explanation_counter = text.split('\n\n')[3]
    else:
        # With bulletpoints
        explanation_support = text.split('Support')[1].split('Counter')[0]
        explanation_counter = text.split('Counter')[1].split('Predict')[0]

    # Detect features from the name list
    names = [*support_names, *counter_names]
    text_sup_names = [name.lower() for name in names if fuzzy_in_paragraph(name.lower(), explanation_support.lower())]
    text_cou_names = [name.lower() for name in names if fuzzy_in_paragraph(name.lower(), explanation_counter.lower())]
    
    # Compute confusion matrix components for Support
    sup_f1 = compute_f1(text_sup_names, support_names)

    # Compute confusion matrix components for Counterpoints
    cou_f1 = compute_f1(text_cou_names, counter_names)

    # Compute confusion matrix components for Support
    swap1_f1 = compute_f1(text_sup_names, counter_names)

    # Compute confusion matrix components for Counterpoints
    swap2_f1 = compute_f1(text_cou_names, support_names)

    return sup_f1, cou_f1, swap1_f1, swap2_f1

In [35]:
name_list = list(set([*expl_cn['name'], *expl_mci['name'], *expl_ad['name']]))
sep = '\u241E'
with open('name_list.csv', 'w', encoding='utf-8') as f:
    f.write(sep.join(name_list))

In [None]:
# Generation for all samples (except those used for fine-tuning)
import numpy as np
from tqdm.notebook import tqdm

sup_f1s, cou_f1s, swap1_f1s, swap2_f1s = list(), list(), list(), list()
i = 0
for msg in tqdm(messages[len(responses_done):]):
    input_ids = tokenizer.apply_chat_template(
        msg,
        add_generation_prompt=True,
        return_tensors="pt",
        tokenize=True,
    ).to(model.device)
    
    output = model.generate(
        input_ids,
        do_sample=True,
        temperature=0.3,
        min_p=0.15,
        repetition_penalty=1.2,
        max_new_tokens=1024,
    )

    # Get explanation text
    explanation_text = tokenizer.decode(output[0], skip_special_tokens=True)
    explanation_text = explanation_text.split('assistant\n')[1]

    # Extract expected features from input
    feature_input = msg[1]['content']

    # Compute f1 scores for appearing features
    sup_f1, cou_f1, swap1_f1, swap2_f1 = extract_f1_scores(explanation_text, feature_input)
    if sup_f1 is not None:
        sup_f1s.append(sup_f1)
        cou_f1s.append(cou_f1)
        swap1_f1s.append(swap1_f1)
        swap2_f1s.append(swap2_f1)

    if i % 50 == 0:
        print("\n\n###", i)
        print(f"F1 score for feature appearance in supporting paragraph: {np.mean(sup_f1s)} ± {np.std(sup_f1s)}")
        print(f"F1 score for feature appearance in counterpoints paragraph: {np.mean(cou_f1s)} ± {np.std(cou_f1s)}")
        print(f"Combined F1 score: {np.mean([*sup_f1s, *cou_f1s])} ± {np.std([*sup_f1s, *cou_f1s])}\n")

        print(f"F1 score for counfused feature appearance in supporting paragraph: {np.mean(swap1_f1s)} ± {np.std(swap1_f1s)}")
        print(f"F1 score for counfused feature appearance in counterpoints paragraph: {np.mean(swap2_f1s)} ± {np.std(swap2_f1s)}")
        print(f"Combined F1 score: {np.mean([*swap1_f1s, *swap2_f1s])} ± {np.std([*swap1_f1s, *swap2_f1s])}")
    
    i += 1

print("### Final")
print(f"F1 score for feature appearance in supporting paragraph: {np.mean(sup_f1s)} ± {np.std(sup_f1s)}")
print(f"F1 score for feature appearance in counterpoints paragraph: {np.mean(cou_f1s)} ± {np.std(cou_f1s)}")
print(f"Combined F1 score: {np.mean([*sup_f1s, *cou_f1s])} ± {np.std([*sup_f1s, *cou_f1s])}\n")

print(f"F1 score for counfused feature appearance in supporting paragraph: {np.mean(swap1_f1s)} ± {np.std(swap1_f1s)}")
print(f"F1 score for counfused feature appearance in counterpoints paragraph: {np.mean(swap2_f1s)} ± {np.std(swap2_f1s)}")
print(f"Combined F1 score: {np.mean([*swap1_f1s, *swap2_f1s])} ± {np.std([*swap1_f1s, *swap2_f1s])}")