# 16. How to Save and Load a Model

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.utils
import torchvision.datasets as dsets
import torchvision.transforms as transforms

import numpy as np
import os

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

## 16.1 Define Model

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.conv_layer = nn.Sequential(
            nn.Conv2d(3, 32, 5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )

        self.fc_layer = nn.Sequential(
            nn.Linear(64*5*5, 100),
            nn.ReLU(),
            nn.Linear(100, 10)              
        )
        
    def forward(self, x):
        out = self.conv_layer(x)
        out = out.view(-1, 64*5*5)
        out = self.fc_layer(out)
        
        return out
    
model = CNN().cuda()

## 16.2 Save Model

In [4]:
torch.save(model, "sample1.pth")

  "type " + obj.__name__ + ". It won't be checked "


In [5]:
torch.save(model.state_dict(), "sample2.pth")

## 16.3 Load Model

In [6]:
torch.load("sample1.pth")

CNN(
  (conv_layer): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layer): Sequential(
    (0): Linear(in_features=1600, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=10, bias=True)
  )
)

In [7]:
# Class가 변경되거나 없어지면 무조건 에러 발생
del CNN
torch.load("sample1.pth")

AttributeError: Can't get attribute 'CNN' on <module '__main__'>

In [8]:
# Class가 변경되거나 없어져도 불러올 수 있음. 최악의 경우에도 유추 가능.
torch.load("sample2.pth")

OrderedDict([('conv_layer.0.weight',
              tensor([[[[-0.0990,  0.1003,  0.0784, -0.0463,  0.0550],
                        [ 0.0777,  0.0056, -0.0169,  0.0741,  0.0613],
                        [-0.1140, -0.0130, -0.0947,  0.0434,  0.0541],
                        [ 0.0888, -0.1151, -0.0670, -0.0473,  0.1092],
                        [-0.0753, -0.0176, -0.0103, -0.0004,  0.0531]],
              
                       [[ 0.0235, -0.1059, -0.1046,  0.0692, -0.0182],
                        [-0.0098,  0.0274, -0.0665, -0.0386, -0.0824],
                        [-0.0674, -0.0335, -0.0380, -0.0557,  0.1107],
                        [-0.1061,  0.0850, -0.0545,  0.0390, -0.0513],
                        [-0.0689,  0.0334, -0.0873, -0.0368,  0.0655]],
              
                       [[-0.0539, -0.1048,  0.0259, -0.0206,  0.0901],
                        [ 0.1142,  0.0408,  0.0435, -0.0873,  0.0765],
                        [-0.0734,  0.0803,  0.0538,  0.0710,  0.1020],
        

In [9]:
# 물론 최종적으로는 Class가 있어야함.
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.conv_layer = nn.Sequential(
            nn.Conv2d(3, 32, 5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(32, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )

        self.fc_layer = nn.Sequential(
            nn.Linear(64*5*5, 100),
            nn.ReLU(),
            nn.Linear(100, 10)              
        )
        
    def forward(self, x):
        out = self.conv_layer(x)
        out = out.view(-1, 64*5*5)
        out = self.fc_layer(out)
        
        return out
    
model = CNN().cuda()
model.load_state_dict(torch.load("sample2.pth"))

<All keys matched successfully>

In [10]:
model

CNN(
  (conv_layer): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layer): Sequential(
    (0): Linear(in_features=1600, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=10, bias=True)
  )
)