In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms
import torch.optim as opt
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models, datasets
from typing import Tuple, List, Dict
import pathlib
from PIL import Image


In [None]:
class InceptionNet:
    def __init__(self, num_classes):
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64 , kernel_size=7, padding='same')
            nn.MaxPool2d(kernel_size=3, padding='same')
            nn.LocalResponseNorm(64)
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 192 , kernel_size=1, padding='same')
            nn.Conv2d(192,192, kernel_size=3, padding='same')
            nn.LocalResponseNorm(64)
        )

        self.maxpool = nn.MaxPool2d(kernel_size=3, padding='same')
        self.avgpool = nn.AvgPool2d(kernel_size=7, padding='same')

        self.final_out = nn.Sequential(
            nn.Linear(1024, num_classes),
            nn.Softmax(num_classes)
        )

    def inceptionBlock1(x, f1, f3_in, f3_out, f5_in, f5_out, max_pool_out):
        # only 1*1 conv
        conv1 = nn.Conv2d(f1, kernel_size=3, padding='same')(x)

        # conv1 and conv3
        conv13 = nn.Conv2d(f3_in, kernel_size=1, padding='same')(x)
        conv13 = nn.Conv2d(f3_out, kernel_size=3, padding='same')(conv13)
        
        # conv1 and conv5
        conv15 = nn.Conv2d(f5_in, kernel_size=1, padding='same')(x)
        conv15 = nn.Conv2d(f5_out, kernel_size=5, padding='same')(conv15)

        # maxpool
        pool = nn.MaxPool2d(kernel_size=3, padding='same')(x)
        pool = nn.Conv2d(max_pool_out , kernel_size=1, padding='same')(pool)

        output_layer = torch.concat([conv1, conv13, conv15, pool])
        
        return output_layer


    def incptionBLock2(x, f1, f3_in, f3_out, f5_in, f5_out, avg_pool_out):
            # only 1*1 conv
        conv1 = nn.Conv2d(f1, kernel_size=3, padding='same')(x)

        # conv1 and conv3
        conv13 = nn.Conv2d(f3_in, kernel_size=1, padding='same')(x)
        conv13 = nn.Conv2d(f3_out, kernel_size=3, padding='same')(conv13)
        
        # conv1 and conv5
        conv15 = nn.Conv2d(f5_in, kernel_size=1, padding='same')(x)
        conv15 = nn.Conv2d(f5_out, kernel_size=5, padding='same')(conv15)

        # maxpool
        pool = nn.MaxPool2d(kernel_size=3, padding='same')(x)
        pool = nn.Conv2d(kernel_size=1, padding='same')(pool)

        # maxpool
        pool2 = nn.AvgPool2d(kernel_size=5, padding='same')(x)
        pool2 = nn.Conv2d(avg_pool_out, kernel_size=1)(pool2)

        output_layer = torch.concat([conv1, conv13, conv15, pool])
        
        return output_layer



    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.maxpool(x)
        x = self.inceptionBlock1(x, 64, 96, 128, 16, 32, 32)
        x = self.inceptionBlock1(x,  128, 128, 192, 32, 96, 64)
        x = self.maxpool(x)
        x = self.inceptionBlock1(x, 192, 96, 208, 16, 48, 64)
        x = self.inceptionBlock1(x, 160, 112, 224, 24, 64, 64)
        x = self.inceptionBlock1(x, 128, 128, 256, 16, 64, 64)
        x = self.inceptionBlock1(x, 112, 144, 288, 32, 64, 64)
        x = self.inceptionBlock1(x, 256, 160, 320, 32, 128, 128)
        x = self.maxpool(x)
        x = self.inceptionBlock1(x, 256, 160, 320, 32, 128, 128)
        x = self.inceptionBlock1(x, 384, 192, 384, 48, 128, 128)
        x = self.avgpool(x)
        x = nn.Dropout(x)
        x = torch.flatten(x)
        x = self.final_out(x)

        return x
