## Check GPU Availability

In [1]:
!nvidia-smi

Fri Nov 24 10:07:21 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla T4            Off  | 00000000:00:05.0 Off |                    0 |
| N/A   33C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|       

## Install required libraries

In [2]:
!pip install trl transformers accelerate git+https://github.com/huggingface/peft.git -Uqqq
!pip install datasets bitsandbytes einops wandb -Uqqq

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tokenizers 0.14.1 requires huggingface_hub<0.18,>=0.16.4, but you have huggingface-hub 0.19.4 which is incompatible.[0m[31m
[0m

## Importing libraries

In [4]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, GenerationConfig
from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel, prepare_model_for_kbit_training
from trl import SFTTrainer
import warnings
warnings.filterwarnings("ignore")

In [5]:
from huggingface_hub import notebook_login
notebook_login()



VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Load custom Mental Health conv dataset

In [38]:
data = load_dataset("vibhorag101/phr_mental_therapy_dataset")
data

Downloading readme:   0%|          | 0.00/460 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/211M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/99086 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 99086
    })
})

In [40]:
print(data["train"][0]['text'])

<s>[INST] <<SYS>>
You are a helpful and joyous mental therapy assistant. Always answer as helpfully and cheerfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>

I've been feeling so sad and overwhelmed lately. Work has become such a massive source of stress for me. [/INST] Hey there, I'm here to listen and support you. It sounds like work has been really challenging lately. Can you tell me more about what's been going on? </s><s>[INST] I recently got a promotion at work, which I thought would be exciting. But the added responsibilities and pressure have just taken a toll on my mental health. It's been

## Model Training

In [42]:
model_name = "typeof/neural-chat-7b-v3-1-sharded" # sharded Mistral-7b model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,            # load model in 4-bit precision
    bnb_4bit_quant_type="nf4",    # pre-trained model should be quantized in 4-bit NF format
    bnb_4bit_use_double_quant=True, # Using double quantization as mentioned in QLoRA paper
    bnb_4bit_compute_dtype=torch.bfloat16, # During computation, pre-trained model should be loaded in BF16 format
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config, # Use bitsandbytes config
    device_map="auto",  # Specifying device_map="auto" so that HF Accelerate will determine which GPU to put each layer of the model on
    trust_remote_code=True, # Set trust_remote_code=True to use Mistral-7b model with custom code
)

Loading checkpoint shards:   0%|          | 0/291 [00:00<?, ?it/s]

In [43]:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Set trust_remote_code=True
tokenizer.pad_token = tokenizer.eos_token # Setting pad_token same as eos_token

In [44]:
model = prepare_model_for_kbit_training(model)

lora_alpha = 32 # scaling factor for the weight matrices
lora_dropout = 0.05 # dropout probability of the LoRA layers
lora_rank = 32 # dimension of the low-rank matrices

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_rank,
    bias="none",  # setting to 'none' for only training weight params instead of biases
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "v_proj"]
)

peft_model = get_peft_model(model, peft_config)

In [52]:
output_dir = "./Intelv3-neuralchat-Viborag-MentalAssitant"
per_device_train_batch_size = 8 # reduce batch size by 2x only if out-of-memory error
gradient_accumulation_steps = 8  # increase gradient accumulation steps by  2x only if batch size is reduced
optim = "paged_adamw_32bit" # activates the paging for better memory management
save_strategy="steps" # checkpoint save strategy to adopt during training
save_steps = 10 # number of updates steps before two checkpoint saves
logging_steps = 10  # number of update steps between two logs if logging_strategy="steps"
learning_rate = 2e-4  # learning rate for AdamW optimizer
max_grad_norm = 0.3 # maximum gradient norm (for gradient clipping)
max_steps = 200       # training will happen for 150 steps
warmup_ratio = 0.03 # number of steps used for a linear warmup from 0 to learning_rate
lr_scheduler_type = "cosine"  # learning rate scheduler

training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    fp16=True,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=True,
    lr_scheduler_type=lr_scheduler_type,
    push_to_hub=True,
)

In [56]:
trainer = SFTTrainer(
    model=peft_model,
    train_dataset=data['train'],
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=256,
    tokenizer=tokenizer,
    args=training_arguments,
)

Map:   0%|          | 0/99086 [00:00<?, ? examples/s]

In [57]:
# upcasting the layer norms in torch.bfloat16 for more stable training
for name, module in trainer.model.named_modules():
    if "norm" in name:
        module = module.to(torch.bfloat16)

In [49]:
# authenticate WandB for logging metrics
import wandb
wandb.login()





True

In [58]:
peft_model.config.use_cache = False
trainer.train()

