In [None]:
import torch
import torch.nn as nn
from torchvision import datasets,transforms
from torch.utils.data import DataLoader

#### create the model architecture from the paper 

In [None]:
class AlexNet(nn.Module):
    def __init__(self, no_of_classes):
        super(AlexNet,self).__init__()
        #construct the cnn layers with sequential
        self.convs = nn.Sequential(
        #1st conv layer
        nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=1),
        nn.LocalResponseNorm(size=5),
        nn.MaxPool2d(kernel_size=3,stride=2),
        nn.ReLU(),
        #2nd conv layer
        nn.Conv2d(in_channels=96,out_channels=256,kernel_size=5),
        nn.LocalResponseNorm(size=5),
        nn.MaxPool2d(kernel_size=3,stride=2),
        nn.ReLU(),
        #3rd conv layer
        nn.Conv2d(in_channels=256,out_channels=384,kernel_size=3),
        nn.ReLU(),
        #4th conv layer
        nn.Conv2d(in_channels=384,out_channels=192,kernel_size=3),
        nn.ReLU(),
        #5th conv layer
        nn.Conv2d(in_channels=192,out_channels=256,kernel_size=3),
        nn.MaxPool2d(kernel_size=3,stride=2),
        nn.ReLU()
        )
        self.classifier=nn.Sequential(
        nn.Flatten(),
        nn.Dropout(p=0.5),
        #1st dense layer
        nn.Linear(in_features=1024,out_features=4096),
        nn.ReLU(),
        nn.Dropout(p=0.5),
        #2nd dense layer
        nn.Linear(in_features=4096,out_features=4096),
        nn.ReLU(),
        #3rd dense layer
        nn.Linear(in_features=4096,out_features=4096),
        nn.Softmax()
        )
        
        def init_parameter():
            #We initialized the weights in each layer from a zero-mean Gaussian distribution with standard deviation 0.01
            for layer in self.convs:
                nn.init.normal_(layer.weight,mean=0,std=0.1),
                nn.init.constant_(layer.bias,0)
            nn.init.constant_(self.convs[4].bias,1)
            nn.init.constant_(self.convs[10].bias,1)
            nn.init.constant_(self.convs[12].bias,1)
            nn.init.constant_(self.classifier[2].bias,1)
            nn.init.constant_(self.classifier[5].bias,1)
            nn.init.constant_(self.classifier[7].bias,1)
            
        def forward(self,x):
            #feed forward
            x=self.convs(x)
            x=self.classifier(x)
            return x

In [None]:
#constants
GPUS=[0]
EPOCH=90
NO_CLASSES=1000
TRAIN_DIR=''
VAL_DIR=''
CHECKPOINT_DIR='/'
IMG_DIM=224
BATCH_SIZE=128
L_RATE=0.01
W_DECAY=0.0005
MOMENTUM=0.9
device='cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
seed = torch.initial_seed()

In [None]:
#create model
model = AlexNet(NO_CLASSES).to(device)

In [None]:
#train with multiple GPU
model = torch.nn.parallel.DataParallel(model,device_ids=GPUS)
print(model)

### other aspects

In [None]:
#image augmentation and transformation
data_transform = transforms.Compose([
    transforms.CenterCrop(IMG_DIM),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])
])

In [None]:
#prepare the dataset
train_dataset=datasets.ImageFolder(TRAIN_DIR,data_transform)
val_dataset = datasets.ImageFolder(VAL_DIR)
train_loader = DataLoader(train_dataset, shuffle=True,batch_size=BATCH_SIZE)
val_loader = DataLoader(val_dataset, shuffle=True,batch_size=BATCH_SIZE)

In [None]:
#optimizer
optim = torch.optim.SGD(model.parameters(),lr=L_RATE,momentum=MOMENTUM,weight_decay=W_DECAY)
#loss function
loss=nn.CrossEntropyLoss()
#decay learning rate
lr_scheduler = torch.optim.lr_scheduler.StepLR(optim,step_size=50,gamma=0.1)
total_step=1

In [None]:
#training
for epoch in range(EPOCHS):
    for step(X,y) in enumerate(train_loader):
        X,y=X.to(device),y.to(device)
        optim.zero_grad() #refresh gradient
        pred=model(X) # forward pass
        loss=loss(pred,y).to(device) #take loss
        loss.backward() #backward pass
        optim.step() #take step
        if total_step % 10 == 0:
            print(f'step:{total_step} | Loss: {loss}')
        total_step +=1
        
    checkpoint_path = os.path.join(CHECKPOINT_DIR,f'model_checkpoint{epoch+1}.pkl')
    state = {
        'epoch':epoch,
        'total_step':total_step,
        'optimizer':optim.state_dict(),
        'model':model.state_dict(),
        'seed':seed
    }
    torch.save(state,checkpoint_path)