In [15]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm
from torchsummary import summary

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [17]:
BATCH_SIZE = 128
MOMENTUM = 0.9
LR = 0.001
N_EPOCH = 10     # 논문에서는 90
WEIGHT_DECAY = 0.0005

In [18]:
cifar_train = torchvision.datasets.CIFAR100(root='./data/train/',
                                            train=True,
                                            download=True)

cifar_test = torchvision.datasets.CIFAR100(root='./data/test/',
                                           train=False, 
                                           download=True)

Files already downloaded and verified
Files already downloaded and verified


In [19]:
mean_R = np.mean(cifar_train.data[..., 0])
mean_G = np.mean(cifar_train.data[..., 1])
mean_B = np.mean(cifar_train.data[..., 2])

std_R = np.std(cifar_train.data[..., 0])
std_G = np.std(cifar_train.data[..., 1])
std_B = np.std(cifar_train.data[..., 2])

In [20]:
train_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Resize((224, 224)),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      transforms.RandomHorizontalFlip()
                                     ])

test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Resize((224, 224)),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                    ])

cifar_train.transform = train_transform
cifar_test.transform = test_transform

In [21]:
trainLoader = torch.utils.data.DataLoader(cifar_train,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True)

testLoader = torch.utils.data.DataLoader(cifar_test,
                                         batch_size=BATCH_SIZE,
                                         shuffle=False)

In [22]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()

        self.conv_block = nn.Sequential(nn.Conv2d(3, 96, kernel_size=(11, 11), stride=4),
                                        nn.ReLU(),
                                        nn.LocalResponseNorm(size=5, alpha=1e-5, beta=0.75, k=2),
                                        nn.MaxPool2d(kernel_size=3, stride=2),
                                        nn.Conv2d(96, 256, kernel_size=(5, 5), padding='same'),
                                        nn.ReLU(),
                                        nn.LocalResponseNorm(size=5, alpha=1e-5, beta=0.75, k=2),
                                        nn.MaxPool2d(kernel_size=3, stride=2),
                                        nn.Conv2d(256, 384, kernel_size=(3, 3), padding='same'),
                                        nn.ReLU(),
                                        nn.Conv2d(384, 384, kernel_size=(3, 3), padding='same'),
                                        nn.ReLU(),
                                        nn.Conv2d(384, 256, kernel_size=(3, 3), padding=1),
                                        nn.ReLU(),
                                        nn.MaxPool2d(kernel_size=3, stride=2)                     
                                      )
    
        self.classifier = nn.Sequential(nn.Dropout2d(p=0.5),
                                        nn.Linear(6400, 4096),
                                        nn.ReLU(),
                                        nn.Dropout2d(p=0.5),
                                        nn.Linear(4096, 4096),
                                        nn.ReLU(),
                                        nn.Linear(4096, 100)
                                      )

        # weight initializaiton
        for layer in self.conv_block:
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weight, mean=0, std=0.01)

        # bias initialization
        nn.init.constant_(self.conv_block[0].bias, 0)      # Conv 1
        nn.init.constant_(self.conv_block[8].bias, 0)      # Conv 3

        nn.init.constant_(self.conv_block[4].bias, 1)      # Conv 2
        nn.init.constant_(self.conv_block[10].bias, 1)      # Conv 4
        nn.init.constant_(self.conv_block[12].bias, 1)      # Conv 5
        nn.init.constant_(self.classifier[1].bias, 1)      # fc 1
        nn.init.constant_(self.classifier[4].bias, 1)      # fc 2
        

    def forward(self, x):
        x = self.conv_block(x)
        x = torch.flatten(x, start_dim=1)
        out = self.classifier(x)
        return out

In [23]:
alexnet = AlexNet().to(device)

