In [5]:
import os

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from torchvision import datasets
from torchvision.transforms import ToTensor

In [4]:
%load_ext autoreload
%autoreload 2

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
# dataset
root_data = os.path.abspath('../../data')
print(f'dataset path: {root_data}')

train_dataset = datasets.FashionMNIST(root=root_data, train=True, download=True, transform=ToTensor())
test_dataset = datasets.FashionMNIST(root=root_data, train=False, download=True, transform=ToTensor())

dataset path: F:\XudongDuan\workspace\Common\AIFramework\data


In [9]:
# data loader
bs = 64

train_dataloader = DataLoader(train_dataset, batch_size=bs)
test_dataloader = DataLoader(test_dataset, batch_size=bs)

# test loader
for x, y in test_dataloader:
    print(f'shape of x: {x.shape}')
    print(f'shape of y: {y.shape}')
    break

shape of x: torch.Size([64, 1, 28, 28])
shape of y: torch.Size([64])


In [25]:
# model
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.main_body = nn.Sequential(
        nn.Linear(28*28, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Linear(512, 10)
        )
        
    def forward(self, x):
        x = self.flatten(x)
        logits = self.main_body(x)
        return logits
    
model = MLP().to(device)
print(model)

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (main_body): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [26]:
# optimizer
lr = 1.0e-3
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [27]:
# train and test
def train(data_loader, model, optimizer, scheduler, loss_fn, device):
    model.train()
    for batch, (x, y) in enumerate(data_loader):
        x, y = x.to(device), y.to(device)

        # Compute prediction error
        pred = model(x)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 1000 == 0:
            loss = loss.item()
            lr = optimizer.param_groups[0]['lr']
            print(f"LR: {lr}, loss: {loss:>7f}")

def test(data_loader, model, loss_fn, device):
    model.eval()
    size = len(data_loader.dataset)
    num_batches = len(data_loader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [28]:
epoch = 7
for i in range(epoch):
    print(f'Epoch: {i}')
    train(train_dataloader, model, optimizer, lr_scheduler, loss, device)
    test(test_dataloader, model, loss, device)
    lr_scheduler.step()

Epoch: 0
LR: 0.001, loss: 2.453668
Test Error: 
 Accuracy: 78.5%, Avg loss: 0.667944 

Epoch: 1
LR: 0.0009000000000000001, loss: 0.640635
Test Error: 
 Accuracy: 81.2%, Avg loss: 0.557258 

Epoch: 2
LR: 0.0008100000000000001, loss: 0.507010
Test Error: 
 Accuracy: 82.4%, Avg loss: 0.510056 

Epoch: 3
LR: 0.000729, loss: 0.445112
Test Error: 
 Accuracy: 83.3%, Avg loss: 0.483115 

Epoch: 4
LR: 0.0006561000000000001, loss: 0.408429
Test Error: 
 Accuracy: 83.9%, Avg loss: 0.465304 

Epoch: 5
LR: 0.00059049, loss: 0.383853
Test Error: 
 Accuracy: 84.3%, Avg loss: 0.452446 

Epoch: 6
LR: 0.000531441, loss: 0.365316
Test Error: 
 Accuracy: 84.5%, Avg loss: 0.442748 



In [29]:
# save
pth_path = os.path.join('../../model', 'model.pth')
torch.save(model.state_dict(), pth_path)
print("Saved PyTorch Model State to model.pth")

Saved PyTorch Model State to model.pth


In [30]:
# load
model_test = MLP().to(device)
model_test.load_state_dict(torch.load(pth_path))

<All keys matched successfully>

In [31]:
# predict
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

model.eval()
x, y = test_dataset[0][0], test_dataset[0][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

Predicted: "Ankle boot", Actual: "Ankle boot"
