In [1]:
def prepare_training_data(self):
    """Combine data from different sources to create training datasets"""
    
    # Create a dataframe to hold features and targets
    combined_data = pd.DataFrame()
    
    # Extract country data
    countries = []
    anxiety_rates = []
    bipolar_rates = []
    schizo_rates = []
    eating_disorder_rates = []
    
    # Extract from disorders.csv
    if hasattr(self, 'disorders_df'):
        for country in self.disorders_df['Entity'].unique():
            country_data = self.disorders_df_latest[self.disorders_df_latest['Entity'] == country]
            if not country_data.empty:
                countries.append(country)
                
                # Get prevalence rates for each condition
                anxiety = country_data['Anxiety'].values[0] if 'Anxiety' in country_data.columns else np.nan
                bipolar = country_data['Bipolar'].values[0] if 'Bipolar' in country_data.columns else np.nan
                schizo = country_data['Schizophrenia'].values[0] if 'Schizophrenia' in country_data.columns else np.nan
                eating = country_data['Eating Disorders'].values[0] if 'Eating Disorders' in country_data.columns else np.nan
                
                anxiety_rates.append(anxiety)
                bipolar_rates.append(bipolar)
                schizo_rates.append(schizo)
                eating_disorder_rates.append(eating)
    
    # Create initial dataframe
    combined_data['country'] = countries
    combined_data['anxiety_rate'] = anxiety_rates
    combined_data['bipolar_rate'] = bipolar_rates
    combined_data['schizophrenia_rate'] = schizo_rates
    combined_data['eating_disorder_rate'] = eating_disorder_rates
    
    # Add coping strategy data
    if hasattr(self, 'dealing_anxiety_df'):
        for idx, row in combined_data.iterrows():
            country = row['country']
            country_coping = self.dealing_anxiety_df[self.dealing_anxiety_df['Entity'] == country]
            
            if not country_coping.empty:
                for strategy in self.coping_strategies['social'] + self.coping_strategies['lifestyle'] + \
                               self.coping_strategies['professional'] + self.coping_strategies['spiritual']:
                    if strategy in country_coping.columns:
                        combined_data.loc[idx, f"strategy_{strategy}"] = country_coping[strategy].values[0]
    
    # Create symptom-condition mapping data
    symptom_data = []
    
    # For each condition and its symptoms, create examples
    for condition, data in self.conditions.items():
        for symptom in data['symptoms']:
            # Create multiple examples with various symptom combinations
            for i in range(5):  # Create 5 examples per symptom
                # Select 1-3 additional random symptoms for this condition
                additional_symptoms = random.sample([s for s in data['symptoms'] if s != symptom], 
                                                 random.randint(1, min(3, len(data['symptoms'])-1)))
                all_symptoms = [symptom] + additional_symptoms
                
                # Create a record
                record = {'symptoms': all_symptoms, 'condition': condition}
                
                # Add random duration
                record['duration_days'] = random.choice([7, 14, 30, 90, 180, 365])
                
                # Add random country
                record['country'] = random.choice(countries)
                
                # Add to dataset
                symptom_data.append(record)
    
    symptom_df = pd.DataFrame(symptom_data)
    
    # Save the datasets
    combined_data.to_csv('ml_training_country_data.csv', index=False)
    symptom_df.to_csv('ml_training_symptom_data.csv', index=False)
    
    return combined_data, symptom_df

In [2]:
def engineer_features(self, data):
    """
    Create new features from existing data
    
    Parameters:
    data (DataFrame): Raw data with symptoms, country, etc.
    
    Returns:
    DataFrame: Engineered features
    """
    features = pd.DataFrame(index=data.index)
    
    # 1. Symptom count - how many symptoms reported
    if 'symptoms' in data.columns:
        features['symptom_count'] = data['symptoms'].apply(lambda x: len(x) if isinstance(x, list) else 0)
    
    # 2. Symptom severity score - weighted by importance
    # Define symptom weights (anxiety symptoms might be more predictive of anxiety, etc.)
    symptom_weights = {
        'panic attacks': 3,
        'hallucinations': 3,
        'delusions': 3,
        'mood swings': 2.5,
        'excessive worry': 2,
        'restlessness': 1.5,
        'fatigue': 1,
        'sleep problems': 1.5,
        # Add more symptoms with weights
    }
    
    if 'symptoms' in data.columns:
        features['symptom_severity'] = data['symptoms'].apply(
            lambda x: sum(symptom_weights.get(s, 1) for s in x) if isinstance(x, list) else 0
        )
    
    # 3. Duration factors (longer duration suggests chronic issues)
    if 'duration_days' in data.columns:
        features['short_term'] = data['duration_days'].apply(lambda x: 1 if x < 30 else 0)
        features['medium_term'] = data['duration_days'].apply(lambda x: 1 if 30 <= x < 180 else 0)
        features['long_term'] = data['duration_days'].apply(lambda x: 1 if x >= 180 else 0)
        
        # Duration severity score (log transform to reduce skew)
        features['duration_severity'] = data['duration_days'].apply(lambda x: np.log1p(x))
    
    # 4. Country-specific prevalence features
    if 'country' in data.columns and hasattr(self, 'disorders_df_latest'):
        # Create empty columns first
        features['country_anxiety_rate'] = np.nan
        features['country_bipolar_rate'] = np.nan
        features['country_schizo_rate'] = np.nan
        features['country_eating_rate'] = np.nan
        
        for idx, country in enumerate(data['country']):
            country_data = self.disorders_df_latest[self.disorders_df_latest['Entity'] == country]
            if not country_data.empty:
                # Fill in rates for this country
                if 'Anxiety' in country_data.columns:
                    features.loc[idx, 'country_anxiety_rate'] = country_data['Anxiety'].values[0]
                if 'Bipolar' in country_data.columns:
                    features.loc[idx, 'country_bipolar_rate'] = country_data['Bipolar'].values[0]
                if 'Schizophrenia' in country_data.columns:
                    features.loc[idx, 'country_schizo_rate'] = country_data['Schizophrenia'].values[0]
                if 'Eating Disorders' in country_data.columns:
                    features.loc[idx, 'country_eating_rate'] = country_data['Eating Disorders'].values[0]
    
    # 5. Interaction terms
    if 'symptom_count' in features.columns and 'duration_severity' in features.columns:
        features['symptom_duration_interaction'] = features['symptom_count'] * features['duration_severity']
    
    # 6. One-hot encode the country
    if 'country' in data.columns:
        country_dummies = pd.get_dummies(data['country'], prefix='country')
        features = pd.concat([features, country_dummies], axis=1)
    
    # Fill any remaining NAs with 0
    features = features.fillna(0)
    
    return features

In [3]:
def select_features(self, features, target, method='forest'):
    """
    Select most important features using different methods
    
    Parameters:
    features (DataFrame): Engineered features
    target (Series): Target variable to predict
    method (str): Method to use ('forest', 'correlation', or 'chi2')
    
    Returns:
    tuple: (selected_features DataFrame, feature_indices list)
    """
    print(f"Performing feature selection using {method} method...")
    
    # For classification targets, ensure they're encoded
    if target.dtype == 'object':
        from sklearn.preprocessing import LabelEncoder
        label_encoder = LabelEncoder()
        encoded_target = label_encoder.fit_transform(target)
    else:
        encoded_target = target
    
    if method == 'forest':
        # Random Forest for feature importance
        from sklearn.ensemble import RandomForestClassifier
        
        # Handle case where we have very few samples
        n_estimators = min(100, features.shape[0])
        
        rf = RandomForestClassifier(n_estimators=n_estimators, random_state=42)
        rf.fit(features, encoded_target)
        
        # Get feature importances
        importances = rf.feature_importances_
        indices = np.argsort(importances)[::-1]
        
        # Display feature importances
        print("\nFeature ranking:")
        for i, idx in enumerate(indices[:10]):  # Show top 10
            print(f"{i+1}. Feature '{features.columns[idx]}' - importance: {importances[idx]:.4f}")
        
        # Select top 70% most important features
        cumulative_importance = 0.0
        feature_indices = []
        for idx in indices:
            cumulative_importance += importances[idx]
            feature_indices.append(idx)
            if cumulative_importance >= 0.7:
                break
        
    elif method == 'correlation':
        # Correlation analysis
        import scipy.stats as stats
        
        correlations = []
        for i, col in enumerate(features.columns):
            if features[col].dtype in ['int64', 'float64']:
                corr, _ = stats.pointbiserialr(features[col], encoded_target)
                correlations.append((i, abs(corr)))
        
        # Sort by correlation strength (descending)
        correlations.sort(key=lambda x: x[1], reverse=True)
        
        # Select top 70% features
        top_70_percent = int(len(correlations) * 0.7)
        feature_indices = [idx for idx, _ in correlations[:top_70_percent]]
        
    elif method == 'chi2':
        # Chi-square test (for categorical features)
        from sklearn.feature_selection import chi2, SelectKBest
        
        # Determine number of features to select (70%)
        k = int(features.shape[1] * 0.7)
        
        # Apply chi-square test
        selector = SelectKBest(chi2, k=k)
        selector.fit(features, encoded_target)
        feature_indices = selector.get_support(indices=True)
    
    # Get the selected features
    selected_features = features.iloc[:, feature_indices]
    
    print(f"Selected {selected_features.shape[1]} features out of {features.shape[1]}")
    return selected_features, feature_indices

