In [6]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
import numpy as np
import os
from tqdm import tqdm
from collections import OrderedDict
import math


In [None]:
def barplot_confidence_location(
    df,
    pivot_index='Years',
    column_group='Age Group',
    columns=None,
    location=None,
    ages=None,
    title_prefix="",
    layout='horizontal'
):
    df_copy = df.copy()

    if isinstance(location, str):
        location = [location]

    region_cols = ['Country', 'IntermediateRegion', 'SubRegion', 'Region']
    detected_col = None

    for col in region_cols:
        if location and any(loc in df_copy[col].unique() for loc in location):
            detected_col = col
            break
    if not detected_col:
        print(f"Could not match locations to any region columns.")
        return None

    # Color and Group Order Matching
    tol_bright_7 = [
        "#4477AA", "#66CCEE", "#228833", "#CCBB44", "#EE6677",
        "#AA3377", "#BBBBBB"
    ]

    if column_group == 'Age Group':
        group_order = [
            "Teen (<18)",
            "Young Adult (18-30)",
            "Middle Adult (31-45)",
            "Older Adult (46-65)",
            "Elder Adult (66+)"
        ]
        color_map = OrderedDict(zip(group_order, tol_bright_7))

    elif column_group == 'Generation':
        group_order = [
            "Lost Generation (1883-'00)", 
            "Greatest Generation (1901-'27)", 
            "Silent Generation (1928-'45)",
            "Baby Boomer (1946-'64)", 
            "Gen X (1965-'80)", 
            "Millennial (1981-'96)", 
            "Gen Z (1997-'12)"
        ]
        color_map = OrderedDict(zip(group_order, tol_bright_7))

    elif column_group == 'Sex':
        group_order = ["Male", "Female", "Other"]
        color_map = OrderedDict(zip(group_order, tol_bright_7))

    elif column_group == 'BornHere':
        group_order = ["Yes", "No"]
        color_map = OrderedDict(zip(group_order, tol_bright_7))

    elif column_group == 'NationalPride':
        group_order = ["1", "2", "3", "4"]
        color_map = OrderedDict(zip(group_order, tol_bright_7))

    else:
        group_order = sorted(df_copy[column_group].dropna().astype(str).unique())
        color_map = OrderedDict(zip(group_order, tol_bright_7 * ((len(group_order) // len(tol_bright_7)) + 1)))

    # Subplots per location
    if layout == 'horizontal':
        fig = make_subplots(
            rows=1,
            cols=len(location),
            shared_yaxes=True,
            shared_xaxes=False,
            subplot_titles=location,
            horizontal_spacing=0.01
        )
        subplot_positions = [(1, i + 1) for i in range(len(location))]
        fig_width = 500 * len(location)
        fig_height = 300

    elif layout == 'vertical':
        fig = make_subplots(
            rows=len(location),
            cols=1,
            shared_xaxes=False,
            shared_yaxes=True,
            subplot_titles=location,
            #vertical_spacing=0.1
        )
        subplot_positions = [(i + 1, 1) for i in range(len(location))]
        fig_width = 850
        fig_height = 250 * len(location)

    else:
        raise ValueError("'layout' must be either 'horizontal' or 'vertical'")
    
    for i, loc in enumerate(location):
        loc_df = df_copy[df_copy[detected_col] == loc].copy()

        if ages:
            loc_df = loc_df[loc_df[column_group].isin(ages)]

        loc_df[column_group] = loc_df[column_group].astype(str)
        loc_df = loc_df.dropna(subset=[column_group])

        pivot = pd.pivot_table(
            loc_df,
            values=columns,
            index=pivot_index,
            columns=column_group,
            aggfunc='median',
            sort=False
        )

        if pivot.empty:
            continue

        pivot = pivot.reset_index().melt(
            id_vars=pivot_index,
            var_name=column_group,
            value_name='Median_Confidence'
        )

        for group_val in group_order:
            group_data = pivot[pivot[column_group] == group_val]
            if group_data.empty:
                continue

            row, col = subplot_positions[i]
            fig.add_trace(
                go.Bar(
                    x=group_data[pivot_index],
                    y=group_data['Median_Confidence'],
                    name=group_val,
                    legendgroup=group_val,
                    marker_color=color_map[group_val],
                    showlegend=(i == 0)
                ),
                row = row, 
                col = col
            )

    #  Optional y-axis labeling for confidence scale
    if columns == 'NationalPride':
        y_labels_median = {
            1: 'Not at all proud', 2: 'Not very proud',
            3: 'Quite proud', 4: 'Very Proud'
        }
    else:
        y_labels_median = {
            1: 'None at all', 2: 'Not very much',
            3: 'Quite a lot', 4: 'A great deal'
        }
    if columns is None:
        raise ValueError("'columns' parameter must be provided to specify which confidence metric to visualize.")

    fig.for_each_yaxis(lambda axis: axis.update(
        range=[0.5, 4],
        tickmode='array',
        tickvals=list(y_labels_median.keys()),
        ticktext=[label.title() for label in y_labels_median.values()]
    ))

    # Conditional legend positioning
    if layout == 'horizontal':
        legend_position = dict(
            orientation='h',
            y=-0.50,  # clearly below subplots
            xanchor='center',
            x=0.5
        )
        title_position = {'text': title_prefix,
                          'xanchor': 'center',
                          'x': 0.5}
    else:  # vertical
        legend_position = dict(
            orientation='v',
            yanchor='top',
            y=1,
            xanchor='left',
            x=1.02  # to the top right of the plot
        )
        title_position = {'text': title_prefix,
                          'xanchor': 'center',
                          'x': 0.45}
            
    # Final layout
    fig.update_layout(
    template="simple_white",
    height=fig_height,
    width=fig_width,
    title=title_position,
    legend_title_text=column_group,
    legend=legend_position,
    margin=dict(t=120, b=40, r=30, l=30),
    barmode='group',
    hovermode="x"
)

    return fig


In [None]:
def pieplot_confidence_location(
    df,
    pivot_index='Years',
    column_group='Age Group',
    columns=None,
    location=None,
    ages=None,
    title_prefix="",
    layout='horizontal',
    rotation=0
):
    df_copy = df.copy()

    # Handle single-location strings
    if isinstance(location, str):
        location = [location]

    # Detect which column holds location names
    region_cols = ['Country', 'IntermediateRegion', 'SubRegion', 'Region']
    detected_col = None
    for col in region_cols:
        if location and any(loc in df_copy[col].unique() for loc in location):
            detected_col = col
            break
    if not detected_col:
        print("Could not match locations to any region columns.")
        return None

    # Fixed order for confidence categories
    confidence_order = [1, 2, 3, 4]  
    confidence_labels = {
        1: 'None at all',
        2: 'Not very much',
        3: 'Quite a lot',
        4: 'A great deal'
    }
    confidence_colors = {
        1: "#D6725E",  
        2: "#E9B09F",  
        3: "#6E9ACF",  
        4: "#A5BDE9"   
    }
    

    # Subplots setup
    if layout == 'horizontal':
        fig = make_subplots(
            rows=1,
            cols=len(location),
            subplot_titles=location,
            specs=[[{"type": "domain"}] * len(location)]
        )
        subplot_positions = [(1, i + 1) for i in range(len(location))]
        fig_width = 400 * len(location)
        fig_height = 400

    elif layout == 'vertical':
        fig = make_subplots(
            rows=len(location),
            cols=1,
            subplot_titles=location,
            specs=[[{"type": "domain"}]] * len(location)
        )
        subplot_positions = [(i + 1, 1) for i in range(len(location))]
        fig_width = 500
        fig_height = 400 * len(location)

    else:
        raise ValueError("'layout' must be either 'horizontal' or 'vertical'")

    # Push subplot titles slightly higher
    for anno in fig['layout']['annotations']:
        anno['y'] += 0.05

    # Loop through each location
    for i, loc in enumerate(location):
        loc_df = df_copy[df_copy[detected_col] == loc].copy()
        if ages:
            loc_df = loc_df[loc_df[column_group].isin(ages)]

        # Create pivot table of medians per wave and age group
        pivot = pd.pivot_table(
            loc_df,
            values=columns,
            index=pivot_index,
            columns=column_group,
            aggfunc='median',
            sort=False
        )

        if pivot.empty:
            continue

        # Flatten into one list of medians across all waves & groups
        median_values = pd.Series(pivot.values.flatten()).dropna().astype(int)
        total_count = len(median_values)

        # Build slices keeping only nonzero categories
        slice_labels = []
        slice_values = []
        slice_colors = []

        for val in confidence_order:
            count = (median_values == val).sum()
            if count > 0:  # only keep if category is present
                percent = round((count / total_count) * 100, 1)
                slice_labels.append(confidence_labels[val])
                slice_values.append(percent)
                slice_colors.append(confidence_colors[val])

        # Add pie chart trace
        row, col = subplot_positions[i]
        fig.add_trace(
            go.Pie(
                labels=slice_labels,
                values=slice_values,
                marker=dict(colors=slice_colors),
                hoverinfo='label+percent',
                insidetextorientation='radial',
                rotation=rotation,  # start at 90°
                sort=False,   # keep slices in fixed order
                showlegend=(i == 0)  # legend only once
            ),
            row=row,
            col=col
        )

    # Legend & title positioning
    if layout == 'horizontal':
        legend_position = dict(
            orientation='h',
            y=-0.20,
            xanchor='center',
            x=0.5
        )
        title_position = {'text': title_prefix, 'xanchor': 'center', 'x': 0.5}
    else:
        legend_position = dict(
            orientation='v',
            yanchor='top',
            y=1,
            xanchor='left',
            x=1.02
        )
        title_position = {'text': title_prefix, 'xanchor': 'center', 'x': 0.45}

    # Final layout
    fig.update_layout(
        template="simple_white",
        height=fig_height,
        width=fig_width,
        title=title_position,
        legend=legend_position,
        margin=dict(t=100, b=40, r=40, l=40)
    )

    return fig


In [5]:
def diverging_barchart(
    df, 
    column, 
    grouping='Age Group',
    location=None,
    wave=None,
    title=None
):
    df_copy = df.copy()

    # Filter by location
    if location:
        region_cols = ['Country', 'IntermediateRegion', 'SubRegion', 'Region']
        matched = False
        for col in region_cols:
            if location in df_copy[col].unique():
                df_copy = df_copy[df_copy[col] == location]
                matched = True
                break
        if not matched:
            print(f" Location '{location}' not found in region columns. Skipping.")
            return None
        location_str = f"({location})"
    else:
        location_str = "(Global)"

    # Filter by wave
    wave_to_years = {
    '1': "1981–84", '2': "1990–94", '3': "1995–98",
    '4': "1999–04", '5': "2005–09", '6': "2010–14", '7': "2017–22"
}
    if wave:
        df_copy = df_copy[df_copy['WVS Wave'] == wave]
        years_label = wave_to_years.get(wave, f"Wave {wave}")

    # Response labels
    if column == 'NationalPride':
        likert_labels = {
            1: "Not at all proud",
            2: "Not very proud",
            3: "Quite proud",
            4: "Very proud",
        }
    else:
        likert_labels = {
            1: "None at all",
            2: "Not very much",
            3: "Quite a lot",
            4: "A great deal",
        }

    # Grouping logic
    if grouping == 'Age Group':
        group_order = [
            "Teen (<18)", "Young Adult (18-30)", "Middle Adult (31-45)",
            "Older Adult (46-65)", "Elder Adult (66+)"
        ]
    elif grouping == 'Generation':
        group_order = [
            "Lost Generation (1883-'00)", "Greatest Generation (1901-'27)", 
            "Silent Generation (1928-'45)", "Baby Boomer (1946-'64)", 
            "Gen X (1965-'80)", "Millennial (1981-'96)", "Gen Z (1997-'12)"
        ]
    elif grouping == 'Sex':
        group_order = ["Male", "Female"]
    
    elif grouping =='BornHere':
        group_order = ["No", "Yes"]

    fig = go.Figure()
    display_labels = []  # To hold group name with n-count
    
    
    for group in group_order:
        group_df = df_copy[df_copy[grouping] == group]
        counts = group_df[column].value_counts().sort_index()
        total = counts.sum()
        display_group = f"{group}<br>(n = {total})"
        display_labels.append(display_group)
        
        

        if total == 0:
            continue  # Avoid division by zero

        # Compute percentages
        neutral_total = counts.get(-1, 0) + counts.get(-2, 0)
        neutral_pct = (neutral_total / total * 100).round(0)

        likert_vals = counts.reindex([1, 2, 3, 4], fill_value=0)
        likert_pct = (likert_vals / total * 100).round(0)

        # Define bar specs
        bar_specs = [
            (-likert_pct[2], likert_labels[2], '#E9B09F', 2),
            (-likert_pct[1], likert_labels[1], '#D6725E', 1),
            (likert_pct[3], likert_labels[3], '#6E9ACF', 3),
            (likert_pct[4], likert_labels[4], '#A5BDE9', 4),
        ]

        for val, label, color, response_code in bar_specs:
            fig.add_trace(go.Bar(
                x=[val],
                y=[display_group],
                orientation='h',
                name=label,
                marker_color=color,
                legendgroup=label,
                legendrank=response_code,
                showlegend=(group == group_order[0]),
                text=f"{int(abs(val))}%",
                textposition="auto" if abs(val) > 5 else "outside",
                insidetextanchor="middle"
            ))

        # Add neutral response bar
        fig.add_trace(go.Bar(
            x=[neutral_pct],
            y=[display_group],
            orientation='h',
            name="Don't know / No answer",
            marker_color='lightgray',
            legendgroup="Neutral",
            legendrank=5,
            xaxis='x2',
            showlegend=(group == group_order[0]),
            text=f"{int(neutral_pct)}%",
            textposition="auto" if neutral_pct > 5 else "outside",
            insidetextanchor="middle"
        ))

    # Layout
    fig.update_layout(
        barmode='relative',
        template='simple_white',
        title={
            'text': title or f"{column} by {grouping}: {years_label} {location_str}",
            'x': 0.5,
            'xanchor': 'center'
        },
        width=1000,
        height=450,
        xaxis=dict(
            title='Responses (%)',
            zeroline=True,
            tickvals=[-75, -50, -25, 0, 25, 50, 75],
            ticktext=["75%", "50%", "25%", "0", "25%", "50%", "75%"],
            domain=[0.0, 0.80]
        ),
        xaxis2=dict(
            title='Neutral (%)',
            range=[0, 20],
            anchor='y',
            side='right',
            domain=[0.85, 1.0],
            showline=False,
            showticklabels=False,
            visible=False,
        ),
        yaxis=dict(
            categoryorder='array',
            categoryarray=display_labels
        ),
        legend=dict(
            orientation='h',
            yanchor="bottom",
            y=1.02,
            xanchor='center',
            x=0.45,
        ),
        margin=dict(r=50)
    )

    return fig