# AlexNet

Paper: https://www.cs.toronto.edu/~kriz/imagenet_classification_with_deep_convolutional.pdf
<img src=https://www.learnopencv.com/wp-content/uploads/2018/05/AlexNet-1.png width="600">

# Implementation

In [2]:
import os
import torch,torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

In [3]:
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000): #ImageNet base
        super(AlexNet,self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=0), #227->55
            nn.ReLU(),
            nn.LocalResponseNorm(size=5,k=2),
            nn.MaxPool2d(kernel_size=3,stride=2), #55->27

            nn.Conv2d(96,256,5,1,2),
            nn.ReLU(),
            nn.LocalResponseNorm(5,k=2),
            nn.MaxPool2d(3,2), #27->13

            nn.Conv2d(256,384,3,1,1),
            nn.ReLU(),
            nn.LocalResponseNorm(5,k=2),

            nn.Conv2d(384,384,3,1,1),
            nn.ReLU(),
            nn.LocalResponseNorm(5,k=2),

            nn.Conv2d(384,256,3,1,1),
            nn.ReLU(),
            nn.LocalResponseNorm(5,k=2),
            nn.MaxPool2d(3,2) #13->6
        )
        self.fc=nn.Sequential(
            nn.Linear(6*6*256,4096),nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096,4096),nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096,num_classes)
        )

    def forward(self,x):
        x=self.conv(x)
        x=torch.flatten(x,1) #(B,C,H,W)
        x=self.fc(x)
        return x

    def init_bias(self):
        for i,layer in enumerate(self.conv):
            if isinstance(layer,nn.Conv2d):
                nn.init.normal_(layer.weight,mean=0,std=0.01)
                nn.init.constant_(layer.bias,0 if i in [0,8] else 1) #set bias to 0 for 1st&3rd conv layers

In [3]:
def alexnet(my=True,pretrained=False,progress=True,**kwargs):
    if my:
        model=AlexNet(**kwargs)
    else:
        model=models.alexnet(pretrained=pretrained,progress=progress,**kwargs)
    return model

# Run

In [4]:
import import_ipynb
from ImageLoader import SimpleLoader
import tnt

importing Jupyter notebook from ImageLoader.ipynb
importing Jupyter notebook from tnt.ipynb


In [9]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE=128
VALID_SIZE=0.1
NUM_EPOCHS=30 #90
PATH='checkpoint/AlexNet'

## Train

In [10]:
sl=SimpleLoader(dataset='CIFAR10',batch_size=BATCH_SIZE,crop_size=227,split=True)
num_classes=sl.GetNumClasses()
train_loader=sl.GetTrainLoader()
valid_loader=sl.GetValidLoader()

Files already downloaded and verified
Files already downloaded and verified


In [11]:
model=alexnet(num_classes=num_classes).to(device)
criterion=nn.CrossEntropyLoss().to(device)
#optimizer=optim.SGD(params=model.parameters(),lr=0.01,momentum=0.9,weight_decay=0.0005)
optimizer=optim.Adam(params=model.parameters(),lr=0.0001) #SGD does not converge well
#...divide the learning rate by 10 when the validation error rate stopped improving with the current learning rate.
#instead of following the heuristic strategy of the paper, use StepLR scheduler
scheduler=optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.1)

In [12]:
model.init_bias()
tnt.train(model,device,NUM_EPOCHS,train_loader,valid_loader,criterion,optimizer,scheduler,save=1,path=PATH)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch  1 	 train_loss: 1.95730 	 top1_acc: 38.94% 	 top5_acc: 88.75%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch  2 	 train_loss: 1.55427 	 top1_acc: 46.65% 	 top5_acc: 91.53%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch  3 	 train_loss: 1.37374 	 top1_acc: 50.88% 	 top5_acc: 93.01%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch  4 	 train_loss: 1.22847 	 top1_acc: 56.78% 	 top5_acc: 94.92%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch  5 	 train_loss: 1.07983 	 top1_acc: 61.47% 	 top5_acc: 95.09%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch  6 	 train_loss: 0.95665 	 top1_acc: 65.13% 	 top5_acc: 96.18%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch  7 	 train_loss: 0.84544 	 top1_acc: 67.41% 	 top5_acc: 96.65%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch  8 	 train_loss: 0.75472 	 top1_acc: 70.09% 	 top5_acc: 97.22%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch  9 	 train_loss: 0.67195 	 top1_acc: 70.99% 	 top5_acc: 97.14%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 10 	 train_loss: 0.59622 	 top1_acc: 72.10% 	 top5_acc: 97.39%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 11 	 train_loss: 0.53688 	 top1_acc: 72.06% 	 top5_acc: 97.41%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 12 	 train_loss: 0.46926 	 top1_acc: 73.09% 	 top5_acc: 97.78%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 13 	 train_loss: 0.41085 	 top1_acc: 74.22% 	 top5_acc: 97.66%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 14 	 train_loss: 0.36053 	 top1_acc: 73.56% 	 top5_acc: 97.94%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 15 	 train_loss: 0.31496 	 top1_acc: 72.70% 	 top5_acc: 97.45%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 16 	 train_loss: 0.28151 	 top1_acc: 74.16% 	 top5_acc: 97.64%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 17 	 train_loss: 0.23890 	 top1_acc: 75.02% 	 top5_acc: 97.72%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 18 	 train_loss: 0.21736 	 top1_acc: 74.12% 	 top5_acc: 97.62%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 19 	 train_loss: 0.18950 	 top1_acc: 74.77% 	 top5_acc: 97.68%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 20 	 train_loss: 0.16994 	 top1_acc: 75.29% 	 top5_acc: 97.62%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 21 	 train_loss: 0.15673 	 top1_acc: 75.37% 	 top5_acc: 97.94%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 22 	 train_loss: 0.14499 	 top1_acc: 75.33% 	 top5_acc: 97.72%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 23 	 train_loss: 0.13502 	 top1_acc: 74.88% 	 top5_acc: 97.76%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 24 	 train_loss: 0.11618 	 top1_acc: 75.33% 	 top5_acc: 97.25%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 25 	 train_loss: 0.11522 	 top1_acc: 75.29% 	 top5_acc: 97.88%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 26 	 train_loss: 0.10562 	 top1_acc: 74.18% 	 top5_acc: 97.96%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 27 	 train_loss: 0.10217 	 top1_acc: 75.02% 	 top5_acc: 97.64%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 28 	 train_loss: 0.09495 	 top1_acc: 74.53% 	 top5_acc: 97.57%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 29 	 train_loss: 0.09140 	 top1_acc: 75.58% 	 top5_acc: 97.70%


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=351.0), HTML(value='')))

Epoch 30 	 train_loss: 0.08904 	 top1_acc: 75.93% 	 top5_acc: 97.86%


## Test

In [13]:
model=alexnet(num_classes=num_classes).to(device)
model.load_state_dict(torch.load(os.path.join(PATH,'CIFAR10_e30_best')))

<All keys matched successfully>

In [14]:
test_loader=sl.GetTestLoader()
tnt.test(model,device,test_loader,criterion)

loss: 0.96674 	 top1_acc: 78.28% 	 top5_acc: 98.08%
