# End to end instruction example using LLMs with fine-tuning

This notebook provides an example for how to use the pretrain data. This means that the model is trained on full patient histories, without any specific task. This can be used to develop models that can generate synthetic patients or embeddings.

> **Note:** You need a GPU with at least 30GB of memory for this example to work.
We also have not tested the performance of PEFT models - only as examples.

> **Important:** Please install first the fine-tuning packages with `pip install twinweaver[fine-tuning-example]`.



In [1]:
# nvidia-smi
import os
os.chdir("/data/gpfs/projects/punim2885/Xuan")

from transformers import AutoTokenizer
import pandas as pd
import gc
import torch
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    pipeline,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig


from twinweaver import (
    DataManager,
    Config,
    ConverterPretrain
)

In [2]:
# Some key settings
BASE_MODEL = "microsoft/Phi-4-mini-instruct"  # NOTE: we haven't tested the performance of this model beyond examples

## Generate training data

In [3]:
# Input folder (after os.chdir to Xuan, paths are relative to that)
INPUT_FOLDER = "/data/gpfs/projects/punim2885/Xuan/Result/03_Demo_data"

# Load data
df_events = pd.read_csv(f"{INPUT_FOLDER}/events.csv")
df_constant = pd.read_csv(f"{INPUT_FOLDER}/constants.csv")
df_constant_description = pd.read_csv(f"{INPUT_FOLDER}/constants_description.csv")

# Ensure event_value is string (converter calls .lower() on it; floats/NaN cause AttributeError)
if "event_value" in df_events.columns:
    df_events["event_value"] = df_events["event_value"].fillna("").astype(str).replace("nan", "")

# Use all constant columns except patientid
config = Config()  # Override values here to customize pipeline
config.constant_columns_to_use = [c for c in df_constant.columns if c != "patientid"]
#config.constant_birthdate_column = "birthyear"

In [4]:
dm = DataManager(config=config)
dm.load_indication_data(df_events=df_events, df_constant=df_constant, df_constant_description=df_constant_description)
dm.process_indication_data()
dm.setup_unique_mapping_of_events()
dm.setup_dataset_splits()
dm.infer_var_types()

converter = ConverterPretrain(config=config, dm=dm)



In [5]:
# Get all training + validation patientids
training_patientids = dm.get_all_patientids_in_split(config.train_split_name)
validation_patientids = dm.get_all_patientids_in_split(config.validation_split_name)

The `generate_transformers_df` function iterates through each patient and generates the text data.


In [6]:
def generate_transformers_df(patientids_list):
    df = []

    for patientid in patientids_list:
        patient_data = dm.get_patient_data(patientid)

        p_converted = converter.forward_conversion(
            events=patient_data["events"], 
            constant=patient_data["constant"]
        )
        new_data = {
            "text": p_converted["text"],
            "patientid": f"{patientid}",  # Just for ease of finding later
        }
        df.append(new_data)

    df = pd.DataFrame(df)
    return df

In [7]:
# Generate training and validation dfs
df_train = generate_transformers_df(training_patientids)
df_validation = generate_transformers_df(validation_patientids)

In [8]:
df_train.head()

Unnamed: 0,text,patientid
0,"The following is a patient, starting with the ...",VN010001
1,"The following is a patient, starting with the ...",VN010003
2,"The following is a patient, starting with the ...",VN010004
3,"The following is a patient, starting with the ...",VN010005
4,"The following is a patient, starting with the ...",VN010006


## Fine-tune LLM

We start by setting up the tokenizer. We set the padding token to be the same as the EOS (End of Sequence) token, which is a common practice for causal language models.


In [9]:
# Setup tokenizer and datasets
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

# Set padding token to eos_token
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

train_dataset = Dataset.from_pandas(df_train)
validation_dataset = Dataset.from_pandas(df_validation)

Instruction-tuned models expect data in a specific conversational format (e.g., User: ... Assistant: ...). 
We use `format_chat_template` to structure our raw prompt/completion strings into this list-of-messages format using the `user` and `assistant` roles.


