# TODO
- Full use of PyTorch Lightning project template.

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
## Regular imports
import sys
sys.path.append('..')
import matplotlib.pyplot as plt
%matplotlib inline

import albumentations as A
from albumentations.pytorch import ToTensorV2

## Lightning imports
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

## PyTorch imports
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
print(f'Cuda available: {torch.cuda.is_available()}')

import wandb
wandb.login()

## Internal imports
from src.callbacks import ImagePredictionLogger
from src.dataset import ComposeMany, ComposeManyTorch, MNISTDataModule2
from src.losses import log_softmax, simsiam_loss
from src.models import accuracy, feature_std, CNN, BaseLitModel, SimSiam
from src.utils import sweep_iteration

Cuda available: True


[34m[1mwandb[0m: Currently logged in as: [33malkalait[0m (use `wandb login --relogin` to force relogin)


---

In [None]:
# Augmentator

AUG_KWARGS = dict(border_mode=A.cv2.BORDER_CONSTANT, value=0,
                  interpolation=A.cv2.INTER_LANCZOS4)

train_transforms = ComposeMany([
    #A.RandomCrop(width=24, height=24),
    #A.HorizontalFlip(p=0.5),
    #A.GridDistortion(p=0.5, distort_limit=.3, **AUG_KWARGS),
    A.ElasticTransform(p=0.5, sigma=1, alpha=3, alpha_affine=0, **AUG_KWARGS),
    A.ElasticTransform(p=0.5, sigma=1, alpha=1, alpha_affine=3, **AUG_KWARGS),
    A.ShiftScaleRotate(p=1.0, scale_limit=.2, rotate_limit=0, **AUG_KWARGS),
    A.ShiftScaleRotate(p=1.0, scale_limit=0, rotate_limit=25, **AUG_KWARGS),
    #A.CoarseDropout(p=1.0, max_holes=8, max_height=4, max_width=4,
    #                min_holes=1, min_height=4, min_width=4),
    #A.RandomBrightnessContrast(p=0.2),
    #A.Blur(blur_limit=4),
    A.Normalize(mean=(0.0,), std=(1,)),  # , max_pixel_value=255),
    ToTensorV2()
], n_aug=2)

In [None]:
## Lightning datamodule. Comes with its own train / val / test dataloader.
mnist = MNISTDataModule2(data_dir='../data/', batch_size=512, train_transforms=train_transforms)
mnist.prepare_data()
mnist.setup()

## Backbone arch
cnn = CNN(num_channels=mnist.dims[0], num_classes=mnist.num_classes,
          wpool=7, maxpool=False)

## Metrics: (name of log-entry, metric)
metrics = (('acc', accuracy),)

model = BaseLitModel(
    datamodule=mnist, backbone=cnn, loss_func=log_softmax, metrics=metrics,
    lr=1e-3, flood_height=0.03
)

In [None]:
wandb_logger = WandbLogger(project='SimSiam-Lightning', job_type='train')
callbacks = [
    LearningRateMonitor(),  # log the LR
    ModelCheckpoint(monitor='val_loss'),
    ImagePredictionLogger(mnist.val_dataloader(batch_size=32), n_samples=32),
]

trainer = Trainer(
    weights_summary='full',
    max_epochs=200, gpus=-1,  # all GPUs
    logger=wandb_logger, callbacks=callbacks,
    accumulate_grad_batches=1, gradient_clip_val=0,  # 0.5
    progress_bar_refresh_rate=20,
    #fast_dev_run=True,
)

In [None]:
trainer.fit(model, mnist)

In [None]:
trainer.test()

In [None]:
wandb.finish()

---

# SimSiam

In [9]:
## Augmentator
AUG_KWARGS = dict(border_mode=A.cv2.BORDER_CONSTANT, value=0,
                  interpolation=A.cv2.INTER_LANCZOS4)
transforms = ComposeMany([
    #A.ElasticTransform(p=0.5, sigma=1, alpha=3, alpha_affine=0, **AUG_KWARGS),
    #A.ElasticTransform(p=0.5, sigma=1, alpha=1, alpha_affine=3, **AUG_KWARGS),
    A.ShiftScaleRotate(p=1.0, scale_limit=.2, rotate_limit=0, **AUG_KWARGS),
    A.ShiftScaleRotate(p=1.0, scale_limit=0, rotate_limit=10, **AUG_KWARGS),
    A.Normalize(mean=(0.0,), std=(1,)),  # , max_pixel_value=255),
    ToTensorV2()
], n_aug=2)

# transforms = ComposeManyTorch([
#     T.RandomResizedCrop(28, scale=(0.6, 1.0)),
#     #T.RandomHorizontalFlip(),
#     T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8),
#     #T.RandomGrayscale(p=0.2),
#     T.RandomApply([T.GaussianBlur(kernel_size=28//20*2+1, sigma=(0.1, 2.0))], p=0.5),
#     T.ToTensor(),
#     T.Normalize([0.1307], [0.3081]),
# ], n_aug=2)

In [31]:
## Lightning datamodule. Comes with its own train / val / test dataloader.
mnist = MNISTDataModule2(data_dir='../data/',
                         batch_size=128,
                         train_transforms=transforms,
                         val_transforms=transforms)
mnist.prepare_data(); mnist.setup()

## Backbone arch
cnn = CNN(num_channels=mnist.dims[0], num_classes=mnist.num_classes,
          maxpool=False, wpool=5, p_dropout=0.0)
simsiam = SimSiam(backbone=cnn, p_dropout=0.0)

model = BaseLitModel(
    datamodule=mnist, backbone=simsiam,
    loss_func=simsiam_loss,
    metrics=(('featstd', feature_std),),
    lr=0.05 * mnist.train_dataloader().batch_size / 256,
    #flood_height=0.03
)

Logging metrics: ['featstd']


In [18]:
wandb_logger = WandbLogger(project='SimSiam-Lightning', job_type='train')
callbacks = [
    LearningRateMonitor(),  # log the LR
    ModelCheckpoint(monitor='val_loss'),
    #ImagePredictionLogger(mnist.val_dataloader(batch_size=32), n_samples=32),
]

trainer = Trainer(
    weights_summary='full',
    max_epochs=100, gpus=-1,  # all GPUs
    logger=wandb_logger, callbacks=callbacks,
    accumulate_grad_batches=1, gradient_clip_val=0,  # 0.5
    progress_bar_refresh_rate=20,
    #fast_dev_run=True,
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [None]:
trainer.fit(model, mnist)

In [None]:
trainer.test()

In [None]:
wandb.finish()

In [23]:
path = ('/home/freddie/projects/SimSiam-Lightning/notebooks/'
        'SimSiam-Lightning/2gow693t/checkpoints/epoch=136-step=58772.ckpt')
model = model.load_from_checkpoint(checkpoint_path=path)
model = model.to('cuda')
# model.summarize()

Logging metrics: ['featstd']


---

# k-NN monitor

Code copied from

- https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb#scrollTo=J4YUQeIvuuMd
- http://github.com/zhirongw/lemniscate.pytorch
- https://github.com/leftthomas/SimCLR

In [None]:
test_transforms = ComposeMany([A.Normalize(mean=(0.0,), std=(1,)), ToTensorV2()])
memory_dataloader = mnist.train_dataloader(transforms=test_transforms)
test_dataloader = mnist.test_dataloader(transforms=test_transforms)

knn_monitor(model.backbone.f, memory_dataloader, test_dataloader,
            knn_k=10, device=model.device, epoch=200)

In [None]:
device = model.device
model.eval()
classes = len(memory_dataloader.dataset.dataset.classes)
total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []

with torch.no_grad():
    ## Generate feature bank.
    for data, _ in tqdm(memory_dataloader, desc='Feature extracting'):
        feature = model.backbone.f(data.to(device))
        feature = F.normalize(feature, dim=1)
        feature_bank.append(feature)
    ## [D, N]
    feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
    ## [N]
    feature_labels = memory_dataloader.dataset.dataset.targets.to(feature_bank.device)

In [None]:
data, target = next(iter(test_dataloader))
data, target = data.to(device), target.to(device)
feature = model.backbone.f(data)
feature = F.normalize(feature, dim=1)

In [None]:
feature.shape

In [None]:
feature_bank.shape

In [None]:
feature_labels.shape

In [None]:
knn_predict(feature=feature, feature_bank=feature_bank, feature_labels=feature_labels,
            classes=classes, knn_k=200, knn_t=1)

In [None]:
torch.Tensor([[1,2,3,4,5]]).to(int).shape

In [None]:
torch.Tensor([1,2,3,4,5]).to(int).shape

In [None]:
knn_predict(feature=torch.Tensor([[1,0,0,0,0]]), feature_bank=torch.eye(5),
            feature_labels=torch.Tensor([1,2,3,4,5]).to(int), classes=5, knn_k=1)

In [None]:
%debug

In [3]:
from tqdm import tqdm


def knn_monitor(model, memory_dataloader, test_dataloader,
                device, epoch, knn_k, knn_t=1, epochs=''):
    '''Test using a k-nn monitor.'''

    model.eval()
    classes = len(memory_dataloader.dataset.dataset.classes)
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []

    with torch.no_grad():
        ## Generate feature bank.
        for data, _ in tqdm(memory_dataloader, desc='Feature extracting'):
            feature = model(data.to(device))
            feature = F.normalize(feature, dim=1)
            feature_bank.append(feature)
        ## [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        ## [N]
        feature_labels = memory_dataloader.dataset.dataset.targets.to(feature_bank.device)
        ## Loop test data to predict the label by weighted knn search.
        test_bar = tqdm(test_dataloader)
        for data, target in test_bar:
            data, target = data.to(device), target.to(device)
            feature = model(data)
            feature = F.normalize(feature, dim=1)
            pred_labels = knn_predict(feature, feature_bank, feature_labels,
                                      classes, knn_k, knn_t)
            total_num += data.size(0)
            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
            test_bar.set_description(
                'Test Epoch: [{}/{}] Acc@1:{:.2f}%'.format(
                    epoch, epochs, total_top1 / total_num * 100
                )
            )
    return total_top1 / total_num * 100


def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t=1):
    '''
    k-NN monitor as in InstDisc https://arxiv.org/abs/1805.01978
    Implementation follows http://github.com/zhirongw/lemniscate.pytorch
    and https://github.com/leftthomas/SimCLR
    '''
    ## Cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    ## [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    ## [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1),
                              dim=-1, index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    ## Counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    ## [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    ## Weighted score ---> [B, C]
    pred_scores = torch.sum(
        one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1),
        dim=1
    )
    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels