In [56]:
%reload_ext autoreload
%autoreload 2

In [57]:
# regular imports
import sys
sys.path.append('..')

# Lightning import 
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
# from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule

# PyTorch imports
import torch
from torch import nn
from torch.nn import functional as F

import wandb
wandb.login()

# internal imports
from src.callbacks import ImagePredictionLogger
from src.dataset import MNISTDataModule
from src.models import CNN, LitModel
from src.utils import sweep_iteration

---

In [58]:
MODEL_CKPT_PATH = '../model/'
MODEL_CKPT = '../model/model-{epoch:02d}-{val_loss:.2f}'

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filepath=MODEL_CKPT ,
    save_top_k=3,
    mode='min'
)

In [59]:
proj = 'SimSiam-Lightning'

# Setup datamodule. Comes with its own train / val / test dataloader.
mnist = MNISTDataModule('../data/', batch_size=512)
mnist.prepare_data()
mnist.setup()

val_samples = next(iter(mnist.val_dataloader()))  # log these with ImagePredictionLogger

cnn = CNN(C=mnist.dims[0], num_classes=mnist.num_classes)  # Architecture
model = LitModel(datamodule=mnist, arch=cnn, lr=1e-3, flood=True)  # Lightning model
wandb_logger = WandbLogger(project=proj, job_type='train')  # Logger
callbacks = [LearningRateMonitor(),  # log the LR
             ImagePredictionLogger(val_samples),  # log some validation results
             #early_stop_callback,
            ]

In [60]:
# logits = cnn(val_samples[0])
# probs = F.softmax(logits, dim=1)
# probs = torch.max(probs, -1).values
# probs
# # probs.shape

In [61]:
trainer = Trainer(
    max_epochs=200,  # number of epochs
    progress_bar_refresh_rate=20,
    gpus=-1,  # all GPUs
    logger=wandb_logger,
    callbacks=callbacks,
    #checkpoint_callback=checkpoint_callback
    accumulate_grad_batches=1,
    gradient_clip_val=0,  # 0.5
    #fast_dev_run=True,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [None]:
trainer.fit(model, mnist)


  | Name     | Type     | Params
--------------------------------------
0 | arch     | CNN      | 158 K 
1 | accuracy | Accuracy | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [45]:
trainer.test()

[autoreload of src.callbacks failed: Traceback (most recent call last):
  File "/home/freddie/venv/wotus/lib/python3.6/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/home/freddie/venv/wotus/lib/python3.6/site-packages/IPython/extensions/autoreload.py", line 394, in superreload
    module = reload(module)
  File "/home/freddie/venv/wotus/lib/python3.6/imp.py", line 315, in reload
    return importlib.reload(module)
  File "/home/freddie/venv/wotus/lib/python3.6/importlib/__init__.py", line 166, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 618, in _exec
  File "<frozen importlib._bootstrap_external>", line 674, in exec_module
  File "<frozen importlib._bootstrap_external>", line 781, in get_code
  File "<frozen importlib._bootstrap_external>", line 741, in source_to_code
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "../src/callba

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.9874, device='cuda:0'),
 'test_loss': tensor(0.0454, device='cuda:0'),
 'train_acc': tensor(0.9961, device='cuda:0'),
 'train_loss': tensor(0.0571, device='cuda:0'),
 'val_acc': tensor(0.9861, device='cuda:0'),
 'val_loss': tensor(0.0521, device='cuda:0')}
--------------------------------------------------------------------------------



[{'train_loss': 0.057143501937389374,
  'train_acc': 0.99609375,
  'val_loss': 0.052121084183454514,
  'val_acc': 0.9861111044883728,
  'test_loss': 0.04540814086794853,
  'test_acc': 0.9873560667037964}]

In [55]:
wandb.finish()

VBox(children=(Label(value=' 0.18MB of 0.18MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
_step,535.0
_runtime,39.0
_timestamp,1607481140.0
train_loss,0.20023
train_acc,0.9375
epoch,4.0
val_loss,0.14017
val_acc,0.95833
lr-Adam,0.001


0,1
_step,▁▂▂▂▂▃▄▄▄▄▅▅▅▆▆▇▇▇███
_runtime,▁▂▂▂▃▃▃▃▄▄▄▄▅▅▆▆▆▆▇▇█
_timestamp,▁▂▂▂▃▃▃▃▄▄▄▄▅▅▆▆▆▆▇▇█
train_loss,█▅▃▂▂▂▁▁▁▁
train_acc,▁▄▆▇▇▇████
epoch,▁▁▁▃▃▃▅▅▅▆▆▆███
val_loss,█▄▂▁▁
val_acc,▁▆▇██
lr-Adam,▁▁▁▁▁


In [None]:
# run = wandb.init(project=proj, job_type='producer')

# artifact = wandb.Artifact('model', type='model')
# artifact.add_dir(MODEL_CKPT_PATH)

# run.log_artifact(artifact)
# run.join()

---

# Hyperparameter sweep 

In [None]:
# from src.sweeps import sweep_config

# sweep_id = wandb.sweep(sweep_config, project=proj)

In [None]:
# wandb.agent(sweep_id, function=sweep_iteration, project=proj)

---