### Eksperymenty

**Modele**
- resnet50
- vgg16

**Zbiory danych**
- CIFAR100
- SVHN

**Metody**
- Random
- KMeans
- KMeansPurity
- KMeansDino
- KMeansDinoPurity

**Ustawienia**
- ratio = (0.1, 1, 10)
- epochs = 10
- batch_size = 512
- clip = 5.0
- num_clusters = 10
- eqsize = True
- min_purity = 0.1 (SVHN), 0.01 (CIFAR100)

**Nazewnictwo przebiegów**
W nawiasach kwadratowych podano wartość do wyboru ("jeden z ...") w nawiasch klamrowych wartości opcjonalne.
ID służy do rozróżniania eksperymentów w tej samej konfiguracji

- [r50,vgg13],ep:10,bs:512,clip:5.0,{nc:20},{eqsize},{mp:[0.1,0.01]},ID:[1,2,3,..]
- przykład: r50,ep:10,bs:512,clip:5.0,ID:1
- przykład: r50,ep:10,bs:512,clip:50,nc:20,eqsize,mp:0.1,ID:1

**Uwagi**
- Przy każdym eksperymencie nowy model (!)

In [1]:
import torch
import torchvision
from torchvision.models import resnet50

from kiss.models import vgg16_kiss
from kiss.experiment import Experiment
from kiss.sampler import RandomSampler, KMeansSampler, KMeansPuritySampler, KMeansPurity2Sampler, KMeansDinoSampler, KMeansPurityDinoSampler
from kiss.utils.configs import CONFIGS

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
])

dataset_cifar100_tr = torchvision.datasets.CIFAR100(root='../data/cifar100', train=True, download=True, transform=transform)
dataset_cifar100_te = torchvision.datasets.CIFAR100(root='../data/cifar100', train=False, download=True, transform=transform)

dataset_cifar10_tr = torchvision.datasets.CIFAR100(root='../data/cifar100', train=True, download=True, transform=transform)
dataset_cifar10_te = torchvision.datasets.CIFAR100(root='../data/cifar100', train=False, download=True, transform=transform)

