In [1]:
%%capture
# Install necessary packages
%pip install -U transformers 
%pip install -U datasets 
%pip install -U accelerate 
%pip install -U peft 
%pip install -U trl 
%pip install -U bitsandbytes 
%pip install -U wandb  # Install W&B for monitoring
%pip install -U scikit-learn 
%pip install -U flash-attn --no-build-isolation


In [2]:
# Import required libraries
import wandb  # Import W&B for monitoring
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig, TrainerCallback
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import torch
import random
from sklearn.metrics import precision_recall_fscore_support, accuracy_score  # Import scikit-learn metrics
from trl import SFTTrainer
from huggingface_hub import login

# Authenticate with Hugging Face
huggingface_token = ""
login(token=huggingface_token)

# Initialize W&B and log in
wandb.login(key="")
wandb.init(project="gemma2-finetuning", entity="adam-fendri")

# Load model and tokenizer
base_model = "google/gemma-2-2b"
new_model = "Gemma-2-2b-it-medical"
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.padding_side = 'right'  # Set padding side to right

# Determine CUDA device capabilities and use eager attention implementation
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
    attn_implementation = "eager"  # Fallback to eager attention for stability
else:
    torch_dtype = torch.float16
    attn_implementation = "eager"

# QLoRA configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

# Identify target modules for LoRA
def find_all_linear_names(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if any(isinstance(module, cls) for cls in (torch.nn.Linear, torch.nn.Conv2d)):  # Check for linear layers
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

modules = find_all_linear_names(model)
print(modules)

# LoRA configuration
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=modules
)

# Setup model for training
model = get_peft_model(model, peft_config)
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


[34m[1mwandb[0m: Currently logged in as: [33madamfendri[0m ([33madam-fendri[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

['o_proj', 'down_proj', 'up_proj', 'k_proj', 'v_proj', 'gate_proj', 'q_proj']
Trainable: 20766720 | total: 2635108608 | Percentage: 0.7881%


In [3]:
# Load the entire dataset
dataset = load_dataset("ruslanmv/ai-medical-chatbot", split="train")

# Shuffle and split dataset
dataset = dataset.shuffle(seed=42)
dataset = dataset.train_test_split(test_size=0.20, seed=42)
train_dataset = dataset['train']
eval_dataset = dataset['test']

# Format the dataset with the chat template
def format_chat_template(row):
    input_text = f"Context: {row['Patient']}\nQuestion: {row['Description']}\nAnswer: {row['Doctor']}\n"
    row["text"] = input_text
    return row

train_dataset = train_dataset.map(format_chat_template, num_proc=2)
eval_dataset = eval_dataset.map(format_chat_template, num_proc=2)

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)

train_dataset = train_dataset.map(tokenize_function, batched=True, num_proc=2)
eval_dataset = eval_dataset.map(tokenize_function, batched=True, num_proc=2)

# Remove raw text columns
train_dataset = train_dataset.remove_columns(['Description', 'Patient', 'Doctor', 'text'])
eval_dataset = eval_dataset.remove_columns(['Description', 'Patient', 'Doctor', 'text'])

# Custom callback to sample different evaluation subsets
class RandomSubsetEvalCallback(TrainerCallback):
    def __init__(self, eval_dataset, subset_size):
        self.eval_dataset = eval_dataset
        self.subset_size = subset_size

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % args.eval_steps == 0:
            indices = random.sample(range(len(self.eval_dataset)), self.subset_size)
            subset = self.eval_dataset.select(indices)
            trainer.evaluate(eval_dataset=subset)
        return control


In [4]:
# Set training hyperparameters with checkpointing
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=4,  # Increase batch size for RTX 3090
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,  # Adjust gradient accumulation
    optim="paged_adamw_32bit",
    num_train_epochs=3,  # Increase number of epochs to 3
    eval_strategy="steps",  # Evaluate on steps
    save_strategy="steps",  # Save checkpoints on steps
    eval_steps=1000,  # Evaluate every 1000 steps
    logging_steps=500,  # Log training metrics every 500 steps
    warmup_steps=100,  # Warmup steps for stability
    learning_rate=1e-4,  # Lower learning rate for stability
    fp16=True,  # Use mixed precision training
    bf16=False,
    group_by_length=True,
    save_total_limit=2,  # Keep last 2 checkpoints
    save_steps=2000,  # Save checkpoints every 2000 steps
    report_to="wandb",  # Report to W&B for monitoring
    load_best_model_at_end=True,  # Load best model at the end of training
)

# Initialize the trainer with checkpointing
trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    max_seq_length=512,  # Adjusted max_seq_length here
    dataset_text_field="text",  # Use dataset_text_field directly
    tokenizer=tokenizer,
    callbacks=[
        RandomSubsetEvalCallback(eval_dataset, 500),  # Evaluate on 500 rows
    ],
    packing=False,
)

# Start training
model.config.use_cache = False  # Disable cache during training
trainer.train()
model.config.use_cache = True  # Re-enable cache after training

# Push the model to Hugging Face hub
trainer.save_model(new_model)
trainer.push_to_hub(new_model, use_temp_dir=False)



Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Step,Training Loss,Validation Loss
1000,2.485,2.294738
2000,2.3034,2.258224
3000,2.2627,2.226064
4000,2.2401,2.198877
5000,2.2283,2.270314
6000,2.1966,2.201096
7000,2.196,2.24915
8000,2.1809,2.180279
9000,2.1757,2.135563
10000,2.1773,2.13484


TypeError: Trainer.create_model_card() got an unexpected keyword argument 'use_temp_dir'

In [5]:
# Push the model to Hugging Face hub
trainer.save_model(new_model)
trainer.push_to_hub(commit_message="Finetuned Gemma-2 model for medical chat", blocking=True)

# Push the tokenizer to Hugging Face hub
tokenizer.push_to_hub(new_model)


Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/83.1M [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.43k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.50k [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/adamfendri/Gemma-2-2b-it-medical/commit/cf7ffeb37fd5a566f87795e35e9d0d7353e7f2de', commit_message='Upload tokenizer', commit_description='', oid='cf7ffeb37fd5a566f87795e35e9d0d7353e7f2de', pr_url=None, pr_revision=None, pr_num=None)

In [6]:
# Example inference
model.eval()  # Set model to evaluation mode
prompt = '''You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.'''
question = '''I'm a 35-year-old male and for the past few months, I've been experiencing fatigue, increased sensitivity to cold, and dry, itchy skin. Could these symptoms be related to hypothyroidism? If so, what steps should I take to get a proper diagnosis and discuss treatment options?'''
input_text = prompt + "\n" + question

# Tokenize input text for inference
inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)

# Generate a response from the model
with torch.no_grad():
    output = model.generate(
        inputs['input_ids'],
        max_length=512,  # Set max generation length
        num_beams=5,  # Use beam search with 5 beams
        early_stopping=True,
        no_repeat_ngram_size=2
    )

# Decode and print the generated response
response = tokenizer.decode(output[0], skip_special_tokens=True)
print(response)




You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
I'm a 35-year-old male and for the past few months, I've been experiencing fatigue, increased sensitivity to cold, and dry, itchy skin. Could these symptoms be related to hypothyroidism? If so, what steps should I take to get a proper diagnosis and discuss treatment options?
Hi, Welcome to Health care magic forum.                      As you describe it appears to be the hypothyriodism, as the symptoms are similar to that.                       I advise you to consult an endocrinologist for diagnosis,and treatment. You may need to have thyroid function tests for confirmation.                        Take more of green leafy vegetables, pulses, sprouts, protein rich foods,to have a good health and resistance against infections.                         Wishing for