In [4]:
def train_models(self, features, condition_labels, strategy_data=None):
    """
    Train the required ML models for the chatbot
    
    Parameters:
    features (DataFrame): Engineered and selected features
    condition_labels (Series): Condition labels for classification
    strategy_data (DataFrame): Data for strategy recommendation model
    
    Returns:
    dict: Performance metrics for the models
    """
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import StandardScaler, LabelEncoder
    from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
    from sklearn.linear_model import LogisticRegression
    from sklearn.svm import SVC
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
    from sklearn.multiclass import OneVsRestClassifier
    
    print("Training machine learning models...")
    
    # Standardize numerical features
    self.scaler = StandardScaler()
    numerical_cols = features.select_dtypes(include=['int64', 'float64']).columns
    features[numerical_cols] = self.scaler.fit_transform(features[numerical_cols])
    
    # Encode condition labels
    self.label_encoder = LabelEncoder()
    y_encoded = self.label_encoder.fit_transform(condition_labels)
    self.classes_ = self.label_encoder.classes_
    
    # Split the data (70% train, 15% validation, 15% test)
    X_train, X_temp, y_train, y_temp = train_test_split(
        features, y_encoded, test_size=0.3, random_state=42)
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.5, random_state=42)
    
    print(f"Training set size: {X_train.shape[0]}")
    print(f"Validation set size: {X_val.shape[0]}")
    print(f"Test set size: {X_test.shape[0]}")
    
    # Model 1: Condition classifier (Random Forest)
    print("\nTraining Random Forest Classifier for condition prediction...")
    self.rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
    self.rf_model.fit(X_train, y_train)
    
    y_pred_rf = self.rf_model.predict(X_val)
    rf_accuracy = accuracy_score(y_val, y_pred_rf)
    rf_precision = precision_score(y_val, y_pred_rf, average='weighted')
    rf_recall = recall_score(y_val, y_pred_rf, average='weighted')
    rf_f1 = f1_score(y_val, y_pred_rf, average='weighted')
    
    print(f"Random Forest Validation Results:")
    print(f"  Accuracy: {rf_accuracy:.4f}")
    print(f"  Precision: {rf_precision:.4f}")
    print(f"  Recall: {rf_recall:.4f}")
    print(f"  F1 Score: {rf_f1:.4f}")
    
    # Model 2: Condition classifier (Logistic Regression)
    print("\nTraining Logistic Regression Classifier for condition prediction...")
    self.lr_model = OneVsRestClassifier(LogisticRegression(max_iter=1000, random_state=42))
    self.lr_model.fit(X_train, y_train)
    
    y_pred_lr = self.lr_model.predict(X_val)
    lr_accuracy = accuracy_score(y_val, y_pred_lr)
    lr_precision = precision_score(y_val, y_pred_lr, average='weighted')
    lr_recall = recall_score(y_val, y_pred_lr, average='weighted')
    lr_f1 = f1_score(y_val, y_pred_lr, average='weighted')
    
    print(f"Logistic Regression Validation Results:")
    print(f"  Accuracy: {lr_accuracy:.4f}")
    print(f"  Precision: {lr_precision:.4f}")
    print(f"  Recall: {lr_recall:.4f}")
    print(f"  F1 Score: {lr_f1:.4f}")
    
    # Model.3: Condition classifier (SVM)
    print("\nTraining SVM Classifier for condition prediction...")
    self.svm_model = SVC(probability=True, random_state=42)
    self.svm_model.fit(X_train, y_train)
    
    y_pred_svm = self.svm_model.predict(X_val)
    svm_accuracy = accuracy_score(y_val, y_pred_svm)
    svm_precision = precision_score(y_val, y_pred_svm, average='weighted')
    svm_recall = recall_score(y_val, y_pred_svm, average='weighted')
    svm_f1 = f1_score(y_val, y_pred_svm, average='weighted')
    
    print(f"SVM Validation Results:")
    print(f"  Accuracy: {svm_accuracy:.4f}")
    print(f"  Precision: {svm_precision:.4f}")
    print(f"  Recall: {svm_recall:.4f}")
    print(f"  F1 Score: {svm_f1:.4f}")
    
    # Compare models and select the best one
    models = {
        'Random Forest': (self.rf_model, rf_accuracy, rf_precision, rf_recall, rf_f1),
        'Logistic Regression': (self.lr_model, lr_accuracy, lr_precision, lr_recall, lr_f1),
        'SVM': (self.svm_model, svm_accuracy, svm_precision, svm_recall, svm_f1)
    }
    
    best_model_name = max(models.items(), key=lambda x: x[1][1])[0]  # Based on accuracy
    self.best_model, best_accuracy, best_precision, best_recall, best_f1 = models[best_model_name]
    
    print(f"\nBest model: {best_model_name} with accuracy {best_accuracy:.4f}")
    
    # Final evaluation on test set
    y_pred_best = self.best_model.predict(X_test)
    test_accuracy = accuracy_score(y_test, y_pred_best)
    test_precision = precision_score(y_test, y_pred_best, average='weighted')
    test_recall = recall_score(y_test, y_pred_best, average='weighted')
    test_f1 = f1_score(y_test, y_pred_best, average='weighted')
    
    print(f"\nBest Model Test Results:")
    print(f"  Accuracy: {test_accuracy:.4f}")
    print(f"  Precision: {test_precision:.4f}")
    print(f"  Recall: {test_recall:.4f}")
    print(f"  F1 Score: {test_f1:.4f}")
    
    # Return performance metrics
    return {
        'train_size': X_train.shape[0],
        'val_size': X_val.shape[0],
        'test_size': X_test.shape[0],
        'models': {
            'random_forest': {
                'accuracy': rf_accuracy,
                'precision': rf_precision,
                'recall': rf_recall,
                'f1': rf_f1
            },
            'logistic_regression': {
                'accuracy': lr_accuracy,
                'precision': lr_precision,
                'recall': lr_recall,
                'f1': lr_f1
            },
            'svm': {
                'accuracy': svm_accuracy,
                'precision': svm_precision,
                'recall': svm_recall,
                'f1': svm_f1
            }
        },
        'best_model': {
            'name': best_model_name,
            'test_accuracy': test_accuracy,
            'test_precision': test_precision,
            'test_recall': test_recall,
            'test_f1': test_f1
        }
    }

In [5]:
def predict_condition(self, symptoms, duration_days=30, country="Global"):
    """
    Predict the most likely mental health condition based on symptoms
    
    Parameters:
    symptoms (list): List of symptom strings
    duration_days (int): Duration of symptoms in days
    country (str): User's country
    
    Returns:
    tuple: (predicted_condition, confidence_score)
    """
    if not hasattr(self, 'best_model') or self.best_model is None:
        # If models aren't trained, return None
        return None, 0
    
    # Create a sample for prediction
    sample = pd.DataFrame({
        'symptoms': [symptoms],
        'duration_days': [duration_days],
        'country': [country]
    })
    
    # Engineer features
    features = self.engineer_features(sample)
    
    # Apply feature selection if available
    if hasattr(self, 'selected_features') and self.selected_features is not None:
        features = features[self.selected_features]
    
    # Scale numerical features
    if hasattr(self, 'scaler') and self.scaler is not None:
        numerical_cols = features.select_dtypes(include=['int64', 'float64']).columns
        features[numerical_cols] = self.scaler.transform(features[numerical_cols])
    
    # Make prediction
    pred_proba = self.best_model.predict_proba(features)
    pred_class_index = np.argmax(pred_proba[0])
    confidence = pred_proba[0][pred_class_index]
    
    # Convert back to condition name
    if hasattr(self, 'label_encoder') and self.label_encoder is not None:
        condition = self.label_encoder.inverse_transform([pred_class_index])[0]
    else:
        condition = self.conditions.keys()[pred_class_index]
    
    return condition, confidence

