<a href="https://colab.research.google.com/github/Anmollllll/SuicideTextDetection_and_ResponseModel/blob/main/BertModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install clean-text
!pip install unidecode

In [None]:
import numpy as np
import pandas as pd
import os
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

In [None]:

# Changing to own directory
os.chdir("/content/drive/MyDrive/SusDetect")
print("Directory changed")

# Loading dataset
df = pd.read_csv('SusReddit2023dataset.csv')
df.head()


Data Cleaning

In [None]:
df.drop(columns=['Title'],inplace=True)

In [None]:
df.head()

In [None]:
#checking null values
df.isna().sum()

In [None]:
#Dropping null values
df.dropna(inplace=True)

In [None]:
#checking duplicate posts
df.duplicated().sum()

In [None]:
# View the duplicate rows
duplicates = df[df.duplicated()]
print(duplicates)

In [None]:
df.drop_duplicates(keep='first',inplace=True)

In [None]:
len(df)

In [None]:
df.duplicated().sum()

Data Preprocessing & EDA


In [None]:
df['Label'].value_counts()

In [None]:

df = df[df['Label'] != 'Label']

In [None]:
df['Label'].value_counts()

In [None]:
#We can see the binary class is well balanced
#Now using label encoder to encode the labels
from sklearn.preprocessing import LabelEncoder
encoder = LabelEncoder()
df['Label']=encoder.fit_transform(df['Label'])

In [None]:
df['Label'].value_counts()

In [None]:
df.head()

In [None]:
import re
from collections import Counter


class TextQualityAnalyzer:
    def __init__(self):
        self.encoding_artifacts = [
            'â€™', 'â€œ', 'â€', 'â€"', 'â€"', 'Ã', 'Â', 'ï¿½',
            'Ã¡', 'Ã©', 'Ã­', 'Ã³', 'Ãº', 'Ã±', 'â€¦'
        ]

        self.suspicious_patterns = {
            'html_tags': r'<[^>]+>',
            'excessive_punctuation': r'[!?]{4,}|[.]{4,}',
            'excessive_caps': r'\b[A-Z]{10,}\b',
            'urls': r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+',
            'emails': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
            'phone_numbers': r'(\+\d{1,3}[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}',
            'excessive_whitespace': r'\s{3,}',
            'repeated_words': r'\b(\w+)(\s+\1){3,}\b',
            'special_chars': r'[^\w\s\.,!?;:\'"()\-/_]',
            'control_chars': r'[\x00-\x1f\x7f-\x9f]'
        }

    def analyze_single_text(self, text):
        """Analyze a single text for quality issues"""
        if pd.isna(text):
            return {
                'is_empty': True,
                'length': 0,
                'word_count': 0,
                'issues': ['empty_text']
            }

        text_str = str(text)
        issues = []


        # Basic statistics
        length = len(text_str)
        words = text_str.split()
        word_count = len(words)

        # Check for emptiness
        if length == 0:
            issues.append('empty_text')

        # Check for encoding issues
        encoding_issues = [artifact for artifact in self.encoding_artifacts if artifact in text_str]
        if encoding_issues:
            issues.append('encoding_artifacts')

        # Check for suspicious patterns
        for issue_name, pattern in self.suspicious_patterns.items():
            if re.search(pattern, text_str):
                issues.append(issue_name)

        # Length-based issues
        if word_count < 3:
            issues.append('too_short')
        elif word_count > 512:
            issues.append('too_long')

        # Repetition analysis
        if word_count > 0:
            word_freq = Counter([word.lower() for word in words])
            most_common_freq = word_freq.most_common(1)[0][1] if word_freq else 0
            if most_common_freq > word_count * 0.3:
                issues.append('highly_repetitive')

        # Character diversity
        # Detect texts with very low character diversity (like "aaaaaaa" or repeated patterns)
        unique_chars = len(set(text_str))
        if unique_chars < 10 and length > 50:
            issues.append('low_char_diversity')

        # Language detection (basic)
        # Detect text that's primarily in non-English languages or contains lots of special Unicode characters
        ascii_ratio = sum(1 for c in text_str if ord(c) < 128) / length if length > 0 else 0
        if ascii_ratio < 0.8:
            issues.append('non_ascii_heavy')

        return {
            'is_empty': length == 0,
            'length': length,
            'word_count': word_count,
            'unique_chars': unique_chars,
            'ascii_ratio': ascii_ratio,
            'encoding_artifacts': encoding_issues,
            'issues': issues
        }

    def analyze_dataframe(self, df, text_column):
        """Analyze entire dataframe for text quality issues"""
        results = []

        print(f"Analyzing {len(df)} texts...")

        for idx, text in enumerate(df[text_column]):
            if idx % 1000 == 0:
                print(f"Processed {idx}/{len(df)} texts")

            analysis = self.analyze_single_text(text)
            analysis['index'] = idx
            analysis['text_preview'] = str(text) if not pd.isna(text) else 'NaN'
            results.append(analysis)

        return pd.DataFrame(results)

    def generate_report(self, analysis_df, sample_size=5):
        """Generate comprehensive quality report"""
        print("=" * 60)
        print("TEXT QUALITY ANALYSIS REPORT")
        print("=" * 60)

        total_texts = len(analysis_df)
        # Overall statistics
        print(f"\n OVERALL STATISTICS:")
        print(f"Total texts analyzed: {total_texts}")
        print(f"Average length: {analysis_df['length'].mean():.1f} characters")
        print(f"Average word count: {analysis_df['word_count'].mean():.1f} words")
        print(f"Texts with issues: {len(analysis_df[analysis_df['issues'].str.len() > 0])}")

        # Issue frequency
        all_issues = []
        for issues_list in analysis_df['issues']:
            all_issues.extend(issues_list)

        issue_counts = Counter(all_issues)

        print(f"\n TOP ISSUES FOUND:")
        for issue, count in issue_counts.most_common(10):
            percentage = (count / total_texts) * 100
            print(f"  {issue}: {count} texts ({percentage:.1f}%)")

        # Examples for each major issue
        print(f"\n SAMPLE PROBLEMATIC TEXTS:")
        for issue, count in issue_counts.most_common(5):
            print(f"\n--- {issue.upper()} (Found in {count} texts) ---")

            # Find examples of this issue
            examples = analysis_df[analysis_df['issues'].apply(lambda x: issue in x)]

            for i, (_, row) in enumerate(examples.head(sample_size).iterrows()):
                print(f"  Example {i+1} (Index {row['index']}):")
                print(f"    Text: {row['text_preview']}")



