In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("ayaroshevskiy/downsampled-imagenet-64x64")
print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Path to dataset files: /Users/igor.varha/.cache/kagglehub/datasets/ayaroshevskiy/downsampled-imagenet-64x64/versions/1


In [2]:
from pathlib import Path

train_path = Path(path)/'train_64x64'/'train_64x64'
test_path = Path(path)/'valid_64x64/valid_64x64'
train_dataset_p = "tr_dataset.pt"
test_dataset_p = "te_dataset.pt"

In [9]:
from torch.optim.lr_scheduler import CosineAnnealingLR
from image_toolkit.data_processor import FragmentDataset
from image_toolkit.clustering import evaluate_clustering_on_validation_p
import pickle
from torch_geometric.data import DataLoader



In [3]:

if Path(train_dataset_p).exists() and Path(test_dataset_p).exists():
    train_dataset = pickle.load(open(train_dataset_p, "rb"))
    test_dataset = pickle.load(open(test_dataset_p, "rb"))
else:
    test_dataset = FragmentDataset(test_path,limit=1000)

    pickle.dump(test_dataset, open(test_dataset_p, "wb"))
    train_dataset = FragmentDataset(train_path,limit=100000)
    pickle.dump(train_dataset, open(train_dataset_p, "wb"))

In [11]:

patch_size = 8
n_patches = 64

test_dataset.fragments_per_image = n_patches
test_dataset.patch_size = patch_size

train_dataset.fragments_per_image = n_patches
train_dataset.patch_size = patch_size
dataloader_test = DataLoader(test_dataset, batch_size=10, shuffle=False)
dataloader_train = DataLoader(train_dataset, batch_size=10, shuffle=True)



In [5]:
import torch

DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")


In [6]:
from image_toolkit.nets import TransformerPatchCluster

model = TransformerPatchCluster(embed_dim=256,nhead=8,device=DEVICE,num_layers=7).to(DEVICE) #0.72
model.load_weights("best_TTC_256_8_8_ARI90(100K)/best_model_epoch_78.pth")
# load best model

Weights loaded from best_TTC_256_8_8_ARI90(100K)/best_model_epoch_78.pth


  self.load_state_dict(torch.load(path, map_location=self.device))


In [7]:
# test on default test set
ari,nmi,sil = evaluate_clustering_on_validation_p(dataloader_test,model,device=DEVICE)
print(f"ARI : {ari}, NMI: {nmi}, Silhouette: {sil}")

ARI : 0.13028539081906534, NMI: 0.23507703530713672, Silhouette: 0.2796148955821991


In [12]:
#retraining model
LR = 9e-5
EPOCHS = 20
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)

val_losses = model.train_model(dataloader_train,
                  dataloader_test,
                  optimizer,
                  lr_scheduler,
                  epochs=EPOCHS,
                  device=DEVICE,
                  temperature=0.33#,top_k=5
                  )

Epoch 1/20: 100%|██████████| 10000/10000 [14:31<00:00, 11.47it/s]


Epoch [1/20], Loss: 5.2211
Epoch [1/20], ARI: 0.7070




Model saved at epoch 1 with ARI: 0.7070
Current learning rate: [8.972587124713445e-05]


Epoch 2/20: 100%|██████████| 10000/10000 [13:14<00:00, 12.59it/s]


Epoch [2/20], Loss: 4.9236




Epoch [2/20], ARI: 0.7550
Model saved at epoch 2 with ARI: 0.7550
Current learning rate: [8.968742855287973e-05]


Epoch 3/20: 100%|██████████| 10000/10000 [12:59<00:00, 12.83it/s]


Epoch [3/20], Loss: 4.8593




Epoch [3/20], ARI: 0.7833
Model saved at epoch 3 with ARI: 0.7833
Current learning rate: [8.966362025669304e-05]


Epoch 4/20: 100%|██████████| 10000/10000 [14:36<00:00, 11.40it/s]


Epoch [4/20], Loss: 4.8254




Epoch [4/20], ARI: 0.7961
Model saved at epoch 4 with ARI: 0.7961
Current learning rate: [8.965254899275425e-05]


Epoch 5/20: 100%|██████████| 10000/10000 [15:40<00:00, 10.63it/s]


Epoch [5/20], Loss: 4.7975




Epoch [5/20], ARI: 0.8048
Model saved at epoch 5 with ARI: 0.8048
Current learning rate: [8.964490663126493e-05]


Epoch 6/20: 100%|██████████| 10000/10000 [15:32<00:00, 10.72it/s]


Epoch [6/20], Loss: 4.7759




Epoch [6/20], ARI: 0.8130
Model saved at epoch 6 with ARI: 0.8130
Current learning rate: [8.963765708323359e-05]


Epoch 7/20: 100%|██████████| 10000/10000 [21:39<00:00,  7.70it/s]


Epoch [7/20], Loss: 4.7599
Epoch [7/20], ARI: 0.8215




Model saved at epoch 7 with ARI: 0.8215
Current learning rate: [8.963000780280249e-05]


Epoch 8/20:  23%|██▎       | 2251/10000 [05:18<18:15,  7.08it/s]


KeyboardInterrupt: 

Testing

In [13]:
ari,nmi,sil = evaluate_clustering_on_validation_p(dataloader_test,model,device=DEVICE)
print(f"ARI : {ari}, NMI: {nmi}, Silhouette: {sil}")

ARI : 0.8198463886422354, NMI: 0.8708453090299507, Silhouette: 0.7423139214515686