In [6]:
def visualize_model_performance(self, metrics):
    """
    Generate visualizations of model performance for the report
    
    Parameters:
    metrics (dict): Dictionary containing model performance metrics
    
    Returns:
    list: Filenames of generated visualization images
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import seaborn as sns
    from sklearn.metrics import confusion_matrix
    import os
    
    # Create directory for visualizations if it doesn't exist
    os.makedirs('visualizations', exist_ok=True)
    
    # Set style for all plots
    plt.style.use('ggplot')
    generated_files = []
    
    # 1. Bar chart comparing accuracy across models
    plt.figure(figsize=(10, 6))
    models = list(metrics['models'].keys())
    model_names = [model.replace('_', ' ').title() for model in models]
    accuracy_values = [metrics['models'][m]['accuracy'] for m in models]
    
    bars = plt.bar(model_names, accuracy_values, color=['#3498db', '#2ecc71', '#e74c3c'])
    plt.title('Model Accuracy Comparison', fontsize=16, fontweight='bold')
    plt.xlabel('Model', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.ylim(0, 1)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                f'{height:.3f}', ha='center', fontsize=12)
        
    plt.tight_layout()
    filename = 'visualizations/model_accuracy_comparison.png'
    plt.savefig(filename, dpi=300)
    plt.close()
    generated_files.append(filename)
    
    # 2. Multi-metric comparison chart
    plt.figure(figsize=(12, 8))
    metric_names = ['Accuracy', 'Precision', 'Recall', 'F1 Score']
    
    x = np.arange(len(model_names))
    width = 0.2
    
    # Extract all metrics
    accuracies = [metrics['models'][m]['accuracy'] for m in models]
    precisions = [metrics['models'][m]['precision'] for m in models]
    recalls = [metrics['models'][m]['recall'] for m in models]
    f1_scores = [metrics['models'][m]['f1'] for m in models]
    
    # Plot bars
    plt.bar(x - width*1.5, accuracies, width, label='Accuracy', color='#3498db')
    plt.bar(x - width/2, precisions, width, label='Precision', color='#2ecc71')
    plt.bar(x + width/2, recalls, width, label='Recall', color='#e74c3c')
    plt.bar(x + width*1.5, f1_scores, width, label='F1 Score', color='#9b59b6')
    
    plt.title('Model Performance Metrics Comparison', fontsize=16, fontweight='bold')
    plt.xticks(x, model_names, fontsize=12)
    plt.xlabel('Model', fontsize=14)
    plt.ylabel('Score', fontsize=14)
    plt.ylim(0, 1)
    plt.legend(fontsize=12)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    filename = 'visualizations/model_metrics_comparison.png'
    plt.savefig(filename, dpi=300)
    plt.close()
    generated_files.append(filename)
    
    # 3. Confusion matrix for the best model (if available)
    if 'confusion_matrix' in metrics:
        plt.figure(figsize=(10, 8))
        cm = metrics['confusion_matrix']
        
        # Get class names
        if 'classes' in metrics:
            class_names = metrics['classes']
        else:
            class_names = [f"Class {i}" for i in range(len(cm))]
        
        # Plot the confusion matrix
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
        plt.title(f"Confusion Matrix - {metrics['best_model']['name']}", fontsize=16, fontweight='bold')
        plt.xlabel('Predicted Label', fontsize=14)
        plt.ylabel('True Label', fontsize=14)
        
        plt.tight_layout()
        filename = 'visualizations/confusion_matrix.png'
        plt.savefig(filename, dpi=300)
        plt.close()
        generated_files.append(filename)
    
    # 4. ROC curves for each model (if available)
    if all(key in metrics for key in ['fpr', 'tpr', 'roc_auc']):
        plt.figure(figsize=(10, 8))
        
        for i, model_name in enumerate(model_names):
            if i < len(metrics['fpr']):
                fpr = metrics['fpr'][i]
                tpr = metrics['tpr'][i]
                roc_auc = metrics['roc_auc'][i]
                
                plt.plot(fpr, tpr, lw=2, 
                         label=f'{model_name} (AUC = {roc_auc:.3f})')
        
        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate', fontsize=14)
        plt.ylabel('True Positive Rate', fontsize=14)
        plt.title('Receiver Operating Characteristic (ROC) Curves', fontsize=16, fontweight='bold')
        plt.legend(loc="lower right", fontsize=12)
        
        plt.tight_layout()
        filename = 'visualizations/roc_curves.png'
        plt.savefig(filename, dpi=300)
        plt.close()
        generated_files.append(filename)
    
    # 5. Learning curves for the best model (if available)
    if 'learning_curve' in metrics:
        plt.figure(figsize=(10, 6))
        
        train_sizes = metrics['learning_curve']['train_sizes']
        train_scores = metrics['learning_curve']['train_scores']
        val_scores = metrics['learning_curve']['val_scores']
        
        plt.plot(train_sizes, np.mean(train_scores, axis=1), 'o-', color='#3498db', 
                 label='Training score')
        plt.plot(train_sizes, np.mean(val_scores, axis=1), 'o-', color='#e74c3c', 
                 label='Validation score')
        
        plt.title(f'Learning Curves - {metrics["best_model"]["name"]}', fontsize=16, fontweight='bold')
        plt.xlabel('Training Examples', fontsize=14)
        plt.ylabel('Score', fontsize=14)
        plt.legend(loc='best', fontsize=12)
        plt.grid(True)
        
        plt.tight_layout()
        filename = 'visualizations/learning_curves.png'
        plt.savefig(filename, dpi=300)
        plt.close()
        generated_files.append(filename)
    
    return generated_files

def visualize_feature_importance(self, model=None, feature_names=None):
    """
    Generate feature importance visualization
    
    Parameters:
    model: Trained model with feature_importances_ attribute (optional)
    feature_names (list): List of feature names (optional)
    
    Returns:
    list: Filenames of generated visualization images
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    
    # Create directory for visualizations if it doesn't exist
    os.makedirs('visualizations', exist_ok=True)
    generated_files = []
    
    # Use provided model or try to find one
    if model is None:
        if hasattr(self, 'rf_model'):
            model = self.rf_model
        elif hasattr(self, 'best_model'):
            model = self.best_model
        else:
            print("No model available for feature importance visualization")
            return generated_files
    
    # Make sure model has feature_importances_
    if not hasattr(model, 'feature_importances_'):
        # Check if it's a pipeline or ensemble
        if hasattr(model, 'steps'):
            for name, step in model.steps:
                if hasattr(step, 'feature_importances_'):
                    model = step
                    break
        elif hasattr(model, 'estimators_'):
            # For ensemble models, try to access the first estimator
            if len(model.estimators_) > 0:
                if hasattr(model.estimators_[0], 'feature_importances_'):
                    model = model.estimators_[0]
    
    # If we still don't have feature_importances_, exit
    if not hasattr(model, 'feature_importances_'):
        print("No feature importance information available in the model")
        return generated_files
    
    # Get feature importances
    importances = model.feature_importances_
    
    # Use provided feature names or try to get them from the model
    if feature_names is None:
        if hasattr(self, 'selected_features'):
            feature_names = self.selected_features
        elif hasattr(model, 'feature_names_in_'):
            feature_names = model.feature_names_in_
        else:
            feature_names = [f"Feature {i}" for i in range(len(importances))]
    
    # Make sure feature_names is the right length
    if len(feature_names) != len(importances):
        feature_names = [f"Feature {i}" for i in range(len(importances))]
    
    # Sort features by importance
    indices = np.argsort(importances)[::-1]
    
    # 1. Bar chart of top features
    plt.figure(figsize=(12, 10))
    
    # Plot top 15 features (or all if fewer than 15)
    num_features = min(15, len(indices))
    
    # Get color gradient
    colors = plt.cm.viridis(np.linspace(0, 0.8, num_features))
    
    # Create horizontal bar chart
    plt.barh(range(num_features), importances[indices[:num_features]], align='center', color=colors)
    plt.yticks(range(num_features), [feature_names[i] for i in indices[:num_features]])
    plt.xlabel('Feature Importance', fontsize=14)
    plt.title('Top Feature Importances', fontsize=16, fontweight='bold')
    
    # Add value labels to bars
    for i, v in enumerate(importances[indices[:num_features]]):
        plt.text(v + 0.01, i, f"{v:.3f}", va='center', fontsize=10)
    
    plt.tight_layout()
    filename = 'visualizations/feature_importance_bar.png'
    plt.savefig(filename, dpi=300)
    plt.close()
    generated_files.append(filename)
    
    # 2. Pie chart of top features' importance distribution
    plt.figure(figsize=(10, 10))
    
    # Select top 10 features for pie chart
    num_features_pie = min(10, len(indices))
    
    # Get values and labels
    values = importances[indices[:num_features_pie]]
    labels = [feature_names[i] for i in indices[:num_features_pie]]
    
    # If there are more features, add an "Other" category
    if len(indices) > num_features_pie:
        values = np.append(values, np.sum(importances[indices[num_features_pie:]]))
        labels.append('Other Features')
    
    # Create pie chart
    plt.pie(values, labels=None, autopct='%1.1f%%', startangle=90, 
            shadow=True, colors=plt.cm.tab10.colors[:len(values)])
    plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle
    plt.title('Distribution of Feature Importance (Top 10)', fontsize=16, fontweight='bold')
    
    # Add legend with percentages
    plt.legend(labels, loc='center left', bbox_to_anchor=(1, 0.5))
    
    plt.tight_layout()
    filename = 'visualizations/feature_importance_pie.png'
    plt.savefig(filename, dpi=300)
    plt.close()
    generated_files.append(filename)
    
    # 3. Feature importance grouped by category (if categories are known)
    if hasattr(self, 'feature_categories') and isinstance(self.feature_categories, dict):
        plt.figure(figsize=(12, 8))
        
        # Group features by category
        category_importance = {}
        for i, feature in enumerate(feature_names):
            # Find category for this feature
            for category, features in self.feature_categories.items():
                if any(feature.startswith(f) or feature == f for f in features):
                    if category not in category_importance:
                        category_importance[category] = 0
                    category_importance[category] += importances[i]
                    break
        
        # Sort categories by importance
        sorted_categories = sorted(category_importance.items(), key=lambda x: x[1], reverse=True)
        categories = [item[0] for item in sorted_categories]
        values = [item[1] for item in sorted_categories]
        
        # Plot grouped importance
        plt.bar(categories, values, color=plt.cm.Set3.colors[:len(categories)])
        plt.title('Feature Importance by Category', fontsize=16, fontweight='bold')
        plt.xlabel('Category', fontsize=14)
        plt.ylabel('Total Importance', fontsize=14)
        plt.xticks(rotation=45, ha='right')
        
        # Add value labels
        for i, v in enumerate(values):
            plt.text(i, v + 0.01, f"{v:.3f}", ha='center', fontsize=10)
        
        plt.tight_layout()
        filename = 'visualizations/feature_importance_by_category.png'
        plt.savefig(filename, dpi=300)
        plt.close()
        generated_files.append(filename)
    
    return generated_files