In [None]:

analyzer = TextQualityAnalyzer()
analysis_results = analyzer.analyze_dataframe(df, 'Post')
issue_summary = analyzer.generate_report(analysis_results)

In [None]:
import cleantext

def clean_dataset(text):

    if pd.isna(text):
        return ""
    text = str(text)

    text = cleantext.clean(text,
        fix_unicode=True,
        lower=True,
        no_line_breaks=True,
        no_urls=True,
        no_emails=True,
        no_phone_numbers=True,
        no_numbers=False,
        no_punct=False,
        normalize_whitespace=True,
    )

    # Additional patterns cleantext doesn't handle
    text = re.sub(r'<[^>]+>', ' ', text)   # HTML tags
    text = re.sub('[^\w\s\.,!?;:\'"()\-/_]',' ',text)  # special characters
    text = re.sub(r'([!?\.]){4,}', r'\1\1\1', text)   # Excessive punctuation (max 3)
    text = re.sub(r'([.]){4,}', r'\1\1\1', text)   # Excessive periods
    text = re.sub(r'\b[A-Z]{10,}\b', lambda m: m.group().lower(), text)  # Excessive caps
    text = re.sub(r'\b(\w+)(\s+\1){3,}\b', r'\1', text)    # Repeated words


    # Final whitespace cleanup (in case regex added extra spaces)
    text = ' '.join(text.split())

    return text

In [None]:
df['cleaned_text'] = df['Post'].apply(clean_dataset)

In [None]:
df.head()

In [None]:
analyzer1 = TextQualityAnalyzer()
analysis_results1 = analyzer1.analyze_dataframe(df, 'cleaned_text')
issue_summary1 = analyzer1.generate_report(analysis_results1)


In [None]:
new_df = df[['cleaned_text','Label']]

In [None]:
new_df.head()

Data Visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


def bert_eda(df, text_col, label_col):

    # Text Length Analysis
    df['text_length'] = df[text_col].apply(len)
    df['word_count'] = df[text_col].apply(lambda x: len(x.split()))

    print("\n--- Text Length and Word Count Analysis ---")

    plt.figure(figsize=(14, 6))
    plt.subplot(1, 2, 1)
    sns.histplot(df['text_length'], bins=50, kde=True)
    plt.title('Distribution of Text Lengths')
    plt.xlabel('Character Count')
    plt.ylabel('Frequency')

    plt.subplot(1, 2, 2)
    sns.histplot(df['word_count'], bins=50, kde=True)
    plt.title('Distribution of Word Counts')
    plt.xlabel('Word Count')
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.show()
    print("\nText length and word count by label:")
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))

    sns.histplot(data=df, x='text_length', hue=label_col, kde=True, ax=axes[0])
    axes[0].set_title(f'Text Length Distribution by {label_col}')
    axes[0].set_xlabel('Character Count')
    axes[0].set_ylabel('Frequency')

    sns.histplot(data=df, x='word_count', hue=label_col, kde=True, ax=axes[1])
    axes[1].set_title(f'Word Count Distribution by {label_col}')
    axes[1].set_xlabel('Word Count')
    axes[1].set_ylabel('Frequency')
    plt.tight_layout()
    plt.show()


    #Vocabulary Analysis
    print("\n---  Vocabulary Analysis ---")
    all_words = ' '.join(df[text_col]).lower()
    # Remove punctuation
    all_words = re.sub(r'[^\w\s]', '', all_words)
    words = all_words.split()
    word_counts = Counter(words)
    print(f"Total unique words in corpus: {len(word_counts)}")

    #Top words by label
    print("\n--- Top 30 Most Common Words by Label ---")
    suicidal = []
    non_suicidal = []
    for text in new_df[new_df['Label']==1]['cleaned_text'].to_list():
      for wordd in text.split():
        suicidal.append(wordd)
    for text in new_df[new_df['Label']==0]['cleaned_text'].to_list():
      for wordd in text.split():
        non_suicidal.append(wordd)

    print("\n--- For label-Suicidal ---")
    sns.barplot(x=pd.DataFrame(Counter(suicidal).most_common(30))[0],y=pd.DataFrame(Counter(suicidal).most_common(30))[1])
    plt.xticks(rotation=90)
    plt.show()

    print("\n--- For label-Non_Suicidal ---")
    sns.barplot(x=pd.DataFrame(Counter(non_suicidal).most_common(30))[0],y=pd.DataFrame(Counter(non_suicidal).most_common(30))[1])
    plt.xticks(rotation=90)
    plt.show()


    # Clean up added columns
    df.drop(columns=['text_length', 'word_count', 'has_url', 'has_hashtag', 'has_mention'], inplace=True, errors='ignore')


