## Packages

In [None]:
import torch.nn as nn
import torch
import torch.optim as optim

from PIL import ImageFile
from torch.utils.data import DataLoader
from torchvision import transforms
from tools import ImageNet, init_weights, init_bias, training_loop

## Configuration

In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

trn_batch_size = 128
val_batch_size = 100
epoch_count = 90
learning_rate = 0.00001
momentum = 0.9
weight_decay = 0.0005
dropout_ratio = 0.5

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Datasets

In [None]:
imgnet_trn_set = ImageNet(transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop((227, 227)),
    transforms.ToTensor(),
]), is_train=True)

print("ImgNet set loaded...")

imgnet_trn_loader = DataLoader(
    dataset=imgnet_trn_set,
    batch_size=trn_batch_size,
    shuffle=True,
)

## Structure

In [None]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=0),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.fucn6 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=4096, kernel_size=6, stride=1, padding=0),
            nn.ReLU(),
            nn.Dropout(p=dropout_ratio),
        )
        self.fucn7 = nn.Sequential(
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(p=dropout_ratio),
        )
        self.fucn8 = nn.Sequential(
            nn.Linear(4096, 1000),
        )

    def forward(self, x):
        output = self.conv1(x)
        output = self.conv2(output)
        output = self.conv3(output)
        output = self.conv4(output)
        output = self.conv5(output)
        output = self.fucn6(output)
        output = self.fucn7(output.view(-1, 4096))
        output = self.fucn8(output)

        return output

ax_net = AlexNet()

init_weights(ax_net.conv1, mean=0.0, std=0.01)
init_weights(ax_net.conv2, mean=0.0, std=0.01)
init_weights(ax_net.conv3, mean=0.0, std=0.01)
init_weights(ax_net.conv4, mean=0.0, std=0.01)
init_weights(ax_net.conv5, mean=0.0, std=0.01)
init_weights(ax_net.fucn6, mean=0.0, std=0.01)
init_weights(ax_net.fucn7, mean=0.0, std=0.01)
init_weights(ax_net.fucn8, mean=0.0, std=0.01)

init_bias(ax_net.conv1, 0)
init_bias(ax_net.conv2, 1)
init_bias(ax_net.conv3, 0)
init_bias(ax_net.conv4, 1)
init_bias(ax_net.conv5, 1)
init_bias(ax_net.fucn6, 0)
init_bias(ax_net.fucn7, 0)
init_bias(ax_net.fucn8, 0)

ax_net.to(device)

## Training

In [None]:
training_loop(
    n_epochs=epoch_count,
    optimizer=optim.SGD(
        ax_net.parameters(),
        lr=learning_rate,
        momentum=momentum,
        weight_decay=weight_decay,
    ),
    model=ax_net,
    loss_fn=nn.CrossEntropyLoss(),
    dev=device,
    loader=imgnet_trn_loader,
)

torch.save(ax_net.state_dict(), "AlexNet_Weights.pt")
print("AlexNet parameters saved...")