In [23]:
from PIL import Image
from torchvision.utils import make_grid
import argparse
import os.path

import torch
from src.clustering_models.clusternet_modules.clusternetasmodel import ClusterNetModel
from src.datasets import CustomDataset, STL10
import numpy as np
from sklearn.metrics import normalized_mutual_info_score as NMI
from sklearn.metrics import adjusted_rand_score as ARI
from src.utils import cluster_acc

def save_cluster_examples(args,predict,x_for_vis,
                          labels,num_img,grid_size):
    
    def save_image(
    tensor,
    fp,
    ) -> None:

        grid = make_grid(tensor)
        # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
        ndarr = (grid.clamp_(0, 255).permute(1, 2, 0)
                 .to("cpu", torch.uint8).numpy())
        im = Image.fromarray(ndarr)
        im.save(fp)

    if not os.path.exists(f"./{args.dataset}_imgs/"):
        os.mkdir(f"./{args.dataset}_imgs/")
    count=0
    for k in np.unique(predict):
        count+=1
        x_k = x_for_vis[predict == k][:num_img]
        y_gt = labels[predict == k][:num_img]
        if not os.path.exists(f"./{args.dataset}_imgs/{count}"):
            os.mkdir(f"./{args.dataset}_imgs/{count}")

        for i in range(min(num_img, x_k.shape[0])):
            save_image(x_k[i], f"{args.dataset}_imgs/"
                               f"{count}/clusternet_clus"
                               f"{count}_label{y_gt[i]}_{i}.jpeg")
        # save as a grid
        num_imgs = min(grid_size, x_k.shape[0])
        if num_imgs > 0:
            grid = make_grid(x_k[:num_imgs], nrow=num_imgs)
            save_image(grid, f"{args.dataset}_imgs/"
                             f"{count}/clusternet_clus{count}.jpeg")

# LOAD MODEL FROM CHECKPOINT
def fun():
    cp_path = "./saved_models/USPS/default_exp/epoch=699-step=51099.ckpt" # E.g.: "./saved_models/USPS/default_exp/epoch=499-step=36499.ckpt"
    cp_state = torch.load(cp_path)
    data_dim =10 # E.g. for MNIST, it would be 10 if the network was trained on the embeedings supplied, or 28*28 otherwise.
    K = cp_state['state_dict']['cluster_net.class_fc2.weight'].shape[0]
    hyper_param = cp_state['hyper_parameters']
    args = argparse.Namespace()
    for key, value in hyper_param.items():
        setattr(args, key, value)

    model = ClusterNetModel.load_from_checkpoint(
        checkpoint_path=cp_path,
        input_dim=data_dim,
        init_k = K,
        hparams=args
        )

    # Example for inference :
    model.eval()
    dataset_obj = CustomDataset(args)

    print(dataset_obj.data_dir)
    print(model.K)


    dataset = dataset_obj.get_train_data()
    data = dataset.data
    predict=  model(data).argmax(-1)
    labels=dataset.targets.numpy()

    acc = np.round(cluster_acc(labels, predict.numpy()), 5)
    nmi = np.round(NMI(predict.numpy(), labels), 5)
    ari = np.round(ARI(predict.numpy(), labels), 5)
    print(f"NMI: {nmi}, ARI: {ari}, acc: {acc}, final K: {len(np.unique(predict))}")

    print(np.unique(predict))


    # num_img=20
    # stl10=STL10(args)
    # test_loader=stl10.get_test_loader()
    # x_for_vis=torch.from_numpy(test_loader.dataset.data)
    # 
    # save_cluster_examples(args,predict,x_for_vis,labels,num_img=20,grid_size=8)

In [24]:
fun()

Sequential()
./pretrained_embeddings/umap_embedded_datasets/USPS
8
NMI: 0.86441, ARI: 0.80714, acc: 0.80964, final K: 8
[0 1 2 3 4 5 6 7]


In [8]:
from torchvision import transforms, datasets
path="pretrained_embeddings/MOCO/IMAGENET_50"

dataset=datasets.ImageNet(path, split="train", download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]))
# test=datasets.ImageNet(path, split="val", download=True, transform=transforms.Compose([
#             transforms.ToTensor(),
#             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
#         ]))

RuntimeError: The archive ILSVRC2012_devkit_t12.tar.gz is not present in the root directory or is corrupted. You need to download it externally and place it in pretrained_embeddings/MOCO/IMAGENET_50.