In [1]:
import os, torch, torchvision, torchmetrics

import torch.nn.functional as F
import torchvision.transforms as TT
import pytorch_lightning as pl

from torch import nn, optim
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [2]:
data_dir = "./petfinder-adoption-prediction"
resize_dim = 50
crop_dim = 50
batch_size = 64
num_workers = 4
lr = 1e-5
max_epochs = 50
log_freq = 10

#Data stats (mean, std)
stats = ((0.4968, 0.4639, 0.4281), (0.2600, 0.2579, 0.2607))

train_transform = TT.Compose([
    TT.Resize((resize_dim, resize_dim)),
    TT.RandomCrop(crop_dim, padding=4, padding_mode='reflect'),
    TT.RandomHorizontalFlip(),
    TT.ToTensor(),
    TT.Normalize(*stats,inplace=True)
])

test_transform = TT.Compose([
    TT.Resize((resize_dim, resize_dim)),
    TT.ToTensor(),
    TT.Normalize(*stats,inplace=True)
])

train_dataset = ImageFolder(data_dir+'/train', transform=train_transform)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, pin_memory = True, num_workers = num_workers, shuffle = True)

test_dataset = ImageFolder(data_dir+'/test', transform=test_transform)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, pin_memory = True, num_workers = num_workers)

In [3]:
def labels_to_tensor(labels):
    label_tensor = torch.zeros((len(labels), 2), device = torch.device("cuda"))
    
    for iter_n in range(len(labels)):
        if labels[iter_n] == 1:
            label_tensor[iter_n][1] = 1
        else:
            label_tensor[iter_n][0] = 1

    return label_tensor