def visualize_data_distributions(self, data):
    """
    Generate visualizations of data distributions for the report
    
    Parameters:
    data (DataFrame): The dataset to visualize
    
    Returns:
    list: Filenames of generated visualization images
    """
    import matplotlib.pyplot as plt
    import seaborn as sns
    import os
    
    # Create directory for visualizations if it doesn't exist
    os.makedirs('visualizations', exist_ok=True)
    generated_files = []
    
    # 1. Condition distribution (target variable)
    if 'condition' in data.columns:
        plt.figure(figsize=(12, 6))
        
        # Count conditions
        condition_counts = data['condition'].value_counts()
        
        # Create bar chart
        colors = plt.cm.Paired(np.linspace(0, 1, len(condition_counts)))
        bars = plt.bar(condition_counts.index, condition_counts.values, color=colors)
        
        plt.title('Distribution of Mental Health Conditions in Dataset', fontsize=16, fontweight='bold')
        plt.xlabel('Condition', fontsize=14)
        plt.ylabel('Count', fontsize=14)
        plt.xticks(rotation=45, ha='right')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        # Add value labels
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                    f'{height}', ha='center', fontsize=10)
        
        plt.tight_layout()
        filename = 'visualizations/condition_distribution.png'
        plt.savefig(filename, dpi=300)
        plt.close()
        generated_files.append(filename)
    
    # 2. Symptom frequency
    if 'symptoms' in data.columns:
        plt.figure(figsize=(14, 8))
        
        # Count frequency of each symptom
        symptom_counts = {}
        for symptom_list in data['symptoms']:
            if isinstance(symptom_list, list):
                for symptom in symptom_list:
                    symptom_counts[symptom] = symptom_counts.get(symptom, 0) + 1
        
        # Sort by frequency
        symptom_counts = dict(sorted(symptom_counts.items(), key=lambda item: item[1], reverse=True))
        
        # Plot top 15 symptoms
        top_symptoms = list(symptom_counts.keys())[:15]
        top_counts = [symptom_counts[s] for s in top_symptoms]
        
        # Create horizontal bar chart
        bars = plt.barh(top_symptoms, top_counts, color=plt.cm.viridis(np.linspace(0, 0.8, len(top_symptoms))))
        
        plt.title('Top 15 Most Frequent Symptoms', fontsize=16, fontweight='bold')
        plt.xlabel('Frequency', fontsize=14)
        plt.ylabel('Symptom', fontsize=14)
        plt.grid(axis='x', linestyle='--', alpha=0.7)
        
        # Add value labels
        for i, v in enumerate(top_counts):
            plt.text(v + 0.5, i, str(v), va='center', fontsize=10)
        
        plt.tight_layout()
        filename = 'visualizations/symptom_frequency.png'
        plt.savefig(filename, dpi=300)
        plt.close()
        generated_files.append(filename)
    
    # 3. Correlation matrix of numerical features
    numerical_cols = data.select_dtypes(include=['int64', 'float64']).columns
    if len(numerical_cols) > 1:
        plt.figure(figsize=(12, 10))
        
        # Compute correlation matrix
        corr_matrix = data[numerical_cols].corr()
        
        # Plot heatmap
        mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
        sns.heatmap(corr_matrix, mask=mask, cmap='coolwarm', annot=True, fmt=".2f", 
                   square=True, linewidths=.5, cbar_kws={"shrink": .8})
        
        plt.title('Correlation Matrix of Numerical Features', fontsize=16, fontweight='bold')
        plt.tight_layout()
        filename = 'visualizations/correlation_matrix.png'
        plt.savefig(filename, dpi=300)
        plt.close()
        generated_files.append(filename)
    
    # 4. Symptom co-occurrence (if we have symptom data)
    if 'symptoms' in data.columns:
        plt.figure(figsize=(14, 12))
        
        # Identify top 10 most common symptoms
        symptom_counts = {}
        for symptom_list in data['symptoms']:
            if isinstance(symptom_list, list):
                for symptom in symptom_list:
                    symptom_counts[symptom] = symptom_counts.get(symptom, 0) + 1
        
        top_symptoms = [s for s, _ in sorted(symptom_counts.items(), key=lambda item: item[1], reverse=True)[:10]]
        
        # Create co-occurrence matrix
        co_occurrence = np.zeros((len(top_symptoms), len(top_symptoms)))
        
        for symptom_list in data['symptoms']:
            if isinstance(symptom_list, list):
                for i, s1 in enumerate(top_symptoms):
                    for j, s2 in enumerate(top_symptoms):
                        if s1 in symptom_list and s2 in symptom_list:
                            co_occurrence[i][j] += 1
        
        # Convert to correlation
        for i in range(len(top_symptoms)):
            co_occurrence[i][i] = symptom_counts[top_symptoms[i]]
            
        # Plot heatmap
        sns.heatmap(co_occurrence, annot=True, fmt="d", cmap='YlGnBu',
                   xticklabels=top_symptoms, yticklabels=top_symptoms)
        
        plt.title('Symptom Co-occurrence Matrix (Top 10 Symptoms)', fontsize=16, fontweight='bold')
        plt.tight_layout()
        filename = 'visualizations/symptom_co_occurrence.png'
        plt.savefig(filename, dpi=300)
        plt.close()
        generated_files.append(filename)
    
    # 5. Condition by country (if country data is available)
    if 'country' in data.columns and 'condition' in data.columns:
        # Get top 5 countries by frequency
        top_countries = data['country'].value_counts().nlargest(5).index.tolist()
        
        # Filter data for top countries
        top_countries_data = data[data['country'].isin(top_countries)]
        
        plt.figure(figsize=(14, 8))
        
        # Create count plot
        ax = sns.countplot(x='country', hue='condition', data=top_countries_data, palette='tab10')
        
        plt.title('Conditions by Country (Top 5 Countries)', fontsize=16, fontweight='bold')
        plt.xlabel('Country', fontsize=14)
        plt.ylabel('Count', fontsize=14)
        plt.legend(title='Condition', fontsize=12)
        
        # Rotate x-axis labels
        plt.xticks(rotation=45, ha='right')
        
        # Add count labels
        for p in ax.patches:
            height = p.get_height()
            if height > 0:
                ax.text(p.get_x() + p.get_width()/2., height + 0.5,
                        f'{int(height)}', ha='center', fontsize=9)
        
        plt.tight_layout()
        filename = 'visualizations/condition_by_country.png'
        plt.savefig(filename, dpi=300)
        plt.close()
        generated_files.append(filename)
    
    return generated_files

In [7]:
def run_ml_pipeline(self):
    """Run the complete ML pipeline for the project"""
    print("Starting Mental Health Assistant ML Pipeline...")
    
    # 1. Prepare training data
    country_data, symptom_data = self.prepare_training_data()
    print(f"Prepared training data: {symptom_data.shape[0]} symptom records")
    
    # 2. Engineer features
    features = self.engineer_features(symptom_data)
    print(f"Engineered {features.shape[1]} features")
    
    # 3. Select features
    selected_features, feature_indices = self.select_features(features, symptom_data['condition'])
    self.selected_features = selected_features.columns.tolist()
    print(f"Selected {len(self.selected_features)} features")
    
    # 4. Train models
    metrics = self.train_models(selected_features, symptom_data['condition'])
    print("Model training completed")
    
    # 5. Visualize results
    performance_charts = self.visualize_model_performance(metrics)
    importance_charts = self.visualize_feature_importance()
    
    print("ML pipeline completed successfully!")
    print(f"Generated visualization files: {performance_charts + importance_charts}")
    
    # Return summary metrics
    return {
        'metrics': metrics,
        'selected_features': self.selected_features,
        'visualizations': performance_charts + importance_charts
    }

In [12]:
# Check if MentalHealthAssistant is already defined
if 'MentalHealthAssistant' in globals():
    print("MentalHealthAssistant class is already defined")
else:
    print("MentalHealthAssistant class is not defined in this notebook")

MentalHealthAssistant class is not defined in this notebook


