In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import re

# Load all datasets
datasets = {
    'ADHD': pd.read_csv('/kaggle/input/reddit/adhd.csv'),
    'Aspergers': pd.read_csv('/kaggle/input/reddit/aspergers.csv'),
    'Depression': pd.read_csv('/kaggle/input/reddit/aspergers.csv'),
    'OCD': pd.read_csv('/kaggle/input/reddit/ocd.csv'),
    'PTSD': pd.read_csv('/kaggle/input/reddit/ptsd.csv')
}

# Explore structure
print("="*80)
print("DATASET EXPLORATION")
print("="*80)

for name, df in datasets.items():
    print(f"\n{'='*80}")
    print(f"Dataset: {name}")
    print(f"{'='*80}")
    print(f"Shape: {df.shape}")
    print(f"Columns: {df.columns.tolist()}")
    print(f"\nFirst 2 rows:")
    print(df.head(2))
    print(f"\nData types:")
    print(df.dtypes)
    print(f"\nMissing values:")
    print(df.isnull().sum())
    print(f"\nSample text (if available):")
    # Try to find text column
    text_cols = [col for col in df.columns if any(keyword in col.lower() for keyword in ['text', 'post', 'content', 'body', 'selftext'])]
    if text_cols:
        sample_text = df[text_cols[0]].iloc[0] if len(df) > 0 else "No data"
        print(f"{sample_text[:200]}...")
    else:
        print("No obvious text column found")

DATASET EXPLORATION

Dataset: ADHD
Shape: (37109, 10)
Columns: ['author', 'body', 'created_utc', 'id', 'num_comments', 'score', 'subreddit', 'title', 'upvote_ratio', 'url']

First 2 rows:
                author                                               body  \
0  HotConversation1273  A few months ago I was accepted into this full...   
1           snorefestt  Hey guys, I was curious if anyone else has the...   

                created_utc      id  num_comments  score subreddit  \
0  2021-12-22T18:32:56.000Z  rmbjwb             1      1      ADHD   
1  2021-12-22T18:24:25.000Z  rmbd1y             3      5      ADHD   

                                               title  upvote_ratio  \
0    I get extremely anxious if I’m not working 24/7           1.0   
1  I can't will myself to clean my own house, but...           1.0   

                                                 url  
0  https://www.reddit.com/r/ADHD/comments/rmbjwb/...  
1  https://www.reddit.com/r/ADHD/comments/rmbd1y

In [2]:
import pandas as pd
import re

# ============================================================================
# STEP 2: STANDARDIZE AND CLEAN DATASETS
# ============================================================================

def standardize_dataset(df, condition_name):
    """
    Standardize dataset to common format
    Combines 'title' and 'body' into single text field
    """
    standardized_df = pd.DataFrame()

    # Combine title and body for richer context
    # Many Reddit posts have important info in both title and body
    standardized_df['title'] = df['title'].fillna('')
    standardized_df['body'] = df['body'].fillna('')

    # Combine: "Title. Body"
    standardized_df['text'] = (
        standardized_df['title'].astype(str) + '. ' +
        standardized_df['body'].astype(str)
    )

    # Add condition label
    standardized_df['condition'] = condition_name

    # Add metadata
    standardized_df['score'] = df['score']
    standardized_df['num_comments'] = df['num_comments']
    standardized_df['created_utc'] = df['created_utc']
    standardized_df['subreddit'] = df['subreddit']
    standardized_df['author'] = df['author']

    return standardized_df

# Standardize all datasets
print("="*80)
print("STANDARDIZING DATASETS")
print("="*80)

standardized_datasets = {}
for name, df in datasets.items():
    print(f"\nProcessing {name}...")
    standardized_datasets[name] = standardize_dataset(df, name)
    print(f"  ✅ {name}: {len(standardized_datasets[name])} posts")

# Combine into single dataset
combined_df = pd.concat(standardized_datasets.values(), ignore_index=True)

print(f"\n{'='*80}")
print(f"COMBINED DATASET")
print(f"{'='*80}")
print(f"Total posts: {len(combined_df)}")
print(f"\nCondition distribution:")
print(combined_df['condition'].value_counts())

# Check for empty or very short texts
combined_df['text_length'] = combined_df['text'].str.len()

print(f"\nText length statistics:")
print(combined_df['text_length'].describe())

# Filter out very short texts (less than 10 characters)
print(f"\nPosts with text length < 10: {(combined_df['text_length'] < 10).sum()}")

# Show sample from each condition
print(f"\n{'='*80}")
print("SAMPLE POSTS FROM EACH CONDITION")
print(f"{'='*80}")

for condition in combined_df['condition'].unique():
    print(f"\n{condition}:")
    print("-" * 80)
    sample = combined_df[combined_df['condition'] == condition].iloc[0]
    print(f"Text: {sample['text'][:300]}...")
    print(f"Length: {sample['text_length']} characters")

print(f"\n✅ Step 2 complete!")
print(f"\nCombined dataset shape: {combined_df.shape}")
print(f"Columns: {combined_df.columns.tolist()}")

STANDARDIZING DATASETS

Processing ADHD...
  ✅ ADHD: 37109 posts

Processing Aspergers...
  ✅ Aspergers: 23294 posts

Processing Depression...
  ✅ Depression: 23294 posts

Processing OCD...
  ✅ OCD: 42826 posts

Processing PTSD...
  ✅ PTSD: 24028 posts

COMBINED DATASET
Total posts: 150551

Condition distribution:
condition
OCD           42826
ADHD          37109
PTSD          24028
Depression    23294
Aspergers     23294
Name: count, dtype: int64

Text length statistics:
count    150551.000000
mean        609.893285
std         986.734149
min           9.000000
25%          60.000000
50%         280.000000
75%         806.000000
max       40126.000000
Name: text_length, dtype: float64

Posts with text length < 10: 1

SAMPLE POSTS FROM EACH CONDITION

ADHD:
--------------------------------------------------------------------------------
Text: I get extremely anxious if I’m not working 24/7. A few months ago I was accepted into this full time software engineering fellowship and it’s mad

In [3]:
import re
import nltk

# Download NLTK data (run once)
try:
    nltk.download('punkt', quiet=True)
    nltk.download('stopwords', quiet=True)
    nltk.download('wordnet', quiet=True)
    print("✅ NLTK data ready")
except:
    print("⚠️ NLTK download failed, continuing anyway...")

# ============================================================================
# STEP 3: TEXT PREPROCESSING
# ============================================================================

class TextPreprocessor:
    def __init__(self):
        # Don't remove important negation/emotion words
        try:
            from nltk.corpus import stopwords
            self.stop_words = set(stopwords.words('english'))
            # Keep emotion-relevant words
            self.stop_words -= {
                'not', 'no', 'never', 'nothing', 'nobody', 'none',
                'neither', 'nor', 'can', 'cannot', 'couldn', 'shouldn',
                'don', 'won', 'wouldn', 'very', 'too', 'really', 'so',
                'just', 'but', 'however', 'although'
            }
        except:
            self.stop_words = set()

    def clean_text(self, text):
        """Clean Reddit-specific text"""
        if pd.isna(text) or text == '':
            return ""

        text = str(text)

        # Remove URLs
        text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)

        # Remove Reddit-specific markers
        text = re.sub(r'\[deleted\]|\[removed\]', '', text)

        # Remove user mentions and subreddit links
        text = re.sub(r'u/\w+|r/\w+', '', text)

        # Remove markdown formatting
        text = re.sub(r'\*\*|\*|__|_|~~', '', text)

        # Remove excessive punctuation (keep some for emotion)
        text = re.sub(r'[!]{3,}', '!!', text)
        text = re.sub(r'[?]{3,}', '??', text)
        text = re.sub(r'[.]{3,}', '...', text)

        # Remove newlines and tabs
        text = re.sub(r'\n|\t|\r', ' ', text)

        # Remove excessive whitespace
        text = re.sub(r'\s+', ' ', text).strip()

        # Remove non-ASCII characters (keep basic punctuation)
        text = re.sub(r'[^\x00-\x7F]+', '', text)

        return text

# Initialize preprocessor
preprocessor = TextPreprocessor()

print("="*80)
print("TEXT PREPROCESSING")
print("="*80)

# Apply cleaning
print("\nCleaning text data...")
combined_df['text_clean'] = combined_df['text'].apply(preprocessor.clean_text)

# Calculate new lengths
combined_df['text_clean_length'] = combined_df['text_clean'].str.len()

print("✅ Text cleaning complete!")

# Filter out very short posts (< 20 chars after cleaning)
print(f"\nPosts before filtering: {len(combined_df)}")
combined_df = combined_df[combined_df['text_clean_length'] >= 20].copy()
print(f"Posts after filtering (>= 20 chars): {len(combined_df)}")

# Filter out extremely long posts (> 5000 chars to avoid memory issues)
long_posts = (combined_df['text_clean_length'] > 5000).sum()
print(f"Posts > 5000 characters: {long_posts}")
combined_df = combined_df[combined_df['text_clean_length'] <= 5000].copy()
print(f"Posts after filtering (<= 5000 chars): {len(combined_df)}")

# Reset index
combined_df = combined_df.reset_index(drop=True)

# Statistics
print(f"\n{'='*80}")
print("CLEANED DATA STATISTICS")
print(f"{'='*80}")
print(f"Total posts: {len(combined_df)}")
print(f"\nCondition distribution:")
print(combined_df['condition'].value_counts())

print(f"\nText length statistics (after cleaning):")
print(combined_df['text_clean_length'].describe())

# Compare before/after cleaning
print(f"\n{'='*80}")
print("BEFORE vs AFTER CLEANING - SAMPLES")
print(f"{'='*80}")

for i in range(3):
    sample = combined_df.sample(1).iloc[0]
    print(f"\nSample {i+1} ({sample['condition']}):")
    print("-" * 80)
    print("BEFORE:")
    print(sample['text'][:300] + "...")
    print("\nAFTER:")
    print(sample['text_clean'][:300] + "...")
    print(f"Length: {sample['text_length']} → {sample['text_clean_length']}")

print(f"\n✅ Step 3 complete!")
print(f"\nFinal dataset: {combined_df.shape}")

✅ NLTK data ready
TEXT PREPROCESSING

Cleaning text data...
✅ Text cleaning complete!

Posts before filtering: 150551
Posts after filtering (>= 20 chars): 142265
Posts > 5000 characters: 1011
Posts after filtering (<= 5000 chars): 141254

CLEANED DATA STATISTICS
Total posts: 141254

Condition distribution:
condition
OCD           39643
ADHD          35325
PTSD          22170
Depression    22058
Aspergers     22058
Name: count, dtype: int64

Text length statistics (after cleaning):
count    141254.000000
mean        582.671585
std         736.938808
min          20.000000
25%          58.000000
50%         328.000000
75%         823.000000
max        5000.000000
Name: text_clean_length, dtype: float64

BEFORE vs AFTER CLEANING - SAMPLES

Sample 1 (OCD):
--------------------------------------------------------------------------------
BEFORE:
Has anyone every experienced this?. I was diagnosed at a very young age, and last night was possibly one of the most intense relapses I’ve ever had.

In [4]:
# ============================================================================
# STEP 4: CREATE MENTAL HEALTH EMOTION DIMENSION DICTIONARY (EDD)
# ============================================================================