class Flatten(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return x.view(x.size(0), -1)

class ResidualPod(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.bn = nn.BatchNorm2d(channels)
        self.ReLU = nn.ReLU(inplace = True)
        self.conv = nn.Conv2d(channels, channels, kernel_size = 3, stride = 1, padding = 1)
    
    def forward(self, x):
        x = self.bn(x)
        x = self.ReLU(x)
        x_pass = x
        x = self.conv(x)
        x = self.bn(x)
        x = self.ReLU(x)
        x = self.conv(x)
        return (x + x_pass)

class ResidualGroup(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.ReLU = nn.ReLU(inplace = True)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv_in = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv_trans = nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1)
        self.conv_mid = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
        self.res_pod = ResidualPod(out_channels)

    def forward(self, x):
        x = self.bn1(x)
        x = self.ReLU(x)
        x_pass = self.conv_trans(x)
        x = self.conv_in(x)
        x = self.bn2(x)
        x = self.ReLU(x)
        x = self.conv_mid(x)
        x = x + x_pass
        x = self.res_pod(x)
        x = self.res_pod(x)
        return x

In [4]:
size_1 = 3
size_2 = 32
size_3 = 64
size_4 = 128
size_5 = 256

resnet = nn.Sequential(
    nn.Conv2d(size_1, size_2, kernel_size = 3, stride = 1, padding = 1),
    ResidualGroup(size_2, size_3),
    ResidualGroup(size_3, size_4),
    ResidualGroup(size_4, size_5),
    nn.BatchNorm2d(size_5),
    nn.ReLU(inplace = True),
    nn.AdaptiveAvgPool2d(1),
    Flatten(),
    nn.Linear(size_5, size_3),
    nn.ReLU(inplace = True),
    nn.Linear(size_3, 2),
    nn.Sigmoid()
)

print("Resnet structure:", resnet)

Resnet structure: Sequential(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ResidualGroup(
    (ReLU): ReLU(inplace=True)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_in): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv_trans): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
    (conv_mid): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (res_pod): ResidualPod(
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (ReLU): ReLU(inplace=True)
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (2): ResidualGroup(
    (ReLU): ReLU(inplace=True)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(128, eps=1e-05, moment

In [5]:
class ResnetHandler(pl.LightningModule):
    def __init__(self, network):
        super().__init__()
        self.network = network
        self.accuracy = torchmetrics.Accuracy().to("cuda")

    def metric_loss(self, labels, y):
        label_tensor = labels_to_tensor(labels)
    
        # Accuracy Metric      
        y_int = torch.round(F.softmax(y, 1)).int()
        self.accuracy(label_tensor, y_int)
    
        # Loss
        loss = F.binary_cross_entropy(y, label_tensor)
    
        return self.accuracy, loss
        
    def training_step(self, batch, batch_idx):
        x, labels = batch
        y = self.network(x)
        
        accuracy_n, loss = self.metric_loss(labels, y)
        
        self.log("train_acc", accuracy_n)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, labels = batch
        y = self.network(x)
        
        accuracy_n, loss = self.metric_loss(labels, y)
        
        self.log("val_acc", accuracy_n)
        self.log("val_loss", loss)
        return loss
    
    def validation_epoch_end(self, validation_step_outputs):
        val_preds = torch.stack(validation_step_outputs)
        val_preds = torch.mean(val_preds)
        print ("Mean validation loss:", val_preds.item())
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=lr)
        return optimizer

In [6]:
%load_ext tensorboard
%tensorboard --logdir=lightning_logs

Reusing TensorBoard on port 6006 (pid 9088), started 12:58:49 ago. (Use '!kill 9088' to kill it.)

In [7]:
resnet_instance = ResnetHandler(resnet)

trainer = pl.Trainer(accelerator = "gpu", devices = 1, max_epochs = max_epochs, log_every_n_steps = log_freq)
trainer.fit(resnet_instance, train_dataloader, test_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | network  | Sequential | 2.0 M 
1 | accuracy | Accuracy   | 0     
----------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
8.002     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Mean validation loss: 0.6917905211448669


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Mean validation loss: 1.0276496410369873


Validation: 0it [00:00, ?it/s]

Mean validation loss: 2.0733015537261963


Validation: 0it [00:00, ?it/s]

Mean validation loss: 3.646667242050171


Validation: 0it [00:00, ?it/s]

Mean validation loss: 4.500024795532227


Validation: 0it [00:00, ?it/s]

Mean validation loss: 6.278190612792969


Validation: 0it [00:00, ?it/s]

Mean validation loss: 9.80264663696289


Validation: 0it [00:00, ?it/s]

Mean validation loss: 6.115978240966797


Validation: 0it [00:00, ?it/s]

Mean validation loss: 24.902463912963867


Validation: 0it [00:00, ?it/s]

Mean validation loss: 10.356403350830078


Validation: 0it [00:00, ?it/s]

Mean validation loss: 11.177411079406738


Validation: 0it [00:00, ?it/s]

Mean validation loss: 5.638296127319336


Validation: 0it [00:00, ?it/s]

Mean validation loss: 11.677205085754395


Validation: 0it [00:00, ?it/s]

Mean validation loss: 20.646543502807617


Validation: 0it [00:00, ?it/s]

Mean validation loss: 21.586381912231445


Validation: 0it [00:00, ?it/s]

Mean validation loss: 14.341350555419922


Validation: 0it [00:00, ?it/s]

Mean validation loss: 13.596687316894531


Validation: 0it [00:00, ?it/s]

Mean validation loss: 21.078350067138672


Validation: 0it [00:00, ?it/s]

Mean validation loss: 24.524290084838867


Validation: 0it [00:00, ?it/s]

Mean validation loss: 16.07631492614746


Validation: 0it [00:00, ?it/s]

Mean validation loss: 8.393138885498047


Validation: 0it [00:00, ?it/s]

Mean validation loss: 7.378521919250488


Validation: 0it [00:00, ?it/s]

Mean validation loss: 5.916163444519043


Validation: 0it [00:00, ?it/s]

Mean validation loss: 8.266493797302246


Validation: 0it [00:00, ?it/s]

Mean validation loss: 6.562342643737793


Validation: 0it [00:00, ?it/s]

Mean validation loss: 5.288812160491943


Validation: 0it [00:00, ?it/s]

Mean validation loss: 9.062552452087402


Validation: 0it [00:00, ?it/s]

Mean validation loss: 8.03467082977295


Validation: 0it [00:00, ?it/s]

Mean validation loss: 4.615635871887207


Validation: 0it [00:00, ?it/s]

Mean validation loss: 7.6312971115112305


Validation: 0it [00:00, ?it/s]

Mean validation loss: 5.336467742919922


Validation: 0it [00:00, ?it/s]

Mean validation loss: 11.737452507019043


Validation: 0it [00:00, ?it/s]

Mean validation loss: 8.65796947479248


Validation: 0it [00:00, ?it/s]

Mean validation loss: 5.2538042068481445


Validation: 0it [00:00, ?it/s]

Mean validation loss: 14.098496437072754


Validation: 0it [00:00, ?it/s]

Mean validation loss: 8.467630386352539


Validation: 0it [00:00, ?it/s]

Mean validation loss: 8.259374618530273


Validation: 0it [00:00, ?it/s]

Mean validation loss: 11.287616729736328


Validation: 0it [00:00, ?it/s]

Mean validation loss: 4.154063701629639


Validation: 0it [00:00, ?it/s]

Mean validation loss: 6.140931129455566


Validation: 0it [00:00, ?it/s]

Mean validation loss: 11.56689453125


Validation: 0it [00:00, ?it/s]

Mean validation loss: 5.121087551116943


Validation: 0it [00:00, ?it/s]

Mean validation loss: 3.7302427291870117


Validation: 0it [00:00, ?it/s]

Mean validation loss: 5.572103023529053


Validation: 0it [00:00, ?it/s]

Mean validation loss: 7.290428161621094


Validation: 0it [00:00, ?it/s]

Mean validation loss: 6.077175140380859


Validation: 0it [00:00, ?it/s]

Mean validation loss: 8.178397178649902


Validation: 0it [00:00, ?it/s]

Mean validation loss: 3.2448084354400635


Validation: 0it [00:00, ?it/s]

Mean validation loss: 3.9581286907196045


Validation: 0it [00:00, ?it/s]

Mean validation loss: 7.471804141998291


Validation: 0it [00:00, ?it/s]

Mean validation loss: 4.822860240936279
