In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
# regular imports
import sys
sys.path.append('..')
import matplotlib.pyplot as plt
%matplotlib inline

import albumentations as A
from albumentations.pytorch import ToTensorV2

# Lightning import 
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
# from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule

# PyTorch imports
import torch
from torch import nn
from torch.nn import functional as F
print(f'Cuda available: {torch.cuda.is_available()}')

import wandb
wandb.login()

# internal imports
from src.callbacks import ImagePredictionLogger
from src.dataset import MNISTDataModule
from src.models import CNN, LitModel
from src.utils import sweep_iteration

Cuda available: True


[34m[1mwandb[0m: Currently logged in as: [33malkalait[0m (use `wandb login --relogin` to force relogin)


---

In [3]:
# MODEL_CKPT_PATH = '../model/'
# MODEL_CKPT = '../model/model-{epoch:02d}-{val_loss:.2f}'

# checkpoint_callback = ModelCheckpoint(
#     monitor='val_loss',
#     filepath=MODEL_CKPT ,
#     save_top_k=3,
#     mode='min'
# )

In [4]:
aug_kwargs = dict(border_mode=A.cv2.BORDER_CONSTANT, value=0,
                  interpolation=A.cv2.INTER_LANCZOS4)

train_transforms = A.Compose([
    A.RandomCrop(width=24, height=24),
    #A.HorizontalFlip(p=0.5),
    #A.GridDistortion(p=0.5, distort_limit=.3, **aug_kwargs),
    A.ElasticTransform(p=0.5, sigma=1, alpha=1, alpha_affine=5, **aug_kwargs),
    A.ShiftScaleRotate(p=1.0, scale_limit=.2, rotate_limit=30, **aug_kwargs),
    #A.CoarseDropout(p=1.0, max_holes=8, max_height=4, max_width=4,
    #                min_holes=1, min_height=4, min_width=4),
    #A.RandomBrightnessContrast(p=0.2),
    #A.Blur(blur_limit=4),
    A.Normalize(mean=(0.0,), std=(1,)),
    ToTensorV2()
])

In [5]:
proj = 'SimSiam-Lightning'

# Setup datamodule. Comes with its own train / val / test dataloader.
mnist = MNISTDataModule('../data/', batch_size=512, train_transforms=train_transforms)
# mnist.prepare_data()
# mnist.setup()

cnn = CNN(C=mnist.dims[0], num_classes=mnist.num_classes)  # Architecture
model = LitModel(datamodule=mnist, backbone=cnn, batch_size=512, lr=1e-3, flood=True)
wandb_logger = WandbLogger(project=proj, job_type='train')  # Logger
callbacks = [
    LearningRateMonitor(),  # log the LR
    ImagePredictionLogger(mnist.val_dataloader(batch_size=5000), n_samples=64),
    #early_stop_callback,
]

In [6]:
trainer = Trainer(
    max_epochs=200, gpus=-1,  # all GPUs
    logger=wandb_logger, callbacks=callbacks,
    accumulate_grad_batches=1, gradient_clip_val=0,  # 0.5
    progress_bar_refresh_rate=20,
    #checkpoint_callback=checkpoint_callback
    #fast_dev_run=True,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [7]:
# # Learning rate finder
# lr_finder = trainer.tuner.lr_find(model, num_training=3000, mode='linear', max_lr=1e-2)
# # lr_finder.results  # Results can be found in
# fig = lr_finder.plot(suggest=True)
# lr_finder.suggestion()
# model.hparams.lr = new_lr  # update hparams of the model

In [None]:
trainer.fit(model, mnist)




  | Name     | Type     | Params
--------------------------------------
0 | backbone | CNN      | 158 K 
1 | accuracy | Accuracy | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [None]:
trainer.test()

In [None]:
wandb.finish()

---

# Hyperparameter sweep 

In [None]:
# from src.sweeps import sweep_config

# sweep_id = wandb.sweep(sweep_config, project=proj)

In [None]:
# wandb.agent(sweep_id, function=sweep_iteration, project=proj)

---