class MentalHealthEDD:
    """
    Mental Health Emotion Dimension Dictionary
    Based on OCC model with 6 dimensions adapted for mental health
    """

    def __init__(self):
        self.emotion_dimensions = self._build_dimensions()
        self.intensity_modifiers = self._build_intensity_modifiers()

    def _build_dimensions(self):
        """Build 6 emotion dimensions with condition-specific terms"""

        dimensions = {
            'Desirable': {
                # General positive mental health
                'general': [
                    'better', 'improving', 'progress', 'recovery', 'healing',
                    'hopeful', 'hope', 'motivated', 'energy', 'strength',
                    'coping', 'managing', 'handling', 'dealing', 'surviving',
                    'breakthrough', 'milestone', 'achievement', 'proud',
                    'peaceful', 'calm', 'relaxed', 'content', 'stable',
                    'grateful', 'thankful', 'blessed', 'lucky', 'fortunate',
                    'good day', 'better day', 'felt good', 'slept well',
                    'clear headed', 'focused', 'productive', 'accomplished'
                ],

                # Treatment success
                'treatment': [
                    'therapy helps', 'therapy working', 'medication working',
                    'treatment helping', 'meds helping', 'feeling better',
                    'side effects manageable', 'right dosage', 'finding balance',
                    'therapist understands', 'good session', 'opened up'
                ],

                # ADHD-specific positive
                'adhd_positive': [
                    'hyperfocus', 'found strategy', 'remembered', 'organized',
                    'on time', 'finished task', 'stayed focused', 'productivity',
                    'stimulant working', 'less scattered', 'routine helping'
                ],

                # Depression-specific positive
                'depression_positive': [
                    'got out of bed', 'showered', 'ate meal', 'left house',
                    'socialized', 'enjoyed something', 'laughed', 'smiled',
                    'less numb', 'feeling again', 'wanted to live'
                ],

                # OCD-specific positive
                'ocd_positive': [
                    'resisted compulsion', 'exposure went well', 'anxiety decreased',
                    'less intrusive thoughts', 'erp working', 'broke ritual',
                    'challenged thought', 'uncertainty tolerance'
                ],

                # PTSD-specific positive
                'ptsd_positive': [
                    'no nightmares', 'no flashback', 'felt safe', 'grounded',
                    'emdr helping', 'processing trauma', 'less triggered',
                    'hypervigilance down', 'sleeping better'
                ],

                # Aspergers/ASD positive
                'asd_positive': [
                    'special interest', 'comfortable', 'understood', 'accepted',
                    'clear communication', 'sensory friendly', 'routine maintained',
                    'social success', 'masking less', 'authentic'
                ]
            },

            'Undesirable': {
                # General distress
                'general': [
                    'suffering', 'pain', 'hurt', 'struggling', 'difficult',
                    'hard', 'tough', 'unbearable', 'overwhelming', 'exhausted',
                    'tired', 'drained', 'burnt out', 'breaking', 'falling apart',
                    'losing it', 'scared', 'terrified', 'afraid',
                    'worried', 'anxious', 'nervous', 'panic', 'stress',
                    'crying', 'tears', 'sobbing', 'breakdown', 'episode'
                ],

                # Crisis indicators
                'crisis': [
                    'suicidal', 'kill myself', 'end it', 'die', 'death',
                    'not worth living', 'better off dead', 'giving up',
                    'no point', 'hopeless', 'helpless',
                    'self harm', 'cutting', 'hurting myself', 'crisis',
                    'emergency', 'hospital', 'psychiatric ward', 'hotline',
                    'cant go on', 'want to die'
                ],

                # ADHD symptoms
                'adhd_symptoms': [
                    'cant focus', 'distracted', 'scattered', 'forgetful',
                    'forgot again', 'lost track', 'missed deadline', 'late',
                    'overwhelmed', 'too much', 'executive dysfunction',
                    'time blindness', 'procrastinating', 'avoidance',
                    'rejection sensitivity', 'emotional dysregulation',
                    'impulsive', 'restless', 'racing thoughts'
                ],

                # Depression symptoms
                'depression_symptoms': [
                    'empty', 'numb', 'hollow', 'void', 'nothing', 'pointless',
                    'worthless', 'useless', 'failure', 'burden', 'waste',
                    'cant get up', 'bed all day', 'no energy', 'no motivation',
                    'dont care', 'anhedonia', 'no pleasure', 'no joy',
                    'isolating', 'alone', 'lonely', 'withdrawn', 'hiding',
                    'ruminating', 'negative thoughts', 'self hate', 'guilt',
                    'appetite loss', 'not eating', 'insomnia', 'oversleeping'
                ],

                # OCD symptoms
                'ocd_symptoms': [
                    'intrusive thoughts', 'obsessive', 'compulsion', 'ritual',
                    'checking', 'washing', 'counting', 'repeating',
                    'contamination', 'fear of harm', 'magical thinking',
                    'pure o', 'mental compulsion', 'reassurance seeking',
                    'cant stop', 'stuck', 'rumination', 'doubt', 'uncertainty',
                    'need certainty', 'spiral', 'loop', 'torture'
                ],

                # PTSD symptoms
                'ptsd_symptoms': [
                    'flashback', 'nightmare', 'triggered', 'trigger',
                    'hypervigilant', 'on edge', 'startle', 'jumpy',
                    'dissociate', 'dissociation', 'detached',
                    'avoidance', 'avoiding', 'memory gaps',
                    'panic attack', 'reliving', 'intrusive memory',
                    'traumatic', 'trauma', 'abuse', 'assault', 'violence'
                ],

                # Aspergers/ASD struggles
                'asd_symptoms': [
                    'sensory overload', 'too loud', 'too bright', 'overstimulated',
                    'meltdown', 'shutdown', 'masking', 'exhausted from masking',
                    'dont understand', 'confused', 'social anxiety',
                    'misunderstood', 'different', 'weird', 'outcast',
                    'routine broken', 'change', 'unexpected', 'uncertainty',
                    'eye contact uncomfortable', 'small talk', 'social cues'
                ]
            },

            'Praiseworthy': {
                # Healthcare providers
                'providers': [
                    'therapist', 'psychiatrist', 'psychologist', 'counselor',
                    'doctor', 'nurse', 'social worker', 'case manager',
                    'my therapist', 'my doctor', 'my psych',
                    'helpful', 'supportive', 'understanding', 'caring',
                    'listened', 'validated', 'believed me', 'took seriously',
                    'good therapist', 'amazing therapist', 'saved me'
                ],

                # Support system
                'support': [
                    'friend', 'friends', 'family', 'mom', 'dad', 'parent',
                    'partner', 'spouse', 'girlfriend', 'boyfriend',
                    'support group', 'community', 'reddit', 'this sub',
                    'someone', 'people', 'everyone here', 'you all',
                    'helped me', 'there for me', 'supported', 'encouraged',
                    'understood', 'accepted', 'didnt judge', 'loved'
                ],

                # Self-compassion
                'self': [
                    'proud of myself', 'did my best', 'trying', 'effort',
                    'kind to myself', 'self care', 'boundary', 'advocate',
                    'deserve', 'worth', 'valid', 'enough'
                ]
            },

            'Blameworthy': {
                # Self-blame
                'self_blame': [
                    'my fault', 'im stupid', 'im dumb', 'im weak',
                    'im pathetic', 'im useless', 'im worthless',
                    'im a failure', 'im a burden', 'im broken',
                    'i should', 'i shouldnt', 'why cant i', 'everyone else can',
                    'wrong with me', 'defective', 'damaged'
                ],

                # Blame others
                'blame_others': [
                    'their fault', 'they caused', 'they dont care',
                    'dont understand', 'dismissed', 'ignored', 'minimized',
                    'invalidated', 'judged', 'blamed', 'shamed', 'stigma',
                    'discriminated', 'mistreated', 'abused', 'toxic',
                    'gaslighted', 'manipulated', 'traumatized me'
                ],

                # System failures
                'system': [
                    'insurance denied', 'cant afford', 'too expensive',
                    'no insurance', 'waitlist', 'months wait', 'no availability',
                    'no help', 'nowhere to go', 'system failed',
                    'medication shortage', 'cant get meds', 'denied treatment'
                ]
            },

            'Confirmed': {
                # Diagnosis/validation
                'diagnosis': [
                    'diagnosed', 'diagnosis', 'confirmed', 'doctor said',
                    'test showed', 'evaluation', 'assessment results',
                    'officially', 'on paper', 'medical record'
                ],

                # Fear realization
                'realization': [
                    'as expected', 'knew it', 'i was right', 'getting worse',
                    'deteriorating', 'declining', 'relapse', 'relapsed',
                    'back to square one', 'happening again', 'pattern',
                    'predictable', 'inevitable'
                ]
            },

            'Disconfirmed': {
                # False alarms
                'false_alarm': [
                    'wasnt as bad', 'better than expected', 'overreacted',
                    'false alarm', 'unnecessary worry', 'made it through',
                    'survived', 'didnt happen', 'didnt die', 'safe now'
                ],

                # Unexpected positive
                'unexpected': [
                    'surprisingly good', 'unexpected relief', 'turned out ok',
                    'not as scary', 'manageable', 'handled it', 'got through',
                    'no relapse', 'stable', 'maintaining', 'still here'
                ]
            }
        }

        return dimensions

    def _build_intensity_modifiers(self):
        """Build intensity modifiers"""
        return {
            'amplifiers': {
                'extreme': ['extremely', 'severely', 'unbearably', 'impossibly', 'absolutely'],
                'high': ['very', 'really', 'so', 'incredibly', 'totally', 'completely'],
                'medium': ['pretty', 'quite', 'fairly', 'rather', 'somewhat']
            },
            'dampeners': ['a bit', 'a little', 'slightly', 'somewhat', 'kind of', 'sort of'],
            'negators': ['not', 'no', 'never', 'neither', 'nowhere', 'nobody', 'nothing', 'without']
        }

    def get_all_words(self, dimension):
        """Get all words for a dimension"""
        if dimension not in self.emotion_dimensions:
            return set()

        all_words = set()
        for category, words in self.emotion_dimensions[dimension].items():
            all_words.update(words)
        return all_words

    def get_intensity(self, word, context, dimension):
        """Calculate intensity score (0-5)"""
        base_intensity = 1.0

        # Crisis keywords get maximum intensity
        if dimension == 'Undesirable':
            crisis_words = self.emotion_dimensions['Undesirable']['crisis']
            if any(crisis in context.lower() for crisis in crisis_words):
                return 5.0

        # Check for amplifiers
        context_lower = context.lower()
        for level, amplifiers in self.intensity_modifiers['amplifiers'].items():
            if any(amp in context_lower for amp in amplifiers):
                if level == 'extreme':
                    base_intensity *= 2.0
                elif level == 'high':
                    base_intensity *= 1.5
                elif level == 'medium':
                    base_intensity *= 1.2

        # Check for dampeners
        if any(damp in context_lower for damp in self.intensity_modifiers['dampeners']):
            base_intensity *= 0.5

        # Check for negation (simple)
        words = context_lower.split()
        try:
            word_idx = words.index(word.lower())
            preceding_words = words[max(0, word_idx-3):word_idx]
            if any(neg in preceding_words for neg in self.intensity_modifiers['negators']):
                return 0.0  # Negated
        except ValueError:
            pass

        return min(base_intensity, 5.0)

    def get_statistics(self):
        """Print EDD statistics"""
        print("\n" + "="*80)
        print("EMOTION DIMENSION DICTIONARY STATISTICS")
        print("="*80)

        for dimension in self.emotion_dimensions:
            all_words = self.get_all_words(dimension)
            print(f"\n{dimension}: {len(all_words)} total words")
            for category, words in self.emotion_dimensions[dimension].items():
                print(f"  - {category}: {len(words)} words")

# Create EDD
print("="*80)
print("CREATING MENTAL HEALTH EMOTION DIMENSION DICTIONARY")
print("="*80)

EDD = MentalHealthEDD()
EDD.get_statistics()

print("\n✅ Step 4 complete!")
print("\nEDD created with 6 dimensions:")
print("  1. Desirable (positive mental health states)")
print("  2. Undesirable (negative symptoms/distress)")
print("  3. Praiseworthy (appreciation/support)")
print("  4. Blameworthy (blame/criticism)")
print("  5. Confirmed (validation of fears)")
print("  6. Disconfirmed (relief from avoided outcomes)")

CREATING MENTAL HEALTH EMOTION DIMENSION DICTIONARY

EMOTION DIMENSION DICTIONARY STATISTICS

Desirable: 98 total words
  - general: 37 words
  - treatment: 12 words
  - adhd_positive: 11 words
  - depression_positive: 11 words
  - ocd_positive: 8 words
  - ptsd_positive: 9 words
  - asd_positive: 10 words

Undesirable: 166 total words
  - general: 29 words
  - crisis: 21 words
  - adhd_symptoms: 19 words
  - depression_symptoms: 32 words
  - ocd_symptoms: 23 words
  - ptsd_symptoms: 22 words
  - asd_symptoms: 22 words

Praiseworthy: 60 total words
  - providers: 22 words
  - support: 26 words
  - self: 12 words

Blameworthy: 49 total words
  - self_blame: 17 words
  - blame_others: 19 words
  - system: 13 words

Confirmed: 23 total words
  - diagnosis: 10 words
  - realization: 13 words

Disconfirmed: 21 total words
  - false_alarm: 10 words
  - unexpected: 11 words

✅ Step 4 complete!

