In [38]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import FixedLocator

## Part 1 - Preprocessing

In [39]:
# General plot configuration
sns.set(style="whitegrid")

# Create a mapping to replace the confirmed case values in the legend
case_labels = {0: "No", 1: "Yes"}

# Mapping to replace the gender values in the legend
gender_labels = {0: "Male", 1: "Female"}

# Mapping for the Evolution labels
evolution_labels = {0: "Cured", 1: "Death"}

# Mapping for the regions
region_labels = {
    'REGION_NORTH': 'North',
    'REGION_NORTHEAST': 'Northeast',
    'REGION_MIDWEST': 'Midwest',
    'REGION_SOUTHEAST': 'Southeast',
    'REGION_SOUTH': 'South'
}

# Columns referring to regions
region_cols = ['REGION_NORTH', 'REGION_NORTHEAST', 'REGION_MIDWEST', 'REGION_SOUTHEAST', 'REGION_SOUTH']

# Columns referring to Race/Color
race_cols = ['WHITE', 'BLACK', 'YELLOW', 'BROWN', 'INDIGENOUS']

# Mapping for the Race/Color labels
race_labels = {
    'WHITE': 'White',
    'BLACK': 'Black',
    'YELLOW': 'Yellow',
    'BROWN': 'Brown',
    'INDIGENOUS': 'Indigenous'
}

# Mapping for the Diabetes labels
diabetes_labels = {0: "No", 1: "Yes"}

In [40]:
# Bar chart of confirmed cases by gender
def plot_preprocessing_confirmed_cases_gender(data: pd.DataFrame):
    plt.figure(figsize=(15, 8))
    ax = sns.countplot(
        x=data['GENDER'], 
        hue=data['CONFIRMED_CASE'], 
        palette='coolwarm',
        alpha=1.0
    )

    # Update the legend to use more intuitive labels
    handles, labels = ax.get_legend_handles_labels()
    labels = [case_labels[int(float(label))] for label in labels]
    plt.legend(handles, labels, title="Confirmed Case", fontsize=10)

    # Fix the x-axis ticks to avoid the warning
    ticks = ax.get_xticks()
    ax.xaxis.set_major_locator(FixedLocator(ticks))

    # Get the current labels and map them to the new gender labels
    current_labels = [label.get_text() for label in ax.get_xticklabels()]
    new_labels = [gender_labels[int(lbl)] for lbl in current_labels]
    ax.set_xticklabels(new_labels)

    plt.title("Distribution of Confirmed Cases by Gender", fontsize=14)
    plt.xlabel("Gender", fontsize=12)
    plt.ylabel("Count", fontsize=12)
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    # Show the plot
    plt.show()

In [41]:
# Bar chart of confirmed cases by region
def plot_preprocessing_confirmed_cases_region(data: pd.DataFrame):
    region_counts = data[region_cols].sum()

    plt.figure(figsize=(15, 8))
    # Adjustment: Set hue equal to x-axis and remove the legend
    ax = sns.barplot(x=region_counts.index, y=region_counts.values, 
                    hue=region_counts.index, palette="coolwarm", legend=False)

    plt.title("Distribution of Cases by Region", fontsize=14)
    plt.xlabel("Region", fontsize=12)
    plt.ylabel("Number of Cases", fontsize=12)
    plt.yticks(fontsize=10)
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    # Fix the x-axis ticks to avoid warnings
    ticks = ax.get_xticks()
    ax.xaxis.set_major_locator(FixedLocator(ticks))

    # Get the current labels, map to new ones and update the ticks
    current_labels = [label.get_text() for label in ax.get_xticklabels()]
    new_labels = [region_labels[label] if label in region_labels else label for label in current_labels]
    ax.set_xticklabels(new_labels, rotation=45, fontsize=10)

    plt.show()

In [42]:
# Boxplot of Age distribution by Gender and Confirmed Cases
def plot_preprocessing_age_distribution(data: pd.DataFrame):
    plt.figure(figsize=(15, 8))
    ax = sns.boxplot(
        x=data['GENDER'], 
        y=data['AGE'], 
        hue=data['CONFIRMED_CASE'], 
        palette='coolwarm'
    )

    # Update the legend to use the defined labels
    handles, labels = ax.get_legend_handles_labels()
    # Convert the labels to int and apply the mapping
    new_case_labels = [case_labels[int(float(lbl))] for lbl in labels]
    plt.legend(handles=handles, labels=new_case_labels, title="Confirmed Case", fontsize=10)

    # Fix the x-axis ticks and update the labels for gender
    ticks = ax.get_xticks()
    ax.xaxis.set_major_locator(FixedLocator(ticks))
    current_gender_labels = [label.get_text() for label in ax.get_xticklabels()]
    new_gender_labels = [gender_labels[int(lbl)] for lbl in current_gender_labels]
    ax.set_xticklabels(new_gender_labels, fontsize=10)

    plt.title("Age Distribution by Gender and Confirmed Cases", fontsize=14)
    plt.xlabel("Gender", fontsize=12)
    plt.ylabel("Age", fontsize=12)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

