In [1]:
import torch 
import torch.nn as nn

In [28]:
class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(conv_block, self).__init__()
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        return self.relu(self.batchnorm(self.conv(x)))
    

In [29]:
class Inception_block(nn.Module):
    def __init__(self, in_channels, out_1X1, red_3x3, out_3X3, red_5x5, out_5X5, out_1X1pool):
        super(Inception_block, self).__init__()
        self.branch1 = conv_block(in_channels, out_1X1, kernel_size=1)

        self.branch2 = nn.Sequential(
            conv_block(in_channels, red_3x3, kernel_size=1),
            conv_block(red_3x3, out_3X3, kernel_size=(3, 3), padding=1),
            )
        self.branch3 = nn.Sequential(
            conv_block(in_channels, red_5x5, kernel_size=1),
            conv_block(red_5x5, out_5X5, kernel_size=5, padding=2),
            )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            conv_block(in_channels, out_1X1pool, kernel_size=1),
            )

    def forward(self, x):
        # N X filters X 28 X 28
        return torch.cat(
            [self.branch1(x),self.branch2(x),self.branch3(x),self.branch4(x)],1)
        
        
        

In [34]:
class GoogleNet(nn.Module):
    def __init__(self,in_channels=3, num_classes=100 ):
        super(GoogleNet, self).__init__()
        self.conv1 = conv_block(in_channels, out_channels=64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding= 1)
        self.conv2 = conv_block(64, 192, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding= 1)

        # In this order: in_channels, out_1X1, red_3x3, out_3X3, red_5x5, out_5X5, out_1X1pool,
        self.inception3a = Inception_block(192, 64, 96, 128, 16,  32, 32)
        self.inception3b = Inception_block(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding= 1)
        
        self.inception4a = Inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception_block(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding= 1)

        self.inception5a = Inception_block(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception_block(832, 384, 192, 384, 48, 128, 128)
        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
        self.dropout = nn.Dropout(p=0.4)
        self.fc1 = nn.Linear(1024, 1000)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)
        
        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0],-1)

        x = self.dropout(x)
        x = self.fc1(x)

        return x  

In [35]:
x = torch.randn(3,3,224,224)
model = GoogleNet()
model(x).shape

torch.Size([3, 1000])

In [36]:
torch.cuda.is_available()

True

In [37]:
x

tensor([[[[-8.4624e-01, -3.5591e-02, -1.0491e+00,  ..., -1.8891e+00,
           -9.3265e-01,  2.1965e-01],
          [-9.6961e-01, -2.2419e-01,  9.9643e-02,  ..., -1.1354e-02,
            2.0183e+00,  5.8881e-01],
          [ 4.8330e-01,  3.3611e-01,  4.8518e-01,  ...,  6.6464e-01,
            1.3156e+00,  1.1234e+00],
          ...,
          [ 1.2189e+00,  5.4420e-01,  8.3786e-01,  ..., -8.6629e-01,
            6.7034e-01,  5.9399e-01],
          [ 5.0008e-01,  3.3639e-01, -1.0693e+00,  ..., -6.5303e-01,
           -5.4056e-01,  2.1344e+00],
          [-3.0497e-02, -4.3942e-01,  2.5284e-01,  ...,  4.9156e-01,
           -1.0191e+00, -3.0415e-01]],

         [[-1.1257e+00, -1.0395e-02,  3.0741e-01,  ...,  5.4842e-02,
            1.3133e+00, -5.3542e-01],
          [-5.0777e-02, -2.8677e-01,  2.4633e-01,  ..., -6.1921e-01,
            4.9188e-01, -1.2642e-01],
          [ 1.2388e+00,  1.2818e-01,  1.8932e-01,  ...,  2.9639e-01,
           -5.0251e-01,  1.0294e+00],
          ...,
     

In [39]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device


'cuda'

In [40]:
x.to(device)

tensor([[[[-8.4624e-01, -3.5591e-02, -1.0491e+00,  ..., -1.8891e+00,
           -9.3265e-01,  2.1965e-01],
          [-9.6961e-01, -2.2419e-01,  9.9643e-02,  ..., -1.1354e-02,
            2.0183e+00,  5.8881e-01],
          [ 4.8330e-01,  3.3611e-01,  4.8518e-01,  ...,  6.6464e-01,
            1.3156e+00,  1.1234e+00],
          ...,
          [ 1.2189e+00,  5.4420e-01,  8.3786e-01,  ..., -8.6629e-01,
            6.7034e-01,  5.9399e-01],
          [ 5.0008e-01,  3.3639e-01, -1.0693e+00,  ..., -6.5303e-01,
           -5.4056e-01,  2.1344e+00],
          [-3.0497e-02, -4.3942e-01,  2.5284e-01,  ...,  4.9156e-01,
           -1.0191e+00, -3.0415e-01]],

         [[-1.1257e+00, -1.0395e-02,  3.0741e-01,  ...,  5.4842e-02,
            1.3133e+00, -5.3542e-01],
          [-5.0777e-02, -2.8677e-01,  2.4633e-01,  ..., -6.1921e-01,
            4.9188e-01, -1.2642e-01],
          [ 1.2388e+00,  1.2818e-01,  1.8932e-01,  ...,  2.9639e-01,
           -5.0251e-01,  1.0294e+00],
          ...,
     