In [1]:
from dash import Dash, dcc, html, Input, Output, State, callback
import dash
import dash_bootstrap_components as dbc
from sklearn.manifold import TSNE
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import linkage, fcluster
from collections import Counter, defaultdict
import plotly.graph_objects as go
import plotly.express as px
import threading
import numpy as np
import json

In [2]:
def load_activations(filename):
    data = np.load(filename)
    activations = data["activations"]
    labels = data["labels"]
    return activations, labels

def load_traits(dataset_path):
    with open(dataset_path, 'r', encoding='utf-8') as f:
        trait_entries = json.load(f)
        
        traits_en = [e["trait"] for e in trait_entries]
        traits_ru = [e["trait_russian"] for e in trait_entries]
        traits_uk = [e["trait_ukrainian"] for e in trait_entries]
        sentiments = [e["sentiment"] for e in trait_entries]
    return {"en": traits_en, "ru": traits_ru, "uk": traits_uk, "sentiment": sentiments}

def group_activations(activations, labels, traits):
    trait_activations = defaultdict(list)
    trait_index = 0

    for activation, label in zip(activations, labels):
        if label == 0:
            trait_activations["Neutral"].append(activation)
        else:
            trait_activations[traits["en"][trait_index]].append(activation)
            trait_index += 1

    return trait_activations

def extract_layer_activations(trait_activations, layer_index, average = False):
    extracted = {}

    for trait, activations in trait_activations.items():
        layer_acts = [a[layer_index] for a in activations]

        if average:
            mean_act = np.mean(layer_acts, axis=0)
            extracted[trait] = mean_act[None, :] 
        else:
            extracted[trait] = np.stack(layer_acts)

    return extracted

In [3]:
activations_llama, labels_llama = load_activations("activations_labels_llama.npz")
activations_mistral, labels_mistral = load_activations("activations_labels_mistral.npz")
activations_t5, labels_t5 = load_activations("activations_labels_t5.npz")
activations_llama.shape, activations_mistral.shape, activations_t5.shape

((18200, 32, 4096), (18200, 32, 4096), (18200, 48, 2048))

In [4]:
labels = labels_llama if all(labels_llama == labels_t5) else None
labels.shape

(18200,)

In [5]:
traits = load_traits("trait_combined_dataset.json")

In [6]:
grouped_activations_llama = group_activations(activations_llama, labels, traits)
grouped_activations_mistral = group_activations(activations_mistral, labels, traits)
grouped_activations_t5 = group_activations(activations_t5, labels, traits)
len(grouped_activations_llama["Neutral"])

9100

