In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("ayaroshevskiy/downsampled-imagenet-64x64")
print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Path to dataset files: /Users/igor.varha/.cache/kagglehub/datasets/ayaroshevskiy/downsampled-imagenet-64x64/versions/1


In [2]:
from pathlib import Path

train_path = Path(path)/'train_64x64'/'train_64x64'
test_path = Path(path)/'valid_64x64/valid_64x64'
train_dataset_p = "tr_dataset.pt"
test_dataset_p = "te_dataset.pt"

In [3]:
from torch.optim.lr_scheduler import CosineAnnealingLR
from image_toolkit.clustering import evaluate_clustering_on_validation_p
import pickle
from torch.utils.data import DataLoader
from image_toolkit.data_processor import FragmentDataset
import torchvision.transforms as T
import torch
import random
import torchvision.transforms.functional as TF


In [4]:

if Path(train_dataset_p).exists() and Path(test_dataset_p).exists():
    train_dataset = pickle.load(open(train_dataset_p, "rb"))
    test_dataset = pickle.load(open(test_dataset_p, "rb"))
else:
    test_dataset = FragmentDataset(test_path,limit=1000)
    pickle.dump(test_dataset, open(test_dataset_p, "wb"))
    train_dataset = FragmentDataset(train_path,limit=100000)
    pickle.dump(train_dataset, open(train_dataset_p, "wb"))


In [5]:
#add augmentations

class Random90Rotation:
    def __call__(self, x):
        angle = random.choice([90, 180, 270])
        return TF.rotate(x, angle)

class RandomPatchAugment:
    def __init__(self, prob=0.5):
        self.prob = prob
        self.augment = T.Compose([
            Random90Rotation(),
            T.ColorJitter(0.1, 0.1, 0.1),
            T.GaussianBlur(kernel_size=3),
        ])
    def __call__(self, patch):
        if random.random() < self.prob:
            return self.augment(patch)
        return patch
augmenter = RandomPatchAugment(prob=0.6)

train_dataset.augment = True
train_dataset.augmenter = augmenter
test_dataset.augment = True
test_dataset.augmenter = augmenter

dataloader_test = DataLoader(test_dataset, batch_size=10, shuffle=False)
dataloader_train = DataLoader(train_dataset, batch_size=10, shuffle=True)

In [6]:
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

In [7]:
from image_toolkit.nets import TransformerPatchCluster

model = TransformerPatchCluster(embed_dim=256,nhead=8,num_layers=7).to(DEVICE) #0.72
model.load_weights("best_TTC_256_8_8_ARI90(100K)/best_model_epoch_78.pth",device=DEVICE)
# load best model

Weights loaded from best_TTC_256_8_8_ARI90(100K)/best_model_epoch_78.pth


  self.load_state_dict(torch.load(path, map_location=device))


In [8]:
# test on default test set
ari,nmi,sil = evaluate_clustering_on_validation_p(dataloader_test,model,device=DEVICE)
print(f"ARI : {ari}, NMI: {nmi}, Silhouette: {sil}")

ARI : 0.6627079427210495, NMI: 0.790216304471584, Silhouette: 0.6842788457870483


In [9]:
#retraining model
LR = 9e-5
EPOCHS = 20
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)
val_losses = model.train_model(dataloader_train,
                  dataloader_test,
                  optimizer,
                  lr_scheduler,
                  epochs=EPOCHS,
                  device=DEVICE,
                  temperature=0.33#,top_k=5
                  )



Epoch 1/20: 100%|██████████| 10000/10000 [15:44<00:00, 10.59it/s]


Epoch [1/20], Loss: 3.3707
Epoch [1/20], ARI: 0.8058
Model saved at epoch 1 with ARI: 0.8058
Current learning rate: [8.945213115648363e-05]


Epoch 2/20: 100%|██████████| 10000/10000 [16:41<00:00,  9.98it/s]


Epoch [2/20], Loss: 3.3292
Epoch [2/20], ARI: 0.8306
Model saved at epoch 2 with ARI: 0.8306
Current learning rate: [8.782201497513435e-05]


Epoch 3/20: 100%|██████████| 10000/10000 [16:39<00:00, 10.01it/s]


Epoch [3/20], Loss: 3.3139
Epoch [3/20], ARI: 0.8292
Current learning rate: [8.514979032638238e-05]


Epoch 4/20: 100%|██████████| 10000/10000 [18:11<00:00,  9.16it/s]


Epoch [4/20], Loss: 3.3010
Epoch [4/20], ARI: 0.8259
Current learning rate: [8.150125624968517e-05]


Epoch 5/20: 100%|██████████| 10000/10000 [17:41<00:00,  9.42it/s]


Epoch [5/20], Loss: 3.2908
Epoch [5/20], ARI: 0.8387
Model saved at epoch 5 with ARI: 0.8387
Current learning rate: [7.696625176280137e-05]


Epoch 6/20:  12%|█▏        | 1226/10000 [02:29<17:53,  8.18it/s]


KeyboardInterrupt: 

In [10]:
# test on default test set
ari,nmi,sil = evaluate_clustering_on_validation_p(dataloader_test,model,device=DEVICE)
print(f"ARI : {ari}, NMI: {nmi}, Silhouette: {sil}")

ARI : 0.849904820242774, NMI: 0.9090903991679766, Silhouette: 0.7769865393638611