EDD created with 6 dimensions:
  1. Desirable (positive mental health states)
  2. Undesirable (neg

In [5]:
# ============================================================================
# STEP 5: EMOTION-COGNITIVE REASONING (ECR) SYSTEM
# ============================================================================

class MentalHealthECR:
    """
    Emotion-Cognitive Reasoning for Mental Health
    Implements 10 OCC-based rules
    """

    def __init__(self, EDD):
        self.EDD = EDD
        self.emotion_rules = self._build_emotion_rules()

    def _build_emotion_rules(self):
        """Build 10 emotion-cognitive rules"""
        return {
            # === SINGLE RULES ===
            'rule_1': {
                'type': 'single',
                'condition': ['Desirable'],
                'emotion': 'hope',
                'polarity': 'positive',
                'description': 'Positive coping or recovery signs'
            },
            'rule_2': {
                'type': 'single',
                'condition': ['Undesirable'],
                'emotion': 'distress',
                'polarity': 'negative',
                'description': 'Symptoms or suffering'
            },
            'rule_3': {
                'type': 'single',
                'condition': ['Praiseworthy'],
                'emotion': 'gratitude',
                'polarity': 'positive',
                'description': 'Appreciation for support/treatment'
            },
            'rule_4': {
                'type': 'single',
                'condition': ['Blameworthy'],
                'emotion': 'reproach',
                'polarity': 'negative',
                'description': 'Blame (self or others)'
            },

            # === COMPOUND RULES ===
            'rule_5': {
                'type': 'compound',
                'condition': ['Desirable', 'Praiseworthy'],
                'emotion': 'gratitude',
                'polarity': 'positive',
                'description': 'Thankful for improvement/help'
            },
            'rule_6': {
                'type': 'compound',
                'condition': ['Undesirable', 'Blameworthy'],
                'emotion': 'anger',
                'polarity': 'negative',
                'description': 'Anger about suffering/cause'
            },
            'rule_7': {
                'type': 'compound',
                'condition': ['Desirable', 'Confirmed'],
                'emotion': 'relief',
                'polarity': 'positive',
                'description': 'Relief from positive confirmation'
            },
            'rule_8': {
                'type': 'compound',
                'condition': ['Undesirable', 'Confirmed'],
                'emotion': 'fear',
                'polarity': 'negative',
                'description': 'Fears realized (diagnosis, relapse)'
            },
            'rule_9': {
                'type': 'compound',
                'condition': ['Desirable', 'Disconfirmed'],
                'emotion': 'relief',
                'polarity': 'positive',
                'description': 'Relief from avoided negative outcome'
            },
            'rule_10': {
                'type': 'compound',
                'condition': ['Undesirable', 'Disconfirmed'],
                'emotion': 'disappointment',
                'polarity': 'negative',
                'description': 'Disappointment from failed hope'
            }
        }

    def identify_emotion_words(self, text):
        """Step 1: Identify emotion words from text"""
        text_lower = text.lower()
        emotion_words = {
            'Desirable': [],
            'Undesirable': [],
            'Praiseworthy': [],
            'Blameworthy': [],
            'Confirmed': [],
            'Disconfirmed': []
        }

        for dimension in emotion_words.keys():
            all_words = self.EDD.get_all_words(dimension)

            for word in all_words:
                pattern = r'\b' + re.escape(word.lower()) + r'\b'
                matches = re.finditer(pattern, text_lower)

                for match in matches:
                    position = match.start()
                    context_start = max(0, position - 50)
                    context_end = min(len(text), position + 50)
                    context = text[context_start:context_end]

                    intensity = self.EDD.get_intensity(word, context, dimension)

                    if intensity > 0:
                        emotion_words[dimension].append({
                            'word': word,
                            'position': position,
                            'intensity': intensity,
                            'context': context
                        })

        return emotion_words

    def apply_rules(self, emotion_words):
        """Step 2: Apply 10 emotion-cognitive rules"""
        inferred_emotions = []

        # Compound rules first (higher priority)
        for rule_name in sorted(self.emotion_rules.keys()):
            rule = self.emotion_rules[rule_name]

            if rule['type'] == 'compound':
                conditions = rule['condition']

                if all(len(emotion_words[cond]) > 0 for cond in conditions):
                    words_dim1 = emotion_words[conditions[0]]
                    words_dim2 = emotion_words[conditions[1]]

                    for w1 in words_dim1[:3]:
                        for w2 in words_dim2[:3]:
                            inferred_emotions.append({
                                'rule': rule_name,
                                'emotion': rule['emotion'],
                                'polarity': rule['polarity'],
                                'type': 'compound',
                                'words': [w1['word'], w2['word']],
                                'intensity': (w1['intensity'] + w2['intensity']) / 2,
                                'evidence': f"{w1['word']} + {w2['word']}",
                                'description': rule['description']
                            })

        # Single rules
        for rule_name in sorted(self.emotion_rules.keys()):
            rule = self.emotion_rules[rule_name]

            if rule['type'] == 'single':
                condition = rule['condition'][0]

                if len(emotion_words[condition]) > 0:
                    for word_info in emotion_words[condition][:5]:
                        inferred_emotions.append({
                            'rule': rule_name,
                            'emotion': rule['emotion'],
                            'polarity': rule['polarity'],
                            'type': 'single',
                            'words': [word_info['word']],
                            'intensity': word_info['intensity'],
                            'evidence': word_info['word'],
                            'description': rule['description']
                        })

        return inferred_emotions

    def calculate_emotion_score(self, text, emotion_words):
        """Step 3: Calculate ES and CS_ECR"""
        S_P = 0.0
        S_N = 0.0

        # Positive intensity
        for word_info in emotion_words['Desirable']:
            word = word_info['word']
            intensity = word_info['intensity']
            frequency = text.lower().count(word.lower())
            S_P += intensity * frequency

        for word_info in emotion_words['Praiseworthy']:
            word = word_info['word']
            intensity = word_info['intensity']
            frequency = text.lower().count(word.lower())
            S_P += intensity * frequency

        for word_info in emotion_words['Disconfirmed']:
            word = word_info['word']
            intensity = word_info['intensity']
            frequency = text.lower().count(word.lower())
            S_P += intensity * frequency * 0.5

        # Negative intensity
        for word_info in emotion_words['Undesirable']:
            word = word_info['word']
            intensity = word_info['intensity']
            frequency = text.lower().count(word.lower())
            S_N += intensity * frequency

        for word_info in emotion_words['Blameworthy']:
            word = word_info['word']
            intensity = word_info['intensity']
            frequency = text.lower().count(word.lower())
            S_N += intensity * frequency

        for word_info in emotion_words['Confirmed']:
            word = word_info['word']
            intensity = word_info['intensity']
            frequency = text.lower().count(word.lower())
            S_N += intensity * frequency * 0.5

        # Calculate ES and CS_ECR
        if S_P + S_N == 0:
            ES = 0.0
            CS_ECR = 0.0
        else:
            ES = (S_P - S_N) / (S_P + S_N)
            CS_ECR = abs(ES)

        return ES, CS_ECR, S_P, S_N

    def extract_knowledge(self, text):
        """Main knowledge extraction (Algorithm 1)"""
        emotion_words = self.identify_emotion_words(text)

        has_emotion_words = any(len(words) > 0 for words in emotion_words.values())

        if not has_emotion_words:
            return {
                'PECK': [],
                'NECK': [],
                'ES': 0.0,
                'CS_ECR': 0.0,
                'S_P': 0.0,
                'S_N': 0.0,
                'has_emotions': False,
                'emotion_words': emotion_words
            }

        ES, CS_ECR, S_P, S_N = self.calculate_emotion_score(text, emotion_words)
        all_inferred_emotions = self.apply_rules(emotion_words)

        PECK = [e for e in all_inferred_emotions if e['polarity'] == 'positive']
        NECK = [e for e in all_inferred_emotions if e['polarity'] == 'negative']

        if ES > 0:
            primary_knowledge = 'PECK'
        elif ES < 0:
            primary_knowledge = 'NECK'
        else:
            primary_knowledge = 'BOTH'

        return {
            'PECK': PECK,
            'NECK': NECK,
            'ES': ES,
            'CS_ECR': CS_ECR,
            'S_P': S_P,
            'S_N': S_N,
            'has_emotions': True,
            'primary_knowledge': primary_knowledge,
            'emotion_words': emotion_words,
            'total_emotions': len(all_inferred_emotions)
        }

# Initialize ECR
print("="*80)
print("INITIALIZING EMOTION-COGNITIVE REASONING SYSTEM")
print("="*80)

ecr = MentalHealthECR(EDD)
print("✅ ECR initialized with 10 emotion-cognitive rules")

# Test on samples
print("\n" + "="*80)
print("TESTING ECR ON SAMPLE POSTS")
print("="*80)

test_samples = combined_df.sample(5)

for idx, row in test_samples.iterrows():
    text = row['text_clean']
    condition = row['condition']

    print(f"\n{'-'*80}")
    print(f"Condition: {condition}")
    print(f"Text: {text[:200]}...")

    result = ecr.extract_knowledge(text)

    print(f"\nECR Results:")
    print(f"  Has emotions: {result['has_emotions']}")
    print(f"  ES: {result['ES']:.3f} (Emotion Score: -1=negative, 0=neutral, 1=positive)")
    print(f"  CS_ECR: {result['CS_ECR']:.3f} (Confidence: 0-1)")
    print(f"  PECK count: {len(result['PECK'])}")
    print(f"  NECK count: {len(result['NECK'])}")

    if result['PECK']:
        print(f"  Positive emotions: {[e['emotion'] for e in result['PECK'][:3]]}")
    if result['NECK']:
        print(f"  Negative emotions: {[e['emotion'] for e in result['NECK'][:3]]}")

print("\n" + "="*80)
print("✅ Step 5 complete!")
print("\nECR system ready to process full dataset")

INITIALIZING EMOTION-COGNITIVE REASONING SYSTEM
✅ ECR initialized with 10 emotion-cognitive rules

TESTING ECR ON SAMPLE POSTS

--------------------------------------------------------------------------------
Condition: Aspergers
Text: Cleaning and frustration fits. Today I asked my wife of two years of she can really live with an ASD partner. It started with having to clean the whole house by myself. I dont aggressively clean like ...

ECR Results:
  Has emotions: True
  ES: -0.105 (Emotion Score: -1=negative, 0=neutral, 1=positive)
  CS_ECR: 0.105 (Confidence: 0-1)
  PECK count: 11
  NECK count: 5
  Positive emotions: ['gratitude', 'gratitude', 'gratitude']
  Negative emotions: ['distress', 'distress', 'distress']

--------------------------------------------------------------------------------
Condition: Aspergers
Text: DAE completely miss read the meaning behind things people who love you say?. Im in a relationship and Im struggling really bad with misunderstanding whats being said

In [6]:
# ============================================================================
# STEP 6: APPLY ECR TO FULL DATASET
# ============================================================================

import time
from tqdm.auto import tqdm

print("="*80)
print("APPLYING ECR TO FULL DATASET")
print("="*80)
print(f"Total posts to process: {len(combined_df)}")
print("This will take approximately 5-10 minutes...")

# Process in batches for progress tracking
batch_size = 1000
results = []

start_time = time.time()

for i in tqdm(range(0, len(combined_df), batch_size), desc="Processing batches"):
    batch = combined_df.iloc[i:i+batch_size]

    for idx, row in batch.iterrows():
        text = row['text_clean']

        # Extract emotion-cognitive knowledge
        ecr_result = ecr.extract_knowledge(text)

        results.append({
            'text_clean': text,
            'ES': ecr_result['ES'],
            'CS_ECR': ecr_result['CS_ECR'],
            'S_P': ecr_result['S_P'],
            'S_N': ecr_result['S_N'],
            'has_emotions': ecr_result['has_emotions'],
            'peck_count': len(ecr_result['PECK']),
            'neck_count': len(ecr_result['NECK']),
            'PECK': ecr_result['PECK'],
            'NECK': ecr_result['NECK'],
            'condition': row['condition'],
            'score': row['score'],
            'num_comments': row['num_comments']
        })

# Create DataFrame with ECR results
ecr_df = pd.DataFrame(results)

elapsed_time = time.time() - start_time
print(f"\n✅ Processing complete in {elapsed_time/60:.1f} minutes")

# Statistics
print("\n" + "="*80)
print("ECR PROCESSING STATISTICS")
print("="*80)

print(f"\nTotal posts: {len(ecr_df)}")
print(f"Posts with emotions detected: {ecr_df['has_emotions'].sum()} ({ecr_df['has_emotions'].sum()/len(ecr_df)*100:.1f}%)")
print(f"Posts without emotions: {(~ecr_df['has_emotions']).sum()} ({(~ecr_df['has_emotions']).sum()/len(ecr_df)*100:.1f}%)")

print(f"\nEmotion Score (ES) statistics:")
print(ecr_df['ES'].describe())

print(f"\nConfidence Score (CS_ECR) statistics:")
print(ecr_df['CS_ECR'].describe())

# Categorize sentiment tendency
ecr_df['sentiment_tendency'] = ecr_df['ES'].apply(
    lambda x: 'positive' if x > 0.2 else ('negative' if x < -0.2 else 'neutral')
)

print(f"\nSentiment Tendency Distribution:")
print(ecr_df['sentiment_tendency'].value_counts())
print(f"\nPercentages:")
print(ecr_df['sentiment_tendency'].value_counts(normalize=True) * 100)

# By condition
print(f"\n{'='*80}")
print("SENTIMENT TENDENCY BY CONDITION")
print("="*80)
for condition in ecr_df['condition'].unique():
    print(f"\n{condition}:")
    condition_data = ecr_df[ecr_df['condition'] == condition]
    print(condition_data['sentiment_tendency'].value_counts())

print(f"\n{'='*80}")
print("✅ Step 6 complete!")
print(f"\nDataset with ECR results: {ecr_df.shape}")

APPLYING ECR TO FULL DATASET
Total posts to process: 141254
This will take approximately 5-10 minutes...


Processing batches:   0%|          | 0/142 [00:00<?, ?it/s]


✅ Processing complete in 10.5 minutes

ECR PROCESSING STATISTICS

Total posts: 141254
Posts with emotions detected: 99193 (70.2%)
Posts without emotions: 42061 (29.8%)

Emotion Score (ES) statistics:
count    141254.000000
mean         -0.034599
std           0.636038
min          -1.000000
25%          -0.500000
50%           0.000000
75%           0.333333
max           1.000000
Name: ES, dtype: float64

Confidence Score (CS_ECR) statistics:
count    141254.000000
mean          0.473051
std           0.426571
min           0.000000
25%           0.000000
50%           0.422222
75%           1.000000
max           1.000000
Name: CS_ECR, dtype: float64

Sentiment Tendency Distribution:
sentiment_tendency
neutral     57585
negative    44246
positive    39423
Name: count, dtype: int64

Percentages:
sentiment_tendency
neutral     40.766987
negative    31.323715
positive    27.909298
Name: proportion, dtype: float64

SENTIMENT TENDENCY BY CONDITION

ADHD:
sentiment_tendency
neutral     15

In [7]:
# ============================================================================
# STEP 7: RULE-BASED PSEUDO-LABELING
# ============================================================================

class MentalHealthSentimentLabeler:
    """
    Create pseudo-labels using ECR results + keyword heuristics
    Labels: positive (0), negative (1), neutral (2)
    """

    def __init__(self):
        self.crisis_keywords = [
            'suicidal', 'kill myself', 'want to die', 'end it all',
            'better off dead', 'no reason to live', 'cant go on',
            'self harm', 'cutting myself', 'overdose'
        ]

        self.strong_positive = [
            'feeling better', 'much better', 'improving', 'recovery',
            'breakthrough', 'proud of myself', 'made progress',
            'therapy helping', 'medication working'
        ]

        self.strong_negative = [
            'getting worse', 'cant cope', 'breaking down',
            'falling apart', 'unbearable', 'hopeless',
            'relapse', 'crisis'
        ]

    def label_post(self, text, ES, CS_ECR, S_P, S_N):
        """
        Create pseudo-label using multiple signals
        Returns: (label, confidence, reasoning)
        """
        text_lower = text.lower()

        # SIGNAL 1: Crisis detection (highest priority)
        for keyword in self.crisis_keywords:
            if keyword in text_lower:
                return 'negative', 1.0, f"Crisis: {keyword}"

        # SIGNAL 2: Strong keyword matching
        pos_score = sum(1 for kw in self.strong_positive if kw in text_lower)
        neg_score = sum(1 for kw in self.strong_negative if kw in text_lower)

        # SIGNAL 3: ECR-based scoring
        if ES > 0.3:
            # Strong positive tendency
            label = 'positive'
            confidence = min(0.7 + (ES * 0.3), 0.95)
            reasoning = f"ES={ES:.2f} (positive)"
        elif ES < -0.3:
            # Strong negative tendency
            label = 'negative'
            confidence = min(0.7 + (abs(ES) * 0.3), 0.95)
            reasoning = f"ES={ES:.2f} (negative)"
        elif ES > 0.1:
            # Weak positive
            label = 'positive'
            confidence = 0.5 + pos_score * 0.1
            reasoning = f"ES={ES:.2f} (weak positive)"
        elif ES < -0.1:
            # Weak negative
            label = 'negative'
            confidence = 0.5 + neg_score * 0.1
            reasoning = f"ES={ES:.2f} (weak negative)"
        else:
            # Neutral
            label = 'neutral'
            confidence = 0.4
            reasoning = f"ES={ES:.2f} (neutral)"

        # SIGNAL 4: Questions often neutral
        if text.count('?') >= 2 and len(text) < 300:
            if label != 'negative' or confidence < 0.7:
                label = 'neutral'
                confidence = max(confidence * 0.7, 0.3)
                reasoning += " + questions"

        return label, confidence, reasoning

# Initialize labeler
print("="*80)
print("STEP 7: RULE-BASED PSEUDO-LABELING")
print("="*80)

labeler = MentalHealthSentimentLabeler()

# Apply labeling
print("\nGenerating pseudo-labels...")
labels = []
confidences = []
reasonings = []

for idx, row in tqdm(ecr_df.iterrows(), total=len(ecr_df), desc="Labeling"):
    label, confidence, reasoning = labeler.label_post(
        row['text_clean'],
        row['ES'],
        row['CS_ECR'],
        row['S_P'],
        row['S_N']
    )

    labels.append(label)
    confidences.append(confidence)
    reasonings.append(reasoning)

ecr_df['label'] = labels
ecr_df['label_confidence'] = confidences
ecr_df['label_reasoning'] = reasonings

# Convert to numeric
label_map = {'positive': 0, 'negative': 1, 'neutral': 2}
ecr_df['label_numeric'] = ecr_df['label'].map(label_map)

print("\n✅ Pseudo-labeling complete!")

# Statistics
print("\n" + "="*80)
print("LABEL STATISTICS")
print("="*80)

print("\nLabel distribution:")
print(ecr_df['label'].value_counts())
print("\nPercentages:")
print(ecr_df['label'].value_counts(normalize=True) * 100)

print("\nConfidence statistics:")
print(ecr_df['label_confidence'].describe())

# By condition
print("\n" + "="*80)
print("LABEL DISTRIBUTION BY CONDITION")
print("="*80)

for condition in ecr_df['condition'].unique():
    print(f"\n{condition}:")
    condition_df = ecr_df[ecr_df['condition'] == condition]
    print(condition_df['label'].value_counts())

## Fix the sampling issue
print("\n" + "="*80)
print("SAMPLE LABELED POSTS (FIXED)")
print("="*80)

for label in ['positive', 'negative', 'neutral']:
    print(f"\n{label.upper()} Examples:")
    print("-" * 80)

    # Try high confidence first, then any confidence
    high_conf = ecr_df[
        (ecr_df['label'] == label) &
        (ecr_df['label_confidence'] > 0.7)
    ]

    if len(high_conf) >= 2:
        samples = high_conf.sample(2)
    else:
        samples = ecr_df[ecr_df['label'] == label].sample(min(2, len(ecr_df[ecr_df['label'] == label])))

    for idx, row in samples.iterrows():
        print(f"\nText: {row['text_clean'][:150]}...")
        print(f"Confidence: {row['label_confidence']:.2f}")
        print(f"Reasoning: {row['label_reasoning']}")
        print(f"ES: {row['ES']:.3f}, PECK: {row['peck_count']}, NECK: {row['neck_count']}")

print("\n" + "="*80)
print("✅ Step 7 complete!")
print(f"\nLabeled dataset: {ecr_df.shape}")

STEP 7: RULE-BASED PSEUDO-LABELING

Generating pseudo-labels...


Labeling:   0%|          | 0/141254 [00:00<?, ?it/s]


✅ Pseudo-labeling complete!

LABEL STATISTICS

Label distribution:
label
neutral     51586
negative    48367
positive    41301
Name: count, dtype: int64

Percentages:
label
neutral     36.520028
negative    34.241154
positive    29.238818
Name: proportion, dtype: float64

Confidence statistics:
count    141254.000000
mean          0.695441
std           0.254151
min           0.300000
25%           0.400000
50%           0.823529
75%           0.950000
max           1.000000
Name: label_confidence, dtype: float64

LABEL DISTRIBUTION BY CONDITION

ADHD:
label
neutral     13631
positive    10953
negative    10741
Name: count, dtype: int64

Aspergers:
label
neutral     8485
positive    8336
negative    5237
Name: count, dtype: int64

Depression:
label
neutral     8485
positive    8336
negative    5237
Name: count, dtype: int64

OCD:
label
negative    16983
neutral     14077
positive     8583
Name: count, dtype: int64

PTSD:
label
negative    10169
neutral      6908
positive     5093
Name

In [8]:
# ============================================================================
# STEP 8: QUALITY CONTROL & DATASET BALANCING
# ============================================================================

print("="*80)
print("STEP 8: QUALITY CONTROL & DATASET BALANCING")
print("="*80)

# Current state
print("\nCurrent dataset:")
print(f"Total posts: {len(ecr_df)}")
print(f"\nLabel distribution:")
print(ecr_df['label'].value_counts())

# Step 8.1: Filter by confidence threshold
print("\n" + "-"*80)
print("8.1: FILTERING BY CONFIDENCE")
print("-"*80)

confidence_threshold = 0.5
print(f"Confidence threshold: {confidence_threshold}")

ecr_df_filtered = ecr_df[ecr_df['label_confidence'] >= confidence_threshold].copy()

print(f"\nBefore filtering: {len(ecr_df)} posts")
print(f"After filtering:  {len(ecr_df_filtered)} posts")
print(f"Removed: {len(ecr_df) - len(ecr_df_filtered)} low-confidence posts ({(len(ecr_df) - len(ecr_df_filtered))/len(ecr_df)*100:.1f}%)")

print(f"\nLabel distribution after filtering:")
print(ecr_df_filtered['label'].value_counts())

# Step 8.2: Balance dataset
print("\n" + "-"*80)
print("8.2: BALANCING DATASET")
print("-"*80)

label_counts = ecr_df_filtered['label'].value_counts()
print(f"\nBefore balancing:")
print(label_counts)

# Use undersampling to balance
min_count = label_counts.min()
print(f"\nTarget count per class (min): {min_count}")

balanced_dfs = []
for label in ['positive', 'negative', 'neutral']:
    label_df = ecr_df_filtered[ecr_df_filtered['label'] == label]

    if len(label_df) > min_count:
        # Undersample
        label_df_sampled = label_df.sample(n=min_count, random_state=42)
    else:
        label_df_sampled = label_df

    balanced_dfs.append(label_df_sampled)

# Combine and shuffle
ecr_df_balanced = pd.concat(balanced_dfs, ignore_index=True)
ecr_df_balanced = ecr_df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)

