In [None]:
import torch
import torchvision
from torchvision.models import resnet50

from kiss.experiment import Experiment
from kiss.sampler import KMeansSampler, KMeansNumerousSampler, KMeansPuritySampler
from kiss.utils.configs import CONFIGS

transform = torchvision.transforms.ToTensor()
dataset_tr = torchvision.datasets.CIFAR100(root='../data', train=True, download=True, transform=transform)
dataset_te = torchvision.datasets.CIFAR100(root='../data', train=False, download=True, transform=transform)

model = resnet50(num_classes=100)
model.to(torch.device(CONFIGS.torch.device))

experiment = Experiment(
    model, 
    dataset_tr,
    dataset_te, 
    KMeansPuritySampler,
    ratio=(0.1, 0.4, 4),
    epochs=10,
    batch_size=512,
    clip=None,
    num_clusters=20,
    eqsize=True,
    min_purity=0.01,
    load_clusters="../checkpoints/kmeans,purity,fe",
    save_clusters="../checkpoints/kmeans,purity,fe")
experiment.run("../experiments", "r50,ep:10,bs:512,clip:None,nc:20,mp:0.01,fe")

In [None]:
print(list(experiment.sampler_.cluster_data_[4]))
print(experiment.sampler_.class_data_[4])

print(experiment.sampler_.cluster_data_[4][1], experiment.sampler_.cluster_data_[4][6])
print(experiment.sampler_.class_data_[4][1], experiment.sampler_.class_data_[4][6])

import pandas as pd

pd.DataFrame(experiment.sampler_.purity_data_)

In [None]:
from kiss.feature_extractor import ClassicFeatureExtractor
import matplotlib.pyplot as plt

fe = ClassicFeatureExtractor()

plt.figure(figsize=(1,1))
plt.imshow(dataset_tr[113][0].permute(1,2,0))
print(fe([dataset_tr[113][0]]))
plt.show()

plt.figure(figsize=(1,1))
plt.imshow(dataset_tr[651][0].permute(1,2,0))
print(fe([dataset_tr[651][0]]))
plt.show()