In [None]:
# 训练可学习 RBF kernel 的完整脚本
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from scipy.optimize import linear_sum_assignment

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device:", device)

# ---------- 模型 ----------
class EmbedNet(nn.Module):
    """
    把输入 embedding 映射到新的表示空间。
    你可以把它换成 nn.Linear(in_features, out_features, bias=False) 来做线性变换。
    """
    def __init__(self, dim_in=300, dim_hidden=300, use_nonlinear=True):
        super().__init__()
        if use_nonlinear:
            self.net = nn.Sequential(
                nn.Linear(dim_in, dim_hidden, bias=False),
                nn.ReLU(inplace=True),
                nn.Linear(dim_hidden, dim_hidden, bias=False)
            )
        else:
            self.net = nn.Linear(dim_in, dim_hidden, bias=False)
        # learnable log_gamma for positivity
        self.log_gamma = nn.Parameter(torch.tensor(0.0))  # gamma = exp(log_gamma)
    def forward(self, x):
        return self.net(x)

# ---------- 距离 / RBF 计算 ----------
def pairwise_sq_dists_torch(x):
    # x: (N, d) tensor
    # returns (N, N) matrix of squared euclidean distances
    x2 = (x * x).sum(dim=1, keepdim=True)  # (N,1)
    dist2 = x2 + x2.t() - 2.0 * (x @ x.t())
    return torch.clamp(dist2, min=0.0)

def rbf_from_Z(Z, log_gamma):
    # Z: (N,d) tensor (not necessary to be normalized)
    gamma = torch.exp(log_gamma)
    sqd = pairwise_sq_dists_torch(Z)
    K = torch.exp(-gamma * sqd)
    return K

# ---------- 辅助函数（评估 / 可视化 / purity） ----------
def make_D_from_labels_torch(y):
    y = y.view(-1,1)
    D = (y == y.t()).float()
    return D

