In [None]:
import pandas as pd
from datasets import Dataset
from unsloth import FastLanguageModel
from transformers import TrainingArguments, Trainer, EvalPrediction
import torch
from peft import LoraConfig
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# -------------------------------------------------------------------
# 1. LOAD DATA
# -------------------------------------------------------------------
train_df = pd.read_csv("medical_cases_train/medical_cases_train.csv")[["description", "transcription"]].dropna()
val_df = pd.read_csv("medical_cases_validation/medical_cases_validation.csv")[["description", "transcription"]].dropna()
test_df = pd.read_csv("medical_cases_test/medical_cases_test.csv")[["description", "transcription"]].dropna()

train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# -------------------------------------------------------------------
# 2. FORMAT PROMPTS
# -------------------------------------------------------------------
def format_prompt(example):
    return {
        "text": f"<start_of_turn>user\n{example['description']}\n<end_of_turn>\n<start_of_turn>model\n{example['transcription']}<end_of_turn>"
    }

train_dataset = train_dataset.map(format_prompt)
val_dataset = val_dataset.map(format_prompt)
test_dataset = test_dataset.map(format_prompt)

# -------------------------------------------------------------------
# 3. LOAD MODEL
# -------------------------------------------------------------------
model_name = "deepseek-ai/DeepSeek-R1-7B"
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=512,
    dtype=None,
    load_in_4bit=True
)
tokenizer.pad_token = tokenizer.eos_token

# -------------------------------------------------------------------
# 4. APPLY LoRA
# -------------------------------------------------------------------
FastLanguageModel.for_training(model)

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model.add_adapter(lora_config)

# -------------------------------------------------------------------
# 5. TOKENIZATION
# -------------------------------------------------------------------
def tokenize(example):
    tokens = tokenizer(example["text"], padding="max_length", truncation=True, max_length=512)
    tokens["labels"] = tokens["input_ids"].copy()
    return tokens

train_dataset = train_dataset.map(tokenize, remove_columns=train_dataset.column_names)
val_dataset = val_dataset.map(tokenize, remove_columns=val_dataset.column_names)
test_dataset = test_dataset.map(tokenize, remove_columns=test_dataset.column_names)

# -------------------------------------------------------------------
# 6. TRAINING ARGUMENTS
# -------------------------------------------------------------------
training_args = TrainingArguments(
    output_dir="./gemma-lora-medical",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    warmup_steps=10,
    num_train_epochs=6,
    learning_rate=2e-4,
    fp16=True,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    save_total_limit=2,
    report_to="none"
)


# -------------------------------------------------------------------
# 7. METRICS FUNCTION
# -------------------------------------------------------------------
def compute_metrics(eval_pred: EvalPrediction):
    preds = eval_pred.predictions.argmax(-1)
    labels = eval_pred.label_ids

    # Flatten and ignore padded tokens
    true_labels = []
    pred_labels = []
    for pred, label in zip(preds, labels):
        for p, l in zip(pred, label):
            if l != -100:
                true_labels.append(l)
                pred_labels.append(p)

    return {
        "accuracy": accuracy_score(true_labels, pred_labels),
        "precision": precision_score(true_labels, pred_labels, average='macro', zero_division=0),
        "recall": recall_score(true_labels, pred_labels, average='macro', zero_division=0),
        "f1": f1_score(true_labels, pred_labels, average='macro', zero_division=0),
    }

# -------------------------------------------------------------------
# 8. TRAINER
# -------------------------------------------------------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()

# -------------------------------------------------------------------
# 9. FINAL TEST EVALUATION
# -------------------------------------------------------------------
print("\n=== Final Evaluation on Test Set ===")
test_results = trainer.evaluate(eval_dataset=test_dataset)
for key, value in test_results.items():
    print(f"{key}: {value:.4f}")

# -------------------------------------------------------------------
# 10. SAVE MODEL
# -------------------------------------------------------------------
model.save_pretrained("./gemma-lora-medical")
tokenizer.save_pretrained("./gemma-lora-medical")
