In [None]:
!pip install pytorch-lightning --quiet
!git clone https://github.com/JB-Lee/ASD2021_01.git

In [None]:
%load_ext autoreload
%autoreload 1

import sys
sys.path.append('/content/ASD2021_01/cifar10')


In [1]:

GPUS = 1
PRECISION = 16
PATIENCE = 3
LOG_DIR = 'tb_logs'

BATCH_SIZE = 200
LEARNING_RATE = 1e-03

class MLP:
    HIDDEN_SIZE = 16 * 16
    HIDDEN_COUNT = 4
    DROPOUT = 0.2

class RESNET:
    NUM_BLOCKS = [2, 2, 2, 2]
    IN_BLOCK_CHANNELS = 32



In [2]:
%aimport model

import torch
import torchvision
import torchvision.transforms as transforms

import pytorch_lightning as pl

from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

base_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

data_augmentation = transforms.Compose(
    [transforms.RandomResizedCrop((32, 32), scale=(0.9, 1.0), ratio=(0.9, 1.1)),
     transforms.RandomHorizontalFlip(),
     transforms.RandomRotation(5)]
)


train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                         download=True, transform=base_transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=base_transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2, pin_memory=True)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified
Files already downloaded and verified


### MLP

In [5]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=PATIENCE,
    check_finite=True
)

logger = TensorBoardLogger(LOG_DIR, name=f'mlp_bs({BATCH_SIZE})_hs({MLP.HIDDEN_SIZE})_hc({MLP.HIDDEN_COUNT})_dr({MLP.DROPOUT})', log_graph=True, default_hp_metric=True)

net = model.MLPCifar10(MLP.HIDDEN_SIZE, MLP.HIDDEN_COUNT, dropout=MLP.DROPOUT, transform=data_augmentation, learning_rate = LEARNING_RATE)
trainer = pl.Trainer(gpus=GPUS, callbacks=[early_stopping], logger=logger, precision=PRECISION)
trainer.fit(net, train_loader, test_loader)

### CNN

In [None]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=PATIENCE,
    check_finite=True
)

logger = TensorBoardLogger(LOG_DIR, name=f'deep cnn_bs({BATCH_SIZE})', log_graph=True, default_hp_metric=True)

net = model.DCNNCifar10(transform=data_augmentation, learning_rate = LEARNING_RATE)
trainer = pl.Trainer(gpus=GPUS, callbacks=[early_stopping], logger=logger, precision=PRECISION)
trainer.fit(net, train_loader, test_loader)

### ResNet

In [None]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=PATIENCE,
    check_finite=True
)

logger = TensorBoardLogger(LOG_DIR, name=f'ResNet_bs({BATCH_SIZE})_nb({RESNET.NUM_BLOCKS})_ibc({RESNET.IN_BLOCK_CHANNELS})', log_graph=True, default_hp_metric=True)

net = model.ResV2Cifar10(RESNET.NUM_BLOCKS, in_block_hannels=RESNET.IN_BLOCK_CHANNELS, transform=data_augmentation, learning_rate = LEARNING_RATE)
trainer = pl.Trainer(gpus=GPUS, callbacks=[early_stopping], logger=logger, precision=PRECISION)
trainer.fit(net, train_loader, test_loader)

## Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir tb_logs