출처 : https://tutorials.pytorch.kr/beginner/saving_loading_models.html

In [1]:
import numpy as np
from copy import deepcopy

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset


In [2]:
class TheModelClass(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [18]:
# model 
model = TheModelClass()

# 각 층에 대한 weight는 다음과 같이 접근
conv1_weight = model.conv1.weight

# 모델 전체의 weight : OrderedDict 형태
model_weights = model.state_dict()

for param_tensor in model_weights:
    print(param_tensor, "\t", model_weights[param_tensor].size())

conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])


In [19]:
# save 
filepath = 'temp.pt' # or pth
torch.save(model.state_dict(), filepath) # model 자체를 save할 수 있음(같은 확장자)

# load, model class가 어딘가에 선언되어 있어야 함
model.load_state_dict(torch.load(filepath))

# optimizer 역시 동일한 방법으로 save, load 가능

# 다른 구조의 모델에 load하는 경우 strict=False로 설정해 일치하지 않는 키들은 무시
model2 = TheModelClass()
model2.load_state_dict(torch.load(filepath), strict=False)

<All keys matched successfully>

In [None]:
# 추론/학습 재개를 위해 checkpoint save
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [None]:
# 여러 모델을 하나의 파일에 저장
torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

## 학습과정 중간에 best model을 저장
validation loss에 따라 가장 성능이 좋은 모델을 유지하되 학습은 정상적으로 이어나간다면 deepcopy를 써야한다. <br>best_model_state = model.state_dict()만 사용한다면 best_model_state가 학습 단계에서 계속 갱신된다.

In [None]:
if score > best_score:
    best_score = score
    best_model_wts = deepcopy(model.state_dict())

In [2]:
# pytorchtools의 EarlyStopping 클래스를 사용 --> best model을 저장하는 기능까지 있음
from source.cifar_dataloader import get_cifar10_dataloader
from source import functions
from source import EarlyStopping

early_stopping = EarlyStopping.EarlyStopping(patience=7, verbose=False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_classes = 10
batch_size = 8
num_epochs = 15
model_name = 'resnet'
path = './../data/'

models, input_size = functions.get_model(model_name, pretrained=True, transfer=True, feature_extract=True, num_classes=num_classes)
dataloader = get_cifar10_dataloader(path, input_size, 16, 4)

  return torch._C._cuda_getDeviceCount() > 0


Files already downloaded and verified
Files already downloaded and verified


In [3]:
model = models.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [None]:
# best model 반환 -> source.functions 참고
model_ft, hist = functions.train_model(
    model = models, 
    dataloaders = dataloader, 
    criterion = criterion, 
    optimizer = optimizer, 
    device=device, 
    num_epochs=num_epochs, 
    is_inception=(model_name=="inception"),
    early_stop = early_stopping
    )
    