<a href="https://colab.research.google.com/github/KimJaehee0725/pytorch-studying/blob/main/Alexnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Pytorch

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
  from torchsummary import summary


class AlexNet(nn.Module):

    def __init__(self, n_classes):
        super(AlexNet, self).__init__()
        self.n_classes = n_classes

        self.extractor = nn.Sequential(
          nn.Conv2d(in_channels = 3, out_channels = 96, kernel_size = 11, stride = 4),
          nn.ReLU(),
          nn.MaxPool2d(kernel_size = 3, stride = 2),
          nn.Conv2d(in_channels = 96, out_channels = 256, kernel_size = 5, stride = 1, padding = 2),
          nn.ReLU(),
          nn.MaxPool2d(kernel_size = 3, stride = 2),
          nn.Conv2d(in_channels = 256, out_channels = 384, kernel_size = 3, stride = 1, padding = 1),
          nn.ReLU(),
          nn.Conv2d(in_channels = 384, out_channels = 384, kernel_size = 3, stride = 1, padding = 1),
          nn.ReLU(),
          nn.Conv2d(in_channels = 384, out_channels = 256, kernel_size = 3, stride = 1, padding = 1),
          nn.ReLU(),
          nn.MaxPool2d(kernel_size = 3, stride = 2)
        )

        self.classifier = nn.Sequential(
          nn.Linear(9216, 4096),
          nn.ReLU(),
          nn.Linear(in_features = 4096, out_features = 4096),
          nn.ReLU(),
          nn.Linear(in_features = 4096, out_features = self.n_classes)
        )
    def forward(self, x):
      x = self.extractor(x)
      x = torch.flatten(x, start_dim = 1)
      x = self.classifier(x)
      return x

model = AlexNet(1000)
summary(model, input_size = (3, 227, 227), device =  'cuda')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 55, 55]          34,944
              ReLU-2           [-1, 96, 55, 55]               0
         MaxPool2d-3           [-1, 96, 27, 27]               0
            Conv2d-4          [-1, 256, 27, 27]         614,656
              ReLU-5          [-1, 256, 27, 27]               0
         MaxPool2d-6          [-1, 256, 13, 13]               0
            Conv2d-7          [-1, 384, 13, 13]         885,120
              ReLU-8          [-1, 384, 13, 13]               0
            Conv2d-9          [-1, 384, 13, 13]       1,327,488
             ReLU-10          [-1, 384, 13, 13]               0
           Conv2d-11          [-1, 256, 13, 13]         884,992
             ReLU-12          [-1, 256, 13, 13]               0
        MaxPool2d-13            [-1, 256, 6, 6]               0
           Linear-14                 [-