In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import string
import re

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score,classification_report, confusion_matrix

import joblib


In [2]:
df = pd.read_csv("../datasets/classification.csv",encoding='ISO-8859-1')
df.head()



Unnamed: 0,Text,Label
0,"I keep imagining the worst possible outcome, e...",Anxiety
1,"My heart races before every phone call, even c...",Anxiety
2,I cant stop overthinking what I said in that ...,Anxiety
3,Every little noise makes me jump lately.,Anxiety
4,I feel like something bad is about to happen a...,Anxiety


In [3]:
df.sample(n=5)

Unnamed: 0,Text,Label
434,I feel like I dont deserve to be happy.,Post-Traumatic Stress Disorder (PTSD)
207,"I procrastinate until I panic, then work in a ...",ADHD( Attention deficit hyperactivity disorder)
51,"If I dont check the door three times, I panic...",Obsessive-Compulsive Disorder (OCD)
75,"I try not to think about it, but it keeps comi...",Post-Traumatic Stress Disorder (PTSD)
353,I cancel plans if I overeat earlier in the day.,Eating Disorder


In [4]:
print("Shape:", df.shape)
print("Columns:", df.columns.tolist())

Shape: (450, 2)
Columns: ['Text', 'Label']


In [5]:
print(df['Label'].value_counts(normalize=True))


Label
Anxiety                                            0.111111
Depression                                         0.111111
ADHD( Attention deficit hyperactivity disorder)    0.111111
Bipolar Disorder                                   0.111111
Autism Spectrum Disorder (ASD)                     0.111111
Obsessive-Compulsive Disorder (OCD)                0.111111
Schizophrenia                                      0.111111
Post-Traumatic Stress Disorder (PTSD)              0.111111
Eating Disorder                                    0.111111
Name: proportion, dtype: float64


In [6]:
# Check missing/nulls
print(df.isnull().sum())




Text     0
Label    0
dtype: int64


In [7]:
#Text Preprocessing


In [8]:
# Lowercase
df['Text'] = df['Text'].str.lower()

# Remove non-ASCII characters (already added)
df['Text'] = df['Text'].str.encode('ascii', 'ignore').str.decode('ascii')

# Remove punctuation
df['Text'] = df['Text'].apply(lambda x: x.translate(str.maketrans('', '', string.punctuation)))

# Remove extra whitespace
df['Text'] = df['Text'].apply(lambda x: re.sub('\s+', ' ', x).strip())



In [9]:
#Augmentation

In [12]:
!pip install transformers sentencepiece
#Here’s code that:

#Loads a T5 model for paraphrasing

#Augments 1–3 paraphrases per sample (configurable)

#Ensures new rows retain the same label






In [13]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from tqdm import tqdm

# Load T5-base paraphrasing model
paraphrase_tokenizer = AutoTokenizer.from_pretrained("ramsrigouthamg/t5_paraphraser")
paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("ramsrigouthamg/t5_paraphraser")
paraphrase_model.eval()

# If you have a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
paraphrase_model = paraphrase_model.to(device)

def paraphrase_text(text, num_return_sequences=2, num_beams=5):
    input_text = f"paraphrase: {text} </s>"
    encoding = paraphrase_tokenizer.encode_plus(input_text, padding='max_length', max_length=128, return_tensors="pt")
    input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)

    outputs = paraphrase_model.generate(
        input_ids=input_ids,
        attention_mask=attention_masks,
        max_length=128,
        num_beams=num_beams,
        num_return_sequences=num_return_sequences,
        temperature=1.5
    )

    paraphrases = [paraphrase_tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                   for output in outputs]

    return paraphrases


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

In [14]:
augmented_texts = []
augmented_labels = []

for i, row in tqdm(df.iterrows(), total=len(df)):
    original_text = row['Text']
    label = row['Label']

    try:
        paraphrases = paraphrase_text(original_text, num_return_sequences=2)

        for para in paraphrases:
            augmented_texts.append(para)
            augmented_labels.append(label)

    except Exception as e:
        print(f"Error for row {i}: {e}")
        continue

# Create new DataFrame for augmented data
aug_df = pd.DataFrame({
    'Text': augmented_texts,
    'Label': augmented_labels
})

