In [None]:
!pip install datasets

In [None]:
!pip install bitsandbytes

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 📦 Imports
import os
import json
import torch
import wandb
from datasets import Dataset
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

In [None]:
!wandb login --relogin

In [None]:
# 🧪 Init W&B
wandb.init(
    project="cardiology-expert-medlineplus-sft",
    name="llama3-8b-ultramedical-cardiology-expert-v1",
    tags=["llama3-8b-ultramedical", "sft", "cardiology", "medical"],
    notes="SFT of llama3-8b-ultramedical for cardiology expertise"
)
#key 3114d04ef3f8187e6f6852dd28ede0fa5a2ec32c

In [None]:
# 💾 Save path
model_path = '/content/drive/MyDrive/medmoe/checkpoints/cardiology_llama3_8b_expert_model'

# 🔧 Hyperparameters
wandb_config = {
    "model_name": "TsinghuaC3I/Llama-3-8B-UltraMedical",
    "learning_rate": 2e-4,
    "epochs": 20,
    "batch_size": 16,
    "gradient_accumulation_steps": 8,
    "lora_r": 16,
    "lora_alpha": 32,
    "medical_domain": "cardiology",
    "load_pretrained": True  # Set to False to load model from scratch
}
wandb.config.update(wandb_config)

In [None]:
# 🧠 Quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

In [None]:
# 🧾 Tokenizer
tokenizer = AutoTokenizer.from_pretrained(wandb_config["model_name"])
tokenizer.pad_token = tokenizer.eos_token


In [None]:
# 🧠 Model
if wandb_config["load_pretrained"] and os.path.exists(model_path):
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        quantization_config=bnb_config,
        torch_dtype=torch.float16,
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        wandb_config["model_name"],
        device_map="auto",
        quantization_config=bnb_config,
        torch_dtype=torch.float16,
    )

model = prepare_model_for_kbit_training(model)

In [None]:
# 🧪 LoRA Config
lora_config = LoraConfig(
    r=wandb_config["lora_r"],
    lora_alpha=wandb_config["lora_alpha"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
model = get_peft_model(model, lora_config)

In [None]:
# 📂 Load Dataset
with open("/content/drive/MyDrive/medmoe/blood_heart_circulation_qa.json", "r") as f:
    qa_data = json.load(f)

train_data = []
for topic in qa_data:
    for question, answer in topic['question_answer_pair']:
        prompt = "Answer this question about cardiology health: "
        train_data.append({"text": prompt + question,
                           "reference": answer
                           })

dataset = Dataset.from_list(train_data).train_test_split(test_size=0.1)

In [None]:
# 🔁 Tokenize
def tokenize(example):
    model_inputs = tokenizer(
        example["text"],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    labels = tokenizer(
        example["reference"],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    model_inputs["labels"] = labels['input_ids']
    return model_inputs

tokenized = dataset.map(
    tokenize,
    batched=True,
    remove_columns=dataset["train"].column_names
)

In [None]:
class SemanticTrainer(Trainer):
    def __init__(self, tokenizer, *args, **kwargs):
        kwargs["processing_class"] = tokenizer
        super().__init__(*args, **kwargs)
        self._signature_columns = ['input_ids', 'attention_mask', 'labels']

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")
        labels = inputs.get("labels")

        input_ids = input_ids.to(model.device)
        attention_mask = attention_mask.to(model.device) if attention_mask is not None else None
        labels = labels.to(model.device) if labels is not None else None

        # Generate answers
        outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        return_dict=True
        )

        # Decode generated text
        hidden_states = outputs.hidden_states[-1]

        # Mean pooling over non-padding tokens
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size())
        sum_embeddings = torch.sum(hidden_states * input_mask_expanded, dim=1)
        sum_mask = input_mask_expanded.sum(dim=1)
        gen_embeddings = sum_embeddings / torch.clamp(sum_mask, min=1e-9)

        with torch.no_grad():
            ref_attention_mask = (labels != self.tokenizer.pad_token_id).to(model.device)
            ref_outputs = model(
                input_ids=labels,
                attention_mask=ref_attention_mask,
                output_hidden_states=True,
                return_dict=True
            )
            ref_hidden = ref_outputs.hidden_states[-1]
            ref_mask_expanded = ref_attention_mask.unsqueeze(-1).expand(ref_hidden.size())
            sum_ref_embeds = torch.sum(ref_hidden * ref_mask_expanded, dim=1)
            sum_ref_mask = ref_mask_expanded.sum(dim=1)
            ref_embeddings = sum_ref_embeds / torch.clamp(sum_ref_mask, min=1e-9)

        # Compute cosine similarity as reward
        sim = F.cosine_similarity(gen_embeddings, ref_embeddings, dim=-1)
        loss = 1 - sim.mean()

        return (loss, outputs) if return_outputs else loss

In [None]:
# ⚙️ Training Args
training_args = TrainingArguments(
    output_dir=model_path,
    save_strategy="steps",
    per_device_train_batch_size=wandb_config["batch_size"],
    per_device_eval_batch_size=wandb_config["batch_size"],
    gradient_accumulation_steps=wandb_config["gradient_accumulation_steps"],
    num_train_epochs=wandb_config["epochs"],
    learning_rate=wandb_config["learning_rate"],
    remove_unused_columns=False,
    logging_dir="./logs",
    logging_steps=10,
    save_steps=10,
    save_total_limit=3,
    fp16=True,
    report_to="wandb",
    metric_for_best_model="loss"
)

trainer = SemanticTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["test"],
    tokenizer=tokenizer
)

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # Set this before trainer.train() call

In [None]:
# 🚀 Train
trainer.train()
trainer.save_model('/content/drive/MyDrive/medmoe/model/cardiology_llama3_8b_expert_model')
tokenizer.save_pretrained('/content/drive/MyDrive/medmoe/model/cardiology_llama3_8b_expert_model')