In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from torch.backends import cudnn
import utils
import matplotlib.pyplot as plt

from LWENet import lwenet
from train_test_valid import train, test, valid, save_model

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
train_cover_path = "data/BOSS_train/cover"
train_stego_path = "data/BOSS_train/stego"
valid_path = "data/BOSS_valid"
test_path = "data/BOSS_test"

model_save_path = "models/"

参数

In [None]:
batch_size = {'train': 10, 'valid': 50, 'test': 50}
lr = 0.01
weight_decay = 0.001
momentum = 0.9
log_interval = 50  # 多少个batch打印一次
save_interval = 10  # 多少个epoch保存一次
epochs = 80

train_title = "example"

数据增强

In [None]:
train_transform = transforms.Compose([utils.AugData(),utils.ToTensor()])

In [None]:
kwargs = {'num_workers': 0, 'pin_memory': True} if torch.cuda.is_available() else {}
train_data= utils.DatasetPair(train_cover_path, train_stego_path, train_transform)
valid_data= datasets.ImageFolder(valid_path, transform=transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]))
test_data= datasets.ImageFolder(test_path, transform=transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]))

train_loader = DataLoader(train_data,batch_size=batch_size['train'], shuffle=True, **kwargs)
valid_loader = DataLoader(valid_data,batch_size=batch_size['valid'], shuffle=False, **kwargs)
test_loader = DataLoader(test_data,batch_size=batch_size['test'], shuffle=True, **kwargs)

加载模型

In [None]:
model = lwenet()

In [None]:
model = model.to(device)

性能优化

In [None]:
cudnn.benchmark = True
cudnn.deterministic = False

使用Kaiming方法初始化权重

In [None]:
def initWeights(module):
    if type(module) == nn.Conv2d:
        if module.weight.requires_grad:
            nn.init.kaiming_normal_(module.weight.data, mode='fan_in', nonlinearity='relu')
            
            
model.apply(initWeights);

筛选出需要训练的多维参数，对其做权重衰减

In [None]:
params = model.parameters()
params_wd, params_rest = [], []
for param_item in params:
    if param_item.requires_grad:
        (params_wd if param_item.dim() != 1 else params_rest).append(param_item)

param_groups= [{'params': params_wd, 'weight_decay': weight_decay},
                    {'params': params_rest}]

使用SGD


In [None]:
optimizer = optim.SGD(param_groups, lr=lr, momentum=momentum)

在抵达milestones时将学习率衰减为原来的$\gamma$倍

In [None]:
DECAY_EPOCH = [80,140,180]
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=DECAY_EPOCH, gamma=0.1)

训练

In [None]:
valid_acc_list = []
valid_loss_list = []
best_valid_loss = torch.inf

for epoch in range(epochs):
    train(model, epoch, train_loader, batch_size['train'], device, optimizer, scheduler, log_interval)
    valid_acc, valid_loss = valid(model, device, valid_loader)
    valid_acc_list.append(valid_acc)
    valid_loss_list.append(valid_loss)
    if (epoch + 1) % save_interval == 0:
        model_path = model_save_path + train_title + str(epoch) + '.pth'
        save_model(model, model_path)
        print("model saved at {}".format(model_path))
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        model_path = model_save_path + train_title + '_BEST' + '.pth'
        save_model(model, model_path)
        print("best model saved at {}".format(model_path))

In [None]:
test(model, device, test_loader)

In [None]:
HILL_test_path = "data/HILL_test"

HILL_test_data = datasets.ImageFolder(HILL_test_path, transform=transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]))
HILL_test_loader = torch.utils.data.DataLoader(HILL_test_data,batch_size=batch_size['test'], shuffle=True, **kwargs)

In [None]:
test(model, device, HILL_test_loader)

In [None]:
plt.plot([i for i in range(len(valid_acc_list))], valid_acc_list)
for epc in DECAY_EPOCH:
    plt.axvline(x=epc, color='red', linestyle='--', linewidth=1)
plt.savefig(f'plots/{train_title}_valid_acc.png')
plt.show()
plt.close()

In [None]:
plt.plot([i for i in range(len(valid_loss_list))], valid_loss_list)
for epc in DECAY_EPOCH:
    plt.axvline(x=epc, color='red', linestyle='--', linewidth=1)
plt.savefig(f'plots/{train_title}_valid_loss.png')
plt.show()
plt.close()