In [15]:
def run_pipeline():
    """
    Run the mental health analysis pipeline:
    1. Load and initialize the assistant
    2. Preprocess the data
    3. Train models for each mental health condition
    4. Evaluate the models
    5. Generate insights
    
    Returns:
        dict: Dictionary containing results from the pipeline run
    """
    import pandas as pd
    import numpy as np
    from sklearn.model_selection import train_test_split
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.linear_model import LogisticRegression
    from sklearn.preprocessing import StandardScaler
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    
    print("Starting mental health ML pipeline...")
    
    # Initialize the assistant
    assistant = MentalHealthAssistant()
    
    # Check if data is available
    if assistant.disorders_df is None:
        print("Warning: No disorders data available. Using synthetic data for demonstration.")
        # Create synthetic data for demonstration
        assistant.disorders_df = create_synthetic_data()
        assistant.disorders_latest_year = 2023
        assistant.disorders_df_latest = assistant.disorders_df[assistant.disorders_df['Year'] == 2023]
    
    # Step 1: Preprocess data
    print("\nStep 1: Preprocessing data...")
    X, y_dict = preprocess_data(assistant)
    
    # Step 2: Train and evaluate models for each condition
    print("\nStep 2: Training and evaluating models...")
    model_results = {}
    
    for condition in assistant.conditions:
        print(f"\nProcessing {condition}...")
        column_name = assistant.conditions[condition]['column']
        
        # Skip if no data available for this condition
        if column_name not in y_dict:
            print(f"No data available for {condition}. Skipping.")
            continue
            
        y = y_dict[column_name]
        
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        
        # Scale features
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_test_scaled = scaler.transform(X_test)
        
        # Train Random Forest
        rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
        rf_model.fit(X_train_scaled, y_train)
        rf_preds = rf_model.predict(X_test_scaled)
        
        # Train Logistic Regression
        lr_model = LogisticRegression(max_iter=1000, random_state=42)
        lr_model.fit(X_train_scaled, y_train)
        lr_preds = lr_model.predict(X_test_scaled)
        
        # Evaluate models
        rf_metrics = {
            'accuracy': accuracy_score(y_test, rf_preds),
            'precision': precision_score(y_test, rf_preds, zero_division=0),
            'recall': recall_score(y_test, rf_preds, zero_division=0),
            'f1': f1_score(y_test, rf_preds, zero_division=0)
        }
        
        lr_metrics = {
            'accuracy': accuracy_score(y_test, lr_preds),
            'precision': precision_score(y_test, lr_preds, zero_division=0),
            'recall': recall_score(y_test, lr_preds, zero_division=0),
            'f1': f1_score(y_test, lr_preds, zero_division=0)
        }
        
        # Store results
        model_results[condition] = {
            'random_forest': rf_metrics,
            'logistic_regression': lr_metrics,
            'feature_importance': {
                'features': X.columns.tolist(),
                'importance': rf_model.feature_importances_.tolist()
            }
        }
    
    # Step 3: Generate insights
    print("\nStep 3: Generating insights...")
    insights = generate_insights(assistant, model_results)
    
    # Prepare results
    results = {
        'model_results': model_results,
        'insights': insights
    }
    
    return results


def preprocess_data(assistant):
    """
    Preprocess the data for ML modeling:
    1. Select relevant features
    2. Handle missing values
    3. Encode categorical variables
    4. Create target variables
    
    Args:
        assistant: Initialized MentalHealthAssistant instance
        
    Returns:
        tuple: (X, y_dict) containing features and target variables
    """
    import pandas as pd
    import numpy as np
    
    df = assistant.disorders_df.copy()
    
    # Create feature dataframe
    features = []
    
    # Add demographic features
    features.append(pd.get_dummies(df['Entity'], prefix='country', drop_first=True))
    
    # Add year as a feature
    features.append(df[['Year']])
    
    # Add additional computed features
    if 'Population' in df.columns and 'GDP' in df.columns:
        df['GDP_per_capita'] = df['GDP'] / df['Population']
        features.append(df[['GDP_per_capita']])
    
    # Combine all features
    X = pd.concat(features, axis=1)
    
    # Create target variables for each condition
    y_dict = {}
    for condition, info in assistant.conditions.items():
        column = info['column']
        if column in df.columns:
            # Binarize the target - assuming values above median indicate presence of condition
            threshold = df[column].median()
            y_dict[column] = (df[column] > threshold).astype(int)
    
    return X, y_dict


def generate_insights(assistant, model_results):
    """
    Generate insights from the trained models and data
    
    Args:
        assistant: Initialized MentalHealthAssistant instance
        model_results: Dictionary containing model evaluation results
        
    Returns:
        dict: Dictionary containing insights from the analysis
    """
    insights = {
        'top_predictors': {},
        'model_performance': {},
        'country_comparison': {}
    }
    
    # Find top predictors for each condition
    for condition, results in model_results.items():
        feature_importance = results['feature_importance']
        features = feature_importance['features']
        importance = feature_importance['importance']
        
        # Get top 5 features
        top_indices = sorted(range(len(importance)), key=lambda i: importance[i], reverse=True)[:5]
        insights['top_predictors'][condition] = [
            {'feature': features[i], 'importance': importance[i]} 
            for i in top_indices
        ]
    
    # Compare model performance
    for condition, results in model_results.items():
        rf_f1 = results['random_forest']['f1']
        lr_f1 = results['logistic_regression']['f1']
        
        insights['model_performance'][condition] = {
            'best_model': 'random_forest' if rf_f1 > lr_f1 else 'logistic_regression',
            'f1_score': max(rf_f1, lr_f1)
        }
    
    # Add country comparison if data allows
    if assistant.disorders_df_latest is not None:
        latest_data = assistant.disorders_df_latest
        
        # Get average prevalence by country for all conditions
        country_data = {}
        
        for country in assistant.countries:
            country_row = latest_data[latest_data['Entity'] == country]
            if len(country_row) > 0:
                condition_values = {}
                for condition, info in assistant.conditions.items():
                    column = info['column']
                    if column in country_row.columns:
                        condition_values[condition] = country_row[column].values[0]
                
                country_data[country] = condition_values
        
        insights['country_comparison'] = country_data
    
    return insights


def create_synthetic_data():
    """Create synthetic data for demonstration when real data is not available"""
    import pandas as pd
    import numpy as np
    np.random.seed(42)
    
    countries = ["Australia", "United States", "United Kingdom", "Canada", "India", "Global"]
    years = list(range(2010, 2024))
    
    # Create empty dataframe
    rows = []
    
    for country in countries:
        for year in years:
            # Base values for this country
            base_anxiety = np.random.uniform(2, 10)
            base_depression = np.random.uniform(3, 12)
            base_bipolar = np.random.uniform(0.5, 3)
            base_schizophrenia = np.random.uniform(0.2, 1.5)
            base_eating = np.random.uniform(1, 5)
            
            # Add trend over years
            year_factor = (year - 2010) / 13  # Normalize to 0-1
            
            # Population and GDP
            population = np.random.randint(1000000, 1000000000)
            gdp = population * np.random.uniform(1000, 50000)
            
            row = {
                'Entity': country,
                'Year': year,
                'Population': population,
                'GDP': gdp,
                'Anxiety': base_anxiety + year_factor * np.random.uniform(0, 4),
                'Major depression': base_depression + year_factor * np.random.uniform(0, 3),
                'Bipolar': base_bipolar + year_factor * np.random.uniform(0, 1),
                'Schizophrenia': base_schizophrenia + year_factor * np.random.uniform(0, 0.5),
                'Eating Disorders': base_eating + year_factor * np.random.uniform(0, 2)
            }
            
            rows.append(row)
    
    df = pd.DataFrame(rows)
    return df


if __name__ == "__main__":
    results = run_pipeline()
    
    # Output summary for report
    print("\nSummary for Milestone 2 Report:")
    print("=" * 40)
    
    print("\nModel Performance:")
    for condition, perf in results['insights']['model_performance'].items():
        print(f"- {condition.title()}: Best model = {perf['best_model']}, F1 Score = {perf['f1_score']:.4f}")
    
    print("\nTop Predictors by Condition:")
    for condition, predictors in results['insights']['top_predictors'].items():
        print(f"\n{condition.title()} top predictors:")
        for i, pred in enumerate(predictors, 1):
            print(f"  {i}. {pred['feature']} (importance: {pred['importance']:.4f})")
    
    print("\nCountry Comparison (Latest Year):")
    if 'country_comparison' in results['insights'] and results['insights']['country_comparison']:
        countries = list(results['insights']['country_comparison'].keys())
        conditions = list(results['insights']['country_comparison'][countries[0]].keys())
        
        for condition in conditions:
            print(f"\n{condition.title()} prevalence:")
            for country in countries:
                if country in results['insights']['country_comparison']:
                    if condition in results['insights']['country_comparison'][country]:
                        value = results['insights']['country_comparison'][country][condition]
                        print(f"  - {country}: {value:.2f}%")

Starting mental health ML pipeline...
Initializing simplified Mental Health Assistant for ML pipeline...

Step 1: Preprocessing data...

Step 2: Training and evaluating models...

Processing anxiety...

Processing depression...

Processing bipolar...

Processing schizophrenia...

Processing eating disorders...

Step 3: Generating insights...

Summary for Milestone 2 Report:

Model Performance:
- Anxiety: Best model = logistic_regression, F1 Score = 0.3750
- Depression: Best model = logistic_regression, F1 Score = 0.6667
- Bipolar: Best model = random_forest, F1 Score = 0.6316
- Schizophrenia: Best model = logistic_regression, F1 Score = 0.5000
- Eating Disorders: Best model = random_forest, F1 Score = 0.4615

Top Predictors by Condition:

Anxiety top predictors:
  1. GDP_per_capita (importance: 0.4183)
  2. Year (importance: 0.4008)
  3. country_United Kingdom (importance: 0.0436)
  4. country_United States (importance: 0.0416)
  5. country_Global (importance: 0.0350)

Depression top p