In [10]:
# Format data for chat template
def format_chat_template(example):
    """Convert prompt/completion pairs to proper prompt/completion format"""
    return {
        "text": example["text"],
    }

# Apply formatting to datasets
train_dataset = train_dataset.map(format_chat_template)
validation_dataset = validation_dataset.map(format_chat_template)

Map:   0%|          | 0/2256 [00:00<?, ? examples/s]

Map:   0%|          | 0/282 [00:00<?, ? examples/s]

We configure 4-bit quantization using `BitsAndBytesConfig` (QLoRA). This significantly lowers memory usage, allowing us to fine-tune the model on consumer GPUs.


In [11]:
# Define Quantization Config (4-bit loading)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,  # This should be set based on your GPU capabilities
    bnb_4bit_use_double_quant=True,
)

Here we set up Low-Rank Adaptation (LoRA) configuration. `LoraConfig` defines the adapter parameters (rank `r`, `alpha`). we target linear layers (`q_proj`, `k_proj` etc.) which generally yields better results than just attending to query/value projections.


In [12]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=8,  # Rank (higher = more parameters to train)
    bias="none",
    task_type="CAUSAL_LM",
    # Target all linear layers for best performance (specific to Llama architecture)
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

We define the training arguments in `SFTConfig`. Notice the higher learning rate (`1e-4`) compared to typical full fine-tuning in the GDT paper. We also set `bf16=True` for newer GPUs (Ampere+) to improve training stability.


In [15]:
training_arguments = SFTConfig(
    output_dir="./Result/04_Demo_results",
    #num_train_epochs=5,
    num_train_epochs=2,
    #per_device_train_batch_size=1,
    per_device_train_batch_size=10,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=10,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=10,
    per_device_eval_batch_size=1,
    learning_rate=1e-4,  # LR is higher for PEFT, see TwinWeaver paper for full fine-tuning details
    fp16=False,  # Use fp16 for older GPUs T4/V100, bf16 for Ampere and later (A100/3090/4090)
    bf16=True,
    max_grad_norm=1.0,
    warmup_ratio=0.1,
    group_by_length=True,
    save_total_limit=1,
    lr_scheduler_type="cosine",
    #max_length=8192,
    max_length=8192,
    packing=False,  # Disable packing for more exact training, though can be activated
    completion_only_loss=False,  # Compute loss on entire text
)

In [13]:
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=False,
)

# Disable cache for training (required for gradient checkpointing)
model.config.use_cache = False

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    processing_class=tokenizer,
    args=training_arguments,
    eval_dataset=validation_dataset,
    peft_config=peft_config,
)

Adding EOS to train dataset:   0%|          | 0/2256 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/2256 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/2256 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/282 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/282 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/282 [00:00<?, ? examples/s]

In [22]:
# Start training - takes around 5 mins, depending on hardware
trainer.train()

Step,Training Loss,Validation Loss
10,4.4163,
20,0.0,
30,0.0,
40,0.0,
50,0.0,
60,0.0,
70,0.0,
80,0.0,
90,0.0,
100,0.0,


TrainOutput(global_step=226, training_loss=0.19541050058550538, metrics={'train_runtime': 730.665, 'train_samples_per_second': 6.175, 'train_steps_per_second': 0.309, 'total_flos': 1.44666853951488e+16, 'train_loss': 0.19541050058550538})

In [17]:
# Save the fine-tuned adapter
adapter_path = "Result/04_Demo_results/final_adapter"
trainer.save_model(adapter_path)
print(f"Adapter saved to {adapter_path}")

del trainer
del model
gc.collect()
torch.cuda.empty_cache()

Adapter saved to Result/04_Demo_results/final_adapter


## Inference example

Inference example for a test set patient, where we want to generate the full patient trajectory after the first line of therapy.

In [18]:
# Get the first test set patient
test_patientid = dm.get_all_patientids_in_split(config.test_split_name)[0]
patient_data = dm.get_patient_data(test_patientid)

# Lets simulate forecasts for after the first line of therapy
df_constant_patient = patient_data["constant"].copy()
df_events_patient = patient_data["events"].copy()
date_of_first_lot = df_events_patient.loc[
    df_events_patient["event_category"] == config.event_category_lot, "date"
].min()
date_of_first_event = df_events_patient["date"].min()

