In [None]:
!pip install datasets tensordict

In [None]:
!mkdir assets

In [None]:
CONFIG = "debug"

RUN_STATS = True
RUN_CLF = False
RUN_CLS = True
USE_C = False

In [None]:
from datasets import load_dataset
import torch
import numpy as np

dataset = load_dataset("lczero-planning/lczero-planning-features", CONFIG, split="test")

In [None]:
torch.autograd.set_grad_enabled(False)

In [None]:
f_ds = dataset.select_columns(["opt_features", "sub_features", "pixel_index", "root_fen"]).with_format("torch")
f_ds

In [None]:
N = len(f_ds)
D_F = f_ds[0]["opt_features"].shape[0]

In [None]:
"""
Defines the dictionary classes
"""

import torch
import torch.nn as nn
from tensordict import TensorDict


class SparseAutoEncoder(nn.Module):
    """
    A 2-layer sparse autoencoder.
    """

    def __init__(
        self,
        activation_dim,
        dict_size,
        pre_bias=False,
        init_normalise_dict=None,
    ):
        super().__init__()
        self.activation_dim = activation_dim
        self.dict_size = dict_size
        self.pre_bias = pre_bias
        self.init_normalise_dict = init_normalise_dict

        self.b_enc = nn.Parameter(torch.zeros(self.dict_size))
        self.relu = nn.ReLU()

        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(
                    self.dict_size,
                    self.activation_dim,
                )
            )
        )
        if init_normalise_dict == "l2":
            self.normalize_dict_(less_than_1=False)
            self.W_dec *= 0.1
        elif init_normalise_dict == "less_than_1":
            self.normalize_dict_(less_than_1=True)

        self.W_enc = nn.Parameter(self.W_dec.t())
        self.b_dec = nn.Parameter(
            torch.zeros(
                self.activation_dim,
            )
        )

    @torch.no_grad()
    def normalize_dict_(
        self,
        less_than_1=False,
    ):
        norm = self.W_dec.norm(dim=1)
        positive_mask = norm != 0
        if less_than_1:
            greater_than_1_mask = (norm > 1) & (positive_mask)
            self.W_dec[greater_than_1_mask] /= norm[greater_than_1_mask].unsqueeze(1)
        else:
            self.W_dec[positive_mask] /= norm[positive_mask].unsqueeze(1)

    def encode(self, x):
        return x @ self.W_enc + self.b_enc

    def decode(self, f):
        return f @ self.W_dec + self.b_dec

    def forward(self, x, output_features=False, ghost_mask=None):
        """
        Forward pass of an autoencoder.
        x : activations to be autoencoded
        output_features : if True, return the encoded features as well
            as the decoded x
        ghost_mask : if not None, run this autoencoder in "ghost mode"
            where features are masked
        """
        if self.pre_bias:
            x = x - self.b_dec
        f_pre = self.encode(x)
        out = TensorDict({}, batch_size=x.shape[0])
        if ghost_mask is not None:
            f_ghost = torch.exp(f_pre) * ghost_mask.to(f_pre)
            x_ghost = f_ghost @ self.W_dec
            out["x_ghost"] = x_ghost
        f = self.relu(f_pre)
        if output_features:
            out["features"] = f
        x_hat = self.decode(f)
        out["x_hat"] = x_hat
        return out

In [None]:
from huggingface_hub import HfApi
import torch

hf_api = HfApi()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

hf_api.snapshot_download(repo_id="lczero-planning/lczero-planning-saes", repo_type="model", local_dir="./assets/saes")
sae_dict = torch.load(
    f"./assets/saes/{CONFIG}/model.pt",
    map_location=DEVICE,
)
sae = SparseAutoEncoder(
    256,
    D_F,
    pre_bias=False,
    init_normalise_dict=None,
)
sae.load_state_dict(sae_dict)
sae.to(DEVICE)

In [None]:
unique_fens , unique_indices = np.unique(dataset["root_fen"], return_index=True, axis=0)

In [None]:
unique_fens

In [None]:
np.where(unique_fens=='rr6/1b2q3/1p1p2pk/4n3/1PPRN2p/4Pp1P/1Q3PP1/1R3BK1 b - - 0 32')[0][0]

In [None]:
import torch

