In [None]:
%%capture
!pip install unsloth
# Also get the latest nightly Unsloth!
!pip install --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git

In [None]:
from unsloth import FastLanguageModel
import torch
from trl import SFTTrainer
from unsloth import is_bfloat16_supported
from huggingface_hub import login
from transformers import TrainingArguments
from datasets import load_dataset
import wandb
from kaggle_secrets import UserSecretsClient

In [None]:
from huggingface_hub import login
login(token = "Hugging_face_api_key")
user_secret = UserSecretsClient()
# hf_tok = user_secret.get_secret("HF_tok")
hf_tok = "Hugging_face_api_key"
wnb_tok = user_secret.get_secret("wandb_check")

login(hf_tok)

wandb.login(key = wnb_tok)
run = wandb.init(
    project='Fine-tune-DeepSeek-R1-Distill-Llama-8B on Medical health Dataset', 
    job_type="training", 
    anonymous="allow"
)


In [None]:
max_seq_length= 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    # model_name = "unsloth/DeepSeek-R1-Distill-Llama-8B",
    model_name = "DemiSho/mental-gen-bot-3.0",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    token = hf_tok, 
)

In [None]:
train_prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context. 
Write a response that appropriately completes the request. 
Before answering, think carefully about the instruction and context, and create a step-by-step chain of thoughts to ensure a logical, accurate, and empathetic response.

### Instruction:
You are a mental health expert with advanced knowledge in psychology, psychotherapy, and mental health care planning. 
You are well-versed in areas such as Cognitive Behavioral Therapy (CBT), trauma-informed care, adolescent mental health, and evidence-based therapeutic interventions. 
Please help with the following mental health-related task.Remember that you are a mental health expert, question other than 
the domain of mental health are unknown to you so refuse them politely.Don't respond to them.

### Question:
{}

### Response:
{}"""


In [None]:
dataset = load_dataset("ShenLab/MentalChat16K",split = "train[0:500]", trust_remote_code = True)
# dataset

In [None]:
EOS_TOKEN =tokenizer.eos_token

def formatting_prompts_func(examples):
    ques = examples["input"]
    # contexts = examples["input"]
    responses = examples["output"]
    # default_cot = "Let's break down the task step by step to ensure clarity and empathy."
    # If 'chain_of_thought' is available in the dataset
    # cots = examples.get("chain_of_thought", [""] * len(ques))

    texts = []
    for instr, resp in zip(ques, responses):
        prompt = train_prompt_style.format(instr, resp) + EOS_TOKEN
        texts.append(prompt)

    return {
        "text": texts,
    }


In [None]:
data_tun = dataset.map(formatting_prompts_func,batched = True)
# data_tun["text"][0]

In [None]:
model_lora = FastLanguageModel.get_peft_model(
    model,
    r=16,  
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=16,
    lora_dropout=0,  
    bias="none",  
    use_gradient_checkpointing="unsloth",  # True or "unsloth" for very long context
    random_state=3407,
    use_rslora=False,  
    loftq_config=None,
)

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model=model_lora,
    tokenizer=tokenizer,
    train_dataset=data_tun,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        num_train_epochs = 1,# warmup_ratio for full training runs!
        warmup_steps=5,
        max_steps=60,
        learning_rate=2e-4,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=10,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs",
    )
)

In [None]:
trainer_stat = trainer.train()
# trainer.train()
trainer.model.save_pretrained("./deep-Psychotherapist")
tokenizer.save_pretrained("./deep-Psychotherapist-tok")
wandb.finish()