In [None]:
!pip install -q -U bitsandbytes
!pip install -q -U peft
!pip install -q -U trl
!pip install -q -U accelerate
!pip install -q -U datasets
!pip install -q -U transformers

In [None]:
import os
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

#set the quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
model_id = "google/gemma-2-2b"
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

In [None]:
from datasets import load_dataset
dataset_url = 'LinhDuong/chatdoctor-5k'
dataset = load_dataset(dataset_url)
dataset

In [None]:
class GemmaPrompt:
  def __init__(self):
    self.user_instrs = []
    self.output_instrs = []
  def add_user_instr(self, user_instr='', inputs=''):
    self.user_instrs.append(f'<start_of_turn>user {user_instr}. Patient: {inputs}<end_of_turn>')
  def add_output_instr(self, output_instr):
    self.output_instrs.append(f'<start_of_turn>model {output_instr}<end_of_turn>')
  def __str__(self):
    return "".join(self.user_instrs) + '\n' + "".join(self.output_instrs)

In [None]:
def generate_prompt(data_point):
  prompt = GemmaPrompt()
  prompt.add_user_instr(user_instr=data_point['instruction'], inputs=data_point['input'])
  prompt.add_output_instr(data_point['output'])
  return str(prompt)

In [None]:
text_column = [generate_prompt(data_point) for data_point in dataset["train"]]
new_dataset = dataset["train"].add_column("prompt", text_column)
new_dataset

In [None]:
new_dataset['prompt'][0]

In [None]:
dataset = new_dataset.shuffle(seed=1234)  # Shuffle dataset here
dataset = dataset.map(lambda samples: tokenizer(samples["prompt"]), batched=True)

In [None]:
dataset = dataset.train_test_split(test_size=0.1)
train_data = dataset["train"]
test_data = dataset["test"]

In [None]:
[(dataset['train'][i]['input'], dataset['train'][i]['output']) for i in range(3)]

In [None]:
import bitsandbytes as bnb
#  Find all the linear layers of the model that could potentially be optimized in 
def find_all_linear_names(model):
  cls = bnb.nn.Linear4bit 
  lora_module_names = set()
  for name, module in model.named_modules():
    if isinstance(module, cls):
      names = name.split('.')
      lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    # needed for 16-bit
    if 'lm_head' in lora_module_names:
      lora_module_names.remove('lm_head')
  return list(lora_module_names)
#
modules = find_all_linear_names(model)

In [None]:
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
# Gradient Checkpointing will slow the training process for lower memory allocation
model.gradient_checkpointing_enable()
# preprocess quantized model for training
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

In [None]:
import transformers

from trl import SFTTrainer

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side='right'

training_args = transformers.TrainingArguments(per_device_train_batch_size=1, gradient_accumulation_steps=4,
                                               gradient_checkpointing=True, max_steps=100, learning_rate=2e-4,
                                               logging_steps=100, output_dir="outputs", optim="paged_adamw_32bit",
                                               save_strategy="epoch", num_train_epochs=1, lr_scheduler_type="cosine", )
trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=test_data,
    dataset_text_field="prompt",
    peft_config=lora_config,
    max_seq_length=2500,
    args=training_args,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

torch.cuda.empty_cache()
trainer.train()

In [None]:
def get_completion(query: str, model, tokenizer) -> str:
  device = "cuda:0"

  prompt_template = """
  <start_of_turn>user
  You are a doctor, please answer the medical questions based on the patient's description.
  {query}
  <end_of_turn>\n<start_of_turn>model


  """
  prompt = prompt_template.format(query=query)

  encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

  model_inputs = encodeds.to(device)


  generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
  # decoded = tokenizer.batch_decode(generated_ids)
  decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
  return (decoded)



In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# Your existing code for loading and merging the model
new_model = "gemma-2-2b-chatdoctor"

# save model
trainer.model.save_pretrained(new_model)

base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
merged_model= PeftModel.from_pretrained(base_model, new_model)
merged_model= merged_model.merge_and_unload()

# Save the merged model
#save_adapter=True, save_config=True
merged_model.save_pretrained("merged_model",safe_serialization=True)
tokenizer.save_pretrained("merged_model")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
merged_model.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)

In [None]:
query = """ Doctor, I have been experiencing some symptoms associated with Von Hippel-Lindau disease. """
result = get_completion(query=query, model=merged_model, tokenizer=tokenizer)
print(result)