In [43]:
# Violinplot of Age distribution by Gender and Confirmed Cases
def plot_preprocessing_age_distribution_violin(data: pd.DataFrame):
    plt.figure(figsize=(15, 8))
    ax = sns.violinplot(
        x=data['GENDER'], 
        y=data['AGE'], 
        hue=data['CONFIRMED_CASE'], 
        split=True, 
        palette='coolwarm',
        inner=None  # Remove the internal lines of the violinplot
    )

    # Update the legend to use the defined labels
    handles, labels = ax.get_legend_handles_labels()
    new_case_labels = [case_labels[int(float(lbl))] for lbl in labels]
    plt.legend(handles=handles, labels=new_case_labels, title="Confirmed Case", fontsize=10)

    # Fix the x-axis ticks and update the labels for gender
    ticks = ax.get_xticks()
    ax.xaxis.set_major_locator(FixedLocator(ticks))
    current_gender_labels = [label.get_text() for label in ax.get_xticklabels()]
    new_gender_labels = [gender_labels[int(lbl)] for lbl in current_gender_labels]
    ax.set_xticklabels(new_gender_labels, fontsize=10)

    plt.title("Age Distribution by Gender and Confirmed Case Status", fontsize=14)
    plt.xlabel("Gender", fontsize=12)
    plt.ylabel("Age", fontsize=12)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

In [44]:
# Bar chart of confirmed cases by Race/Color
def plot_preprocessing_confirmed_cases_race(data: pd.DataFrame):
    # Preprocessing the data for Race/Color
    race_data = data[race_cols + ['EVOLUCAO']].melt(id_vars=['EVOLUCAO'], var_name='Race', value_name='Presence')
    race_data = race_data[race_data['Presence'] == 1]

    plt.figure(figsize=(15, 8))
    ax = sns.countplot(
        x='Race', 
        hue='EVOLUCAO', 
        data=race_data, 
        palette='coolwarm'
    )

    # Update the legend to use the defined labels for Evolution
    handles, labels = ax.get_legend_handles_labels()
    new_evolution_labels = [evolution_labels[int(float(lbl))] for lbl in labels]
    plt.legend(handles=handles, labels=new_evolution_labels, title="Outcome", fontsize=10)

    # Fix the x-axis ticks and update the labels for Race/Color
    ticks = ax.get_xticks()
    ax.xaxis.set_major_locator(FixedLocator(ticks))
    current_race_labels = [label.get_text() for label in ax.get_xticklabels()]
    new_race_labels = [race_labels[label] if label in race_labels else label for label in current_race_labels]
    ax.set_xticklabels(new_race_labels, fontsize=10)

    plt.title("Distribution of Case Outcome by Race/Color", fontsize=14)
    plt.xlabel("Race/Color", fontsize=12)
    plt.ylabel("Case Count", fontsize=12)
    plt.yticks(fontsize=10)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

In [45]:
# Histogram of Age in cases with evolution to Death
def plot_preprocessing_age_distribution_death(data: pd.DataFrame):

    # Filter the data to include only cases where evolution equals 1 (patient died)
    evolucao_data = data[data['EVOLUCAO'] == 1]

    plt.figure(figsize=(15, 8))
    ax = sns.histplot(
        evolucao_data['AGE'],
        bins=20,
        kde=True,
        color='royalblue',
        alpha=0.7
    )
    if ax.lines:
        ax.lines[-1].set_label("KDE - Kernel Density Estimate")
        
    plt.title("Age Distribution in Cases Resulting in Death", fontsize=14)
    plt.xlabel("Age", fontsize=12)
    plt.ylabel("Number of Cases", fontsize=12)
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.legend(fontsize=10)
    plt.show()

In [46]:
# Bar chart of disease evolution by Gender
def plot_preprocessing_disease_evolution_gender(data: pd.DataFrame):
    plt.figure(figsize=(15, 8))
    ax = sns.countplot(x='GENDER', hue='EVOLUCAO', data=data, palette='coolwarm')

    # Update the legend to use the defined labels for Evolution
    handles, labels = ax.get_legend_handles_labels()
    new_evolution_labels = [evolution_labels[int(float(lbl))] for lbl in labels]
    plt.legend(handles=handles, labels=new_evolution_labels, title="Outcome", fontsize=10)

    # Get the original x-axis labels (should be "0" and "1")
    current_gender_labels = [label.get_text() for label in ax.get_xticklabels()]
    # Convert to integers for mapping
    original_gender_order = [int(lbl) for lbl in current_gender_labels]

    # Update the x-axis labels to the mapped names (e.g., Male, Female)
    ticks = ax.get_xticks()
    ax.xaxis.set_major_locator(FixedLocator(ticks))
    new_gender_labels = [gender_labels[gender] for gender in original_gender_order]
    ax.set_xticklabels(new_gender_labels, fontsize=10)

    plt.title("Distribution of Disease Outcome by Gender", fontsize=14)
    plt.xlabel("Gender", fontsize=12)
    plt.ylabel("Number of Cases", fontsize=12)
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

