In [1]:
import torch
import torchvision
from torchvision.models import resnet50

from kiss.models import vgg16_kiss
from kiss.experiment import Experiment
from kiss.sampler import RandomSampler, KMeansSampler, KMeansPuritySampler, KMeansDinoSampler, KMeansPurityDinoSampler
from kiss.utils.configs import CONFIGS

transform = torchvision.transforms.ToTensor()
dataset_cifar100_tr = torchvision.datasets.CIFAR100(root='../data/cifar100', train=True, download=True, transform=transform)
dataset_cifar100_te = torchvision.datasets.CIFAR100(root='../data/cifar100', train=False, download=True, transform=transform)

dataset_svhn_tr = torchvision.datasets.SVHN(root='../data/svhn', split='train', download=True, transform=transform)
dataset_svhn_te = torchvision.datasets.SVHN(root='../data/svhn', split='test', download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ../data/svhn/train_32x32.mat
Using downloaded and verified file: ../data/svhn/test_32x32.mat


In [2]:
model = resnet50(num_classes=100)
model.to(torch.device(CONFIGS.torch.device))

experiment = Experiment(
    model = model,
    dataset_tr = dataset_cifar100_tr,
    dataset_te = dataset_cifar100_te,
    sampler_cls=KMeansSampler,
    ratio=(0.1, 0.3, 3),
    epochs=10,
    batch_size=512,
    clip=5.0,
    num_clusters=10,
    eqsize=True,
    load_clusters="/Users/michal/GitHub/KISS/checkpoints/KMeansSampler,CIFAR100,nc:10",
)
experiment.run("../experiments", "GOODVALID")

[1m[33mRunning experiment ResNet!CIFAR100!KMeansSampler
[0m[1m[95mRunning run GOODVALID/1
[0mKept clusters size 88
Kept clusters size 100
Kept clusters size 86
Kept clusters size 74
Kept clusters size 89
Kept clusters size 84
Kept clusters size 87
Kept clusters size 82
Kept clusters size 80
Kept clusters size 90
Kept clusters size 86
Kept clusters size 93
Kept clusters size 101
Kept clusters size 99
Kept clusters size 87
Kept clusters size 88
Kept clusters size 91
Kept clusters size 92
Kept clusters size 78
Kept clusters size 93
Kept clusters size 79
Kept clusters size 83
Kept clusters size 84
Kept clusters size 84
Kept clusters size 92
Kept clusters size 87
Kept clusters size 94
Kept clusters size 102
Kept clusters size 88
Kept clusters size 81
Kept clusters size 90
Kept clusters size 98
Kept clusters size 90
Kept clusters size 92
Kept clusters size 89
Kept clusters size 82
Kept clusters size 112
Kept clusters size 57
Kept clusters size 72
Kept clusters size 81
Kept clusters siz

Epoch 1/10: 100%|██████████| 8/8 [00:04<00:00,  1.78 batch/s, loss=5.1350]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.25 batch/s, loss=4.6840]


[1m[36mBest valid loss improved. Current accuracy is 0.98%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 0.98%. Saving checkpoint...
[0m

Epoch 2/10: 100%|██████████| 8/8 [00:03<00:00,  2.21 batch/s, loss=4.6611]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.68 batch/s, loss=4.8250]


[1m[36mBest valid accuracy improved. Current accuracy is 1.06%. Saving checkpoint...
[0m

Epoch 3/10: 100%|██████████| 8/8 [00:03<00:00,  2.23 batch/s, loss=4.2048]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.55 batch/s, loss=4.9176]


[1m[36mBest valid accuracy improved. Current accuracy is 1.77%. Saving checkpoint...
[0m

Epoch 4/10: 100%|██████████| 8/8 [00:03<00:00,  2.21 batch/s, loss=3.7509]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.49 batch/s, loss=5.1172]


[1m[36mBest valid accuracy improved. Current accuracy is 2.07%. Saving checkpoint...
[0m

Epoch 5/10: 100%|██████████| 8/8 [00:03<00:00,  2.20 batch/s, loss=3.2112]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.69 batch/s, loss=5.0871]


[1m[36mBest valid accuracy improved. Current accuracy is 3.64%. Saving checkpoint...
[0m

Epoch 6/10: 100%|██████████| 8/8 [00:03<00:00,  2.20 batch/s, loss=2.5341]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.69 batch/s, loss=4.7191]


[1m[36mBest valid accuracy improved. Current accuracy is 5.88%. Saving checkpoint...
[0m

Epoch 7/10: 100%|██████████| 8/8 [00:03<00:00,  2.20 batch/s, loss=1.7765]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.68 batch/s, loss=4.7252]


[1m[36mBest valid accuracy improved. Current accuracy is 7.66%. Saving checkpoint...
[0m

Epoch 8/10: 100%|██████████| 8/8 [00:03<00:00,  2.20 batch/s, loss=1.1736]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.69 batch/s, loss=5.2955]
Epoch 9/10: 100%|██████████| 8/8 [00:03<00:00,  2.19 batch/s, loss=0.8769]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.69 batch/s, loss=5.6587]


[1m[36mBest valid accuracy improved. Current accuracy is 8.15%. Saving checkpoint...
[0m

Epoch 10/10: 100%|██████████| 8/8 [00:03<00:00,  2.20 batch/s, loss=0.6461]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.48 batch/s, loss=6.2896]


[1m[36mBest valid accuracy improved. Current accuracy is 9.01%. Saving checkpoint...
[0m

Testing:  50%|█████     | 10/20 [00:01<00:01,  5.40 batch/s]

In [None]:
import pickle
from collections import Counter

with open("/Users/michal/GitHub/KISS/checkpoints/KMeansDinoSampler,CIFAR100,nc:10/cluster_data.pickle", "rb") as file:
    cluster_data = pickle.load(file)
    
with open("/Users/michal/GitHub/KISS/checkpoints/KMeansDinoSampler,SVHN,nc:10/cluster_data.pickle", "rb") as file:
    cluster_data = pickle.load(file)
    
total = 0
for label, clusters in cluster_data.items():    
    cluster_sizes = dict(Counter(clusters))
    cluster_sizes = dict(sorted(cluster_sizes.items(), key=lambda item: item[1], reverse=True))
    print(label, cluster_sizes)
    keep_clusters = list(cluster_sizes.keys())[:max(1, int(10 * 0.3 * 0.9))]
    keep_clusters_size = 0
    for cluster, size in cluster_sizes.items():
        if cluster not in keep_clusters: continue
        keep_clusters_size += size
    total += keep_clusters_size
    print(keep_clusters, keep_clusters_size)
    
print(total)

In [None]:
import numpy as np
for ratio in np.arange(0.1, 1.01, 0.1):
    print(ratio, ratio * 0.95 * 10)