In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import time
import random
import glob
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.multiprocessing.spawn import spawn
import torchvision.models as models
import torchvision.datasets as dst
from torchvision.io import read_image
from torchvision.transforms import v2

In [2]:
def transform_image(img):
    transforms = v2.Compose([
        v2.ToTensor(), 
        v2.ToDtype(torch.float32, scale=True),
    ])
    return transforms(img)

test_dataset = dst.CIFAR10(
    root="../data",
    train=False,
    download=False,
    transform=transform_image
)

In [3]:
test_dataset[0][0].size

<function Tensor.size>

In [4]:
len(test_dataset)

10000

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
# model
class Encoder(nn.Module):
    def __init__(self, bottleneck_dim=256):
        super().__init__()
        # conv 層で次第にチャネル数↑・空間サイズ↓
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),  # 32→16
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, 2, 1), # 16→8
            nn.ReLU(True),
            nn.Conv2d(128, 256, 4, 2, 1),# 8→4
            nn.ReLU(True),
        )
        self.fc = nn.Linear(256*4*4, bottleneck_dim)

    def forward(self, x):
        h = self.conv(x)
        h_flat = h.view(x.size(0), -1)
        z = self.fc(h_flat)
        return z, h  # ボトルネックと中間マップを返す
    

class Decoder(nn.Module):
    def __init__(self, bottleneck_dim=256):
        super().__init__()
        self.fc = nn.Linear(bottleneck_dim, 256*4*4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), # 4→8
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # 8→16
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),    # 16→32
            nn.Sigmoid(),  # 出力を [0,1] に
        )

    def forward(self, z):
        h = self.fc(z)
        h = h.view(z.size(0), 256, 4, 4)
        x_recon = self.deconv(h)
        return x_recon
    

class Autoencoder(nn.Module):
    def __init__(self, bottleneck_dim=256):
        super().__init__()
        self.encoder = Encoder(bottleneck_dim)
        self.decoder = Decoder(bottleneck_dim)

    def forward(self, x):
        z, h = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z, h
    
# full_sd = torch.load("./1_autoencoder.pth", map_location=device)
full_sd = torch.load("./1_2_denoising_ae.pth", map_location=device)

# 2) encoder 部分だけフィルタしてプレフィックスを除去
encoder_sd = {
    k.replace("encoder.", ""): v
    for k, v in full_sd.items()
    if k.startswith("encoder.")
}

# 3) Encoder モデルを作ってロード
encoder = Encoder(bottleneck_dim=256)
encoder.load_state_dict(encoder_sd)
encoder = nn.DataParallel(encoder.to(device))

In [7]:
dataloader = DataLoader(test_dataset, batch_size=100, shuffle=False)

encoder.eval()
with torch.no_grad():
    imgs, labels = next(iter(dataloader))
    imgs = imgs.to(device)
    # 中間特徴マップを取得
    z, h = encoder(imgs)
    # h: [100, 256, 4, 4] → flatten → [100, 256*4*4]
    feats = h.view(h.size(0), -1).cpu().numpy()

In [8]:
classes = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
for k in (5, 10):
    kmeans = KMeans(n_clusters=k, random_state=0).fit(feats)
    cluster_labels = kmeans.labels_

    # 真のラベルごとにクラスタ分布を集計
    df = pd.crosstab(pd.Series(labels, name='TrueLabel'),
                    pd.Series(cluster_labels, name='Cluster'))
    # インデックスをラベル名に変換
    df.index = [classes[i] for i in df.index]

    print(f"--- k = {k} ---")
    print(df)

--- k = 5 ---
Cluster     0   1  2   3  4
airplane    0   3  0   4  3
automobile  1   1  2   1  1
bird        0   5  0   1  2
cat         0   7  2   1  0
deer        0   6  0   1  0
dog         0   5  0   3  0
frog        0  11  4   1  0
horse       0   3  2   3  3
ship        1   3  0   2  7
truck       0   1  0  10  0
--- k = 10 ---
Cluster     0  1  2  3   4  5  6  7  8  9
airplane    0  5  1  0   1  2  1  0  0  0
automobile  1  2  0  1   1  0  0  0  0  1
bird        0  3  1  0   2  0  2  0  0  0
cat         0  2  1  2   5  0  0  0  0  0
deer        0  4  0  0   2  0  1  0  0  0
dog         0  2  0  0   2  0  2  2  0  0
frog        0  2  0  1  11  0  0  2  0  0
horse       0  1  3  1   1  0  0  4  0  1
ship        1  6  6  0   0  0  0  0  0  0
truck       1  4  0  0   0  0  4  1  1  0
