# saveとload

## early stopping

In [1]:
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import Dataset
from sklearn import datasets
from sklearn.model_selection import train_test_split

In [24]:
class MyDataset(Dataset):
    def __init__(self, X, y, transform = None):
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X = self.X[idx]
        y = self.y[idx]

        if self.transform:
            X = self.transform(X)
        return X, y

class MLP(nn.Module):
    def __init__(self, num_in, num_hidden, num_out):
        super().__init__()
        self.flatten = nn.Flatten(1, -1)
        self.l1 = nn.Linear(num_in, num_hidden) # 隠れ層(第1層)を定義
        self.l2 = nn.Linear(num_hidden, num_out) # 隠れ層(第2層)を定義

    def forward(self, x):
        x = self.flatten(x)
        # z1 = self.l1(x) 
        # a1 = F.relu(z1)
        # z2 = self.l2(a1)
        x = self.l2(F.relu(self.l1(x)))
        return x

def learn(model, train_loader, val_loader, opt, loss_func, num_epoch, early_stopping = None):
    # ログ
    train_losses = []
    val_losses = []
    val_accuracies = []
    # early_stopping_counter
    no_improve = 0
    

    best_val_loss = float('inf')
    # モデル学習
    for epoch in range(num_epoch):
        running_loss = 0
        running_val_loss = 0
        running_val_accuracy = 0
        
        for train_batch, data in enumerate(train_loader):
    
            X, y = data
            opt.zero_grad() # 勾配初期化
            # forward
            preds = model(X)
            loss = loss_func(preds, y)
            running_loss += loss.item()
    
            # backward
            loss.backward()
            opt.step() # パラメータ更新
            
    
        with torch.no_grad():
            for val_batch, data in enumerate(val_loader):
                X_val, y_val = data
                preds_val = model(X_val)
                val_loss = loss_func(preds_val, y_val)
                running_val_loss += val_loss.item()
                val_accuracy = torch.sum(torch.argmax(preds_val, dim = -1) == y_val) / y_val.shape[0]
                running_val_accuracy += val_accuracy.item()
    
        train_losses.append(running_loss / (train_batch + 1))
        val_losses.append(running_val_loss / (val_batch + 1))
        val_accuracies.append(running_val_accuracy / (val_batch + 1))
        print(f'epoch:{epoch}, train error:{train_losses[-1]}, val_losses:{val_losses[-1]}, val_accuracy:{val_accuracies[-1]}')
        if val_losses[-1] < best_val_loss:
            no_improve = 0
            best_val_loss = val_losses[-1]
        else:
            no_improve += 1
        if early_stopping and no_improve >= early_stopping:
            print("stop training because val loss don't improve anymore")
            break

    return train_losses, val_losses, val_accuracies

In [25]:
# データ準備
dataset = datasets.load_digits()
data = dataset['data']
target = dataset['target']
images = dataset['images']
images = images * (255. / 16.) # 0~16 -> 0~255
images = images.astype(np.uint8)
# 学習データと検証データの作成
X_train, X_val, y_train, y_val = train_test_split(images, target, test_size = 0.2, random_state = 0)
# DatasetとDataLoaderの作成
batch_size = 32
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,), (0.5,))])
train_dataset = MyDataset(X_train, y_train, transform = transform)
val_dataset = MyDataset(X_val, y_val, transform = transform)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = 2)
val_loader = DataLoader(val_dataset, batch_size = batch_size, num_workers = 2)
# モデルの初期化
model = MLP(64, 30, 10)
# optimizerの定義
opt = optim.SGD(model.parameters(), lr = learning_rate)
learning_rate = 0.03

train_losses, val_losses, val_accuracies = learn(model, train_loader, val_loader, opt = opt, loss_func = F.cross_entropy, num_epoch = 1000, early_stopping=5)

epoch:0, train error:2.1843870057000054, val_losses:2.0654842456181846, val_accuracy:0.4036458333333333
epoch:1, train error:1.866288505660163, val_losses:1.7236454288164775, val_accuracy:0.6328125
epoch:2, train error:1.476594238811069, val_losses:1.3486510713895161, val_accuracy:0.7317708333333334
epoch:3, train error:1.099349820613861, val_losses:1.0151362419128418, val_accuracy:0.8463541666666666
epoch:4, train error:0.8169830269283719, val_losses:0.7726800094048182, val_accuracy:0.9010416666666666
epoch:5, train error:0.6271594590610928, val_losses:0.6124205912152926, val_accuracy:0.921875
epoch:6, train error:0.506532926691903, val_losses:0.5097397714853287, val_accuracy:0.9348958333333334
epoch:7, train error:0.421243777539995, val_losses:0.442648025850455, val_accuracy:0.9296875
epoch:8, train error:0.36365310847759247, val_losses:0.3789863313237826, val_accuracy:0.9348958333333334
epoch:9, train error:0.3212355573972066, val_losses:0.34919851397474605, val_accuracy:0.934895833

## モデルオブジェクトの保存とロード

