In [1]:
import numpy as np
from torch import nn
from torch.nn import functional as F
import torch
from torch import optim
from torchvision import datasets,transforms
from torchinfo import summary
import albumentations as A
from albumentations.pytorch import ToTensorV2

from transformations import (train_transforms,test_transforms,no_transforms)

from trainer import Trainer
from tester import Tester
from utlis import visualize_data,show_misclassified_images,device,is_cuda
from viz import plot_class_distribution, plot_confusion_matrix, plot_curves
from dataloader import CIFAR10Dataset,CIFAR10DataLoader
from models import ConvLayer,TransBlock,DepthwiseConvLayer


# Augmentation
atrain_dataset = CIFAR10Dataset(root='../data/',train=True,Atransforms=train_transforms,download=False)
atest_dataset  = CIFAR10Dataset(root='../data/',train=False,Atransforms=test_transforms,download=False)
acifar = CIFAR10DataLoader(batch_size=512,is_cuda_available=True)
atrain_loader, atest_loader = acifar.get_loader(atrain_dataset,atest_dataset)


# No Augmentation
train_dataset = CIFAR10Dataset(root='../data/',train=True,Atransforms=no_transforms,download=False)
test_dataset  = CIFAR10Dataset(root='../data/',train=False,Atransforms=no_transforms,download=False)
cifar = CIFAR10DataLoader(batch_size=512,is_cuda_available=True)
train_loader, test_loader = cifar.get_loader(train_dataset,test_dataset)

In [2]:
# visualize_data(atrain_loader,classes=acifar.classes,num_figures=24)

In [3]:
# visualize_data(train_loader,classes=acifar.classes,num_figures=24)

In [4]:
dp_rate = 0.1
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = ConvLayer(inc=3,outc=8,k=3,p=1,s=1,d=1,dp_rate=dp_rate)
        
        self.conv2 = ConvLayer(         inc=8,outc=16,k=3,p=1,s=1,d=1,dp_rate=dp_rate)
        self.dep2  = DepthwiseConvLayer(inc=8,outc=16,p=1,s=1,dp_rate=dp_rate)
        self.dil2  = ConvLayer(         inc=8,outc=16,k=3,p=2,s=1,d=2,dp_rate=dp_rate)

        self.conv3 = ConvLayer(         inc=16,outc=32,k=3,p=1,s=1,d=1,dp_rate=dp_rate)
        self.dep3  = DepthwiseConvLayer(inc=16,outc=32,p=1,s=1,dp_rate=dp_rate)
        self.dil3  = ConvLayer(         inc=16,outc=32,k=3,p=2,s=1,d=2,dp_rate=dp_rate)

        self.trans4 = TransBlock(56,16,p=0,s=2)


        self.conv5 = ConvLayer(16,32,k=3,p=1,s=1,d=1,dp_rate=dp_rate)

        self.conv6 = ConvLayer(         inc=32,outc=48,k=3,p=1,s=1,d=1,dp_rate=dp_rate)
        self.dep6  = DepthwiseConvLayer(inc=32,outc=48,p=1,s=1,dp_rate=dp_rate)
        self.dil6  = ConvLayer(         inc=32,outc=48,k=3,p=2,s=1,d=2,dp_rate=dp_rate)

        self.conv7 = ConvLayer(         inc=48,outc=56,k=3,p=1,s=1,d=1,dp_rate=dp_rate)
        self.dep7  = DepthwiseConvLayer(inc=48,outc=56,p=1,s=1,dp_rate=dp_rate)
        self.dil7  = ConvLayer(         inc=48,outc=56,k=3,p=2,s=1,d=2,dp_rate=dp_rate)
      
        self.trans8 = TransBlock(136,24,p=0,s=2)

        
        self.conv9 = ConvLayer(24,40,k=3,p=1,s=1,d=1,dp_rate=dp_rate)

        self.conv10 = ConvLayer(         inc=40,outc=48,k=3,p=1,s=1,d=1,dp_rate=dp_rate)
        self.dep10  = DepthwiseConvLayer(inc=40,outc=48,p=1,s=1,dp_rate=dp_rate)
        self.dil10  = ConvLayer(         inc=40,outc=48,k=3,p=4,s=1,d=4,dp_rate=dp_rate)


        self.conv_ = ConvLayer(         inc=48,outc=48,k=3,p=1,s=1,d=1,dp_rate=dp_rate)
        self.dep_  = DepthwiseConvLayer(inc=48,outc=48,p=1,s=1,dp_rate=dp_rate)
        self.dil_  = ConvLayer(         inc=48,outc=48,k=3,p=4,s=1,d=4,dp_rate=dp_rate)

        self.trans11 = TransBlock(88,40,p=0,s=1)
        self.trans12 = TransBlock(40,20,p=0,s=1)
        self.out = TransBlock(inc=20,outc=10,p=0,s=1)
        self.gap = nn.AdaptiveAvgPool2d(output_size=1)
        
    
    def forward(self,x):        
        x0 = self.conv1(x)
        x1 = self.dep2(x0) + self.conv2(x0) + self.dil2(x0) 
        x2 = self.dep3(x1) + self.conv3(x1) + self.dil3(x1) 
        print(x0.shape,x1.shape,x2.shape)
        x3 = torch.concat((x0,x1,x2),dim=1)
        x4 = self.trans4(x3)

        x5 = self.conv5(x4)
        x6 = self.conv6(x5) + self.dep6(x5) + self.dil6(x5)
        x7 = self.conv7(x6)+ self.dep7(x6)+ self.dil7(x6)
        x8 = torch.concat((x5,x6,x7),dim=1)
        x9  = self.trans8(x8)

        x10 = self.conv9(x9)
        x11 = self.conv10(x10) + self.dep10(x10) + self.dil10(x10)
        x11 = self.conv_(x11) + self.dep_(x11) + self.dil_(x11)
        x12 = torch.concat((x10,x11),dim=1)
        x  = self.trans11(x12)
        x  = self.trans12(x)
        
        x = self.out(x)
        x = self.gap(x)
        return  F.log_softmax(x.view(-1,10), dim=1)


