In [1]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torchvision import datasets, transforms,models
from torch.utils.data import DataLoader

import pcl.builder  # <-- 載入你的 DCBCL 架構


def run_tsne(
    checkpoint_path,
    dataset_root,
    output_dir="tsne_results",
    model_name="DC-LCL",
    batch_size=256,
    perplexities=[5,10,15,30,50,100, 200],
    max_iter=1000
):
    os.makedirs(output_dir, exist_ok=True)

    # CIFAR10 驗證資料
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    val_dataset = datasets.CIFAR10(root=dataset_root, train=False, download=True, transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    class_names = val_dataset.classes
    num_classes = len(class_names)

    # 初始化 DCBCL encoder
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = pcl.builder.MoCo(
        models.__dict__['resnet50'],  # or resnet18 if that’s what you used
        dim=128,
        r=4096,
        m=0.999,
        T=0.2,
        mlp=True
    )
    model.to(device)

    # 載入預訓練 checkpoint（DCBCL 預訓練權重）
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.eval()

    # 提取特徵
    features, labels = [], []
    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device)
            feats = model.encoder_q(x)  # MoCo encoder 的前向方法
            features.append(feats.cpu().numpy())
            labels.append(y.numpy())

    features = np.concatenate(features)
    labels = np.concatenate(labels)

    # 儲存特徵方便重複使用
    np.savez(os.path.join(output_dir, f"{model_name}_features.npz"), features=features, labels=labels)

    # 執行 t-SNE 並儲存圖像
    for perp in perplexities:
        tsne = TSNE(n_components=2, perplexity=perp, max_iter=max_iter, random_state=42)
        tsne_result = tsne.fit_transform(features)

        plt.figure(figsize=(16, 12))
        plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=labels, cmap="tab10", s=10)
        for i in range(num_classes):
            plt.scatter([], [], color=plt.cm.tab10(i / num_classes), label=class_names[i], s=100)
        plt.tick_params(axis='both', which='major', labelsize=20)
        plt.legend(
            title="CIFAR-10",
            title_fontsize=18,
            fontsize=16,
            markerscale=2,
            bbox_to_anchor=(1.05, 1),
            loc="upper left"
        )

        plt.title(f't-SNE of {model_name} Representations (Perplexity={perp}, n_iter={max_iter})', fontsize=25, fontweight='bold',pad=30)
        plt.tight_layout()
        
        save_path = os.path.join(output_dir, f"tsne_{model_name}_perp_{perp}.png")
        plt.savefig(save_path)
        plt.close()
        print(f"[\u2714] Saved {save_path}")
        plt.figure(figsize=(12, 10))
        



if __name__ == "__main__":
    run_tsne(
        checkpoint_path=r"save\dcbcl\cifar10\train\20250409\model_best.pth.tar",
        dataset_root=r"D:\Document\Project\Dataset",
        output_dir="tsne_cifar10_DC-LCL_pretrained",
        model_name="DC-LCL"
    )


Files already downloaded and verified


  checkpoint = torch.load(checkpoint_path, map_location=device)


[✔] Saved tsne_cifar10_DC-LCL_pretrained\tsne_DC-LCL_perp_5.png
[✔] Saved tsne_cifar10_DC-LCL_pretrained\tsne_DC-LCL_perp_10.png
[✔] Saved tsne_cifar10_DC-LCL_pretrained\tsne_DC-LCL_perp_15.png
[✔] Saved tsne_cifar10_DC-LCL_pretrained\tsne_DC-LCL_perp_30.png
[✔] Saved tsne_cifar10_DC-LCL_pretrained\tsne_DC-LCL_perp_50.png
[✔] Saved tsne_cifar10_DC-LCL_pretrained\tsne_DC-LCL_perp_100.png
[✔] Saved tsne_cifar10_DC-LCL_pretrained\tsne_DC-LCL_perp_200.png


<Figure size 1200x1000 with 0 Axes>

<Figure size 1200x1000 with 0 Axes>

<Figure size 1200x1000 with 0 Axes>

<Figure size 1200x1000 with 0 Axes>

<Figure size 1200x1000 with 0 Axes>

<Figure size 1200x1000 with 0 Axes>

<Figure size 1200x1000 with 0 Axes>