# Mental Health Chatbot with GPT-2 Small
This notebook implements a chatbot fine-tuned on mental health dialogue datasets using GPT-2 Small (124M parameters). It includes data loading, model training, sentiment analysis, and a Streamlit app for interaction. Uses FP16 without quantization to fit T4 GPU (16GB VRAM).

In [1]:
# Clean corrupted transformers and pyarrow installations
!rm -rf /usr/local/lib/python3.11/dist-packages/~ransformers
!rm -rf /usr/local/lib/python3.11/dist-packages/transformers
!rm -rf /usr/local/lib/python3.11/dist-packages/pyarrow
!pip uninstall -y transformers sentence-transformers torch torchvision torchaudio datasets gcsfs fsspec pyarrow

# Clear pip cache to avoid reusing corrupted packages
!pip cache purge

Found existing installation: transformers 4.41.2
Uninstalling transformers-4.41.2:
  Successfully uninstalled transformers-4.41.2
Found existing installation: sentence-transformers 2.6.1
Uninstalling sentence-transformers-2.6.1:
  Successfully uninstalled sentence-transformers-2.6.1
Found existing installation: torch 2.1.0
Uninstalling torch-2.1.0:
  Successfully uninstalled torch-2.1.0
Found existing installation: torchvision 0.16.0
Uninstalling torchvision-0.16.0:
  Successfully uninstalled torchvision-0.16.0
Found existing installation: torchaudio 2.1.0
Uninstalling torchaudio-2.1.0:
  Successfully uninstalled torchaudio-2.1.0
Found existing installation: datasets 2.20.0
Uninstalling datasets-2.20.0:
  Successfully uninstalled datasets-2.20.0
Found existing installation: gcsfs 2023.10.0
Uninstalling gcsfs-2023.10.0:
  Successfully uninstalled gcsfs-2023.10.0
Found existing installation: fsspec 2023.10.0
Uninstalling fsspec-2023.10.0:
  Successfully uninstalled fsspec-2023.10.0
Found

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
%pip uninstall -y numpy pandas datasets torch torchvision torchaudio transformers sentence-transformers tqdm nltk rouge_score peft accelerate streamlit pyngrok fsspec gcsfs pyarrow protobuf google-auth tornado decorator jedi -q  # Remove all conflicting packages
%pip install --no-cache-dir --force-reinstall \
    torch==2.1.0 \
    torchvision==0.16.0 \
    torchaudio==2.1.0 \
    transformers==4.41.2 \
    datasets==2.20.0 \
    sentence-transformers==2.6.1 \
    pandas==2.2.2 \
    numpy==1.25.2 \
    tqdm==4.67.0 \
    nltk==3.9.1 \
    rouge_score==0.1.2 \
    peft==0.9.0 \
    accelerate==1.0.1 \
    streamlit==1.45.1 \
    pyngrok==7.2.0 \
    fsspec==2023.10.0 \
    gcsfs==2023.10.0 \
    pyarrow==15.0.2 \
    protobuf==5.29.1 \
    google-auth==2.38.0 \
    tornado==6.4.2 \
    decorator==4.4.2 \
    jedi>=0.16 --no-deps -q