In [None]:
bert_eda(new_df,'cleaned_text','Label')

Model Building

In [None]:
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer,BertForSequenceClassification,Trainer,TrainingArguments
import torch
from torch.utils.data import Dataset

In [None]:
X_train,X_test,y_train,y_test=train_test_split(new_df['cleaned_text'],new_df['Label'],test_size=0.2,random_state=42,shuffle=True,stratify=df['Label'])

In [None]:
tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
train_encodings = tokenizer(X_train.tolist(),truncation=True,padding=True,max_length=512)
test_encodings = tokenizer(X_test.tolist(),truncation=True,padding=True,max_length=512)

In [None]:
class TextDataset(Dataset):
  def __init__(self,encodings,labels):
    self.encodings = encodings
    self.labels = labels

  def __len__(self):
    return len(self.labels)

  def __getitem__(self,idx):
    item = {key:torch.tensor(val[idx]) for key,val in self.encodings.items()}
    item['labels'] = torch.tensor(self.labels[idx])
    return item


In [None]:
train_dataset = TextDataset(train_encodings, y_train.reset_index(drop=True))
val_dataset = TextDataset(test_encodings, y_test.reset_index(drop=True))

In [None]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)




In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=50,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

In [None]:
metrics = trainer.evaluate()
print(metrics)

In [None]:

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix


print("\n" + "="*70)
print("EVALUATING EACH EPOCH CHECKPOINT")
print("="*70)
def evaluate_model(trainer, eval_dataset, dataset_name):

    # Get predictions
    predictions = trainer.predict(eval_dataset)
    preds = np.argmax(predictions.predictions, axis=1)
    labels = predictions.label_ids

    # Calculate metrics
    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')

    print(f"\n{'='*50}")
    print(f"{dataset_name} Set Metrics:")
    print(f"{'='*50}")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-Score:  {f1:.4f}")
    print(f"\nDetailed Classification Report:")
    print(classification_report(labels, preds, target_names=['Class 0', 'Class 1']))
    print(f"\nConfusion Matrix:")
    print(confusion_matrix(labels, preds))

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

# all checkpoint directories
checkpoint_dirs = sorted([d for d in os.listdir("./results") if d.startswith("checkpoint-")])

epoch_results = []

for i, checkpoint_dir in enumerate(checkpoint_dirs, 1):
    checkpoint_path = os.path.join("./results", checkpoint_dir)

    print(f"\n{'='*50}")
    print(f"Epoch {i} - Checkpoint: {checkpoint_dir}")
    print(f"{'='*50}")

    model_checkpoint = BertForSequenceClassification.from_pretrained(checkpoint_path)

    trainer_checkpoint = Trainer(
        model=model_checkpoint,
        args=training_args,
        tokenizer=tokenizer,
        eval_dataset=val_dataset
    )

    metrics = evaluate_model(trainer_checkpoint, val_dataset, f"Epoch {i}")
    metrics['epoch'] = i
    metrics['checkpoint'] = checkpoint_dir
    epoch_results.append(metrics)

# Summary comparison across all epochs
print("\n" + "="*70)
print("SUMMARY: Performance Across All Epochs")
print("="*70)
print(f"{'Epoch':<10} {'Accuracy':<12} {'Precision':<12} {'Recall':<12} {'F1-Score':<12}")
print("-"*70)
for result in epoch_results:
    print(f"{result['epoch']:<10} {result['accuracy']:<12.4f} {result['precision']:<12.4f} "
          f"{result['recall']:<12.4f} {result['f1']:<12.4f}")

# Find best epoch
best_epoch = max(epoch_results, key=lambda x: x['f1'])
print(f"\n Best performing epoch: Epoch {best_epoch['epoch']} (F1-Score: {best_epoch['f1']:.4f})")