In [11]:
def create_activation_visualization(model_data: dict, traits: dict):
    app = Dash(__name__)

    sentiments_map = {name: sentiment for name, sentiment in zip(traits["en"], traits["sentiment"])}
    sentiments_map["Neutral"] = "None"

    color_palette = px.colors.qualitative.Dark24 + px.colors.qualitative.Alphabet
    sentiment_colors = {"Positive": "green", "Negative": "red", "Neutral": "blue", "None": "gray"}

    precomputed_cluster_cache = {}  # (model, layer, n_clusters) -> data
    precomputed_all_cache = {}      # (model, layer) -> data

    def generate_marks(min_val, max_val, steps=10):
        step = int(max(1, (max_val - min_val) // steps))
        return {i: str(i) for i in range(min_val, max_val + 1, step)}

    model_names = list(model_data.keys())
    default_model = model_names[0]
    max_clusters_dict = {model: len(data[1]) for model, data in model_data.items()}
    max_layers_dict = {model: data[0].shape[1] - 1 for model, data in model_data.items()}

    app.layout = html.Div([
        html.Div([
            html.Label("Модель:", style={"color": "white"}),
            dcc.Dropdown(
                id="model-dropdown",
                options=[{"label": name, "value": name} for name in model_names],
                value=default_model
            )
        ]),
        dcc.Graph(id="tsne-graph"),
        html.Div([
            html.Label("Кількість кластерів:", style={"color": "white"}),
            dcc.Slider(id="cluster-slider")
        ]),
        html.Div([
            html.Label("Шар:", style={"color": "white"}),
            dcc.Slider(id="layer-slider")
        ]),
        html.Div([
            dcc.Checklist(
                id="color-mode",
                options=[{"label": "Показати кольори за настроєм", "value": "sentiment"}],
                value=[]
            )
        ], style={"color": "white"}),
        html.Div([
            dcc.Checklist(
                id="show-all-points",
                options=[{"label": "Показати всі активації без кластеризації", "value": "show_all"}],
                value=[]
            )
        ], style={"color": "white", "marginTop": "10px"}),
        html.Div([
            html.Label("Мова назв рис:", style={"color": "white"}),
            dcc.Dropdown(
                id="language-dropdown",
                options=[{"label": lang, "value": lang} for lang in ["en", "uk", "ru"]],
                value="en"
            )
        ])
    ], style={"backgroundColor": "black", "color": "white", "padding": "20px"})

    @app.callback(
        Output("cluster-slider", "min"),
        Output("cluster-slider", "max"),
        Output("cluster-slider", "value"),
        Output("cluster-slider", "marks"),
        Output("layer-slider", "min"),
        Output("layer-slider", "max"),
        Output("layer-slider", "value"),
        Output("layer-slider", "marks"),
        Input("model-dropdown", "value")
    )
    def update_sliders(model):
        max_clusters = max_clusters_dict[model]
        max_layer = max_layers_dict[model]
        cluster_marks = {i: str(int(i)) for i in range(2, max_clusters + 1) if i == 2 or i % 5 == 0}
        layer_marks = generate_marks(0, max_layer)
        return 2, max_clusters, max_clusters, cluster_marks, 0, max_layer, min(18, max_layer), layer_marks

    @app.callback(
        Output("tsne-graph", "figure"),
        Input("model-dropdown", "value"),
        Input("cluster-slider", "value"),
        Input("layer-slider", "value"),
        Input("color-mode", "value"),
        Input("language-dropdown", "value"),
        Input("show-all-points", "value")
    )
    def update_graph(model_name, n_clusters, layer_index, color_mode, lang, show_all_points):
        layer_index = int(layer_index)
        n_clusters = int(n_clusters)
        activations, grouped_activations = model_data[model_name]
        show_all = "show_all" in show_all_points

        if show_all:
            cache_key = (model_name, layer_index)
            if cache_key not in precomputed_all_cache:
                trait_embeddings = extract_layer_activations(grouped_activations, layer_index, average=False)
                trait_names = list(trait_embeddings.keys())
                vectors_list, names_list = [], []
                for name in trait_names:
                    vectors = trait_embeddings[name]
                    vectors_list.append(vectors)
                    names_list.extend([name]*vectors.shape[0])
                all_vectors = np.vstack(vectors_list)
                tsne_points = TSNE(n_components=2, random_state=0).fit_transform(all_vectors)
                precomputed_all_cache[cache_key] = {
                    "trait_names": trait_names,
                    "tsne_points": tsne_points,
                    "names_list": names_list
                }

            data = precomputed_all_cache[cache_key]
            tsne_points = data["tsne_points"]
            names_list = data["names_list"]

            traces = []
            for trait in sorted(set(names_list)):
                idxs = [i for i, n in enumerate(names_list) if n == trait]
                color = sentiment_colors.get(sentiments_map.get(trait, "None").capitalize(), "gray") if "sentiment" in color_mode else color_palette[hash(trait) % len(color_palette)]
                localized_name = "Нейтрально" if (trait == "Neutral" and lang != "en") else traits[lang][traits["en"].index(trait)] if trait in traits["en"] else trait
                traces.append(go.Scatter(
                    x=tsne_points[idxs, 0],
                    y=tsne_points[idxs, 1],
                    mode='markers',
                    name=localized_name,
                    text=[localized_name]*len(idxs),
                    hoverinfo='text',
                    marker=dict(color=color, size=6),
                    legendgroup=trait,
                    showlegend=True
                ))
            return go.Figure(
                            data=traces,
                            layout={
                                "title": f"{model_name} (Layer {layer_index}, All Activations)",
                                "hovermode": "closest",
                                "legend": {
                                    "itemclick": "toggle",
                                    "itemdoubleclick": "toggleothers",
                                }
                            }
                        )

        else:
            cache_key = (model_name, layer_index, n_clusters)
            if cache_key not in precomputed_cluster_cache:
                trait_embeddings = extract_layer_activations(grouped_activations, layer_index, average=True)
                trait_names = list(trait_embeddings.keys())
                vectors = np.vstack([trait_embeddings[name] for name in trait_names])
                clustering = AgglomerativeClustering(n_clusters=n_clusters)
                labels = clustering.fit_predict(vectors)
                centroids = np.array([vectors[labels == i].mean(axis=0) for i in range(n_clusters)])
                tsne_centroids = TSNE(n_components=2, perplexity=min(n_clusters - 1, 30), random_state=0).fit_transform(centroids)
                precomputed_cluster_cache[cache_key] = {
                    "trait_names": trait_names,
                    "labels": labels,
                    "tsne_centroids": tsne_centroids
                }

            data = precomputed_cluster_cache[cache_key]
            trait_names = data["trait_names"]
            labels = data["labels"]
            tsne_centroids = data["tsne_centroids"]

            traces = []
            for cluster_id in np.unique(labels):
                idxs = np.where(labels == cluster_id)[0]
                names = [trait_names[i] for i in idxs]
                sentiments = [sentiments_map.get(name, "None").capitalize() for name in names]
                dominant_sentiment = Counter(sentiments).most_common(1)[0][0]
                color = sentiment_colors.get(dominant_sentiment, "gray") if "sentiment" in color_mode else color_palette[cluster_id % len(color_palette)]
                localized_names = ["Нейтрально" if (n == "Neutral" and lang != "en") else traits[lang][traits["en"].index(n)] if n in traits["en"] else n for n in names]
                tooltip = "<br>".join(localized_names)
                traces.append(go.Scatter(
                    x=[tsne_centroids[cluster_id, 0]],
                    y=[tsne_centroids[cluster_id, 1]],
                    mode='markers',
                    name=f"Cluster {cluster_id}",
                    text=tooltip,
                    hoverinfo='text',
                    marker=dict(color=color, size=12),
                    legendgroup=f"cluster_{cluster_id}",
                    showlegend=True
                ))
            return go.Figure(
                data=traces,
                layout={
                    "title": f"{model_name} (Layer {layer_index}, Clusters {n_clusters})",
                    "hovermode": "closest",
                    "legend": {
                        "itemclick": "toggle",
                        "itemdoubleclick": "toggleothers",
                    }
                }
            )

    app.run(debug=True)

In [12]:
create_activation_visualization({
    "LLaMa": (activations_llama, grouped_activations_llama),
    "Mistral": (activations_mistral, grouped_activations_mistral),
    "FLAN-T5": (activations_t5, grouped_activations_t5),
    },
    traits)