### Eksperymenty

**Modele**
- resnet50
- vgg16

**Zbiory danych**
- CIFAR100
- SVHN

**Metody**
- Random
- KMeans
- KMeansPurity
- KMeansDino
- KMeansDinoPurity

**Ustawienia**
- ratio = (0.1, 1, 10)
- epochs = 10
- batch_size = 512
- clip = 5.0
- num_clusters = 10
- eqsize = True
- min_purity = 0.1 (SVHN), 0.01 (CIFAR100)

**Nazewnictwo przebiegów**
W nawiasach kwadratowych podano wartość do wyboru ("jeden z ...") w nawiasch klamrowych wartości opcjonalne.
ID służy do rozróżniania eksperymentów w tej samej konfiguracji

- [r50,vgg13],ep:10,bs:512,clip:5.0,{nc:20},{eqsize},{mp:[0.1,0.01]},ID:[1,2,3,..]
- przykład: r50,ep:10,bs:512,clip:5.0,ID:1
- przykład: r50,ep:10,bs:512,clip:50,nc:20,eqsize,mp:0.1,ID:1

**Uwagi**
- Przy każdym eksperymencie nowy model (!)

In [7]:
import torch
import torchvision
from torchvision.models import resnet50

from kiss.models import vgg16_kiss
from kiss.experiment import Experiment
from kiss.sampler import RandomSampler, KMeansSampler, KMeansPuritySampler, KMeansPurity2Sampler, KMeansDinoSampler, KMeansPurityDinoSampler
from kiss.utils.configs import CONFIGS

transform = torchvision.transforms.ToTensor()
dataset_cifar100_tr = torchvision.datasets.CIFAR100(root='../data/cifar100', train=True, download=True, transform=transform)
dataset_cifar100_te = torchvision.datasets.CIFAR100(root='../data/cifar100', train=False, download=True, transform=transform)

