In [13]:
!nvidia-smi

Wed Aug 24 01:19:28 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P0    26W /  70W |  15098MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!git clone https://github.com/TTN-YKK/Clustering_friendly_representation_learning.git
%cd Clustering_friendly_representation_learning
!pip install -r requirements.txt
#!python main.py --gpus 0

fatal: destination path 'Clustering_friendly_representation_learning' already exists and is not an empty directory.
/content/Clustering_friendly_representation_learning
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [28]:
!ls

LICENSE  main.py  README.md  requirements.txt


In [26]:
#! /usr/bin/env python

import os
import time
import argparse

import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
from scipy.optimize import linear_sum_assignment

import torch
import torch.nn as nn
from torch.autograd import Function
from torchvision import datasets, transforms
from torchvision.models import resnet

In [5]:
def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument("-g", "--gpus", type=str, default="")
    parser.add_argument("-n", "--num_workers", type=int, default=8)
    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    return args


class CIFAR10(datasets.CIFAR10):
    def __getitem__(self, index):
        img, target = super().__getitem__(index)
        return img, target, index


class metrics:
    ari = adjusted_rand_score
    nmi = normalized_mutual_info_score

    @staticmethod
    def acc(y_true, y_pred):
        y_true = y_true.astype(np.int64)
        y_pred = y_pred.astype(np.int64)
        assert y_pred.size == y_true.size
        D = max(y_pred.max(), y_true.max()) + 1
        w = np.zeros((D, D), dtype=np.int64)
        for i in range(y_pred.size):
            w[y_pred[i], y_true[i]] += 1
        row, col = linear_sum_assignment(w.max() - w)
        return sum([w[i, j] for i, j in zip(row, col)]) * 1.0 / y_pred.size


def calc_clustering_metrics(features, targets):
    z = features.detach().numpy()
    y = np.array(targets)
    n_clusters = len(np.unique(y))
    kmeans = KMeans(n_clusters=n_clusters, n_init=20)
    y_pred = kmeans.fit_predict(z)
    return metrics.acc(y, y_pred), metrics.nmi(y, y_pred), metrics.ari(y, y_pred)


class Normalize(nn.Module):
    def __init__(self, power=2):
        super().__init__()
        self.power = power

    def forward(self, x):
        norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
        out = x.div(norm)
        return out


def ResNet18(low_dim=128):
    net = resnet.ResNet(resnet.BasicBlock, [2, 2, 2, 2], low_dim)
    net.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                          stride=1, padding=1, bias=False)
    net.maxpool = nn.Identity()
    return net

In [6]:
class Arg:
    def __init__(self, gpus="", num_workers=8):
        self.gpus = str(gpus)
        self.num_workers = num_workers

In [39]:
def inference(model_path,batch_size,arg):

    args = arg
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # テンソルに変換・正規化
    tf = [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                             std=[0.2470, 0.2435, 0.2616])
    ]
    transform = transforms.Compose(tf)

    # データセット読み込み
    test_set = CIFAR10(root="~/.datasets",
                       train=False,
                       download=True,
                       transform=transform)
    test_loader = torch.utils.data.DataLoader(test_set,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               pin_memory=True,
                                               num_workers=args.num_workers)

    # モデルの定義
    low_dim = 128
    net = ResNet18(low_dim=low_dim)
    norm = Normalize(2)

    net, norm = net.to(device), norm.to(device)

    # 重みのロード
    net.load_state_dict(torch.load(model_path))

    print("バッチサイズ:{}".format(batch_size))

    # 推論
    start = time.time()

    net.eval()
    features_buffer = []
    for inputs, _, _ in test_loader:
        with torch.no_grad():
            inputs = inputs.to(device, dtype=torch.float32, non_blocking=True)
            features = norm(net(inputs)).cpu()

        features_buffer.append(features)
        del features
        torch.cuda.empty_cache()

    features_buffer = torch.cat(features_buffer,dim=0)
    targets = test_loader.dataset.targets
    acc, nmi, ari =  calc_clustering_metrics(features_buffer, targets)
    print("ACC:{} NMI:{} ARI:{}".format(acc,nmi,ari))
    elapsed_time = time.time() - start
    print("実行時間:{}".format(elapsed_time) + "[秒]\n")

In [42]:
model_path = "/content/drive/MyDrive/idfd_epoch_1999.pth"
arg = Arg(gpus=0,num_workers=2)

# 推論
inference(model_path,batch_size=128,arg=arg)

Files already downloaded and verified
バッチサイズ:128
ACC:0.8088 NMI:0.7040700954708554 ARI:0.6451485747886236
実行時間:7.893804550170898[秒]



In [43]:
# 各バッチサイズで推論
batch_list = [64,128,256,512,1024,2048,4096,10000]
for batch_size in batch_list:
    inference(model_path,batch_size,arg)

Files already downloaded and verified
バッチサイズ:64
ACC:0.8072 NMI:0.7025449151478466 ARI:0.6427341271548341
実行時間:9.243892192840576[秒]

Files already downloaded and verified
バッチサイズ:128
ACC:0.8065 NMI:0.7013995668719712 ARI:0.6417647179414508
実行時間:7.7186033725738525[秒]

Files already downloaded and verified
バッチサイズ:256
ACC:0.8061 NMI:0.7055964260698524 ARI:0.6483797245644722
実行時間:7.495065212249756[秒]

Files already downloaded and verified
バッチサイズ:512
ACC:0.8074 NMI:0.7029782870955685 ARI:0.643858631211743
実行時間:7.118958950042725[秒]

Files already downloaded and verified
バッチサイズ:1024
ACC:0.8067 NMI:0.702037683668778 ARI:0.6425086036756946
実行時間:7.152648687362671[秒]

Files already downloaded and verified
バッチサイズ:2048
ACC:0.806 NMI:0.7052292332758415 ARI:0.6481220321584356
実行時間:7.156885147094727[秒]

Files already downloaded and verified
バッチサイズ:4096
ACC:0.8072 NMI:0.7024831858120149 ARI:0.6427413717881044
実行時間:9.016765832901001[秒]

Files already downloaded and verified
バッチサイズ:10000
ACC:0.7692 NMI:0.6