<a href="https://colab.research.google.com/github/Azuremis/make_your_first_gan_with_pytorch/blob/master/mnist_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn

In [0]:
# neural network class
class Classifier(nn.Module):

  def __init__(self):
    # initialise parent pytorch class
    super().__init__()

    # setup neural network architecture
    self.model = nn.Sequential(
        nn.Linear( 784, 200), #  fully connected mapping from 784 nodes to 200 nodes
        nn.Sigmoid(),  # apply sigmoid to ouput of 200 nodes
        nn.Linear(200, 10),  # maps 200 nodes to 10 nodes
        nn.Sigmoid()  # apply sigmoid to output of 10 nodes to get final output
    )

    # setup loss function
    self.loss_function = torch.nn.MSELoss()

    # setup optimiser using simple stochastic gradient descent
    self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)

    # counter and accumulator to track progress
    self.counter = 0
    self.progress = []

    # setup forward method for passing information through network
    def forward(self, inputs):
      # run the model
      return self.model(inputs)

    def train(self, inputs, targets):
      # calculate nn outputs
      outputs = self.forward(inputs)

      # calculate loss
      loss = self.loss_function(outputs, targets)

      # update training progress trackers, accumulate loss val after 10 train ex
      self.counter += 1
      if (self.counter % 10):  
        self.progress.append(loss.item())  # loss.item() unwraps tensor
        
      # indicate speed of training to user
      if (self.counter % 10000 == 0):
        print("counter = ", self.counter)

      # process nn updates
      self.optimiser().zero_grad() # set gradients to zero
      loss.backward() # calculate gradients via backward pass
      self.optimiser.step() # update nn weights using gradients

    def plot_progress(self):
      df = pandas.Dataframe(self.progress, columns[["loss"]])
      df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker=".",
              grid=True, yticks=(0, 0.25, 0.5))


