In [1]:
import torch
import torch.nn as nn

In [2]:
class Basicblock(nn.Module):
  def __init__(self,iin,out,**kwargs):
    super().__init__()
    self.conv=nn.Conv2d(iin,out,**kwargs)
    self.relu=nn.ReLU()
  def forward(self,x):
    return self.relu(self.conv(x))

In [3]:
class InceptionConv(nn.Module):
  def __init__(self,iin,n1,n3red,n3,n5red,n5,pool):
    super().__init__()
    self.b1=nn.Sequential(
        nn.Conv2d(iin,n1,kernel_size=1),
        nn.ReLU(True)
    )
    self.b2=nn.Sequential(
        Basicblock(iin,n3,kernel_size=1,),
        Basicblock(n3red,n3,kernel_size=3,padding=1)
    )
    self.b3=nn.Sequential(
        Basicblock(iin,n5red,kernel_size=1),
        Basicblock(n5red,n5,kernel_size=5,padding=2)
    )
    self.b4=nn.Sequential(
        nn.MaxPool2d(3,stride=1,padding=1),
        Basicblock(iin,pool,kernel_size=1)
    )
  def forward(self,x):
    l1=self.b1(x)
    l2=self.b2(x)
    l3=self.b3(x)
    l4=self.b4(x)
    return torch.concat([l1,l2,l3,l4],dim=1)

In [4]:
class AuxiliaryClassifier(nn.Module):
  def __init__(self,iin,num,dropout=0.7):
    super().__init__()
    self.pool=nn.AvgPool2d(5,stride=3)
    self.conv=Basicblock(iin,128,kernel_size=1)
    self.relu=nn.ReLU(True)
    self.flatten=nn.Flatten()
    self.fc1=nn.Linear(2048,1024)
    self.dropout=nn.Dropout(dropout)
    self.fc2=nn.Linear(1024,num)
  def forward(self,x):
    return self.fc2(self.dropout(self.relu(self.fc1(self.flatten(self.conv(self.pool(x)))))))

In [8]:
class GoogleNEt(nn.Module):
  def __init__(self,use_aux=True):
    super().__init__()
    self.use_aux=use_aux

    self.conv1=Basicblock(3,64,kernel_size=7,stride=2,padding=3)
    self.lrn1=nn.BatchNorm2d(64)
    self.maxpool1=nn.MaxPool2d(3,stride=2,padding=1)

    self.conv2=Basicblock(64,64,kernel_size=1)
    self.conv3=Basicblock(64,192,kernel_size=3,padding=1)
    self.lrn2=nn.BatchNorm2d(64)
    self.maxpool2=nn.MaxPool2d(3,stride=2,padding=1)

    self.inception3a=InceptionConv(192,64,96,128,16,32,32)
    self.inception3b=InceptionConv(256,128,128,192,32,96,64)
    self.maxpool3=nn.MaxPool2d(3,stride=2,padding=1)

    self.inception4a=InceptionConv(480,192,96,208,16,48,64)
    self.inception4b=InceptionConv(512,160,112,224,24,64,64)
    self.inception4c=InceptionConv(512,128,128,256,24,64,64)
    self.incetion4d=InceptionConv(512,112,144,288,32,64,64)
    self.inception4e=InceptionConv(528,256,160,320,32,128,128)
    self.maxpool4=nn.MaxPool2d(3,stride=2,padding=1)

    self.inception5a=InceptionConv(832,256,160,320,32,128,128)
    self.inception5b=InceptionConv(832,384,192,384,48,128,128)

    if self.use_aux:
      self.aux1=AuxiliaryClassifier(512,1000)
      self.aux2=AuxiliaryClassifier(528,1000)

    self.avgpool=nn.AvgPool2d(7,stride=1)
    self.dropout=nn.Dropout(0.4)
    self.fc=nn.Linear(1024,1000)
  def forward(self,x):
    x=self.lrn1(self.maxpool1(self.conv1(x)))
    x=self.maxpool2(self.lrn2(self.conv3(self.conv2(x))))
    x=self.maxpool3(self.inception3b(self.inception3a(x)))
    x=self.inception4a(x)

    if self.use_aux:
      aux1=self.aux1(x)
    x=self.inception4d(self.inception4c(self.inception4b(x)))

    if self.use_aux:
      aux2=self.aux2(x)
    x=self.maxpool4(self.inception4e(x))
    x=self.inception5b(self.inception5a(x))
    x=self.avgpool(x)
    x=torch.flatten(x,1)
    x=self.fc(self.dropout(x))

    if self.use_aux:
      return x,aux1,aux2
    else:
      return x