## 模型保存

In [1]:
import torch
import numpy as np
import torch.nn as nn

**模型保存函数**

    torch.save(obj,f)

- obj

    对象(可以是模型，张量，参数)

- f

    输出路径
    
**使用1：保存整个模型及参数**

    torch.save(net,path)
    
**使用2：保存模型参数**

    state_dict=net.state_dict()
    torch.save(state_dict,path)

In [4]:
'''
模型
'''


class LeNet2(nn.Module):
    def __init__(self, classes):
        super(LeNet2, self).__init__()
        self.features = nn.Sequential(nn.Conv2d(3, 6, 5), nn.ReLU(),
                                      nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, 5),
                                      nn.ReLU(), nn.MaxPool2d(2, 2))
        self.classifier = nn.Sequential(nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
                                        nn.Linear(120, 84), nn.ReLU(),
                                        nn.Linear(84, classes))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

    def initialize(self):
        for p in self.parameters():
            p.data.fill_(20191104)

In [5]:
net = LeNet2(classes=2019)

# "训练"
print("训练前: ", net.features[0].weight[0, ...])
net.initialize()
print("训练后: ", net.features[0].weight[0, ...])

'''
模型保存方法
'''

# 设置路径
path_model = "./model.pkl" 
path_state_dict = "./model_state_dict.pkl"

# 保存整个模型
torch.save(net, path_model) 

# 保存模型参数
net_state_dict = net.state_dict()
torch.save(net_state_dict, path_state_dict)

训练前:  tensor([[[ 0.0244,  0.0320,  0.0236,  0.0429, -0.0642],
         [-0.0121,  0.0216,  0.0315,  0.0691, -0.1069],
         [ 0.0681,  0.0427,  0.0456, -0.0759, -0.0435],
         [ 0.0251,  0.0125,  0.0396, -0.0089, -0.0837],
         [ 0.0997,  0.0329, -0.0125,  0.0471, -0.0490]],

        [[-0.1042, -0.0166,  0.0783, -0.0810, -0.0347],
         [-0.1133,  0.1128,  0.0365,  0.0527, -0.1127],
         [-0.0754, -0.0348,  0.1081,  0.0082, -0.0328],
         [ 0.0291, -0.0005, -0.0295, -0.1043,  0.0293],
         [ 0.0808,  0.0860, -0.0920,  0.0826,  0.0746]],

        [[ 0.0645,  0.0414,  0.0894,  0.0812, -0.1095],
         [-0.0164, -0.0709, -0.1120, -0.0308, -0.0463],
         [ 0.0895, -0.0027, -0.0596, -0.0174,  0.1027],
         [ 0.1024,  0.1076,  0.0820, -0.0939, -0.0687],
         [ 0.0550, -0.0546,  0.1137,  0.0184, -0.0103]]],
       grad_fn=<SelectBackward>)
训练后:  tensor([[[20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 

  "type " + obj.__name__ + ". It won't be checked "


## 模型加载

**模型加载函数**

    torch.load(f,map_location)

- f

    文件路径
    
- 指定存放位置

    cpu gpu

In [6]:
# ================================== load net ===========================
'''
整个模型的加载
'''
path_model = "./model.pkl"
net_load = torch.load(path_model)

print(net_load)

LeNet2(
  (features): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=400, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=2019, bias=True)
  )
)


In [13]:
'''
使用加载的模型
'''
x = torch.randn((1, 3, 32, 32))
out = net_load(x)
out.shape

torch.Size([1, 2019])

In [14]:
# ================================== load state_dict ===========================


path_state_dict = "./model_state_dict.pkl"
state_dict_load = torch.load(path_state_dict)

print(state_dict_load.keys()) # 打印参数字典

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.2.weight', 'classifier.2.bias', 'classifier.4.weight', 'classifier.4.bias'])


In [15]:
# ================================== update state_dict ===========================

net_new = LeNet2(classes=2019)

print("加载前: ", net_new.features[0].weight[0, ...])
net_new.load_state_dict(state_dict_load) # 模型加载字典
print("加载后: ", net_new.features[0].weight[0, ...])

加载前:  tensor([[[-0.0671,  0.0104,  0.0223,  0.0123,  0.0489],
         [ 0.0446,  0.0674, -0.0812,  0.0379, -0.0725],
         [ 0.0326, -0.0005, -0.0052,  0.0763,  0.0650],
         [ 0.0428,  0.0140,  0.0897, -0.1132,  0.0294],
         [-0.0729,  0.0889,  0.1053,  0.0051, -0.0066]],

        [[-0.0961, -0.0159,  0.1070,  0.0081,  0.0922],
         [-0.0859, -0.1024, -0.1002, -0.0764, -0.0575],
         [-0.0020,  0.0323,  0.0384, -0.0291,  0.0323],
         [-0.0884, -0.0826, -0.0143, -0.0699,  0.0784],
         [-0.1066,  0.0144, -0.1136,  0.1069,  0.0435]],

        [[ 0.0553, -0.0419, -0.0099, -0.0948, -0.0688],
         [ 0.0023, -0.1136, -0.0485, -0.0658,  0.0608],
         [-0.0482, -0.0460,  0.0321,  0.0239, -0.0422],
         [-0.0965,  0.0671, -0.0280,  0.0786,  0.0731],
         [ 0.0229,  0.0736, -0.0972,  0.0803, -0.0036]]],
       grad_fn=<SelectBackward>)
加载后:  tensor([[[20191104., 20191104., 20191104., 20191104., 20191104.],
         [20191104., 20191104., 20191104., 

## 断点续训练

- 可以在训练完每个epoch之后，保存下epoch,optimizer,net的信息。

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(checkpoint, CHECKPOINT_FILE)

In [None]:
# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1


# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)     # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

start_epoch = -1
for epoch in range(start_epoch+1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率
    
    '''
    此段代码加入断点保存
    '''
    if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)
    
    '''
    此段代码加入断点保存
    '''

    if epoch > 5:
        print("训练意外中断...")
        break

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total))

- 如果需要从上次状态接着训练的话：

      if resume:
         # 恢复上次的训练状态
         print("Resume from checkpoint...")
         checkpoint = torch.load(CHECKPOINT_FILE)
         net.load_state_dict(checkpoint['model_state_dict'])
         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
         initepoch = checkpoint['epoch']+1
         #从上次记录的损失和正确率接着记录
         dict = torch.load(ACC_LOSS_FILE)
         loss_record = dict['loss']
         acc_record = dict['acc']

In [None]:
# 参数设置
checkpoint_interval = 5
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1


# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)     # 设置学习率下降策略


# ============================ step 5+/5 断点恢复 ============================

path_checkpoint = "./checkpoint_4_epoch.pkl"
checkpoint = torch.load(path_checkpoint)

net.load_state_dict(checkpoint['model_state_dict'])

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

start_epoch = checkpoint['epoch']

scheduler.last_epoch = start_epoch

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(start_epoch + 1, MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    if (epoch+1) % checkpoint_interval == 0:

        checkpoint = {"model_state_dict": net.state_dict(),
                      "optimizer_state_dic": optimizer.state_dict(),
                      "loss": loss,
                      "epoch": epoch}
        path_checkpoint = "./checkpint_{}_epoch.pkl".format(epoch)
        torch.save(checkpoint, path_checkpoint)

    # if epoch > 5:
    #     print("训练意外中断...")
    #     break

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val/len(valid_loader), correct / total))