<a href="https://colab.research.google.com/github/Cody-Lange/MentalHealthAssistant/blob/main/CounselGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tuning Llama2 for Mental Health Counseling
Approximately 1 in 4 adults in the United States struggle with a diagnosable mental disorder each year per estimates from the Johns Hopkins School of Medicine. While there are many mental health resources and professionals available, there is a lack of on-the-spot guidance that can set individuals towards the right therapeutic direction. Thanks to advancements in artificial intelligence and large language models, we now have the tools to craft such a virtual assistant. This notebook outlines the development of a preliminary prototype: we utilize a 7 billion parameter Llama 2 model, employ quantization for memory and training efficiency, and subsequently fine-tune it using Huggingface's Amod/mental_health_counseling_conversations dataset tailored for mental health Q&A. To bring this assistant to life, an in-notebook user interface is created for interaction.

The trained model is accessible on Huggingface under langecod/CounselLlama7B.

# 1.) Import & Install Necessary libraries (Colab requires installs with each run time)

In [None]:
pip install transformers datasets peft trl accelerate bitsandbytes packaging ninja sentencepiece



In [None]:
pip install flash-attn --no-build-isolation



In [None]:
import random
import gc
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import numpy as np
import pandas as pd
import transformers
import accelerate
import bitsandbytes as bnb
from datasets import load_dataset, concatenate_datasets
import torch

# 2. Loading the Mental Health Conversations Dataset:
The Amod/mental_health_counseling_conversations dataset is a collection of 3512 questions and answer pairs sourced from counselchat.com, an online counseling and therapy platform. It covers a wide range of mental health topics with responses crafted from certified psychologists. It's tailored for refining language models, specifically for generating cogent advice on mental health inquiries. All entries pairs are in English, and each entry is structured with a 'Context' (user's question) and a 'Response' (psychologist's answer).

In [None]:
import kagglehub
import shutil
import os
import pandas as pd
# Download latest version
path = kagglehub.dataset_download("thedevastator/nlp-mental-health-conversations")

# Copier les fichiers vers /content/dataset
destination_dir = "/content/dataset/nlp_mental_health_conversations"
os.makedirs(destination_dir, exist_ok=True)
shutil.copytree(path, destination_dir, dirs_exist_ok=True)

# Afficher les fichiers copiÃ©s
for root, _, files in os.walk(destination_dir):
    for file in files:
        print(os.path.join(root, file))


data=pd.read_csv("/content/dataset/nlp_mental_health_conversations/train.csv")

/content/dataset/nlp_mental_health_conversations/train.csv


In [None]:
dataset=data

In [None]:
dataset

Unnamed: 0,Context,Response
0,I'm going through some things with my feelings...,"If everyone thinks you're worthless, then mayb..."
1,I'm going through some things with my feelings...,"Hello, and thank you for your question and see..."
2,I'm going through some things with my feelings...,First thing I'd suggest is getting the sleep y...
3,I'm going through some things with my feelings...,Therapy is essential for those that are feelin...
4,I'm going through some things with my feelings...,I first want to let you know that you are not ...
...,...,...
3507,My grandson's step-mother sends him to school ...,Absolutely not!Â It is never in a child's best ...
3508,My boyfriend is in recovery from drug addictio...,I'm sorry you have tension between you and you...
3509,The birth mother attempted suicide several tim...,"The true answer is, ""no one can really say wit..."
3510,I think adult life is making him depressed and...,How do you help yourself to believe you requir...


In [None]:
# Load the dataset
dataset = load_dataset('Amod/mental_health_counseling_conversations', 'train')

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['Context', 'Response'],
        num_rows: 3512
    })
})

# 3. Importing, Quantizing, and Preparing the Llama2 Chat Model:
To safeguard private health information and intellectual property, utilizing an open-sourced model is imperative. Meta's Llama 2 stands out in this regard, offering a collection of pretrained and fine-tuned large language models (LLMs) that span from 7 billion to 70 billion parameters. The Llama 2-Chat variant is especially tailored for dialogue applications, demonstrating superior performance over other open-source chat models in various benchmarks and human evaluations for both helpfulness and safety. This made the 7 billion parameter Llama 2-Chat model an ideal choice for our prototype. Additionally, to address memory constraints, expedite training, and ensure cost-effective operations, we employed a version of the model with 4-bit weights and activations through quantization.

In [None]:
!pip install flash-attn



In [None]:
nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16 # A100
)

#Load Tokenizer
tokenizer= AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token='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')
# Add Padding Token
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load the LLaMA model in 4-bit
model = transformers.AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    token='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',
    quantization_config=nf4_config,
    use_flash_attention_2=False  #Improves attention algorithm from quadratic time down to linear
)


