In [1]:
import numpy as np
import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

In [None]:
def to_numpy(x):
  x = np.array(x).flatten()
  return x / 255

In [None]:
train_dataset = MNIST(".", download=True, train=True, transform=to_numpy)
valid_dataset = MNIST(".", download=True, train=False, transform=to_numpy)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [None]:
train_dataset[0]

(array([0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.  

In [None]:
len(train_dataset), len(valid_dataset)

(60000, 10000)

In [None]:
neural_net = nn.Sequential(
    nn.Linear(28 * 28, 256),
    nn.ReLU(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

In [None]:
criteria = nn.CrossEntropyLoss()
optimizator = optim.Adam(neural_net.parameters(), 1e-3)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False)

In [None]:
from sklearn.metrics import accuracy_score

In [None]:
for epoch in tqdm.tqdm(range(10)):
  for i, (x, y) in enumerate(train_loader):
    
    optimizator.zero_grad()
    output = neural_net.forward(x.float())
    loss = criteria(output, y)
    loss.backward()
    nn.utils.clip_grad_norm_(neural_net.parameters(), 5.0)
    optimizator.step()

    if i % 100 == 0:
      print(f"Epoch: {epoch+1}/10, Iteration: {i+1}/{len(train_loader)}, Loss: {loss.item():.3f}")

  y_pred = []
  y_valid = []
  for x, y in valid_loader:

    with torch.no_grad():
      output = neural_net.forward(x.float())
      y_pred.extend(output.argmax(1))
      y_valid.extend(y)

  quality = accuracy_score(y_valid, y_pred)
  print(f"Epoch: {epoch}/10, quality: {quality:.3f}")
  

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1/10, Iteration: 1/469, Loss: 2.297
Epoch: 1/10, Iteration: 101/469, Loss: 0.291
Epoch: 1/10, Iteration: 201/469, Loss: 0.209
Epoch: 1/10, Iteration: 301/469, Loss: 0.267
Epoch: 1/10, Iteration: 401/469, Loss: 0.113


 10%|█         | 1/10 [00:06<01:01,  6.80s/it]

Epoch: 0/10, quality: 0.956
Epoch: 2/10, Iteration: 1/469, Loss: 0.263
Epoch: 2/10, Iteration: 101/469, Loss: 0.260
Epoch: 2/10, Iteration: 201/469, Loss: 0.106
Epoch: 2/10, Iteration: 301/469, Loss: 0.099
Epoch: 2/10, Iteration: 401/469, Loss: 0.106


 20%|██        | 2/10 [00:13<00:54,  6.76s/it]

Epoch: 1/10, quality: 0.967
Epoch: 3/10, Iteration: 1/469, Loss: 0.082
Epoch: 3/10, Iteration: 101/469, Loss: 0.124
Epoch: 3/10, Iteration: 201/469, Loss: 0.024
Epoch: 3/10, Iteration: 301/469, Loss: 0.098
Epoch: 3/10, Iteration: 401/469, Loss: 0.036


 30%|███       | 3/10 [00:20<00:47,  6.83s/it]

Epoch: 2/10, quality: 0.971
Epoch: 4/10, Iteration: 1/469, Loss: 0.058
Epoch: 4/10, Iteration: 101/469, Loss: 0.026
Epoch: 4/10, Iteration: 201/469, Loss: 0.099
Epoch: 4/10, Iteration: 301/469, Loss: 0.101
Epoch: 4/10, Iteration: 401/469, Loss: 0.116


 40%|████      | 4/10 [00:27<00:40,  6.83s/it]

Epoch: 3/10, quality: 0.976
Epoch: 5/10, Iteration: 1/469, Loss: 0.015
Epoch: 5/10, Iteration: 101/469, Loss: 0.023
Epoch: 5/10, Iteration: 201/469, Loss: 0.070
Epoch: 5/10, Iteration: 301/469, Loss: 0.041
Epoch: 5/10, Iteration: 401/469, Loss: 0.113


 50%|█████     | 5/10 [00:34<00:34,  6.85s/it]

Epoch: 4/10, quality: 0.976
Epoch: 6/10, Iteration: 1/469, Loss: 0.026
Epoch: 6/10, Iteration: 101/469, Loss: 0.054
Epoch: 6/10, Iteration: 201/469, Loss: 0.025
Epoch: 6/10, Iteration: 301/469, Loss: 0.033
Epoch: 6/10, Iteration: 401/469, Loss: 0.022


 60%|██████    | 6/10 [00:41<00:28,  7.05s/it]

Epoch: 5/10, quality: 0.979
Epoch: 7/10, Iteration: 1/469, Loss: 0.006
Epoch: 7/10, Iteration: 101/469, Loss: 0.040
Epoch: 7/10, Iteration: 201/469, Loss: 0.050
Epoch: 7/10, Iteration: 301/469, Loss: 0.048
Epoch: 7/10, Iteration: 401/469, Loss: 0.012


 70%|███████   | 7/10 [00:48<00:21,  7.03s/it]

Epoch: 6/10, quality: 0.977
Epoch: 8/10, Iteration: 1/469, Loss: 0.018
Epoch: 8/10, Iteration: 101/469, Loss: 0.099
Epoch: 8/10, Iteration: 201/469, Loss: 0.014
Epoch: 8/10, Iteration: 301/469, Loss: 0.035
Epoch: 8/10, Iteration: 401/469, Loss: 0.003


 80%|████████  | 8/10 [00:55<00:13,  7.00s/it]

Epoch: 7/10, quality: 0.978
Epoch: 9/10, Iteration: 1/469, Loss: 0.019
Epoch: 9/10, Iteration: 101/469, Loss: 0.013
Epoch: 9/10, Iteration: 201/469, Loss: 0.006
Epoch: 9/10, Iteration: 301/469, Loss: 0.009
Epoch: 9/10, Iteration: 401/469, Loss: 0.012


 90%|█████████ | 9/10 [01:02<00:06,  6.95s/it]

Epoch: 8/10, quality: 0.978
Epoch: 10/10, Iteration: 1/469, Loss: 0.034
Epoch: 10/10, Iteration: 101/469, Loss: 0.007
Epoch: 10/10, Iteration: 201/469, Loss: 0.006
Epoch: 10/10, Iteration: 301/469, Loss: 0.009
Epoch: 10/10, Iteration: 401/469, Loss: 0.029


100%|██████████| 10/10 [01:09<00:00,  6.94s/it]

Epoch: 9/10, quality: 0.981