# Only keep data until (and including) first line of therapy
df_events_patient = df_events_patient.loc[df_events_patient["date"] <= date_of_first_lot]

We convert the patient data into the first part.

In [33]:
# Convert to instruction
converted = converter.forward_conversion(
    events=df_events_patient, 
    constant=df_constant_patient
)

For inference, we load the base model again (clean slate) to avoid any state from training, and then attach the adapter we trained. `PeftModel` handles the integration of the LoRA weights with the base model.


For inference, we load the base model again (clean slate) and then attach the adapter we trained. `PeftModel` handles the integration of the LoRA weights.


In [20]:
# 1. Load the Base Model again (clean instance)
adapter_path = "Result/04_Demo_results/final_adapter"
base_model_inference = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,  # Reuse the 4-bit config
    device_map="auto",
    trust_remote_code=False,
)

# 2. Load the Saved Adapter
# This wraps the base model with the fine-tuned LoRA layers
inference_model = PeftModel.from_pretrained(base_model_inference, adapter_path)

# 3. Switch to evaluation mode
inference_model.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Phi3ForCausalLM(
      (model): Phi3Model(
        (embed_tokens): Embedding(200064, 3072, padding_idx=199999)
        (layers): ModuleList(
          (0-31): 32 x Phi3DecoderLayer(
            (self_attn): Phi3Attention(
              (o_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              

In [45]:
# Create text generation pipeline
# Re-enable cache for inference
inference_model.config.use_cache = True
text_gen_pipeline = pipeline("text-generation", model=inference_model, tokenizer=tokenizer)

Device set to use cuda:0


In [46]:
print(converted["text"])

The following is a patient, starting with the demographic data, following visit by visit everything that the patient experienced. All lab codes refer to LOINC codes.

Starting with demographic data:
	Younger cohort=1; Older cohort=0 is Younger cohort,
	Child is in all rounds is yes,
	Child's sex is male,
	Child's first language is vietnamese,
	Child's ethnic group is Kinh,
	Child's religion is none.

On the first visit, the patient experienced the following: 
.



In [53]:
# /data/gpfs/projects/punim2885/llm_dts/lib/python3.11/site-packages/twinweaver/common/config.py
# Generate with LLM, for a given time
generated_answer = text_gen_pipeline(
    converted["text"] + "\n\n3 weeks later, the child visited and experienced the following:",
    #converted["text"] ,
    # max_new_tokens=128,
    max_new_tokens=1024,
    return_full_text=False,
    do_sample=True,  # Using nucleus sampling
    temperature=0.7,
    top_p=0.9,
)[0]["generated_text"]

In [54]:
# Show the generated answer
print(generated_answer)

 
	Child experienced fever with temperature of 38.4 degrees Celsius
	Child experienced headache with severity of 4
	Child experienced sore throat with severity of 3
	Child experienced cough with severity of 3
	Child experienced nasal congestion with severity of 2
	Child experienced facial pain with severity of 2
	Child experienced pharyngeal erythema with severity of 2
	Child experienced ear pain with severity of 2
	Child experienced difficulty swallowing with severity of 3
	Child experienced nasal discharge with color yellow and consistency thick
	Child experienced nausea with severity of 3
	Child experienced vomiting with frequency of 2 and duration of 5 minutes
	Child experienced nasal blockage with nasal obstruction score of 2
	Child experienced nasal blockage with nasal congestion score of 2
	Child experienced nasal blockage with nasal discharge score of 2
	Child experienced nasal blockage with facial pain score of 2
	Child experienced nasal blockage with pharyngeal erythema score

The raw text output from the model needs to be parsed back into structured data. `reverse_conversion` handles this, returning a dictionary with the data.

In [57]:
# Reverse convert
full_trajectory = converted["text"] + generated_answer
ret_dict = converter.reverse_conversion(full_trajectory, dm, date_of_first_event)

TypeError: 'DataManager' object is not subscriptable

In [56]:
ret_dict["events"].head()

NameError: name 'ret_dict' is not defined