<a href="https://colab.research.google.com/github/EggPudding/Deep-Learning-Paper-with-Codes/blob/main/AlexNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

### AlexNet 모델의 정의 및 초기화
---

In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.backends.cudnn as cudnn
import torch.optim as optim

import os

import torchsummary

class AlexNet(nn.Module):
  """ AlexNet Implementation
  Original Paper: https://proceedings.neurips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html

  원 논문에서 사용하는 GPU Paralell은 사용하지 않았습니다.
  추가로 MNIST를 대상으로 하므로 모델의 구조를 간단히 수정하였습니다.

  """
  def __init__(self, num_classes=10):
    super(AlexNet, self).__init__()

    self.feature_extract = nn.Sequential(
      nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),
      nn.Conv2d(64, 128, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),
      nn.Conv2d(128, 256, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),     
      nn.MaxPool2d(kernel_size=3, stride=2),  
    )


    self.classifier = nn.Sequential(
      nn.Dropout(),
      nn.Linear(256 * 2 * 2, 512),
      nn.ReLU(inplace=True),
      nn.Dropout(),
      nn.Linear(512, num_classes),       
    )

  def forward(self, x):
    x = self.feature_extract(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x

### 학습 환경 설명 및 학습(Train) 함수 정의
---

In [39]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AlexNet()
model.to(device)
model = torch.nn.DataParallel(model)

cudnn.benchmark = True

learning_rate = 0.01
batch_size = 128
max_epoch = 10

model_path = 'alexnet_mnist.pt'

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [40]:
torchsummary.summary(model, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 28, 28]             640
              ReLU-2           [-1, 64, 28, 28]               0
         MaxPool2d-3           [-1, 64, 13, 13]               0
            Conv2d-4          [-1, 128, 13, 13]          73,856
              ReLU-5          [-1, 128, 13, 13]               0
         MaxPool2d-6            [-1, 128, 6, 6]               0
            Conv2d-7            [-1, 256, 6, 6]         295,168
              ReLU-8            [-1, 256, 6, 6]               0
         MaxPool2d-9            [-1, 256, 2, 2]               0
          Dropout-10                 [-1, 1024]               0
           Linear-11                  [-1, 512]         524,800
             ReLU-12                  [-1, 512]               0
          Dropout-13                  [-1, 512]               0
           Linear-14                   

### 학습 (Train) & 검증 (Validation) 함수 정의
---

In [50]:
def train(epoch, max_epoch):
  print(f"Train Epoch [{epoch}/{max_epoch}]")
  model.train()

  train_loss = 0
  correct = 0
  total = 0
  acc = 0

  for idx, (x, y) in enumerate(train_dataloader):
    x = x.to(device)
    y = y.to(device)

    optimizer.zero_grad()

    y_pred = model(x)
    loss = criterion(y_pred, y)

    loss.backward()

    optimizer.step()
    train_loss += loss.item()
    _, inference = y_pred.max(1)

    total += x.size(0)
    correct += inference.eq(y).sum().item()

    if idx % 100 == 0:
      print(f"Epoch [{epoch}/{max_epoch}] Batch [{idx}] Train Loss: {loss.item()}")

  acc = 100*correct/total
  print(f"Epoch [{epoch}/{max_epoch}] Train Loss: {train_loss/total} Train Accuracy: {acc}")

def valid(epoch, max_epoch):
  print(f"Valid Epoch [{epoch}/{max_epoch}]")
  model.eval()
  
  valid_loss = 0
  correct = 0
  total = 0

  for idx, (x, y) in enumerate(valid_dataloader):
    x = x.to(device)
    y = y.to(device)

    with torch.no_grad():
      y_pred = model(x)
      valid_loss += criterion(y_pred, y).item()

      _, inference = y_pred.max(1)

      total += x.size(0)
      correct += inference.eq(y).sum().item()

  acc = 100*correct/total
  print(f"Epoch [{epoch}/{max_epoch}] Valid Loss: {valid_loss/total} Valid Accuracy: {acc}")

  if not os.path.exists('checkpoint'):
    os.mkdir('checkpoint')
  
  torch.save(model.state_dict(), f'checkpoint/{model_path}')
  print(f"Epoch [{epoch}/{max_epoch}] Valid Model Saved: checkpoint/{model_path}")

def lr_schedule(optimizer, epoch):
  if epoch == 5:
    lr = learning_rate / 10
    for param_group in optimizer.param_groups:
      param_group['lr'] = lr
  

### MNIST 데이터셋 다운로드 및 불러오기
---

In [51]:
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
valid_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

### 학습 (Train) 진행
---

In [52]:
for epoch in range(0, max_epoch):
  lr_schedule(optimizer, epoch)
  train(epoch, max_epoch)
  valid(epoch, max_epoch)

Train Epoch [0/10]
Epoch [0/10] Batch [0] Train Loss: 0.2873525619506836
Epoch [0/10] Batch [100] Train Loss: 0.21464131772518158
Epoch [0/10] Batch [200] Train Loss: 0.15093854069709778
Epoch [0/10] Batch [300] Train Loss: 0.17556892335414886
Epoch [0/10] Batch [400] Train Loss: 0.2250879406929016
Epoch [0/10] Train Loss: 0.0018991753535345197 Train Accuracy: 93.645
Valid Epoch [0/10]
Epoch [0/10] Valid Loss: 0.0009041530852671712 Valid Accuracy: 96.83
Epoch [0/10] Valid Model Saved: checkpoint/alexnet_mnist.pt
Train Epoch [1/10]
Epoch [1/10] Batch [0] Train Loss: 0.1981702744960785
Epoch [1/10] Batch [100] Train Loss: 0.29971346259117126
Epoch [1/10] Batch [200] Train Loss: 0.24769625067710876
Epoch [1/10] Batch [300] Train Loss: 0.26893478631973267
Epoch [1/10] Batch [400] Train Loss: 0.13225844502449036
Epoch [1/10] Train Loss: 0.001721328614745289 Train Accuracy: 94.10166666666667
Valid Epoch [1/10]
Epoch [1/10] Valid Loss: 0.0008588661203742958 Valid Accuracy: 96.91
Epoch [1/10] 