print(f"\nAfter balancing:")
print(ecr_df_balanced['label'].value_counts())
print(f"\nTotal balanced dataset: {len(ecr_df_balanced)} posts")

# Distribution by condition
print("\n" + "-"*80)
print("BALANCED DATASET BY CONDITION")
print("-"*80)

for condition in ecr_df_balanced['condition'].unique():
    condition_df = ecr_df_balanced[ecr_df_balanced['condition'] == condition]
    print(f"\n{condition}: {len(condition_df)} posts")
    print(condition_df['label'].value_counts())

# Step 8.3: Final statistics
print("\n" + "="*80)
print("FINAL DATASET STATISTICS")
print("="*80)

print(f"\nTotal posts: {len(ecr_df_balanced)}")
print(f"\nLabel distribution:")
print(ecr_df_balanced['label'].value_counts())
print(f"\nPercentages:")
print(ecr_df_balanced['label'].value_counts(normalize=True) * 100)

print(f"\nCondition distribution:")
print(ecr_df_balanced['condition'].value_counts())

print(f"\nConfidence statistics:")
print(ecr_df_balanced['label_confidence'].describe())

print(f"\nES (Emotion Score) statistics:")
print(ecr_df_balanced['ES'].describe())

print(f"\nText length statistics:")
ecr_df_balanced['text_length'] = ecr_df_balanced['text_clean'].str.len()
print(ecr_df_balanced['text_length'].describe())

# Step 8.4: Save processed dataset
print("\n" + "-"*80)
print("SAVING PROCESSED DATA")
print("-"*80)

# Save balanced dataset
ecr_df_balanced.to_csv('mental_health_balanced_labeled.csv', index=False)
print("✅ Saved: mental_health_balanced_labeled.csv")

# Also save the full filtered dataset (unbalanced but high confidence)
ecr_df_filtered.to_csv('mental_health_filtered_labeled.csv', index=False)
print("✅ Saved: mental_health_filtered_labeled.csv")

print("\n" + "="*80)
print("✅ Step 8 complete!")
print("="*80)

print(f"\nFinal balanced dataset ready for training:")
print(f"  - Total posts: {len(ecr_df_balanced)}")
print(f"  - Labels: 3 classes (positive, negative, neutral)")
print(f"  - Balanced: ~{min_count} posts per class")
print(f"  - Conditions: {ecr_df_balanced['condition'].nunique()} mental health conditions")

# Show summary
print("\n" + "="*80)
print("DATASET SUMMARY")
print("="*80)
summary_df = ecr_df_balanced.groupby(['condition', 'label']).size().unstack(fill_value=0)
print(summary_df)

STEP 8: QUALITY CONTROL & DATASET BALANCING

Current dataset:
Total posts: 141254

Label distribution:
label
neutral     51586
negative    48367
positive    41301
Name: count, dtype: int64

--------------------------------------------------------------------------------
8.1: FILTERING BY CONFIDENCE
--------------------------------------------------------------------------------
Confidence threshold: 0.5

Before filtering: 141254 posts
After filtering:  90946 posts
Removed: 50308 low-confidence posts (35.6%)

Label distribution after filtering:
label
negative    48367
positive    41301
neutral      1278
Name: count, dtype: int64

--------------------------------------------------------------------------------
8.2: BALANCING DATASET
--------------------------------------------------------------------------------

Before balancing:
label
negative    48367
positive    41301
neutral      1278
Name: count, dtype: int64

Target count per class (min): 1278

After balancing:
label
neutral     1

In [10]:
# ============================================================================
# STEP 8 REVISED: BETTER BALANCING STRATEGY
# ============================================================================

print("="*80)
print("STEP 8 REVISED: IMPROVED BALANCING")
print("="*80)

# Use lower confidence threshold to get more data
confidence_threshold = 0.4  # Lower threshold
print(f"Using confidence threshold: {confidence_threshold}")

ecr_df_filtered_v2 = ecr_df[ecr_df['label_confidence'] >= confidence_threshold].copy()

print(f"\nFiltered dataset: {len(ecr_df_filtered_v2)} posts")
print("\nLabel distribution:")
print(ecr_df_filtered_v2['label'].value_counts())

# Strategy: Balance by limiting majority classes, keep all minority
label_counts = ecr_df_filtered_v2['label'].value_counts()
print(f"\nOriginal distribution:")
for label, count in label_counts.items():
    print(f"  {label}: {count}")

# Set target: Take middle value or cap at reasonable size
target_per_class = min(label_counts.max(), 20000)  # Cap at 20k per class
target_per_class = max(target_per_class, label_counts.min() * 2)  # At least 2x minority

