In [49]:
import torch
import torchvision

## Load dataset

In [50]:
def load_fashion_mnist_dataset(batch_size):
    train_dataset = torchvision.datasets.FashionMNIST(root="./data", train=True, \
                                                      transform=torchvision.transforms.ToTensor(), download=True)
    test_dataset= torchvision.datasets.FashionMNIST(root="./data", train=False, \
                                                      transform=torchvision.transforms.ToTensor(), download=True)
    ## return the dataloaders of the two dataset
    return torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4), \
                torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [80]:
iter_train, iter_test = map(iter, load_fashion_mnist_dataset(256))
# print(iter_train, iter_test)
# for X, y in iter_train:
#     print(X, y)
#     break

## Initialize model parameters

In [65]:
num_inputs, num_outputs, num_hiddens=28*28, 10, 256

W1 = torch.normal(mean=0, std=0.01, size=[num_inputs, num_hiddens], requires_grad = True)
b1 = torch.zeros(num_hiddens, requires_grad=True)
W2 = torch.normal(mean=0, std=0.01, size=[num_hiddens, num_outputs], requires_grad = True)
b2 = torch.zeros(num_outputs, requires_grad=True)

params = [W1, b1, W2, b2]
print(W1.shape, b1.shape, W2.shape, b2.shape)

torch.Size([784, 256]) torch.Size([256]) torch.Size([256, 10]) torch.Size([10])


### Activation function

In [53]:
def relu(X):
    a = torch.zeros_like(X)
    return torch.max(X, a)

## Model

In [54]:
def mlp(X):
    X = X.reshape(-1, num_inputs)
    hiddens = relu(torch.mm(X, W1)+b1)
    # only logits are enough
    return torch.mm(hiddens, W2)+b2

## Loss function

In [60]:
loss = torch.nn.CrossEntropyLoss()

## Optimizer

In [82]:
def sgd(params, lr, batch_size):
    # print("in sgd")
    with torch.no_grad():
        for param in params:
            param -= lr*param.grad/batch_size
            # print(param.grad==torch.zeros_like(param))
            param.grad.zero_()

## Training

In [83]:
num_epochs = 3
lr = 0.1
batch_size = 256

for epoch in range(num_epochs):
    for X, y in iter_train:
        # print(X, y)
        # print(X.shape)
        y_hat_logit = mlp(X)
        ce_loss = loss(y_hat_logit, y)  
        ce_loss.sum().backward()
        sgd(params, lr, batch_size)
    # print(X.shape)
    # compute the traning loss
    with torch.no_grad():
        y_hat_logit = mlp(X)
        ce_loss = loss(y_hat_logit, y)
    #     # acct = accuracy(y_hat, y)
        print(f'epoch {epoch + 1}, loss {float(ce_loss.mean()):f}')

epoch 1, loss 2.291073
epoch 2, loss 2.291073
epoch 3, loss 2.291073