In [27]:
torch.save(model, 'sample_model.pth')

In [28]:
loaded_model = torch.load('sample_model.pth')

In [29]:
loaded_model

MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (l1): Linear(in_features=64, out_features=30, bias=True)
  (l2): Linear(in_features=30, out_features=10, bias=True)
)

## モデルパラメータの保存とロード
- モデルのオブジェクトを保存するよりモデルパラメータを保存することが推奨されている

In [43]:
params = model.state_dict() # .parameters()はパラメータのイテレータを返す

In [44]:
another_model = MLP(64, 30, 10)

In [45]:
another_model.load_state_dict(params)

<All keys matched successfully>

In [46]:
model.l1.weight

Parameter containing:
tensor([[ 1.1282e-01, -6.3440e-02,  7.6734e-02,  ..., -4.8236e-02,
          7.9601e-02, -8.7400e-02],
        [ 5.9361e-02, -1.7456e-04,  1.8106e-01,  ...,  1.0305e-01,
          8.9167e-02,  4.8587e-03],
        [ 6.1278e-03, -1.2902e-01,  7.3366e-02,  ...,  6.2912e-03,
          9.6055e-02,  7.9921e-02],
        ...,
        [-1.9209e-01, -8.3390e-02, -2.9702e-01,  ...,  2.4603e-02,
          3.7131e-02,  1.3476e-01],
        [-1.8528e-01, -5.5052e-02,  4.7711e-02,  ...,  5.0124e-01,
          6.0677e-02,  9.4646e-02],
        [-1.1786e-01, -1.3595e-01,  7.4357e-02,  ..., -8.5267e-02,
         -1.3104e-01, -1.6461e-01]], requires_grad=True)

In [47]:
another_model.l1.weight

Parameter containing:
tensor([[ 1.1282e-01, -6.3440e-02,  7.6734e-02,  ..., -4.8236e-02,
          7.9601e-02, -8.7400e-02],
        [ 5.9361e-02, -1.7456e-04,  1.8106e-01,  ...,  1.0305e-01,
          8.9167e-02,  4.8587e-03],
        [ 6.1278e-03, -1.2902e-01,  7.3366e-02,  ...,  6.2912e-03,
          9.6055e-02,  7.9921e-02],
        ...,
        [-1.9209e-01, -8.3390e-02, -2.9702e-01,  ...,  2.4603e-02,
          3.7131e-02,  1.3476e-01],
        [-1.8528e-01, -5.5052e-02,  4.7711e-02,  ...,  5.0124e-01,
          6.0677e-02,  9.4646e-02],
        [-1.1786e-01, -1.3595e-01,  7.4357e-02,  ..., -8.5267e-02,
         -1.3104e-01, -1.6461e-01]], requires_grad=True)

In [48]:
torch.save(model.state_dict(), 'sample_model_state_dict.pth')

In [49]:
another_model.load_state_dict(torch.load('sample_model_state_dict.pth'))

<All keys matched successfully>

In [51]:
opt.state_dict()

{'state': {0: {'momentum_buffer': None},
  1: {'momentum_buffer': None},
  2: {'momentum_buffer': None},
  3: {'momentum_buffer': None}},
 'param_groups': [{'lr': 0.03,
   'momentum': 0,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'maximize': False,
   'foreach': None,
   'differentiable': False,
   'params': [0, 1, 2, 3]}]}

## 学習ループ中に最良のモデルを保存する

In [60]:
class MyDataset(Dataset):
    def __init__(self, X, y, transform = None):
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X = self.X[idx]
        y = self.y[idx]

        if self.transform:
            X = self.transform(X)
        return X, y

class MLP(nn.Module):
    def __init__(self, num_in, num_hidden, num_out):
        super().__init__()
        self.flatten = nn.Flatten(1, -1)
        self.l1 = nn.Linear(num_in, num_hidden) # 隠れ層(第1層)を定義
        self.l2 = nn.Linear(num_hidden, num_out) # 隠れ層(第2層)を定義

    def forward(self, x):
        x = self.flatten(x)
        # z1 = self.l1(x) 
        # a1 = F.relu(z1)
        # z2 = self.l2(a1)
        x = self.l2(F.relu(self.l1(x)))
        return x

