In [None]:
%%time
!python -V

In [None]:
%%time
!pip install -q git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3 

import kagglehub
import torch
from transformers import AutoTokenizer
from transformers.models.gemma3 import Gemma3ForCausalLM

In [None]:
%%time
!pip install -qq accelerate peft bitsandbytes 

In [None]:
%%time
# import kagglehub
import torch
from transformers import AutoTokenizer
from transformers.models.gemma3 import Gemma3ForCausalLM

# from transformers import AutoTokenizer, Gemma3ForCausalLM

model_id = "/kaggle/input/gemma-3/transformers/gemma-3-1b-it/1"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = Gemma3ForCausalLM.from_pretrained(model_id, device_map="cuda:0", )

# inference
input_ids = tokenizer("Write me a poem about Machine Learning.", return_tensors="pt").to(model.device)
outputs = model.generate(**input_ids, max_new_tokens=100)
text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(text)

In [None]:
%%time
from datasets import load_dataset
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT",'en')['train']
dataset

In [None]:
%%time
def preprocess_function(examples):
    prompts = [f"Question: {q}\nComplex_CoT: {cot}\nResponse: {r}" for q, cot, r in zip(examples['Question'], examples['Complex_CoT'], examples['Response'])]
    tokenized_inputs = tokenizer(prompts, truncation=True, max_length=1024, padding="max_length")
    tokenized_inputs["labels"] = tokenized_inputs["input_ids"].copy()  # Important for causal LM
    return tokenized_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)

In [None]:
%%time
from peft import LoraConfig, get_peft_model
import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM, TrainingArguments, Trainer
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

In [None]:
%%time
# LoRA configuration
lora_config = LoraConfig(
    r=8,  # Rank
    lora_alpha=32,
    target_modules=[
        "q_proj", "v_proj", "k_proj", "o_proj",
        "gate_proj", "down_proj", "up_proj", "down_proj", "up_proj"], #Adjust to match model layers
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)

In [None]:
%%time
# Prepare model for k-bit training and add LoRA
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
%%time
model = model.to(torch.device("cuda"))
# Training arguments
# training_args = TrainingArguments(
#     output_dir="./gemma-medical-lora",
#     per_device_train_batch_size=1,
#     gradient_accumulation_steps=4,
#     learning_rate=2e-4,
#     logging_steps=10,
#     max_steps=500, #adjust based on your needs and resources
#     report_to="none", #use "wandb" or "tensorboard" if you have them setup
#     save_steps=100,
#     save_total_limit=2,
#     optim="paged_adamw_8bit",
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.03,
#     fp16=True, #or bf16=True if supported
#     push_to_hub=False, #set to True if you want to push to hub
# )

# # Training arguments
# training_args = TrainingArguments(
#     output_dir="./results_small",
#     per_device_train_batch_size=1,
#     gradient_accumulation_steps=4,
#     learning_rate=2e-4,
#     logging_steps=1,
#     max_steps=20,
#     report_to="none",
#     save_steps=10,
#     save_total_limit=1,
#     optim="paged_adamw_8bit",
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.03,
#     fp16=True,
#     push_to_hub=False,
# )

# Training arguments
training_args = TrainingArguments(
    output_dir="./results_small",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=256,
    learning_rate=2e-4,
    logging_steps=1,
    # max_steps=500,
    max_steps=2,
    report_to="none",
    save_steps=10,
    save_total_limit=1,
    optim="paged_adamw_8bit",
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    fp16=True,
    push_to_hub=False,
)

# Trainer setup and training
trainer = Trainer(
    model=model,
    args=training_args,
    # train_dataset=tokenized_dataset.select(range(5)),
    train_dataset=tokenized_dataset.select(range(70)),
    # train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
)

trainer.train()

In [None]:
%%time
# Save the fine-tuned model
trainer.save_model("./finetuned_model_1b")
print("Model fine-tuning complete and saved to './finetuned_model_1b'.")

In [None]:
%%time
# Merge LoRA weights into the base model
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Load the fine-tuned LoRA model
model_id = "/kaggle/input/gemma-3/transformers/gemma-3-1b-it/1"
finetuned_model_path = "./finetuned_model_1b"

# Load the base model and tokenizer
base_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load the LoRA adapter
model = PeftModel.from_pretrained(base_model, finetuned_model_path)

# Merge the LoRA weights into the base model
merged_model = model.merge_and_unload()

# Save the merged model
merged_model_path = "./merged_gemma3_medical_1b"
merged_model.save_pretrained(merged_model_path)
tokenizer.save_pretrained(merged_model_path)
print(f"Merged model saved to '{merged_model_path}'.")

In [None]:
%%time
# Test the merged model for chatbot inference
def generate_response(prompt, max_length=200):
    # Prepare the input prompt with a conversational format
    input_text = f"User: {prompt} Assistant: "
    inputs = tokenizer(input_text, return_tensors="pt").to(merged_model.device)
    
    # Generate response
    outputs = merged_model.generate(
        **inputs,
        max_length=max_length,
        num_return_sequences=1,
        do_sample=True,
        top_p=0.9,
        temperature=0.7,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id
    )
    
    # Decode and clean the response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Remove the input prompt from the response
    response = response[len(input_text):].strip()
    return response

# Example usage
test_prompt = "What are the symptoms of diabetes?"
response = generate_response(test_prompt)
print(f"Prompt: {test_prompt}")
print(f"Response: {response}")