if RUN_STATS:
    base_stats = {
        "mean": torch.zeros(D_F),
        "active": torch.zeros(D_F),
        "active_p": torch.zeros((64, D_F)),
        "active_fen": torch.zeros((500,D_F)),
        "opt_dead_loss": [],
        "opt_active_loss": [],
        "sub_dead_loss": [],
        "sub_active_loss": [],
        "c_diff_loss": [],
        "d_prod_loss": [],
    }


    def compute_stats(batch, stats):
        opt_features = batch["opt_features"]
        sub_features = batch["sub_features"]

        stats["mean"] += batch["opt_features"].sum(dim=0) / (2* N)
        stats["mean"] += batch["sub_features"].sum(dim=0) / (2 * N)

        opt_alive = batch["opt_features"] != 0
        sub_alive = batch["sub_features"] != 0

        stats["active"] += opt_alive.sum(dim=0) + sub_alive.sum(dim=0)
        for row, p, fen in zip(
            torch.cat([opt_alive,sub_alive], dim=0),
            batch["pixel_index"],
            batch["root_fen"],
        ):
            stats["active_p"][p,:] += row
            idx = np.where(unique_fens == fen)[0][0]
            stats["active_fen"][idx,:] += row

        stats["opt_active_loss"] += list(opt_alive.sum(dim=1))
        stats["opt_dead_loss"] += list(D_F - opt_alive.sum(dim=1))
        stats["sub_active_loss"] += list(sub_alive.sum(dim=1))
        stats["sub_dead_loss"] += list(D_F - sub_alive.sum(dim=1))

        c_opt, d_opt = opt_features.chunk(2, dim=1)
        c_sub, d_sub = sub_features.chunk(2, dim=1)
        c_diff = c_opt - c_sub
        d_prod = d_opt * d_sub

        stats["c_diff_loss"] += list(c_diff.norm(p=1, dim=1))
        stats["d_prod_loss"] += list(c_diff.norm(p=1, dim=1))


    f_ds.map(
        compute_stats,
        batched=True,
        fn_kwargs={"stats":base_stats}
    )


In [None]:
from torch.nn.functional import kl_div

