## Replicating the AlexNet architechture using PyTorch

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

The original paper used model parallelism using two GPUs. Since I am using a mac, the GPU is integrated and there is no way to specify the different cores of the GPU (turns out, _cores_ in a mac are a different from cores in a Nvidia or AMD GPU). I will split the model between the GPU and the CPU.

In [3]:
device1 = torch.device("mps")
device2 = torch.device("cpu")

In [3]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1_top = nn.Conv2d(
            in_channels=3,
            out_channels=96,
            kernel_size=11,
            stride=4
        ).to(device1)

        self.conv1_bottom = nn.Conv2d(
            in_channels=3,
            out_channels=96,
            kernel_size=11,
            stride=4
        ).to(device2)

        self.conv2_top = nn.Conv2d(
            in_channels=48,
            out_channels=256,
            kernel_size=5,
        ).to(device1)

        self.conv2_top = nn.Conv2d(
            in_channels=48,
            out_channels=256,
            kernel_size=5,
        ).to(device2)

        # local response normalization layer with 
        # hyperparameters as described in the paper
        self.lrn = nn.LocalResponseNorm(
            size=5, 
            alpha=0.0001, 
            beta=0.75, 
            k=2
        )

        self.pool = nn.MaxPool2d(
            kernel_size=kernel_size
        )

    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
