In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [41]:
def save_model(model, optim, epoch, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),         # all the parameters of the network
        'optimizer_state_dict': optim.state_dict(),     # gradients, etc.
        'loss': loss
        }, path)

def load_model(model, optim, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optim.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return model, optim, epoch, loss

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
train_dataset = datasets.MNIST(root='/data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='/data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=True)

In [23]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.drop = nn.Dropout(p=0.9)
        self.conv1 = nn.Conv2d(1, 6, 5)         # (28x28) -> (24x24)
        self.conv2 = nn.Conv2d(6, 16, 5)        # (24x24) -> (20x20)
        self.fc2 = nn.Linear(16*20*20, 10)

    def forward(self, x):
        x = F.relu(self.drop(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = self.fc2(x.view(-1, 16*20*20))
        return x

In [64]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = Net().to(device)

In [65]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=1e-3)

In [66]:
epochs = 5
for epoch in range(epochs):
    for _, samples in enumerate(train_loader):
        images, labels = samples
        images = images.to(device)
        labels = labels.to(device)

        output = net(images)
        loss = criterion(output, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('epoch: {:2d}/{}, \t loss: {:.3f}'.format(epoch+1, epochs, loss))

epoch:  1/5, 	 loss: 0.898
epoch:  2/5, 	 loss: 0.876
epoch:  3/5, 	 loss: 0.849
epoch:  4/5, 	 loss: 0.460
epoch:  5/5, 	 loss: 0.585


In [67]:
for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())

conv1.weight 	 torch.Size([6, 1, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc2.weight 	 torch.Size([10, 6400])
fc2.bias 	 torch.Size([10])


In [37]:
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

state 	 {0: {'momentum_buffer': None}, 1: {'momentum_buffer': None}, 2: {'momentum_buffer': None}, 3: {'momentum_buffer': None}, 4: {'momentum_buffer': None}, 5: {'momentum_buffer': None}}
param_groups 	 [{'lr': 0.001, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5]}]


In [68]:
print('SAVING MODEL')
save_model(net, optimizer, 10, loss, './model')

SAVING MODEL


In [69]:
net = Net().to(device)

In [70]:
print('LOADING MODEL')
net, optimizer, epochs, loss = load_model(net, optimizer, './model')
net.train()
print(epochs)

LOADING MODEL
10


In [52]:
correct = 0
total = 0

for _, samples in enumerate(test_loader):
    images, labels = samples
    images = images.to(device)
    labels = labels.to(device)

    output = net(images)
    predictions = output.argmax(dim=1)

    correct += (predictions == labels).sum()
    total += labels.shape[0]

print(f'accuracy: {correct / total * 100:.3f}%')

    

accuracy: 86.318%


In [34]:
class Resnet(nn.Module):

    def __init__(self):
        super(Resnet, self).__init__()
        # first layer
        self.fc1 = nn.Linear(1*28*28, 1*28*28)
        # residual layer
        self.conv1 = nn.Conv2d(1, 1, 5, padding=2)
        # last layer
        self.fc2 = nn.Linear(1*28*28, 10)

    def forward(self, x):
        x_shape = x.shape
        x = F.relu(self.fc1(x.view(-1, 1*28*28))).view(x_shape)

        # resnet block
        x = F.relu(self.conv1(x)) + x

        x = self.fc2(x.view(-1, 1*28*28))       # vectorizing x to forward to the linear layer
        return x

net = Resnet().to(device)
t = torch.randn((64, 1, 28, 28)).to(device)
net(t).shape

torch.Size([64, 10])