In [47]:
# Bar chart of confirmed cases by Diabetes
def plot_preprocessing_confirmed_cases_diabetes(data: pd.DataFrame):
    plt.figure(figsize=(15, 8))
    ax = sns.countplot(x='DIABETES', hue='EVOLUCAO', data=data, palette='coolwarm')

    # Update the legend to use the defined labels for Evolution
    handles, labels = ax.get_legend_handles_labels()
    new_evolution_labels = [evolution_labels[int(float(lbl))] for lbl in labels]
    plt.legend(handles=handles, labels=new_evolution_labels, title="Outcome", fontsize=10)

    # Update the x-axis labels for the DIABETES variable
    # Assuming DIABETES is encoded as 0 and 1 and we want to display "No" and "Yes"
    current_diabetes_labels = [label.get_text() for label in ax.get_xticklabels()]
    original_diabetes_order = [int(lbl) for lbl in current_diabetes_labels]
    ticks = ax.get_xticks()
    ax.xaxis.set_major_locator(FixedLocator(ticks))
    new_diabetes_labels = [diabetes_labels[val] for val in original_diabetes_order]
    ax.set_xticklabels(new_diabetes_labels, fontsize=10)

    plt.title("Distribution of Disease Outcome by Comorbidity (Diabetes)", fontsize=14)
    plt.xlabel("Diabetes", fontsize=12)
    plt.ylabel("Number of Cases", fontsize=12)
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.show()

In [48]:
# Bar chart of confirmed cases by Region with percentages of Evolution to Death
def plot_preprocessing_confirmed_cases_region_evolution(data: pd.DataFrame):
    # Summing up the cases for each region
    region_counts = data[region_cols].sum()

    # Counting evolutions for each region (cases where EVOLUCAO == 1)
    region_evolution = data[region_cols + ['EVOLUCAO']].groupby('EVOLUCAO').sum().loc[1]
    region_evolution_percentage = (region_evolution / region_counts) * 100

    # Sorting the regions (from fewest to most cases) and aligning the percentages in the same order
    region_counts_sorted = region_counts.sort_values(ascending=True)
    region_evolution_percentage_sorted = region_evolution_percentage[region_counts_sorted.index]

    fig, ax1 = plt.subplots(figsize=(15, 8))

    # Bar chart: total number of cases per region
    sns.barplot(
        x=region_counts_sorted.index, 
        y=region_counts_sorted.values, 
        hue=region_counts_sorted.index, 
        palette="coolwarm", 
        ax=ax1, 
        legend=False
    )
    ax1.set_xlabel("Region", fontsize=12)
    ax1.set_ylabel("Number of Cases", fontsize=12)
    ticks = ax1.get_xticks()
    ax1.xaxis.set_major_locator(FixedLocator(ticks))
    new_region_labels = [region_labels[label] for label in region_counts_sorted.index]
    ax1.set_xticklabels(new_region_labels, rotation=45, fontsize=10)
    ax1.tick_params(axis='y', labelsize=10)
    ax1.grid(axis='y', linestyle='--', alpha=0.7)

    # Second y-axis: percentage of evolutions per region
    ax2 = ax1.twinx()
    ax2.plot(region_counts_sorted.index, region_evolution_percentage_sorted.values, 
            color='tab:red', marker='o', linestyle='-', linewidth=2, label="Percentage of Deaths")
    ax2.set_ylabel("Percentage of Deaths (%)", fontsize=12)
    ax2.tick_params(axis='y', labelsize=10)

    # Configure the secondary axis to start at 0 and go up to the maximum value plus a 10% margin
    max_pct = region_evolution_percentage_sorted.max()
    ax2.set_ylim(0, max_pct * 1.1)

    ax2.legend(loc='upper left', fontsize=10)

    plt.title("Distribution of Cases by Region with Percentage of Deaths", fontsize=14)
    plt.show()

## Part 3 - Fairness

