# Mental Health Advice Guard — Notebook

## Time spent - make extra subcategories that show how you spent your time
- Build: `6` hours
- Write‑up: `1` hours

## What this notebook does
1. Loads data
2. Implements a simple detector
3. Evaluates model
4. Exposes `guard_decide(text)` returning `{label, rationale, action, response_template}`
5. Discusses limitations and next steps. Feel free to elaborate in the report


# Install required libraries

In [None]:
%pip install lightning
%pip install --upgrade "mlflow>=3.1"

Collecting lightning
  Downloading lightning-2.5.5-py3-none-any.whl.metadata (39 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Collecting torchmetrics<3.0,>0.7.0 (from lightning)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.5.5-py3-none-any.whl.metadata (20 kB)
Downloading lightning-2.5.5-py3-none-any.whl (828 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m828.5/828.5 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m23.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pytorch_lightning-2.5.5-py3-none-any.whl (832 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# Import dependencies

In [None]:
import pandas as pd
from datasets import load_dataset
import re
import json
from sklearn.model_selection import train_test_split
from pathlib import Path
from typing import Any
from dataclasses import field
from pydantic.dataclasses import dataclass
import mlflow
import mlflow.pytorch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
from transformers import (
    AutoTokenizer, AutoModel, AutoConfig,
    get_linear_schedule_with_warmup
)
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from typing import Any
from pydantic.dataclasses import dataclass
import warnings
from mlflow.models.signature import infer_signature
import numpy as np

warnings.filterwarnings("ignore")

# Dataset generation from the required task

I have used `Amod/mental_health_counseling_conversations` dataset from huggingface as this the closest dataset I found related to the task.

The `Amod/mental_health_counseling_conversations` was transformed based on phrase and word in contain and labeled accordingly each sample `advice` and `not_advice`.

In [None]:
@dataclass
class MentalHealthDataProcessor:
    """
    Processes mental health counseling conversations to create labeled dataset
    for personalized advice detection.
    """

    output_dir: str = field(default="processed_data")
    """Output dir"""

    advice_keywords: list[str] = field(init=False)
    """Keywords that often indicate personalized advice"""

    info_patterns: list[str] = field(init=False)
    """Patterns that indicate general information (not advice)"""


    def __post_init__(self) -> None:
      self.output_dir = Path(self.output_dir)
      self.output_dir.mkdir(exist_ok=True)

      self.advice_keywords = [
          # Direct recommendations
          "you should", "you need to", "you must", "you have to", "you ought to",
          "i recommend", "i suggest", "i advise", "my recommendation",
          "try this", "consider doing", "what you can do",

          # Medical/treatment advice
          "see a doctor", "talk to your therapist", "consult with", "get help",
          "medication", "therapy", "treatment", "diagnosis",
          "call a hotline", "seek professional help", "go to hospital",

          # Crisis intervention
          "call 911", "emergency", "crisis", "suicide hotline", "immediate help",
          "you're in danger", "seek immediate", "urgent care",

          # Specific behavioral instructions
          "stop doing", "start doing", "avoid", "practice", "exercise",
          "take medication", "don't take", "increase", "decrease"
      ]

      self.info_patterns = [
          # General statements
          r"many people", r"some people", r"research shows", r"studies indicate",
          r"it's common", r"generally", r"typically", r"often",
          r"depression affects", r"anxiety can", r"mental health is",

          # Educational content
          r"what is", r"definition of", r"symptoms include", r"characterized by",
          r"types of", r"causes of", r"risk factors",

          # Empathetic responses without advice
          r"i understand", r"that sounds", r"i hear you", r"i'm sorry",
          r"that must be", r"it's difficult", r"you're not alone"
      ]


    def load_counseling_dataset(self, hugging_face_dataset:str = "Amod/mental_health_counseling_conversations") -> pd.DataFrame:
        """Load the mental health counseling conversations dataset from HuggingFace."""
        print(f"Loading {hugging_face_dataset} dataset")

        # Load dataset from HuggingFace
        dataset = load_dataset(hugging_face_dataset)

        # The dataset only contain training set
        df = pd.DataFrame(dataset['train'])
        print(f"Loaded {len(df)} conversations from HuggingFace dataset")

        return df

    def classify_response(self, response: str) -> dict[str, Any]:
        """
        Classify a response as advice or not_advice based on content analysis.

        Args:
            response: The response text to classify

        Returns:
            Dictionary with classification results
        """
        response_lower = response.lower()

        # Count advice indicators
        advice_score = 0
        matched_advice_keywords = []

        for keyword in self.advice_keywords:
            if keyword in response_lower:
                advice_score += 1
                matched_advice_keywords.append(keyword)

        # Check for info patterns (negative indicators for advice)
        info_score = 0
        matched_info_patterns = []

        for pattern in self.info_patterns:
            if re.search(pattern, response_lower):
                info_score += 1
                matched_info_patterns.append(pattern)

        # Additional heuristics
        has_second_person = bool(re.search(r'\byou\b', response_lower))
        has_imperative = bool(re.search(r'^(try|consider|avoid|stop|start|call|see|talk|get)', response_lower))
        has_modal_verbs = bool(re.search(r'\b(should|must|need to|have to|ought to)\b', response_lower))

        # Scoring logic
        total_advice_score = advice_score
        if has_second_person and (has_imperative or has_modal_verbs):
            total_advice_score += 2

        # Classification decision
        is_advice = total_advice_score > info_score and total_advice_score >= 1
        confidence = min(abs(total_advice_score - info_score) / max(len(response.split()), 1), 1.0)

        return {
            'label': 'advice' if is_advice else 'not_advice',
            'confidence': confidence,
            'advice_score': total_advice_score,
            'info_score': info_score,
            'matched_advice_keywords': matched_advice_keywords,
            'matched_info_patterns': matched_info_patterns,
            'has_second_person': has_second_person,
            'has_imperative': has_imperative,
            'has_modal_verbs': has_modal_verbs
        }

    def process_dataset(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Process the raw dataset to create labeled training data.

        Args:
            df: Raw dataset DataFrame

        Returns:
            Processed DataFrame with labels and features
        """
        print("Processing dataset for advice detection...")

        processed_data = []

        for idx, row in df.iterrows():
            context = row.get('Context', '')
            response = row.get('Response', '')

            if not response or len(response.strip()) < 10:
                continue

            # Classify the response
            classification = self.classify_response(response)

            # Create processed record
            processed_record = {
                'id': idx,
                'context': context,
                'response': response,
                'text': response,  # Main text for classification
                'label': classification['label'],
                'confidence': classification['confidence'],
                'advice_score': classification['advice_score'],
                'info_score': classification['info_score'],
                'response_length': len(response),
                'word_count': len(response.split()),
                'has_second_person': classification['has_second_person'],
                'has_imperative': classification['has_imperative'],
                'has_modal_verbs': classification['has_modal_verbs'],
                'matched_advice_keywords': json.dumps(classification['matched_advice_keywords']),
                'matched_info_patterns': json.dumps(classification['matched_info_patterns'])
            }

            processed_data.append(processed_record)

        processed_df = pd.DataFrame(processed_data)
        print(f"Processed {len(processed_df)} responses")

        return processed_df

    def create_train_val_test_splits(self, df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """
        Create train/validation/test splits with stratification.

        Args:
            df: dataset DataFrame

        Returns:
            Tuple of (train_df, val_df, test_df)
        """
        print("Creating train/validation/test splits...")

        # First split: train (70%) and temp (30%)
        train_df, temp_df = train_test_split(
            df,
            test_size=0.3,
            random_state=42,
            stratify=df['label']
        )

        # Second split: validation (15%) and test (15%) from temp (30%)
        val_df, test_df = train_test_split(
            temp_df,
            test_size=0.5,
            random_state=42,
            stratify=temp_df['label']
        )

        print(f"Train set: {len(train_df)} examples")
        print(f"Validation set: {len(val_df)} examples")
        print(f"Test set: {len(test_df)} examples")

        return train_df, val_df, test_df

    def save_datasets(self, train_df: pd.DataFrame, val_df: pd.DataFrame, test_df: pd.DataFrame):
        """Save the processed datasets to files."""
        print("Saving processed datasets...")

        # Save as CSV
        train_df.to_csv(self.output_dir / "train.csv", index=False)
        val_df.to_csv(self.output_dir / "validation.csv", index=False)
        test_df.to_csv(self.output_dir / "test.csv", index=False)

        # Save as JSON for easy loading
        train_df.to_json(self.output_dir / "train.json", orient='records', indent=2)
        val_df.to_json(self.output_dir / "validation.json", orient='records', indent=2)
        test_df.to_json(self.output_dir / "test.json", orient='records', indent=2)

        # Save dataset statistics
        stats = {
            'total_examples': len(train_df) + len(val_df) + len(test_df),
            'train_size': len(train_df),
            'val_size': len(val_df),
            'test_size': len(test_df),
            'train_advice_ratio': (train_df['label'] == 'advice').mean(),
            'val_advice_ratio': (val_df['label'] == 'advice').mean(),
            'test_advice_ratio': (test_df['label'] == 'advice').mean(),
            'avg_response_length': train_df['response_length'].mean(),
            'avg_word_count': train_df['word_count'].mean()
        }

        with open(self.output_dir / "dataset_stats.json", 'w') as f:
            json.dump(stats, f, indent=2)

        print(f"Datasets saved to {self.output_dir}")
        print(f"Dataset statistics: {stats}")

    def run_full_pipeline(self, target_size: int = 2000):
        """Run the complete data processing pipeline."""
        print("Starting full data processing pipeline...")

        # Step 1: Load raw dataset
        raw_df = self.load_counseling_dataset()

        # Step 2: Process and classify responses
        processed_df = self.process_dataset(raw_df)

        advice_df = processed_df[processed_df['label'] == 'advice']
        not_advice_df = processed_df[processed_df['label'] == 'not_advice']

        print(f"Original distribution - Advice: {len(advice_df)}, Not Advice: {len(not_advice_df)}")

        # Step 3: Create train/val/test splits
        train_df, val_df, test_df = self.create_train_val_test_splits(processed_df)

        # Step 4: Save datasets
        self.save_datasets(train_df, val_df, test_df)

        print("Data processing pipeline completed successfully!")

        return train_df, val_df, test_df


In [None]:
processor = MentalHealthDataProcessor()
train_df, val_df, test_df = processor.run_full_pipeline(target_size=2000)

print("=======================================")
print("DATA PROCESSING COMPLETED")
print("=======================================")
print(f"Train set: {len(train_df)} examples")
print(f"Validation set: {len(val_df)} examples")
print(f"Test set: {len(test_df)} examples")
print("Files saved to: processed_data")
print("=======================================")


Starting full data processing pipeline...
Loading Amod/mental_health_counseling_conversations dataset...


README.md: 0.00B [00:00, ?B/s]

combined_dataset.json: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/3512 [00:00<?, ? examples/s]

Loaded 3512 conversations from HuggingFace dataset
Processing dataset for advice detection...
Processed 3507 responses
Original distribution - Advice: 1656, Not Advice: 1851
Creating train/validation/test splits...
Train set: 2454 examples
Validation set: 526 examples
Test set: 527 examples
Saving processed datasets...
Datasets saved to processed_data
Dataset statistics: {'total_examples': 3507, 'train_size': 2454, 'val_size': 526, 'test_size': 527, 'train_advice_ratio': np.float64(0.47229013854930724), 'val_advice_ratio': np.float64(0.4714828897338403), 'test_advice_ratio': np.float64(0.47248576850094876), 'avg_response_length': np.float64(1027.0664221678892), 'avg_word_count': np.float64(178.25509372453138)}
Data processing pipeline completed successfully!
DATA PROCESSING COMPLETED
Train set: 2454 examples
Validation set: 526 examples
Test set: 527 examples
Files saved to: processed_data/


#PyTorch Lightning implementation for Mental Health Advice Guard
Fine-tunes a small transformer model for binary classification.

As an example classification model I choosed the `https://huggingface.co/nlptown/bert-base-multilingual-uncased-sentiment` as it was one of the most popular sentiment analysis model based on bert. The model have 167 M parameters which is not too large.

Another reason to choose this particular model is this model is finetuned with six languages: English, Dutch, German, French, Spanish, and Italian.


## Training config

In [None]:
@dataclass
class LightningConfig:
    """Configuration for PyTorch Lightning model."""

    model_name: str = field(default="nlptown/bert-base-multilingual-uncased-sentiment")
    """Model name"""

    max_length: int = field(default=256)
    """Max token length"""

    batch_size: int = field(default=64)
    """Batch size"""

    learning_rate: float = field(default=2e-5)
    """Learning rate"""

    weight_decay: float = field(default=0.01)
    """Weight decay"""

    warmup_ratio: float = field(default=0.1)
    """Warmup ratio"""

    max_epochs: int = field(default=5)
    """Max epochs"""

    patience: int = field(default=3)
    """Patience"""

    num_workers: int = field(default=4)
    """Number of workers"""

    seed: int = field(default=42)
    """Seed"""

    output_dir: str = field(default="lightning_models")
    """Output directory"""

    # MLflow configuration
    mlflow_experiment_name: str = field(default="mental_health_advice_guard_experiment")
    """MLflow experiment name"""

    mlflow_tracking_uri: str = field(default="./mlruns")
    """MLflow tracking URI"""

    mlflow_run_name: str = field(default="bert_multilingual_baseline")
    """MLflow run name"""

    enable_mlflow: bool = field(default=True)
    """Enable MLflow"""

## MentalHealthDataset class

In [None]:
@dataclass
class MentalHealthDataset(Dataset):
    """Custom dataset for mental health advice classification."""

    texts: list[str]
    """List of texts"""

    labels: list[int]
    """List of labels"""

    tokenizer: Any
    """Tokenizer"""

    max_length: int = field(default=256)
    """Max length"""

    def __len__(self):
        """Return the length of the dataset."""
        return len(self.texts)

    def __getitem__(self, idx):
        """Get an item from the dataset."""
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

## Classification model using pytorch lightning

In [None]:
class MentalHealthClassifier(pl.LightningModule):
    """PyTorch Lightning module for mental health advice classification."""

    def __init__(self, config: LightningConfig, num_training_steps: int = None):
        super().__init__()
        self.config = config
        self.num_training_steps = num_training_steps

        # Save hyperparameters
        self.save_hyperparameters()

        # Load model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        model_config = AutoConfig.from_pretrained(config.model_name)
        self.transformer = AutoModel.from_pretrained(config.model_name, config=model_config)

        # Classification head
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.transformer.config.hidden_size, 2)  # Binary classification

        # Loss function with class weights for imbalanced data
        self.criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.48, 0.52]))  # Weight advice class higher

        # Metrics storage
        self.validation_step_outputs = []
        self.test_step_outputs = []

    def forward(self, input_ids, attention_mask):
        # Get transformer outputs
        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)

        # Use [CLS] token representation
        pooled_output = outputs.last_hidden_state[:, 0]  # [CLS] token

        # Apply dropout and classification layer
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        return logits

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        logits = self(input_ids, attention_mask)
        loss = self.criterion(logits, labels)

        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()

        # Log metrics
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        logits = self(input_ids, attention_mask)
        loss = self.criterion(logits, labels)

        # Get predictions and probabilities
        preds = torch.argmax(logits, dim=1)
        probs = torch.softmax(logits, dim=1)

        # Store outputs for epoch-end calculations
        self.validation_step_outputs.append({
            'loss': loss,
            'preds': preds,
            'probs': probs,
            'labels': labels
        })

        return loss

    def on_validation_epoch_end(self):
        # Aggregate all validation outputs
        all_preds = torch.cat([x['preds'] for x in self.validation_step_outputs])
        all_probs = torch.cat([x['probs'] for x in self.validation_step_outputs])
        all_labels = torch.cat([x['labels'] for x in self.validation_step_outputs])
        avg_loss = torch.stack([x['loss'] for x in self.validation_step_outputs]).mean()

        # Calculate metrics
        acc = (all_preds == all_labels).float().mean()

        # Convert to numpy for sklearn metrics
        y_true = all_labels.cpu().numpy()
        y_pred = all_preds.cpu().numpy()
        y_prob = all_probs[:, 1].cpu().numpy()  # Probability of advice class

        # Calculate precision, recall, f1
        precision, recall, f1, _ = precision_recall_fscore_support(
            y_true, y_pred, average='weighted', zero_division=0
        )

        # Calculate AUC-ROC
        try:
            auc = roc_auc_score(y_true, y_prob)
        except ValueError:
            auc = 0.0  # In case of single class in validation

        # Log all metrics
        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        self.log('val_precision', precision)
        self.log('val_recall', recall)
        self.log('val_f1', f1)
        self.log('val_auc', auc)

        # Clear outputs
        self.validation_step_outputs.clear()

    def test_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        logits = self(input_ids, attention_mask)
        loss = self.criterion(logits, labels)

        # Get predictions and probabilities
        preds = torch.argmax(logits, dim=1)
        probs = torch.softmax(logits, dim=1)

        # Store outputs
        self.test_step_outputs.append({
            'loss': loss,
            'preds': preds,
            'probs': probs,
            'labels': labels
        })

        return loss

    def on_test_epoch_end(self):
        # Aggregate all test outputs
        all_preds = torch.cat([x['preds'] for x in self.test_step_outputs])
        all_probs = torch.cat([x['probs'] for x in self.test_step_outputs])
        all_labels = torch.cat([x['labels'] for x in self.test_step_outputs])
        avg_loss = torch.stack([x['loss'] for x in self.test_step_outputs]).mean()

        # Calculate metrics
        acc = (all_preds == all_labels).float().mean()

        # Convert to numpy for sklearn metrics
        y_true = all_labels.cpu().numpy()
        y_pred = all_preds.cpu().numpy()
        y_prob = all_probs[:, 1].cpu().numpy()

        # Calculate detailed metrics
        precision, recall, f1, _ = precision_recall_fscore_support(
            y_true, y_pred, average=None, zero_division=0
        )

        # Calculate weighted averages
        precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
            y_true, y_pred, average='weighted', zero_division=0
        )

        auc = roc_auc_score(y_true, y_prob)

        # Log test metrics
        self.log('test_loss', avg_loss)
        self.log('test_acc', acc)
        self.log('test_precision_weighted', precision_weighted)
        self.log('test_recall_weighted', recall_weighted)
        self.log('test_f1_weighted', f1_weighted)
        self.log('test_auc', auc)

        # Log per-class metrics
        self.log('test_precision_not_advice', precision[0])
        self.log('test_precision_advice', precision[1])
        self.log('test_recall_not_advice', recall[0])
        self.log('test_recall_advice', recall[1])
        self.log('test_f1_not_advice', f1[0])
        self.log('test_f1_advice', f1[1])

        # Store results for later use
        self.test_results = {
            'accuracy': acc.item(),
            'precision_weighted': precision_weighted,
            'recall_weighted': recall_weighted,
            'f1_weighted': f1_weighted,
            'auc': auc,
            'precision_per_class': precision.tolist(),
            'recall_per_class': recall.tolist(),
            'f1_per_class': f1.tolist(),
            'predictions': y_pred.tolist(),
            'probabilities': y_prob.tolist(),
            'true_labels': y_true.tolist()
        }

        # Clear outputs
        self.test_step_outputs.clear()

    def configure_optimizers(self):
        # AdamW optimizer
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )

        if self.num_training_steps:
            # Linear warmup scheduler
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=int(self.num_training_steps * self.config.warmup_ratio),
                num_training_steps=self.num_training_steps
            )

            return {
                'optimizer': optimizer,
                'lr_scheduler': {
                    'scheduler': scheduler,
                    'interval': 'step',
                    'frequency': 1
                }
            }

        return optimizer

    def predict(self, text: str) -> dict[str, Any]:
        """
        Predict on a single text input.

        Args:
            text: Input text to classify

        Returns:
            Dictionary with prediction results
        """
        self.eval()

        # Tokenize input
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.config.max_length,
            return_tensors='pt'
        )

        # Move to device
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)

        # Get prediction
        with torch.no_grad():
            logits = self(input_ids, attention_mask)
            probs = torch.softmax(logits, dim=1)
            pred = torch.argmax(logits, dim=1)

        # Convert to labels
        label = "advice" if pred.item() == 1 else "not_advice"
        confidence = probs.max().item()
        advice_prob = probs[0, 1].item()

        return {
            'label': label,
            'confidence': confidence,
            'advice_probability': advice_prob,
            'prediction': pred.item()
        }

    def guard_decide(self, text: str) -> dict[str, Any]:
        """
        Implement the required guard_decide function.

        Args:
            text: Input text to classify

        Returns:
            Dictionary with label, rationale, action, and response_template
        """
        # Get prediction
        result = self.predict(text)

        label = result['label']
        confidence = result['confidence']
        advice_prob = result['advice_probability']

        # Generate rationale
        if label == "advice":
            rationale = f"Text classified as personalized mental health advice (confidence: {confidence:.3f}, advice probability: {advice_prob:.3f}). "
            rationale += "Contains directive language or specific recommendations for an individual."
        else:
            rationale = f"Text classified as general information (confidence: {confidence:.3f}, advice probability: {advice_prob:.3f}). "
            rationale += "Appears to be educational content or empathetic response without specific advice."

        # Determine action based on confidence and label
        if label == "advice" and confidence > 0.85:
            action = "block"
            response_template = ("I understand you're looking for guidance, but I can't provide "
                               "personalized mental health advice. Please consider speaking with "
                               "a qualified mental health professional who can provide appropriate "
                               "support for your specific situation.")
        elif label == "advice" and confidence > 0.65:
            action = "flag"
            response_template = ("This response may contain personalized advice. Please review "
                               "before sending and consider directing the user to professional "
                               "mental health resources if appropriate.")
        else:
            action = "allow"
            response_template = None

        return {
            'label': label,
            'confidence': confidence,
            'advice_probability': advice_prob,
            'rationale': rationale,
            'action': action,
            'response_template': response_template
        }

## Trainer class

In [None]:
@dataclass
class MentalHealthLightningTrainer:
    """Trainer class for the PyTorch Lightning model."""
    config: LightningConfig = None

    def __post_init__(self):
        self.config = self.config or LightningConfig()
        self.output_dir = Path(self.config.output_dir)
        self.output_dir.mkdir(exist_ok=True)

        # Set random seeds
        pl.seed_everything(self.config.seed)

        # Initialize MLflow
        if self.config.enable_mlflow:
            mlflow.set_tracking_uri(self.config.mlflow_tracking_uri)
            mlflow.set_experiment(self.config.mlflow_experiment_name)

        self.model = None
        self.trainer = None
        self.tokenizer = None

    def load_data(self, data_dir: str = "processed_data") -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Load processed datasets."""
        data_path = Path(data_dir)

        if not data_path.exists():
            print("Processed data not found. Running data processing pipeline...")
            processor = MentalHealthDataProcessor(output_dir=data_dir)
            return processor.run_full_pipeline()

        train_df = pd.read_csv(data_path / "train.csv")
        val_df = pd.read_csv(data_path / "validation.csv")
        test_df = pd.read_csv(data_path / "test.csv")

        print("Loaded datasets:")
        print(f"  Train: {len(train_df)} examples")
        print(f"  Validation: {len(val_df)} examples")
        print(f"  Test: {len(test_df)} examples")

        return train_df, val_df, test_df

    def prepare_data_loaders(self, train_df: pd.DataFrame, val_df: pd.DataFrame, test_df: pd.DataFrame) -> tuple[DataLoader, DataLoader, DataLoader]:
        """Prepare PyTorch data loaders."""
        # Initialize tokenizer
        tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
        self.tokenizer = tokenizer

        # Convert labels to binary
        label_map = {'advice': 1, 'not_advice': 0}

        # Create datasets
        train_dataset = MentalHealthDataset(
            texts=train_df['text'].tolist(),
            labels=train_df['label'].map(label_map).tolist(),
            tokenizer=tokenizer,
            max_length=self.config.max_length
        )

        val_dataset = MentalHealthDataset(
            texts=val_df['text'].tolist(),
            labels=val_df['label'].map(label_map).tolist(),
            tokenizer=tokenizer,
            max_length=self.config.max_length
        )

        test_dataset = MentalHealthDataset(
            texts=test_df['text'].tolist(),
            labels=test_df['label'].map(label_map).tolist(),
            tokenizer=tokenizer,
            max_length=self.config.max_length
        )

        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            pin_memory=True
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            pin_memory=True
        )

        return train_loader, val_loader, test_loader

    def train_model(self, train_loader: DataLoader, val_loader: DataLoader) -> MentalHealthClassifier:
        """Train the PyTorch Lightning model."""
        print("=================================")
        print("TRAINING PYTORCH LIGHTNING MODEL")
        print("=================================")

        # Calculate training steps
        num_training_steps = len(train_loader) * self.config.max_epochs

        # Initialize model
        model = MentalHealthClassifier(self.config, num_training_steps)

        # Setup callbacks
        checkpoint_callback = ModelCheckpoint(
            dirpath=self.output_dir / "checkpoints",
            filename='best-model-{epoch:02d}-{val_f1:.3f}',
            save_top_k=1,
            verbose=True,
            monitor='val_f1',
            mode='max'
        )

        early_stopping = EarlyStopping(
            monitor='val_f1',
            mode='max',
            patience=self.config.patience,
            verbose=True
        )

        # Setup MLflow logger
        logger = None
        if self.config.enable_mlflow:
            logger = MLFlowLogger(
                experiment_name=self.config.mlflow_experiment_name,
                tracking_uri=self.config.mlflow_tracking_uri,
                run_name=self.config.mlflow_run_name
            )

        # Initialize trainer
        trainer = pl.Trainer(
            max_epochs=self.config.max_epochs,
            callbacks=[checkpoint_callback, early_stopping],
            logger=logger,
            accelerator='auto',  # Automatically use GPU if available
            devices='auto',
            precision=16,  # Mixed precision for faster training
            gradient_clip_val=1.0,
            log_every_n_steps=10,
            val_check_interval=0.5,  # Validate twice per epoch
            enable_progress_bar=True
        )

        # Train model
        print("Training on device:", trainer.strategy.root_device)
        trainer.fit(model, train_loader, val_loader)

        # Load best model
        best_model = MentalHealthClassifier.load_from_checkpoint(
            checkpoint_callback.best_model_path,
            config=self.config,
            num_training_steps=num_training_steps
        )

        self.model = best_model
        self.trainer = trainer

        return best_model

    def evaluate_model(self, test_loader: DataLoader) -> dict[str, Any]:
        """Evaluate the trained model."""
        print("=================================")
        print("EVALUATING MODEL ON TEST SET")
        print("=================================")

        if self.model is None or self.trainer is None:
            raise ValueError("Model must be trained before evaluation")

        # Test the model
        self.trainer.test(self.model, test_loader)

        # Get results
        results = self.model.test_results

        # Print results
        print("\nTest Results:")
        print(f"  Accuracy: {results['accuracy']:.4f}")
        print(f"  AUC-ROC: {results['auc']:.4f}")
        print(f"  Weighted Precision: {results['precision_weighted']:.4f}")
        print(f"  Weighted Recall: {results['recall_weighted']:.4f}")
        print(f"  Weighted F1-Score: {results['f1_weighted']:.4f}")
        print("\nPer-class Results:")
        print(f"  Not Advice - Precision: {results['precision_per_class'][0]:.4f}, Recall: {results['recall_per_class'][0]:.4f}, F1: {results['f1_per_class'][0]:.4f}")
        print(f"  Advice - Precision: {results['precision_per_class'][1]:.4f}, Recall: {results['recall_per_class'][1]:.4f}, F1: {results['f1_per_class'][1]:.4f}")

        return results

    def save_model(self, model_name: str = "best_model"):
        """Save the trained model and tokenizer."""
        if self.model is None:
            raise ValueError("No model to save")

        model_path = self.output_dir / model_name
        model_path.mkdir(exist_ok=True)

        # Save model
        self.trainer.save_checkpoint(model_path / "model.ckpt")

        # Save tokenizer
        self.tokenizer.save_pretrained(model_path / "tokenizer")

        # Save config
        with open(model_path / "config.json", 'w') as f:
            json.dump(self.config.__dict__, f, indent=2)

        print(f"Model saved to {model_path}")

        # Log model to MLflow if enabled
        if self.config.enable_mlflow:
            self._log_model_to_mlflow(model_path)

    def _log_model_to_mlflow(self, model_path: Path):
        """Log model artifacts and metadata to MLflow."""
        try:
            # Start a nested run to log the model artifacts
            with mlflow.start_run(nested=True) as run:
                # Log hyperparameters
                mlflow.log_params({
                    "model_name": self.config.model_name,
                    "max_length": self.config.max_length,
                    "batch_size": self.config.batch_size,
                    "learning_rate": self.config.learning_rate,
                    "weight_decay": self.config.weight_decay,
                    "warmup_ratio": self.config.warmup_ratio,
                    "max_epochs": self.config.max_epochs,
                    "patience": self.config.patience,
                    "seed": self.config.seed
                })

                # Log model artifacts
                mlflow.log_artifacts(str(model_path), "model")

                # Create a sample input for model signature
                sample_text = "You should see a therapist for your depression."
                sample_input = np.array([sample_text])

                # Get model prediction for signature
                prediction = self.model.predict(sample_text)
                sample_output = np.array([prediction['label']])

                # Infer and log model signature
                signature = infer_signature(sample_input, sample_output)

                # Log the PyTorch model
                mlflow.pytorch.log_model(
                    pytorch_model=self.model,
                    artifact_path="pytorch_model",
                    signature=signature,
                    registered_model_name="mental_health_advice_guard"
                )

                print(f"Model logged to MLflow with run ID: {run.info.run_id}")

        except Exception as e:
            print(f"Warning: Failed to log model to MLflow: {e}")

    def deploy_model_to_mlflow(self, model_version: str = "latest"):
        """Deploy model using MLflow Model Registry."""
        if not self.config.enable_mlflow:
            print("MLflow is not enabled. Cannot deploy model.")
            return None

        try:
            # Get the model from MLflow registry
            model_uri = f"models:/mental_health_advice_guard/{model_version}"

            # Load model for serving
            loaded_model = mlflow.pytorch.load_model(model_uri)

            print(f"Model deployed from MLflow registry: {model_uri}")
            return loaded_model

        except Exception as e:
            print(f"Error deploying model from MLflow: {e}")
            return None

    def run_full_pipeline(self):
        """Run the complete training and evaluation pipeline."""
        print("MENTAL HEALTH ADVICE GUARD - PYTORCH LIGHTNING TRAINING")
        print("="*60)

        # Start MLflow run if enabled
        if self.config.enable_mlflow:
            mlflow.start_run(run_name=self.config.mlflow_run_name)

        try:
            # Load data
            train_df, val_df, test_df = self.load_data()

            # Log dataset info to MLflow
            if self.config.enable_mlflow:
                mlflow.log_metrics({
                    "train_size": len(train_df),
                    "val_size": len(val_df),
                    "test_size": len(test_df)
                })

            # Prepare data loaders
            train_loader, val_loader, test_loader = self.prepare_data_loaders(train_df, val_df, test_df)

            # Train model
            self.train_model(train_loader, val_loader)

            # Evaluate model
            results = self.evaluate_model(test_loader)

            # Log evaluation metrics to MLflow
            if self.config.enable_mlflow:
                mlflow.log_metrics({
                    "test_accuracy": results['accuracy'],
                    "test_auc": results['auc'],
                    "test_precision_weighted": results['precision_weighted'],
                    "test_recall_weighted": results['recall_weighted'],
                    "test_f1_weighted": results['f1_weighted'],
                    "test_precision_advice": results['precision_per_class'][1],
                    "test_recall_advice": results['recall_per_class'][1],
                    "test_f1_advice": results['f1_per_class'][1]
                })

            # Save model
            self.save_model()

        finally:
            # End MLflow run
            if self.config.enable_mlflow:
                mlflow.end_run()

        print("=============================")
        print("TESTING GUARD_DECIDE FUNCTION")
        print("=============================")

        test_examples = [
            "You should definitely see a therapist for your depression symptoms.",
            "Depression is a common mental health condition that affects many people.",
            "I recommend you stop taking your medication immediately.",
            "Many people find that exercise can help improve their mood.",
            "You need to call a crisis hotline right now - this sounds urgent."
        ]

        for example in test_examples:
            result = self.model.guard_decide(example)
            print(f"\nText: {example}")
            print(f"Label: {result['label']} (confidence: {result['confidence']:.3f})")
            print(f"Action: {result['action']}")
            print(f"Rationale: {result['rationale']}")
            if result['response_template']:
                print(f"Response Template: {result['response_template'][:100]}...")

        print("=====================================")
        print("PYTORCH LIGHTNING TRAINING COMPLETED!")
        print("=====================================")

        return results

## Run the PyTorch Lightning training pipeline.

In [None]:
config = LightningConfig()

trainer = MentalHealthLightningTrainer(config)
results = trainer.run_full_pipeline()


INFO:lightning_fabric.utilities.seed:Seed set to 42


MENTAL HEALTH ADVICE GUARD - PYTORCH LIGHTNING TRAINING
Loaded datasets:
  Train: 2454 examples
  Validation: 526 examples
  Test: 527 examples
TRAINING PYTORCH LIGHTNING MODEL


INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Training on device: cuda:0


INFO:pytorch_lightning.callbacks.model_summary:
  | Name        | Type             | Params | Mode 
---------------------------------------------------------
0 | transformer | BertModel        | 167 M  | eval 
1 | dropout     | Dropout          | 0      | train
2 | classifier  | Linear           | 1.5 K  | train
3 | criterion   | CrossEntropyLoss | 0      | train
---------------------------------------------------------
167 M     Trainable params
0         Non-trainable params
167 M     Total params
669.432   Total estimated model params size (MB)
3         Modules in train mode
228       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1 improved. New best score: 0.705
INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 38: 'val_f1' reached 0.70529 (best 0.70529), saving model to '/content/lightning_models/checkpoints/best-model-epoch=00-val_f1=0.705.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1 improved by 0.105 >= min_delta = 0.0. New best score: 0.810
INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 76: 'val_f1' reached 0.81048 (best 0.81048), saving model to '/content/lightning_models/checkpoints/best-model-epoch=00-val_f1=0.810.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1 improved by 0.083 >= min_delta = 0.0. New best score: 0.893
INFO:pytorch_lightning.utilities.rank_zero:Epoch 1, global step 115: 'val_f1' reached 0.89330 (best 0.89330), saving model to '/content/lightning_models/checkpoints/best-model-epoch=01-val_f1=0.893.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1 improved by 0.034 >= min_delta = 0.0. New best score: 0.928
INFO:pytorch_lightning.utilities.rank_zero:Epoch 1, global step 153: 'val_f1' reached 0.92770 (best 0.92770), saving model to '/content/lightning_models/checkpoints/best-model-epoch=01-val_f1=0.928.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1 improved by 0.017 >= min_delta = 0.0. New best score: 0.945
INFO:pytorch_lightning.utilities.rank_zero:Epoch 2, global step 192: 'val_f1' reached 0.94490 (best 0.94490), saving model to '/content/lightning_models/checkpoints/best-model-epoch=02-val_f1=0.945.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 2, global step 230: 'val_f1' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.947
INFO:pytorch_lightning.utilities.rank_zero:Epoch 3, global step 269: 'val_f1' reached 0.94674 (best 0.94674), saving model to '/content/lightning_models/checkpoints/best-model-epoch=03-val_f1=0.947.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 3, global step 307: 'val_f1' was not in top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.949
INFO:pytorch_lightning.utilities.rank_zero:Epoch 4, global step 346: 'val_f1' reached 0.94867 (best 0.94867), saving model to '/content/lightning_models/checkpoints/best-model-epoch=04-val_f1=0.949.ckpt' as top 1


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 4, global step 384: 'val_f1' was not in top 1
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


EVALUATING MODEL ON TEST SET


Testing: |          | 0/? [00:00<?, ?it/s]


Test Results:
  Accuracy: 0.9298
  AUC-ROC: 0.9854
  Weighted Precision: 0.9299
  Weighted Recall: 0.9298
  Weighted F1-Score: 0.9298

Per-class Results:
  Not Advice - Precision: 0.9382, Recall: 0.9281, F1: 0.9331
  Advice - Precision: 0.9206, Recall: 0.9317, F1: 0.9261
Model saved to lightning_models/best_model


Registered model 'mental_health_advice_guard' already exists. Creating a new version of this model...
Created version '2' of model 'mental_health_advice_guard'.


Model logged to MLflow with run ID: d97f3dd188314045b67a9ae7362317db

TESTING GUARD_DECIDE FUNCTION

Text: You should definitely see a therapist for your depression symptoms.
Label: advice (confidence: 0.999)
Action: block
Rationale: Text classified as personalized mental health advice (confidence: 0.999, advice probability: 0.999). Contains directive language or specific recommendations for an individual.
Response Template: I understand you're looking for guidance, but I can't provide personalized mental health advice. Ple...

Text: Depression is a common mental health condition that affects many people.
Label: not_advice (confidence: 0.996)
Action: allow
Rationale: Text classified as general information (confidence: 0.996, advice probability: 0.004). Appears to be educational content or empathetic response without specific advice.

Text: I recommend you stop taking your medication immediately.
Label: advice (confidence: 0.974)
Action: block
Rationale: Text classified as personalized 

## Deploy and test model using MLflow

In [None]:
class MLflowModelServer:
    """MLflow model server for Mental Health Advice Guard."""

    def __init__(self, model_uri: str = "models:/mental_health_advice_guard/latest"):
        """
        Initialize the MLflow model server.

        Args:
            model_uri: MLflow model URI (e.g., "models:/model_name/version")
        """
        self.model_uri = model_uri
        self.model = None
        self.load_model()

    def load_model(self):
        """Load model from MLflow registry."""
        print(f"Loading model from: {self.model_uri}")
        self.model = mlflow.pytorch.load_model(self.model_uri)
        print("Model loaded successfully!")


def quick_test(model_uri: str = "models:/mental_health_advice_guard/latest"):
    """
    Quick test function for notebook cells.

    Args:
        model_uri: MLflow model URI

    Example usage in notebook:
        quick_test()
    """
    print("🧪 Running quick test of the mental health advice guard...")

    try:
        guard = MLflowModelServer(model_uri)
        if guard is None:
            return

        # Test examples
        test_cases = [
            ("✅ General info", "Depression is a common mental health condition."),
            ("⚠️  Personalized advice", "You should stop taking your antidepressants."),
            ("✅ Educational", "Exercise can help improve mood in general."),
            ("⚠️  Crisis advice", "You need to call 911 right now."),
            ("⚠️  Personalized advice", "ho iniziato a prendere antidepressivi.")
        ]

        print("\n📊 Test Results:")
        print("=================================")

        for category, text in test_cases:
            result = guard.model.guard_decide(text)
            action_emoji = {"allow": "✅", "flag": "⚠️", "block": "🚫"}.get(result['action'], "❓")

            print(f"\n{category}")
            print(f"Text: '{text}'")
            print(f"Result: {action_emoji} {result['action'].upper()} (confidence: {result.get('confidence', 0):.2f})")
            print(f"Rationale: {result['rationale']}")

        print("\n🎉 Quick test completed!")

    except Exception as e:
        print(f"❌ Test failed: {e}")



In [None]:
quick_test()

🧪 Running quick test of the mental health advice guard...
Loading model from: models:/mental_health_advice_guard/latest
Model loaded successfully!

📊 Test Results:

✅ General info
Text: 'Depression is a common mental health condition.'
Result: ✅ ALLOW (confidence: 0.98)
Rationale: Text classified as general information (confidence: 0.980, advice probability: 0.020). Appears to be educational content or empathetic response without specific advice.

⚠️  Personalized advice
Text: 'You should stop taking your antidepressants.'
Result: 🚫 BLOCK (confidence: 1.00)
Rationale: Text classified as personalized mental health advice (confidence: 0.998, advice probability: 0.998). Contains directive language or specific recommendations for an individual.

✅ Educational
Text: 'Exercise can help improve mood in general.'
Result: ✅ ALLOW (confidence: 0.67)
Rationale: Text classified as general information (confidence: 0.667, advice probability: 0.333). Appears to be educational content or empathetic re

## Error analysis & limitations
- Where does the detector fail?

  One of the limitation for this detector is it is limited english language as our dataset doesn't have samples in any other language.

- What trade‑offs did you make?  

  As there was no specific dataset for this problem, I created dataset using certain keywords and phrases which is not ideal solution. This can be used as baseline and improve in future.

- What would you try next?

  Next I would like to generate a better dataset with more samples in different language. Also benchmark different models and their performance.
