## 1. Quickstart

### Dataset

In [27]:
# Download both the training and testing FashionMNIST dataset from torchvision

from torchvision import datasets
from torchvision.transforms import ToTensor

training_set = datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)
testing_set = datasets.FashionMNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor()
)

In [28]:
# Pass Dataset as an argument to Dataloader and iterate over the dataset
# Set batch_size as 64

from torch.utils.data import DataLoader
batch_size = 64

train_dataloader = DataLoader(training_set, batch_size=batch_size)
test_dataloader = DataLoader(testing_set, batch_size=batch_size)

### Creating the Model

In [29]:
import torch

# check available device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using cpu device


In [30]:
# define NeuralNetwork: linear(512) -> relu -> linear(512) -> relu -> linear(10)
from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
    
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [31]:
# Initialize a model
model = NeuralNetwork().to(device)

In [32]:
# Optimizing the parameters

# loss_fn: crossentropy
# optimizer: sgd
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [33]:
# train_fn
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    
    for batch, (X, y) in enumerate(dataloader):
        X,y = X.to(device), y.to(device)
        
        pred = model(X)
        loss = loss_fn(pred, y)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch%100 == 0:
            print(f"Round: {(batch+1) * len(X)}/{size}  Loss: {loss.item():5f}")

In [34]:
# test_fn
def test(dataloader, model, loss_fn):
    batch_num = len(dataloader)
    size = len(dataloader.dataset)
    model.eval()
    
    loss, correct = 0, 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            X,y = X.to(device), y.to(device)
            
            pred = model(X)
            loss += loss_fn(pred, y)
            correct += (pred.argmax(1)==y).type(torch.float64).sum().item()
    
    print(f"Avg Loss: {loss/batch_num}, Correct Rate: {correct/size}")

In [35]:
# start training
epochs = 5
for t in range(epochs):
    print(f"Epoch {t}\n")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
    print("-"*30)

Epoch 0

Round: 64/60000  Loss: 2.287028
Round: 6464/60000  Loss: 2.283128
Round: 12864/60000  Loss: 2.264858
Round: 19264/60000  Loss: 2.269808
Round: 25664/60000  Loss: 2.252944
Round: 32064/60000  Loss: 2.215040
Round: 38464/60000  Loss: 2.236125
Round: 44864/60000  Loss: 2.194273
Round: 51264/60000  Loss: 2.191097
Round: 57664/60000  Loss: 2.175271
Avg Loss: 2.160900115966797, Correct Rate: 0.299
------------------------------
Epoch 1

Round: 64/60000  Loss: 2.155209
Round: 6464/60000  Loss: 2.164204
Round: 12864/60000  Loss: 2.106785
Round: 19264/60000  Loss: 2.131963
Round: 25664/60000  Loss: 2.099849
Round: 32064/60000  Loss: 2.021408
Round: 38464/60000  Loss: 2.063553
Round: 44864/60000  Loss: 1.981072
Round: 51264/60000  Loss: 1.982052
Round: 57664/60000  Loss: 1.934895
Avg Loss: 1.9249920845031738, Correct Rate: 0.5858
------------------------------
Epoch 2

Round: 64/60000  Loss: 1.937815
Round: 6464/60000  Loss: 1.934401
Round: 12864/60000  Loss: 1.818734
Round: 19264/60000

### Save & Load Model

In [38]:
# save the model
torch.save(model.state_dict(), "model/basic.pth")

In [40]:
# load the model
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("model/basic.pth"))

  model.load_state_dict(torch.load("model/basic.pth"))


<All keys matched successfully>