def purity_score(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    labels_true = np.unique(y_true)
    labels_pred = np.unique(y_pred)
    cost = np.zeros((labels_pred.size, labels_true.size), dtype=int)
    for i, lp in enumerate(labels_pred):
        for j, lt in enumerate(labels_true):
            cost[i, j] = np.sum((y_pred == lp) & (y_true == lt))
    row_ind, col_ind = linear_sum_assignment(-cost)
    total = cost[row_ind, col_ind].sum()
    return total / y_true.size

def plot_heatmap(K, words=None, title=None, cmap='magma', figsize=(6,6), vmin=None, vmax=None):
    plt.figure(figsize=figsize)
    plt.imshow(K, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.colorbar(fraction=0.046, pad=0.04)
    if words is not None:
        plt.yticks(np.arange(len(words)), words, fontsize=8)
        plt.xticks(np.arange(len(words)), words, rotation=90, fontsize=8)
    if title:
        plt.title(title)
    plt.tight_layout()
    plt.show()

# ---------- 训练配置 ----------
embed_dim = 300
use_nonlinear = True    # 如果想试线性请设为 False
kernel = EmbedNet(dim_in=embed_dim, dim_hidden=embed_dim, use_nonlinear=use_nonlinear).to(device)

optimizer = torch.optim.Adam(kernel.parameters(), lr=1e-3, weight_decay=1e-5)
epochs = 100

# optional: helper to compute avg pos/neg from K and labels (numpy)
def avg_pos_neg_from_K(K_np, y_np, mask_diag=True):
    N = K_np.shape[0]
    D = (y_np[:,None] == y_np[None,:])
    mask = np.ones((N,N), dtype=bool)
    if mask_diag:
        np.fill_diagonal(mask, False)
    pos_mask = D & mask
    neg_mask = (~D) & mask
    pos_mean = K_np[pos_mask].mean() if pos_mask.sum() else np.nan
    neg_mean = K_np[neg_mask].mean() if neg_mask.sum() else np.nan
    return pos_mean, neg_mean

# ---------- 训练 loop ----------
loss_history = []
pos_history = []
neg_history = []

for epoch in range(epochs):
    kernel.train()
    total_loss = 0.0
    total_pos = 0.0
    total_neg = 0.0
    count_batches = 0

    for X_train, y_train, words_train in data_train:
        X_train = X_train.to(device)            # (16, 300)
        y_train = y_train.to(device)            # (16,)

        Z = kernel(X_train)                     # (16, d)
        # optional: normalize embeddings or not. For RBF we don't have to normalize.
        # Z = F.normalize(Z, p=2, dim=1)
        K = rbf_from_Z(Z, kernel.log_gamma)     # (16,16)

        D = make_D_from_labels_torch(y_train).to(Z.device)
        # mask out diagonal
        N = K.size(0)
        diag_mask = torch.eye(N, device=K.device).bool()
        mask = ~diag_mask

        # BCE loss between K (in (0,1)) and D (0/1)
        loss = F.binary_cross_entropy(K[mask], D[mask])

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(kernel.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

        # monitoring: compute pos/neg means (on CPU numpy)
        with torch.no_grad():
            K_np = K.detach().cpu().numpy()
            y_np = y_train.detach().cpu().numpy()
            pos_mean, neg_mean = avg_pos_neg_from_K(K_np, y_np, mask_diag=True)
            if not np.isnan(pos_mean):
                total_pos += pos_mean
            total_neg += neg_mean
            count_batches += 1

    avg_loss = total_loss / max(1, len(data_train))
    avg_pos = total_pos / max(1, count_batches)
    avg_neg = total_neg / max(1, count_batches)
    loss_history.append(avg_loss)
    pos_history.append(avg_pos)
    neg_history.append(avg_neg)

    if epoch % 5 == 0 or epoch == epochs - 1:
        print(f"epoch {epoch:03d} loss={avg_loss:.4f} pos_sim={avg_pos:.4f} neg_sim={avg_neg:.4f} gamma={float(torch.exp(kernel.log_gamma).item()):.4e}")

# ---------- 绘制 loss / pos-neg 曲线 ----------
plt.figure(figsize=(6,3))
plt.plot(loss_history, label='loss')
plt.title("training loss")
plt.legend()
plt.show()

plt.figure(figsize=(6,3))
plt.plot(pos_history, label='pos_sim')
plt.plot(neg_history, label='neg_sim')
plt.title("pos / neg mean similarity")
plt.legend()
plt.show()

# ---------- 在 test 上可视化（第一个样本）和做聚类评估 ----------
kernel.eval()
X0, y0, words0 = data_test[0]
with torch.inference_mode():
    Z0 = kernel(X0.to(device))          # (16,d)
    K0 = rbf_from_Z(Z0, kernel.log_gamma).cpu().numpy()
    y0_np = y0.numpy()
# raw heatmap
plot_heatmap(K0, words=words0, title="test K_rbf (raw)")

# reorder by kmeans on rows of K for nicer block visualization
n_clusters = len(np.unique(y0_np))
km = KMeans(n_clusters=n_clusters, random_state=0).fit(K0)
order = np.argsort(km.labels_)
K_sorted = K0[np.ix_(order, order)]
words_sorted = [words0[i] for i in order]
plot_heatmap(K_sorted, words=words_sorted, title="test K_rbf (sorted by kmeans on rows)")

# quantitative clustering on embeddings (kmeans) and on K (spectral-like: kmeans on rows)
Z0_np = Z0.detach().cpu().numpy()
# clustering on embeddings
km_emb = KMeans(n_clusters=n_clusters, random_state=0).fit(Z0_np)
pred_emb = km_emb.labels_
ari_emb = adjusted_rand_score(y0_np, pred_emb)
nmi_emb = normalized_mutual_info_score(y0_np, pred_emb)
pur_emb = purity_score(y0_np, pred_emb)

# clustering on K rows
km_K = KMeans(n_clusters=n_clusters, random_state=0).fit(K0)
pred_K = km_K.labels_
ari_K = adjusted_rand_score(y0_np, pred_K)
nmi_K = normalized_mutual_info_score(y0_np, pred_K)
pur_K = purity_score(y0_np, pred_K)

print("Embedding KMeans -> ARI, NMI, Purity:", ari_emb, nmi_emb, pur_emb)
print("K-rows KMeans -> ARI, NMI, Purity:", ari_K, nmi_K, pur_K)
print("learned gamma:", float(torch.exp(kernel.log_gamma).item()))
