In [1]:
import torch
import gc
print("GPU available:", torch.cuda.is_available())
print(f"GPU Memory available: {torch.cuda.get_device_properties(0).total_memory / 1024**2:.2f} MB")

GPU available: True
GPU Memory available: 15095.06 MB


In [2]:
def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        print(f"GPU memory cached: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

In [3]:
!pip install unsloth




In [4]:
pip install wandb



In [27]:
import wandb
wandb.login(key="")


In [28]:
from datasets import load_dataset
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT","en")

print("Data Load Successfull")

Data Load Successfull


In [7]:
dataset

DatasetDict({
    train: Dataset({
        features: ['Question', 'Complex_CoT', 'Response'],
        num_rows: 25371
    })
})

In [8]:
import re

def extract_sections(text):
    # Extract the chain of thought
    think_match = re.search(r'<think>(.*?)</think>', text, re.DOTALL)
    # Extract the final response
    response_match = re.search(r'<response>(.*?)</response>', text, re.DOTALL)
    think_text = think_match.group(1).strip() if think_match else ""
    response_text = response_match.group(1).strip() if response_match else ""
    return think_text, response_text


def preprocess_dataset(example):
    think, response = extract_sections(example["Complex_CoT"])
    return {"think": think, "response": response}

dataset = dataset.map(preprocess_dataset)


In [9]:

train_dataset = dataset["train"].select(range(100, len(dataset["train"])))
val_dataset = dataset["train"].select(range(100))


In [10]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
# quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

# Load model with optimized settings
model_name = "unsloth/llama-3-8b-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
clear_memory()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

GPU memory allocated: 5441.35 MB
GPU memory cached: 7230.00 MB


In [11]:
pip install peft




In [12]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.gradient_checkpointing_enable()
clear_memory()

GPU memory allocated: 5454.35 MB
GPU memory cached: 7242.00 MB


In [13]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./medical-model",
    learning_rate=1e-4,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    max_steps=1000,
    warmup_steps=100,
    fp16=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    max_grad_norm=0.3,
    logging_steps=10,
    save_steps=200,
    eval_steps=200,
    optim="paged_adamw_32bit"
)

In [18]:

# Function to generate prompt from a single example
def generate_prompt(example):
    question = example['Question']
    think = example['think']
    response = example['response']

    return f"""Below is a medical question that needs analysis and reasoning. First think through the steps, then provide the final response.
Question: {question}
Think: {think}
Response: {response}"""

# Preprocessing function for batched inputs
def preprocess_function(examples):
    prompts = [
        generate_prompt({
            'Question': q,
            'think': t,
            'response': r
        }) for q, t, r in zip(examples['Question'], examples['think'], examples['response'])
    ]
    model_inputs = tokenizer(prompts, truncation=True, max_length=512, padding='max_length')
    model_inputs['labels'] = model_inputs['input_ids'].copy()
    return model_inputs

# Preprocess the train and validation datasets
tokenized_train = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names
)

tokenized_val = val_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=val_dataset.column_names
)


clear_memory()

Dataset structure: dict_keys(['Question', 'Complex_CoT', 'Response', 'think', 'response'])


Map:   0%|          | 0/25271 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

GPU memory allocated: 5454.35 MB


In [19]:
from transformers import DataCollatorForLanguageModeling


data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [20]:
from transformers import Trainer, TrainerCallback

class MemoryCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % 50 == 0:
            clear_memory()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    callbacks=[MemoryCallback()]
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [27]:
# Start training
trainer.train()

# Save the final model
trainer.save_model("./medical-model-final")
clear_memory()
print("Training Complete")

Training Complete


In [25]:
# Test the model
def generate_response(question, max_length=512):
    prompt = f"Below is a medical question that needs analysis and reasoning. First think through the steps, then provide the final response.\nQuestion: {question}"
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    outputs = model.generate(
        **inputs,
        max_length=max_length,
        temperature=0.7,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test example
test_question = "What are the common symptoms of pneumonia?"
print(generate_response(test_question))

Below is a medical question that needs analysis and reasoning. First think through the steps, then provide the final response.
Question: What are the common symptoms of pneumonia? Provide a list of three symptoms.
Think: 
Response: 