print(f"\nTarget per class: {target_per_class}")

# Balance by undersampling majority, keeping all minority
balanced_dfs_v2 = []
for label in ['positive', 'negative', 'neutral']:
    label_df = ecr_df_filtered_v2[ecr_df_filtered_v2['label'] == label]

    if len(label_df) > target_per_class:
        # Undersample with stratification by condition
        label_df_sampled = label_df.groupby('condition', group_keys=False).apply(
            lambda x: x.sample(n=min(len(x), target_per_class // 5), random_state=42)
        )
    else:
        label_df_sampled = label_df

    balanced_dfs_v2.append(label_df_sampled)
    print(f"{label}: {len(label_df)} → {len(label_df_sampled)}")

# Combine
ecr_df_balanced_v2 = pd.concat(balanced_dfs_v2, ignore_index=True)
ecr_df_balanced_v2 = ecr_df_balanced_v2.sample(frac=1, random_state=42).reset_index(drop=True)

print(f"\n{'='*80}")
print("FINAL BALANCED DATASET V2")
print("="*80)

print(f"\nTotal posts: {len(ecr_df_balanced_v2)}")
print("\nLabel distribution:")
print(ecr_df_balanced_v2['label'].value_counts())
print("\nPercentages:")
print(ecr_df_balanced_v2['label'].value_counts(normalize=True) * 100)

print("\nCondition distribution:")
print(ecr_df_balanced_v2['condition'].value_counts())

# Summary table
print("\n" + "="*80)
print("DATASET SUMMARY BY CONDITION")
print("="*80)
summary_df_v2 = ecr_df_balanced_v2.groupby(['condition', 'label']).size().unstack(fill_value=0)
print(summary_df_v2)
print(f"\nTotal: {summary_df_v2.sum().sum()}")

# Save
ecr_df_balanced_v2.to_csv('mental_health_balanced_v2.csv', index=False)
print("\n✅ Saved: mental_health_balanced_v2.csv")

# Use this as final dataset
ecr_df_final = ecr_df_balanced_v2.copy()

print("\n" + "="*80)
print("✅ Step 8 REVISED complete!")
print(f"Final dataset: {len(ecr_df_final)} posts")
print("="*80)

STEP 8 REVISED: IMPROVED BALANCING
Using confidence threshold: 0.4

Filtered dataset: 138419 posts

Label distribution:
label
neutral     48751
negative    48367
positive    41301
Name: count, dtype: int64

Original distribution:
  neutral: 48751
  negative: 48367
  positive: 41301

Target per class: 82602
positive: 41301 → 41301
negative: 48367 → 48367
neutral: 48751 → 48751

FINAL BALANCED DATASET V2

Total posts: 138419

Label distribution:
label
neutral     48751
negative    48367
positive    41301
Name: count, dtype: int64

Percentages:
label
neutral     35.219876
negative    34.942457
positive    29.837667
Name: proportion, dtype: float64

Condition distribution:
condition
OCD           38576
ADHD          34933
PTSD          21844
Aspergers     21533
Depression    21533
Name: count, dtype: int64

DATASET SUMMARY BY CONDITION
label       negative  neutral  positive
condition                              
ADHD           10741    13239     10953
Aspergers       5237     7960      8

In [11]:
# ============================================================================
# STEP 9: GENERATE SENTENCE-EMOTION TREES (SETs)
# ============================================================================

class SentenceEmotionTreeGenerator:
    """
    Generate Sentence-Emotion Trees (SETs)
    Integrates emotion-cognitive knowledge into text
    """

    def __init__(self, ecr):
        self.ecr = ecr

    def create_emotion_annotation(self, emotions, max_emotions=3):
        """Create emotion annotations for tree structure"""
        if not emotions:
            return ""

        # Sort by intensity and limit
        sorted_emotions = sorted(
            emotions,
            key=lambda x: x['intensity'],
            reverse=True
        )[:max_emotions]

        annotations = []
        for emo in sorted_emotions:
            if emo['type'] == 'compound':
                # Compound: word1+word2{emotion}
                annotation = f"{emo['words'][0]}+{emo['words'][1]}→{emo['emotion']}"
            else:
                # Single: word{emotion}
                annotation = f"{emo['words'][0]}→{emo['emotion']}"

            annotations.append(annotation)

        return " | ".join(annotations)

    def generate_set(self, text, ES, CS_ECR, PECK, NECK, threshold=0.3):
        """
        Generate Sentence-Emotion Tree (SET)

        Decision: Which emotions to incorporate based on ES
        """

        # Decide which emotions to use
        if ES > 0.2:
            # Strong positive → use PECK
            selected_emotions = PECK
            tendency = 'positive'
        elif ES < -0.2:
            # Strong negative → use NECK
            selected_emotions = NECK
            tendency = 'negative'
        elif ES > 0:
            # Weak positive → use PECK
            selected_emotions = PECK
            tendency = 'weak_positive'
        elif ES < 0:
            # Weak negative → use NECK
            selected_emotions = NECK
            tendency = 'weak_negative'
        else:
            # Neutral
            selected_emotions = []
            tendency = 'neutral'

        # Create SET
        if selected_emotions and CS_ECR >= threshold:
            emotion_annotation = self.create_emotion_annotation(selected_emotions)
            # Append emotions at end
            set_text = f"{text} [EMOTIONS: {emotion_annotation}]"
            used_ecr = True
        else:
            set_text = text
            used_ecr = False

        return {
            'SET': set_text,
            'tendency': tendency,
            'selected_emotions': selected_emotions,
            'emotion_count': len(selected_emotions),
            'used_ecr': used_ecr
        }

# Initialize SET generator
print("="*80)
print("STEP 9: GENERATING SENTENCE-EMOTION TREES")
print("="*80)

set_generator = SentenceEmotionTreeGenerator(ecr)

# Generate SETs for the dataset
print(f"\nGenerating SETs for {len(ecr_df_final)} posts...")
print("This will take 2-3 minutes...")

sets_data = []

for idx, row in tqdm(ecr_df_final.iterrows(), total=len(ecr_df_final), desc="Generating SETs"):
    set_result = set_generator.generate_set(
        text=row['text_clean'],
        ES=row['ES'],
        CS_ECR=row['CS_ECR'],
        PECK=row['PECK'],
        NECK=row['NECK'],
        threshold=0.3
    )

    sets_data.append({
        'original_text': row['text_clean'],
        'SET': set_result['SET'],
        'ES': row['ES'],
        'CS_ECR': row['CS_ECR'],
        'tendency': set_result['tendency'],
        'emotion_count': set_result['emotion_count'],
        'used_ecr': set_result['used_ecr'],
        'label': row['label'],
        'label_numeric': row['label_numeric'],
        'label_confidence': row['label_confidence'],
        'condition': row['condition'],
        'peck_count': row['peck_count'],
        'neck_count': row['neck_count']
    })

# Create DataFrame with SETs
sets_df = pd.DataFrame(sets_data)

print("\n✅ All SETs generated!")

# Statistics
print("\n" + "="*80)
print("SET STATISTICS")
print("="*80)

print(f"\nTotal SETs: {len(sets_df)}")
print(f"SETs with emotions incorporated: {sets_df['used_ecr'].sum()} ({sets_df['used_ecr'].sum()/len(sets_df)*100:.1f}%)")
print(f"SETs without emotions: {(~sets_df['used_ecr']).sum()} ({(~sets_df['used_ecr']).sum()/len(sets_df)*100:.1f}%)")

print(f"\nTendency distribution:")
print(sets_df['tendency'].value_counts())

print(f"\nEmotion count distribution:")
print(sets_df['emotion_count'].value_counts().head(10))

# Text length comparison
sets_df['original_length'] = sets_df['original_text'].str.len()
sets_df['set_length'] = sets_df['SET'].str.len()
sets_df['length_increase'] = sets_df['set_length'] - sets_df['original_length']

print(f"\nText length comparison:")
print(f"Original avg length: {sets_df['original_length'].mean():.0f} chars")
print(f"SET avg length: {sets_df['set_length'].mean():.0f} chars")
print(f"Avg increase: {sets_df['length_increase'].mean():.0f} chars")

# Show examples
print("\n" + "="*80)
print("SAMPLE SENTENCE-EMOTION TREES")
print("="*80)

for label in ['positive', 'negative', 'neutral']:
    print(f"\n{label.upper()} Examples with Emotions:")
    print("-" * 80)

    sample = sets_df[
        (sets_df['label'] == label) &
        (sets_df['emotion_count'] > 0)
    ]

    if len(sample) > 0:
        sample_row = sample.sample(1).iloc[0]

        print(f"Original: {sample_row['original_text'][:150]}...")
        print(f"\nSET: {sample_row['SET'][:300]}...")
        print(f"\nStats:")
        print(f"  ES: {sample_row['ES']:.3f} | CS_ECR: {sample_row['CS_ECR']:.3f}")
        print(f"  Tendency: {sample_row['tendency']}")
        print(f"  Emotions: {sample_row['emotion_count']}")
        print(f"  PECK: {sample_row['peck_count']}, NECK: {sample_row['neck_count']}")
    else:
        print("No examples with emotions found")

# Save dataset with SETs
sets_df.to_csv('mental_health_with_sets.csv', index=False)
print(f"\n✅ Saved: mental_health_with_sets.csv")

print("\n" + "="*80)
print("✅ Step 9 complete!")
print("="*80)
print(f"\nDataset with SETs ready: {len(sets_df)} posts")
print(f"  - Original text preserved")
print(f"  - SETs with emotion annotations created")
print(f"  - Ready for BERT training")

STEP 9: GENERATING SENTENCE-EMOTION TREES

Generating SETs for 138419 posts...
This will take 2-3 minutes...


Generating SETs:   0%|          | 0/138419 [00:00<?, ?it/s]


✅ All SETs generated!

SET STATISTICS

Total SETs: 138419
SETs with emotions incorporated: 76175 (55.0%)
SETs without emotions: 62244 (45.0%)

Tendency distribution:
tendency
negative         44237
neutral          43729
positive         39408
weak_negative     5594
weak_positive     5451
Name: count, dtype: int64

Emotion count distribution:
emotion_count
0     46467
1     26668
2     14504
5     12971
3     10829
4      6716
8      3701
9      3452
7      2508
11     1566
Name: count, dtype: int64

Text length comparison:
Original avg length: 592 chars
SET avg length: 622 chars
Avg increase: 31 chars

SAMPLE SENTENCE-EMOTION TREES

POSITIVE Examples with Emotions:
--------------------------------------------------------------------------------
Original: Everynight I go back to when I was a little kid dealing with the abuse and trauma, sometimes even the not so bad stuff but its really confusing, becau...

SET: Everynight I go back to when I was a little kid dealing with the abuse an

In [12]:
# ============================================================================
# STEP 10: PREPARE DATA FOR BERT TRAINING
# ============================================================================

from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("STEP 10: PREPARE DATA FOR BERT TRAINING")
print("="*80)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🖥️  Device: {device}")
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Step 10.1: Train/Val/Test Split
print("\n" + "-"*80)
print("10.1: CREATING TRAIN/VAL/TEST SPLITS")
print("-"*80)

# 70% train, 15% val, 15% test
train_df, temp_df = train_test_split(
    sets_df,
    test_size=0.3,
    random_state=42,
    stratify=sets_df['label_numeric']
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,  # 0.5 of 30% = 15% each
    random_state=42,
    stratify=temp_df['label_numeric']
)

print(f"Train set: {len(train_df):,} posts ({len(train_df)/len(sets_df)*100:.1f}%)")
print(f"Val set:   {len(val_df):,} posts ({len(val_df)/len(sets_df)*100:.1f}%)")
print(f"Test set:  {len(test_df):,} posts ({len(test_df)/len(sets_df)*100:.1f}%)")

print("\nLabel distribution in train:")
print(train_df['label'].value_counts())

print("\nLabel distribution in val:")
print(val_df['label'].value_counts())

print("\nLabel distribution in test:")
print(test_df['label'].value_counts())

# Save splits
train_df.to_csv('train_set.csv', index=False)
val_df.to_csv('val_set.csv', index=False)
test_df.to_csv('test_set.csv', index=False)
print("\n✅ Saved train/val/test splits")

# Step 10.2: Load BERT Tokenizer
print("\n" + "-"*80)
print("10.2: LOADING BERT TOKENIZER")
print("-"*80)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print("✅ Loaded: bert-base-uncased tokenizer")

# Test tokenization
sample_text = train_df.iloc[0]['original_text']
tokens = tokenizer.tokenize(sample_text[:100])
print(f"\nSample tokenization:")
print(f"Text: {sample_text[:100]}...")
print(f"Tokens ({len(tokens)}): {tokens[:10]}...")

# Step 10.3: Create PyTorch Dataset Class
print("\n" + "-"*80)
print("10.3: CREATING PYTORCH DATASETS")
print("-"*80)

class MentalHealthDataset(Dataset):
    """
    PyTorch Dataset for Mental Health Text
    Can use either original text or SET
    """

    def __init__(self, dataframe, tokenizer, max_length=128, use_set=False):
        """
        Args:
            dataframe: DataFrame with columns
            tokenizer: BERT tokenizer
            max_length: max sequence length
            use_set: if True use SET, else use original_text
        """
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.use_set = use_set

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Choose text source
        if self.use_set:
            text = row['SET']
        else:
            text = row['original_text']

        # Tokenize
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(row['label_numeric'], dtype=torch.long),
            'ES': torch.tensor(row['ES'], dtype=torch.float),
            'CS_ECR': torch.tensor(row['CS_ECR'], dtype=torch.float)
        }

# Create datasets (start with original text for baseline BERT)
print("\nCreating datasets (using original_text)...")

max_length = 128  # Standard BERT max length

train_dataset = MentalHealthDataset(train_df, tokenizer, max_length, use_set=False)
val_dataset = MentalHealthDataset(val_df, tokenizer, max_length, use_set=False)
test_dataset = MentalHealthDataset(test_df, tokenizer, max_length, use_set=False)

print(f"✅ Train dataset: {len(train_dataset):,} samples")
print(f"✅ Val dataset:   {len(val_dataset):,} samples")
print(f"✅ Test dataset:  {len(test_dataset):,} samples")

# Step 10.4: Create DataLoaders
print("\n" + "-"*80)
print("10.4: CREATING DATALOADERS")
print("-"*80)

# Batch size based on GPU memory
if device.type == 'cuda':
    gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    if gpu_memory_gb >= 15:
        batch_size = 32
    elif gpu_memory_gb >= 8:
        batch_size = 16
    else:
        batch_size = 8
else:
    batch_size = 8

print(f"Batch size: {batch_size}")

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True if device.type == 'cuda' else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True if device.type == 'cuda' else False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True if device.type == 'cuda' else False
)

