1. 载入数据

In [None]:
import numpy as np
import pandas as pd

IMAGE_DATA_FILE = "../data/caltech-256_features.npz"
CLASS_NAME_FILE = "../data/256_ObjectCategories_map_ZH.csv"

image_data = np.load(IMAGE_DATA_FILE)
# print(f"Components of image_data: {list(image_data.keys())}")

X_vit = image_data.get("vit_features")
print(f"The shape of X_vit is {X_vit.shape}")

X_clip = image_data.get("clip_features")
print(f"The shape of X_clip is {X_clip.shape}")

y = image_data.get("labels")
print(f"The shape of y is {y.shape}")

class_name_df = pd.read_csv(CLASS_NAME_FILE)
class_name_map = class_name_df.set_index("class")["handle"].to_dict()

2. 数据标准化，分别针对 ViT 和 Clip 两模型的输出特征，在待分析(聚类或可视化)样本集合上，尝试多种强度（对原始特征的影响）不同的标准化手段：
    - 无标准化。
    - 样本点 L2 范数均值的单位化，即仅放缩两个模型的输出，使全部待分析样本点分布在单位球面附近。
    - 各特征的 Standard Scaler, 使全部待分析样本点分布呈近似单位立方体分布, scipy.cluster.vq.whiten or sklearn.preprocessing.StandardScaler。
    - PCA/ZCA Whitening。

In [None]:
from standardization import get_standard_data

STANDARD_METHOD = None
# STANDARD_METHOD = "l2_norm"
# STANDARD_METHOD = "feature_standard"
# STANDARD_METHOD = "PCA_whiten"
# STANDARD_METHOD = "ZCA_whiten"

X_vit = get_standard_data(X_vit, STANDARD_METHOD)
X_clip = get_standard_data(X_clip, STANDARD_METHOD)

3. 获取各类别的表示
    - 在数据中剔除“其它”
    - 各类别全部样本点的重心。

In [None]:
labels = np.unique(y)
# 最后一类是其它，后续分析应剔除
print(f"The last categoty is \"{labels[-1]}\".")
not_clutter_index = np.where(y != labels[-1])
X_vit = X_vit[not_clutter_index]
X_clip = X_clip[not_clutter_index]
y = y[not_clutter_index]
print(f"Samples of Category \"{labels[-1]}\" are removed.")

X_vit_class = np.array([np.mean(X_vit[np.where(y == label)], axis=0)
                        for label in labels[:-1]])
X_clip_class = np.array([np.mean(X_clip[np.where(y == label)], axis=0)
                         for label in labels[:-1]])
y_class = np.array([label for label in labels[:-1]])

4. 延续之前经验，绘制聚类热度图，便于观察整体。

In [None]:
import seaborn as sns
from seaborn import clustermap
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage

# METRIC = "euclidean"
METRIC = "cityblock"

FIGURE_PATH = "../results/cluster/"


def draw_clustermap(X, y, metric):

    y_code_zh = []
    for code_en in y:
        code, en = code_en.split(".")
        y_code_zh.append(code+"."+class_name_map[en])

    X_pdist = pdist(X, metric)
    X_linkage = linkage(X_pdist,
                        method="ward",
                        optimal_ordering=True)
    sns.set(font="SimHei")
    X_clustermap = clustermap(squareform(X_pdist),
                              row_linkage=X_linkage,
                              col_linkage=X_linkage,
                              figsize=(70, 55),
                              dendrogram_ratio=(0.3, 1/11),
                              xticklabels=False,
                              yticklabels=y_code_zh,
                              cbar_pos=None)
    return X_clustermap

In [None]:

FIGURE_FILE = f"clustermap_vit_&_{STANDARD_METHOD}_&_{METRIC}"


X_vit_clustermap = draw_clustermap(X_vit_class, y_class, METRIC)

X_vit_clustermap.fig.suptitle(
    FIGURE_FILE,
    x=0.5, y=1, va="bottom",
    fontsize=60
)

X_vit_clustermap.savefig(FIGURE_PATH+FIGURE_FILE+".pdf", format="pdf")

In [None]:

FIGURE_FILE = f"clustermap_clip_&_{STANDARD_METHOD}_&_{METRIC}"


X_clip_clustermap = draw_clustermap(X_clip_class, y_class, METRIC)

X_clip_clustermap.fig.suptitle(
    FIGURE_FILE,
    x=0.5, y=1, va="bottom",
    fontsize=60
)

X_clip_clustermap.savefig(FIGURE_PATH+FIGURE_FILE+".pdf", format="pdf")

5. 进绘制聚类树

In [None]:
from scipy.cluster.hierarchy import dendrogram, set_link_color_palette
from matplotlib import colormaps, colors
import matplotlib.pyplot as plt


def draw_dendrogram(X, y, metric, color_threshold):

    y_code_zh = []
    for code_en in y:
        code, en = code_en.split(".")
        y_code_zh.append(code+"."+class_name_map[en])

    sns.set(font="SimHei")
    plt.figure(figsize=(15, 30))

    X_pdist = pdist(X, metric)
    X_linkage = linkage(X_pdist,
                        method="ward",
                        optimal_ordering=True)

    cmap = colormaps.get_cmap("Dark2")
    cmap_colors = cmap(np.linspace(0, 1, 8))
    set_link_color_palette([colors.rgb2hex(rgb[:3]) for rgb in cmap_colors])

    X_dendrogram = dendrogram(X_linkage,
                              p=5,
                              truncate_mode=None,
                              color_threshold=color_threshold,
                              above_threshold_color="silver",
                              orientation="left",
                              labels=y_code_zh,
                              show_leaf_counts=True,
                              leaf_rotation=0,
                              show_contracted=True)

    set_link_color_palette(None)

    return X_dendrogram


FIGURE_FILE = f"dendrogram_clip_&_{STANDARD_METHOD}_preprocess_&_{METRIC}"
X_clip_dendrogram = draw_dendrogram(X_clip_class, y_class, METRIC, 120)
plt.title(FIGURE_FILE)
# plt.savefig(FIGURE_PATH+FIGURE_FILE+".pdf", format="pdf")

In [None]:
FIGURE_FILE = f"dendrogram_vit_&_{STANDARD_METHOD}_preprocess_&_{METRIC}"
X_vit_dendrogram = draw_dendrogram(X_vit_class, y_class, METRIC, 0.5)
plt.title(FIGURE_FILE)
plt.savefig(FIGURE_PATH+FIGURE_FILE+".pdf", format="pdf")