# Language model: Supervised Fine-Tune (STF)

Code adapted from LiquidAI's Supervised Fine-Tuning (SFT) notebook with a LoRA adapter in TRL (https://colab.research.google.com/drive/1j5Hk_SyBb2soUsuhU0eIEA9GwLNRnElF?usp=sharing)

## Imports 

In [None]:
# !pip install git+https://github.com/huggingface/transformers.git trl>=0.18.2 peft>=0.15.2

In [None]:
import os
import pandas as pd
from google import genai
from google.genai import types

## Fine-tuning dataset generation

In [None]:
system_prompt = """You are a medical expert. The user will provide a
 dataset in JSON format containing a collection of
 supporting features and counterpoint features. For each feature, there is
 the value, the percentile and the importance. The text must be based only
 on the provided data.

 The implications of the features on the diagnosis: if present in the
 Support section, then they support the diagnosis. If present in the
 Counterpoints section, then they do not support the diagnosis.
 """.replace("\n ", " ")

def get_user_prompt(explanations_df, prediction, model):
    # The task and output format definition is the same for all user prompts
    start = """Summarise the patient data into 3 cohesive and
     medical term-rich paragraphs of 100 words each. First, the diagnosis
     support (use only features from the Support section). Second, the
     diagnosis counterpoints (use only features from the Counterpoints
     section). Third, summarise and categorise the given diagnosis into possible
     or impossible based on the data.
    """.replace("\n     ", " ")

    # Gather patient data 
    df = explanations_df[['name', 'value', 'percentile', f'importance-{model}']]
    df = df.rename(columns={f'importance-{model}': 'importance'})
    df['value'] = [float(f'{val:.3f}') for val in df['value']]
    df['percentile'] = [int(per) for per in df['percentile']]
    df['importance'] = [float(f'{imp:.3f}') for imp in df['importance']]
    
    # Support data
    df_p = df[df['importance'] > 0.0]
    p_prompt = df_p.sort_values(by='importance', ascending=False)[:10].T.to_dict()
    p_prompt = [val for val in p_prompt.values()]
    
    # Counterpoint data
    df_n = df[df['importance'] < 0.0]
    df_n = df_n.sort_values(by='importance', ascending=True)[:10]
    df_n['importance'] = [abs(imp) for imp in df_n['importance']]
    n_prompt = df_n.T.to_dict()
    n_prompt = [val for val in n_prompt.values()]
    
    # Join data
    prompt = f'{start}\n\nPrediction: {prediction}\nSupport: {p_prompt}\nCounterpoints: {n_prompt}'
    return prompt

# Construct final message
def get_messages(explanations_df, prediction, model):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": get_user_prompt(explanations_df, prediction, model)}
    ]
    return messages

In [None]:
# Load explanations from the model
patients_id = os.listdir('results')
model_name = ["xgboost", "rf"]
cat_map = {0: 'Cognitive Normal', 1: 'Mild Cognitive Impairment', 2: "Alzheimer's Disease"} 

messages = list()
for ptid in patients_id:
    path = f"results/{ptid}/"
    expl_cn = pd.read_csv(path + f"explanations-cn.csv")
    expl_mci = pd.read_csv(path + f"explanations-mci.csv")
    expl_ad = pd.read_csv(path + f"explanations-ad.csv")
    expls = [expl_cn, expl_mci, expl_ad]
    probs = dict()
    for model in model_name:
        for i, cat in cat_map.items():
            messages.append(get_messages(expls[i], cat, model))
        

In [None]:
import pickle

with open('messages.pkl', 'wb') as f:
    pickle.dump(messages, f)

In [None]:
# Import Gemini model
GEMINI_API_KEY = " "

client = genai.Client(api_key=GEMINI_API_KEY)

In [None]:
# Read progress to continue with the generation
sep = '\u241E'  # Special character to separate lines safely
try:
    with open('responses.csv', 'r', encoding='utf-8') as f:
        content = f.read()
        responses_done = content.split(sep)
except FileNotFoundError:
    responses_done = list()
print(len(responses_done))

In [None]:
# Generate responses (adjusted for rate limits)
import time