dataset_svhn_tr = torchvision.datasets.SVHN(root='../data/svhn', split='train', download=True, transform=transform)
dataset_svhn_te = torchvision.datasets.SVHN(root='../data/svhn', split='test', download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ../data/svhn/train_32x32.mat
Using downloaded and verified file: ../data/svhn/test_32x32.mat


In [2]:
ID = 5
RATIO = (0.1, 1, 10)
EPOCHS = 10
BATCH_SIZE = 512
CLIP = 5.0
NUM_CLUSTERS = 10
EQSIZE = True

# 5 - poprawna walidacja

In [3]:
import os

for model_fun in [resnet50]:
    for dataset_tr, dataset_te in zip([dataset_cifar100_tr, dataset_svhn_tr], [dataset_cifar100_te, dataset_svhn_te]):
        for sampler_cls in [KMeansDinoSampler]:

            if model_fun.__name__ == 'resnet50':
                MODEL_NAME = 'r50'
            if model_fun.__name__ == 'vgg16':
                MODEL_NAME = 'vgg16'
            if model_fun.__name__ == 'vgg16_kiss':
                MODEL_NAME = 'vgg16_kiss'

            RUN_NAME = f"{MODEL_NAME},ep:{EPOCHS},bs:{BATCH_SIZE},clip:{CLIP}"

            if dataset_tr.__class__.__name__ == 'SVHN':
                NUM_CLASSES = 10
                MIN_PURITY = 0.1
            if dataset_tr.__class__.__name__ == 'CIFAR100':
                NUM_CLASSES = 100
                MIN_PURITY = 0.01

            if 'KMeans' in sampler_cls.__name__:
                RUN_NAME += f",nc:{NUM_CLUSTERS}"

            if 'KMeans' in sampler_cls.__name__ and EQSIZE:
                RUN_NAME += ",eqsize"

            if 'KMeansPurity' in sampler_cls.__name__:
                RUN_NAME += f",mp:{MIN_PURITY}"

            SAVE_CLUSTERS = LOAD_CLUSTERS = f"../checkpoints/{sampler_cls.__name__},{dataset_tr.__class__.__name__},nc:{NUM_CLUSTERS}"
            
            if not os.path.exists(LOAD_CLUSTERS):
                LOAD_CLUSTERS = None
            
            RUN_NAME += f',ID:{ID}'

            model = model_fun(num_classes=NUM_CLASSES)
            model.to(torch.device(CONFIGS.torch.device))

            experiment = Experiment(
                model = model,
                dataset_tr = dataset_tr,
                dataset_te = dataset_te,
                sampler_cls=sampler_cls,
                ratio=RATIO,
                epochs=EPOCHS,
                batch_size=BATCH_SIZE,
                clip=CLIP,
                num_clusters=NUM_CLUSTERS,
                eqsize=EQSIZE,
                min_purity=MIN_PURITY,
                load_clusters=LOAD_CLUSTERS,
                save_clusters=SAVE_CLUSTERS,
            )
            experiment.run("../experiments", RUN_NAME)

Using cache found in /Users/michal/.cache/torch/hub/facebookresearch_dinov2_main
Clustering: 100%|██████████| 100/100 [33:01<00:00, 19.82s/it]


[1m[33mRunning experiment ResNet!CIFAR100!KMeansDinoSampler
[0m[1m[95mRunning run r50,ep:10,bs:512,clip:5.0,nc:10,eqsize,ID:5/1
[0m4000 3984


Epoch 1/10: 100%|██████████| 8/8 [00:04<00:00,  1.79 batch/s, loss=5.2370]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.44 batch/s, loss=4.6722]


[1m[36mBest valid loss improved. Current accuracy is 0.97%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 0.97%. Saving checkpoint...
[0m

Epoch 2/10: 100%|██████████| 8/8 [00:03<00:00,  2.20 batch/s, loss=4.7176]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.66 batch/s, loss=4.7515]


[1m[36mBest valid accuracy improved. Current accuracy is 1.09%. Saving checkpoint...
[0m

Epoch 3/10: 100%|██████████| 8/8 [00:03<00:00,  2.20 batch/s, loss=4.3138]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.66 batch/s, loss=4.6646]


[1m[36mBest valid loss improved. Current accuracy is 1.90%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 1.90%. Saving checkpoint...
[0m

Epoch 4/10: 100%|██████████| 8/8 [00:03<00:00,  2.21 batch/s, loss=3.9223]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.66 batch/s, loss=4.8956]
Epoch 5/10: 100%|██████████| 8/8 [00:03<00:00,  2.20 batch/s, loss=3.3180]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.66 batch/s, loss=4.7254]


[1m[36mBest valid accuracy improved. Current accuracy is 3.47%. Saving checkpoint...
[0m

Epoch 6/10: 100%|██████████| 8/8 [00:03<00:00,  2.22 batch/s, loss=2.5874]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.64 batch/s, loss=4.5399]


[1m[36mBest valid loss improved. Current accuracy is 5.96%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 5.96%. Saving checkpoint...
[0m

Epoch 7/10: 100%|██████████| 8/8 [00:03<00:00,  2.21 batch/s, loss=1.8534]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.65 batch/s, loss=4.9401]


[1m[36mBest valid accuracy improved. Current accuracy is 6.12%. Saving checkpoint...
[0m

Epoch 8/10: 100%|██████████| 8/8 [00:03<00:00,  2.21 batch/s, loss=1.1264]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.65 batch/s, loss=5.2759]


[1m[36mBest valid accuracy improved. Current accuracy is 7.60%. Saving checkpoint...
[0m

Epoch 9/10: 100%|██████████| 8/8 [00:03<00:00,  2.23 batch/s, loss=0.7477]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.65 batch/s, loss=5.8123]


[1m[36mBest valid accuracy improved. Current accuracy is 7.74%. Saving checkpoint...
[0m

Epoch 10/10: 100%|██████████| 8/8 [00:03<00:00,  2.21 batch/s, loss=0.5810]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.65 batch/s, loss=5.9348]


[1m[36mBest valid accuracy improved. Current accuracy is 8.02%. Saving checkpoint...
[0m

Testing: 100%|██████████| 20/20 [00:03<00:00,  5.53 batch/s]


[1m[95mRunning run r50,ep:10,bs:512,clip:5.0,nc:10,eqsize,ID:5/2
[0m8000 7914


Epoch 1/10: 100%|██████████| 16/16 [00:07<00:00,  2.14 batch/s, loss=5.0388]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.63 batch/s, loss=4.7068]


[1m[36mBest valid loss improved. Current accuracy is 1.05%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 1.05%. Saving checkpoint...
[0m

Epoch 2/10: 100%|██████████| 16/16 [00:07<00:00,  2.21 batch/s, loss=4.3625]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.67 batch/s, loss=4.6661]


[1m[36mBest valid loss improved. Current accuracy is 2.40%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 2.40%. Saving checkpoint...
[0m

Epoch 3/10: 100%|██████████| 16/16 [00:07<00:00,  2.22 batch/s, loss=4.0663]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.68 batch/s, loss=27.9135]


[1m[36mBest valid accuracy improved. Current accuracy is 2.47%. Saving checkpoint...
[0m

Epoch 4/10: 100%|██████████| 16/16 [00:07<00:00,  2.22 batch/s, loss=3.7160]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.67 batch/s, loss=4.1014]


[1m[36mBest valid loss improved. Current accuracy is 8.96%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 8.96%. Saving checkpoint...
[0m

Epoch 5/10: 100%|██████████| 16/16 [00:07<00:00,  2.22 batch/s, loss=3.3695]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.66 batch/s, loss=13.6558]
Epoch 6/10: 100%|██████████| 16/16 [00:07<00:00,  2.21 batch/s, loss=2.8431]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.67 batch/s, loss=4.1012]


[1m[36mBest valid loss improved. Current accuracy is 10.87%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 10.87%. Saving checkpoint...
[0m

Epoch 7/10: 100%|██████████| 16/16 [00:07<00:00,  2.21 batch/s, loss=2.2452]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.65 batch/s, loss=4.7731]


[1m[36mBest valid accuracy improved. Current accuracy is 12.03%. Saving checkpoint...
[0m

Epoch 8/10: 100%|██████████| 16/16 [00:07<00:00,  2.21 batch/s, loss=1.6673]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.65 batch/s, loss=5.7710]
Epoch 9/10: 100%|██████████| 16/16 [00:07<00:00,  2.21 batch/s, loss=1.1679]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.51 batch/s, loss=5.1297]


[1m[36mBest valid accuracy improved. Current accuracy is 12.76%. Saving checkpoint...
[0m

Epoch 10/10: 100%|██████████| 16/16 [00:07<00:00,  2.21 batch/s, loss=0.8032]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.55 batch/s, loss=5.5329]


[1m[36mBest valid accuracy improved. Current accuracy is 12.95%. Saving checkpoint...
[0m

Testing: 100%|██████████| 20/20 [00:03<00:00,  5.50 batch/s]


[1m[95mRunning run r50,ep:10,bs:512,clip:5.0,nc:10,eqsize,ID:5/3
[0m12000 11765


Epoch 1/10: 100%|██████████| 24/24 [00:11<00:00,  2.13 batch/s, loss=4.8753]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.62 batch/s, loss=4.6822]


[1m[36mBest valid loss improved. Current accuracy is 1.78%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 1.78%. Saving checkpoint...
[0m

Epoch 2/10: 100%|██████████| 24/24 [00:10<00:00,  2.22 batch/s, loss=4.2145]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.55 batch/s, loss=9.3920]
Epoch 3/10: 100%|██████████| 24/24 [00:10<00:00,  2.22 batch/s, loss=3.8889]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.59 batch/s, loss=65.4785]


[1m[36mBest valid accuracy improved. Current accuracy is 7.13%. Saving checkpoint...
[0m

Epoch 4/10: 100%|██████████| 24/24 [00:10<00:00,  2.24 batch/s, loss=3.6118]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.59 batch/s, loss=3.8425]


[1m[36mBest valid loss improved. Current accuracy is 12.71%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 12.71%. Saving checkpoint...
[0m

Epoch 5/10: 100%|██████████| 24/24 [00:10<00:00,  2.23 batch/s, loss=3.4282]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.46 batch/s, loss=10.1936]
Epoch 6/10: 100%|██████████| 24/24 [00:10<00:00,  2.24 batch/s, loss=2.9735]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.53 batch/s, loss=6.5522]


[1m[36mBest valid accuracy improved. Current accuracy is 13.40%. Saving checkpoint...
[0m

Epoch 7/10: 100%|██████████| 24/24 [00:10<00:00,  2.23 batch/s, loss=2.6204]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.59 batch/s, loss=4.1729]


[1m[36mBest valid accuracy improved. Current accuracy is 16.32%. Saving checkpoint...
[0m

Epoch 8/10: 100%|██████████| 24/24 [00:10<00:00,  2.21 batch/s, loss=2.1215]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.66 batch/s, loss=5.0744]


[1m[36mBest valid accuracy improved. Current accuracy is 16.75%. Saving checkpoint...
[0m

Epoch 9/10: 100%|██████████| 24/24 [00:10<00:00,  2.21 batch/s, loss=1.7058]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.62 batch/s, loss=4.7220]
Epoch 10/10: 100%|██████████| 24/24 [00:10<00:00,  2.22 batch/s, loss=1.5709]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.65 batch/s, loss=4.5893]
Testing: 100%|██████████| 20/20 [00:03<00:00,  5.58 batch/s]
Using cache found in /Users/michal/.cache/torch/hub/facebookresearch_dinov2_main
Clustering:  50%|█████     | 5/10 [26:29<26:00, 312.06s/it]

In [None]:
from collections import Counter

cnt = Counter(experiment.sampler_.cluster_data)
print(cnt)