In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, datasets, transforms

from tqdm import tqdm

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available()
                      else "cuda" if torch.cuda.is_available()
                      else "cpu")
device

device(type='mps')

In [None]:
alexnet = models.alexnet(weights=models.AlexNet_Weights.DEFAULT).to(device)
alexnet.classifier = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(9216, out_features=4096),
    nn.ReLU(True),
    nn.Dropout(p=0.5),
    nn.Linear(4096, 2048),
    nn.ReLU(True),
    nn.Linear(2048, 10)
)
alexnet = alexnet.to(device)

optimizer = optim.Adam(alexnet.parameters(), lr=3e-4)

criterion = nn.CrossEntropyLoss()

In [4]:
alexnet

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [5]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))  # ImageNet 統計
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

In [6]:
train = datasets.CIFAR10("../../data", train=True, transform=train_transform,
                         download=True)
test = datasets.CIFAR10("../../data", train=False, transform=test_transform,
                        download=True)

train_dataloader = DataLoader(train, batch_size=512, shuffle=True,
                              generator=torch.Generator().manual_seed(42))

test_dataloader = DataLoader(test, batch_size=512, shuffle=False,
                             generator=torch.Generator().manual_seed(42))

In [7]:
num_epochs = 10

alexnet.train()
for epoch in range(1, num_epochs + 1):
    running_loss = 0.0
    for imgs, labels in tqdm(train_dataloader):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        preds = alexnet(imgs)
        loss = criterion(preds, labels)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    print(f"{epoch}/{num_epochs} Loss: {running_loss}")

100%|██████████| 98/98 [01:45<00:00,  1.07s/it]


1/10 Loss: 62678.895126104355


100%|██████████| 98/98 [01:47<00:00,  1.10s/it]


2/10 Loss: 235.21741580963135


100%|██████████| 98/98 [01:47<00:00,  1.10s/it]


3/10 Loss: 97189.52403092384


  6%|▌         | 6/98 [00:06<01:46,  1.15s/it]


KeyboardInterrupt: 

In [None]:
alexnet.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for X, y in test_dataloader:
        X, y = X.to(device), y.to(device)
        logits = alexnet(X)

        pred = torch.argmax(logits, dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    print(f"Acc: {correct / total}")