In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from wordcloud import WordCloud
from sklearn.feature_extraction.text import TfidfVectorizer

def plot_confusion_matrix(y_true, y_pred, display_labels=None, title="Confusion Matrix", cmap="Blues"):
    """
    Plot confusion matrix for classifier evaluation.

    Args:
        y_true: True labels
        y_pred: Predicted labels
        display_labels (list): Label names for display
        title (str): Plot title
        cmap (str): Colormap for the plot

    Returns:
        matplotlib.figure.Figure: The figure object containing the plot
    """
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=display_labels)

    fig, ax = plt.subplots(figsize=(10, 8))
    disp.plot(cmap=cmap, ax=ax)
    plt.title(title)

    return fig

def plot_roc_curves(y_true, model_probabilities, class_names=None, title="ROC Curves"):
    """
    Plot ROC curves for multi-class classification.

    Args:
        y_true: True labels
        model_probabilities (dict): Dictionary with model names as keys and prediction probabilities as values
        class_names (list): Names of the classes
        title (str): Plot title

    Returns:
        matplotlib.figure.Figure: The figure object containing the plot
    """
    if class_names is None:
        class_names = [f"Class {i}" for i in range(len(np.unique(y_true)))]

    n_classes = len(class_names)

    # Binarize the labels for ROC calculation
    y_true_bin = label_binarize(y_true, classes=range(n_classes))

    # Create figure
    fig, ax = plt.subplots(figsize=(12, 8))

    # Define colors for different classes
    colors = ['red', 'green', 'blue']

    # Calculate and plot ROC curves for each model and class
    for model_name, probs in model_probabilities.items():
        linestyle = '--' if 'DNN' in model_name else '-'

        for i in range(n_classes):
            fpr, tpr, _ = roc_curve(y_true_bin[:, i], probs[:, i])
            roc_auc = auc(fpr, tpr)
            ax.plot(
                fpr, tpr, linestyle=linestyle, color=colors[i],
                label=f'{model_name} - {class_names[i]} (AUC = {roc_auc:.2f})'
            )

    # Add diagonal line (random classifier)
    ax.plot([0, 1], [0, 1], 'k--')

    # Customize plot
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title(title)
    ax.legend(loc="lower right")
    ax.grid(True)
    plt.tight_layout()

    return fig

def plot_training_history(history, title="Training History"):
    """
    Plot training and validation metrics from model training history.

    Args:
        history (dict): Training history dictionary
        title (str): Plot title

    Returns:
        matplotlib.figure.Figure: The figure object containing the plot
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    # Plot accuracy
    if 'accuracy' in history:
        ax1.plot(history['accuracy'], label='Training Accuracy')
    if 'val_accuracy' in history:
        ax1.plot(history['val_accuracy'], label='Validation Accuracy')

    ax1.set_title('Model Accuracy')
    ax1.set_ylabel('Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.legend()
    ax1.grid(True)

    # Plot loss
    if 'loss' in history:
        ax2.plot(history['loss'], label='Training Loss')
    if 'val_loss' in history:
        ax2.plot(history['val_loss'], label='Validation Loss')

    ax2.set_title('Model Loss')
    ax2.set_ylabel('Loss')
    ax2.set_xlabel('Epoch')
    ax2.legend()
    ax2.grid(True)

    plt.suptitle(title)
    plt.tight_layout()

    return fig

def plot_wordcloud_by_cluster(texts_by_cluster, title_prefix="Cluster", figsize=(15, 10)):
    """
    Generate wordclouds for text clusters.

    Args:
        texts_by_cluster (dict): Dictionary with cluster IDs as keys and concatenated texts as values
        title_prefix (str): Prefix for subplot titles
        figsize (tuple): Figure size

    Returns:
        matplotlib.figure.Figure: The figure object containing the wordclouds
    """
    num_clusters = len(texts_by_cluster)
    fig, axes = plt.subplots(1, num_clusters, figsize=figsize)

    # Handle the case with only one cluster
    if num_clusters == 1:
        axes = [axes]

    for i, (cluster_id, text) in enumerate(texts_by_cluster.items()):
        # Create TF-IDF vectorizer
        vectorizer = TfidfVectorizer(stop_words='english', max_features=100)
        tfidf_matrix = vectorizer.fit_transform([text])

        # Get top words and their scores
        feature_names = vectorizer.get_feature_names_out()
        tfidf_scores = tfidf_matrix.toarray().flatten()
        word_scores = {feature_names[j]: tfidf_scores[j] for j in range(len(feature_names))}

        # Create wordcloud
        wordcloud = WordCloud(width=800, height=400, background_color='white')
        wordcloud.generate_from_frequencies(word_scores)

        # Plot
        axes[i].imshow(wordcloud, interpolation='bilinear')
        axes[i].set_title(f"{title_prefix} {cluster_id}", fontsize=16)
        axes[i].axis('off')

    plt.tight_layout()
    return fig

def visualize_cluster_distribution(df, cluster_col, label_col, title="Cluster Distribution by Label"):
    """
    Visualize distribution of labels across clusters.

    Args:
        df (pandas.DataFrame): DataFrame containing cluster and label information
        cluster_col (str): Column name for cluster labels
        label_col (str): Column name for true labels/classes
        title (str): Plot title

    Returns:
        matplotlib.figure.Figure: The figure object containing the plot
    """
    # Create cross-tabulation of clusters and labels
    cross_tab = pd.crosstab(df[cluster_col], df[label_col])

    # Normalize by cluster
    cross_tab_norm = cross_tab.div(cross_tab.sum(axis=1), axis=0)

    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7))

    # Raw counts
    sns.heatmap(cross_tab, annot=True, cmap="YlGnBu", fmt="d", ax=ax1)
    ax1.set_title("Raw Counts")
    ax1.set_xlabel("True Label")
    ax1.set_ylabel("Cluster")

    # Normalized proportions
    sns.heatmap(cross_tab_norm, annot=True, cmap="YlGnBu", fmt=".2f", ax=ax2)
    ax2.set_title("Proportions within Clusters")
    ax2.set_xlabel("True Label")
    ax2.set_ylabel("Cluster")

    plt.suptitle(title, fontsize=16)
    plt.tight_layout()

    return fig

def compare_model_performances(results_dict, metric='accuracy', title="Model Performance Comparison"):
    """
    Compare performance metrics across different models.

    Args:
        results_dict (dict): Dictionary with model names as keys and performance metrics as values
        metric (str): Metric to compare ('accuracy', 'precision', 'recall', or 'f1')
        title (str): Plot title

    Returns:
        matplotlib.figure.Figure: The figure object containing the plot
    """
    # Extract metric values
    models = list(results_dict.keys())
    values = [results_dict[model][metric] for model in models]

    # Create figure
    fig, ax = plt.subplots(figsize=(10, 6))

    # Create bar plot
    bars = ax.bar(models, values, color=['skyblue', 'lightgreen', 'coral'])

    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width()/2.,
            height + 0.01,
            f'{height:.2f}',
            ha='center', va='bottom'
        )

    # Customize plot
    ax.set_ylim(0, 1.0)
    ax.set_ylabel(f'{metric.capitalize()} Score')
    ax.set_title(title)
    ax.grid(axis='y', linestyle='--', alpha=0.7)

    return fig