In [1]:
import os, pdb

import argparse 

import torch, torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import kornia
import pytorch_lightning as pl

from einops import rearrange, reduce, repeat
import numpy as np
import matplotlib.pyplot as plt



In [2]:
# parser = argparse.ArgumentParser(description='PyTorch Lightning Training Template')
# parser.add_argument('--lr', default=0.001, type=float, help='learning rate')


# args = parser.parse_args()


In [3]:
from torchvision import transforms

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


training_set = torchvision.datasets.CIFAR10(
    root='../data', train=True, download=True, transform=transform_train
)



training_set, valiation_dset = torch.utils.data.random_split(training_set, [40000,len(training_set) - 40000])


training_loader = torch.utils.data.DataLoader(
    training_set, batch_size=128, shuffle=True, num_workers=2
)

validation_loader = torch.utils.data.DataLoader(
    training_set, batch_size=512, shuffle=False, num_workers=2
)


classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')



Files already downloaded and verified


In [4]:

class RGBNetwork(nn.Module):
    def __init__(self, input_dim=3):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(input_dim, 64, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv2d(512, 256, kernel_size=3, stride=2),
            nn.Flatten(),
            nn.Linear(256, 1000),
            nn.ReLU(),
            nn.Linear(1000, 10),
        )

    def forward(self, x):
        x = self.layers(x)
        return x
        

        
data = torch.randn(16, 3, 32, 32)
net = RGBNetwork()
out = net(data)
print(out.shape)


torch.Size([16, 10])


In [5]:

class YNetwork(nn.Module):
    def __init__(self, input_dim=1):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(input_dim, 64, kernel_size=3, stride=3),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=3),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=3),
            nn.Flatten(),
            nn.Linear(256, 1000),
            nn.ReLU(),
            nn.Linear(1000, 10),
        )

    def forward(self, x):
        x = self.layers(x)
        return x
        

        
data = torch.randn(16, 1, 32, 32)
net = YNetwork()
out = net(data)
print(out.shape)


torch.Size([16, 10])


In [6]:
class LightningModule(pl.LightningModule):

    def __init__(self, model, hparams=None):
        super().__init__()
        self.model = model
        self.loss = torch.nn.CrossEntropyLoss()
        
    def training_step(self, batch, batch_idx):
        x, y = batch
    
        scores = self.model(x)        
        loss = self.loss(scores, y)
        self.log("training_loss", loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        scores = self.model(x)
        predictions = torch.argmax(scores, dim=1)
        
        num_true = torch.sum(predictions == y)
        num_false = torch.sum(predictions != y)
        
        return num_true.item(), num_false.item()
        
        
    def validation_epoch_end(self, validation_step_outputs):
        validation_step_outputs = np.array(validation_step_outputs)
        total = reduce(validation_step_outputs, "b tf -> tf", reduction=sum)
        acc = total[0] / (total[0] + total[1])
        self.log("val_acc", acc, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.3, patience=2)
        return {'optimizer': optimizer, 
                'lr_scheduler': scheduler, 
                'monitor': 'training_loss'}

    


In [7]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger


gpu_stats = GPUStatsMonitor()
early_stopping = EarlyStopping(monitor='val_acc', patience=5, verbose=True, mode='max')
tb_logger = TensorBoardLogger(save_dir="../logs/")
checkpoint = ModelCheckpoint(dirpath='../model-checkpoints', filename='{epoch}anan')




In [None]:
from torchvision.models import resnet50
resnet = resnet50(pretrained=False)
resnet.fc = torch.nn.Linear(in_features=2048, out_features=100)

rgbmodule = LightningModule(model=resnet)


trainer = pl.Trainer(gpus=1, 
                     callbacks=[gpu_stats, early_stopping, checkpoint],
                     logger = tb_logger)

trainer.fit(rgbmodule, training_loader, validation_loader)


GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNet           | 23.7 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
23.7 M    Trainable params
0         Non-trainable params
23.7 M    Total params
94.852    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…