In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m31.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl 

In [2]:
!pip install bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.0->bitsandbytes)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# 📦 Imports
import os
import json
import torch
import wandb
from datasets import Dataset
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

In [5]:
!wandb login --relogin

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [29]:
# 🧪 Init W&B
wandb.init(
    project="orthopedic-expert-medlineplus-sft",
    name="llama3-8b-ultramedical-orthopedic-expert-v1",
    tags=["llama3-8b-ultramedical", "sft", "orthopedic", "medical"],
    notes="SFT of llama3-8b-ultramedical for orthopedic expertise"
)
#key 3114d04ef3f8187e6f6852dd28ede0fa5a2ec32c

In [30]:
# 💾 Save path
model_path = '/content/drive/MyDrive/medmoe/checkpoints/orthopedic_llama3_8b_expert_model'

# 🔧 Hyperparameters
wandb_config = {
    "model_name": "TsinghuaC3I/Llama-3-8B-UltraMedical",
    "learning_rate": 2e-4,
    "epochs": 20,
    "batch_size": 16,
    "gradient_accumulation_steps": 8,
    "lora_r": 16,
    "lora_alpha": 32,
    "medical_domain": "orthopedic",
    "load_pretrained": True  # Set to False to load model from scratch
}
wandb.config.update(wandb_config)

In [31]:
# 🧠 Quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

In [32]:
# 🧾 Tokenizer
tokenizer = AutoTokenizer.from_pretrained(wandb_config["model_name"])
tokenizer.pad_token = tokenizer.eos_token


In [33]:
# 🧠 Model
if wandb_config["load_pretrained"] and os.path.exists(model_path):
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        quantization_config=bnb_config,
        torch_dtype=torch.float16,
    )
else:
    model = AutoModelForCausalLM.from_pretrained(
        wandb_config["model_name"],
        device_map="auto",
        quantization_config=bnb_config,
        torch_dtype=torch.float16,
    )

model = prepare_model_for_kbit_training(model)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [34]:
# 🧪 LoRA Config
lora_config = LoraConfig(
    r=wandb_config["lora_r"],
    lora_alpha=wandb_config["lora_alpha"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
model = get_peft_model(model, lora_config)

In [35]:
# 📂 Load Dataset
with open("/content/drive/MyDrive/medmoe/bones_joints_muscles_qa.json", "r") as f:
    qa_data = json.load(f)

train_data = []
for topic in qa_data:
    for question, answer in topic['question_answer_pair']:
        prompt = "Answer this question about orthopedic health: "
        train_data.append({"text": prompt + question,
                           "reference": answer
                           })

dataset = Dataset.from_list(train_data).train_test_split(test_size=0.1)

In [36]:
# 🔁 Tokenize
def tokenize(example):
    model_inputs = tokenizer(
        example["text"],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    labels = tokenizer(
        example["reference"],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    model_inputs["labels"] = labels['input_ids']
    return model_inputs

tokenized = dataset.map(
    tokenize,
    batched=True,
    remove_columns=dataset["train"].column_names
)

Map:   0%|          | 0/139 [00:00<?, ? examples/s]

Map:   0%|          | 0/16 [00:00<?, ? examples/s]

In [37]:
class SemanticTrainer(Trainer):
    def __init__(self, tokenizer, *args, **kwargs):
        kwargs["processing_class"] = tokenizer
        super().__init__(*args, **kwargs)
        self._signature_columns = ['input_ids', 'attention_mask', 'labels']

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        input_ids = inputs.get("input_ids")
        attention_mask = inputs.get("attention_mask")
        labels = inputs.get("labels")

        input_ids = input_ids.to(model.device)
        attention_mask = attention_mask.to(model.device) if attention_mask is not None else None
        labels = labels.to(model.device) if labels is not None else None

        # Generate answers
        outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        return_dict=True
        )

        # Decode generated text
        hidden_states = outputs.hidden_states[-1]

        # Mean pooling over non-padding tokens
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size())
        sum_embeddings = torch.sum(hidden_states * input_mask_expanded, dim=1)
        sum_mask = input_mask_expanded.sum(dim=1)
        gen_embeddings = sum_embeddings / torch.clamp(sum_mask, min=1e-9)

        with torch.no_grad():
            ref_attention_mask = (labels != self.tokenizer.pad_token_id).to(model.device)
            ref_outputs = model(
                input_ids=labels,
                attention_mask=ref_attention_mask,
                output_hidden_states=True,
                return_dict=True
            )
            ref_hidden = ref_outputs.hidden_states[-1]
            ref_mask_expanded = ref_attention_mask.unsqueeze(-1).expand(ref_hidden.size())
            sum_ref_embeds = torch.sum(ref_hidden * ref_mask_expanded, dim=1)
            sum_ref_mask = ref_mask_expanded.sum(dim=1)
            ref_embeddings = sum_ref_embeds / torch.clamp(sum_ref_mask, min=1e-9)

        # Compute cosine similarity as reward
        sim = F.cosine_similarity(gen_embeddings, ref_embeddings, dim=-1)
        loss = 1 - sim.mean()

        return (loss, outputs) if return_outputs else loss

In [38]:
# ⚙️ Training Args
training_args = TrainingArguments(
    output_dir=model_path,
    save_strategy="steps",
    per_device_train_batch_size=wandb_config["batch_size"],
    per_device_eval_batch_size=wandb_config["batch_size"],
    gradient_accumulation_steps=wandb_config["gradient_accumulation_steps"],
    num_train_epochs=wandb_config["epochs"],
    learning_rate=wandb_config["learning_rate"],
    remove_unused_columns=False,
    logging_dir="./logs",
    logging_steps=10,
    save_steps=10,
    save_total_limit=3,
    fp16=True,
    report_to="wandb",
    metric_for_best_model="loss"
)

trainer = SemanticTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["test"],
    tokenizer=tokenizer
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [26]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # Set this before trainer.train() call

In [39]:
# 🚀 Train
trainer.train()
trainer.save_model('/content/drive/MyDrive/medmoe/model/orthopedic_llama3_8b_expert_model')
tokenizer.save_pretrained('/content/drive/MyDrive/medmoe/model/orthopedic_llama3_8b_expert_model')

  return fn(*args, **kwargs)
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


Step,Training Loss


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Tr

KeyboardInterrupt: 