In [1]:
import pandas as pd
import os
os.chdir("..")
import plotly.graph_objects as go
import pickle

In [11]:

# Volver a cargar las subsamples
def load_subsamples(dataset_name, dim):
    pickle_file_path = f"data/subsamples/{dataset_name}_{dim}_subsamples.pkl"
    with open(pickle_file_path, "rb") as f:
        subsamples = pickle.load(f)
    
    return subsamples


def create_plot_for_dataset(dataset_name, dataset, source_subsamples):
    whole_data = dataset.groupby("source").agg(
        reliability=("reliability", "first"),  
        bias=("bias", "first"),     
        count=("source", "size")
    ).reset_index()

    total_count_whole = whole_data['count'].sum()  # Total count of all sources
    whole_data['proportion'] = whole_data['count'] / total_count_whole 

    # Normalize bubble sizes 
    max_bubble_size = 50  
    whole_data['normalized_size'] = (whole_data['count'] / whole_data['count'].max()) * max_bubble_size


    subset_data = source_subsamples  
    # Extract subset names and sort by vs 
    subset_names_sorted = sorted(subset_data.keys(), key=lambda name: subset_data[name]['vs'])

    # Create frames for each subset
    frames = []
    for subset_name in subset_names_sorted:
        subset_df = subset_data[subset_name]['data'].groupby("source").agg(
            reliability=("reliability", "first"),  
            bias=("bias", "first"),     
            count=("source", "size")
        ).reset_index()

        total_count = subset_df['count'].sum()
        subset_df['proportion'] = subset_df['count'] / total_count

        # Merge with the whole dataset
        merged_data = whole_data.merge(subset_df, on='source', how='left', suffixes=('_x', '_y'))

        merged_data['count_y'] = merged_data['count_y'].fillna(0)
        merged_data['proportion_y'] = merged_data['proportion_y'].fillna(0)

        # Normalize bubble sizes 
        merged_data['normalized_size'] = (merged_data['count_y'] / merged_data['count_y'].max()) * max_bubble_size

        # Create the frame for the subset
        frame = go.Frame(
            data=[go.Scatter(
                x=merged_data['bias_x'],  
                y=merged_data['reliability_x'],  
                mode='markers',
                text=merged_data['source'],
                hovertext=merged_data['source'], 
                hoverinfo='text',
                marker=dict(
                    size=merged_data['normalized_size'],
                    color=merged_data['proportion_y'], 
                    colorscale='Redor',  
                    cmin=0,              
                    cmax=1,             
                    showscale=True,     
                    colorbar=dict(
                        title='Proportion',
                        tickformat=".2f"
                    )
                )
            )],
            name=subset_name
        )
        frames.append(frame)

    # Base figure with the whole dataset
    fig = go.Figure(
        data=[go.Scatter(
            x=whole_data['bias'],  
            y=whole_data['reliability'],  
            mode='markers',  
            hovertext=whole_data['source'],  
            hoverinfo='text',  
            marker=dict(
                size=whole_data['normalized_size'],
                color='gray',
                opacity=0.5
            ),
            name="Whole Dataset"
        )],
        frames=frames
    )

    vs_values = [subset_data[name]['vs'] for name in subset_names_sorted]
    vs_min, vs_max = min(vs_values), max(vs_values)

    slider_steps = [
        {
            "label": f"VS: {vs:.2f}",
            "method": "animate",
            "args": [[subset_name], {"frame": {"duration": 500, "redraw": True}, "mode": "immediate"}],
        }
        for subset_name, vs in zip(subset_names_sorted, vs_values)
    ]

    sliders = [{
        "active": 0,
        "pad": {"t": 50},
        "steps": slider_steps
    }]

    x_min, x_max = -42, 42
    y_min, y_max = 0, 64

    # Configure layout
    fig.update_layout(
        title=f"Source subsamples by diversity measure for {dataset_name}",
        xaxis=dict(title="Bias", range=[x_min, x_max]),  
        yaxis=dict(title="Reliability", range=[y_min, y_max]),  
        updatemenus=[{
            "type": "buttons",
            "showactive": True,
            "buttons": [
                {
                    "label": "Play",
                    "method": "animate",
                    "args": [None, {"frame": {"duration": 500, "redraw": True}, "fromcurrent": True}]
                },
                {
                    "label": "Pause",
                    "method": "animate",
                    "args": [[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}]
                }
            ]
        }],
        sliders=sliders
    )

    fig.frames = frames
    fig.show()


In [12]:
datasets = {
    "annomatic": pd.read_parquet("data/enriched/annomatic_full.parquet"),
    "babe": pd.read_parquet("data/enriched/babe_full.parquet"),
    "basil": pd.read_parquet("data/enriched/basil_full.parquet")
}

for dataset_name, df in datasets.items():
    source_subsamples = load_subsamples(dataset_name, "source")  
    create_plot_for_dataset(dataset_name, df, source_subsamples)