In [1]:
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import confusion_matrix

from tqdm import tqdm

def compute_clusters_over_time(labels_unaligned, num_days_to_cluster):

    labels_list = []
    for day in tqdm(range(0, num_days_to_cluster)):
        # IZ: Find the optimal label alignment by solving the hungarian matching problem
        if day == 0:
            # For the first day, the initial labels are fine
            labels_aligned = labels_unaligned[day]
        else:
            # For the second day onward, try to align the current labels with the most recent day
            # labels_last = labels_list[-1]
            labels_last = labels_unaligned[day-1]
            # Compute the confusion matrix between the two labeligns
            conf_mat = confusion_matrix(labels_last, labels_unaligned[day])
            # Compute the linear sum assignment / matching problem
            row_ind, col_ind = linear_sum_assignment(-conf_mat)
            mapping = dict(zip(col_ind, row_ind))
            labels_aligned = np.vectorize(mapping.get)(labels_unaligned[day])

        labels_list.append(labels_aligned)

    labels_arr = np.stack(labels_list, axis=0)
    return labels_arr

In [2]:
import os
PATH = os.path.dirname(os.getcwd())
os.chdir(PATH + "/Data/Cluster Labels")

In [4]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

# labels = [pd.read_parquet(name) for name in os.listdir() if name.endswith(".parquet")]
# labels = pd.concat(labels)
labels = pd.read_parquet(r"C:\Users\james\ICAIF_25\Current_Code\Data\Cluster Labels\cluster_labels_8c_252d.parquet")
print(labels.head())

assignments = compute_clusters_over_time(labels.values, num_days_to_cluster=10)

unique_clusters = np.unique(assignments)
print(unique_clusters)

# --- KEY CHANGE 1: Create a consistent color mapping for clusters ---
# Use a qualitative color scale from Plotly Express
# The modulo operator (%) ensures we don't run out of colors if we have many clusters
palette = px.colors.qualitative.Plotly
cluster_color_map = {cluster: palette[i % len(palette)] for i, cluster in enumerate(unique_clusters)}

# --- Create nodes and define their colors ---
node_labels = []
node_colors = [] # This new list will store the color for each node
node_lookup = {}  # (day, cluster) -> node index

for day in range(assignments.shape[0]):
    # For Sankey, it's visually better if nodes are sorted
    for cluster in sorted(unique_clusters):
        # Create a unique label for the hover info, but the visual label can be simpler
        label = f"Cluster {cluster}"
        node_idx = len(node_labels)

        node_labels.append(label)
        # --- KEY CHANGE 2: Add the corresponding color from our map ---
        node_colors.append(cluster_color_map[cluster])

        node_lookup[(day, cluster)] = node_idx

# Build links between days
source = []
target = []
value = []

for day in range(assignments.shape[0] - 1):
    curr_clusters = assignments[day]
    next_clusters = assignments[day + 1]

    for cluster_from in unique_clusters:
        for cluster_to in unique_clusters:
            mask = (curr_clusters == cluster_from) & (next_clusters == cluster_to)
            count = np.sum(mask)
            if count > 0:
                source.append(node_lookup[(day, cluster_from)])
                target.append(node_lookup[(day + 1, cluster_to)])
                value.append(count)

# Build Sankey diagram
fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color="black", width=0.5),
        label=node_labels,
        color=node_colors
    ),
    link=dict(
        source=source,
        target=target,
        value=value
    )
)])

# Add day labels as annotations
num_days = assignments.shape[0]
x_positions = np.linspace(0, 1, num_days)  # Evenly spaced x for each stack

annotations = []
for i, x in enumerate(x_positions):
    annotations.append(dict(
        x=x,
        y=-0.1,  # Position below the Sankey diagram
        xref='paper',
        yref='paper',
        text=f"Day {i}",
        showarrow=False,
        font=dict(size=12)
    ))

fig.update_layout(
    title_text="Cluster Assignment Transitions Over Time",
    font_size=10,
    annotations=annotations
)

fig.show()


ticker      AA  AAA  AAC  AAG  AAI  AAIC  AAIR  AAL  AAM  AAN  ...  ZQK  ZTO  \
date                                                           ...             
2000-12-29   5    2    6    7    3     2     7    2    7    2  ...    2    2   
2001-01-02   3    6    0    1    7     6     7    6    7    6  ...    6    6   
2001-01-03   1    1    2    7    6     1     7    1    6    1  ...    1    1   
2001-01-04   7    0    1    6    7     0     6    0    7    0  ...    0    0   
2001-01-05   4    0    4    6    4     0     6    0    4    0  ...    0    0   

ticker      ZTR  ZTS  ZUO  ZVO  ZWS  ZX  ZYME  ZZ  
date                                               
2000-12-29    3    2    2    2    2   2     2   2  
2001-01-02    7    6    6    6    6   6     6   6  
2001-01-03    6    1    1    1    1   1     1   1  
2001-01-04    7    0    0    0    0   0     0   0  
2001-01-05    6    0    0    0    0   0     0   0  

[5 rows x 6089 columns]


100%|██████████| 10/10 [00:00<00:00, 535.72it/s]

[0 1 2 3 4 5 6 7]



