In [4]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F

## ReLU Non-Linearity

In [5]:
reLU = nn.ReLU()

## Local Response Normalization

In [6]:
lrn = nn.LocalResponseNorm(
    size=5,
    alpha=1e-4,
    beta=0.75,
    k=2
)

## Overlapping Max Pooling

In [7]:
ompool = nn.MaxPool2d(kernel_size=3, stride=2)

## Complete Architecture

In [18]:
class AlexNet(nn.Module):
    def __init__(self,
                activation,
                pool,
                norm,
                dropout,
                C1_in=3, C1_out=96, C1_kernel=11, C1_stride=4, C1_padding=2,
                C2_in=96, C2_out=256, C2_kernel=5, C2_stride=1, C2_padding=2,
                C3_in=256, C3_out=384, C3_kernel=3, C3_stride=1, C3_padding=1,
                C4_in=384, C4_out=384, C4_kernel=3, C4_stride=1, C4_padding=1,
                C5_in=384, C5_out=256, C5_kernel=3, C5_stride=1, C5_padding=1,
                FC1_in=9216, FC1_out=4096,
                FC2_in=4096, FC2_out=4096,
                FC3_in=4096, FC3_out=1000):
        super().__init__()
        self.activation = activation or nn.ReLU()
        self.pool = pool or nn.MaxPool2d(kernel_size=3, stride=2)
        self.norm = norm or nn.LocalResponseNorm(k=2, alpha=1e-4, beta=0.75, size=5)
        self.dropout = dropout or nn.Dropout(p=0.5)
        self.conv1 = nn.Conv2d(
            in_channels=C1_in,
            out_channels=C1_out,
            kernel_size=C1_kernel,
            stride=C1_stride,
            padding=C1_padding
        )
        self.conv2 = nn.Conv2d(
            in_channels=C2_in,
            out_channels=C2_out,
            kernel_size=C2_kernel,
            stride=C2_stride,
            padding=C2_padding
        )
        self.conv3 = nn.Conv2d(
            in_channels=C3_in,
            out_channels=C3_out,
            kernel_size=C3_kernel,
            stride=C3_stride,
            padding=C3_padding
        )
        self.conv4 = nn.Conv2d(
            in_channels=C4_in,
            out_channels=C4_out,
            kernel_size=C4_kernel,
            stride=C4_stride,
            padding=C4_padding
        )
        self.conv5 = nn.Conv2d(
            in_channels=C5_in,
            out_channels=C5_out,
            kernel_size=C5_kernel,
            stride=C5_stride,
            padding=C5_padding
        )
        self.fc1 = nn.Linear(in_features=FC1_in, out_features=FC1_out)
        self.fc2 = nn.Linear(in_features=FC2_in, out_features=FC2_out)
        self.fc3 = nn.Linear(in_features=FC3_in, out_features=FC3_out)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, mean=0.0, std=0.01)
                if m.bias is not None:
                    # conv1 and conv3 get bias = 0
                    if m is self.conv1 or m is self.conv3:
                        nn.init.constant_(m.bias, 0)
                    else:
                        nn.init.constant_(m.bias, 1)
    
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 1)
        
    def forward(self, X):
        X = self.conv1(X)
        X = self.activation(X)
        X = self.norm(X)
        X = self.pool(X)

        X = self.conv2(X)
        X = self.activation(X)
        X = self.norm(X)
        X = self.pool(X)

        X = self.conv3(X)
        X = self.activation(X)

        X = self.conv4(X)
        X = self.activation(X)

        X = self.conv5(X)
        X = self.activation(X)
        X = self.pool(X)

        X = X.view(X.size(0), -1)

        X = self.fc1(X)
        X = self.activation(X)
        X = self.dropout(X)

        X = self.fc2(X)
        X = self.activation(X)
        X = self.dropout(X)

        X = self.fc3(X)
        return X

## Augmentation

In [19]:
import torchvision
import torchvision.transforms as transforms

In [20]:
class LightingPCA(object):
    def __init__(self, alpha_std=0.1):
        self.alpha_std = alpha_std
        self.eigvals = torch.tensor([0.2175, 0.0188, 0.0045])
        self.eigvecs = torch.tensor([
            [-0.5675,  0.7192,  0.4009],
            [-0.5808, -0.0045, -0.8140],
            [-0.5836, -0.6948,  0.4203]
        ])

    def __call__(self, img):
        """
        img: Tensor assumed to be shape (C, H, W), float in [0, 1].
        """
        if self.alpha_std == 0:
            return img

        # Sample random alpha from N(0, 0.1)
        alpha = torch.normal(mean=0.0, std=self.alpha_std, size=(3,))

        # Compute RGB noise
        rgb = (self.eigvecs @ (alpha * self.eigvals)).view(3, 1, 1)

        return img + rgb

In [21]:
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    LightingPCA(alpha_std=0.01)
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.TenCrop(224),  # returns 10 PIL images
    transforms.Lambda(lambda crops: torch.stack([
        transforms.ToTensor()(crop) for crop in crops
    ]))  # shape (10, 3, 224, 224)
])

In [23]:
x = torch.randn(1, 3, 224, 224)

model = AlexNet(
    activation=None,
    pool=None,
    norm=None,
    dropout=None
)

with torch.no_grad():
    out = model(x)

print("Input:", x.shape)
print("Output:", out.shape)


Input: torch.Size([1, 3, 224, 224])
Output: torch.Size([1, 1000])