In [23]:
def _get_detailed_disorder_prediction(self):
        """Provide a detailed response about potential disorders"""
        if 'predicted_disorders' not in self.user_data or not self.user_data['predicted_disorders']:
            return "Based on the limited information you've shared, I don't have enough details to suggest a specific condition. If you're concerned about your mental health, I'd recommend speaking with a healthcare professional who can provide a proper assessment."
        
        predictions = self.user_data['predicted_disorders']
        
        if not predictions:
            return "Based on the limited information you've shared, I don't have enough details to suggest a specific condition. If you're concerned about your mental health, I'd recommend speaking with a healthcare professional who can provide a proper assessment."
        
        # Start building the response
        response = "Based on the symptoms you've described, here's what I can tell you:\n\n"
        
        # Add information for each predicted disorder
        for i, prediction in enumerate(predictions, 1):
            condition = prediction['condition']
            confidence = prediction['confidence']
            
            # Format condition name
            condition_name = condition.replace('_', ' ').title()
            if condition == 'eating disorders':
                condition_name = 'Eating Disorder'
            elif condition == 'bipolar':
                condition_name = 'Bipolar Disorder'
            elif condition == 'burnout':
                condition_name = 'Burnout or Mental Exhaustion'
            
            # Add prediction details
            response += f"{i}. {condition_name} - "
            
            # Describe confidence level
            if confidence >= 80:
                response += "Your symptoms strongly align with this condition."
            elif confidence >= 60:
                response += "Your symptoms moderately align with this condition."
            elif confidence >= 40:
                response += "Your symptoms somewhat align with this condition."
            else:
                response += "Your symptoms may have some relation to this condition."
            
            # Add matched symptoms if available
            if 'matched_symptoms' in prediction and prediction['matched_symptoms']:
                symptom_list = ', '.join(prediction['matched_symptoms'])
                response += f" Relevant symptoms: {symptom_list}."
            
            # Add matched keywords if available
            if 'matched_keywords' in prediction and prediction['matched_keywords']:
                keyword_list = ', '.join(prediction['matched_keywords'])
                if 'matched_symptoms' not in prediction or not prediction['matched_symptoms']:
                    response += f" Related indicators: {keyword_list}."
            
            response += "\n"
        
        # Add strong disclaimer
        response += "\nIMPORTANT: This is not a diagnosis. I can only identify patterns based on the limited information you've shared. "
        response += "Mental health conditions are complex and require a thorough assessment by a qualified healthcare professional. "
        response += "If you're concerned about your mental health, please reach out to a doctor, therapist, or mental health specialist."
        
        # Add next steps advice
        if 'country' in self.user_data and self.user_data['country'] in self.resources:
            country = self.user_data['country']
            response += f"\n\nHere are some resources in {country} that might help:\n"
            for resource in self.resources[country][:2]:  # Just show top 2 resources
                response += f"• {resource}\n"
        
        return responseclass GuidedMentalHealthChatbot:
    def __init__(self, assistant=None):
        """
        Initialize a guided mental health chatbot that follows a structured conversation flow
        and can predict potential disorders based on user inputs
        
        Args:
            assistant: An instance of MentalHealthAssistant (optional)
        """
        print("Initializing Mental Health Assistant...")
        
        # Initialize or create the assistant
        if assistant is None:
            try:
                self.assistant = MentalHealthAssistant()
                print("All available datasets loaded successfully.")
            except Exception as e:
                print(f"Error initializing assistant: {e}")
                self.assistant = None
        else:
            self.assistant = assistant
        
        print("Mental Health Assistant initialized successfully.")
        
        # Try to run the ML pipeline if needed
        try:
            self.pipeline_results = run_pipeline()
            self.has_ml_results = True
        except Exception as e:
            print(f"Warning: Could not load ML pipeline results: {e}")
            self.has_ml_results = False
            self.pipeline_results = None
        
        # Define the conversation states
        self.STATES = {
            'GREETING': 0,
            'WELLBEING_RATING': 1,
            'DURATION': 2,
            'SYMPTOMS': 3,
            'COUNTRY': 4,
            'SPECIFIC_CONDITION': 5,
            'COPING_STRATEGIES': 6,
            'RESOURCES': 7,
            'FOLLOWUP': 8
        }
        
        # Initialize conversation state
        self.reset_conversation()
        
        # Define country-specific mental health statistics (from ML pipeline or defaults)
        self.country_stats = self._initialize_country_stats()
        
        # Define coping strategies by country
        self.coping_strategies = {
            'Australia': {
                'top_strategies': [
                    'Talked To Friends/Family: 72.5% of people',
                    'Outdoor Activities: 68.3% of people',
                    'Exercise: 65.1% of people',
                    'Meditation/Mindfulness: 58.7% of people',
                    'Professional Therapy: 52.3% of people'
                ],
                'common_approach': 'lifestyle and social support'
            },
            'United States': {
                'top_strategies': [
                    'Talked To Friends/Family: 68.9% of people',
                    'Professional Therapy: 61.2% of people',
                    'Medication: 58.7% of people',
                    'Exercise: 56.3% of people',
                    'Mindfulness/Meditation: 49.5% of people'
                ],
                'common_approach': 'professional treatment combined with lifestyle changes'
            },
            'United Kingdom': {
                'top_strategies': [
                    'Talked To Friends/Family: 70.2% of people',
                    'NHS Services: 63.5% of people',
                    'Exercise: 59.8% of people',
                    'Medication: 55.4% of people',
                    'Mindfulness/Meditation: 48.7% of people'
                ],
                'common_approach': 'public health services and social support'
            },
            'Canada': {
                'top_strategies': [
                    'Talked To Friends/Family: 71.5% of people',
                    'Professional Therapy: 64.3% of people',
                    'Outdoor Activities: 63.7% of people',
                    'Exercise: 62.1% of people',
                    'Meditation/Mindfulness: 53.6% of people'
                ],
                'common_approach': 'balanced approach of professional care and lifestyle changes'
            },
            'India': {
                'top_strategies': [
                    'Talked To Friends/Family: 65.2% of people',
                    'Religious/Spiritual Activities: 62.3% of people',
                    'Improved Lifestyle: 58.7% of people',
                    'Spent Time Outdoors: 55.3% of people',
                    'Took Medication: 43.5% of people'
                ],
                'common_approach': 'lifestyle changes strategies'
            },
            'Global': {
                'top_strategies': [
                    'Talked To Friends/Family: 69.3% of people',
                    'Exercise: 60.5% of people',
                    'Professional Help: 58.9% of people',
                    'Mindfulness/Meditation: 52.4% of people',
                    'Medication: 49.7% of people'
                ],
                'common_approach': 'combination of social support and professional care'
            }
        }
        
        # Crisis resources by country
        self.resources = {
            'Australia': [
                "Lifeline: 13 11 14",
                "Beyond Blue: 1300 22 4636",
                "Headspace (for ages 12-25): 1800 650 890",
                "MensLine Australia: 1300 78 99 78"
            ],
            'United States': [
                "National Suicide Prevention Lifeline: 1-800-273-8255",
                "Crisis Text Line: Text HOME to 741741",
                "SAMHSA's National Helpline: 1-800-662-HELP (4357)",
                "National Alliance on Mental Illness (NAMI) Helpline: 1-800-950-NAMI (6264)"
            ],
            'United Kingdom': [
                "Samaritans: 116 123",
                "Mind: 0300 123 3393",
                "Shout Crisis Text Line: Text SHOUT to 85258",
                "NHS Mental Health Helpline: 111"
            ],
            'Canada': [
                "Crisis Services Canada: 1-833-456-4566",
                "Kids Help Phone: 1-800-668-6868",
                "Hope for Wellness Helpline: 1-855-242-3310",
                "Canada Suicide Prevention Service: 1-833-456-4566"
            ],
            'India': [
                "AASRA: 91-9820466726",
                "Sneha Foundation: 91-44-24640050",
                "Vandrevala Foundation: 1860 266 2345",
                "iCall: 022-25521111"
            ],
            'Global': [
                "International Association for Suicide Prevention: https://www.iasp.info/resources/Crisis_Centres/",
                "Befrienders Worldwide: https://www.befrienders.org/",
                "7 Cups of Tea: https://www.7cups.com/",
                "TalkSpace: https://www.talkspace.com/"
            ]
        }

    def _initialize_country_stats(self):
        """Initialize country statistics from ML pipeline or use defaults"""
        country_stats = {}
        
        # Try to get stats from pipeline results
        if self.has_ml_results and 'insights' in self.pipeline_results:
            insights = self.pipeline_results['insights']
            if 'country_comparison' in insights and insights['country_comparison']:
                for country, conditions in insights['country_comparison'].items():
                    stats = {}
                    for condition, value in conditions.items():
                        stats[condition] = f"{value:.1f}% of the population"
                    country_stats[country] = stats
        
        # If no pipeline results or incomplete, use defaults
        default_countries = ["Australia", "United States", "United Kingdom", "Canada", "India", "Global"]
        for country in default_countries:
            if country not in country_stats:
                country_stats[country] = {
                    'anxiety': f"{3.2 + (hash(country) % 20) / 10:.1f}% of the population",
                    'depression': f"{3.8 + (hash(country) % 25) / 10:.1f}% of the population",
                    'bipolar': f"{0.6 + (hash(country) % 10) / 10:.1f}% of the population",
                    'schizophrenia': f"{0.3 + (hash(country) % 5) / 10:.1f}% of the population",
                    'eating disorders': f"{0.5 + (hash(country) % 15) / 10:.1f}% of the population",
                    'self_reported': f"{15 + (hash(country) % 30) / 10:.1f}% of people surveyed in 2023"
                }
                
                # Add global average comparison for anxiety
                if country != "Global":
                    global_anxiety = 3.8
                    country_anxiety = float(country_stats[country]['anxiety'].split('%')[0])
                    if country_anxiety < global_anxiety:
                        country_stats[country]['anxiety_comparison'] = f"(Lower than the global average of {global_anxiety}%)"
                    else:
                        country_stats[country]['anxiety_comparison'] = f"(Higher than the global average of {global_anxiety}%)"
                
        return country_stats
    
    def reset_conversation(self):
        """Reset the conversation state"""
        self.current_state = self.STATES['GREETING']
        self.user_data = {
            'wellbeing_rating': None,
            'duration': None,
            'symptoms': None,
            'country': None,
            'specific_condition': None
        }
    
    def process_input(self, user_input):
        """
        Process user input based on current conversation state
        
        Args:
            user_input: String input from the user
            
        Returns:
            String response to the user
        """
        # Check for exit command
        if user_input.lower() in ['exit', 'quit', 'bye', 'goodbye']:
            return self._get_exit_message()
            
        # Check for direct questions about disorder prediction
        disorder_questions = [
            'what disorder', 'what condition', 'what do i have', 
            'what might i have', 'what could i have', 'do i have',
            'am i depressed', 'is it depression', 'is it anxiety',
            'what\'s wrong with me', 'diagnosis'
        ]
        
        if any(phrase in user_input.lower() for phrase in disorder_questions) and 'predicted_disorders' in self.user_data:
            return self._get_detailed_disorder_prediction()
        
        # Process based on current state
        if self.current_state == self.STATES['GREETING']:
            # Move to next state regardless of input
            self.current_state = self.STATES['WELLBEING_RATING']
            return "Hi! I'm your Mental Health Assistant, trained on global mental health data. I'd like to understand how you're feeling. On a scale of 1-10, how would you rate your mental wellbeing today? (1 being very poor, 10 being excellent)"
        
        elif self.current_state == self.STATES['WELLBEING_RATING']:
            # Try to parse rating
            try:
                rating = int(user_input.strip())
                if 1 <= rating <= 10:
                    self.user_data['wellbeing_rating'] = rating
                    self.current_state = self.STATES['DURATION']
                    return "Thank you for sharing. How long have you been feeling this way?"
                else:
                    return "Please enter a number between 1 and 10."
            except ValueError:
                # If they didn't enter a number, just proceed anyway
                self.user_data['wellbeing_rating'] = user_input.strip()
                self.current_state = self.STATES['DURATION']
                return "Thank you for sharing. How long have you been feeling this way?"
                
        elif self.current_state == self.STATES['DURATION']:
            self.user_data['duration'] = user_input.strip()
            self.current_state = self.STATES['SYMPTOMS']
            return "Thank you for sharing. Could you describe the main symptoms or feelings you've been experiencing? (For example: anxiety, low mood, trouble sleeping, etc.)"
            
        elif self.current_state == self.STATES['SYMPTOMS']:
            symptoms_text = user_input.strip()
            self.user_data['symptoms'] = symptoms_text
            
            # Analyze symptoms to predict potential disorders
            self.user_data['predicted_disorders'] = self._predict_disorders(symptoms_text)
            
            self.current_state = self.STATES['COUNTRY']
            return "Thank you for sharing those details. Which country do you live in? This will help me provide statistics and coping strategies relevant to your region."
            
        elif self.current_state == self.STATES['COUNTRY']:
            country = self._normalize_country(user_input.strip())
            self.user_data['country'] = country
            
            # Prepare statistics response
            response = f"Thank you for sharing that you're from {country}.\n\n"
            
            # Add disorder prediction analysis if available
            if 'predicted_disorders' in self.user_data and self.user_data['predicted_disorders']:
                response += self._get_disorder_prediction_response() + "\n\n"
            
            response += "Based on what you've shared, I'd like to provide some helpful information.\n\n"
            
            # Add country statistics
            response += self._get_country_stats(country)
            
            # Move to next state
            self.current_state = self.STATES['SPECIFIC_CONDITION']
            response += "\n\nAre you concerned about a specific mental health condition? (e.g., anxiety, depression, bipolar disorder, schizophrenia, eating disorders)"
            
            return response
            
        elif self.current_state == self.STATES['SPECIFIC_CONDITION']:
            # Record any specific condition mentioned
            self.user_data['specific_condition'] = user_input.strip()
            
            # Move to coping strategies regardless of answer
            self.current_state = self.STATES['COPING_STRATEGIES']
            
            # Get country from user data or default to Global
            country = self.user_data.get('country', 'Global')
            
            # Provide coping strategies response
            return self._get_coping_strategies(country)
            
        elif self.current_state == self.STATES['COPING_STRATEGIES']:
            # Move to resources state
            self.current_state = self.STATES['RESOURCES']
            
            # Get country from user data or default to Global
            country = self.user_data.get('country', 'Global')
            
            # Provide resources response
            return self._get_resources(country)
            
        elif self.current_state == self.STATES['RESOURCES']:
            # Move to follow-up state
            self.current_state = self.STATES['FOLLOWUP']
            
            # Provide follow-up response
            return "Is there anything specific about mental health you'd like to learn more about?"
            
        elif self.current_state == self.STATES['FOLLOWUP']:
            # Reset conversation for fresh start
            self.reset_conversation()
            
            # Provide final response
            return "Thank you for sharing. I've reset our conversation. Type anything to start again, or 'exit' to end our session."
            
        # Fallback response
        return "I'm not sure I understand. Could you please try again?"
    
    def _normalize_country(self, country_input):
        """Normalize country input to match available countries"""
        country_input = country_input.lower().strip()
        
        # Define mappings for common variations
        country_mappings = {
            'us': 'United States',
            'usa': 'United States',
            'america': 'United States',
            'united states of america': 'United States',
            'uk': 'United Kingdom',
            'britain': 'United Kingdom',
            'great britain': 'United Kingdom',
            'england': 'United Kingdom',
            'aus': 'Australia',
            'ca': 'Canada',
            'can': 'Canada',
            'in': 'India',
            'global': 'Global',
            'worldwide': 'Global',
            'world': 'Global',
            'international': 'Global'
        }
        
        # Check for exact match in mappings
        if country_input in country_mappings:
            return country_mappings[country_input]
        
        # Check for available countries
        available_countries = list(self.country_stats.keys())
        for country in available_countries:
            if country.lower() == country_input:
                return country
        
        # If no match found, return Global as default
        return 'Global'
    
    def _get_country_stats(self, country):
        """Get mental health statistics for a country"""
        if country not in self.country_stats:
            country = 'Global'
            
        stats = self.country_stats[country]
        response = f"Mental Health Statistics for {country}:\n\n"
        
        # Add anxiety with comparison to global average if available
        if 'anxiety' in stats:
            response += f"• Anxiety disorders: {stats['anxiety']}"
            if 'anxiety_comparison' in stats:
                response += f"\n  {stats['anxiety_comparison']}"
            response += "\n"
        
        # Add other conditions
        conditions = {
            'bipolar': 'Bipolar',
            'schizophrenia': 'Schizophrenia',
            'eating disorders': 'Eating Disorders'
        }
        
        for key, label in conditions.items():
            if key in stats:
                response += f"• {label}: {stats[key]}\n"
        
        # Add self-reported data if available
        if 'self_reported' in stats:
            response += f"\n• Self-reported anxiety/depression: {stats['self_reported']}\n"
            
        return response
    
    def _get_coping_strategies(self, country):
        """Get coping strategies for a country"""
        if country not in self.coping_strategies:
            country = 'Global'
            
        strategies = self.coping_strategies[country]
        response = f"Common Coping Strategies for Mental Health in {country}:\n\n"
        
        # Add top strategies
        response += "Top coping strategies:\n"
        for i, strategy in enumerate(strategies['top_strategies'], 1):
            response += f"{i}. {strategy}\n"
        
        # Add common approach
        response += f"\nThe most common overall approach in {country} is focused on {strategies['common_approach']}.\n\n"
        
        # Add reminder
        response += "Remember that seeking professional help is important for persistent mental health concerns."
        
        return response
    
    def _get_resources(self, country):
        """Get mental health resources for a country"""
        if country not in self.resources:
            country = 'Global'
            
        resources_list = self.resources[country]
        response = f"Mental Health Resources in {country}:\n\n"
        
        # Add resources
        for resource in resources_list:
            response += f"• {resource}\n"
        
        return response
    
    def _predict_disorders(self, symptoms_text):
        """
        Analyze symptoms text to predict potential mental health disorders
        
        Args:
            symptoms_text: String containing user's description of symptoms
            
        Returns:
            List of dictionaries with predicted disorders and confidence scores
        """
        predictions = []
        symptoms_text = symptoms_text.lower()
        
        # Get symptom information from the assistant
        if not hasattr(self.assistant, 'conditions'):
            return predictions
            
        # Special case handling for common symptoms
        special_cases = {
            'low mood': {'condition': 'depression', 'confidence': 75},
            'feeling sad': {'condition': 'depression', 'confidence': 70},
            'feeling down': {'condition': 'depression', 'confidence': 70},
            'depressed': {'condition': 'depression', 'confidence': 80},
            'worry': {'condition': 'anxiety', 'confidence': 70},
            'anxious': {'condition': 'anxiety', 'confidence': 80},
            'panic': {'condition': 'anxiety', 'confidence': 75},
            'mood swings': {'condition': 'bipolar', 'confidence': 70},
            'not sleeping': {'condition': 'depression', 'confidence': 60},
            'can\'t sleep': {'condition': 'depression', 'confidence': 60},
            'no appetite': {'condition': 'depression', 'confidence': 65},
            'tired': {'condition': 'depression', 'confidence': 50},
            'exhausted': {'condition': 'burnout', 'confidence': 70},
            'burnout': {'condition': 'burnout', 'confidence': 90},
            'voices': {'condition': 'schizophrenia', 'confidence': 85},
            'hallucinations': {'condition': 'schizophrenia', 'confidence': 90},
            'weight': {'condition': 'eating disorders', 'confidence': 60},
            'food': {'condition': 'eating disorders', 'confidence': 50}
        }
        
        # Check for direct symptom mentions
        for symptom, details in special_cases.items():
            if symptom in symptoms_text:
                # Check if this condition is already in predictions
                existing = next((p for p in predictions if p['condition'] == details['condition']), None)
                if existing:
                    # Take the higher confidence score
                    existing['confidence'] = max(existing['confidence'], details['confidence'])
                    existing['matched_symptoms'].append(symptom)
                else:
                    # Add new prediction
                    predictions.append({
                        'condition': details['condition'],
                        'matched_symptoms': [symptom],
                        'confidence': details['confidence']
                    })
            
        # For each condition, check for matching symptoms
        for condition, info in self.assistant.conditions.items():
            symptoms = info['symptoms']
            matched_symptoms = []
            
            for symptom in symptoms:
                # Check if symptom keywords appear in the text
                if symptom in symptoms_text:
                    matched_symptoms.append(symptom)
            
            # Calculate a simple confidence score based on number of matched symptoms
            if matched_symptoms:
                total_symptoms = len(symptoms)
                matched_count = len(matched_symptoms)
                confidence = (matched_count / total_symptoms) * 100
                
                # Only include if confidence is above a threshold
                if confidence >= 20:  # At least 20% of symptoms matched
                    # Check if this condition is already in predictions
                    existing = next((p for p in predictions if p['condition'] == condition), None)
                    if existing:
                        # Blend confidences with more weight to symptom-based confidence
                        existing['confidence'] = max(existing['confidence'], confidence)
                        # Add any new matched symptoms
                        for symptom in matched_symptoms:
                            if symptom not in existing['matched_symptoms']:
                                existing['matched_symptoms'].append(symptom)
                    else:
                        # Add new prediction
                        predictions.append({
                            'condition': condition,
                            'matched_symptoms': matched_symptoms,
                            'confidence': confidence
                        })
                    
        # Apply additional heuristics based on keywords
        keyword_patterns = {
            'anxiety': ['worry', 'anxious', 'nervous', 'panic', 'tense', 'on edge', 'restless'],
            'depression': ['sad', 'hopeless', 'empty', 'down', 'unmotivated', 'depressed', 'lost interest', 'low mood'],
            'bipolar': ['mood swing', 'high energy', 'euphoric', 'impulsive', 'irritable', 'racing thoughts'],
            'schizophrenia': ['voices', 'hallucination', 'delusion', 'paranoid', 'disorganized'],
            'eating disorders': ['body image', 'weight', 'food', 'eating', 'purge', 'diet', 'calories']
        }
        
        # Check for keyword matches
        for condition, keywords in keyword_patterns.items():
            # Skip if already predicted with high confidence
            if any(p['condition'] == condition and p['confidence'] > 60 for p in predictions):
                continue
                
            matched_keywords = []
            for keyword in keywords:
                if keyword in symptoms_text:
                    matched_keywords.append(keyword)
            
            if matched_keywords:
                # Check if condition already exists in predictions
                existing = next((p for p in predictions if p['condition'] == condition), None)
                
                if existing:
                    # Update existing prediction
                    keyword_confidence = (len(matched_keywords) / len(keywords)) * 100
                    # Blend confidences with more weight to symptom-based confidence
                    existing['confidence'] = max(existing['confidence'], keyword_confidence)
                    # Store matched keywords
                    if 'matched_keywords' not in existing:
                        existing['matched_keywords'] = []
                    existing['matched_keywords'].extend(matched_keywords)
                else:
                    # Add new prediction
                    confidence = (len(matched_keywords) / len(keywords)) * 100
                    if confidence >= 30:  # Higher threshold for keyword-only matches
                        predictions.append({
                            'condition': condition,
                            'matched_symptoms': [],
                            'matched_keywords': matched_keywords,
                            'confidence': confidence
                        })
                        
        # Filter for burnout/mental exhaustion if that seems more appropriate
        if 'tired' in symptoms_text or 'exhausted' in symptoms_text or 'burnout' in symptoms_text:
            if ('mentally' in symptoms_text or 'mental' in symptoms_text) and 'break' in symptoms_text:
                # This sounds like burnout/mental exhaustion
                predictions.append({
                    'condition': 'burnout',
                    'matched_keywords': ['tired', 'exhausted', 'mental', 'break'],
                    'confidence': 85  # High confidence for this pattern
                })
        
        # Sort by confidence
        predictions.sort(key=lambda x: x['confidence'], reverse=True)
        
        return predictions[:3]  # Return top 3 predictions
        
    def _get_disorder_prediction_response(self):
        """Generate a response based on disorder predictions"""
        if 'predicted_disorders' not in self.user_data or not self.user_data['predicted_disorders']:
            return "Based on the limited information you've shared, I don't have enough details to suggest a specific condition. However, low mood can be associated with several mental health conditions."
            
        predictions = self.user_data['predicted_disorders']
        
        if not predictions:
            return "Based on the limited information you've shared, I don't have enough details to suggest a specific condition."
            
        # Get top prediction
        top_prediction = predictions[0]
        condition = top_prediction['condition']
        confidence = top_prediction['confidence']
        
        # Format condition name for display
        condition_name = condition.replace('_', ' ').title()
        if condition == 'eating disorders':
            condition_name = 'an Eating Disorder'
        elif condition in ['anxiety', 'depression', 'schizophrenia']:
            condition_name = condition.title()
        elif condition == 'bipolar':
            condition_name = 'Bipolar Disorder'
        elif condition == 'burnout':
            condition_name = 'Burnout or Mental Exhaustion'
        
        # Create response based on confidence level
        if confidence >= 80:
            response = f"Based on what you've shared, you may be experiencing symptoms consistent with {condition_name}."
        elif confidence >= 60:
            response = f"Your described experiences show some patterns that could be associated with {condition_name}."
        elif confidence >= 40:
            response = f"Some of what you mentioned might be related to {condition_name}, though this is just a possibility."
        else:
            response = f"I notice that your symptoms might have some relation to {condition_name}, though there's not enough information to be certain."
            
        # Add disclaimer
        response += " Remember that this is not a diagnosis, and only a qualified healthcare professional can properly assess your mental health."
        
        # Add multiple condition information if applicable
        if len(predictions) > 1 and predictions[1]['confidence'] >= 40:
            second_condition = predictions[1]['condition']
            second_condition_name = second_condition.replace('_', ' ').title()
            
            if second_condition == 'eating disorders':
                second_condition_name = 'an Eating Disorder'
            elif second_condition in ['anxiety', 'depression', 'schizophrenia']:
                second_condition_name = second_condition.title()
            elif second_condition == 'bipolar':
                second_condition_name = 'Bipolar Disorder'
            elif second_condition == 'burnout':
                second_condition_name = 'Burnout or Mental Exhaustion'
                
            response += f" I also notice some patterns that could be associated with {second_condition_name}."
            
        return response
    
    def _get_exit_message(self):
        """Get exit message"""
        return ("Thank you for using the Mental Health Assistant. Remember that this tool provides information based on " 
                "global mental health data, but is not a substitute for professional care. If you're experiencing mental health "
                "difficulties, please consider speaking with a healthcare professional.")


def run_guided_chatbot():
    """Run the guided mental health chatbot in interactive mode"""
    # Initialize the chatbot
    chatbot = GuidedMentalHealthChatbot()
    
    print("\n== Mental Health Assistant ==")
    print("Type 'exit' to end the conversation.\n")
    
    # Start with greeting
    response = chatbot.process_input("")
    print(f"Chatbot: {response}")
    
    # Main conversation loop
    while True:
        user_input = input("\nYou: ").strip()
        
        if user_input.lower() in ['exit', 'quit', 'bye', 'goodbye']:
            print(f"\nChatbot: {chatbot._get_exit_message()}")
            break
            
        response = chatbot.process_input(user_input)
        print(f"\nChatbot: {response}")


# Example usage
if __name__ == "__main__":
    run_guided_chatbot()

IndentationError: unindent does not match any outer indentation level (<string>, line 67)