dataset_svhn_tr = torchvision.datasets.SVHN(root='../data/svhn', split='train', download=True, transform=transform)
dataset_svhn_te = torchvision.datasets.SVHN(root='../data/svhn', split='test', download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
ID = 5
RATIO = (0.1, 0.3, 3)
EPOCHS = 10
BATCH_SIZE = 512
CLIP = 5.0
NUM_CLUSTERS = 10
EQSIZE = True

# 5 - poprawna walidacja

In [None]:
import os

for model_fun in [resnet50]:
    for dataset_tr, dataset_te in zip([dataset_cifar100_tr, dataset_svhn_tr], [dataset_cifar100_te, dataset_svhn_te]):
        for sampler_cls in [KMeansDinoSampler]:

            if model_fun.__name__ == 'resnet50':
                MODEL_NAME = 'r50'
            if model_fun.__name__ == 'vgg16':
                MODEL_NAME = 'vgg16'
            if model_fun.__name__ == 'vgg16_kiss':
                MODEL_NAME = 'vgg16_kiss'

            RUN_NAME = f"{MODEL_NAME},ep:{EPOCHS},bs:{BATCH_SIZE},clip:{CLIP}"

            if dataset_tr.__class__.__name__ == 'SVHN':
                NUM_CLASSES = 10
                MIN_PURITY = 0.1
            if dataset_tr.__class__.__name__ == 'CIFAR100':
                NUM_CLASSES = 100
                MIN_PURITY = 0.01

            if 'KMeans' in sampler_cls.__name__:
                RUN_NAME += f",nc:{NUM_CLUSTERS}"

            if 'KMeans' in sampler_cls.__name__ and EQSIZE:
                RUN_NAME += ",eqsize"

            if 'KMeansPurity' in sampler_cls.__name__:
                RUN_NAME += f",mp:{MIN_PURITY}"

            SAVE_CLUSTERS = LOAD_CLUSTERS = f"../checkpoints/{sampler_cls.__name__},{dataset_tr.__class__.__name__},nc:{NUM_CLUSTERS}"
            
            if not os.path.exists(LOAD_CLUSTERS):
                LOAD_CLUSTERS = None
            
            RUN_NAME += f',ID:{ID}'

            model = model_fun(num_classes=NUM_CLASSES)
            model.to(torch.device(CONFIGS.torch.device))

            experiment = Experiment(
                model = model,
                dataset_tr = dataset_tr,
                dataset_te = dataset_te,
                sampler_cls=sampler_cls,
                ratio=RATIO,
                epochs=EPOCHS,
                batch_size=BATCH_SIZE,
                clip=CLIP,
                num_clusters=NUM_CLUSTERS,
                eqsize=EQSIZE,
                min_purity=MIN_PURITY,
                load_clusters=LOAD_CLUSTERS,
                save_clusters=SAVE_CLUSTERS,
            )
            experiment.run("../experiments", RUN_NAME)

[1m[33mRunning experiment ResNet!CIFAR100!RandomSampler
[0m[1m[95mRunning run r50,ep:10,bs:512,clip:5.0,ID:5/1
[0m

Epoch 1/10: 100%|██████████| 8/8 [00:04<00:00,  1.77 batch/s, loss=5.1547]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.40 batch/s, loss=4.6781]


[1m[36mBest valid loss improved. Current accuracy is 0.97%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 0.97%. Saving checkpoint...
[0m

Epoch 2/10: 100%|██████████| 8/8 [00:03<00:00,  2.20 batch/s, loss=4.6894]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.43 batch/s, loss=4.8054]


[1m[36mBest valid accuracy improved. Current accuracy is 1.00%. Saving checkpoint...
[0m

Epoch 3/10: 100%|██████████| 8/8 [00:03<00:00,  2.23 batch/s, loss=4.1967]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.44 batch/s, loss=4.8718]


[1m[36mBest valid accuracy improved. Current accuracy is 1.24%. Saving checkpoint...
[0m

Epoch 4/10: 100%|██████████| 8/8 [00:03<00:00,  2.21 batch/s, loss=3.7937]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.49 batch/s, loss=4.9795]


[1m[36mBest valid accuracy improved. Current accuracy is 2.18%. Saving checkpoint...
[0m

Epoch 5/10: 100%|██████████| 8/8 [00:03<00:00,  2.21 batch/s, loss=3.2669]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.54 batch/s, loss=4.8974]


[1m[36mBest valid accuracy improved. Current accuracy is 2.98%. Saving checkpoint...
[0m

Epoch 6/10: 100%|██████████| 8/8 [00:03<00:00,  2.20 batch/s, loss=2.5772]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.59 batch/s, loss=4.5371]


[1m[36mBest valid loss improved. Current accuracy is 5.32%. Saving checkpoint...
[0m[1m[36mBest valid accuracy improved. Current accuracy is 5.32%. Saving checkpoint...
[0m

Epoch 7/10: 100%|██████████| 8/8 [00:03<00:00,  2.21 batch/s, loss=1.7841]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.57 batch/s, loss=4.7266]


[1m[36mBest valid accuracy improved. Current accuracy is 7.43%. Saving checkpoint...
[0m

Epoch 8/10: 100%|██████████| 8/8 [00:03<00:00,  2.17 batch/s, loss=1.1963]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.37 batch/s, loss=5.0793]


[1m[36mBest valid accuracy improved. Current accuracy is 8.44%. Saving checkpoint...
[0m

Epoch 9/10: 100%|██████████| 8/8 [00:03<00:00,  2.23 batch/s, loss=0.8468]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.42 batch/s, loss=5.8110]
Epoch 10/10: 100%|██████████| 8/8 [00:03<00:00,  2.19 batch/s, loss=0.6313]
Validating: 100%|██████████| 20/20 [00:03<00:00,  5.38 batch/s, loss=6.0760]
Testing: 100%|██████████| 20/20 [00:03<00:00,  5.40 batch/s]


[1m[95mRunning run r50,ep:10,bs:512,clip:5.0,ID:5/2
[0m

Epoch 1/10: 100%|██████████| 16/16 [00:07<00:00,  2.14 batch/s, loss=5.0323]
Validating:  80%|████████  | 16/20 [00:03<00:00,  4.91 batch/s, loss=4.7304]


KeyboardInterrupt: 