**Image recognition and classification by AlexNet convolutional neural network**
---
---

**Complete version code:**
---

In [None]:
import torch
import torch.nn as nn

class AlexNet(nn.Module):
    """ AlexNet 深层卷积神经网络 """
    def __init__(self):
        super().__init__()
        # 卷积层
        self.conv = nn.Sequential(nn.Conv2d(1, 96, kernel_size=11, stride=4, padding=1),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=3, stride=2),

                                  nn.Conv2d(96, 256, kernel_size=5, padding=2),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=3, stride=2),

                                  # 使用3个连续的 3x3 卷积层
                                  nn.Conv2d(256, 384, kernel_size=3, padding=1),
                                  nn.ReLU(),
                                  nn.Conv2d(384, 384, kernel_size=3, padding=1),
                                  nn.ReLU(),
                                  nn.Conv2d(384, 256, kernel_size=3, padding=1),
                                  nn.ReLU(),
                                  nn.MaxPool2d(kernel_size=3, stride=2),

                                  nn.Flatten())

        self.fc = nn.Sequential(nn.Linear(5*5*256, 4096),
                                nn.ReLU(),
                                nn.Linear(4096, 4096),
                                nn.ReLU(),
                                nn.Linear(4096, 10))

    def forward(self, x):
        """ 前向传播方法 """
        x_flatten = self.conv(x)
        return self.fc(x_flatten)

In [None]:
import torchvision
import torchvision.transforms as transforms
from torch.utils import data
from torch.utils.data import DataLoader

trans = transforms.Compose([transforms.ToTensor(),
                            transforms.Resize((224, 224))])

mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

len(mnist_train), len(mnist_test)
train_dataloader = DataLoader(mnist_train, batch_size=256, shuffle=True)
test_dataloader = DataLoader(mnist_test, batch_size=256, shuffle=False)

len(train_dataloader), len(test_dataloader)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
model_AlexNet = AlexNet().to(device)
model_AlexNet

In [None]:
def count_parameters(model_AlexNet: nn.Module):
    """ 计算模型参数量函数 """
    return sum(p.numel() for p in model_AlexNet.parameters())


count_parameters(model_AlexNet)

In [None]:
x, y = next(iter(train_dataloader))
x.shape, y.shape

In [None]:
def train_model(model_AlexNet, train_dataloader, loss_func, optimizer):
    """ 模型训练函数 """
    model_AlexNet.train()
    total_loss = 0.
    for x, y in train_dataloader:
        # x: [bs, 1, 224, 224]
        # y: [batch_size]
        y_hat = model_AlexNet(x.to(device))
        loss = loss_func(y_hat, y.to(device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(train_dataloader)

def test_model(model_AlexNet, test_dataloader, loss_func):
    """ 模型测试函数 """
    model_AlexNet.eval()

    y_true = 0
    total_loss = 0.
    for x, y in test_dataloader:
        # x: [bs, 1, 224, 224]
        # y: [batch_size]
        y_hat = model_AlexNet(x.to(device))
        loss = loss_func(y_hat, y.to(device))

        y_true += (y == torch.argmax(y_hat, dim=-1)).sum().item()   #很重要的一步，原理是是布尔值求和，True=1,False=0

        total_loss += loss.item()

    avg_loss = total_loss / len(test_dataloader)
    acc = round(y_true / len(test_dataloader.dataset), 3)  #round是四舍五入函数
    return avg_loss, acc

In [None]:
model_AlexNet = AlexNet().to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_AlexNet.parameters(), lr=1e-3)

In [None]:
n_epoch = 5

train_loss_list = []
test_loss_list = []
for i in range(n_epoch):
    train_loss = train_model(model_AlexNet, train_dataloader, loss_func, optimizer)
    test_loss, acc = test_model(model_AlexNet, test_dataloader, loss_func)

    train_loss_list.append(train_loss)
    test_loss_list.append(test_loss)
    print(train_loss)

In [None]:
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

plt.figure(figsize=(12, 8))
plt.plot(train_loss_list, label="train loss")
plt.plot(test_loss_list, label="test loss")
plt.title("Model Loss")
plt.grid()
plt.legend()
plt.show()