In [20]:
import torch
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import shutil
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from PIL import Image
import glob
print(torch.__version__)
print(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

1.8.1+cu101
True


device(type='cuda', index=0)

In [21]:
import torch.nn.functional as F
class BasicBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels=in_channels, 
                                     out_channels=out_channels,
                                     bias=False,
                                     **kwargs)
        self.batch_norm = torch.nn.BatchNorm2d(num_features=out_channels)
    
    def forward(self, input):
        x = self.conv(input) 
        x = F.relu(self.batch_norm(x), inplace=True)
        return x

In [22]:
class InceptionBlock(torch.nn.Module):
    def __init__(self, in_channels, pool_features):
        super().__init__()
        self.branch1x1 = BasicBlock(in_channels=in_channels, out_channels=64, kernel_size=1)
        self.branch3x3_1 = BasicBlock(in_channels=in_channels, out_channels=64, kernel_size=1)
        self.branch3x3_2 = BasicBlock(in_channels=64, out_channels=96, kernel_size=3, padding=1)
        self.branch5x5_1 = BasicBlock(in_channels=in_channels, out_channels=48, kernel_size=1)
        self.branch5x5_2 = BasicBlock(in_channels=48, out_channels=64, kernel_size=5, padding=2)
        self.branch_pool = BasicBlock(in_channels=in_channels, 
                                      out_channels=pool_features, 
                                      kernel_size=1)
    
    def forward(self, input):
        branch_1x1_out = self.branch1x1(input) 

        branch_3x3 = self.branch3x3_1(input)
        branch_3x3_out = self.branch3x3_2(branch_3x3)

        branch_5x5 = self.branch5x5_1(input)
        branch_5x5_out = self.branch5x5_2(branch_5x5)

        branch_pool = F.max_pool2d(input, kernel_size=3, stride=1, padding=1)
        branch_pool_out = self.branch_pool(branch_pool)

        outputs = [branch_1x1_out, branch_3x3_out, branch_5x5_out, branch_pool_out]
        return torch.cat(tensors=outputs, dim=1)

In [23]:
inception_net = InceptionBlock(in_channels=32, pool_features=64)
inception_net.to(device)

InceptionBlock(
  (branch1x1): BasicBlock(
    (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch3x3_1): BasicBlock(
    (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch3x3_2): BasicBlock(
    (conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (batch_norm): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch5x5_1): BasicBlock(
    (conv): Conv2d(32, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (batch_norm): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (branch5x5_2): BasicBlock(
    (conv): Conv2d(48, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
    (batch_norm): BatchNorm2d(6

In [26]:
from torchsummary import summary
summary(inception_net, input_size=(32, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]           2,048
       BatchNorm2d-2         [-1, 64, 256, 256]             128
        BasicBlock-3         [-1, 64, 256, 256]               0
            Conv2d-4         [-1, 64, 256, 256]           2,048
       BatchNorm2d-5         [-1, 64, 256, 256]             128
        BasicBlock-6         [-1, 64, 256, 256]               0
            Conv2d-7         [-1, 96, 256, 256]          55,296
       BatchNorm2d-8         [-1, 96, 256, 256]             192
        BasicBlock-9         [-1, 96, 256, 256]               0
           Conv2d-10         [-1, 48, 256, 256]           1,536
      BatchNorm2d-11         [-1, 48, 256, 256]              96
       BasicBlock-12         [-1, 48, 256, 256]               0
           Conv2d-13         [-1, 64, 256, 256]          76,800
      BatchNorm2d-14         [-1, 64, 2