In [24]:
summary(alexnet, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 54, 54]          34,944
              ReLU-2           [-1, 96, 54, 54]               0
 LocalResponseNorm-3           [-1, 96, 54, 54]               0
         MaxPool2d-4           [-1, 96, 26, 26]               0
            Conv2d-5          [-1, 256, 26, 26]         614,656
              ReLU-6          [-1, 256, 26, 26]               0
 LocalResponseNorm-7          [-1, 256, 26, 26]               0
         MaxPool2d-8          [-1, 256, 12, 12]               0
            Conv2d-9          [-1, 384, 12, 12]         885,120
             ReLU-10          [-1, 384, 12, 12]               0
           Conv2d-11          [-1, 384, 12, 12]       1,327,488
             ReLU-12          [-1, 384, 12, 12]               0
           Conv2d-13          [-1, 256, 12, 12]         884,992
             ReLU-14          [-1, 256,

In [25]:
## 논문에서 사용된 optimizer인데 학습이 안 됨...
# optimizer = torch.optim.SGD(alexnet.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

## 논문에서는 총 3번에 걸쳐 LR을 1/10로 줄임 
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

optimizer = torch.optim.Adam(alexnet.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [26]:
def train_loop(model, dataLoader):
    train_loss = 0
    model.train()
    for image, label in tqdm(dataLoader):
        image, label = image.to(device), label.to(device)
        output = model(image)
        loss = loss_fn(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    return train_loss / len(dataLoader)

In [27]:
def val_loop(model, dataLoader):
    val_loss = 0
    val_acc = 0
    model.eval()
    with torch.no_grad():
        for image, label in tqdm(dataLoader):
            image, label = image.to(device), label.to(device)
            output = model(image)
            loss = loss_fn(output, label)

            val_loss += loss.item()
            val_acc += (torch.argmax(output, dim=1) == label).sum().item()
    
    return val_loss / len(dataLoader), val_acc / len(dataLoader.dataset) * 100

In [28]:
for epoch in range(1, N_EPOCH+1):
    print(f'[[ EPOCH {epoch} ]]')
    train_loss = train_loop(alexnet, trainLoader)
    val_loss, val_acc = val_loop(alexnet, testLoader)
    print('\nTrain Loss : {:.4f}, Val Loss : {:.4f}, Val Accuracy : {:.2f} %\n'.format(
           train_loss, val_loss, val_acc
    ))

[[ EPOCH 1 ]]


100%|██████████| 391/391 [04:59<00:00,  1.31it/s]
100%|██████████| 79/79 [00:31<00:00,  2.50it/s]



Train Loss : 4.7012, Val Loss : 4.0186, Val Accuracy : 6.57 %

[[ EPOCH 2 ]]


100%|██████████| 391/391 [05:01<00:00,  1.30it/s]
100%|██████████| 79/79 [00:32<00:00,  2.45it/s]



Train Loss : 3.9197, Val Loss : 3.6827, Val Accuracy : 11.94 %

[[ EPOCH 3 ]]


100%|██████████| 391/391 [05:02<00:00,  1.29it/s]
100%|██████████| 79/79 [00:32<00:00,  2.45it/s]



Train Loss : 3.6172, Val Loss : 3.4200, Val Accuracy : 16.93 %

[[ EPOCH 4 ]]


100%|██████████| 391/391 [05:00<00:00,  1.30it/s]
100%|██████████| 79/79 [00:31<00:00,  2.47it/s]



Train Loss : 3.3716, Val Loss : 3.1707, Val Accuracy : 22.24 %

[[ EPOCH 5 ]]


100%|██████████| 391/391 [04:59<00:00,  1.31it/s]
100%|██████████| 79/79 [00:32<00:00,  2.47it/s]



Train Loss : 3.1582, Val Loss : 2.9314, Val Accuracy : 26.90 %

[[ EPOCH 6 ]]


100%|██████████| 391/391 [05:00<00:00,  1.30it/s]
100%|██████████| 79/79 [00:31<00:00,  2.48it/s]



Train Loss : 2.9695, Val Loss : 2.7380, Val Accuracy : 30.64 %

[[ EPOCH 7 ]]


100%|██████████| 391/391 [04:59<00:00,  1.30it/s]
100%|██████████| 79/79 [00:32<00:00,  2.47it/s]



Train Loss : 2.8242, Val Loss : 2.5952, Val Accuracy : 32.91 %

[[ EPOCH 8 ]]


100%|██████████| 391/391 [04:58<00:00,  1.31it/s]
100%|██████████| 79/79 [00:31<00:00,  2.49it/s]



Train Loss : 2.6824, Val Loss : 2.5162, Val Accuracy : 34.95 %

[[ EPOCH 9 ]]


100%|██████████| 391/391 [04:59<00:00,  1.31it/s]
100%|██████████| 79/79 [00:32<00:00,  2.46it/s]



Train Loss : 2.5933, Val Loss : 2.4133, Val Accuracy : 36.60 %

[[ EPOCH 10 ]]


100%|██████████| 391/391 [05:00<00:00,  1.30it/s]
100%|██████████| 79/79 [00:32<00:00,  2.47it/s]


Train Loss : 2.4729, Val Loss : 2.3949, Val Accuracy : 37.73 %




