In [None]:
%reload_ext autoreload
%autoreload 2

In [1]:
# regular imports
import sys
sys.path.append('..')
import matplotlib.pyplot as plt
%matplotlib inline

import albumentations as A
from albumentations.pytorch import ToTensorV2

# Lightning import
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
# from pl_bolts.datamodules.mnist_datamodule import LightningDataModule, MNISTDataModule,

# PyTorch imports
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
print(f'Cuda available: {torch.cuda.is_available()}')

import wandb
wandb.login()

# internal imports
from src.callbacks import ImagePredictionLogger
from src.dataset import ComposeMany, MNISTDataModule2  # , alb_to_torch_aug, MNISTDataModule
from src.losses import log_softmax, simsiam_loss
from src.models import accuracy, CNN, BaseLitModel, SimSiam
from src.utils import sweep_iteration

Cuda available: True


[34m[1mwandb[0m: Currently logged in as: [33malkalait[0m (use `wandb login --relogin` to force relogin)


---

In [None]:
# Augmentator

AUG_KWARGS = dict(border_mode=A.cv2.BORDER_CONSTANT, value=0,
                  interpolation=A.cv2.INTER_LANCZOS4)

train_transforms = ComposeMany([
    #A.RandomCrop(width=24, height=24),
    #A.HorizontalFlip(p=0.5),
    #A.GridDistortion(p=0.5, distort_limit=.3, **AUG_KWARGS),
    A.ElasticTransform(p=0.5, sigma=1, alpha=3, alpha_affine=0, **AUG_KWARGS),
    A.ElasticTransform(p=0.5, sigma=1, alpha=1, alpha_affine=3, **AUG_KWARGS),
    A.ShiftScaleRotate(p=1.0, scale_limit=.2, rotate_limit=0, **AUG_KWARGS),
    A.ShiftScaleRotate(p=1.0, scale_limit=0, rotate_limit=25, **AUG_KWARGS),
    #A.CoarseDropout(p=1.0, max_holes=8, max_height=4, max_width=4,
    #                min_holes=1, min_height=4, min_width=4),
    #A.RandomBrightnessContrast(p=0.2),
    #A.Blur(blur_limit=4),
    A.Normalize(mean=(0.0,), std=(1,)),  # , max_pixel_value=255),
    ToTensorV2()
], n_aug=2)

In [None]:
proj = 'SimSiam-Lightning'

## Lightning datamodule. Comes with its own train / val / test dataloader.
mnist = MNISTDataModule2(data_dir='../data/', batch_size=512, train_transforms=train_transforms)
mnist.prepare_data()
mnist.setup()

## Backbone arch
cnn = CNN(num_channels=mnist.dims[0], num_classes=mnist.num_classes)

## Metrics: (name of log-entry, metric)
metrics = (('acc', accuracy),)

model = BaseLitModel(
    datamodule=mnist, backbone=cnn, loss_func=log_softmax, metrics=metrics,
    lr=1e-3, flood_height=0.03
)

In [None]:
wandb_logger = WandbLogger(project=proj, job_type='train')
callbacks = [
    LearningRateMonitor(),  # log the LR
    ImagePredictionLogger(mnist.val_dataloader(batch_size=32), n_samples=32),
]

trainer = Trainer(
    max_epochs=200, gpus=-1,  # all GPUs
    logger=wandb_logger, callbacks=callbacks,
    accumulate_grad_batches=1, gradient_clip_val=0,  # 0.5
    progress_bar_refresh_rate=20,
    fast_dev_run=True,
)

In [None]:
trainer.fit(model, mnist)

---

In [None]:
trainer.test()

In [None]:
wandb.finish()

---

# LR finder

In [None]:
# # Learning rate finder
# lr_finder = trainer.tuner.lr_find(model, num_training=3000, mode='linear', max_lr=1e-2)
# # lr_finder.results  # Results can be found in
# fig = lr_finder.plot(suggest=True)
# lr_finder.suggestion()
# model.hparams.lr = new_lr  # update hparams of the model

# Hyperparameter sweep 

In [None]:
# from src.sweeps import sweep_config

# sweep_id = wandb.sweep(sweep_config, project=proj)

# wandb.agent(sweep_id, function=sweep_iteration, project=proj)

---

# SimSiam

In [None]:
# Augmentator

AUG_KWARGS = dict(border_mode=A.cv2.BORDER_CONSTANT, value=0,
                  interpolation=A.cv2.INTER_LANCZOS4)

transforms = ComposeMany([
    #A.ElasticTransform(p=0.5, sigma=1, alpha=3, alpha_affine=0, **AUG_KWARGS),
    #A.ElasticTransform(p=0.5, sigma=1, alpha=1, alpha_affine=3, **AUG_KWARGS),
    A.ShiftScaleRotate(p=1.0, scale_limit=.2, rotate_limit=0, **AUG_KWARGS),
    A.ShiftScaleRotate(p=1.0, scale_limit=0, rotate_limit=25, **AUG_KWARGS),
    A.Normalize(mean=(0.0,), std=(1,)),  # , max_pixel_value=255),
    ToTensorV2()
], n_aug=2)

In [2]:
proj = 'SimSiam-Lightning'

# Lightning datamodule. Comes with its own train / val / test dataloader.
mnist = MNISTDataModule2(data_dir='../data/',
                         batch_size=512,
                         train_transforms=transforms,
                         val_transforms=transforms)
mnist.prepare_data()
mnist.setup()

## Backbone arch
cnn = CNN(num_channels=mnist.dims[0], num_classes=mnist.num_classes)
simsiam = SimSiam(backbone=cnn)

model = BaseLitModel(
    datamodule=mnist, backbone=simsiam, loss_func=simsiam_loss, #metrics=metrics,
    #lr=1e-3,
    lr=0.05 * mnist.train_dataloader().batch_size / 256
    #flood_height=0.03
)

NameError: name 'transforms' is not defined

In [None]:
wandb_logger = WandbLogger(project=proj, job_type='train')
callbacks = [
    LearningRateMonitor(),  # log the LR
    #ImagePredictionLogger(mnist.val_dataloader(batch_size=32), n_samples=32),
]

trainer = Trainer(
    max_epochs=200, #gpus=-1,  # all GPUs
    logger=wandb_logger, callbacks=callbacks,
    accumulate_grad_batches=1, gradient_clip_val=0,  # 0.5
    progress_bar_refresh_rate=20,
    fast_dev_run=True,
)

In [None]:
trainer.fit(model, mnist)

In [None]:
x, y = next(iter(mnist.train_dataloader()))
x.shape

In [None]:
# cnn.avgpool(cnn.features(x[:,[0]])).shape

In [None]:
from src.losses import negcosim, simsiam_loss

In [None]:
x, y = next(iter(mnist.val_dataloader()))

In [None]:
z1, z2, p1, p2 = simsiam(x)

In [None]:
simsiam_loss((x, y), simsiam)

---

In [4]:
import torchvision.transforms as T
from src.dataset import ComposeManyTorch

transforms = ComposeManyTorch([
    T.RandomResizedCrop(28, scale=(0.6, 1.0)),
    #T.RandomHorizontalFlip(),
    T.RandomApply([T.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8),
    #T.RandomGrayscale(p=0.2),
    T.RandomApply([T.GaussianBlur(kernel_size=28//20*2+1, sigma=(0.1, 2.0))], p=0.5),
    T.ToTensor(),
    T.Normalize([0.1307], [0.3081]),
], n_aug=2)

mnist = MNISTDataModule2(data_dir='../data/',
                         batch_size=32,
                         train_transforms=transforms,
                         val_transforms=transforms)

x, y = next(iter(mnist.val_dataloader()))

i = 1
fig, ax = plt.subplots(1,2)
ax[0].imshow(x[i][0], );
ax[1].imshow(x[i][1]);