print(f"✅ Train loader: {len(train_loader):,} batches")
print(f"✅ Val loader:   {len(val_loader):,} batches")
print(f"✅ Test loader:  {len(test_loader):,} batches")

# Step 10.5: Test Data Loading
print("\n" + "-"*80)
print("10.5: TESTING DATA LOADING")
print("-"*80)

print("Loading one batch...")
sample_batch = next(iter(train_loader))

print(f"\nBatch shapes:")
print(f"  input_ids: {sample_batch['input_ids'].shape}")
print(f"  attention_mask: {sample_batch['attention_mask'].shape}")
print(f"  labels: {sample_batch['label'].shape}")
print(f"  ES: {sample_batch['ES'].shape}")
print(f"  CS_ECR: {sample_batch['CS_ECR'].shape}")

print(f"\nLabel distribution in batch:")
labels_in_batch = sample_batch['label'].numpy()
unique, counts = np.unique(labels_in_batch, return_counts=True)
for label_num, count in zip(unique, counts):
    label_name = {0: 'positive', 1: 'negative', 2: 'neutral'}[label_num]
    print(f"  {label_name}: {count}")

print("\n" + "="*80)
print("✅ Step 10 complete!")
print("="*80)

print(f"\n📦 Data ready for training:")
print(f"  - Train: {len(train_dataset):,} samples ({len(train_loader):,} batches)")
print(f"  - Val:   {len(val_dataset):,} samples ({len(val_loader):,} batches)")
print(f"  - Test:  {len(test_dataset):,} samples ({len(test_loader):,} batches)")
print(f"  - Batch size: {batch_size}")
print(f"  - Max length: {max_length} tokens")
print(f"  - Device: {device}")

print("\n🎯 Ready for Step 11: Build and Train BERT Model!")

STEP 10: PREPARE DATA FOR BERT TRAINING

🖥️  Device: cuda
   GPU: Tesla T4
   Memory: 15.8 GB

--------------------------------------------------------------------------------
10.1: CREATING TRAIN/VAL/TEST SPLITS
--------------------------------------------------------------------------------
Train set: 96,893 posts (70.0%)
Val set:   20,763 posts (15.0%)
Test set:  20,763 posts (15.0%)

Label distribution in train:
label
neutral     34125
negative    33857
positive    28911
Name: count, dtype: int64

Label distribution in val:
label
neutral     7313
negative    7255
positive    6195
Name: count, dtype: int64

Label distribution in test:
label
neutral     7313
negative    7255
positive    6195
Name: count, dtype: int64

✅ Saved train/val/test splits

--------------------------------------------------------------------------------
10.2: LOADING BERT TOKENIZER
--------------------------------------------------------------------------------


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

✅ Loaded: bert-base-uncased tokenizer

Sample tokenization:
Text: POCD flare of the day......
Tokens (9): ['po', '##cd', 'flare', 'of', 'the', 'day', '.', '.', '.']...

--------------------------------------------------------------------------------
10.3: CREATING PYTORCH DATASETS
--------------------------------------------------------------------------------

Creating datasets (using original_text)...
✅ Train dataset: 96,893 samples
✅ Val dataset:   20,763 samples
✅ Test dataset:  20,763 samples

--------------------------------------------------------------------------------
10.4: CREATING DATALOADERS
--------------------------------------------------------------------------------
Batch size: 32
✅ Train loader: 3,028 batches
✅ Val loader:   649 batches
✅ Test loader:  649 batches

--------------------------------------------------------------------------------
10.5: TESTING DATA LOADING
--------------------------------------------------------------------------------
Loading one batc

In [13]:
# ============================================================================
# STEP 10 COMPLETION: VERIFY DATA IS READY
# ============================================================================

print("="*80)
print("STEP 10: FINAL VERIFICATION")
print("="*80)

# Check if all variables exist
print("\n✓ Checking required variables...")

required_vars = {
    'sets_df': 'Dataset with SETs',
    'train_df': 'Training data',
    'val_df': 'Validation data',
    'test_df': 'Test data',
    'EDD': 'Emotion Dimension Dictionary',
    'ecr': 'ECR System',
    'tokenizer': 'BERT Tokenizer'
}

all_ready = True
for var_name, description in required_vars.items():
    if var_name in globals():
        print(f"  ✅ {description} ({var_name}): Ready")
    else:
        print(f"  ❌ {description} ({var_name}): Missing")
        all_ready = False

if not all_ready:
    print("\n⚠️ Some variables are missing. Please run previous steps.")
else:
    print("\n✅ All required variables present!")

    # Summary
    print("\n" + "="*80)
    print("DATASET SUMMARY")
    print("="*80)
    print(f"Total dataset: {len(sets_df):,} posts")
    print(f"  - Train: {len(train_df):,} ({len(train_df)/len(sets_df)*100:.1f}%)")
    print(f"  - Val:   {len(val_df):,} ({len(val_df)/len(sets_df)*100:.1f}%)")
    print(f"  - Test:  {len(test_df):,} ({len(test_df)/len(sets_df)*100:.1f}%)")

    print(f"\nLabel distribution:")
    print(sets_df['label'].value_counts())

    print(f"\nCondition distribution:")
    print(sets_df['condition'].value_counts())

    print(f"\nSETs with emotions: {sets_df['used_ecr'].sum():,} ({sets_df['used_ecr'].sum()/len(sets_df)*100:.1f}%)")

    print("\n" + "="*80)
    print("✅ Step 10 COMPLETE - Ready for BERT Training!")
    print("="*80)

STEP 10: FINAL VERIFICATION

✓ Checking required variables...
  ✅ Dataset with SETs (sets_df): Ready
  ✅ Training data (train_df): Ready
  ✅ Validation data (val_df): Ready
  ✅ Test data (test_df): Ready
  ✅ Emotion Dimension Dictionary (EDD): Ready
  ✅ ECR System (ecr): Ready
  ✅ BERT Tokenizer (tokenizer): Ready

✅ All required variables present!

DATASET SUMMARY
Total dataset: 138,419 posts
  - Train: 96,893 (70.0%)
  - Val:   20,763 (15.0%)
  - Test:  20,763 (15.0%)

Label distribution:
label
neutral     48751
negative    48367
positive    41301
Name: count, dtype: int64

Condition distribution:
condition
OCD           38576
ADHD          34933
PTSD          21844
Aspergers     21533
Depression    21533
Name: count, dtype: int64

SETs with emotions: 76,175 (55.0%)

✅ Step 10 COMPLETE - Ready for BERT Training!


In [14]:
# ============================================================================
# STEP 11: BUILD AND TRAIN BASELINE BERT MODEL (FIXED)
# ============================================================================

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW  # Import from torch.optim instead
from sklearn.metrics import accuracy_score, f1_score, classification_report
import numpy as np
from tqdm.auto import tqdm
import time

print("="*80)
print("STEP 11: BASELINE BERT MODEL TRAINING")
print("="*80)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n🖥️  Device: {device}")
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

# ============================================================================
# 11.1: DEFINE BERT CLASSIFIER
# ============================================================================

class BERTClassifier(nn.Module):
    def __init__(self, n_classes=3, dropout=0.3):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        output = self.dropout(pooled_output)
        logits = self.classifier(output)
        return logits

print("\nInitializing BERT model...")
model = BERTClassifier(n_classes=3, dropout=0.3)
model = model.to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✅ Model initialized")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")

# ============================================================================
# 11.2: SETUP TRAINING
# ============================================================================

print("\n" + "-"*80)
print("TRAINING CONFIGURATION")
print("-"*80)

epochs = 3
learning_rate = 2e-5
batch_size = 16

optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)
criterion = nn.CrossEntropyLoss()
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

print(f"Epochs: {epochs}")
print(f"Learning rate: {learning_rate}")
print(f"Batch size: {batch_size}")
print(f"Total training steps: {total_steps:,}")
print(f"Optimizer: AdamW")
print(f"Loss function: CrossEntropyLoss")

# ============================================================================
# 11.3: TRAINING FUNCTIONS
# ============================================================================

def train_epoch(model, data_loader, criterion, optimizer, scheduler, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    predictions = []
    true_labels = []

    progress_bar = tqdm(data_loader, desc='Training')

    for batch in progress_bar:
        # Move to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        # Forward pass
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        # Track metrics
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        predictions.extend(preds)
        true_labels.extend(labels.cpu().numpy())

        # Update progress bar
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_loss = total_loss / len(data_loader)
    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='macro')

    return avg_loss, accuracy, f1

def evaluate(model, data_loader, criterion, device):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc='Evaluating'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)

            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            predictions.extend(preds)
            true_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(data_loader)
    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions, average='macro')

    return avg_loss, accuracy, f1, predictions, true_labels

# ============================================================================
# 11.4: TRAIN THE MODEL
# ============================================================================

print("\n" + "="*80)
print("STARTING TRAINING")
print("="*80)
print(f"Training on {len(train_dataset):,} samples")
print(f"Validating on {len(val_dataset):,} samples")
print(f"Estimated time: ~{epochs * 10} minutes on GPU\n")

best_val_f1 = 0
training_stats = []
start_time = time.time()

for epoch in range(epochs):
    print(f"\n{'='*80}")
    print(f"Epoch {epoch + 1}/{epochs}")
    print(f"{'='*80}")

    epoch_start = time.time()

    # Train
    train_loss, train_acc, train_f1 = train_epoch(
        model, train_loader, criterion, optimizer, scheduler, device
    )

    print(f"\nTraining Results:")
    print(f"  Loss:     {train_loss:.4f}")
    print(f"  Accuracy: {train_acc:.4f} ({train_acc*100:.2f}%)")
    print(f"  F1-Score: {train_f1:.4f}")

    # Validate
    val_loss, val_acc, val_f1, _, _ = evaluate(
        model, val_loader, criterion, device
    )

    print(f"\nValidation Results:")
    print(f"  Loss:     {val_loss:.4f}")
    print(f"  Accuracy: {val_acc:.4f} ({val_acc*100:.2f}%)")
    print(f"  F1-Score: {val_f1:.4f}")

    epoch_time = time.time() - epoch_start
    print(f"\nEpoch completed in {epoch_time/60:.1f} minutes")

    # Save best model
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), 'best_baseline_bert.pt')
        print(f"💾 Best model saved! (F1: {best_val_f1:.4f})")

    # Track stats
    training_stats.append({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'train_f1': train_f1,
        'val_loss': val_loss,
        'val_acc': val_acc,
        'val_f1': val_f1,
        'epoch_time_min': epoch_time/60
    })

total_time = time.time() - start_time

print("\n" + "="*80)
print("✅ TRAINING COMPLETE!")
print("="*80)
print(f"Total training time: {total_time/60:.1f} minutes")
print(f"Best validation F1-Score: {best_val_f1:.4f}")

# Save training history
import pandas as pd
history_df = pd.DataFrame(training_stats)
history_df.to_csv('baseline_bert_training_history.csv', index=False)
print("✅ Training history saved: baseline_bert_training_history.csv")

# Display training history
print("\n" + "-"*80)
print("TRAINING HISTORY")
print("-"*80)
print(history_df.to_string(index=False))

print("\n" + "="*80)
print("✅ Step 11 complete!")
print("="*80)

2025-11-10 04:12:40.031057: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1762747960.423916      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1762747960.530447      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

STEP 11: BASELINE BERT MODEL TRAINING

🖥️  Device: cuda
   GPU: Tesla T4

Initializing BERT model...


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

✅ Model initialized
   Total parameters: 109,484,547
   Trainable parameters: 109,484,547

--------------------------------------------------------------------------------
TRAINING CONFIGURATION
--------------------------------------------------------------------------------
Epochs: 3
Learning rate: 2e-05
Batch size: 16
Total training steps: 9,084
Optimizer: AdamW
Loss function: CrossEntropyLoss

