In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchsummary import summary

from torch.utils import data
import torchvision.datasets as datasets
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# ReLU 활성화 함수 (Section 3.1) - 학습 시간 단축
# LRN (Section 3.3) - 과적합 방지
# Overlapping Pooling (Section 3.4) - 과적합 방지
# Dropout (Section 4.2) - 과적합 방지
# Bias Initialization (Section 5) - ReLU 함수에 양의 입력 -> 초기 단계 학습 가속
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000): # Section 3.5
        # 입력 크기: (b x 3 x 227 x 227)
        # 논문 상에선 입력 크기가 224라고 작성되어 있지만 실제론 227
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4), # (b x 96 x 55 x 55)
            nn.ReLU(), # Section 3.1
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2), # Section 3.3
            nn.MaxPool2d(kernel_size=3, stride=2), # Section 3.4 (b x 96 x 27 x 27)
            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2), # (b x 256 x 27 x 27)
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2), # (b x 256 x 13 x 13)
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1), # (b x 384 x 13 x 13)
            nn.ReLU(),
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1), # (b x 384 x 13 x 13)
            nn.ReLU(),
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1), # (b x 256 x 13 x 13)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2) # (b x 256 x 6 x 6)
        )

        self.fc = nn.Sequential(
            nn.Dropout(p=0.5), # Section 4.2
            nn.Linear(in_features=(256 * 6 * 6), out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=num_classes)
        )

        self.init_wb()

    def init_wb(self): # Section 5
        for layer in self.conv:
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weight, mean=0, std=0.01)
                nn.init.constant_(layer.bias, 0)

        nn.init.constant_(self.conv[4].bias, 1)
        nn.init.constant_(self.conv[10].bias, 1)
        nn.init.constant_(self.conv[12].bias, 1)

        for layer in self.fc:
            if isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, mean=0, std=0.01)
                nn.init.constant_(layer.bias, 1)

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [10]:
# Details of Learning (Section 5)
# weight decay: 학습 오류 감소
model = AlexNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    params=model.parameters(),
    lr=0.01,
    momentum=0.9,
    weight_decay=0.0005
)

lr_scheduler = optim.lr_scheduler.StepLR(
    optimizer=optimizer,
    step_size=30,
    gamma=0.1
)

summary(model, input_size=(3, 227, 227), batch_size=128)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [128, 96, 55, 55]          34,944
              ReLU-2          [128, 96, 55, 55]               0
 LocalResponseNorm-3          [128, 96, 55, 55]               0
         MaxPool2d-4          [128, 96, 27, 27]               0
            Conv2d-5         [128, 256, 27, 27]         614,656
              ReLU-6         [128, 256, 27, 27]               0
 LocalResponseNorm-7         [128, 256, 27, 27]               0
         MaxPool2d-8         [128, 256, 13, 13]               0
            Conv2d-9         [128, 384, 13, 13]         885,120
             ReLU-10         [128, 384, 13, 13]               0
           Conv2d-11         [128, 384, 13, 13]       1,327,488
             ReLU-12         [128, 384, 13, 13]               0
           Conv2d-13         [128, 256, 13, 13]         884,992
             ReLU-14         [128, 256,

In [None]:
# Data Augmentation (Section 4.1) - 과적합 방지
# mean substraction only, PCA color augmentation, 10 crop test 구현 X
TRAIN_DIR = ''
VAL_DIR = ''
TEST_DIR = ''

train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(227),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=train_transform)

train_loader = data.DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(227),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

val_dataset = datasets.ImageFolder(VAL_DIR, transform=val_transform)

val_loader = data.DataLoader(
    val_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=8,
    pin_memory=True
)

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(227),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

test_dataset = datasets.ImageFolder(TEST_DIR, transform=test_transform)

test_loader = data.DataLoader(
    test_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=8,
    pin_memory=True
)

In [None]:
# Epoch (Section 5) - Roughly 90
best_acc = 0
for epoch in range(90):
    model.train()

    train_loss = 0
    correct = 0
    total = 0

    for img, cls in train_loader:
        img, cls = img.to(device), cls.to(device)

        output = model(img)
        loss = criterion(output, cls)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * img.size(0)
        _, preds = torch.max(output, 1)
        correct += (preds == cls).sum().item()
        total += cls.size(0)

    train_loss /= total
    train_acc = correct / total

    model.eval()

    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for img, cls in val_loader:
            img, cls = img.to(device), cls.to(device)

            output = model(img)
            loss = criterion(output, cls)

            val_loss += loss.item() * img.size(0)
            _, preds = torch.max(output, 1)
            correct += (preds == cls).sum().item()
            total += cls.size(0)

    val_loss /= total
    val_acc = correct / total

    if (epoch + 1) % 10 == 0:
        print(f'Epoch: {epoch + 1}/90 - Train Loss: {train_loss:.4f} - Train Acc: {train_acc:.4f} - Val Loss {val_loss:.4f} - Val Acc {val_acc:.4f}')

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_alexnet.pth")
        print("Best Model Updated")

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_acc': best_acc
    }, "alexnet_checkpoint.pth")

    lr_scheduler.step()

In [None]:
model.load_state_dict(torch.load("best_alexnet.pth"))
model.to(device)

model.eval()

test_loss = 0
correct = 0
total = 0

with torch.no_grad():
    for img, cls in test_loader:
        img, cls = img.to(device), cls.to(device)

        output = model(img)
        loss = criterion(output, cls)

        test_loss += loss.item() * img.size(0)
        _, preds = torch.max(output, 1)
        correct += (preds == cls).sum().item()
        total += cls.size(0)

test_loss /= total
test_acc = correct / total

print(f'Test Loss {test_loss:.4f} - Test Acc {test_acc:.4f}')