In [1]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm
from torchsummary import summary

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
BATCH_SIZE = 128
MOMENTUM = 0.9
LR = 0.001
N_EPOCH = 30     # 논문에서는 90
WEIGHT_DECAY = 0.0005

In [4]:
cifar_train = torchvision.datasets.CIFAR100(root='./data/train/',
                                            train=True,
                                            download=True)

cifar_test = torchvision.datasets.CIFAR100(root='./data/test/',
                                           train=False, 
                                           download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/train/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/train/cifar-100-python.tar.gz to ./data/train/
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/test/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/test/cifar-100-python.tar.gz to ./data/test/


In [5]:
cifar_train2 = cifar_train.data / 255.

In [6]:
mean_R = np.mean(cifar_train2[..., 0])
mean_G = np.mean(cifar_train2[..., 1])
mean_B = np.mean(cifar_train2[..., 2])

std_R = np.std(cifar_train2[..., 0])
std_G = np.std(cifar_train2[..., 1])
std_B = np.std(cifar_train2[..., 2])

In [7]:
train_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Resize((224, 224)),
                                    #   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      transforms.Normalize((mean_R, mean_G, mean_B), (std_R, std_G, std_B)),
                                      transforms.RandomHorizontalFlip()
                                     ])

test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Resize((224, 224)),
                                    #  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                     transforms.Normalize((mean_R, mean_G, mean_B), (std_R, std_G, std_B))
                                    ])

cifar_train.transform = train_transform
cifar_test.transform = test_transform

In [8]:
trainLoader = torch.utils.data.DataLoader(cifar_train,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True)

testLoader = torch.utils.data.DataLoader(cifar_test,
                                         batch_size=BATCH_SIZE,
                                         shuffle=False)

In [9]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()

        self.conv_block = nn.Sequential(nn.Conv2d(3, 96, kernel_size=(11, 11), stride=4),
                                        nn.ReLU(),
                                        nn.LocalResponseNorm(size=5, alpha=1e-5, beta=0.75, k=2),
                                        nn.MaxPool2d(kernel_size=3, stride=2),
                                        nn.Conv2d(96, 256, kernel_size=(5, 5), padding='same'),
                                        nn.ReLU(),
                                        nn.LocalResponseNorm(size=5, alpha=1e-5, beta=0.75, k=2),
                                        nn.MaxPool2d(kernel_size=3, stride=2),
                                        nn.Conv2d(256, 384, kernel_size=(3, 3), padding='same'),
                                        nn.ReLU(),
                                        nn.Conv2d(384, 384, kernel_size=(3, 3), padding='same'),
                                        nn.ReLU(),
                                        nn.Conv2d(384, 256, kernel_size=(3, 3), padding=1),
                                        nn.ReLU(),
                                        nn.MaxPool2d(kernel_size=3, stride=2)                     
                                      )
    
        self.classifier = nn.Sequential(nn.Dropout(p=0.5),
                                        nn.Linear(6400, 4096),
                                        nn.ReLU(),
                                        nn.Dropout(p=0.5),
                                        nn.Linear(4096, 4096),
                                        nn.ReLU(),
                                        nn.Linear(4096, 100)
                                      )

        # weight initializaiton
        for layer in self.conv_block:
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weight, mean=0, std=0.01)

        # bias initialization
        nn.init.constant_(self.conv_block[0].bias, 0)      # Conv 1
        nn.init.constant_(self.conv_block[8].bias, 0)      # Conv 3

        nn.init.constant_(self.conv_block[4].bias, 1)      # Conv 2
        nn.init.constant_(self.conv_block[10].bias, 1)      # Conv 4
        nn.init.constant_(self.conv_block[12].bias, 1)      # Conv 5
        nn.init.constant_(self.classifier[1].bias, 1)      # fc 1
        nn.init.constant_(self.classifier[4].bias, 1)      # fc 2
        

    def forward(self, x):
        x = self.conv_block(x)
        x = torch.flatten(x, start_dim=1)
        out = self.classifier(x)
        return out

In [10]:
alexnet = AlexNet().to(device)