{'loss': 2.2907, 'learning_rate': 0.00019979028262377118, 'epoch': 0.01}
{'loss': 0.8395, 'learning_rate': 0.00019744105246469263, 'epoch': 0.01}
{'loss': 0.5853, 'learning_rate': 0.00019254212296427044, 'epoch': 0.02}
{'loss': 0.5254, 'learning_rate': 0.00018522168236559695, 'epoch': 0.03}
{'loss': 0.5134, 'learning_rate': 0.00017567128158176953, 'epoch': 0.03}
{'loss': 0.492, 'learning_rate': 0.000164140821963114, 'epoch': 0.04}
{'loss': 0.4846, 'learning_rate': 0.00015093201623287631, 'epoch': 0.05}
{'loss': 0.468, 'learning_rate': 0.00013639049369634876, 'epoch': 0.05}
{'loss': 0.4397, 'learning_rate': 0.00012089675630312754, 'epoch': 0.06}
{'loss': 0.4207, 'learning_rate': 0.00010485622221144484, 'epoch': 0.06}
{'loss': 0.4098, 'learning_rate': 8.868861738047158e-05, 'epoch': 0.07}
{'loss': 0.4112, 'learning_rate': 7.281699277636572e-05, 'epoch': 0.08}
{'loss': 0.4083, 'learning_rate': 5.765665457425102e-05, 'epoch': 0.08}
{'loss': 0.4105, 'learning_rate': 4.360429701490934e-05, '

TrainOutput(global_step=200, training_loss=0.5555560445785522, metrics={'train_runtime': 11207.0046, 'train_samples_per_second': 1.142, 'train_steps_per_second': 0.018, 'train_loss': 0.5555560445785522, 'epoch': 0.13})

In [61]:
trainer.push_to_hub()

'https://huggingface.co/Jaykumaran17/Intelv3-neuralchat-Viborag-MentalAssitant/tree/main/'

## Inference Pipeline

In [6]:
# Loading original model
model_name = "typeof/neural-chat-7b-v3-1-sharded"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

config.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/291 [00:00<?, ?it/s]

model-00001-of-00291.safetensors:   0%|          | 0.00/262M [00:00<?, ?B/s]

model-00002-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00003-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00004-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00005-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00006-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00007-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00008-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00009-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00010-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00011-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00012-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00013-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00014-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00015-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00016-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00017-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00018-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00019-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00020-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00021-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00022-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00023-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00024-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00025-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00026-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00027-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00028-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00029-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00030-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00031-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00032-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00033-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00034-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00035-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00036-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00037-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00038-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00039-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00040-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00041-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00042-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00043-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00044-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00045-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00046-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00047-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00048-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00049-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00050-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00051-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00052-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00053-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00054-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00055-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00056-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00057-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00058-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00059-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00060-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00061-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00062-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00063-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00064-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00065-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00066-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00067-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00068-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00069-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00070-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00071-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00072-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00073-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00074-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00075-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00076-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00077-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00078-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00079-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00080-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00081-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00082-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00083-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00084-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00085-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00086-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00087-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00088-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00089-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00090-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00091-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00092-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00093-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00094-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00095-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00096-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00097-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00098-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00099-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00100-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00101-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00102-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00103-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00104-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00105-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00106-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00107-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00108-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00109-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00110-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00111-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00112-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00113-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00114-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00115-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00116-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00117-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00118-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00119-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00120-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00121-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00122-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00123-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00124-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00125-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00126-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00127-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00128-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00129-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00130-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00131-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00132-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00133-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00134-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00135-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00136-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00137-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00138-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00139-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00140-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00141-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00142-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00143-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00144-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00145-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00146-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00147-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00148-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00149-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00150-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00151-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00152-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00153-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00154-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00155-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00156-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00157-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00158-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00159-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00160-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00161-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00162-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00163-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00164-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00165-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00166-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00167-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00168-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00169-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00170-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00171-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00172-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00173-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00174-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00175-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00176-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00177-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00178-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00179-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00180-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00181-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00182-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00183-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00184-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00185-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00186-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00187-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00188-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00189-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00190-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00191-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00192-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00193-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00194-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00195-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00196-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00197-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00198-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00199-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00200-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00201-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00202-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00203-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00204-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00205-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00206-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00207-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00208-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00209-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00210-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00211-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00212-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00213-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00214-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00215-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00216-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00217-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00218-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00219-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00220-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00221-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00222-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00223-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00224-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00225-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00226-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00227-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00228-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00229-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00230-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00231-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00232-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00233-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00234-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00235-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00236-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00237-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00238-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00239-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00240-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00241-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00242-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00243-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00244-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00245-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00246-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00247-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00248-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00249-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00250-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00251-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00252-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00253-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00254-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00255-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00256-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00257-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00258-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00259-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00260-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00261-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00262-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00263-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00264-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00265-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00266-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00267-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00268-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00269-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00270-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00271-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00272-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00273-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00274-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00275-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00276-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00277-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00278-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00279-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00280-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00281-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00282-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00283-of-00291.safetensors:   0%|          | 0.00/8.39M [00:00<?, ?B/s]

model-00284-of-00291.safetensors:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

model-00285-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00286-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00287-of-00291.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

model-00288-of-00291.safetensors:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

model-00289-of-00291.safetensors:   0%|          | 0.00/8.34k [00:00<?, ?B/s]

model-00290-of-00291.safetensors:   0%|          | 0.00/8.31k [00:00<?, ?B/s]

model-00291-of-00291.safetensors:   0%|          | 0.00/262M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/291 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/953 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/145 [00:00<?, ?B/s]

In [9]:
# Loading PEFT model
PEFT_MODEL = "Jaykumaran17/Intelv3-neuralchat-Viborag-MentalAssitant"

config = PeftConfig.from_pretrained(PEFT_MODEL)
peft_base_model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    return_dict=True,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

peft_model = PeftModel.from_pretrained(peft_base_model, PEFT_MODEL)

peft_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
peft_tokenizer.pad_token = peft_tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/291 [00:00<?, ?it/s]

In [12]:
# Function to generate responses from both original model and PEFT model and compare their answers.
def generate_answer(query):
  system_prompt = """Answer the following question truthfully.
  If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
  If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'."""

  user_prompt = f"""<HUMAN>: {query}
  <ASSISTANT>: """

  final_prompt = system_prompt + "\n" + user_prompt

  device = "cuda:0"
  dashline = "-".join("" for i in range(50))

  encoding = tokenizer(final_prompt, return_tensors="pt").to(device)
  outputs = model.generate(input_ids=encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=256, pad_token_id = tokenizer.eos_token_id, \
                                                                                                                     eos_token_id = tokenizer.eos_token_id, attention_mask = encoding.attention_mask, \
                                                                                                                     temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,))
  text_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

  print(dashline)
  print(f'ORIGINAL MODEL RESPONSE:\n{text_output}')
  print(dashline)

  peft_encoding = peft_tokenizer(final_prompt, return_tensors="pt").to(device)
  peft_outputs = peft_model.generate(input_ids=peft_encoding.input_ids, generation_config=GenerationConfig(max_new_tokens=256, pad_token_id = peft_tokenizer.eos_token_id, \
                                                                                                                     eos_token_id = peft_tokenizer.eos_token_id, attention_mask = peft_encoding.attention_mask, \
                                                                                                                     temperature=0.4, top_p=0.6, repetition_penalty=1.3, num_return_sequences=1,))
  peft_text_output = peft_tokenizer.decode(peft_outputs[0], skip_special_tokens=True)

  print(f'PEFT MODEL RESPONSE:\n{peft_text_output}')
  print(dashline)

