In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import torch.nn.functional as F
import torch.optim as optim

In [2]:
data_set = torchvision.datasets.FashionMNIST(
    root = './data',
    download = True,
    train = True,
    transform = transforms.Compose([transforms.ToTensor()])
)

In [3]:
data_loader = torch.utils.data.DataLoader(
    data_set,
    batch_size = 1000,
    shuffle = True
)

In [4]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1,6,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(6,12,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(12*4*4,60),
            nn.ReLU(),
            nn.Linear(60,10),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
        

In [5]:
def get_num_correct(labels, targets):
    return (labels == targets.argmax(dim=1)).sum().item()

In [6]:
network = Network()
optimizer = optim.Adam(network.parameters(), lr=0.01)
for epoch in range(10):
    total_loss = 0
    total_correct = 0
    
    for batch in data_loader:
        images, labels = batch
        
        preds = network(images)
        loss = F.cross_entropy(preds, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_correct += get_num_correct(labels, preds)
    
    print(
        'epoch:', epoch,
        'total_correct', total_correct,
        'loss', total_loss
    )

epoch: 0 total_correct 32740 loss 89.23979592323303
epoch: 1 total_correct 40450 loss 70.20904731750488
epoch: 2 total_correct 41949 loss 65.1881422996521
epoch: 3 total_correct 42587 loss 62.81270307302475
epoch: 4 total_correct 43080 loss 61.311224818229675
epoch: 5 total_correct 43457 loss 60.03487968444824
epoch: 6 total_correct 43712 loss 59.094791531562805
epoch: 7 total_correct 43809 loss 58.53269284963608
epoch: 8 total_correct 44158 loss 57.47499191761017
epoch: 9 total_correct 44287 loss 57.09552711248398


In [7]:
preds = network(images)

In [8]:
print(get_num_correct(labels, preds)/len(labels))

0.711


In [9]:
print(total_correct)

44287


In [None]:
# network = Network()
# optimizer = optim.Adam(network.parameters(), lr=0.01)

# for epoch in range(10):
#     total_loss = 0
#     total_correct = 0
    
#     for batch in data_loader:
#         images, labels = batch
#         preds = network(images)
        
#         loss = F.cross_entropy(preds, labels)
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         total_loss += loss
#         total_correct += get_num_correct(labels, preds)
#     print(
#         'epoch:', epoch,
#         'total_correct', total_correct,
#         'loss', total_loss
#     )