#Necessary imports

In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
import warnings
warnings.filterwarnings("ignore")

In [15]:
!pip install -q lightly
!pip install -U openmim && mim install "mmpretrain>=1.0.0rc8"

[0mLooking in links: https://download.openmmlab.com/mmcv/dist/cu116/torch1.12.0/index.html
[0m

In [16]:
DATASET = 'CIFAR100' # or 'CIFAR100'
N_EPOCHS = 20

In [17]:
import numpy as np
import pandas as pd
import copy
import torch
from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from torch import nn
import matplotlib.pyplot as plt
from PIL import Image

from tqdm import tqdm

import torch
import torchvision

from torch import nn
from torchvision import transforms

from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform

from lightly.data import LightlyDataset
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda', index=0)

In [18]:
from knn import KNN, reproducibility, get_input_stats
reproducibility(1000)

#Define model

In [19]:
from mmpretrain import get_model

model = get_model('vit-small-p14_dinov2-pre_3rdparty', pretrained=False)
model = model.to(DEVICE)

#Data (cifar10 or cifar100)

In [20]:
from lightly.transforms.dino_transform import DINOTransform

In [21]:
transform = SimCLRTransform(input_size=32, gaussian_blur=0.0)
if DATASET == 'CIFAR10':
    cifar = torchvision.datasets.CIFAR10("datasets/cifar_10", download=True)
if DATASET == 'CIFAR100':
    cifar = torchvision.datasets.CIFAR100("datasets/cifar_100", download=True)
    
data_mean, data_std = get_input_stats(DATASET)
normalize_dict = {'mean': data_mean, 'std': data_std}

dataset = LightlyDataset.from_torch_dataset(cifar, transform=transform)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    drop_last=True,
    num_workers=2,
    pin_memory=True
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to datasets/cifar_100/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting datasets/cifar_100/cifar-100-python.tar.gz to datasets/cifar_100


#Optimizer and Scheduler

In [22]:
criterion = NTXentLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=4e-4)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.999)
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,base_lr=0.001, max_lr=0.01)
n_epochs, max_norm = 20, 0.1

#KNN and dataloaders for evaluation

In [23]:
from torchvision import transforms as T
from torch.utils.data import DataLoader

BATCH_SIZE = 32

test_transform = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])

def get_dataloader(batch_size, train=True, transform=test_transform):
    if DATASET == 'CIFAR10':
        cifar = torchvision.datasets.CIFAR10("datasets/cifar_10",train=train, transform=transform,download=True)
    if DATASET == 'CIFAR100':
        cifar = torchvision.datasets.CIFAR100("datasets/cifar_100",train=train,transform=transform, download=True)
    return DataLoader(dataset=cifar, batch_size=batch_size, num_workers=4, drop_last=True)


loader_train_plain = get_dataloader(batch_size=BATCH_SIZE, train=True, transform=test_transform)
loader_test = get_dataloader(batch_size=BATCH_SIZE, train=False, transform=test_transform)

Files already downloaded and verified
Files already downloaded and verified


#Make directory for saving models & Training

In [24]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [25]:
ssl_evaluator = KNN(model=model, k=1, device='cuda', transformer=True)
train_acc, val_acc = ssl_evaluator.fit(loader_train_plain, loader_test)
val_acc

Evaluate on train data...
Evaluate on test data...


7.6322901185445495

In [26]:
best_val_acc = 0

print("Starting Training")
for epoch in tqdm(range(N_EPOCHS)):
    total_loss = 0
    for batch in dataloader:
        x0, x1 = batch[0]
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0 = model.extract_feat(x0)[0]
        z1 = model.extract_feat(x1)[0]
        loss = criterion(z0, z1)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    scheduler.step()
    avg_loss = total_loss / len(dataloader)
    print(f" epoch: {epoch:>02}, loss: {avg_loss:.5f}")
    # if epoch % 10 == 0:
    #     ssl_evaluator = KNN(model=model, k=1, device='cuda', transformer=True)
    #     train_acc, val_acc = ssl_evaluator.fit(loader_train_plain, loader_test)
    #     if val_acc > best_val_acc:
    #         torch.save(model.state_dict(),f'dino_{DATASET}_acc{val_acc:.2f}.pt')
    #         best_val_acc = val_acc
    #     print(f' Train Accuracy:{train_acc:.1f}%',f' Val Accuracy:{val_acc:.1f}%')

Starting Training


  5%|▌         | 1/20 [00:28<08:59, 28.40s/it]

 epoch: 00, loss: 5.19210


 10%|█         | 2/20 [00:57<08:39, 28.88s/it]

 epoch: 01, loss: 4.97812


 15%|█▌        | 3/20 [01:26<08:12, 28.95s/it]

 epoch: 02, loss: 4.90572


 20%|██        | 4/20 [01:56<07:46, 29.14s/it]

 epoch: 03, loss: 4.85625


 25%|██▌       | 5/20 [02:25<07:19, 29.27s/it]

 epoch: 04, loss: 4.82940


 30%|███       | 6/20 [02:54<06:48, 29.21s/it]

 epoch: 05, loss: 4.79726


 35%|███▌      | 7/20 [03:24<06:20, 29.27s/it]

 epoch: 06, loss: 4.77958


 40%|████      | 8/20 [03:53<05:50, 29.24s/it]

 epoch: 07, loss: 4.76120


 45%|████▌     | 9/20 [04:21<05:19, 29.07s/it]

 epoch: 08, loss: 4.73555


 50%|█████     | 10/20 [04:50<04:50, 29.06s/it]

 epoch: 09, loss: 4.72659


 55%|█████▌    | 11/20 [05:19<04:20, 28.99s/it]

 epoch: 10, loss: 4.71542


 60%|██████    | 12/20 [05:49<03:52, 29.07s/it]

 epoch: 11, loss: 4.70174


 65%|██████▌   | 13/20 [06:18<03:23, 29.09s/it]

 epoch: 12, loss: 4.69662


 70%|███████   | 14/20 [06:47<02:54, 29.03s/it]

 epoch: 13, loss: 4.68403


 75%|███████▌  | 15/20 [07:15<02:24, 28.93s/it]

 epoch: 14, loss: 4.67510


 80%|████████  | 16/20 [07:45<01:56, 29.06s/it]

 epoch: 15, loss: 4.66587


 85%|████████▌ | 17/20 [08:14<01:27, 29.00s/it]

 epoch: 16, loss: 4.65623


 90%|█████████ | 18/20 [08:43<00:58, 29.12s/it]

 epoch: 17, loss: 4.64905


 95%|█████████▌| 19/20 [09:12<00:29, 29.01s/it]

 epoch: 18, loss: 4.63907


100%|██████████| 20/20 [09:40<00:00, 29.04s/it]

 epoch: 19, loss: 4.62879





In [27]:
ssl_evaluator = KNN(model=model, k=1, device='cuda', transformer=True)
train_acc, val_acc = ssl_evaluator.fit(loader_train_plain, loader_test)
print(train_acc, val_acc)

Evaluate on train data...
Evaluate on test data...
13.274241175914197 9.785730700258805


#Visual evaluation

#Save backbone model

In [28]:
backbone = model.backbone
torch.save(backbone.state_dict(), f'dino_{DATASET}_acc{val_acc:.2f}_epoch{N_EPOCHS}.pt')