<a href="https://colab.research.google.com/github/Speedbird45Bravo/rando_projects/blob/main/Torch_61322_TM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [72]:
import torch
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from torch.nn import Module, Sequential, Conv2d, ReLU, MaxPool2d, CrossEntropyLoss, Linear
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.optim import Adam
from datetime import datetime as dt
from pytz import timezone
import warnings
warnings.filterwarnings("ignore")

In [73]:
train = MNIST("/files", train=True, download=True, transform=ToTensor())
test = MNIST("/files", train=False, download=True, transform=ToTensor())
loaders = {"train":DataLoader(train, batch_size=64, shuffle=True, num_workers=4),\
           "test":DataLoader(train, batch_size=64, shuffle=True, num_workers=4)}
tz = timezone("US/Eastern")

In [74]:
class NeuralNetwork(Module):

  def __init__(self):

    super(NeuralNetwork, self).__init__()

    self.c1 = Sequential(Conv2d(1,16,5,1,2),ReLU(),MaxPool2d(2))
    self.c2 = Sequential(Conv2d(16,32,5,1,2),ReLU(),MaxPool2d(2))
    self.out = Sequential(Linear(32 * 7 * 7, 10))
  
  def forward(self, x):
    x = self.c1(x)
    x = self.c2(x)
    x = x.view(x.size(0),-1)
    output = self.out(x)
    return output, x

In [75]:
cnn = NeuralNetwork()
print(f"Network printed at {dt.now(tz=tz)}")
print(cnn)

Network printed at 2022-06-13 12:02:03.403075-04:00
NeuralNetwork(
  (c1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (c2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Sequential(
    (0): Linear(in_features=1568, out_features=10, bias=True)
  )
)


In [76]:
epochs = 6
loss_function = CrossEntropyLoss()
rabinowitz = Adam(cnn.parameters(), lr=0.01)

In [77]:
def train(cnn, epochs, loaders):

  cnn.train()

  for epoch in range(epochs):

    for i, (images, labels) in enumerate(loaders["train"]):
      b_X = Variable(images)
      b_y = Variable(labels)
      output = cnn(b_X)[0]
      loss = loss_function(output, b_y)

      rabinowitz.zero_grad()

      loss.backward()

      rabinowitz.step()

      if (i + 1) % 100 == 0:
        print(f"Epoch {epoch+1}/{epochs} | Loss: {loss.item()}")

print(f"Model Trained at {dt.now(tz=tz)}")

Model Trained at 2022-06-13 12:02:05.135429-04:00


In [78]:
train(cnn, epochs, loaders)

Epoch 1/6 | Loss: 0.20966669917106628
Epoch 1/6 | Loss: 0.05210019275546074
Epoch 1/6 | Loss: 0.1923825889825821
Epoch 1/6 | Loss: 0.021975547075271606
Epoch 1/6 | Loss: 0.037835318595170975
Epoch 1/6 | Loss: 0.03717513754963875
Epoch 1/6 | Loss: 0.08732873946428299
Epoch 1/6 | Loss: 0.08596909791231155
Epoch 1/6 | Loss: 0.016834694892168045
Epoch 2/6 | Loss: 0.0910496711730957
Epoch 2/6 | Loss: 0.013144388794898987
Epoch 2/6 | Loss: 0.011565802618861198
Epoch 2/6 | Loss: 0.06526030600070953
Epoch 2/6 | Loss: 0.16613951325416565
Epoch 2/6 | Loss: 0.022126754745841026
Epoch 2/6 | Loss: 0.12692561745643616
Epoch 2/6 | Loss: 0.0011487246956676245
Epoch 2/6 | Loss: 0.06046127900481224
Epoch 3/6 | Loss: 0.062489379197359085
Epoch 3/6 | Loss: 0.19304445385932922
Epoch 3/6 | Loss: 0.045264169573783875
Epoch 3/6 | Loss: 0.0838605985045433
Epoch 3/6 | Loss: 0.02725561335682869
Epoch 3/6 | Loss: 0.0021531139500439167
Epoch 3/6 | Loss: 0.0825572982430458
Epoch 3/6 | Loss: 0.0724736899137497
Epoch

In [83]:
def test():

  cnn.eval()

  with torch.no_grad():

    correct = 0
    total = 0

    for images, labels in loaders["test"]:
      test_output, last_layer = cnn(images)
      y_pred = torch.max(test_output,1)[1].data.squeeze()
      accuracy = ((y_pred == labels).sum().item()) / float(labels.size(0))
      pass

    print(f"Test Accuracy: {accuracy * 100}%")
    print(f"Model Tested at {dt.now(tz=tz)}")

In [84]:
test()

Test Accuracy: 96.875%
Model Tested at 2022-06-13 12:07:22.618097-04:00
