# simple nn model feed forward

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def download_mnist_data():
    training_data = datasets.MNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )

    test_data = datasets.MNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )

    return training_data, test_data

training_data, test_data = download_mnist_data()
# create data loaders
# a data loader is a class that wraps a dataset
# and provides access to the underlying data
# in a random or sequential order in batches
BATCH_SIZE = 128
train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE)

In [7]:
# create a model
# we will use a simple feedforward neural network

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_layers = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
       flattened = self.flatten(x)
       logits = self.dense_layers(flattened)
       return self.softmax(logits)
    
model = NeuralNetwork().to("cuda")

In [8]:
def train(dataloader, model, loss_fn, optimizer, epochs):
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}\n-------------------------------")
        for X, y in dataloader:
            X, y = X.to("cuda"), y.to("cuda")

            # compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y)

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"loss: {loss.item()}")
    print("Done!")
    

In [9]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train(train_dataloader, model, loss_fn, optimizer, epochs=5)

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

Epoch 1
-------------------------------
loss: 1.5287114381790161
Epoch 2
-------------------------------
loss: 1.5042095184326172
Epoch 3
-------------------------------
loss: 1.4961620569229126
Epoch 4
-------------------------------
loss: 1.4913662672042847
Epoch 5
-------------------------------
loss: 1.4829338788986206
Done!
Saved PyTorch Model State to model.pth