while True:
    current = len(responses_done)
    try:
        responses = list()
        
        for i, msg in enumerate(messages[current:]):
            sys_prompt = msg[0]['content']
            user_prompt = msg[1]['content']
            response = client.models.generate_content(
                model="gemini-2.0-flash", 
                config=types.GenerateContentConfig(
                    system_instruction=sys_prompt),
                contents=user_prompt,
            )
            responses.append(response.text)
    except:
        if len(responses) == 0:
            print('Daily limit met')
            break
        responses_done.extend(responses)
        print(len(responses_done))

    # RPM go around
    time.sleep(20)

In [None]:
# Saving progress until next rate limit reset
sep = '\u241E'
with open('responses.csv', 'w', encoding='utf-8') as f:
    f.write(sep.join(responses_done))

In [None]:
responses_done[760]

## Language model initialization

In [None]:
import torch
import transformers
import trl

os.environ["WANDB_DISABLED"] = "true"

print(f"📦 PyTorch version: {torch.__version__}")
print(f"🤗 Transformers version: {transformers.__version__}")
print(f"📊 TRL version: {trl.__version__}")

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML, Markdown
import torch

model_id = "LiquidAI/LFM2-1.2B"

print("📚 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)

print("🧠 Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="bfloat16",
#   attn_implementation="flash_attention_2" <- uncomment on compatible GPU
)

print("✅ Local model loaded successfully!")
print(f"🔢 Parameters: {model.num_parameters():,}")
print(f"📖 Vocab size: {len(tokenizer)}")
print(f"💾 Model size: ~{model.num_parameters() * 2 / 1e9:.2f} GB (bfloat16)")

In [None]:
from datasets import Dataset, load_dataset
import pickle

print("📥 Loading SFT dataset...")

sep = '\u241E'
with open("messages.pkl", "rb") as f:
    messages = pickle.load(f)
    
with open('responses.csv', 'r', encoding='utf-8') as f:
    content = f.read()
    responses_done = content.split(sep)
    
split = int(len(messages)*0.9)
data = messages
for i in range(len(data)):
    data[i] = {'messages':data[i]}
    data[i]['messages'].append({
        'role':'assistant',
        'content': responses_done[i]
    })
    
train_dataset_sft = Dataset.from_list(data[:split])
eval_dataset_sft = Dataset.from_list(data[split:])

print("✅ SFT Dataset loaded:")
print(f"   📚 Train samples: {len(train_dataset_sft)}")
print(f"   🧪 Eval samples: {len(eval_dataset_sft)}")
print(f"\n📝 Single Sample: {train_dataset_sft[0]['messages']}")

## SFT

In [None]:
from trl import SFTConfig, SFTTrainer

sft_config = SFTConfig(
    output_dir="lfm2-sft",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    learning_rate=5e-5,
    lr_scheduler_type="linear",
    warmup_steps=100,
    warmup_ratio=0.2,
    logging_steps=10,
    save_strategy="epoch",
    eval_strategy="epoch",
    load_best_model_at_end=True,
    report_to=None,
    bf16=False # <- not all colab GPUs support bf16
)

print("🏗️  Creating SFT trainer...")
sft_trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=train_dataset_sft,
    eval_dataset=eval_dataset_sft,
    processing_class=tokenizer,
)

print("\n🚀 Starting SFT training...")
sft_trainer.train()

print("🎉 SFT training completed!")

sft_trainer.save_model()
print(f"💾 SFT model saved to: {sft_config.output_dir}")

## Fine-tuned model load and test

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML, Markdown
import torch

model_id = "lfm2-sft"

print("📚 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)

print("🧠 Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="bfloat16",
#   attn_implementation="flash_attention_2" <- uncomment on compatible GPU
)

print("✅ Local model loaded successfully!")
print(f"🔢 Parameters: {model.num_parameters():,}")
print(f"📖 Vocab size: {len(tokenizer)}")
print(f"💾 Model size: ~{model.num_parameters() * 2 / 1e9:.2f} GB (bfloat16)")

In [None]:
# Prediction
msg = messages[2][:2]
input_ids = tokenizer.apply_chat_template(
    msg,
    add_generation_prompt=True,
    return_tensors="pt",
    tokenize=True,
).to(model.device)

output = model.generate(
    input_ids,
    do_sample=True,
    temperature=0.3,
    min_p=0.15,
    repetition_penalty=1.05,
    max_new_tokens=512,
)

explanation_text = tokenizer.decode(output[0], skip_special_tokens=True)
explanation_text = explanation_text.split('assistant\n')[1]

In [None]:
print("Target\n", messages[2][2]['content'], "\nGenerated\n", explanation_text)