# Merge original and augmented
df_augmented = pd.concat([df, aug_df], ignore_index=True)


  0%|          | 0/450 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 1/450 [00:19<2:25:39, 19.46s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  0%|          | 2/450 [00:22<1:12:33,  9.72s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  1%|          | 3/450 [00:27<56:34,  7.59s/it]  The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  1%|          | 4/450 [00:30<43:49,  5.90s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  1%|          | 5/450 [00:36<42:50,  5.78s/it]The following generation flags are not valid and m

 10%|█         | 45/450 [04:49<27:55,  4.14s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 10%|█         | 46/450 [05:08<58:39,  8.71s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 10%|█         | 47/450 [05:10<44:41,  6.65s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 11%|█         | 48/450 [05:13<38:07,  5.69s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 11%|█         | 49/450 [05:16<31:37,  4.73s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 11%|█         | 50/450 [05:34<59:03,  8.86s/it]The following generation flags are not val

 20%|██        | 90/450 [09:58<52:21,  8.73s/it]  The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 20%|██        | 91/450 [10:04<48:46,  8.15s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 20%|██        | 92/450 [10:08<40:02,  6.71s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 21%|██        | 93/450 [10:11<34:16,  5.76s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 21%|██        | 94/450 [10:14<29:04,  4.90s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 21%|██        | 95/450 [10:32<51:14,  8.66s/it]The following generation flags are not v

 30%|███       | 135/450 [14:25<31:33,  6.01s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 30%|███       | 136/450 [14:27<25:30,  4.87s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 30%|███       | 137/450 [14:33<27:06,  5.20s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 31%|███       | 138/450 [14:35<21:42,  4.18s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 31%|███       | 139/450 [14:38<19:36,  3.78s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 31%|███       | 140/450 [14:42<20:49,  4.03s/it]The following generation flags are n

 40%|████      | 180/450 [18:59<21:41,  4.82s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 40%|████      | 181/450 [19:02<18:30,  4.13s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 40%|████      | 182/450 [19:04<16:06,  3.61s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 41%|████      | 183/450 [19:21<34:16,  7.70s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 41%|████      | 184/450 [19:39<47:00, 10.61s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 41%|████      | 185/450 [19:44<39:37,  8.97s/it]The following generation flags are n

 50%|█████     | 225/450 [26:10<43:40, 11.65s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 50%|█████     | 226/450 [26:14<34:50,  9.33s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 50%|█████     | 227/450 [26:32<44:30, 11.98s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 51%|█████     | 228/450 [26:37<35:55,  9.71s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 51%|█████     | 229/450 [26:39<27:47,  7.54s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 51%|█████     | 230/450 [26:58<39:45, 10.84s/it]The following generation flags are n

 60%|██████    | 270/450 [36:36<12:51,  4.29s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 60%|██████    | 271/450 [36:55<25:43,  8.63s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 60%|██████    | 272/450 [36:58<20:26,  6.89s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 61%|██████    | 273/450 [37:01<16:54,  5.73s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 61%|██████    | 274/450 [37:03<14:08,  4.82s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 61%|██████    | 275/450 [37:06<12:28,  4.28s/it]The following generation flags are n

 70%|███████   | 315/450 [41:04<22:02,  9.79s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 70%|███████   | 316/450 [41:07<17:19,  7.76s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 70%|███████   | 317/450 [41:25<24:15, 10.95s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 71%|███████   | 318/450 [41:27<18:24,  8.37s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 71%|███████   | 319/450 [41:32<16:09,  7.40s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 71%|███████   | 320/450 [41:36<13:42,  6.32s/it]The following generation flags are n

 80%|████████  | 360/450 [46:50<09:27,  6.31s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 80%|████████  | 361/450 [46:57<09:48,  6.61s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 80%|████████  | 362/450 [47:01<08:34,  5.85s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 81%|████████  | 363/450 [47:19<13:42,  9.45s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 81%|████████  | 364/450 [47:22<10:40,  7.45s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 81%|████████  | 365/450 [47:28<10:00,  7.06s/it]The following generation flags are n

 90%|█████████ | 405/450 [51:59<03:23,  4.52s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 90%|█████████ | 406/450 [52:03<03:00,  4.10s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 90%|█████████ | 407/450 [52:21<06:03,  8.45s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 91%|█████████ | 408/450 [52:39<07:58, 11.39s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 91%|█████████ | 409/450 [52:41<05:48,  8.51s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 91%|█████████ | 410/450 [52:45<04:39,  6.99s/it]The following generation flags are n

100%|██████████| 450/450 [58:19<00:00,  7.78s/it]


In [16]:
print(df_augmented['Label'].value_counts())


Label
Anxiety                                            150
Depression                                         150
ADHD( Attention deficit hyperactivity disorder)    150
Bipolar Disorder                                   150
Autism Spectrum Disorder (ASD)                     150
Obsessive-Compulsive Disorder (OCD)                150
Schizophrenia                                      150
Post-Traumatic Stress Disorder (PTSD)              150
Eating Disorder                                    150
Name: count, dtype: int64


In [15]:
# Save your augmented dataset
df_augmented.to_csv("augmented_dataset.csv", index=False)


In [17]:
import os
os.listdir()


['.ipynb_checkpoints', 'augmented_dataset.csv', 'classifier_training.ipynb']

In [18]:

print(os.path.abspath("augmented_classification.csv"))


C:\Users\Asus\MindCheck_AI\notebooks\augmented_classification.csv
