In [116]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [117]:
!pip install -q lightly

[0m

In [118]:
DATASET = 'CIFAR10' # or 'CIFAR100'
N_EPOCHS = 20

In [119]:
import os
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import pandas as pd
import copy

import matplotlib.pyplot as plt
from PIL import Image

from tqdm import tqdm

import torch
import torchvision

from torch import nn
from torchvision import transforms

from lightly.data import LightlyDataset
from lightly.loss import BarlowTwinsLoss
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.transforms.byol_transform import (
    BYOLTransform,
    BYOLView1Transform,
    BYOLView2Transform,
)


from sklearn.preprocessing import normalize

In [120]:
from knn import KNN, reproducibility

reproducibility(42)

In [121]:
class BarlowTwins(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

In [122]:
def get_input_stats(dataset):
    if DATASET == 'CIFAR10':
        data_mean, data_std = (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
    elif DATASET == 'CIFAR100':
        data_mean, data_std = (0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)

    return data_mean, data_std

In [123]:
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BarlowTwins(backbone)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print()




In [124]:
if DATASET == 'CIFAR10':
    cifar = torchvision.datasets.CIFAR10("datasets/cifar_10", download=True)
if DATASET == 'CIFAR100':
    cifar = torchvision.datasets.CIFAR100("datasets/cifar_100", download=True)

data_mean, data_std = get_input_stats(DATASET)
normalize_dict = {'mean': data_mean, 'std': data_std}

transform = BYOLTransform(
    view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
    view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
)
cifar = LightlyDataset.from_torch_dataset(cifar, transform=transform)

dataloader = torch.utils.data.DataLoader(
    cifar,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=2,
    pin_memory=True
)

Files already downloaded and verified


In [125]:
criterion = BarlowTwinsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# scheduler = torch.optim.lr_scheduler.CLR(optimizer,gamma=0.85)
# scaler = torch.cuda.amp.GradScaler()

In [129]:
from torchvision import transforms as T
from torch.utils.data import DataLoader

BATCH_SIZE = 32

test_transform = T.Compose([
            T.ToTensor(),
            # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_dataloader(batch_size, train=True, transform=test_transform):
    if DATASET == 'CIFAR10':
        cifar = torchvision.datasets.CIFAR10("datasets/cifar_10",train=train, transform=transform,download=True)
    if DATASET == 'CIFAR100':
        cifar = torchvision.datasets.CIFAR100("datasets/cifar_100",train=train,transform=transform, download=True)
    return DataLoader(dataset=cifar, batch_size=batch_size, num_workers=4, drop_last=True)


loader_train_plain = get_dataloader(batch_size=BATCH_SIZE, train=True, transform=test_transform)
loader_test = get_dataloader(batch_size=BATCH_SIZE, train=False, transform=test_transform)

Files already downloaded and verified
Files already downloaded and verified


In [130]:
ssl_evaluator = KNN(model=model, k=1, device='cuda')
train_acc, val_acc = ssl_evaluator.fit(loader_train_plain, loader_test)
print(train_acc, val_acc)

29.335365635701116 25.0200049573358


In [128]:
best_val_acc = 0

for epoch in tqdm(range(N_EPOCHS)):
    total_loss = 0
    for batch in dataloader:
        x0, x1 = batch[0]
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0 = model(x0)
        z1 = model(x1)
        loss = criterion(z0, z1)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    # scheduler.step()

    avg_loss = total_loss / len(dataloader)

    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
    
    if epoch % 10 == 0:
        ssl_evaluator = KNN(model=model, k=1, device='cuda')
        train_acc, val_acc = ssl_evaluator.fit(loader_train_plain, loader_test)
        if val_acc > best_val_acc:
            torch.save(model.state_dict(),f'BT_c100_acc{val_acc:.2f}.pt')
            best_val_acc = val_acc
        print(f'Train Accuracy:{train_acc:.1f}%',f' Val Accuracy:{val_acc:.1f}%')


  0%|          | 0/20 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# model.load_state_dict(torch.load('BT_c100_acc18.69.pt'))

In [None]:
ssl_evaluator = KNN(model=model, k=1, device='cuda')
train_acc, val_acc = ssl_evaluator.fit(loader_train_plain, loader_test)
print(train_acc, val_acc)

In [None]:
backbone = model.backbone
torch.save(backbone.state_dict(), f'BT_{DATASET}_acc{val_acc:.2f}_epoch{N_EPOCHS}.pt')