In [49]:
def plot_fairness_tpr_metrics(results_with_attr: dict, results_without_attr: dict):
    """
    Generates a bar plot comparing True Positive Rates (TPR) on the test set.

    Parameters:
        results_with_attr (dict): The fairness metrics dictionary for the model trained WITH the sensitive attribute.
        results_without_attr (dict): The fairness metrics dictionary for the model trained WITHOUT the sensitive attribute.
    """
    # Prepare data for plotting using only the Test set results
    test_data = {
        'TPR': [
            results_with_attr['test']['tpr_group1'],
            results_with_attr['test']['tpr_group0'],
            results_without_attr['test']['tpr_group1'],
            results_without_attr['test']['tpr_group0'],
        ],
        'Group': [
            'Female (Model with GENDER)', 'Male (Model with GENDER)',
            'Female (Model without GENDER)', 'Male (Model without GENDER)'
        ]
    }
    df_plot_test = pd.DataFrame(test_data)
    
    # Create new columns for better plotting control
    df_plot_test['Model Type'] = df_plot_test['Group'].apply(lambda x: 'With GENDER' if 'with GENDER' in x else 'Without GENDER')
    df_plot_test['Gender'] = df_plot_test['Group'].apply(lambda x: 'Female' if 'Female' in x else 'Male')

    
    # Create the bar plot for the Test set
    plt.figure(figsize=(12, 8))
    ax = sns.barplot(
        x='Model Type', # This creates the separation between the two models
        y='TPR',
        hue='Gender',   # This groups by gender within each model type
        data=df_plot_test,
        palette='coolwarm_r', 
        edgecolor='black'
    )

    # Customize the plot for clarity and aesthetics
    for p in ax.patches:
        if pd.notna(p.get_height()) and p.get_height() > 0:
            ax.annotate(
                format(p.get_height(), '.3f'),
                (p.get_x() + p.get_width() / 2., p.get_height()),
                ha = 'center',
                va = 'center',
                xytext = (0, 9),
                textcoords = 'offset points',
                fontsize=11,
                fontweight='bold'
            )

    # Set titles and labels
    plt.title("Comparison of True Positive Rate (TPR) on Test Set", fontsize=18, pad=20)
    plt.xlabel("Model Type", fontsize=14)
    plt.ylabel("True Positive Rate (TPR)", fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    
    # Position the legend inside the plot to a non-overlapping location
    plt.legend(
        title="Gender",
        loc='lower right',
        fontsize=11,
        title_fontsize=13,
    )

    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.ylim(0, 1.0)
    
    # Use tight_layout to ensure everything fits cleanly
    plt.tight_layout()
    
    plt.show()



In [50]:
def plot_fairness_class_metrics(report_with_gender: pd.DataFrame, report_without_gender: pd.DataFrame):
    """
    Plots precision, recall, and f1-score for each class, comparing the two models.
    """
    # Prepare data from both reports
    metrics_to_plot = ['precision', 'recall', 'f1-score']
    
    # Filter for class-specific metrics ('0' and '1')
    report_with_gender = report_with_gender.loc[['0', '1'], metrics_to_plot].copy()
    report_with_gender['Model'] = 'With GENDER'
    
    report_without_gender = report_without_gender.loc[['0', '1'], metrics_to_plot].copy()
    report_without_gender['Model'] = 'Without GENDER'
    
    # Combine data and melt for plotting
    combined_df = pd.concat([report_with_gender, report_without_gender]).reset_index()
    plot_df = combined_df.melt(
        id_vars=['Model', 'Class'], 
        value_vars=metrics_to_plot, 
        var_name='Metric', 
        value_name='Score'
    )
    
    # Using catplot with col='Metric' to create subplots for each metric
    g = sns.catplot(
        data=plot_df,
        kind='bar',
        x='Class',          # X-axis will have Class 0 and Class 1
        y='Score',
        hue='Model',        # Hue will create adjacent bars for each model
        col='Metric',       # This creates side-by-side subplots for each metric
        palette='coolwarm_r', # Changed color palette
        edgecolor='black',
        height=6,
        aspect=0.9
    )
    
    g.fig.suptitle('Model Performance Comparison by Class and Metric', y=1.03, fontsize=18)
    g.set_axis_labels("Class", "Score", fontsize=14)
    g.set_titles("Metric: {col_name}", fontsize=14)
    
    # Add value annotations to each bar (with NaN check)
    for ax in g.axes.flat:
        for p in ax.patches:
            # Check if the value is not NaN and is greater than 0 before annotating
            if pd.notna(p.get_height()) and p.get_height() > 0:
                ax.annotate(
                    format(p.get_height(), '.3f'),
                    (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha = 'center',
                    va = 'center',
                    xytext = (0, 9),
                    textcoords = 'offset points',
                    fontsize=11, # Increased font size for readability
                    fontweight='bold'
                )
        ax.grid(axis='y', linestyle='--', alpha=0.7)
        ax.set_ylim(0, 1.05)

    plt.show()
