In [0]:
import torch
import torch.nn as nn

In [0]:
class LRN(nn.Module):
  def __init__(self, local_size = 1, alpha = 1.0, beta = 0.75, ACROSS_CHANNELS = True):
    super(LRN, self).__init__()
    self.ACROSS_CHANNELS = ACROSS_CHANNELS
    if ACROSS_CHANNELS:
      self.average = nn.AvgPool3d(kernel_size = (local_size, 1, 1), stride = 1, padding = (int((local_size - 1.0) / 2), 0, 0))
    
    else:
      self.average = nn.AvgPool2d(kernel_size = local_size, stride = 1, padding = int((local_size - 1.0) / 2))
    
    self.alpha = alpha
    self.beta = beta
  
  def forward(self, x):
    if self.ACROSS_CHANNELS:
      div = x.pow(2).unsqueeze(1)
      div = self.average(div).squeeze(1)
      div = div.mul(self.alpha).add(1.0).pow(self.beta)
    
    else:
      div = x.pow(2)
      div = self.average(div)
      div = div.mul(self.alpha).add(1.0).pow(self.beta)
    
    x = x.div(div)
    return x

In [0]:
class AlexNet(nn.Module):
  def __init__(self, num_classes = 1000):
    super(AlexNet, self).__init__()
    self.features = nn.Sequential(
        nn.Conv2d(3, 96, kernel_size = 11, stride = 4, padding = 0),
        nn.ReLU(inplace = True),
        nn.MaxPool2d(kernel_size = 3, stride = 2),
        LRN(local_size = 5, alpha = 0.0001, beta = 0.75),

        nn.Conv2d(96, 256, kernel_size = 5, stride = 1, padding = 2),
        nn.ReLU(inplace = True),
        nn.MaxPool2d(kernel_size = 3, stride = 2),
        LRN(local_size = 5, alpha = 0.0001, beta = 0.75),
        
        nn.Conv2d(256, 384, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(inplace = True),
        
        nn.Conv2d(384, 384, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(inplace = True),
        
        nn.Conv2d(384, 256, kernel_size = 3, stride = 1, padding = 1),
        nn.ReLU(inplace = True),
        nn.MaxPool2d(kernel_size = 3, stride = 2),
        nn.Dropout(p=0.5),
        )
    self.classifier = nn.Sequential(
        nn.Linear(256 * 6 * 6, 4096),
        nn.ReLU(inplace = True),
        nn.Dropout(p=0.5),

        nn.Linear(4096, 4096),
        nn.ReLU(inplace = True),
        nn.Linear(4096, num_classes))
    
  def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), 256 * 6 * 6)
    x = self.classifier(x)
    return x

In [4]:
def alexnet(**kwargs):
  model = AlexNet(**kwargs)
  return model

alexnet(num_classes = 1000)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): LRN(
      (average): AvgPool3d(kernel_size=(5, 1, 1), stride=1, padding=(2, 0, 0))
    )
    (4): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): LRN(
      (average): AvgPool3d(kernel_size=(5, 1, 1), stride=1, padding=(2, 0, 0))
    )
    (8): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
   