In [None]:
import torch
import torchvision
from torchvision.models import resnet50

from kiss.models import vgg16_kiss
from kiss.experiment import Experiment
from kiss.sampler import RandomSampler, KMeansSampler, KMeansPuritySampler, KMeansDinoSampler, KMeansPurityDinoSampler
from kiss.utils.configs import CONFIGS

transform = torchvision.transforms.ToTensor()
dataset_cifar100_tr = torchvision.datasets.CIFAR100(root='../data/cifar100', train=True, download=True, transform=transform)
dataset_cifar100_te = torchvision.datasets.CIFAR100(root='../data/cifar100', train=False, download=True, transform=transform)

dataset_svhn_tr = torchvision.datasets.SVHN(root='../data/svhn', split='train', download=True, transform=transform)
dataset_svhn_te = torchvision.datasets.SVHN(root='../data/svhn', split='test', download=True, transform=transform)

In [None]:
model = resnet50(num_classes=100)
model.to(torch.device(CONFIGS.torch.device))

experiment = Experiment(
    model = model,
    dataset_tr = dataset_cifar100_tr,
    dataset_te = dataset_cifar100_te,
    sampler_cls=KMeansDinoSampler,
    ratio=(0.1, 0.3, 3),
    epochs=10,
    batch_size=512,
    clip=5.0,
    num_clusters=100,
    eqsize=True,
    min_purity=0.01,
    # load_clusters="KMeansPurity2Sampler,CIFAR100,nc:8",
    save_clusters="KMeansDinoSampler,CIFAR100,nc:100",
)
experiment.run("../experiments", "GOODVALID")

In [6]:
import pickle
from collections import Counter

with open("/Users/michal/GitHub/KISS/checkpoints/KMeansDinoSampler,CIFAR100,nc:10/cluster_data.pickle", "rb") as file:
    cluster_data = pickle.load(file)
    
for label, clusters in cluster_data.items():
    cnt = Counter(clusters)
    print(label, sum(dict(cnt).values()))

48 400
19 400
31 400
0 400
42 400
62 400
8 400
82 400
71 400
16 400
66 400
40 400
10 400
84 400
72 400
77 400
93 400
75 400
12 400
91 400
81 400
5 400
94 400
22 400
78 400
97 400
44 400
95 400
90 400
17 400
98 400
18 400
26 400
70 400
30 400
54 400
89 400
65 400
68 400
24 400
49 400
52 400
3 400
34 400
86 400
29 400
38 400
39 400
80 400
35 400
83 400
50 400
85 400
74 400
76 400
55 400
27 400
64 400
11 400
63 400
1 400
43 400
88 400
92 400
51 400
20 400
9 400
60 400
4 400
61 400
21 400
87 400
69 400
41 400
37 400
56 400
14 400
67 400
58 400
59 400
15 400
32 400
6 400
45 400
23 400
79 400
57 400
13 400
96 400
99 400
7 400
2 400
73 400
47 400
46 400
53 400
28 400
25 400
36 400
33 400
