In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [None]:
train_data = datasets.MNIST(
    root = "data",
    train = True,
    download = True,
    transform = ToTensor()
)

test_data = datasets.MNIST(
    root = "data",
    train = False,
    download = True,
    transform = ToTensor()
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 112130706.55it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 44848461.25it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 30792859.10it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 10799619.48it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [None]:
print(type(train_data))

<class 'torchvision.datasets.mnist.MNIST'>


In [None]:
batch_size = 64

train_dl = DataLoader(train_data, batch_size = batch_size)
test_dl = DataLoader(test_data, batch_size = batch_size)

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cuda device


In [None]:
class FPNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.flatten = nn.Flatten()
    self.inh1 = nn.Linear(in_features=784, out_features=512, bias=False)
    self.bn1 = nn.BatchNorm1d(512)
    self.relu = nn.ReLU()
    self.h2 = nn.Linear(in_features=512, out_features=256, bias=False)
    self.bn2 = nn.BatchNorm1d(256)
    self.h3 = nn.Linear(in_features=256, out_features=128, bias=False)
    self.bn3 = nn.BatchNorm1d(128)
    self.h4 = nn.Linear(in_features=128, out_features=64, bias=False)
    self.bn4 = nn.BatchNorm1d(64)
    self.h5 = nn.Linear(in_features=64, out_features=32, bias=False)
    self.bn5 = nn.BatchNorm1d(32)
    self.h6 = nn.Linear(in_features=32, out_features=10, bias=False)
    self.bn6 = nn.BatchNorm1d(10)
  def forward(self, x):
    x = self.flatten(x)
    x = self.inh1(x)
    x = self.relu(x)
    x = self.bn1(x)
    x = self.h2(x)
    x = self.relu(x)
    x = self.bn2(x)
    x = self.h3(x)
    x = self.relu(x)
    x = self.bn3(x)
    x = self.h4(x)
    x = self.relu(x)
    x = self.bn4(x)
    x = self.h5(x)
    x = self.relu(x)
    x = self.bn5(x)
    x = self.h6(x)
    output = self.bn6(x)

    return output

In [None]:
def train(dataloader, model, loss_fn, optimizer):
  model.train()
  track_loss=0
  num_correct=0
  for i, (imgs, labels) in enumerate(dataloader):
      imgs=imgs.to(device)
      labels=labels.to(device)
      pred=model(imgs)

      loss=loss_fn(pred,labels)
      track_loss+=loss.item()
      num_correct+=(torch.argmax(pred,dim=1)==labels).type(torch.float).sum().item()

      running_loss=round(track_loss/(i+(imgs.shape[0]/batch_size)),2)
      running_acc=round((num_correct/((i*batch_size+imgs.shape[0])))*100,2)

      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      if i%100==0:
          print("Batch:", i+1, "/",len(dataloader), "Running Loss:",running_loss, "Running Accuracy:",running_acc)

  epoch_loss=running_loss
  epoch_acc=running_acc
  return epoch_loss, epoch_acc

In [None]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
model=FPNN()
model=model.to(device)
loss=nn.CrossEntropyLoss()
lr=0.001

optimizer=torch.optim.Adam(params=model.parameters(), lr=lr)

epochs=10

for i in range(epochs):
    print("Epoch:",i+1)
    train_epoch_loss, train_epoch_acc = train(train_dl, model, loss, optimizer)
    print("Training:", "Epoch Loss:", train_epoch_loss, "Epoch Accuracy:", train_epoch_acc)
    print("--------------------------------------------------")

test(test_dl, model, loss)

Epoch: 1
Batch: 1 / 938 Running Loss: 2.69 Running Accuracy: 6.25
Batch: 101 / 938 Running Loss: 0.74 Running Accuracy: 84.78
Batch: 201 / 938 Running Loss: 0.61 Running Accuracy: 88.33
Batch: 301 / 938 Running Loss: 0.54 Running Accuracy: 89.67
Batch: 401 / 938 Running Loss: 0.49 Running Accuracy: 90.71
Batch: 501 / 938 Running Loss: 0.46 Running Accuracy: 91.33
Batch: 601 / 938 Running Loss: 0.43 Running Accuracy: 91.85
Batch: 701 / 938 Running Loss: 0.41 Running Accuracy: 92.27
Batch: 801 / 938 Running Loss: 0.39 Running Accuracy: 92.5
Batch: 901 / 938 Running Loss: 0.37 Running Accuracy: 92.79
Training: Epoch Loss: 0.36 Epoch Accuracy: 92.94
--------------------------------------------------
Epoch: 2
Batch: 1 / 938 Running Loss: 0.19 Running Accuracy: 96.88
Batch: 101 / 938 Running Loss: 0.2 Running Accuracy: 95.85
Batch: 201 / 938 Running Loss: 0.2 Running Accuracy: 95.51
Batch: 301 / 938 Running Loss: 0.2 Running Accuracy: 95.7
Batch: 401 / 938 Running Loss: 0.19 Running Accuracy