In [None]:
import pandas as pd
from datasets import Dataset
from unsloth import FastLanguageModel, is_bfloat16_supported
from transformers import TrainingArguments, Trainer
from peft import LoraConfig
import torch

# Optional: To reduce memory usage
load_in_4bit = True
max_seq_length = 2048

# Automatically set dtype based on GPU support
dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16

# -------------------------------------------------------------------
# 1. LOAD DATA
# -------------------------------------------------------------------
train_df = pd.read_csv("medical_cases_train/medical_cases_train.csv")[["description", "transcription"]].dropna()
val_df = pd.read_csv("medical_cases_validation/medical_cases_validation.csv")[["description", "transcription"]].dropna()
test_df = pd.read_csv("medical_cases_test/medical_cases_test.csv")[["description", "transcription"]].dropna()

train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# -------------------------------------------------------------------
# 2. FORMAT PROMPTS
# -------------------------------------------------------------------
def format_prompt(example):
    return {
        "text": f"<start_of_turn>user\n{example['description']}\n<end_of_turn>\n<start_of_turn>model\n{example['transcription']}<end_of_turn>"
    }

train_dataset = train_dataset.map(format_prompt)
val_dataset = val_dataset.map(format_prompt)
test_dataset = test_dataset.map(format_prompt)

# -------------------------------------------------------------------
# 3. LOAD MODEL
# -------------------------------------------------------------------
model_name = "unsloth/Llama-3.2-3B-Instruct"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
)

tokenizer.pad_token = tokenizer.eos_token

# -------------------------------------------------------------------
# 4. APPLY LoRA
# -------------------------------------------------------------------
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

# -------------------------------------------------------------------
# 5. TOKENIZATION
# -------------------------------------------------------------------
def tokenize(example):
    tokens = tokenizer(example["text"], padding="max_length", truncation=True, max_length=512)
    tokens["labels"] = tokens["input_ids"].copy()
    return tokens

train_dataset = train_dataset.map(tokenize, remove_columns=train_dataset.column_names)
val_dataset = val_dataset.map(tokenize, remove_columns=val_dataset.column_names)
test_dataset = test_dataset.map(tokenize, remove_columns=test_dataset.column_names)

# -------------------------------------------------------------------
# 6. TRAINING ARGUMENTS
# -------------------------------------------------------------------
training_args = TrainingArguments(
    output_dir="./llama-lora-medical",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    warmup_steps=10,
    num_train_epochs=6,
    learning_rate=2e-4,
    fp16=(dtype == torch.float16),
    bf16=(dtype == torch.bfloat16),
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=2,
    report_to="none"
)

# -------------------------------------------------------------------
# 7. TRAINER
# -------------------------------------------------------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

trainer.train()

# -------------------------------------------------------------------
# 8. SAVE MODEL
# -------------------------------------------------------------------
model.save_pretrained("./llama-lora-medical")
tokenizer.save_pretrained("./llama-lora-medical")