def learn(model, train_loader, val_loader, opt, loss_func, num_epoch, early_stopping = None, save_path = None):
    # ログ
    train_losses = []
    val_losses = []
    val_accuracies = []
    # early_stopping_counter
    no_improve = 0
    

    best_val_loss = float('inf')
    # モデル学習
    for epoch in range(num_epoch):
        running_loss = 0
        running_val_loss = 0
        running_val_accuracy = 0
        
        for train_batch, data in enumerate(train_loader):
    
            X, y = data
            opt.zero_grad() # 勾配初期化
            # forward
            preds = model(X)
            loss = loss_func(preds, y)
            running_loss += loss.item()
    
            # backward
            loss.backward()
            opt.step() # パラメータ更新
            
    
        with torch.no_grad():
            for val_batch, data in enumerate(val_loader):
                X_val, y_val = data
                preds_val = model(X_val)
                val_loss = loss_func(preds_val, y_val)
                running_val_loss += val_loss.item()
                val_accuracy = torch.sum(torch.argmax(preds_val, dim = -1) == y_val) / y_val.shape[0]
                running_val_accuracy += val_accuracy.item()
    
        train_losses.append(running_loss / (train_batch + 1))
        val_losses.append(running_val_loss / (val_batch + 1))
        val_accuracies.append(running_val_accuracy / (val_batch + 1))
        print(f'epoch:{epoch}, train error:{train_losses[-1]}, val_losses:{val_losses[-1]}, val_accuracy:{val_accuracies[-1]}')
        if val_losses[-1] < best_val_loss:
            no_improve = 0
            best_val_loss = val_losses[-1]
            best_model = model
        else:
            no_improve += 1
        if early_stopping and no_improve >= early_stopping:
            print("stop training because val loss don't improve anymore")
            if save_path:
                torch.save({'model_parameter' : best_model.state_dict(), 'opt_parameter' : opt.state_dict(), 'val_loss' : best_val_loss}, save_path)
            break

    if save_path:
        torch.save({'model_parameter' : best_model.state_dict(), 'opt_parameter' : opt.state_dict(), 'val_loss' : best_val_loss}, save_path)
    return train_losses, val_losses, val_accuracies

In [61]:
# データ準備
dataset = datasets.load_digits()
data = dataset['data']
target = dataset['target']
images = dataset['images']
images = images * (255. / 16.) # 0~16 -> 0~255
images = images.astype(np.uint8)
# 学習データと検証データの作成
X_train, X_val, y_train, y_val = train_test_split(images, target, test_size = 0.2, random_state = 0)
# DatasetとDataLoaderの作成
batch_size = 32
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5,), (0.5,))])
train_dataset = MyDataset(X_train, y_train, transform = transform)
val_dataset = MyDataset(X_val, y_val, transform = transform)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = 2)
val_loader = DataLoader(val_dataset, batch_size = batch_size, num_workers = 2)
# モデルの初期化
model = MLP(64, 30, 10)
# optimizerの定義
opt = optim.SGD(model.parameters(), lr = learning_rate)
learning_rate = 0.03

train_losses, val_losses, val_accuracies = learn(model, train_loader, val_loader, opt = opt, loss_func = F.cross_entropy, num_epoch = 1000, early_stopping=5, save_path = 'checkpoint')

epoch:0, train error:2.2394752979278563, val_losses:2.1214182376861572, val_accuracy:0.2942708333333333
epoch:1, train error:1.946669496430291, val_losses:1.798375556866328, val_accuracy:0.6171875
epoch:2, train error:1.5539350509643555, val_losses:1.3932238121827443, val_accuracy:0.7630208333333334
epoch:3, train error:1.158011163605584, val_losses:1.0601645509401958, val_accuracy:0.8098958333333334
epoch:4, train error:0.8614217082659403, val_losses:0.8103910237550735, val_accuracy:0.8776041666666666
epoch:5, train error:0.6634238587485419, val_losses:0.6546802371740341, val_accuracy:0.8854166666666666
epoch:6, train error:0.5366636315981547, val_losses:0.5372151459256808, val_accuracy:0.9244791666666666
epoch:7, train error:0.44715239604314166, val_losses:0.4701428363720576, val_accuracy:0.8932291666666666
epoch:8, train error:0.3863742378022936, val_losses:0.4113098258773486, val_accuracy:0.9322916666666666
epoch:9, train error:0.34036011265383825, val_losses:0.3672925891975562, va

In [62]:
saved_model_dict = torch.load('checkpoint')

In [63]:
saved_model_dict

{'model_parameter': OrderedDict([('l1.weight',
               tensor([[-0.0242, -0.1255, -0.0174,  ...,  0.0572,  0.0815,  0.0332],
                       [ 0.0369, -0.0697,  0.0556,  ..., -0.0340,  0.1218,  0.1237],
                       [-0.0819, -0.0244,  0.0521,  ..., -0.0984,  0.1144, -0.0518],
                       ...,
                       [-0.1152, -0.1422, -0.3021,  ...,  0.2067,  0.0325, -0.0742],
                       [-0.1323, -0.2531, -0.3395,  ..., -0.1320, -0.1352, -0.0737],
                       [-0.1395, -0.1404, -0.0083,  ..., -0.0617,  0.0605, -0.0538]])),
              ('l1.bias',
               tensor([ 0.0377,  0.0483,  0.0062,  0.1766, -0.0049,  0.0060,  0.0848, -0.0733,
                        0.0310,  0.0682,  0.1656, -0.0051, -0.0700,  0.0230,  0.1556,  0.1181,
                        0.1604,  0.0165,  0.1670,  0.0256,  0.1723, -0.0251,  0.0231,  0.1918,
                        0.0003,  0.0688, -0.0796,  0.0889,  0.1414,  0.0748])),
              ('l2.we