<a href="https://colab.research.google.com/github/RubinThomas75/eli5_meditron/blob/main/eli5_meditron_fine_tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Meditron Model Fine Tuning

Downloaded weights from https://huggingface.co/epfl-llm/meditron-70b

Install Dependencies

In [None]:
!pip install torch transformers accelerate datasets peft bitsandbytes

In [None]:
import os
from google.colab import drive

drive.mount('/content/drive')
token_file_path = "/content/drive/MyDrive/hf_read_token.txt"

with open(token_file_path, "r", encoding="utf-8-sig") as f:
    token = f.read().strip()

os.environ["HF_TOKEN"] = token

### Set up Lora

In [None]:
import torch
import os
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model

cache_dir = "/content/drive/MyDrive/epfLLM_meditron7b"
train_file = "/content/drive/MyDrive/eli5_medical_train.jsonl"
val_file = "/content/drive/MyDrive/eli5_medical_val.jsonl"

dataset = load_dataset("json", data_files={"train": train_file, "validation": val_file})

# Load Meditron 7B model and tokenizer
model_name = "epfl-llm/meditron-7b"

tokenizer = AutoTokenizer.from_pretrained(
    model_name, cache_dir=cache_dir, use_auth_token=os.environ["HF_TOKEN"]
)

model = AutoModelForCausalLM.from_pretrained(
    model_name, cache_dir=cache_dir, use_auth_token=os.environ["HF_TOKEN"],
    load_in_8bit=True, device_map="auto"
)

# Apply LoRA
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"]
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

def tokenize_function(example):
    tokens = tokenizer(
        example["text"], truncation=True, padding="max_length", max_length=512
    )
    tokens["labels"] = tokens["input_ids"].copy()
    return tokens


tokenized_datasets = dataset.map(tokenize_function, batched=True)



### Run Trainer

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/meditron-lora-checkpoints",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    logging_dir="/content/drive/MyDrive/logs",
    save_total_limit=2,
    learning_rate=2e-5,
    fp16=True,
    push_to_hub=False
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer
)


In [None]:
trainer.train()

model.save_pretrained("/content/drive/MyDrive/meditron-lora")
print("Training completed.")

tokenizer.save_pretrained("/content/drive/MyDrive/meditron-lora")
print("Model saved in Google Drive.")

The above step will save the adaptors. When loading, we need to merge the adaptor with model.

In [None]:
from peft import PeftModel

base_model_name = "epfl-llm/meditron-7b"
adapter_path = "/content/drive/MyDrive/meditron-lora"

tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = AutoModelForCausalLM.from_pretrained(base_model_name, load_in_8bit=True, device_map="auto")

# Load LoRA adapter and merge
model = PeftModel.from_pretrained(model, adapter_path)
model = model.merge_and_unload()


### Run inference

In [None]:
# Stop token?
model.config.pad_token_id = model.config.eos_token_id

system_message = "[SYSTEM]: You are a helpful medical assistant who explains complex topics in simple terms."
user_question = "What is a headache?"
input_text = f"{system_message}\n[USER]: {user_question}\n[ASSISTANT]:"

input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

# Generate response
with torch.no_grad():
    output_ids = model.generate(
        input_ids,
        max_length=300,
        num_beams=3,
        early_stopping=True,
        no_repeat_ngram_size=3,
        top_p=0.9,
        top_k=50,
        eos_token_id=model.config.eos_token_id
    )

decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

print("\n=== Model Response ===")
print(decoded_output.replace(input_text, "").strip())  # Remove prompt from output

### Legacy, plaground and base model inference


In [None]:
# Load EPFL LLM
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "epfl-llm/meditron-7b"
cache_dir = "/content/drive/MyDrive/epfLLM_meditron7b"

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=cache_dir,
    use_auth_token=os.environ["HF_TOKEN"]
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    cache_dir=cache_dir,
    use_auth_token=os.environ["HF_TOKEN"]
)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded on device: {device}")

In [None]:
# Define stop tokens (e.g., ### or newline-based termination)
stop_token = "###"
stop_token_id = tokenizer.encode(stop_token, add_special_tokens=False)[0]

# Quick Test Inference
input_text = "A child asks: 'What is a headache?' Answer in a way a 5-year-old would understand. Example: 'A headache is when your head feels tight or ouchy. Sometimes it happens when you're too tired or didn't drink enough water.' Now, your answer:"

input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

model.config.pad_token_id = model.config.eos_token_id

# Generate attention mask
attention_mask = torch.ones(input_ids.shape, device=device)

# Generate the response with stop token enforcement
with torch.no_grad():
    output_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=200,         # Adjust for longer/shorter responses
        num_beams=5,            # Increase for more exhaustive search
        early_stopping=True,
        no_repeat_ngram_size=2,
        top_p=0.9,              # Use nucleus sampling
        top_k=50,               # Use top-k sampling
        eos_token_id=stop_token_id  # Force model to stop at the stop token
    )

# Decode the output
decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

# Truncate response at the stop token if it appears
if stop_token in decoded_output:
    decoded_output = decoded_output.split(stop_token)[0]

print("\n=== Model Response ===")
print(decoded_output)