STARTING TRAINING
Training on 96,893 samples
Validating on 20,763 samples
Estimated time: ~30 minutes on GPU


Epoch 1/3


Training:   0%|          | 0/3028 [00:00<?, ?it/s]


Training Results:
  Loss:     0.4942
  Accuracy: 0.8032 (80.32%)
  F1-Score: 0.8025


Evaluating:   0%|          | 0/649 [00:00<?, ?it/s]


Validation Results:
  Loss:     0.3476
  Accuracy: 0.8644 (86.44%)
  F1-Score: 0.8642

Epoch completed in 37.3 minutes
💾 Best model saved! (F1: 0.8642)

Epoch 2/3


Training:   0%|          | 0/3028 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x796b6cb90180><function _MultiProcessingDataLoaderIter.__del__ at 0x796b6cb90180>

Traceback (most recent call last):
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x796b6cb90180>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1564, in _shutdown_workers
    self._pin_memory_thread.join()
  File "/usr/lib/python3.11/threading.py", line 1116, in join
    raise RuntimeError("cannot join current thread")
RuntimeError: cannot join current thread
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 


Training Results:
  Loss:     0.3230
  Accuracy: 0.8766 (87.66%)
  F1-Score: 0.8761


Evaluating:   0%|          | 0/649 [00:00<?, ?it/s]


Validation Results:
  Loss:     0.3280
  Accuracy: 0.8786 (87.86%)
  F1-Score: 0.8780

Epoch completed in 37.5 minutes
💾 Best model saved! (F1: 0.8780)

Epoch 3/3


Training:   0%|          | 0/3028 [00:00<?, ?it/s]


Training Results:
  Loss:     0.2677
  Accuracy: 0.8977 (89.77%)
  F1-Score: 0.8972


Evaluating:   0%|          | 0/649 [00:00<?, ?it/s]


Validation Results:
  Loss:     0.3171
  Accuracy: 0.8809 (88.09%)
  F1-Score: 0.8804

Epoch completed in 37.5 minutes
💾 Best model saved! (F1: 0.8804)

✅ TRAINING COMPLETE!
Total training time: 112.3 minutes
Best validation F1-Score: 0.8804
✅ Training history saved: baseline_bert_training_history.csv

--------------------------------------------------------------------------------
TRAINING HISTORY
--------------------------------------------------------------------------------
 epoch  train_loss  train_acc  train_f1  val_loss  val_acc   val_f1  epoch_time_min
     1    0.494159   0.803154  0.802456  0.347576 0.864374 0.864249       37.255349
     2    0.322951   0.876616  0.876128  0.327971 0.878630 0.878015       37.497965
     3    0.267671   0.897681  0.897174  0.317112 0.880894 0.880377       37.480924

✅ Step 11 complete!


In [15]:
# ============================================================================
# STEP 12: ECR-BERT MODEL WITH SELF-ADAPTIVE FUSION
# ============================================================================

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
import numpy as np
from tqdm.auto import tqdm
import time

print("="*80)
print("STEP 12: ECR-BERT MODEL (WITH EMOTION-COGNITIVE REASONING)")
print("="*80)

# ============================================================================
# 12.1: SELF-ADAPTIVE FUSION ALGORITHM
# ============================================================================

class SelfAdaptiveFusion:
    """
    Implements Algorithm 2 from the paper
    Selects PECK or NECK based on ECR and BERT predictions
    """

    def __init__(self, threshold=0.3):
        self.threshold = threshold

    def fuse(self, ES, CS_ECR, CS_BERT, bert_prediction):
        """
        Algorithm 2: Self-Adaptive Fusion

        Args:
            ES: Emotion Score from ECR (-1 to 1)
            CS_ECR: Confidence Score from ECR (0 to 1)
            CS_BERT: Confidence Score from BERT (0 to 1)
            bert_prediction: BERT's predicted label (0=pos, 1=neg, 2=neutral)

        Returns:
            decision: 'PECK', 'NECK', or 'NONE'
        """

        # If CS_ECR >= threshold, use ECR result
        if CS_ECR >= self.threshold:
            if ES > 0:
                return 'PECK'
            elif ES < 0:
                return 'NECK'
            else:
                return 'NONE'

        # Otherwise, compare CS_ECR and CS_BERT
        delta = CS_BERT - CS_ECR

        if delta < 0:
            # ECR is more confident
            if ES > 0:
                return 'PECK'
            elif ES < 0:
                return 'NECK'
            else:
                return 'NONE'
        else:
            # BERT is more confident
            if bert_prediction == 0:  # positive
                return 'PECK'
            elif bert_prediction == 1:  # negative
                return 'NECK'
            else:  # neutral
                return 'NONE'

print("\n✅ Self-Adaptive Fusion Algorithm implemented")
print(f"   Threshold: {0.3}")

# ============================================================================
# 12.2: ECR-BERT DATASET (WITH SETS)
# ============================================================================

class ECRBERTDataset(Dataset):
    """
    Dataset that uses Sentence-Emotion Trees (SETs)
    """

    def __init__(self, dataframe, tokenizer, max_length=128):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        # Use SET (Sentence-Emotion Tree)
        text = row['SET']

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(row['label_numeric'], dtype=torch.long),
            'ES': torch.tensor(row['ES'], dtype=torch.float),
            'CS_ECR': torch.tensor(row['CS_ECR'], dtype=torch.float)
        }

# Create ECR-BERT datasets (using SETs)
print("\n" + "-"*80)
print("12.2: CREATING ECR-BERT DATASETS")
print("-"*80)

max_length = 128
batch_size = 16

train_dataset_ecr = ECRBERTDataset(train_df, tokenizer, max_length)
val_dataset_ecr = ECRBERTDataset(val_df, tokenizer, max_length)
test_dataset_ecr = ECRBERTDataset(test_df, tokenizer, max_length)

