In [1]:
import torchvision
import torchvision.transforms as transforms
train_set = torchvision.datasets.FashionMNIST(
    root = "./data/FashionMNIST",
    train = True,
    download = True,
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
)

In [2]:
import torch
data_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size = 100
)

In [3]:
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 16, kernel_size = 5)
        self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 5)
        self.fc1 = nn.Linear(in_features = 32*4*4, out_features = 128)
        self.fc2 = nn.Linear(in_features = 128, out_features = 64)
        self.out = nn.Linear(in_features = 64, out_features = 10)
    def forward(self, t):
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size = 2, stride = 2)
        t = t.reshape(-1, 32*4*4)
        t = self.fc1(t)
        t = F.relu(t)
        t = self.fc2(t)
        t = F.relu(t)
        t = self.out(t)
        #t = F.softmax(t, dim = 1)
        return t

In [4]:
network = Network()
optimizer = optim.Adam(network.parameters())

In [5]:
NUM_EPOCHS = 10
for i in range(NUM_EPOCHS):
    batches = iter(data_loader)
    for image_batch, label_batch in batches:
        pred = network(image_batch)
        loss = F.cross_entropy(pred, label_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [6]:
def get_num_correct(preds, labels):
    return preds.argmax(dim = 1).eq(labels).sum().item()

In [7]:
batch = next(iter(data_loader))
img, lab = batch;
pred = network(img)
print(get_num_correct(pred, lab))

94


In [10]:
test_set = torchvision.datasets.FashionMNIST(
    root = "./data/FashionMNIST",
    train = False,
    download = True,
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
)
test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size = 100
)

In [12]:
batches = iter(test_loader)
num_correct = 0
num_test = 0
total_loss = 0
for image_batch, label_batch in batches:
    pred = network(image_batch)
    loss = F.cross_entropy(pred, label_batch)
    num_correct += get_num_correct(pred, label_batch)
    num_test += 100
    total_loss += loss
print("accuracy on test dataset: {}, average loss: {}".format(num_correct/num_test, total_loss/num_test) )

accuracy on test dataset: 0.8944, average loss: 0.0028908594977110624