In [11]:
summary(alexnet, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 96, 54, 54]          34,944
              ReLU-2           [-1, 96, 54, 54]               0
 LocalResponseNorm-3           [-1, 96, 54, 54]               0
         MaxPool2d-4           [-1, 96, 26, 26]               0
            Conv2d-5          [-1, 256, 26, 26]         614,656
              ReLU-6          [-1, 256, 26, 26]               0
 LocalResponseNorm-7          [-1, 256, 26, 26]               0
         MaxPool2d-8          [-1, 256, 12, 12]               0
            Conv2d-9          [-1, 384, 12, 12]         885,120
             ReLU-10          [-1, 384, 12, 12]               0
           Conv2d-11          [-1, 384, 12, 12]       1,327,488
             ReLU-12          [-1, 384, 12, 12]               0
           Conv2d-13          [-1, 256, 12, 12]         884,992
             ReLU-14          [-1, 256,

In [12]:
## 논문에서 사용된 optimizer인데 학습이 안 됨...
# optimizer = torch.optim.SGD(alexnet.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
optimizer = torch.optim.Adam(alexnet.parameters(), lr=LR)

## 논문에서는 총 3번에 걸쳐 LR을 1/10로 줄임 
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

loss_fn = nn.CrossEntropyLoss()

In [13]:
def train_loop(model, dataLoader):
    train_loss = 0
    model.train()
    for image, label in tqdm(dataLoader):
        image, label = image.to(device), label.to(device)
        output = model(image)
        loss = loss_fn(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    return train_loss / len(dataLoader)

In [14]:
def val_loop(model, dataLoader):
    val_loss = 0
    val_acc = 0
    model.eval()
    with torch.no_grad():
        for image, label in tqdm(dataLoader):
            image, label = image.to(device), label.to(device)
            output = model(image)
            loss = loss_fn(output, label)

            val_loss += loss.item()
            val_acc += (torch.argmax(output, dim=1) == label).sum().item()
    
    return val_loss / len(dataLoader), val_acc / len(dataLoader.dataset) * 100

In [15]:
for epoch in range(1, N_EPOCH+1):
    print(f'[[ EPOCH {epoch} ]]')
    train_loss = train_loop(alexnet, trainLoader)
    val_loss, val_acc = val_loop(alexnet, testLoader)
    print('\nTrain Loss : {:.4f}, Val Loss : {:.4f}, Val Accuracy : {:.2f} %\n'.format(
           train_loss, val_loss, val_acc
    ))
    lr_scheduler.step()

[[ EPOCH 1 ]]


100%|██████████| 391/391 [04:18<00:00,  1.51it/s]
100%|██████████| 79/79 [00:23<00:00,  3.38it/s]



Train Loss : 4.6808, Val Loss : 4.0425, Val Accuracy : 6.71 %

[[ EPOCH 2 ]]


100%|██████████| 391/391 [04:15<00:00,  1.53it/s]
100%|██████████| 79/79 [00:23<00:00,  3.40it/s]



Train Loss : 3.8928, Val Loss : 3.6515, Val Accuracy : 13.10 %

[[ EPOCH 3 ]]


100%|██████████| 391/391 [04:15<00:00,  1.53it/s]
100%|██████████| 79/79 [00:23<00:00,  3.39it/s]



Train Loss : 3.5661, Val Loss : 3.2724, Val Accuracy : 19.74 %

[[ EPOCH 4 ]]


100%|██████████| 391/391 [04:14<00:00,  1.54it/s]
100%|██████████| 79/79 [00:23<00:00,  3.38it/s]



Train Loss : 3.2870, Val Loss : 3.1136, Val Accuracy : 23.78 %

[[ EPOCH 5 ]]


100%|██████████| 391/391 [04:14<00:00,  1.54it/s]
100%|██████████| 79/79 [00:23<00:00,  3.40it/s]



Train Loss : 3.0803, Val Loss : 2.8676, Val Accuracy : 27.99 %

[[ EPOCH 6 ]]


100%|██████████| 391/391 [04:13<00:00,  1.54it/s]
100%|██████████| 79/79 [00:23<00:00,  3.41it/s]



Train Loss : 2.9029, Val Loss : 2.7423, Val Accuracy : 30.94 %

[[ EPOCH 7 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:22<00:00,  3.44it/s]



Train Loss : 2.7752, Val Loss : 2.6381, Val Accuracy : 32.50 %

[[ EPOCH 8 ]]


100%|██████████| 391/391 [04:11<00:00,  1.56it/s]
100%|██████████| 79/79 [00:22<00:00,  3.47it/s]



Train Loss : 2.6685, Val Loss : 2.4574, Val Accuracy : 36.99 %

[[ EPOCH 9 ]]


100%|██████████| 391/391 [04:11<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.40it/s]



Train Loss : 2.5633, Val Loss : 2.5453, Val Accuracy : 34.67 %

[[ EPOCH 10 ]]


100%|██████████| 391/391 [04:13<00:00,  1.54it/s]
100%|██████████| 79/79 [00:23<00:00,  3.41it/s]



Train Loss : 2.4807, Val Loss : 2.3987, Val Accuracy : 37.75 %

[[ EPOCH 11 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.40it/s]



Train Loss : 2.2188, Val Loss : 2.1778, Val Accuracy : 42.52 %

[[ EPOCH 12 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.40it/s]



Train Loss : 2.1120, Val Loss : 2.1147, Val Accuracy : 44.45 %

[[ EPOCH 13 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.40it/s]



Train Loss : 2.0270, Val Loss : 2.0970, Val Accuracy : 44.58 %

[[ EPOCH 14 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.40it/s]



Train Loss : 1.9608, Val Loss : 2.0303, Val Accuracy : 46.15 %

[[ EPOCH 15 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.41it/s]



Train Loss : 1.8819, Val Loss : 2.0145, Val Accuracy : 46.60 %

[[ EPOCH 16 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:22<00:00,  3.46it/s]



Train Loss : 1.8254, Val Loss : 2.0204, Val Accuracy : 46.91 %

[[ EPOCH 17 ]]


100%|██████████| 391/391 [04:10<00:00,  1.56it/s]
100%|██████████| 79/79 [00:22<00:00,  3.47it/s]



Train Loss : 1.7618, Val Loss : 1.9934, Val Accuracy : 47.19 %

[[ EPOCH 18 ]]


100%|██████████| 391/391 [04:10<00:00,  1.56it/s]
100%|██████████| 79/79 [00:23<00:00,  3.41it/s]



Train Loss : 1.7004, Val Loss : 1.9773, Val Accuracy : 48.05 %

[[ EPOCH 19 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.40it/s]



Train Loss : 1.6429, Val Loss : 1.9275, Val Accuracy : 48.96 %

[[ EPOCH 20 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.41it/s]



Train Loss : 1.6020, Val Loss : 1.9297, Val Accuracy : 49.27 %

[[ EPOCH 21 ]]


100%|██████████| 391/391 [04:13<00:00,  1.54it/s]
100%|██████████| 79/79 [00:23<00:00,  3.41it/s]



Train Loss : 1.3929, Val Loss : 1.8686, Val Accuracy : 50.64 %

[[ EPOCH 22 ]]


100%|██████████| 391/391 [04:13<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.42it/s]



Train Loss : 1.3039, Val Loss : 1.8537, Val Accuracy : 51.10 %

[[ EPOCH 23 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.42it/s]



Train Loss : 1.2414, Val Loss : 1.8607, Val Accuracy : 51.41 %

[[ EPOCH 24 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.42it/s]



Train Loss : 1.1890, Val Loss : 1.8544, Val Accuracy : 51.86 %

[[ EPOCH 25 ]]


100%|██████████| 391/391 [04:12<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.41it/s]



Train Loss : 1.1359, Val Loss : 1.8614, Val Accuracy : 52.31 %

[[ EPOCH 26 ]]


100%|██████████| 391/391 [04:11<00:00,  1.55it/s]
100%|██████████| 79/79 [00:22<00:00,  3.48it/s]



Train Loss : 1.0871, Val Loss : 1.8634, Val Accuracy : 52.09 %

[[ EPOCH 27 ]]


100%|██████████| 391/391 [04:10<00:00,  1.56it/s]
100%|██████████| 79/79 [00:22<00:00,  3.50it/s]



Train Loss : 1.0361, Val Loss : 1.8710, Val Accuracy : 52.69 %

[[ EPOCH 28 ]]


100%|██████████| 391/391 [04:11<00:00,  1.55it/s]
100%|██████████| 79/79 [00:23<00:00,  3.41it/s]



Train Loss : 0.9929, Val Loss : 1.8753, Val Accuracy : 52.28 %

[[ EPOCH 29 ]]


100%|██████████| 391/391 [04:13<00:00,  1.54it/s]
100%|██████████| 79/79 [00:23<00:00,  3.40it/s]



Train Loss : 0.9507, Val Loss : 1.9087, Val Accuracy : 52.60 %

[[ EPOCH 30 ]]


100%|██████████| 391/391 [04:13<00:00,  1.54it/s]
100%|██████████| 79/79 [00:23<00:00,  3.40it/s]


Train Loss : 0.9206, Val Loss : 1.8911, Val Accuracy : 52.66 %