[0m

In [4]:
%pip install jedi>=0.16 -q

In [5]:
import numpy
print(numpy.__version__)

1.25.2


In [13]:
!pip uninstall -y numpy -q
!pip install numpy==1.25.2 --no-cache-dir --force-reinstall -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/18.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/18.2 MB[0m [31m32.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/18.2 MB[0m [31m75.8 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━[0m [32m11.5/18.2 MB[0m [31m131.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m208.0 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.25.2 which is incompatible.
blosc2 3.3.2 requires numpy>=1.26, but you have numpy 1.25.2 which is incompatibl

In [6]:
import requests
import numpy
import torch
from rouge_score import rouge_scorer
from transformers import AutoTokenizer, AutoModel
import datasets
import pandas as pd

print("✅ All key libraries loaded successfully!")


✅ All key libraries loaded successfully!


In [7]:
!nvidia-smi
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
print(f"CUDA version: {torch.version.cuda}")
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Please select T4 GPU runtime.")

Mon May 26 11:38:36 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   38C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [8]:
import numpy as np
import pandas as pd
import json
import nltk
from datasets import Dataset, load_dataset
from sklearn.model_selection import train_test_split
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BertTokenizer,
    AutoModelForSequenceClassification
)
from peft import LoraConfig, get_peft_model
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
from google.colab import userdata
from tqdm import tqdm
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [9]:
try:
    HF_TOKEN = userdata.get('HF_TOKEN')
    if not HF_TOKEN:
        raise ValueError("HF_TOKEN not found in Colab Secrets.")
except Exception as e:
    print(f"Error loading HF_TOKEN: {e}")
    print("Please set HF_TOKEN in Colab Secrets or input manually.")
    HF_TOKEN = input("Enter your Hugging Face token: ")

#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

Error loading HF_TOKEN: Secret HF_TOKEN does not exist.
Please set HF_TOKEN in Colab Secrets or input manually.
Enter your Hugging Face token: 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


In [10]:
# Load and merge datasets
def load_and_merge_datasets():
    hope_path = "/content/drive/MyDrive/Hope_data"
    try:
        files = [f for f in os.listdir(hope_path) if f.endswith(".csv")]
    except FileNotFoundError:
        print(f"Error: HOPE dataset directory not found at {hope_path}")
        files = []

    hope_pairs = []
    for file in files:
        df = pd.read_csv(os.path.join(hope_path, file))
        df['Speaker'] = df['Type'].map({'T': 'Therapist', 'P': 'Client'})
        df['Content'] = df['Utterance']
        for i in range(1, len(df)):
            if df.loc[i-1, 'Speaker'] == "Client" and df.loc[i, 'Speaker'] == "Therapist":
                hope_pairs.append({
                    "prompt": f"Client: {df.loc[i-1, 'Content']}",
                    "response": f"Therapist: {df.loc[i, 'Content']}",
                    "source": "HOPE"
                })
    print(f"Extracted {len(hope_pairs)} dialogue pairs from HOPE dataset")

    try:
        empathetic_ds = load_dataset("empathetic_dialogues")
        empathy_pairs = []
        prev_conv_id = None
        context = ""
        for row in empathetic_ds['train']:
            if row['utterance_idx'] > 0 and row['conv_id'] == prev_conv_id:
                empathy_pairs.append({
                    "prompt": f"Client: {context}",
                    "response": f"Therapist: {row['utterance']}",
                    "emotion": row['context'],
                    "source": "EmpatheticDialogues"
                })
            context = row['utterance']
            prev_conv_id = row['conv_id']
        print(f"Extracted {len(empathy_pairs)} dialogue pairs from EmpatheticDialogues dataset")
    except Exception as e:
        print(f"Error loading EmpatheticDialogues: {e}")
        empathy_pairs = []

    try:
        url = "https://raw.githubusercontent.com/nbertagnolli/counsel-chat/master/data/counselchat-data.csv"
        response = requests.get(url)
        with open("counselchat-data.csv", "wb") as f:
            f.write(response.content)
        cc_df = pd.read_csv("counselchat-data.csv")
        counsel_pairs = []
        for _, row in cc_df.iterrows():
            if pd.notnull(row['questionText']) and pd.notnull(row['answerText']):
                counsel_pairs.append({
                    "prompt": f"Client: {row['questionText']}",
                    "response": f"Therapist: {row['answerText'].replace('<p>','').replace('</p>','').strip()}",
                    "source": "CounselChat"
                })
    except Exception as e:
        print(f"Error loading CounselChat: {e}")
        counsel_pairs = []

    dialogue_data = hope_pairs + empathy_pairs + counsel_pairs
    print(f"Total dialogue pairs: {len(dialogue_data)}")

    print("\nSample data from each source:")
    for source in ["HOPE", "EmpatheticDialogues", "CounselChat"]:
        samples = [d for d in dialogue_data if d.get("source") == source]
        if samples:
            print(f"\n{source} sample:")
            sample = np.random.choice(samples)
            print(f"Prompt: {sample['prompt']}")
            print(f"Response: {sample['response']}")

    dialogue_data = [{'prompt': d['prompt'], 'response': d['response'], 'emotion': d.get('emotion', 'unknown')} for d in dialogue_data]
    return Dataset.from_list(dialogue_data)

# Load and split dataset
dialogue_data = load_and_merge_datasets()
train_data, val_data = train_test_split(dialogue_data.to_pandas(), test_size=0.2, random_state=42)
train_dataset = Dataset.from_pandas(train_data.reset_index(drop=True))
val_dataset = Dataset.from_pandas(val_data.reset_index(drop=True))


Extracted 225 dialogue pairs from HOPE dataset


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script:   0%|          | 0.00/4.51k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.15k [00:00<?, ?B/s]

The repository for empathetic_dialogues contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/empathetic_dialogues.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/28.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/76673 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/12030 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10943 [00:00<?, ? examples/s]

Extracted 58829 dialogue pairs from EmpatheticDialogues dataset
Total dialogue pairs: 60437

Sample data from each source:

HOPE sample:
Prompt: Client: what cancer right I think I probably have cancer i'm not i'm not doing very well I just look headaches and I was just
Response: Therapist: let's try to figure some of this out. Okay. When you said you were throwing up blubbery, throwing up bright red blood and was it like little spots, big spots?

EmpatheticDialogues sample:
Prompt: Client: I had to go to the doctor recently to get something checked_comma_ always makes me worried going there.
Response: Therapist: Are you going to be okay

CounselChat sample:
Prompt: Client: I've hit my head on walls and floors ever since I was young.  I sometimes still do it but I don't exactly know why, 

I have anxiety and I had a rough childhood but now I'll start to hit my head and sometimes not realize it but I don't know how to stop or even why I'm doing it. 

How can I help myself to change my b

In [11]:
model_path = "/content/drive/MyDrive/gpt2-finetuned"
if os.path.exists(model_path) and os.path.isfile(os.path.join(model_path, "config.json")):
    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        print(f"Loaded fine-tuned GPT-2 Small from {model_path}")
    except Exception as e:
        print(f"Error loading fine-tuned model: {e}")
        model = AutoModelForCausalLM.from_pretrained(
            "gpt2",
            torch_dtype=torch.float16,
            device_map="auto"
        )
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
else:
    print(f"Fine-tuned model not found at {model_path}. Loading pre-trained GPT-2 Small.")
    model = AutoModelForCausalLM.from_pretrained(
        "gpt2",
        torch_dtype=torch.float16,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained("gpt2")

tokenizer.pad_token = tokenizer.eos_token

# Load BERT for sentiment analysis
try:
    bert_model = AutoModelForSequenceClassification.from_pretrained("bhadresh-savani/bert-base-uncased-emotion", token=HF_TOKEN)
    bert_tokenizer = BertTokenizer.from_pretrained("bhadresh-savani/bert-base-uncased-emotion", token=HF_TOKEN)
    print("Loaded BERT model for sentiment analysis.")
except Exception as e:
    print(f"Error loading BERT model: {e}")
    raise

Fine-tuned model not found at /content/drive/MyDrive/gpt2-finetuned. Loading pre-trained GPT-2 Small.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/935 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/285 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Loaded BERT model for sentiment analysis.


In [12]:
from peft import LoraConfig, get_peft_model
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from datasets import Dataset

# Re-run load_and_merge_datasets to ensure raw dataset
def load_and_merge_datasets():
    import os
    import pandas as pd
    import requests
    from datasets import load_dataset
    import numpy as np

    hope_path = "/content/drive/MyDrive/Hope_data"
    try:
        files = [f for f in os.listdir(hope_path) if f.endswith(".csv")]
    except FileNotFoundError:
        print(f"Error: HOPE dataset directory not found at {hope_path}")
        files = []

    hope_pairs = []
    for file in files:
        df = pd.read_csv(os.path.join(hope_path, file))
        df['Speaker'] = df['Type'].map({'T': 'Therapist', 'P': 'Client'})
        df['Content'] = df['Utterance']
        for i in range(1, len(df)):
            if df.loc[i-1, 'Speaker'] == "Client" and df.loc[i, 'Speaker'] == "Therapist":
                hope_pairs.append({
                    "prompt": f"Client: {df.loc[i-1, 'Content']}",
                    "response": f"Therapist: {df.loc[i, 'Content']}",
                    "source": "HOPE"
                })
    print(f"Extracted {len(hope_pairs)} dialogue pairs from HOPE dataset")

    try:
        empathetic_ds = load_dataset("empathetic_dialogues")
        empathy_pairs = []
        prev_conv_id = None
        context = ""
        for row in empathetic_ds['train']:
            if row['utterance_idx'] > 0 and row['conv_id'] == prev_conv_id:
                empathy_pairs.append({
                    "prompt": f"Client: {context}",
                    "response": f"Therapist: {row['utterance']}",
                    "emotion": row['context'],
                    "source": "EmpatheticDialogues"
                })
            context = row['utterance']
            prev_conv_id = row['conv_id']
        print(f"Extracted {len(empathy_pairs)} dialogue pairs from EmpatheticDialogues dataset")
    except Exception as e:
        print(f"Error loading EmpatheticDialogues: {e}")
        empathy_pairs = []

    try:
        url = "https://raw.githubusercontent.com/nbertagnolli/counsel-chat/master/data/counselchat-data.csv"
        response = requests.get(url)
        with open("counselchat-data.csv", "wb") as f:
            f.write(response.content)
        cc_df = pd.read_csv("counselchat-data.csv")
        counsel_pairs = []
        for _, row in cc_df.iterrows():
            if pd.notnull(row['questionText']) and pd.notnull(row['answerText']):
                counsel_pairs.append({
                    "prompt": f"Client: {row['questionText']}",
                    "response": f"Therapist: {row['answerText'].replace('<p>','').replace('</p>','').strip()}",
                    "source": "CounselChat"
                })
    except Exception as e:
        print(f"Error loading CounselChat: {e}")
        counsel_pairs = []

    dialogue_data = hope_pairs + empathy_pairs + counsel_pairs
    print(f"Total dialogue pairs: {len(dialogue_data)}")

    print("\nSample data from each source:")
    for source in ["HOPE", "EmpatheticDialogues", "CounselChat"]:
        samples = [d for d in dialogue_data if d.get("source") == source]
        if samples:
            print(f"\n{source} sample:")
            sample = np.random.choice(samples)
            print(f"Prompt: {sample['prompt']}")
            print(f"Response: {sample['response']}")

    dialogue_data = [{'prompt': d['prompt'], 'response': d['response'], 'emotion': d.get('emotion', 'unknown')} for d in dialogue_data]
    return Dataset.from_list(dialogue_data)

# Reload and split dataset
dialogue_data = load_and_merge_datasets()
train_data, val_data = train_test_split(dialogue_data.to_pandas(), test_size=0.2, random_state=42)
train_dataset = Dataset.from_pandas(train_data.reset_index(drop=True))
val_dataset = Dataset.from_pandas(val_data.reset_index(drop=True))

def tokenize_function(examples):
    prompt_key = None
    response_key = None
    for key in examples.keys():
        if key.lower() in ['prompt', 'client', 'questiontext', 'client_text']:
            prompt_key = key
        if key.lower() in ['response', 'therapist', 'answertext', 'therapist_text']:
            response_key = key

    if not prompt_key or not response_key:
        raise KeyError(f"Could not find prompt or response columns in dataset. Available keys: {list(examples.keys())}")

    print(f"Using prompt_key: {prompt_key}, response_key: {response_key}")
    texts = [examples[prompt_key][i] + " " + examples[response_key][i] for i in range(len(examples[prompt_key]))]
    tokenized = tokenizer(texts, truncation=True, max_length=128, padding='max_length')
    tokenized['labels'] = tokenized['input_ids'].copy()
    return tokenized

# Custom collator to prevent unexpected kwargs
class CustomDataCollator:
    def __call__(self, examples):
        input_ids = torch.stack([ex['input_ids'] for ex in examples])
        attention_mask = torch.stack([ex['attention_mask'] for ex in examples])
        labels = torch.stack([ex['labels'] for ex in examples])
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

# Debug dataset before tokenization
print("Raw train dataset columns:", train_dataset.column_names)
print("Sample raw train item:", train_dataset[0])

# Check if dataset is already tokenized
if 'input_ids' in train_dataset.column_names and 'attention_mask' in train_dataset.column_names:
    print("Dataset is already tokenized. Skipping tokenization.")
else:
    # Limit dataset to manage VRAM
    train_dataset = train_dataset.select(range(min(10000, len(train_dataset))))
    val_dataset = val_dataset.select(range(min(2500, len(val_dataset))))

    # Tokenize dataset
    try:
        train_dataset = train_dataset.map(tokenize_function, batched=True)
        val_dataset = val_dataset.map(tokenize_function, batched=True)
    except KeyError as e:
        print(f"Tokenization failed: {e}")
        raise

    # Remove extra columns
    train_dataset = train_dataset.remove_columns([col for col in train_dataset.column_names if col not in ['input_ids', 'attention_mask', 'labels']])
    val_dataset = val_dataset.remove_columns([col for col in val_dataset.column_names if col not in ['input_ids', 'attention_mask', 'labels']])
    train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
    val_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# Debug dataset after tokenization
print("Train dataset columns:", train_dataset.column_names)
print("Sample train item keys:", train_dataset[0].keys())

# Apply LoRA with explicit fan_in_fan_out
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["c_attn", "c_proj", "mlp.c_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    fan_in_fan_out=True
)
model = get_peft_model(model, lora_config)

# Debug LoRA parameters and gradients
print("LoRA trainable parameters:")
trainable_params = 0
total_params = 0
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f" - {name}: {param.shape}, requires_grad={param.requires_grad}")
        trainable_params += param.numel()
    total_params += param.numel()
print(f"Trainable parameters: {trainable_params}")
print(f"Total parameters: {total_params}")
print(f"Percentage trainable: {(trainable_params / total_params) * 100:.2f}%")

# Ensure model is in training mode
model.train()

# Disable cache
model.config.use_cache = False

# Debug: Test forward and backward pass
print("Testing model forward and backward pass...")
sample_batch = train_dataset[:1]
input_ids = sample_batch['input_ids'].to('cuda')
attention_mask = sample_batch['attention_mask'].to('cuda')
labels = sample_batch['labels'].to('cuda')
try:
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    print("Forward pass successful. Loss:", outputs.loss.item())
    print("Loss requires grad:", outputs.loss.requires_grad)
    if outputs.loss.requires_grad:
        outputs.loss.backward()
        print("Backward pass successful. Gradients for LoRA parameters:")
        for name, param in model.named_parameters():
            if param.requires_grad and param.grad is not None:
                print(f" - {name}: grad_norm={param.grad.norm().item()}")
    else:
        print("Loss does not require grad. Check LoRA configuration.")
except Exception as e:
    print(f"Forward/backward pass failed: {e}")
    raise

# Manual training loop
data_collator = CustomDataCollator()
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=data_collator)

# Optimizer with only trainable parameters
optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=2e-5
)
num_epochs = 3
gradient_accumulation_steps = 4
max_grad_norm = 0.5
device = torch.device('cuda')
model.to(device)

# Debug first batch
print("Inspecting first batch from DataLoader...")
first_batch = next(iter(train_dataloader))
print("Batch keys:", first_batch.keys())
print("Batch shapes:", {k: v.shape for k, v in first_batch.items()})

# Training loop
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    optimizer.zero_grad()
    for step, batch in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch+1}")):
        valid_batch = {k: v.to(device) for k, v in batch.items() if k in ['input_ids', 'attention_mask', 'labels']}

        try:
            outputs = model(**valid_batch)
            loss = outputs.loss / gradient_accumulation_steps
            if not loss.requires_grad:
                print(f"Error at step {step+1}: Loss does not require grad")
                raise RuntimeError("Loss does not require grad")
            loss.backward()
            total_loss += loss.item() * gradient_accumulation_steps
        except Exception as e:
            print(f"Error in forward/backward pass at step {step+1}: {e}")
            raise

        if (step + 1) % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()

        if (step + 1) % 100 == 0:
            print(f"Epoch {epoch+1}, Step {step+1}, Loss: {total_loss / (step + 1):.4f}")

    print(f"Epoch {epoch+1} Average Loss: {total_loss / len(train_dataloader):.4f}")

# Save model
model.save_pretrained("/content/drive/MyDrive/gpt2-finetuned")
tokenizer.save_pretrained("/content/drive/MyDrive/gpt2-finetuned")

Extracted 225 dialogue pairs from HOPE dataset
Extracted 58829 dialogue pairs from EmpatheticDialogues dataset
Total dialogue pairs: 60437

Sample data from each source:

HOPE sample:
Prompt: Client: Yeah
Response: Therapist: Okay. Now, when you were working, how did that make you feel when you're working?

EmpatheticDialogues sample:
Prompt: Client: I witnessed a man rob an old woman walking across the street. Boils my friggin' blood!
Response: Therapist: Oh my God that's terrible!  So sad.  I'd be mad too.  What did you do?

CounselChat sample:
Prompt: Client: My ex-boyfriend and I have been back and forth for over a year now. He's in his late 20s, divorced for like five years now with two kids. He has a lot of narcissistic behaviors. He lies and cheats, but I love him. I've tried to date other people, but I always go back to him.
Response: Therapist: There are a lot of pieces to the decision of whether to stay or leave. Can you have open conversations about your concerns? Is he able

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Using prompt_key: prompt, response_key: response
Using prompt_key: prompt, response_key: response
Using prompt_key: prompt, response_key: response
Using prompt_key: prompt, response_key: response
Using prompt_key: prompt, response_key: response
Using prompt_key: prompt, response_key: response
Using prompt_key: prompt, response_key: response
Using prompt_key: prompt, response_key: response
Using prompt_key: prompt, response_key: response
Using prompt_key: prompt, response_key: response


Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

Using prompt_key: prompt, response_key: response
Using prompt_key: prompt, response_key: response
Using prompt_key: prompt, response_key: response
Train dataset columns: ['input_ids', 'attention_mask', 'labels']
Sample train item keys: dict_keys(['input_ids', 'attention_mask', 'labels'])
LoRA trainable parameters:
 - base_model.model.transformer.h.0.attn.c_attn.lora_A.default.weight: torch.Size([16, 768]), requires_grad=True
 - base_model.model.transformer.h.0.attn.c_attn.lora_B.default.weight: torch.Size([2304, 16]), requires_grad=True
 - base_model.model.transformer.h.0.attn.c_proj.lora_A.default.weight: torch.Size([16, 768]), requires_grad=True
 - base_model.model.transformer.h.0.attn.c_proj.lora_B.default.weight: torch.Size([768, 16]), requires_grad=True
 - base_model.model.transformer.h.0.mlp.c_proj.lora_A.default.weight: torch.Size([16, 3072]), requires_grad=True
 - base_model.model.transformer.h.0.mlp.c_proj.lora_B.default.weight: torch.Size([768, 16]), requires_grad=True
 - bas

Epoch 1:   1%|          | 102/10000 [00:04<10:18, 15.99it/s]

Epoch 1, Step 100, Loss: nan


Epoch 1:   2%|▏         | 202/10000 [00:09<06:57, 23.46it/s]

Epoch 1, Step 200, Loss: nan


Epoch 1:   3%|▎         | 304/10000 [00:14<06:58, 23.16it/s]

Epoch 1, Step 300, Loss: nan


Epoch 1:   4%|▍         | 403/10000 [00:19<09:23, 17.03it/s]

Epoch 1, Step 400, Loss: nan


Epoch 1:   5%|▌         | 502/10000 [00:23<06:21, 24.88it/s]

Epoch 1, Step 500, Loss: nan


Epoch 1:   6%|▌         | 604/10000 [00:28<06:45, 23.15it/s]

Epoch 1, Step 600, Loss: nan


Epoch 1:   7%|▋         | 704/10000 [00:33<08:13, 18.83it/s]

Epoch 1, Step 700, Loss: nan


Epoch 1:   8%|▊         | 803/10000 [00:37<06:31, 23.48it/s]

Epoch 1, Step 800, Loss: nan


Epoch 1:   9%|▉         | 902/10000 [00:41<06:35, 22.98it/s]

Epoch 1, Step 900, Loss: nan


Epoch 1:  10%|█         | 1004/10000 [00:47<06:49, 21.98it/s]

Epoch 1, Step 1000, Loss: nan


Epoch 1:  11%|█         | 1103/10000 [00:51<06:18, 23.50it/s]

Epoch 1, Step 1100, Loss: nan


Epoch 1:  12%|█▏        | 1202/10000 [00:55<05:58, 24.55it/s]

Epoch 1, Step 1200, Loss: nan


Epoch 1:  13%|█▎        | 1305/10000 [01:00<06:03, 23.93it/s]

Epoch 1, Step 1300, Loss: nan


Epoch 1:  14%|█▍        | 1404/10000 [01:05<06:02, 23.74it/s]

Epoch 1, Step 1400, Loss: nan


Epoch 1:  15%|█▌        | 1503/10000 [01:09<06:56, 20.42it/s]

Epoch 1, Step 1500, Loss: nan


Epoch 1:  16%|█▌        | 1604/10000 [01:14<06:06, 22.92it/s]

Epoch 1, Step 1600, Loss: nan


Epoch 1:  17%|█▋        | 1703/10000 [01:18<05:42, 24.22it/s]

Epoch 1, Step 1700, Loss: nan


Epoch 1:  18%|█▊        | 1802/10000 [01:23<08:39, 15.78it/s]

Epoch 1, Step 1800, Loss: nan


Epoch 1:  19%|█▉        | 1904/10000 [01:28<05:59, 22.52it/s]

Epoch 1, Step 1900, Loss: nan


Epoch 1:  20%|██        | 2003/10000 [01:32<05:37, 23.67it/s]

Epoch 1, Step 2000, Loss: nan


Epoch 1:  21%|██        | 2103/10000 [01:37<07:22, 17.85it/s]

Epoch 1, Step 2100, Loss: nan


Epoch 1:  22%|██▏       | 2202/10000 [01:42<05:40, 22.87it/s]

Epoch 1, Step 2200, Loss: nan


Epoch 1:  23%|██▎       | 2304/10000 [01:46<05:32, 23.16it/s]

Epoch 1, Step 2300, Loss: nan


Epoch 1:  24%|██▍       | 2402/10000 [01:51<06:46, 18.67it/s]

Epoch 1, Step 2400, Loss: nan


Epoch 1:  25%|██▌       | 2504/10000 [01:56<05:16, 23.65it/s]

Epoch 1, Step 2500, Loss: nan


Epoch 1:  26%|██▌       | 2603/10000 [02:00<05:07, 24.02it/s]

Epoch 1, Step 2600, Loss: nan


Epoch 1:  27%|██▋       | 2704/10000 [02:05<05:19, 22.81it/s]

Epoch 1, Step 2700, Loss: nan


Epoch 1:  28%|██▊       | 2803/10000 [02:09<05:04, 23.66it/s]

Epoch 1, Step 2800, Loss: nan


Epoch 1:  29%|██▉       | 2902/10000 [02:14<04:59, 23.71it/s]

Epoch 1, Step 2900, Loss: nan


Epoch 1:  30%|███       | 3004/10000 [02:19<05:02, 23.12it/s]

Epoch 1, Step 3000, Loss: nan


Epoch 1:  31%|███       | 3103/10000 [02:23<05:09, 22.27it/s]

Epoch 1, Step 3100, Loss: nan


Epoch 1:  32%|███▏      | 3202/10000 [02:28<06:07, 18.51it/s]

Epoch 1, Step 3200, Loss: nan


Epoch 1:  33%|███▎      | 3305/10000 [02:33<04:41, 23.80it/s]

Epoch 1, Step 3300, Loss: nan


Epoch 1:  34%|███▍      | 3404/10000 [02:37<04:43, 23.25it/s]

Epoch 1, Step 3400, Loss: nan


Epoch 1:  35%|███▌      | 3502/10000 [02:42<06:20, 17.06it/s]

Epoch 1, Step 3500, Loss: nan


Epoch 1:  36%|███▌      | 3603/10000 [02:47<04:31, 23.59it/s]

Epoch 1, Step 3600, Loss: nan


Epoch 1:  37%|███▋      | 3705/10000 [02:51<04:24, 23.76it/s]

Epoch 1, Step 3700, Loss: nan


Epoch 1:  38%|███▊      | 3802/10000 [02:56<06:14, 16.56it/s]

Epoch 1, Step 3800, Loss: nan


Epoch 1:  39%|███▉      | 3902/10000 [03:00<04:23, 23.14it/s]

Epoch 1, Step 3900, Loss: nan


Epoch 1:  40%|████      | 4004/10000 [03:05<04:23, 22.75it/s]

Epoch 1, Step 4000, Loss: nan


Epoch 1:  41%|████      | 4103/10000 [03:10<04:42, 20.85it/s]

Epoch 1, Step 4100, Loss: nan


Epoch 1:  42%|████▏     | 4205/10000 [03:14<04:06, 23.51it/s]

Epoch 1, Step 4200, Loss: nan


Epoch 1:  43%|████▎     | 4304/10000 [03:19<04:09, 22.83it/s]

Epoch 1, Step 4300, Loss: nan


Epoch 1:  44%|████▍     | 4404/10000 [03:24<04:10, 22.31it/s]

Epoch 1, Step 4400, Loss: nan


Epoch 1:  45%|████▌     | 4503/10000 [03:28<03:50, 23.85it/s]

Epoch 1, Step 4500, Loss: nan


Epoch 1:  46%|████▌     | 4605/10000 [03:33<03:47, 23.67it/s]

Epoch 1, Step 4600, Loss: nan


Epoch 1:  47%|████▋     | 4705/10000 [03:38<03:42, 23.84it/s]

Epoch 1, Step 4700, Loss: nan


Epoch 1:  48%|████▊     | 4804/10000 [03:42<03:40, 23.56it/s]

Epoch 1, Step 4800, Loss: nan


Epoch 1:  49%|████▉     | 4903/10000 [03:47<04:45, 17.82it/s]

Epoch 1, Step 4900, Loss: nan


Epoch 1:  50%|█████     | 5004/10000 [03:52<03:29, 23.85it/s]

Epoch 1, Step 5000, Loss: nan


Epoch 1:  51%|█████     | 5103/10000 [03:56<03:23, 24.04it/s]

Epoch 1, Step 5100, Loss: nan


Epoch 1:  52%|█████▏    | 5202/10000 [04:01<04:38, 17.23it/s]

Epoch 1, Step 5200, Loss: nan


Epoch 1:  53%|█████▎    | 5304/10000 [04:06<03:27, 22.68it/s]

Epoch 1, Step 5300, Loss: nan


Epoch 1:  54%|█████▍    | 5403/10000 [04:10<03:11, 23.95it/s]

Epoch 1, Step 5400, Loss: nan


Epoch 1:  55%|█████▌    | 5502/10000 [04:15<05:00, 14.97it/s]

Epoch 1, Step 5500, Loss: nan


Epoch 1:  56%|█████▌    | 5602/10000 [04:19<03:08, 23.38it/s]

Epoch 1, Step 5600, Loss: nan


Epoch 1:  57%|█████▋    | 5704/10000 [04:24<03:08, 22.80it/s]

Epoch 1, Step 5700, Loss: nan


Epoch 1:  58%|█████▊    | 5805/10000 [04:29<02:59, 23.34it/s]

Epoch 1, Step 5800, Loss: nan


Epoch 1:  59%|█████▉    | 5904/10000 [04:33<02:54, 23.42it/s]

Epoch 1, Step 5900, Loss: nan


Epoch 1:  60%|██████    | 6003/10000 [04:37<02:49, 23.57it/s]

Epoch 1, Step 6000, Loss: nan


Epoch 1:  61%|██████    | 6104/10000 [04:43<02:50, 22.90it/s]

Epoch 1, Step 6100, Loss: nan


Epoch 1:  62%|██████▏   | 6203/10000 [04:47<02:44, 23.07it/s]

Epoch 1, Step 6200, Loss: nan


Epoch 1:  63%|██████▎   | 6302/10000 [04:51<02:41, 22.88it/s]

Epoch 1, Step 6300, Loss: nan


Epoch 1:  64%|██████▍   | 6402/10000 [04:56<02:39, 22.58it/s]

Epoch 1, Step 6400, Loss: nan


Epoch 1:  65%|██████▌   | 6504/10000 [05:01<02:33, 22.83it/s]

Epoch 1, Step 6500, Loss: nan


Epoch 1:  66%|██████▌   | 6603/10000 [05:06<03:20, 16.91it/s]

Epoch 1, Step 6600, Loss: nan


Epoch 1:  67%|██████▋   | 6704/10000 [05:11<02:19, 23.70it/s]

Epoch 1, Step 6700, Loss: nan


Epoch 1:  68%|██████▊   | 6803/10000 [05:15<02:15, 23.52it/s]

Epoch 1, Step 6800, Loss: nan


Epoch 1:  69%|██████▉   | 6902/10000 [05:20<02:55, 17.66it/s]

Epoch 1, Step 6900, Loss: nan


Epoch 1:  70%|███████   | 7003/10000 [05:24<02:07, 23.53it/s]

Epoch 1, Step 7000, Loss: nan


Epoch 1:  71%|███████   | 7102/10000 [05:29<01:57, 24.64it/s]

Epoch 1, Step 7100, Loss: nan


Epoch 1:  72%|███████▏  | 7204/10000 [05:34<02:39, 17.58it/s]

Epoch 1, Step 7200, Loss: nan


Epoch 1:  73%|███████▎  | 7303/10000 [05:38<01:57, 23.04it/s]

Epoch 1, Step 7300, Loss: nan


Epoch 1:  74%|███████▍  | 7402/10000 [05:42<01:57, 22.10it/s]

Epoch 1, Step 7400, Loss: nan


Epoch 1:  75%|███████▌  | 7504/10000 [05:48<01:52, 22.14it/s]

Epoch 1, Step 7500, Loss: nan


Epoch 1:  76%|███████▌  | 7603/10000 [05:52<01:45, 22.68it/s]

Epoch 1, Step 7600, Loss: nan


Epoch 1:  77%|███████▋  | 7702/10000 [05:56<01:38, 23.38it/s]

Epoch 1, Step 7700, Loss: nan


Epoch 1:  78%|███████▊  | 7803/10000 [06:01<01:31, 23.99it/s]

Epoch 1, Step 7800, Loss: nan


Epoch 1:  79%|███████▉  | 7902/10000 [06:06<01:28, 23.76it/s]

Epoch 1, Step 7900, Loss: nan


Epoch 1:  80%|████████  | 8001/10000 [06:10<01:25, 23.42it/s]

Epoch 1, Step 8000, Loss: nan


Epoch 1:  81%|████████  | 8104/10000 [06:15<01:21, 23.25it/s]

Epoch 1, Step 8100, Loss: nan


Epoch 1:  82%|████████▏ | 8203/10000 [06:19<01:14, 24.21it/s]

Epoch 1, Step 8200, Loss: nan


Epoch 1:  83%|████████▎ | 8302/10000 [06:24<01:46, 15.99it/s]

Epoch 1, Step 8300, Loss: nan


Epoch 1:  84%|████████▍ | 8403/10000 [06:29<01:07, 23.58it/s]

Epoch 1, Step 8400, Loss: nan


Epoch 1:  85%|████████▌ | 8502/10000 [06:33<01:02, 23.80it/s]

Epoch 1, Step 8500, Loss: nan


Epoch 1:  86%|████████▌ | 8603/10000 [06:38<01:19, 17.58it/s]

Epoch 1, Step 8600, Loss: nan


Epoch 1:  87%|████████▋ | 8704/10000 [06:43<00:55, 23.15it/s]

Epoch 1, Step 8700, Loss: nan


Epoch 1:  88%|████████▊ | 8803/10000 [06:47<00:49, 24.22it/s]

Epoch 1, Step 8800, Loss: nan


Epoch 1:  89%|████████▉ | 8902/10000 [06:52<01:15, 14.52it/s]

Epoch 1, Step 8900, Loss: nan


Epoch 1:  90%|█████████ | 9003/10000 [06:56<00:42, 23.44it/s]

Epoch 1, Step 9000, Loss: nan


Epoch 1:  91%|█████████ | 9102/10000 [07:01<00:38, 23.41it/s]

Epoch 1, Step 9100, Loss: nan


Epoch 1:  92%|█████████▏| 9202/10000 [07:06<00:37, 21.53it/s]

Epoch 1, Step 9200, Loss: nan


Epoch 1:  93%|█████████▎| 9304/10000 [07:10<00:29, 23.52it/s]

Epoch 1, Step 9300, Loss: nan


Epoch 1:  94%|█████████▍| 9403/10000 [07:14<00:25, 23.61it/s]

Epoch 1, Step 9400, Loss: nan


Epoch 1:  95%|█████████▌| 9502/10000 [07:20<00:21, 23.63it/s]

Epoch 1, Step 9500, Loss: nan


Epoch 1:  96%|█████████▌| 9604/10000 [07:24<00:16, 23.98it/s]

Epoch 1, Step 9600, Loss: nan


Epoch 1:  97%|█████████▋| 9703/10000 [07:28<00:12, 23.79it/s]

Epoch 1, Step 9700, Loss: nan


Epoch 1:  98%|█████████▊| 9804/10000 [07:33<00:08, 23.29it/s]

Epoch 1, Step 9800, Loss: nan


Epoch 1:  99%|█████████▉| 9903/10000 [07:38<00:04, 23.36it/s]

Epoch 1, Step 9900, Loss: nan


Epoch 1: 100%|██████████| 10000/10000 [07:42<00:00, 21.64it/s]


Epoch 1, Step 10000, Loss: nan
Epoch 1 Average Loss: nan


Epoch 2:   1%|          | 105/10000 [00:05<06:44, 24.45it/s]

Epoch 2, Step 100, Loss: nan


Epoch 2:   2%|▏         | 204/10000 [00:09<06:50, 23.86it/s]

Epoch 2, Step 200, Loss: nan


Epoch 2:   3%|▎         | 302/10000 [00:14<09:44, 16.60it/s]

Epoch 2, Step 300, Loss: nan


Epoch 2:   4%|▍         | 403/10000 [00:18<06:52, 23.25it/s]

Epoch 2, Step 400, Loss: nan


Epoch 2:   5%|▌         | 502/10000 [00:23<06:34, 24.06it/s]

Epoch 2, Step 500, Loss: nan


Epoch 2:   6%|▌         | 602/10000 [00:28<08:45, 17.87it/s]

Epoch 2, Step 600, Loss: nan


Epoch 2:   7%|▋         | 704/10000 [00:32<06:49, 22.72it/s]

Epoch 2, Step 700, Loss: nan


Epoch 2:   8%|▊         | 803/10000 [00:36<06:34, 23.29it/s]

Epoch 2, Step 800, Loss: nan


Epoch 2:   9%|▉         | 902/10000 [00:41<09:46, 15.51it/s]

Epoch 2, Step 900, Loss: nan


Epoch 2:  10%|█         | 1002/10000 [00:46<06:39, 22.50it/s]

Epoch 2, Step 1000, Loss: nan


Epoch 2:  11%|█         | 1104/10000 [00:50<06:28, 22.87it/s]

Epoch 2, Step 1100, Loss: nan


Epoch 2:  12%|█▏        | 1204/10000 [00:55<06:53, 21.29it/s]

Epoch 2, Step 1200, Loss: nan


Epoch 2:  13%|█▎        | 1303/10000 [01:00<06:01, 24.03it/s]

Epoch 2, Step 1300, Loss: nan


Epoch 2:  14%|█▍        | 1402/10000 [01:04<06:05, 23.53it/s]

Epoch 2, Step 1400, Loss: nan


Epoch 2:  15%|█▌        | 1504/10000 [01:09<06:14, 22.66it/s]

Epoch 2, Step 1500, Loss: nan


Epoch 2:  16%|█▌        | 1603/10000 [01:13<05:58, 23.41it/s]

Epoch 2, Step 1600, Loss: nan


Epoch 2:  17%|█▋        | 1705/10000 [01:18<05:37, 24.55it/s]

Epoch 2, Step 1700, Loss: nan


Epoch 2:  18%|█▊        | 1804/10000 [01:23<05:40, 24.06it/s]

Epoch 2, Step 1800, Loss: nan


Epoch 2:  19%|█▉        | 1903/10000 [01:27<05:33, 24.30it/s]

Epoch 2, Step 1900, Loss: nan


Epoch 2:  20%|██        | 2003/10000 [01:32<07:27, 17.88it/s]

Epoch 2, Step 2000, Loss: nan


Epoch 2:  21%|██        | 2103/10000 [01:37<05:17, 24.86it/s]

Epoch 2, Step 2100, Loss: nan


Epoch 2:  22%|██▏       | 2202/10000 [01:41<05:30, 23.58it/s]

Epoch 2, Step 2200, Loss: nan


Epoch 2:  23%|██▎       | 2302/10000 [01:46<08:08, 15.76it/s]

Epoch 2, Step 2300, Loss: nan


Epoch 2:  24%|██▍       | 2403/10000 [01:51<05:15, 24.11it/s]

Epoch 2, Step 2400, Loss: nan


Epoch 2:  25%|██▌       | 2502/10000 [01:55<05:29, 22.76it/s]

Epoch 2, Step 2500, Loss: nan


Epoch 2:  26%|██▌       | 2602/10000 [02:00<07:29, 16.47it/s]

Epoch 2, Step 2600, Loss: nan


Epoch 2:  27%|██▋       | 2702/10000 [02:04<04:57, 24.53it/s]

Epoch 2, Step 2700, Loss: nan


Epoch 2:  28%|██▊       | 2804/10000 [02:09<05:13, 22.98it/s]

Epoch 2, Step 2800, Loss: nan


Epoch 2:  29%|██▉       | 2904/10000 [02:14<05:55, 19.93it/s]

Epoch 2, Step 2900, Loss: nan


Epoch 2:  30%|███       | 3003/10000 [02:18<04:49, 24.14it/s]

Epoch 2, Step 3000, Loss: nan


Epoch 2:  31%|███       | 3102/10000 [02:22<04:49, 23.84it/s]

Epoch 2, Step 3100, Loss: nan


Epoch 2:  32%|███▏      | 3205/10000 [02:28<04:46, 23.68it/s]

Epoch 2, Step 3200, Loss: nan


Epoch 2:  33%|███▎      | 3304/10000 [02:32<04:55, 22.63it/s]

Epoch 2, Step 3300, Loss: nan


Epoch 2:  34%|███▍      | 3403/10000 [02:36<04:52, 22.55it/s]

Epoch 2, Step 3400, Loss: nan


Epoch 2:  35%|███▌      | 3504/10000 [02:42<04:43, 22.93it/s]

Epoch 2, Step 3500, Loss: nan


Epoch 2:  36%|███▌      | 3603/10000 [02:46<04:52, 21.89it/s]

Epoch 2, Step 3600, Loss: nan


Epoch 2:  37%|███▋      | 3702/10000 [02:50<05:47, 18.10it/s]

Epoch 2, Step 3700, Loss: nan


Epoch 2:  38%|███▊      | 3803/10000 [02:56<04:21, 23.71it/s]

Epoch 2, Step 3800, Loss: nan


Epoch 2:  39%|███▉      | 3902/10000 [03:00<04:20, 23.38it/s]

Epoch 2, Step 3900, Loss: nan


Epoch 2:  40%|████      | 4003/10000 [03:05<06:11, 16.16it/s]

Epoch 2, Step 4000, Loss: nan


Epoch 2:  41%|████      | 4103/10000 [03:09<04:10, 23.55it/s]

Epoch 2, Step 4100, Loss: nan


Epoch 2:  42%|████▏     | 4202/10000 [03:14<04:06, 23.54it/s]

Epoch 2, Step 4200, Loss: nan


Epoch 2:  43%|████▎     | 4303/10000 [03:19<05:48, 16.35it/s]

Epoch 2, Step 4300, Loss: nan


Epoch 2:  44%|████▍     | 4403/10000 [03:23<04:01, 23.17it/s]

Epoch 2, Step 4400, Loss: nan


Epoch 2:  45%|████▌     | 4505/10000 [03:28<03:56, 23.22it/s]

Epoch 2, Step 4500, Loss: nan


Epoch 2:  46%|████▌     | 4602/10000 [03:33<04:25, 20.35it/s]

Epoch 2, Step 4600, Loss: nan


Epoch 2:  47%|████▋     | 4704/10000 [03:37<03:51, 22.91it/s]

Epoch 2, Step 4700, Loss: nan


Epoch 2:  48%|████▊     | 4803/10000 [03:41<03:40, 23.52it/s]

Epoch 2, Step 4800, Loss: nan


Epoch 2:  49%|████▉     | 4904/10000 [03:47<03:46, 22.47it/s]

Epoch 2, Step 4900, Loss: nan


Epoch 2:  50%|█████     | 5003/10000 [03:51<03:30, 23.79it/s]

Epoch 2, Step 5000, Loss: nan


Epoch 2:  51%|█████     | 5102/10000 [03:55<03:27, 23.57it/s]

Epoch 2, Step 5100, Loss: nan


Epoch 2:  52%|█████▏    | 5203/10000 [04:00<03:27, 23.16it/s]

Epoch 2, Step 5200, Loss: nan


Epoch 2:  53%|█████▎    | 5302/10000 [04:05<03:18, 23.61it/s]

Epoch 2, Step 5300, Loss: nan


Epoch 2:  54%|█████▍    | 5403/10000 [04:09<04:24, 17.41it/s]

Epoch 2, Step 5400, Loss: nan


Epoch 2:  55%|█████▌    | 5503/10000 [04:14<03:13, 23.24it/s]

Epoch 2, Step 5500, Loss: nan


Epoch 2:  56%|█████▌    | 5602/10000 [04:19<03:11, 22.96it/s]

Epoch 2, Step 5600, Loss: nan


Epoch 2:  57%|█████▋    | 5702/10000 [04:24<04:01, 17.78it/s]

Epoch 2, Step 5700, Loss: nan


Epoch 2:  58%|█████▊    | 5802/10000 [04:28<03:04, 22.78it/s]

Epoch 2, Step 5800, Loss: nan


Epoch 2:  59%|█████▉    | 5904/10000 [04:33<02:51, 23.82it/s]

Epoch 2, Step 5900, Loss: nan


Epoch 2:  60%|██████    | 6003/10000 [04:38<03:57, 16.86it/s]

Epoch 2, Step 6000, Loss: nan


Epoch 2:  61%|██████    | 6102/10000 [04:42<02:47, 23.33it/s]

Epoch 2, Step 6100, Loss: nan


Epoch 2:  62%|██████▏   | 6204/10000 [04:47<02:53, 21.86it/s]

Epoch 2, Step 6200, Loss: nan


Epoch 2:  63%|██████▎   | 6302/10000 [04:52<02:47, 22.10it/s]

Epoch 2, Step 6300, Loss: nan


Epoch 2:  64%|██████▍   | 6404/10000 [04:56<02:36, 22.94it/s]

Epoch 2, Step 6400, Loss: nan


Epoch 2:  65%|██████▌   | 6503/10000 [05:00<02:32, 22.88it/s]

Epoch 2, Step 6500, Loss: nan


Epoch 2:  66%|██████▌   | 6604/10000 [05:06<02:25, 23.32it/s]

Epoch 2, Step 6600, Loss: nan


Epoch 2:  67%|██████▋   | 6703/10000 [05:10<02:23, 22.90it/s]

Epoch 2, Step 6700, Loss: nan


Epoch 2:  68%|██████▊   | 6803/10000 [05:15<02:50, 18.71it/s]

Epoch 2, Step 6800, Loss: nan


Epoch 2:  69%|██████▉   | 6904/10000 [05:20<02:13, 23.24it/s]

Epoch 2, Step 6900, Loss: nan


Epoch 2:  70%|███████   | 7003/10000 [05:24<02:07, 23.52it/s]

Epoch 2, Step 7000, Loss: nan


Epoch 2:  71%|███████   | 7102/10000 [05:29<02:48, 17.21it/s]

Epoch 2, Step 7100, Loss: nan


Epoch 2:  72%|███████▏  | 7202/10000 [05:33<01:56, 23.92it/s]

Epoch 2, Step 7200, Loss: nan


Epoch 2:  73%|███████▎  | 7304/10000 [05:38<01:57, 22.94it/s]

Epoch 2, Step 7300, Loss: nan


Epoch 2:  74%|███████▍  | 7402/10000 [05:43<02:47, 15.47it/s]

Epoch 2, Step 7400, Loss: nan


Epoch 2:  75%|███████▌  | 7504/10000 [05:47<01:50, 22.56it/s]

Epoch 2, Step 7500, Loss: nan


Epoch 2:  76%|███████▌  | 7603/10000 [05:51<01:41, 23.59it/s]

Epoch 2, Step 7600, Loss: nan


Epoch 2:  77%|███████▋  | 7703/10000 [05:57<01:48, 21.14it/s]

Epoch 2, Step 7700, Loss: nan


Epoch 2:  78%|███████▊  | 7802/10000 [06:01<01:34, 23.25it/s]

Epoch 2, Step 7800, Loss: nan


Epoch 2:  79%|███████▉  | 7904/10000 [06:05<01:30, 23.05it/s]

Epoch 2, Step 7900, Loss: nan


Epoch 2:  80%|████████  | 8002/10000 [06:11<01:25, 23.39it/s]

Epoch 2, Step 8000, Loss: nan


Epoch 2:  81%|████████  | 8104/10000 [06:15<01:21, 23.13it/s]

Epoch 2, Step 8100, Loss: nan


Epoch 2:  82%|████████▏ | 8203/10000 [06:19<01:13, 24.44it/s]

Epoch 2, Step 8200, Loss: nan


Epoch 2:  83%|████████▎ | 8305/10000 [06:25<01:10, 23.97it/s]

Epoch 2, Step 8300, Loss: nan


Epoch 2:  84%|████████▍ | 8404/10000 [06:29<01:10, 22.79it/s]

Epoch 2, Step 8400, Loss: nan


Epoch 2:  85%|████████▌ | 8502/10000 [06:33<01:23, 17.93it/s]

Epoch 2, Step 8500, Loss: nan


Epoch 2:  86%|████████▌ | 8603/10000 [06:38<01:01, 22.90it/s]

Epoch 2, Step 8600, Loss: nan


Epoch 2:  87%|████████▋ | 8702/10000 [06:43<00:56, 22.96it/s]

Epoch 2, Step 8700, Loss: nan


Epoch 2:  88%|████████▊ | 8803/10000 [06:47<01:13, 16.33it/s]

Epoch 2, Step 8800, Loss: nan


Epoch 2:  89%|████████▉ | 8903/10000 [06:52<00:45, 24.21it/s]

Epoch 2, Step 8900, Loss: nan


Epoch 2:  90%|█████████ | 9002/10000 [06:56<00:43, 23.18it/s]

Epoch 2, Step 9000, Loss: nan


Epoch 2:  91%|█████████ | 9102/10000 [07:01<00:58, 15.43it/s]

Epoch 2, Step 9100, Loss: nan


Epoch 2:  92%|█████████▏| 9203/10000 [07:06<00:34, 23.05it/s]

Epoch 2, Step 9200, Loss: nan


Epoch 2:  93%|█████████▎| 9305/10000 [07:10<00:29, 23.92it/s]

Epoch 2, Step 9300, Loss: nan


Epoch 2:  94%|█████████▍| 9403/10000 [07:16<00:27, 21.56it/s]

Epoch 2, Step 9400, Loss: nan


Epoch 2:  95%|█████████▌| 9502/10000 [07:20<00:21, 23.19it/s]

Epoch 2, Step 9500, Loss: nan


Epoch 2:  96%|█████████▌| 9604/10000 [07:24<00:17, 22.26it/s]

Epoch 2, Step 9600, Loss: nan


Epoch 2:  97%|█████████▋| 9703/10000 [07:30<00:13, 22.58it/s]

Epoch 2, Step 9700, Loss: nan


Epoch 2:  98%|█████████▊| 9805/10000 [07:34<00:08, 23.67it/s]

Epoch 2, Step 9800, Loss: nan


Epoch 2:  99%|█████████▉| 9904/10000 [07:38<00:04, 23.21it/s]

Epoch 2, Step 9900, Loss: nan


Epoch 2: 100%|██████████| 10000/10000 [07:43<00:00, 21.56it/s]


Epoch 2, Step 10000, Loss: nan
Epoch 2 Average Loss: nan


Epoch 3:   1%|          | 102/10000 [00:04<07:13, 22.85it/s]

Epoch 3, Step 100, Loss: nan


Epoch 3:   2%|▏         | 203/10000 [00:09<10:12, 16.00it/s]

Epoch 3, Step 200, Loss: nan


Epoch 3:   3%|▎         | 304/10000 [00:14<07:14, 22.34it/s]

Epoch 3, Step 300, Loss: nan


Epoch 3:   4%|▍         | 403/10000 [00:18<06:59, 22.89it/s]

Epoch 3, Step 400, Loss: nan


Epoch 3:   5%|▌         | 503/10000 [00:23<08:51, 17.86it/s]

Epoch 3, Step 500, Loss: nan


Epoch 3:   6%|▌         | 604/10000 [00:28<06:39, 23.52it/s]

Epoch 3, Step 600, Loss: nan


Epoch 3:   7%|▋         | 703/10000 [00:32<06:48, 22.77it/s]

Epoch 3, Step 700, Loss: nan


Epoch 3:   8%|▊         | 802/10000 [00:37<08:53, 17.24it/s]

Epoch 3, Step 800, Loss: nan


Epoch 3:   9%|▉         | 904/10000 [00:42<06:36, 22.94it/s]

Epoch 3, Step 900, Loss: nan


Epoch 3:  10%|█         | 1003/10000 [00:46<06:34, 22.81it/s]

Epoch 3, Step 1000, Loss: nan


Epoch 3:  11%|█         | 1103/10000 [00:51<06:43, 22.03it/s]

Epoch 3, Step 1100, Loss: nan


Epoch 3:  12%|█▏        | 1205/10000 [00:56<06:16, 23.37it/s]

Epoch 3, Step 1200, Loss: nan


Epoch 3:  13%|█▎        | 1304/10000 [01:00<06:11, 23.41it/s]

Epoch 3, Step 1300, Loss: nan


Epoch 3:  14%|█▍        | 1402/10000 [01:05<06:06, 23.48it/s]

Epoch 3, Step 1400, Loss: nan


Epoch 3:  15%|█▌        | 1504/10000 [01:09<06:04, 23.33it/s]

Epoch 3, Step 1500, Loss: nan


Epoch 3:  16%|█▌        | 1603/10000 [01:14<07:07, 19.66it/s]

Epoch 3, Step 1600, Loss: nan


Epoch 3:  17%|█▋        | 1702/10000 [01:19<05:53, 23.50it/s]

Epoch 3, Step 1700, Loss: nan


Epoch 3:  18%|█▊        | 1804/10000 [01:23<05:47, 23.62it/s]

Epoch 3, Step 1800, Loss: nan


Epoch 3:  19%|█▉        | 1902/10000 [01:28<08:28, 15.93it/s]

Epoch 3, Step 1900, Loss: nan


Epoch 3:  20%|██        | 2002/10000 [01:33<05:41, 23.45it/s]

Epoch 3, Step 2000, Loss: nan


Epoch 3:  21%|██        | 2104/10000 [01:37<05:59, 21.96it/s]

Epoch 3, Step 2100, Loss: nan


Epoch 3:  22%|██▏       | 2203/10000 [01:42<08:21, 15.56it/s]

Epoch 3, Step 2200, Loss: nan


Epoch 3:  23%|██▎       | 2305/10000 [01:47<05:22, 23.84it/s]

Epoch 3, Step 2300, Loss: nan


Epoch 3:  24%|██▍       | 2404/10000 [01:51<05:35, 22.62it/s]

Epoch 3, Step 2400, Loss: nan


Epoch 3:  25%|██▌       | 2503/10000 [01:56<05:51, 21.30it/s]

Epoch 3, Step 2500, Loss: nan


Epoch 3:  26%|██▌       | 2602/10000 [02:00<05:07, 24.09it/s]

Epoch 3, Step 2600, Loss: nan


Epoch 3:  27%|██▋       | 2704/10000 [02:05<05:16, 23.05it/s]

Epoch 3, Step 2700, Loss: nan


Epoch 3:  28%|██▊       | 2803/10000 [02:10<05:06, 23.46it/s]

Epoch 3, Step 2800, Loss: nan


Epoch 3:  29%|██▉       | 2902/10000 [02:14<05:04, 23.32it/s]

Epoch 3, Step 2900, Loss: nan


Epoch 3:  30%|███       | 3004/10000 [02:19<05:06, 22.79it/s]

Epoch 3, Step 3000, Loss: nan


Epoch 3:  31%|███       | 3102/10000 [02:24<05:03, 22.72it/s]

Epoch 3, Step 3100, Loss: nan


Epoch 3:  32%|███▏      | 3204/10000 [02:28<04:58, 22.80it/s]

Epoch 3, Step 3200, Loss: nan


Epoch 3:  33%|███▎      | 3303/10000 [02:33<06:44, 16.57it/s]

Epoch 3, Step 3300, Loss: nan


Epoch 3:  34%|███▍      | 3403/10000 [02:38<04:41, 23.47it/s]

Epoch 3, Step 3400, Loss: nan


Epoch 3:  35%|███▌      | 3505/10000 [02:42<04:28, 24.18it/s]

Epoch 3, Step 3500, Loss: nan


Epoch 3:  36%|███▌      | 3603/10000 [02:47<06:07, 17.43it/s]

Epoch 3, Step 3600, Loss: nan


Epoch 3:  37%|███▋      | 3702/10000 [02:52<04:17, 24.49it/s]

Epoch 3, Step 3700, Loss: nan


Epoch 3:  38%|███▊      | 3804/10000 [02:56<04:28, 23.08it/s]

Epoch 3, Step 3800, Loss: nan


Epoch 3:  39%|███▉      | 3902/10000 [03:01<06:27, 15.75it/s]

Epoch 3, Step 3900, Loss: nan


Epoch 3:  40%|████      | 4004/10000 [03:06<04:21, 22.89it/s]

Epoch 3, Step 4000, Loss: nan


Epoch 3:  41%|████      | 4103/10000 [03:10<04:19, 22.75it/s]

Epoch 3, Step 4100, Loss: nan


Epoch 3:  42%|████▏     | 4204/10000 [03:15<04:20, 22.29it/s]

Epoch 3, Step 4200, Loss: nan


Epoch 3:  43%|████▎     | 4303/10000 [03:20<04:03, 23.39it/s]

Epoch 3, Step 4300, Loss: nan


Epoch 3:  44%|████▍     | 4402/10000 [03:24<04:20, 21.46it/s]

Epoch 3, Step 4400, Loss: nan


Epoch 3:  45%|████▌     | 4504/10000 [03:30<04:07, 22.21it/s]

Epoch 3, Step 4500, Loss: nan


Epoch 3:  46%|████▌     | 4603/10000 [03:34<03:58, 22.67it/s]

Epoch 3, Step 4600, Loss: nan


Epoch 3:  47%|████▋     | 4702/10000 [03:38<04:55, 17.90it/s]

Epoch 3, Step 4700, Loss: nan


Epoch 3:  48%|████▊     | 4803/10000 [03:44<03:40, 23.62it/s]

Epoch 3, Step 4800, Loss: nan


Epoch 3:  49%|████▉     | 4902/10000 [03:48<03:37, 23.49it/s]

Epoch 3, Step 4900, Loss: nan


Epoch 3:  50%|█████     | 5002/10000 [03:53<05:07, 16.24it/s]

Epoch 3, Step 5000, Loss: nan


Epoch 3:  51%|█████     | 5102/10000 [03:57<03:34, 22.80it/s]

Epoch 3, Step 5100, Loss: nan


Epoch 3:  52%|█████▏    | 5204/10000 [04:02<03:25, 23.29it/s]

Epoch 3, Step 5200, Loss: nan


Epoch 3:  53%|█████▎    | 5303/10000 [04:07<05:07, 15.29it/s]

Epoch 3, Step 5300, Loss: nan


Epoch 3:  54%|█████▍    | 5403/10000 [04:11<03:24, 22.50it/s]

Epoch 3, Step 5400, Loss: nan


Epoch 3:  55%|█████▌    | 5502/10000 [04:16<03:17, 22.73it/s]

Epoch 3, Step 5500, Loss: nan


Epoch 3:  56%|█████▌    | 5604/10000 [04:21<03:16, 22.32it/s]

Epoch 3, Step 5600, Loss: nan


Epoch 3:  57%|█████▋    | 5703/10000 [04:25<02:59, 23.99it/s]

Epoch 3, Step 5700, Loss: nan


Epoch 3:  58%|█████▊    | 5802/10000 [04:29<02:59, 23.45it/s]

Epoch 3, Step 5800, Loss: nan


Epoch 3:  59%|█████▉    | 5905/10000 [04:35<02:56, 23.25it/s]

Epoch 3, Step 5900, Loss: nan


Epoch 3:  60%|██████    | 6004/10000 [04:39<02:52, 23.18it/s]

Epoch 3, Step 6000, Loss: nan


Epoch 3:  61%|██████    | 6103/10000 [04:43<02:47, 23.25it/s]

Epoch 3, Step 6100, Loss: nan


Epoch 3:  62%|██████▏   | 6204/10000 [04:49<02:41, 23.54it/s]

Epoch 3, Step 6200, Loss: nan


Epoch 3:  63%|██████▎   | 6303/10000 [04:53<02:35, 23.70it/s]

Epoch 3, Step 6300, Loss: nan


Epoch 3:  64%|██████▍   | 6402/10000 [04:57<03:14, 18.47it/s]

Epoch 3, Step 6400, Loss: nan


Epoch 3:  65%|██████▌   | 6503/10000 [05:02<02:28, 23.50it/s]

Epoch 3, Step 6500, Loss: nan


Epoch 3:  66%|██████▌   | 6602/10000 [05:07<02:22, 23.87it/s]

Epoch 3, Step 6600, Loss: nan


Epoch 3:  67%|██████▋   | 6702/10000 [05:11<03:17, 16.68it/s]

Epoch 3, Step 6700, Loss: nan


Epoch 3:  68%|██████▊   | 6802/10000 [05:16<02:22, 22.43it/s]

Epoch 3, Step 6800, Loss: nan


Epoch 3:  69%|██████▉   | 6904/10000 [05:21<02:14, 23.04it/s]

Epoch 3, Step 6900, Loss: nan


Epoch 3:  70%|███████   | 7002/10000 [05:26<03:17, 15.19it/s]

Epoch 3, Step 7000, Loss: nan


Epoch 3:  71%|███████   | 7104/10000 [05:30<02:06, 22.82it/s]

Epoch 3, Step 7100, Loss: nan


Epoch 3:  72%|███████▏  | 7203/10000 [05:34<01:59, 23.42it/s]

Epoch 3, Step 7200, Loss: nan


Epoch 3:  73%|███████▎  | 7302/10000 [05:40<02:00, 22.35it/s]

Epoch 3, Step 7300, Loss: nan


Epoch 3:  74%|███████▍  | 7404/10000 [05:44<01:57, 22.01it/s]

Epoch 3, Step 7400, Loss: nan


Epoch 3:  75%|███████▌  | 7503/10000 [05:48<01:54, 21.79it/s]

Epoch 3, Step 7500, Loss: nan


Epoch 3:  76%|███████▌  | 7603/10000 [05:54<01:46, 22.56it/s]

Epoch 3, Step 7600, Loss: nan


Epoch 3:  77%|███████▋  | 7702/10000 [05:58<01:40, 22.89it/s]

Epoch 3, Step 7700, Loss: nan


Epoch 3:  78%|███████▊  | 7801/10000 [06:02<01:32, 23.74it/s]

Epoch 3, Step 7800, Loss: nan


Epoch 3:  79%|███████▉  | 7902/10000 [06:07<01:26, 24.17it/s]

Epoch 3, Step 7900, Loss: nan


Epoch 3:  80%|████████  | 8004/10000 [06:12<01:27, 22.81it/s]

Epoch 3, Step 8000, Loss: nan


Epoch 3:  81%|████████  | 8101/10000 [06:16<01:50, 17.18it/s]

Epoch 3, Step 8100, Loss: nan


Epoch 3:  82%|████████▏ | 8203/10000 [06:21<01:16, 23.36it/s]

Epoch 3, Step 8200, Loss: nan


Epoch 3:  83%|████████▎ | 8302/10000 [06:26<01:11, 23.83it/s]

Epoch 3, Step 8300, Loss: nan


Epoch 3:  84%|████████▍ | 8403/10000 [06:31<01:30, 17.58it/s]

Epoch 3, Step 8400, Loss: nan


Epoch 3:  85%|████████▌ | 8503/10000 [06:35<01:06, 22.51it/s]

Epoch 3, Step 8500, Loss: nan


Epoch 3:  86%|████████▌ | 8602/10000 [06:39<01:01, 22.82it/s]

Epoch 3, Step 8600, Loss: nan


Epoch 3:  87%|████████▋ | 8702/10000 [06:45<01:26, 14.95it/s]

Epoch 3, Step 8700, Loss: nan


Epoch 3:  88%|████████▊ | 8803/10000 [06:49<00:52, 22.64it/s]

Epoch 3, Step 8800, Loss: nan


Epoch 3:  89%|████████▉ | 8902/10000 [06:53<00:46, 23.61it/s]

Epoch 3, Step 8900, Loss: nan


Epoch 3:  90%|█████████ | 9002/10000 [06:59<00:44, 22.55it/s]

Epoch 3, Step 9000, Loss: nan


Epoch 3:  91%|█████████ | 9104/10000 [07:03<00:38, 23.30it/s]

Epoch 3, Step 9100, Loss: nan


Epoch 3:  92%|█████████▏| 9203/10000 [07:07<00:33, 23.49it/s]

Epoch 3, Step 9200, Loss: nan


Epoch 3:  93%|█████████▎| 9302/10000 [07:12<00:30, 22.57it/s]

Epoch 3, Step 9300, Loss: nan


Epoch 3:  94%|█████████▍| 9404/10000 [07:17<00:25, 23.66it/s]

Epoch 3, Step 9400, Loss: nan


Epoch 3:  95%|█████████▌| 9503/10000 [07:21<00:26, 18.76it/s]

Epoch 3, Step 9500, Loss: nan


Epoch 3:  96%|█████████▌| 9604/10000 [07:26<00:17, 22.63it/s]

Epoch 3, Step 9600, Loss: nan


Epoch 3:  97%|█████████▋| 9703/10000 [07:31<00:12, 23.36it/s]

Epoch 3, Step 9700, Loss: nan


Epoch 3:  98%|█████████▊| 9802/10000 [07:35<00:12, 16.11it/s]

Epoch 3, Step 9800, Loss: nan


Epoch 3:  99%|█████████▉| 9903/10000 [07:40<00:04, 23.11it/s]

Epoch 3, Step 9900, Loss: nan


Epoch 3: 100%|██████████| 10000/10000 [07:45<00:00, 21.51it/s]


Epoch 3, Step 10000, Loss: nan
Epoch 3 Average Loss: nan




('/content/drive/MyDrive/gpt2-finetuned/tokenizer_config.json',
 '/content/drive/MyDrive/gpt2-finetuned/special_tokens_map.json',
 '/content/drive/MyDrive/gpt2-finetuned/vocab.json',
 '/content/drive/MyDrive/gpt2-finetuned/merges.txt',
 '/content/drive/MyDrive/gpt2-finetuned/added_tokens.json',
 '/content/drive/MyDrive/gpt2-finetuned/tokenizer.json')

In [13]:
class MentalHealthChatbot:
    def __init__(self, model, tokenizer, bert_model, bert_tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.bert_model = bert_model
        self.bert_tokenizer = bert_tokenizer
        self.enable_sentiment = True
        self.emotion_labels = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

    def analyze_sentiment(self, text):
        try:
            inputs = self.bert_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = self.bert_model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1)
            emotion_idx = probs.argmax().item()
            return self.emotion_labels[emotion_idx], probs[0][emotion_idx].item()
        except Exception as e:
            print(f"Error in sentiment analysis: {e}")
            return "unknown", 0.0

    def generate_response(self, user_input, max_length=100):
        try:
            prompt = f"Client: {user_input}"
            if self.enable_sentiment:
                emotion, confidence = self.analyze_sentiment(user_input)
                prompt += f" [Emotion: {emotion}, Confidence: {confidence:.2f}]"
            prompt += " Therapist:"

            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=1,
                do_sample=True,
                top_p=0.9,
                temperature=0.7,
                pad_token_id=self.tokenizer.pad_token_id
            )

            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            therapist_response = response.split("Therapist:")[-1].strip()
            return therapist_response
        except Exception as e:
            print(f"Error generating response: {e}")
            return "I'm sorry, I couldn't process that. Can you try again?"

# Initialize chatbot
chatbot = MentalHealthChatbot(model, tokenizer, bert_model, bert_tokenizer)

In [14]:
%%writefile app.py
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BertTokenizer, AutoModelForSequenceClassification

# Load models
@st.cache_resource
def load_models():
    try:
        model = AutoModelForCausalLM.from_pretrained(
            "/content/drive/MyDrive/gpt2-finetuned",
            torch_dtype=torch.float16,
            device_map="auto"
        )
        tokenizer = AutoTokenizer.from_pretrained(
            "/content/drive/MyDrive/gpt2-finetuned"
        )
        tokenizer.pad_token = tokenizer.eos_token
        bert_model = AutoModelForSequenceClassification.from_pretrained(
            "bhadresh-savani/bert-base-uncased-emotion"
        )
        bert_tokenizer = BertTokenizer.from_pretrained(
            "bhadresh-savani/bert-base-uncased-emotion"
        )
        return model, tokenizer, bert_model, bert_tokenizer
    except Exception as e:
        st.error(f"Error loading models: {e}")
        return None, None, None, None

model, tokenizer, bert_model, bert_tokenizer = load_models()

# Chatbot class
class MentalHealthChatbot:
    def __init__(self, model, tokenizer, bert_model, bert_tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.bert_model = bert_model
        self.bert_tokenizer = bert_tokenizer
        self.enable_sentiment = True
        self.emotion_labels = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

    def analyze_sentiment(self, text):
        try:
            inputs = self.bert_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            with torch.no_grad():
                outputs = self.bert_model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1)
            emotion_idx = probs.argmax().item()
            return self.emotion_labels[emotion_idx], probs[0][emotion_idx].item()
        except Exception as e:
            return "unknown", 0.0

    def generate_response(self, user_input, max_length=100):
        try:
            prompt = f"Client: {user_input}"
            if self.enable_sentiment:
                emotion, confidence = self.analyze_sentiment(user_input)
                prompt += f" [Emotion: {emotion}, Confidence: {confidence:.2f}]"
            prompt += " Therapist:"

            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_return_sequences=1,
                do_sample=True,
                top_p=0.9,
                temperature=0.7,
                pad_token_id=self.tokenizer.pad_token_id
            )

            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            therapist_response = response.split("Therapist:")[-1].strip()
            return therapist_response
        except Exception as e:
            return "I'm sorry, I couldn't process that. Can you try again?"

# Streamlit UI
st.title("Mental Health Chatbot")
st.write("Talk to our AI therapist trained to provide empathetic responses.")

if model is None or tokenizer is None:
    st.error("Failed to load models. Please check the notebook logs.")
else:
    chatbot = MentalHealthChatbot(model, tokenizer, bert_model, bert_tokenizer)

    # Chat history
    if "messages" not in st.session_state:
        st.session_state.messages = []

    # Display chat history
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    # User input
    if user_input := st.chat_input("How are you feeling today?"):
        # Add user message
        st.session_state.messages.append({"role": "user", "content": user_input})
        with st.chat_message("user"):
            st.markdown(user_input)

        # Generate response
        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                response = chatbot.generate_response(user_input)
                st.markdown(response)
        st.session_state.messages.append({"role": "assistant", "content": response})

Writing app.py


In [15]:
!streamlit run app.py --server.port 8501


Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Local URL: [0m[1mhttp://localhost:8501[0m
[34m  Network URL: [0m[1mhttp://172.28.0.12:8501[0m
[34m  External URL: [0m[1mhttp://34.125.1.175:8501[0m
[0m
[34m  Stopping...[0m
[34m  Stopping...[0m


In [16]:
from pyngrok import ngrok
import subprocess
import time

# Kill any existing ngrok tunnels
ngrok.kill()

# Set ngrok authtoken
ngrok.set_auth_token("2xP1dGhas6RhvIEbzAwaLgKTaQS_4r2bq8TaiQfrmC2KLUigZ")  # Replace with your ngrok token if different

# Start Streamlit server in the background
subprocess.Popen(["streamlit", "run", "app.py", "--server.port", "8501"])

# Wait for Streamlit to start
time.sleep(5)

# Create ngrok tunnel
public_url = ngrok.connect(8501)
print(f"Streamlit app running at: {public_url}")

# Keep the tunnel alive
while True:
    time.sleep(60)

Streamlit app running at: NgrokTunnel: "https://00c1-34-125-1-175.ngrok-free.app" -> "http://localhost:8501"


KeyboardInterrupt: 

In [21]:
prompt = "Client: I feel anxious today."
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to("cuda")
print("Inputs:", inputs)
outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    num_return_sequences=1,
    do_sample=True,
    temperature=0.9,
    top_k=50,
    pad_token_id=tokenizer.eos_token_id,
    renormalize_logits=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Response: {response}")

Inputs: {'input_ids': tensor([[11792,    25,   314,  1254, 18116,  1909,    13]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}


RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

In [18]:
import torch
import nltk
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction
import time
from transformers import AutoModelForSequenceClassification, BertTokenizer

# Ensure NLTK resources are downloaded
nltk.download('punkt')
nltk.download('wordnet')

# Define metrics calculation function
def evaluate_chatbot(model, tokenizer, val_dataset, bert_model, bert_tokenizer, device='cuda', max_samples=100):
    model.eval()
    model.to(device)
    bert_model.eval()
    bert_model.to(device)

    # Initialize metrics
    perplexities = []
    bleu_scores = []
    rouge_scores = []
    sentiment_scores = []
    inference_times = []
    distinct_n_scores = []
    emotion_labels = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']

    # ROUGE scorer
    rouge = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)

    # Limit evaluation to a subset to save time
    val_subset = val_dataset.select(range(min(max_samples, len(val_dataset))))

    for item in tqdm(val_subset, desc="Evaluating"):
        prompt = item['prompt']
        reference = item['response'].replace("Therapist: ", "").strip()

        # Generate response
        start_time = time.time()
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=100,
                num_return_sequences=1,
                do_sample=True,
                top_p=0.9,
                temperature=0.7,
                pad_token_id=tokenizer.pad_token_id
            )
        inference_time = time.time() - start_time
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True).split("Therapist:")[-1].strip()
        inference_times.append(inference_time)

        # Perplexity
        input_ids = inputs['input_ids']
        with torch.no_grad():
            outputs = model(input_ids=input_ids, labels=input_ids)
            ppl = torch.exp(outputs.loss).item()
        perplexities.append(ppl)

        # BLEU
        reference_tokens = nltk.word_tokenize(reference.lower())
        generated_tokens = nltk.word_tokenize(generated.lower())
        bleu = sentence_bleu([reference_tokens], generated_tokens, smoothing_function=SmoothingFunction().method1)
        bleu_scores.append(bleu)

        # ROUGE
        rouge_result = rouge.score(reference, generated)
        rouge_scores.append(rouge_result['rouge1'].fmeasure)

        # Sentiment Score (using BERT)
        inputs_bert = bert_tokenizer(generated, return_tensors="pt", truncation=True, max_length=512).to(device)
        with torch.no_grad():
            outputs_bert = bert_model(**inputs_bert)
        probs = torch.softmax(outputs_bert.logits, dim=1)
        sentiment_score = probs[0][emotion_labels.index('joy')].item()  # Focus on 'joy' as a positive metric
        sentiment_scores.append(sentiment_score)

        # Distinct-n (e.g., Distinct-2)
        n = 2
        generated_ngrams = set()
        tokens = generated_tokens
        for i in range(len(tokens) - n + 1):
            ngram = tuple(tokens[i:i+n])
            generated_ngrams.add(ngram)
        distinct_n = len(generated_ngrams) / max(1, len(tokens) - n + 1)
        distinct_n_scores.append(distinct_n)

    # Memory Usage
    memory_usage = torch.cuda.memory_allocated(device) / (1024 ** 3)  # Convert to GB

    # Aggregate results
    metrics = {
        "Average Perplexity": sum(perplexities) / len(perplexities),
        "Average BLEU": sum(bleu_scores) / len(bleu_scores),
        "Average ROUGE-1": sum(rouge_scores) / len(rouge_scores),
        "Average Sentiment Score (Joy)": sum(sentiment_scores) / len(sentiment_scores),
        "Average Distinct-2": sum(distinct_n_scores) / len(distinct_n_scores),
        "Average Inference Time (s)": sum(inference_times) / len(inference_times),
        "Memory Usage (GB)": memory_usage
    }

    return metrics

# Run evaluation
device = torch.device('cuda')
metrics = evaluate_chatbot(model, tokenizer, val_dataset, bert_model, bert_tokenizer, device)
print("Evaluation Metrics:")
for metric, value in metrics.items():
    print(f"{metric}: {value:.4f}")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]


KeyError: 'prompt'