# Train MLP with CIFAR10

## 1. import necessary packages

In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn, optim
import matplotlib.pyplot as plt

## 2. prepare dataset and dataloader

In [None]:
train_data = torchvision.datasets.CIFAR10(root='../data',
                                          train=True,
                                          transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root='../data',
                                         train=False,
                                         transform=torchvision.transforms.ToTensor(),
                                         download=True)

train_data_size = len(train_data)
test_data_size = len(test_data)

train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

## 3. design network architecture

In [None]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(3*32*32, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.model(x)

## 4. define models
- `loss function`: cross entropy
- `max epoch`: 50
- `learning rate`: 1e-3
- `optimizer`: adam

In [None]:
plot_train = []
plot_test = []
train_loss = []
test_loss = []
plot_x = []

model = MLP()
# model = torch.load("../model/best_mlp.pth")
loss_fn = nn.CrossEntropyLoss()

max_epoch = 50
lr = 1e-3

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-3)
# optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

## 5. start training

In [None]:
min_loss = 100000
for epoch in range(1, 1+max_epoch):
    total_train_loss = 0
    for i, batch in enumerate(train_dataloader):
        imgs, targets = batch
        outputs = model(imgs)

        loss = loss_fn(outputs, targets)
        total_train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 200 == 0:
            print("Current step: {}/{}, Loss: {}".format(i, len(train_dataloader), loss.item()))

    total_train_loss = 0
    total_test_loss = 0
    total_train_accuracy = 0
    total_test_accuracy = 0
    with torch.no_grad():
        for batch in train_dataloader:
            imgs, targets = batch
            outputs = model(imgs)
            loss = loss_fn(outputs, targets)
            total_train_loss += loss.item()
            total_train_accuracy += (outputs.argmax(1) == targets).sum()
        plot_x.append(epoch)
        train_loss.append(total_train_loss / train_data_size)
        plot_train.append(total_train_accuracy / train_data_size * 100)
        for batch in test_dataloader:
            imgs, targets = batch
            outputs = model(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss.item()
            total_test_accuracy += (outputs.argmax(1) == targets).sum()
        test_loss.append(total_test_loss / test_data_size)
        plot_test.append(total_test_accuracy / test_data_size * 100)
    print("epoch: %d, train_loss: %.4f, test_loss: %.4f, train_accuracy: %.2f%%, test_accuracy: %.2f%%"
          % (epoch,
             total_train_loss / train_data_size,
             total_test_loss / test_data_size,
             total_train_accuracy / train_data_size * 100,
             total_test_accuracy / test_data_size * 100))


    if total_test_loss < min_loss:
        min_loss = total_test_loss
        torch.save(model, "../model/best_mlp.pth")
        print("best model saved!")


Current step: 400/782, Loss: 1.9022704362869263
Current step: 600/782, Loss: 1.8276209831237793
epoch: 5, train_loss: 0.0281, test_loss: 0.0285, train_accuracy: 32.01%, test_accuracy: 31.67%
best model saved!
Current step: 0/782, Loss: 1.9261943101882935
Current step: 200/782, Loss: 1.6582214832305908
Current step: 400/782, Loss: 1.752622365951538
Current step: 600/782, Loss: 1.6829006671905518
epoch: 6, train_loss: 0.0259, test_loss: 0.0266, train_accuracy: 40.05%, test_accuracy: 39.38%
best model saved!
Current step: 0/782, Loss: 1.790103554725647
Current step: 200/782, Loss: 1.466533899307251
Current step: 400/782, Loss: 1.6450697183609009
Current step: 600/782, Loss: 1.5248744487762451
epoch: 7, train_loss: 0.0238, test_loss: 0.0249, train_accuracy: 45.84%, test_accuracy: 43.96%
best model saved!
Current step: 0/782, Loss: 1.6557466983795166
Current step: 200/782, Loss: 1.3170733451843262
Current step: 400/782, Loss: 1.5416054725646973
Current step: 600/782, Loss: 1.487803339958191

KeyboardInterrupt: 

## 6. visualize curves

In [None]:
fig1 = plt.subplot(2,1,1)
fig2 = plt.subplot(2,1,2)
print(plot_x, plot_train)
fig1.plot(plot_x, plot_train,  c='red', label='training data accuracy')
fig1.plot(plot_x, plot_test, c='blue', label='test data accuracy')
fig1.legend()
fig2.plot(plot_x, train_loss, c='green', label='train data loss')
fig2.plot(plot_x, test_loss, c='yellow', label='test data loss')
fig2.legend()
plt.savefig("curve.jpg")
plt.show()