<a href="https://colab.research.google.com/github/Festivius/hello-world/blob/main/cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import json

In [None]:
transform = transforms.ToTensor()

In [None]:
train_data = datasets.MNIST(root='/cnn_data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='/cnn_data', train=False, download=True, transform=transform)

In [None]:
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)

In [None]:
class ConvolutionalNetwork(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(196, 33)
    self.fc2 = nn.Linear(33, 25)
    self.fc3 = nn.Linear(25, 10)

  def forward(self, X):
    X = F.max_pool2d(X,2,2)
    X = X.view(-1, 196)

    X = F.relu(self.fc1(X))
    X = F.relu(self.fc2(X))
    X = self.fc3(X)

    return F.log_softmax(X, dim=1)

In [None]:
model = ConvolutionalNetwork()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 20

for i in range(epochs):
  for b,(X_train, y_train) in enumerate(train_loader):
    b += 1
    y_pred = model(X_train)
    loss = criterion(y_pred, y_train)

    predicted = torch.max(y_pred.data, 1)[1]
    batch_corr = (predicted == y_train).sum()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if b%1000 == 0:
      print(f"Epoch: {i}  Batch: {b}  Loss: {loss.item()}")

  with torch.no_grad():
    for b,(X_test, y_test) in enumerate(test_loader):
      y_val = model(X_test)
      predicted = torch.max(y_val.data, 1)[1]
  loss = criterion(y_val, y_test)

Epoch: 0  Batch: 1000  Loss: 0.28416404128074646
Epoch: 0  Batch: 2000  Loss: 0.2713087201118469
Epoch: 0  Batch: 3000  Loss: 0.2503032684326172
Epoch: 0  Batch: 4000  Loss: 0.2004493921995163
Epoch: 0  Batch: 5000  Loss: 0.02620089054107666
Epoch: 0  Batch: 6000  Loss: 0.061026811599731445
Epoch: 1  Batch: 1000  Loss: 0.29720428586006165
Epoch: 1  Batch: 2000  Loss: 0.09474434703588486
Epoch: 1  Batch: 3000  Loss: 0.5805155038833618
Epoch: 1  Batch: 4000  Loss: 0.05504431203007698
Epoch: 1  Batch: 5000  Loss: 0.021260861307382584
Epoch: 1  Batch: 6000  Loss: 0.37327665090560913
Epoch: 2  Batch: 1000  Loss: 0.025350576266646385
Epoch: 2  Batch: 2000  Loss: 0.20941917598247528
Epoch: 2  Batch: 3000  Loss: 0.024845723062753677
Epoch: 2  Batch: 4000  Loss: 0.028816912323236465
Epoch: 2  Batch: 5000  Loss: 0.037517525255680084
Epoch: 2  Batch: 6000  Loss: 0.05883635953068733
Epoch: 3  Batch: 1000  Loss: 0.0488986074924469
Epoch: 3  Batch: 2000  Loss: 0.040872085839509964
Epoch: 3  Batch: 3

In [None]:
test_load_everything = DataLoader(test_data, batch_size=10000, shuffle=False)

In [None]:
with torch.no_grad():
  correct = 0
  for X_test, y_test in test_load_everything:
    y_val = model(X_test)
    predicted = torch.max(y_val, 1)[1]
    correct += (predicted == y_test).sum()

In [None]:
correct.item()

9597

In [None]:
weights_fc1 = model.fc1.weight.data.tolist()
weights_fc2 = model.fc2.weight.data.tolist()
weights_fc3 = model.fc3.weight.data.tolist()

for i in range(33):
  for j in range(196):
    weights_fc1[i][j] = round(weights_fc1[i][j], 4)

for i in range(25):
  for j in range(33):
    weights_fc2[i][j] = round(weights_fc2[i][j], 4)

for i in range(10):
  for j in range(25):
    weights_fc3[i][j] = round(weights_fc3[i][j], 4)

with open('weights_fc1.json', 'w') as f:
  json.dump(weights_fc1, f)

with open('weights_fc2.json', 'w') as f:
  json.dump(weights_fc2, f)

with open('weights_fc3.json', 'w') as f:
  json.dump(weights_fc3, f)

In [None]:
biases_fc1 = model.fc1.bias
biases_fc2 = model.fc2.bias
biases_fc3 = model.fc3.bias

In [None]:
biases_fc3

Parameter containing:
tensor([-0.2921,  0.1034,  0.1809, -0.3112,  0.4765,  0.4518, -0.3592, -0.1181,
         0.1247, -0.1593], requires_grad=True)