model = Net().to(device=device)
summary(model,input_size=(1,3,32,32),device=device)

torch.Size([1, 8, 32, 32]) torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32, 32])


Layer (type:depth-idx)                   Output Shape              Param #
Net                                      [1, 10]                   --
├─ConvLayer: 1-1                         [1, 8, 32, 32]            --
│    └─Sequential: 2-1                   [1, 8, 32, 32]            --
│    │    └─Conv2d: 3-1                  [1, 8, 32, 32]            216
│    │    └─BatchNorm2d: 3-2             [1, 8, 32, 32]            16
│    │    └─ReLU: 3-3                    [1, 8, 32, 32]            --
│    │    └─Dropout2d: 3-4               [1, 8, 32, 32]            --
├─DepthwiseConvLayer: 1-2                [1, 16, 32, 32]           --
│    └─Sequential: 2-2                   [1, 16, 32, 32]           --
│    │    └─Conv2d: 3-5                  [1, 8, 32, 32]            72
│    │    └─Conv2d: 3-6                  [1, 16, 32, 32]           128
│    │    └─ReLU: 3-7                    [1, 16, 32, 32]           --
│    │    └─BatchNorm2d: 3-8             [1, 16, 32, 32]           32
│    │    └─D

In [None]:
# prev_test_loss = float('inf')

# optimizer = optim.SGD(params=model.parameters(), lr=0.1, momentum=0.9)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,mode='min',factor=0.1,patience=3)
# criterion = nn.NLLLoss()
# trainer = Trainer(model=model, train_loader=atrain_loader, optimizer=optimizer, criterion=criterion, device=device)
# tester = Tester(model=model, test_loader=atest_loader,criterion=criterion, device=device)

# for epoch in range(1, 51):
#     trainer.train(epoch=epoch)
#     _,test_loss = tester.test()
#     if prev_test_loss>test_loss:
#         torch.save(obj=model.state_dict(),f='./bmodels/model6.pth')
#     scheduler.step(test_loss)

In [None]:
# plot_curves(trainer.train_losses,trainer.train_accuracies,tester.test_losses,tester.test_accuracies)

In [5]:
# images, predictions, labels =  tester.get_misclassified_images()
# show_misclassified_images(images[:15],predictions[:15],labels[:15],cifar.classes)
# plot_class_distribution(train_loader,cifar.classes)
# plot_confusion_matrix(model,test_loader,device,cifar.classes)