In [1]:
!pip install -U bitsandbytes 
!pip install -U transformers 
!pip install -U accelerate 
!pip install -U peft
!pip install -U trl
!pip install -U datasets


import os
import torch
import wandb
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig, 
    TrainingArguments, 
    logging
)
from peft import LoraConfig, get_peft_model
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login
from trl import SFTTrainer, setup_chat_format
import bitsandbytes as bnb

In [2]:
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("huggingface")
login(token=hf_token)
wb_token = user_secrets.get_secret("wandb")

wandb.login(key=wb_token)
run = wandb.init(project='Fine-tune Gemma-2-2b-it Doctor', job_type="training", anonymous="allow")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mvattvoltamper[0m ([33mvattvoltamper-ustudy[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111304434444441, max=1.0)…

In [4]:
base_model = "google/gemma-2-2b-it"
new_model = "Gemma-2-2b-it-ChatDoctor"
dataset_name = "lavita/ChatDoctor-HealthCareMagic-100k"

if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
    attn_implementation = "flash_attention_2" # !pip install -qqq flash-attn
else:
    torch_dtype = torch.float16
    attn_implementation = "eager"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)


In [None]:
!pip install trl

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            if len(names) == 1:
                lora_module_names.add(names[0])
            else:
                lora_module_names.add(names[-1])
    lora_module_names.discard('lm_head')  
    return list(lora_module_names)

modules = find_all_linear_names(model)

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

In [6]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=modules
)

tokenizer.chat_template = None 

model, tokenizer = setup_chat_format(model, tokenizer)
model = get_peft_model(model, peft_config)


In [7]:
import re
from datasets import load_dataset


dataset = load_dataset(dataset_name, split="all", cache_dir="./cache")
dataset = dataset.shuffle(seed=42).select(range(2000))  

def clean_text(text):
    text = re.sub(r'\b(?:www\.[^\s]+|http\S+)', '', text)                   # Remove URLs
    text = re.sub(r'\b(?:aCht Doctor(?:.com)?(?:.in)?|www\.(?:google|yahoo)\S*)', '', text)  # Remove site names
    text = re.sub(r'\s+', ' ', text)                                    
    return text.strip()

def format_chat_template(row):
    cleaned_instruction = clean_text(row["instruction"])
    cleaned_input = clean_text(row["input"])
    cleaned_output = clean_text(row["output"])
    
    row_json = [
        {"role": "system", "content": cleaned_instruction},
        {"role": "user", "content": cleaned_input},
        {"role": "assistant", "content": cleaned_output}
    ]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

dataset = dataset.map(format_chat_template, num_proc=4)

dataset = dataset.train_test_split(test_size=0.1)
data_collator = lambda batch: tokenizer(batch["text"], return_tensors="pt", padding=True, truncation=True)


README.md:   0%|          | 0.00/542 [00:00<?, ?B/s]

(…)-00000-of-00001-5e7cb295b9cff0bf.parquet:   0%|          | 0.00/70.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/112165 [00:00<?, ? examples/s]

  self.pid = os.fork()


Map (num_proc=4):   0%|          | 0/2000 [00:00<?, ? examples/s]

  self.pid = os.fork()


In [8]:
training_args = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=1,
    eval_strategy="steps",
    eval_steps=200, 
    save_steps=500,  
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=0.0002,
    fp16=True,
    bf16=False,
    group_by_length=True,
    report_to="wandb",
    load_best_model_at_end=False 
)


trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    max_seq_length=512,
    dataset_text_field="text",  
    tokenizer=tokenizer,
    args=training_args,
    packing=False,
)

model.config.use_cache = False



Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/1800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]



In [9]:
trainer.train()



Step,Training Loss,Validation Loss
200,2.271,2.578353
400,2.2171,2.52285
600,2.4161,2.488086
800,1.7434,2.46427




TrainOutput(global_step=900, training_loss=2.5062590618928273, metrics={'train_runtime': 1499.4907, 'train_samples_per_second': 1.2, 'train_steps_per_second': 0.6, 'total_flos': 5615864755831296.0, 'train_loss': 2.5062590618928273, 'epoch': 1.0})

In [10]:
wandb.finish()
model.config.use_cache = True

VBox(children=(Label(value='0.029 MB of 0.029 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/loss,█▅▂▁
eval/runtime,█▁▅▄
eval/samples_per_second,▁█▅▅
eval/steps_per_second,▁█▅▅
train/epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇██
train/global_step,▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
train/grad_norm,▅▆▃▄▃▂▄▅▄▄▄▂▅▁▄▂▂▄▃▅▃▄▅▂▄█▃▃▄▄▂▃▄▃▂▄▃▅▄▅
train/learning_rate,▃████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁
train/loss,▆▅█▅▄▆▃▁▅▃▆▆▆▅▅▄▃▇▄▃▆▄▄▄▁▆▄▆▅▄▄▃▆▇▃▃▅▅▄▄

0,1
eval/loss,2.46427
eval/runtime,62.1611
eval/samples_per_second,3.217
eval/steps_per_second,3.217
total_flos,5615864755831296.0
train/epoch,1.0
train/global_step,900.0
train/grad_norm,2.1567
train/learning_rate,0.0
train/loss,2.6493


In [12]:
from transformers import GenerationConfig

messages = [
    {"role": "system", "content": "You are a medical expert specializing in respiratory diseases."},
    {"role": "user", "content": "I have a persistent cough, night sweats, and recent weight loss. I’ve been to multiple doctors with no diagnosis yet. Could these symptoms be related to tuberculosis or another serious illness? Please provide a detailed answer considering possible causes and recommended next steps."}
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")

outputs = model.generate(
    **inputs,
    max_length=350,          
    top_k=50,                
    top_p=0.85,               
    temperature=0.3,         
    no_repeat_ngram_size=3,  
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True).split("assistant")[-1].strip()
print(response)




Hi, Thanks for writing to us. I have gone through your query and understand your concern. I would like to tell you that the symptoms you have mentioned are suggestive of tuberculosis. I suggest you to consult a pulmonologist and get done a chest x-ray and sputum examination. If the chest x -ray is suggestive of TB, then you should get done sputum culture and sensitivity test. If it is positive, then it is confirmed that you have TB. You should take anti-TB treatment under the supervision of a pulmonology. I hope this information would help you. Please do not hesitate to ask in case of any further doubts. Thanks and regards. Wish you a good health. . N. Senior Surgical Specialist. . S. Genl-CVTS. . M.S. . D.N.B.S, D.C.S(S). . F.C.(S). Wish you good health and a long life. . . N, Senior Surgical specialist. .S. Gen. CVTS. M. S. D. N. B. S, D C S (S) F. C. (S). Thanks and Regards. Wish You Good Health and a Long Life. .N. Senior surgical specialist. S Genl CVTS M.s D.n.B S D.c.S (S),