## Compare responses between Original model and PEFT model

In [14]:
query = "How can I prevent anxiety and depression?"
generate_answer(query)

-------------------------------------------------
ORIGINAL MODEL RESPONSE:
Answer the following question truthfully.
  If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
  If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'.
<HUMAN>: How can I prevent anxiety and depression?
  <ASSISTANT>: 
To reduce or manage anxiety and depression, it's important to maintain good mental health practices such as regular exercise, maintaining healthy sleep patterns, eating well-balanced meals, engaging in activities that bring joy like hobbies or spending time with loved ones, practicing mindfulness techniques (like meditation), seeking professional help if needed, and building strong social support networks. Remember, everyone experiences these emotions differently; what works best might vary from person to person.
-------------------------------------------------
PEFT MODEL RESPONSE:
Answer the following question truthfu

In [16]:
query = "How to take care of mental health?"
generate_answer(query)

-------------------------------------------------
ORIGINAL MODEL RESPONSE:
Answer the following question truthfully.
  If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
  If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'.
<HUMAN>: How to take care of mental health?
  <ASSISTANT>: 
To maintain good mental health, it's important to practice self-care and engage in activities that promote wellbeing. Some ways include:
1. Exercise regularly - Physical activity can help reduce stress levels and improve mood.
2. Get enough sleep - Adequate rest helps your brain function properly and supports overall physical health.
3. Eat healthy foods - Proper nutrition provides essential nutrients needed by both body and mind.
4. Connect with others - Strong social connections contribute positively to our emotional state.
5. Engage in hobbies or interests - Pursuing enjoyable pastimes gives us a sense of accomplishment and

In [19]:
query = "What is the warning sign of depression?"
generate_answer(query)

-------------------------------------------------
ORIGINAL MODEL RESPONSE:
Answer the following question truthfully.
  If you don't know the answer, respond 'Sorry, I don't know the answer to this question.'.
  If the question is too complex, respond 'Kindly, consult a psychiatrist for further queries.'.
  <ASSISTANT>: 
Some common signs include persistent feelings of sadness or hopelessness, loss of interest in activities that used to bring joy, changes in sleep patterns (sleeping more/less than usual), appetite disturbances, fatigue and low energy levels, difficulty concentrating, irritability, increased sensitivity to rejection, self-loathing, thoughts about suicide, etc. Please note these can vary from person to person. It would be best if someone experiencing such symptoms could seek professional help to understand their situation better.
-------------------------------------------------
PEFT MODEL RESPONSE:
Answer the following question truthfully.
  If you don't know the answer,