print("c")
print((base_stats["active"][:D_F//2]<(2*N*0.001)).sum())
print((base_stats["active"][:D_F//2]>(2*N*0.1)).sum())
p_probs = base_stats["active_p"][:,:D_F//2]/base_stats["active"][:D_F//2]
H = -torch.xlogy(p_probs, p_probs)
print(H.sum(dim=0).mean(),H.std(dim=0).mean())
fen_probs = base_stats["active_fen"][:,:D_F//2]/base_stats["active"][:D_F//2]
H = -torch.xlogy(fen_probs, fen_probs)
print(H.sum(dim=0).mean(),H.std(dim=0).mean())
print(p_probs.std()*100)
print(fen_probs.std()*100)

print("d")
print((base_stats["active"][D_F//2:]<(2*N*0.001)).sum())
print((base_stats["active"][D_F//2:]>(2*N*0.1)).sum())
p_probs = base_stats["active_p"][:,D_F//2:]/base_stats["active"][D_F//2:]
H = -torch.xlogy(p_probs, p_probs)
print(H.sum(dim=0).mean(),H.std(dim=0).mean())
fen_probs = base_stats["active_fen"][:,D_F//2:]/base_stats["active"][D_F//2:]
H = -torch.xlogy(fen_probs, fen_probs)
print(H.sum(dim=0).mean(),H.std(dim=0).mean())
print(p_probs.std()*100)
print(fen_probs.std()*100)

print("f")
print((base_stats["active"]<(2*N*0.001)).sum())
print((base_stats["active"]>(2*N*0.1)).sum())
p_probs = base_stats["active_p"]/base_stats["active"]
H = -torch.xlogy(p_probs, p_probs)
print(H.sum(dim=0).mean(),H.std(dim=0).mean())
fen_probs = base_stats["active_fen"]/base_stats["active"]
H = -torch.xlogy(fen_probs, fen_probs)
print(H.sum(dim=0).mean(),H.std(dim=0).mean())
print(p_probs.std()*100)
print(fen_probs.std()*100)


In [None]:
print((base_stats["active_fen"]==0).sum())

In [None]:
import matplotlib.pyplot as plt

if RUN_STATS:
    plt.hist(base_stats["active"]/(2*N), bins=200)
    plt.xlabel("Active rate")
    plt.show()
if RUN_STATS:
    plt.hist(base_stats["active"]/(2*N), bins=np.logspace(-4, -0.5, 200))
    plt.xlabel("Active rate")
    plt.xscale('log')
    plt.show()

In [None]:
if RUN_STATS:
    base_stats["dead_loss"] = base_stats["opt_dead_loss"] + base_stats["sub_dead_loss"]
    base_stats["active_loss"] = base_stats["opt_active_loss"] + base_stats["sub_active_loss"]

In [None]:
if RUN_STATS:
    labels = [
        "dead_loss",
        "opt_dead_loss",
        "sub_dead_loss",
        "active_loss",
        "opt_active_loss",
        "sub_active_loss",
    ]
    boxed_data = [base_stats[label] for label in labels]

    plt.boxplot(boxed_data, notch=True, vert=True, patch_artist=True, labels=labels)
    plt.ylabel("Metric value")
    plt.xticks(rotation=20)
    plt.show()

In [None]:
if RUN_STATS:
    labels = [
        "c_diff_loss",
        "d_prod_loss",
    ]
    boxed_data = [base_stats[label] for label in labels]

    plt.boxplot(boxed_data, notch=True, vert=True, patch_artist=True, labels=labels)
    plt.ylabel("Metric value")
    plt.show()

In [None]:
c_opt, d_opt = f_ds["opt_features"].chunk(2, dim=1)
c_sub, d_sub = f_ds["sub_features"].chunk(2, dim=1)

c_f = torch.cat([c_opt, c_sub], dim=0)
d_f = torch.cat([d_opt, d_sub], dim=0)
f = torch.cat([f_ds["opt_features"],f_ds["sub_features"]], dim=0)

In [None]:
from sklearn.linear_model import LogisticRegression


if RUN_CLF:
    labels = torch.cat([torch.ones(N), torch.zeros(N)])
    clf = LogisticRegression(penalty="l2", solver="lbfgs", max_iter=1000)


In [None]:
from sklearn.metrics import f1_score, recall_score, precision_score

if RUN_CLF:

    for X in [c_f, d_f, f]:
        clf.fit(X, labels)

        y_pred = clf.predict(X)

        metrics = {}
        metrics["precision"] = precision_score(labels, y_pred)
        metrics["recall"] = recall_score(labels, y_pred)
        metrics["f1"] = f1_score(labels, y_pred)
        print(metrics)

In [None]:
if RUN_CLF:
    plt.hist(clf.coef_[0], bins=200)
    plt.xlabel("coef")
    plt.show()

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.cluster import AgglomerativeClustering, SpectralClustering, KMeans, FeatureAgglomeration
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA, NMF

small_c_f, _, small_d_f, _, small_indices, _,= train_test_split(c_f, d_f, torch.arange(N*2), train_size=0.01, random_state=42)





In [None]:
tsne = TSNE(n_components=2, perplexity=20)
pca = PCA(n_components=50)
nmf = NMF(n_components=50)

In [None]:
cls = AgglomerativeClustering(
    n_clusters=100,
)
cls.fit(small_c_f)
c_labels = cls.labels_
print(small_d_f.shape)
cls.fit(small_d_f)
d_labels = cls.labels_

In [None]:
# X_pca = pca.fit_transform(small_c_f)
# X_c_tsne = tsne.fit_transform(X_pca)

# # Plot the clustered data
# plt.scatter(X_c_tsne[:, 0], X_c_tsne[:, 1], c=c_labels, cmap="tab20")
# plt.xlabel("AU 1")
# plt.ylabel("AU 2")

In [None]:
X_nmf = nmf.fit_transform(small_c_f)
X_c_tsne = tsne.fit_transform(X_nmf)

cls.fit(X_nmf)
c_labels = cls.labels_

# Plot the clustered data
plt.scatter(X_c_tsne[:, 0], X_c_tsne[:, 1], c=c_labels, cmap="rainbow")
plt.xlabel("AU 1")
plt.ylabel("AU 2")

In [None]:
# Plot the clustered data
plt.scatter(X_c_tsne[:, 0], X_c_tsne[:, 1], c=c_labels, cmap="tab20")
plt.xlabel("AU 1")
plt.ylabel("AU 2")

In [None]:
# X_pca = pca.fit_transform(small_d_f)
# X_d_tsne = tsne.fit_transform(X_pca)

# # Plot the clustered data
# plt.scatter(X_d_tsne[:, 0], X_d_tsne[:, 1], c=d_labels, cmap="tab20")
# plt.xlabel("AU 1")
# plt.ylabel("AU 2")

In [None]:
X_nmf = nmf.fit_transform(small_d_f)
X_d_tsne = tsne.fit_transform(X_nmf)

cls.fit(X_nmf)
d_labels = cls.labels_

# Plot the clustered data
plt.scatter(X_d_tsne[:, 0], X_d_tsne[:, 1], c=d_labels, cmap="tab20")
plt.xlabel("AU 1")
plt.ylabel("AU 2")

In [None]:
unique, counts = np.unique(c_labels, return_counts=True)
counts

In [None]:
def compute_c_stats(batch, stats):
    opt_alive = batch["opt_features"] != 0
    sub_alive = batch["sub_features"] != 0

    stats["active"] += opt_alive.sum(dim=0) + sub_alive.sum(dim=0)
    for row, p, fen in zip(
        torch.cat([opt_alive,sub_alive], dim=0),
        batch["pixel_index"],
        batch["root_fen"],
    ):
        stats["active_p"][p,:] += row
        idx = np.where(unique_fens == fen)[0][0]
        stats["active_fen"][idx,:] += row

def return_H(indices):
    base_stats = {
        "p": torch.zeros(64),
        "fen": torch.zeros(500),
        "opt": torch.zeros(2),
    }
    for idx in indices:
        idx = idx.item()
        if idx >= N:
            base_stats["opt"][1] += 1
            idx -= N
        else:
            base_stats["opt"][0] += 1
        s = f_ds[idx]
        p = s["pixel_index"]
        fen = s["root_fen"]
        base_stats["p"][p] += 1
        idx = np.where(unique_fens == fen)[0][0]
        base_stats["fen"][idx] += 1
    n = len(indices)

    return {
        "p": -torch.xlogy(base_stats["p"]/n, base_stats["p"]/n).sum(),
        "fen": -torch.xlogy(base_stats["fen"]/n, base_stats["fen"]/n).sum(),
        "opt": -torch.xlogy(base_stats["opt"]/n, base_stats["opt"]/n).sum(),
    }


In [None]:
f_ds[0]

In [None]:
Hs = []
for i in range(100):
    bool_index = i == c_labels
    sub_indices = small_indices[bool_index]
    Hs.append(return_H(sub_indices))


In [None]:
H_p = []
H_f = []
H_o = []

for H in Hs:
    H_p.append(H["p"])
    H_f.append(H["fen"])
    H_o.append(H["opt"])
print(np.mean(H_p), np.std(H_p))
print(np.mean(H_f), np.std(H_f))
print(np.mean(H_o), np.std(H_o))

In [None]:
Hs = []
for i in range(100):
    bool_index = i == d_labels
    sub_indices = small_indices[bool_index]
    Hs.append(return_H(sub_indices))


In [None]:
H_p = []
H_f = []
H_o = []

for H in Hs:
    H_p.append(H["p"])
    H_f.append(H["fen"])
    H_o.append(H["opt"])
print(np.mean(H_p), np.std(H_p))
print(np.mean(H_f), np.std(H_f))
print(np.mean(H_o), np.std(H_o))

In [None]:
import numpy as np
unique, counts = np.unique(c_labels, return_counts=True)
counts

In [None]:
unique, counts = np.unique(d_labels, return_counts=True)
counts

In [None]:
np.corrcoef(c_labels, d_labels)

In [None]:
max_d = 2
max_c = 1
np.corrcoef(c_labels==max_c, d_labels==max_c)

In [None]:
np.corrcoef(np.eye(100)[c_labels].transpose())

In [None]:
cor = np.corrcoef(np.eye(100)[c_labels].transpose(), np.eye(100)[d_labels].transpose())
cor[100:, :100].max(axis=1).mean()

In [None]:
other_d_f, _, other_f, _, = train_test_split(d_f, f, train_size=0.04, random_state=42)
dtf_nmf = nmf.fit_transform(other_d_f.T)

In [None]:
dtf_nmf.shape

In [None]:
tsne = TSNE(n_components=2, perplexity=20)
cls = AgglomerativeClustering(
    n_clusters=100,
    linkage="average"
)
cls.fit(dtf_nmf)
dtf_nmf_labels = cls.labels_

X_dtf_nmf_tsne = tsne.fit_transform(dtf_nmf)

# Plot the clustered data
plt.scatter(X_dtf_nmf_tsne[:, 0], X_dtf_nmf_tsne[:, 1], c=dtf_nmf_labels, cmap="tab20")
plt.xlabel("UA 1")
plt.ylabel("UA 2")

In [None]:
tf_nmf = nmf.fit_transform(other_f.T)

In [None]:
tsne = TSNE(n_components=2, perplexity=30)
X_tf_nmf_tsne = tsne.fit_transform(tf_nmf)

In [None]:

cls = AgglomerativeClustering(
    n_clusters=None,
    distance_threshold=0.3,
    linkage="ward"
)
cls.fit(tf_nmf)
tf_nmf_labels = cls.labels_



# Plot the clustered data
plt.scatter(X_tf_nmf_tsne[:, 0], X_tf_nmf_tsne[:, 1], c=tf_nmf_labels, cmap="tab20")
plt.xlabel("UA 1")
plt.ylabel("UA 2")

In [None]:
unique, counts = np.unique(tf_nmf_labels, return_counts=True)
labels_to_drop = np.where(counts<=3)[0]
bool_index = tf_nmf_labels == labels_to_drop[0]
for label in labels_to_drop:
    bool_index = bool_index | (tf_nmf_labels == label)
bool_index.sum()

In [None]:
tf_nmf[~bool_index].shape

In [None]:
cls = AgglomerativeClustering(
    n_clusters=None,
    distance_threshold=0.3,
    linkage="ward"
)
cls.fit(tf_nmf[~bool_index])
tf_nmf_labels = cls.labels_



# Plot the clustered data
plt.scatter(X_tf_nmf_tsne[~bool_index, 0], X_tf_nmf_tsne[~bool_index, 1], c=tf_nmf_labels, cmap="tab20")
plt.scatter(X_tf_nmf_tsne[bool_index, 0], X_tf_nmf_tsne[bool_index, 1], c="k", label="outliers")
plt.xlabel("UA 1")
plt.ylabel("UA 2")
plt.legend()

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram


cls = AgglomerativeClustering(
    n_clusters=None,
    distance_threshold=0.,
    linkage="ward"
)
cls.fit(tf_nmf[~bool_index])

def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)

# plot the top three levels of the dendrogram
plt.figure(figsize=(10,6))
plot_dendrogram(cls, truncate_mode="level", p=4)
ax=plt.gca()
labels = [item.get_text() for item in ax.get_xticklabels()]
labels = [l if l.startswith('(') else "" for l in labels]

ax.set_xticklabels(labels)

In [None]:
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

In [None]:
sae.W_dec.shape

In [None]:
list_sims = []
for i in range(2047):
    sims = cos(sae.W_dec[i], sae.W_dec[i+1:])
    list_sims.append(sims)

In [None]:
cos(sae.W_dec[0], sae.W_dec).shape

In [None]:
all_sims = torch.cat(list_sims, dim=0)

In [None]:
all_sims.shape

In [None]:
(2048 * 2047)/2

In [None]:
c_sims = []
for i in range(1023):
    sims = cos(sae.W_dec[i], sae.W_dec[i+1:1024])
    c_sims.append(sims)
all_c_sims = torch.cat(c_sims, dim=0)
d_sims = []
for i in range(1023):
    sims = cos(sae.W_dec[i+1024], sae.W_dec[i+1025:])
    d_sims.append(sims)
all_d_sims = torch.cat(d_sims, dim=0)
cd_sims = []
for i in range(1023):
    sims = cos(sae.W_dec[i], sae.W_dec[i+1025:])
    cd_sims.append(sims)
all_cd_sims = torch.cat(cd_sims, dim=0)

In [None]:
#plt.hist(all_sims, bins=200)
plt.hist(all_c_sims, bins=200, label="c sim", alpha=0.3)
plt.hist(all_d_sims, bins=200, label="d sim", alpha=0.3)
plt.hist(all_cd_sims, bins=200, label="c-d sim", alpha=0.3)
plt.legend()

In [None]:
c_sims = []
for i in range(1023):
    sims = cos(sae.W_dec[i,:128], sae.W_dec[i+1:1024,:128])
    c_sims.append(sims)
all_c_sims = torch.cat(c_sims, dim=0)
d_sims = []
for i in range(1023):
    sims = cos(sae.W_dec[i+1024,:128], sae.W_dec[i+1025:,:128])
    d_sims.append(sims)
all_d_sims = torch.cat(d_sims, dim=0)
cd_sims = []
for i in range(1023):
    sims = cos(sae.W_dec[i,:128], sae.W_dec[i+1025:,:128])
    cd_sims.append(sims)
all_cd_sims = torch.cat(cd_sims, dim=0)

In [None]:
#plt.hist(all_sims, bins=200)
plt.hist(all_c_sims, bins=200, label="c sims", alpha=0.3)
plt.hist(all_d_sims, bins=200, label="d sims", alpha=0.3)
plt.hist(all_cd_sims, bins=200, label="cd sims", alpha=0.3)
plt.legend()

In [None]:
c_sims = []
for i in range(1023):
    sims = cos(sae.W_dec[i,128:], sae.W_dec[i+1:1024,128:])
    c_sims.append(sims)
all_c_sims = torch.cat(c_sims, dim=0)
d_sims = []
for i in range(1023):
    sims = cos(sae.W_dec[i+1024,128:], sae.W_dec[i+1025:,128:])
    d_sims.append(sims)
all_d_sims = torch.cat(d_sims, dim=0)
cd_sims = []
for i in range(1023):
    sims = cos(sae.W_dec[i,128:], sae.W_dec[i+1025:,128:])
    cd_sims.append(sims)
all_cd_sims = torch.cat(cd_sims, dim=0)

In [None]:
#plt.hist(all_sims, bins=200)
plt.hist(all_c_sims, bins=200, label="c sims", alpha=0.3)
plt.hist(all_d_sims, bins=200, label="d sims", alpha=0.3)
plt.hist(all_cd_sims, bins=200, label="cd sims", alpha=0.3)
plt.legend()

In [None]:
dico = torch.load("./dico.pt")

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
dico.shape

In [None]:
m = 5000

dico_sims = []
for i in range(m-1):
    sims = cos(dico[i], dico[i+1:m])
    dico_sims.append(sims)
all_dico_sims = torch.cat(dico_sims, dim=0)

In [None]:
plt.hist(all_dico_sims, bins=200, label="Regular SAE", alpha=0.3)
plt.legend()

In [None]:
bool_index.sum()

In [None]:
tf_nmf_labels.shape

In [None]:
_, counts = np.unique(tf_nmf_labels, return_counts=True)
torch.topk(torch.tensor(counts), k=5)

In [None]:
valid_D = sae.W_dec[~bool_index]

In [None]:
valid_D.shape

In [None]:
label_a = 5
ai = tf_nmf_labels == label_a
a_sims = []
for i in range(1680):
    if not ai[i]:
        continue
    sims = cos(valid_D[i], valid_D[i+1:])
    a_sims.append(sims)
all_a_sims = torch.cat(a_sims, dim=0)
label_b = 3
bi = tf_nmf_labels == label_b
b_sims = []
for i in range(1680):
    if not bi[i]:
        continue
    sims = cos(valid_D[i], valid_D[i+1:])
    b_sims.append(sims)
all_b_sims = torch.cat(b_sims, dim=0)

ab_sims = []
for i in range(1680):
    if not ai[i]:
        continue
    sims = cos(valid_D[i], valid_D[bi])
    ab_sims.append(sims)
all_ab_sims = torch.cat(ab_sims, dim=0)

In [None]:
plt.hist(all_a_sims, bins=200, label="c1", alpha=0.3)
plt.hist(all_b_sims, bins=200, label="c2", alpha=0.3)
plt.hist(all_ab_sims, bins=200, label="c1-c2", alpha=0.3)
plt.legend()