In [7]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import pickle


In [9]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,),(0.5,))])
train_set = torchvision.datasets.MNIST(root="MNIST", train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root="MNIST", train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)

In [10]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, stride=2)
        self.conv1 = nn.Conv2d(1,16,3)
        self.conv2 = nn.Conv2d(16,32,3)
        self.func1 = nn.Linear(32*5*5, 120)
        self.func2 = nn.Linear(120,10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size()[0], -1)
        x = self.func1(x)
        x = self.relu(x)
        x = self.func2(x)
        return x

In [11]:
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.005)

In [12]:
losses = []
for epoch in range(100):
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()
        prediction = model(batch[0])
        loss = criterion(prediction, batch[1])
        train_loss = loss.item()
        losses.append(train_loss)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1} | train_loss:{train_loss}")

    for batch in test_loader:
        optimizer.zero_grad()
        prediction = model(batch[0])
        loss = criterion(prediction, batch[1])
        loss.backward()
        optimizer.step()

Epoch 1 | train_loss:0.3957759737968445
Epoch 2 | train_loss:0.11211442947387695
Epoch 3 | train_loss:0.09716687351465225
Epoch 4 | train_loss:0.07155212014913559
Epoch 5 | train_loss:0.09079495072364807
Epoch 6 | train_loss:0.03388676047325134
Epoch 7 | train_loss:0.030975086614489555
Epoch 8 | train_loss:0.0330594927072525
Epoch 9 | train_loss:0.04112349823117256


KeyboardInterrupt: 

In [21]:
loss.item()

0.03193401172757149

In [20]:
len(losses)

8617

In [17]:
with open("CNN_pickle_test", "wb") as f:
    pickle.dump(losses, f)

In [18]:
with open("CNN_pickle_test", "rb") as f:
    losses_ = pickle.load(f)
losses_

[2.3080503940582275,
 2.3157994747161865,
 2.3003084659576416,
 2.3132882118225098,
 2.309725284576416,
 2.3051609992980957,
 2.286848545074463,
 2.2963790893554688,
 2.3126609325408936,
 2.292819023132324,
 2.3074519634246826,
 2.2943153381347656,
 2.300705671310425,
 2.2979023456573486,
 2.300865888595581,
 2.3078534603118896,
 2.307793140411377,
 2.3028671741485596,
 2.2937521934509277,
 2.2877390384674072,
 2.299406051635742,
 2.2978549003601074,
 2.2982327938079834,
 2.2955241203308105,
 2.2986555099487305,
 2.296056032180786,
 2.306704521179199,
 2.3002476692199707,
 2.3088271617889404,
 2.2891926765441895,
 2.3023080825805664,
 2.2999343872070312,
 2.296405076980591,
 2.2974460124969482,
 2.3035881519317627,
 2.293267011642456,
 2.2919235229492188,
 2.293144464492798,
 2.2849910259246826,
 2.286529064178467,
 2.2959659099578857,
 2.285979747772217,
 2.2795515060424805,
 2.289767265319824,
 2.297098159790039,
 2.292921304702759,
 2.2956628799438477,
 2.283576250076294,
 2.2860772