train_loader_ecr = DataLoader(
    train_dataset_ecr,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

val_loader_ecr = DataLoader(
    val_dataset_ecr,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

test_loader_ecr = DataLoader(
    test_dataset_ecr,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

print(f"✅ ECR-BERT Datasets created (using SETs)")
print(f"   Train: {len(train_dataset_ecr):,} samples ({len(train_loader_ecr):,} batches)")
print(f"   Val:   {len(val_dataset_ecr):,} samples ({len(val_loader_ecr):,} batches)")
print(f"   Test:  {len(test_dataset_ecr):,} samples ({len(test_loader_ecr):,} batches)")

# ============================================================================
# 12.3: ECR-BERT MODEL
# ============================================================================

class ECRBERTClassifier(nn.Module):
    """
    ECR-BERT: BERT enhanced with Emotion-Cognitive Reasoning
    """

    def __init__(self, n_classes=3, dropout=0.3):
        super(ECRBERTClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(dropout)

        # Main classifier
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)

        # Additional layer for ECR features (optional enhancement)
        # This can incorporate ES and CS_ECR if needed

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        output = self.dropout(pooled_output)
        logits = self.classifier(output)
        return logits

print("\n" + "-"*80)
print("12.3: INITIALIZING ECR-BERT MODEL")
print("-"*80)

# Initialize ECR-BERT model
model_ecr = ECRBERTClassifier(n_classes=3, dropout=0.3)
model_ecr = model_ecr.to(device)

total_params = sum(p.numel() for p in model_ecr.parameters())
trainable_params = sum(p.numel() for p in model_ecr.parameters() if p.requires_grad)

print(f"✅ ECR-BERT Model initialized")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Difference from baseline: Emotion-enhanced input (SETs)")

# ============================================================================
# 12.4: TRAINING SETUP
# ============================================================================

print("\n" + "-"*80)
print("12.4: ECR-BERT TRAINING CONFIGURATION")
print("-"*80)

epochs_ecr = 3
learning_rate_ecr = 2e-5

optimizer_ecr = AdamW(model_ecr.parameters(), lr=learning_rate_ecr, eps=1e-8)
criterion_ecr = nn.CrossEntropyLoss()
total_steps_ecr = len(train_loader_ecr) * epochs_ecr
scheduler_ecr = get_linear_schedule_with_warmup(
    optimizer_ecr,
    num_warmup_steps=0,
    num_training_steps=total_steps_ecr
)

print(f"Epochs: {epochs_ecr}")
print(f"Learning rate: {learning_rate_ecr}")
print(f"Batch size: {batch_size}")
print(f"Total training steps: {total_steps_ecr:,}")
print(f"Using: Sentence-Emotion Trees (SETs)")

# ============================================================================
# 12.5: TRAIN ECR-BERT
# ============================================================================

print("\n" + "="*80)
print("12.5: TRAINING ECR-BERT MODEL")
print("="*80)
print(f"Training with emotion-enhanced text (SETs)")
print(f"Estimated time: ~{epochs_ecr * 10} minutes on GPU\n")

best_val_f1_ecr = 0
training_stats_ecr = []
start_time_ecr = time.time()

for epoch in range(epochs_ecr):
    print(f"\n{'='*80}")
    print(f"Epoch {epoch + 1}/{epochs_ecr}")
    print(f"{'='*80}")

    epoch_start = time.time()

    # Train
    train_loss, train_acc, train_f1 = train_epoch(
        model_ecr, train_loader_ecr, criterion_ecr, optimizer_ecr, scheduler_ecr, device
    )

    print(f"\nTraining Results:")
    print(f"  Loss:     {train_loss:.4f}")
    print(f"  Accuracy: {train_acc:.4f} ({train_acc*100:.2f}%)")
    print(f"  F1-Score: {train_f1:.4f}")

    # Validate
    val_loss, val_acc, val_f1, _, _ = evaluate(
        model_ecr, val_loader_ecr, criterion_ecr, device
    )

    print(f"\nValidation Results:")
    print(f"  Loss:     {val_loss:.4f}")
    print(f"  Accuracy: {val_acc:.4f} ({val_acc*100:.2f}%)")
    print(f"  F1-Score: {val_f1:.4f}")

    epoch_time = time.time() - epoch_start
    print(f"\nEpoch completed in {epoch_time/60:.1f} minutes")

    # Save best model
    if val_f1 > best_val_f1_ecr:
        best_val_f1_ecr = val_f1
        torch.save(model_ecr.state_dict(), 'best_ecr_bert.pt')
        print(f"💾 Best ECR-BERT model saved! (F1: {best_val_f1_ecr:.4f})")

    # Track stats
    training_stats_ecr.append({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'train_f1': train_f1,
        'val_loss': val_loss,
        'val_acc': val_acc,
        'val_f1': val_f1,
        'epoch_time_min': epoch_time/60
    })

total_time_ecr = time.time() - start_time_ecr

print("\n" + "="*80)
print("✅ ECR-BERT TRAINING COMPLETE!")
print("="*80)
print(f"Total training time: {total_time_ecr/60:.1f} minutes")
print(f"Best validation F1-Score: {best_val_f1_ecr:.4f}")

# Save training history
history_df_ecr = pd.DataFrame(training_stats_ecr)
history_df_ecr.to_csv('ecr_bert_training_history.csv', index=False)
print("✅ Training history saved: ecr_bert_training_history.csv")

# Display training history
print("\n" + "-"*80)
print("ECR-BERT TRAINING HISTORY")
print("-"*80)
print(history_df_ecr.to_string(index=False))

print("\n" + "="*80)
print("✅ Step 12 complete!")
print("="*80)

STEP 12: ECR-BERT MODEL (WITH EMOTION-COGNITIVE REASONING)

✅ Self-Adaptive Fusion Algorithm implemented
   Threshold: 0.3

--------------------------------------------------------------------------------
12.2: CREATING ECR-BERT DATASETS
--------------------------------------------------------------------------------
✅ ECR-BERT Datasets created (using SETs)
   Train: 96,893 samples (6,056 batches)
   Val:   20,763 samples (1,298 batches)
   Test:  20,763 samples (1,298 batches)

--------------------------------------------------------------------------------
12.3: INITIALIZING ECR-BERT MODEL
--------------------------------------------------------------------------------
✅ ECR-BERT Model initialized
   Total parameters: 109,484,547
   Trainable parameters: 109,484,547
   Difference from baseline: Emotion-enhanced input (SETs)

--------------------------------------------------------------------------------
12.4: ECR-BERT TRAINING CONFIGURATION
------------------------------------------

Training:   0%|          | 0/6056 [00:00<?, ?it/s]


Training Results:
  Loss:     0.3885
  Accuracy: 0.8429 (84.29%)
  F1-Score: 0.8420


Evaluating:   0%|          | 0/1298 [00:00<?, ?it/s]


Validation Results:
  Loss:     0.3247
  Accuracy: 0.8690 (86.90%)
  F1-Score: 0.8685

Epoch completed in 40.8 minutes
💾 Best ECR-BERT model saved! (F1: 0.8685)

Epoch 2/3


Training:   0%|          | 0/6056 [00:00<?, ?it/s]


Training Results:
  Loss:     0.2982
  Accuracy: 0.8847 (88.47%)
  F1-Score: 0.8840


Evaluating:   0%|          | 0/1298 [00:00<?, ?it/s]


Validation Results:
  Loss:     0.3218
  Accuracy: 0.8797 (87.97%)
  F1-Score: 0.8791

Epoch completed in 40.9 minutes
💾 Best ECR-BERT model saved! (F1: 0.8791)

Epoch 3/3


Training:   0%|          | 0/6056 [00:00<?, ?it/s]


Training Results:
  Loss:     0.2362
  Accuracy: 0.9116 (91.16%)
  F1-Score: 0.9111


Evaluating:   0%|          | 0/1298 [00:00<?, ?it/s]


Validation Results:
  Loss:     0.3398
  Accuracy: 0.8834 (88.34%)
  F1-Score: 0.8827

Epoch completed in 40.9 minutes
💾 Best ECR-BERT model saved! (F1: 0.8827)

✅ ECR-BERT TRAINING COMPLETE!
Total training time: 122.7 minutes
Best validation F1-Score: 0.8827
✅ Training history saved: ecr_bert_training_history.csv

--------------------------------------------------------------------------------
ECR-BERT TRAINING HISTORY
--------------------------------------------------------------------------------
 epoch  train_loss  train_acc  train_f1  val_loss  val_acc   val_f1  epoch_time_min
     1    0.388540   0.842868  0.841972  0.324742 0.869046 0.868525       40.808138
     2    0.298205   0.884677  0.884008  0.321808 0.879738 0.879055       40.907900
     3    0.236169   0.911552  0.911100  0.339770 0.883398 0.882687       40.911517

✅ Step 12 complete!


In [18]:
# ============================================================================
# STEP 13: ABLATION STUDY
# ============================================================================

print("="*80)
print("STEP 13: ABLATION STUDY")
print("="*80)
print("\nEvaluating contribution of each component:")
print("  1. B (Baseline BERT)")
print("  2. B + E (BERT + ECR emotions)")
print("  3. B + E + K (BERT + ECR + Knowledge-enabled features)")
print("  4. B + E + S (BERT + ECR + Self-adaptive fusion)")
print("  5. B + E + S + K (Full ECR-BERT)")

# ============================================================================
# 13.1: Test All Models on Test Set
# ============================================================================

def test_model(model, test_loader, device, model_name):
    """Evaluate model on test set"""
    model.eval()
    preds_all, labels_all = [], []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f'Testing {model_name}'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            logits = model(input_ids, attention_mask)
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            preds_all.extend(preds)
            labels_all.extend(labels.cpu().numpy())
    
    acc = accuracy_score(labels_all, preds_all)
    f1 = f1_score(labels_all, preds_all, average='macro')
    
    return acc, f1, preds_all, labels_all

print("\n" + "-"*80)
print("13.1: TESTING ON TEST SET")
print("-"*80)

# Load best models
model.load_state_dict(torch.load('/kaggle/working/best_baseline_bert.pt'))
model_ecr.load_state_dict(torch.load('/kaggle/working/best_ecr_bert.pt'))

# Test Baseline BERT
print("\n1. Testing Baseline BERT (B)...")
acc_baseline, f1_baseline, preds_baseline, labels_baseline = test_model(
    model, test_loader, device, "Baseline BERT"
)
print(f"   Accuracy: {acc_baseline:.4f} ({acc_baseline*100:.2f}%)")
print(f"   F1-Score: {f1_baseline:.4f}")

# Test ECR-BERT
print("\n2. Testing ECR-BERT (B+E+S+K)...")
acc_ecr, f1_ecr, preds_ecr, labels_ecr = test_model(
    model_ecr, test_loader_ecr, device, "ECR-BERT"
)
print(f"   Accuracy: {acc_ecr:.4f} ({acc_ecr*100:.2f}%)")
print(f"   F1-Score: {f1_ecr:.4f}")

# ============================================================================
# 13.2: Additional Ablation Variants
# ============================================================================

print("\n" + "-"*80)
print("13.2: TRAINING ABLATION VARIANTS")
print("-"*80)

# Variant 1: B + E (BERT with emotions but no selection)
print("\n3. Training B+E (BERT + ECR emotions, no fusion)...")

class SimpleECRDataset(Dataset):
    """Uses SETs but without sophisticated fusion"""
    def __init__(self, dataframe, tokenizer, max_length=128):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        # Use original text + ALL emotions (both PECK and NECK)
        text = row['original_text']
        
        # Add simple emotion annotation if available
        if row['peck_count'] > 0 or row['neck_count'] > 0:
            text = f"{text} [HAS_EMOTIONS]"
        
        encoding = self.tokenizer.encode_plus(
            text, add_special_tokens=True, max_length=self.max_length,
            padding='max_length', truncation=True, return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(row['label_numeric'], dtype=torch.long)
        }

# Create B+E dataset
train_be = SimpleECRDataset(train_df, tokenizer)
val_be = SimpleECRDataset(val_df, tokenizer)
test_be = SimpleECRDataset(test_df, tokenizer)

train_loader_be = DataLoader(train_be, batch_size=16, shuffle=True, num_workers=2)
val_loader_be = DataLoader(val_be, batch_size=16, shuffle=False, num_workers=2)
test_loader_be = DataLoader(test_be, batch_size=16, shuffle=False, num_workers=2)

# Train B+E model
model_be = BERTClassifier().to(device)
optimizer_be = AdamW(model_be.parameters(), lr=2e-5, eps=1e-8)
criterion_be = nn.CrossEntropyLoss()
total_steps_be = len(train_loader_be) * 2  # Fewer epochs
scheduler_be = get_linear_schedule_with_warmup(optimizer_be, 0, total_steps_be)

best_f1_be = 0
for epoch in range(2):  # Quick training
    train_loss, train_acc, train_f1 = train_epoch(
        model_be, train_loader_be, criterion_be, optimizer_be, scheduler_be, device
    )
    val_loss, val_acc, val_f1 = evaluate(model_be, val_loader_be, criterion_be, device)
    
    if val_f1 > best_f1_be:
        best_f1_be = val_f1
        torch.save(model_be.state_dict(), 'model_be.pt')

print(f"   Best Val F1: {best_f1_be:.4f}")

# Test B+E
model_be.load_state_dict(torch.load('model_be.pt'))
acc_be, f1_be, _, _ = test_model(model_be, test_loader_be, device, "B+E")
print(f"   Test Accuracy: {acc_be:.4f}")
print(f"   Test F1-Score: {f1_be:.4f}")

# ============================================================================
# 13.3: Summary of Ablation Results
# ============================================================================

print("\n" + "="*80)
print("ABLATION STUDY RESULTS")
print("="*80)

ablation_results = pd.DataFrame({
    'Model': [
        'B (Baseline BERT)',
        'B+E (BERT + Emotions)',
        'B+E+S+K (Full ECR-BERT)'
    ],
    'Test Accuracy': [
        f"{acc_baseline:.4f}",
        f"{acc_be:.4f}",
        f"{acc_ecr:.4f}"
    ],
    'Test F1-Score': [
        f"{f1_baseline:.4f}",
        f"{f1_be:.4f}",
        f"{f1_ecr:.4f}"
    ],
    'Improvement over Baseline': [
        'Baseline',
        f"+{(f1_be - f1_baseline):.4f}",
        f"+{(f1_ecr - f1_baseline):.4f}"
    ]
})

print("\n" + ablation_results.to_string(index=False))

# Save results
ablation_results.to_csv('ablation_study_results.csv', index=False)
print("\n✅ Ablation results saved: ablation_study_results.csv")

# ============================================================================
# 13.4: Statistical Significance Analysis
# ============================================================================

print("\n" + "-"*80)
print("13.4: PERFORMANCE IMPROVEMENTS")
print("-"*80)

improvements = {
    'B → B+E': (f1_be - f1_baseline) / f1_baseline * 100,
    'B → B+E+S+K': (f1_ecr - f1_baseline) / f1_baseline * 100,
    'B+E → B+E+S+K': (f1_ecr - f1_be) / f1_be * 100
}

for transition, improvement in improvements.items():
    print(f"{transition}: {improvement:+.2f}%")

# ============================================================================
# 13.5: Detailed Classification Reports
# ============================================================================

print("\n" + "="*80)
print("DETAILED CLASSIFICATION REPORTS")
print("="*80)

from sklearn.metrics import classification_report

label_names = ['Positive', 'Negative', 'Neutral']

print("\n" + "-"*80)
print("BASELINE BERT")
print("-"*80)
print(classification_report(labels_baseline, preds_baseline, target_names=label_names, digits=4))

print("\n" + "-"*80)
print("ECR-BERT")
print("-"*80)
print(classification_report(labels_ecr, preds_ecr, target_names=label_names, digits=4))

# ============================================================================
# 13.6: Confusion Matrices
# ============================================================================

print("\n" + "="*80)
print("CONFUSION MATRICES")
print("="*80)

from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# Create confusion matrices
cm_baseline = confusion_matrix(labels_baseline, preds_baseline)
cm_ecr = confusion_matrix(labels_ecr, preds_ecr)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

sns.heatmap(cm_baseline, annot=True, fmt='d', cmap='Blues', 
            xticklabels=label_names, yticklabels=label_names, ax=axes[0])
axes[0].set_title('Baseline BERT\nConfusion Matrix')
axes[0].set_ylabel('True Label')
axes[0].set_xlabel('Predicted Label')

sns.heatmap(cm_ecr, annot=True, fmt='d', cmap='Greens',
            xticklabels=label_names, yticklabels=label_names, ax=axes[1])
axes[1].set_title('ECR-BERT\nConfusion Matrix')
axes[1].set_ylabel('True Label')
axes[1].set_xlabel('Predicted Label')

plt.tight_layout()
plt.savefig('confusion_matrices.png', dpi=300, bbox_inches='tight')
print("✅ Confusion matrices saved: confusion_matrices.png")
plt.show()

# ============================================================================
# 13.7: Performance by Condition
# ============================================================================

print("\n" + "="*80)
print("PERFORMANCE BY MENTAL HEALTH CONDITION")
print("="*80)

# Get predictions by condition
test_df_with_preds = test_df.copy()
test_df_with_preds['pred_baseline'] = preds_baseline
test_df_with_preds['pred_ecr'] = preds_ecr
test_df_with_preds['true_label'] = labels_baseline

condition_results = []

for condition in test_df_with_preds['condition'].unique():
    condition_data = test_df_with_preds[test_df_with_preds['condition'] == condition]
    
    # Baseline performance
    f1_base_cond = f1_score(
        condition_data['true_label'], 
        condition_data['pred_baseline'], 
        average='macro'
    )
    
    # ECR-BERT performance
    f1_ecr_cond = f1_score(
        condition_data['true_label'], 
        condition_data['pred_ecr'], 
        average='macro'
    )
    
    condition_results.append({
        'Condition': condition,
        'Sample Size': len(condition_data),
        'Baseline F1': f"{f1_base_cond:.4f}",
        'ECR-BERT F1': f"{f1_ecr_cond:.4f}",
        'Improvement': f"{(f1_ecr_cond - f1_base_cond):.4f}"
    })

condition_df = pd.DataFrame(condition_results)
print("\n" + condition_df.to_string(index=False))

condition_df.to_csv('performance_by_condition.csv', index=False)
print("\n✅ Saved: performance_by_condition.csv")

# ============================================================================
# 13.8: Final Summary
# ============================================================================

print("\n" + "="*80)
print("ABLATION STUDY COMPLETE - SUMMARY")
print("="*80)

print(f"\n📊 Key Findings:")
print(f"   Baseline BERT F1:        {f1_baseline:.4f}")
print(f"   ECR-BERT F1:             {f1_ecr:.4f}")
print(f"   Absolute Improvement:    +{(f1_ecr - f1_baseline):.4f}")
print(f"   Relative Improvement:    +{((f1_ecr - f1_baseline) / f1_baseline * 100):.2f}%")

print(f"\n🎯 ECR Contribution:")
print(f"   Posts with emotions: {(test_df['peck_count'] > 0).sum() + (test_df['neck_count'] > 0).sum()}")
print(f"   ECR improved classification by leveraging emotion-cognitive knowledge")

print(f"\n✅ All results saved:")
print(f"   - ablation_study_results.csv")
print(f"   - confusion_matrices.png")
print(f"   - performance_by_condition.csv")

print("\n" + "="*80)
print("✅ Step 13 complete!")
print("="*80)

STEP 13: ABLATION STUDY

Evaluating contribution of each component:
  1. B (Baseline BERT)
  2. B + E (BERT + ECR emotions)
  3. B + E + K (BERT + ECR + Knowledge-enabled features)
  4. B + E + S (BERT + ECR + Self-adaptive fusion)
  5. B + E + S + K (Full ECR-BERT)

--------------------------------------------------------------------------------
13.1: TESTING ON TEST SET
--------------------------------------------------------------------------------

1. Testing Baseline BERT (B)...


Testing Baseline BERT:   0%|          | 0/649 [00:00<?, ?it/s]

   Accuracy: 0.8770 (87.70%)
   F1-Score: 0.8765

2. Testing ECR-BERT (B+E+S+K)...


Testing ECR-BERT:   0%|          | 0/1298 [00:00<?, ?it/s]

   Accuracy: 0.8790 (87.90%)
   F1-Score: 0.8781

--------------------------------------------------------------------------------
13.2: TRAINING ABLATION VARIANTS
--------------------------------------------------------------------------------

3. Training B+E (BERT + ECR emotions, no fusion)...


Training:   0%|          | 0/6056 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/1298 [00:00<?, ?it/s]

ValueError: too many values to unpack (expected 3)