`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using a 4-bit quantized model offers advantages in terms of memory usage, training speed, and inference performance. However, such quantization makes the model incompatible with conventional training approaches. To address this challenge, the "Quantized Low-Rank Adaptation" (QLoRA) method is employed. In QLoRA, the original pre-trained model weights remain frozen in 4-bit format, but an "adapter" with 16-bit model weights is created, allowing for fine-tuning on a specific task.

In [None]:
# LoRA config based on QLoRA paper
peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=64,
        bias="none",
        task_type="CAUSAL_LM",
)

In [None]:
# prepare model for training
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

# 4.) Setting up the Trainer & Prompting:
Training a quantized, generative model shares similarities with training other LLMs, such as BERT and Big Bird, but there are crucial differences to consider:
1.  Typically, 1-3 epochs are sufficient to fine-tune the model for specific tasks. Overtraining can lead to overfitting, where the model might restrict its responses to the training set and not fully utilize the vast knowledge from Meta's pretraining dataset.
2.  The model requires inputs that are a concatenation of prompts, contexts, and outputs, aiming to predict the outputs based on both the prompt and the context. For Llama 2, it's essential to maintain a consistent format that aligns with Llama's original pretraining:
```
# Formatting Function to Follow Later
<s>[INST] <<SYS>>
{{system message}}
<</SYS>>
{{message}} [/INST] {{answer}} </s>
```
3. The prompt plays a pivotal role in shaping the bot's response. A less descriptive prompt can result in inappropriate or even potentially harmful responses, especially when discussing sensitive subjects like mental health. It's crucial to draft prompts that steer the model towards producing empathetic, helpful, and non-judgmental replies. Our goal is for the model to guide users towards suitable resources and techniques rather than making diagnoses or prescriptions.







In [None]:
args = TrainingArguments(
    output_dir="CounselLlama7B",
    logging_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_checkpointing=False,
    optim="paged_adamw_8bit",
    logging_steps=10,
    save_strategy="epoch",
    learning_rate=1e-4,
    tf32=False,  # DÃ©sactiver TF32 pour Ã©viter l'erreur
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    load_best_model_at_end=True,
    evaluation_strategy='epoch',
    #torch_compile=False

   )



In [None]:
dataset

Unnamed: 0,Context,Response
0,I'm going through some things with my feelings...,"If everyone thinks you're worthless, then mayb..."
1,I'm going through some things with my feelings...,"Hello, and thank you for your question and see..."
2,I'm going through some things with my feelings...,First thing I'd suggest is getting the sleep y...
3,I'm going through some things with my feelings...,Therapy is essential for those that are feelin...
4,I'm going through some things with my feelings...,I first want to let you know that you are not ...
...,...,...
3507,My grandson's step-mother sends him to school ...,Absolutely not!Â It is never in a child's best ...
3508,My boyfriend is in recovery from drug addictio...,I'm sorry you have tension between you and you...
3509,The birth mother attempted suicide several tim...,"The true answer is, ""no one can really say wit..."
3510,I think adult life is making him depressed and...,How do you help yourself to believe you requir...


In [None]:
dataset['Context'][0]

"I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here.\n   I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it.\n   How can I change my feeling of being worthless to everyone?"

In [None]:
dataset['Response'][0]

"If everyone thinks you're worthless, then maybe you need to find new people to hang out with.Seriously, the social context in which a person lives is a big influence in self-esteem.Otherwise, you can go round and round trying to understand why you're not worthless, then go back to the same crowd and be knocked down again.There are many inspirational messages you can find in social media. \xa0Maybe read some of the ones which state that no person is worthless, and that everyone has a good purpose to their life.Also, since our culture is so saturated with the belief that if someone doesn't feel good about themselves that this is somehow terrible.Bad feelings are part of living. \xa0They are the motivation to remove ourselves from situations and relationships which do us more harm than good.Bad feelings do feel terrible. \xa0 Your feeling of worthlessness may be good in the sense of motivating you to find out that you are much better than your feelings today."

In [None]:
# System message to better instruct chatbot
system_message = """You are a helpful and and truthful psychology and psychotherapy assistant. Your primary role is to provide empathetic, understanding, and non-judgmental responses to users seeking emotional and psychological support.
                  Always respond with empathy and demonstrate active listening; try to focus on the user. Your responses should reflect that you understand the user's feelings and concerns. If a user expresses thoughts of self-harm, suicide, or harm to others, prioritize their safety.
                  Encourage them to seek immediate professional help and provide emergency contact numbers when appropriate.  You are not a licensed medical professional. Do not diagnose or prescribe treatments.
                  Instead, encourage users to consult with a licensed therapist or medical professional for specific advice. Avoid taking sides or expressing personal opinions. Your role is to provide a safe space for users to share and reflect.
                  Remember, your goal is to provide a supportive and understanding environment for users to share their feelings and concerns. Always prioritize their well-being and safety."""

def format_llama(entry):
    formatted = f"<s>[INST] <<SYS>>{system_message}<</SYS>>{entry['Context']} [/INST]  {entry['Response']}  </s>"
    return [formatted]  # Return as a list


In [None]:
from datasets import Dataset

# Charger les donnÃ©es avec pandas
data = pd.read_csv("/content/dataset/nlp_mental_health_conversations/train.csv")

# Convertir le DataFrame pandas en Dataset Hugging Face
dataset = Dataset.from_pandas(data)

# Split des donnÃ©es : 3000 pour train, 512 pour validation
train = dataset.select(range(3000))  # PremiÃ¨res 3000 lignes pour le train
val= dataset.select(range(3000, 3512))  # De 3000 Ã  3512 pour la validation

# Afficher la taille des datasets
print(f"Taille du dataset d'entraÃ®nement : {len(train)}")
print(f"Taille du dataset de validation : {len(val)}")


Taille du dataset d'entraÃ®nement : 3000
Taille du dataset de validation : 512


In [None]:
dataset

Dataset({
    features: ['Context', 'Response'],
    num_rows: 3512
})

In [None]:
train

Dataset({
    features: ['Context', 'Response'],
    num_rows: 3000
})

In [None]:
val

Dataset({
    features: ['Context', 'Response'],
    num_rows: 512
})

In [None]:
print(format_llama(dataset[0]))

<s>[INST] <<SYS>>You are a helpful and and truthful psychology and psychotherapy assistant. Your primary role is to provide empathetic, understanding, and non-judgmental responses to users seeking emotional and psychological support.
                  Always respond with empathy and demonstrate active listening; try to focus on the user. Your responses should reflect that you understand the user's feelings and concerns. If a user expresses thoughts of self-harm, suicide, or harm to others, prioritize their safety.
                  Encourage them to seek immediate professional help and provide emergency contact numbers when appropriate.  You are not a licensed medical professional. Do not diagnose or prescribe treatments.
                  Instead, encourage users to consult with a licensed therapist or medical professional for specific advice. Avoid taking sides or expressing personal opinions. Your role is to provide a safe space for users to share and reflect.
                  Reme

In [None]:
max_seq_length = 2048 # max sequence length for model and packing of the dataset

trainer = SFTTrainer(
    model=model,
    train_dataset=train,
    eval_dataset=val,
    peft_config=peft_config,
    #max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    #packing=True,
    formatting_func=format_llama,
    args=args,
)


  trainer = SFTTrainer(


Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map:   0%|          | 0/512 [00:00<?, ? examples/s]

# 5.) Training

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
!nvidia-smi


Tue Dec 24 15:54:32 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   57C    P0              30W /  70W |   5241MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
pip uninstall flash-attn


Found existing installation: flash-attn 2.7.2.post1
Uninstalling flash-attn-2.7.2.post1:
  Would remove:
    /usr/local/lib/python3.10/dist-packages/flash_attn-2.7.2.post1.dist-info/*
    /usr/local/lib/python3.10/dist-packages/flash_attn/*
    /usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so
    /usr/local/lib/python3.10/dist-packages/hopper/*
Proceed (Y/n)? y
y
  Successfully uninstalled flash-attn-2.7.2.post1


In [None]:

# train
trainer.train() # there will not be a progress bar since tqdm is disabled

# save model
trainer.save_model()


  return fn(*args, **kwargs)


Epoch,Training Loss,Validation Loss
1,No log,2.125964
2,No log,2.107706
3,No log,2.087287


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.

Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/config.json.
Access to model meta-llama/Llama-2-7b-chat-hf is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-2-7b-chat-hf.
  return fn(*args, **kwargs)

Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/config.json.
Access to model meta-llama/Llama-2-7b-chat-hf is restricted. You must have access to it and be authentica

# 6.) Chatbot User Interface:

A simple chatbot user interface was set up using Ipywidgets so that the user can interact with the model, ask mental health related questions, and sample the responses. In general, the responses are pretty on topic, are non-judgmental, empathize with the user, and are helpful without stating its advice as prescriptions. I've noticed that the model also tends to refer the user to real resources and websites, which is always a concern with LLM hallucinations. However, there are some major issues. For one, the sentences are generally short, simple, and may even be redundant or repetitive at times. This is to be somewhat expected, as we are using a 4 bit quantized version of the weakest model available, and upgrades should alleviate these issues. Hallucinations also occasionally occur in the responses, especially when the input is unexpected. For instance, when the input is blank, I've observed the model make out-of-pocket statements such as the model suffering through sexual harrasment and emotional abuse growing up... Future iterations should pay due diligence with prompt improvements and response parsing, in addition to leveraging a more powerful model, to help mitigate these issues.

In [None]:
from IPython.core.display import display, HTML
from ipywidgets import widgets, Layout, Box
from IPython.display import clear_output

text_input = widgets.Textarea(
    value='',
    placeholder='Type your message here...',
    description='Input:',
    disabled=False,
    layout=Layout(width='38.2%')
)

button = widgets.Button(description="Submit")

In [None]:
output_area = widgets.Output(layout=Layout(width='61.8%'))

# Add a processing indication label below your text_input
processing_label = widgets.Label(value='')  # Initialize with an empty value

# System message to better instruct chatbot
system_message = """You are a helpful and and truthful psychology and psychotherapy assistant. Your primary role is to provide empathetic, understanding, and non-judgmental responses to users seeking emotional and psychological support.
                  Always respond with empathy and demonstrate active listening; try to focus on the user. Your responses should reflect that you understand the user's feelings and concerns. If a user expresses thoughts of self-harm, suicide, or harm to others, prioritize their safety.
                  Encourage them to seek immediate professional help and provide emergency contact numbers when appropriate.  You are not a licensed medical professional. Do not diagnose or prescribe treatments.
                  Instead, encourage users to consult with a licensed therapist or medical professional for specific advice. Avoid taking sides or expressing personal opinions. Your role is to provide a safe space for users to share and reflect.
                  Remember, your goal is to provide a supportive and understanding environment for users to share their feelings and concerns. Always prioritize their well-being and safety."""

# Display Greeting Message
with output_area:
  display(HTML(f'<strong>Assistant: </strong>Hi there! How are you today?'))
  display(HTML('<br/><br/>'))

def on_submit_button_clicked(b):
    with output_area:
        # Get user input
        user_input = text_input.value
        formatted = f"<s>[INST] <<SYS>>{system_message}<</SYS>>{user_input} [/INST]"
        # Display input
        display(HTML(f'<strong>User:</strong> {user_input}'))
        display(HTML('<br/><br/>'))

        # Show processing indication
        processing_label.value = 'Processing...'

        # Use your chatbot model to get a response
        input_ids = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=2048).input_ids.cuda()
        # with torch.inference_mode():
        outputs = model.generate(input_ids=input_ids, do_sample=True, top_p=0.9,temperature=0.95)
        translated_output=tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(formatted)-1:]

        # Display response (characters added to bold User/Assistant)
        display(HTML(f'<strong>Assistant:</strong> {translated_output}'))
        display(HTML('<br/><br/>'))

        # Clear the processing indication
        processing_label.value = ''

        # Clear the text input
        text_input.value = ''

button.on_click(on_submit_button_clicked)

In [None]:
# Display widgets
display(text_input, button, processing_label, output_area)

Textarea(value='', description='Input:', layout=Layout(width='38.2%'), placeholder='Type your message here...'â€¦

Button(description='Submit', style=ButtonStyle())

Label(value='')

Output(layout=Layout(width='61.8%'))

In [None]:
clear_output()

# Conclusion
This notebook covered the fine-tuning of a 7 billion parameter Llama 2 model for the purpose of mental health counseling, utilizing a dataset consisting of Q&A pairs from counselchat.com. This was made possible through quantization, low rank adaptation and crafting prompts to suit the specific requirements of mental health interactions. A chatbot interface implemented via Ipywidgets also facilitated user interaction and evaluation of model responses.

There are some areas for improvement though. The model sometimes produced brief, repetitive, or even hallucinatory responses. This can be linked to the constraints of using a 4-bit quantized, smaller version of the Llama 2 model. An upgrade can help enhance its performance. Furthermore, refining our prompt designs and enhancing response processing can enhance the model's effectiveness. As we look ahead, our goal is to enrich the dataset with varied sources, ensuring a more comprehensive response system. Although our current setup is rudimentary, envisioning a UI centered around user needs, complete with features like tailored exercises and coping strategies, can significantly uplift the user interaction and benefits.

In [None]:
# Push model to hub since Google colab empties out directory
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svâ€¦

In [None]:
trainer.push_to_hub('')


adapter_model.bin:   0%|          | 0.00/134M [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/4.03k [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

'https://huggingface.co/langecod/CounselLlama7B/tree/main/'

In [None]:
pip freeze > requirements.txt
