# Flatness v.s. Generalization (Part 2)

## 采用的模型

| Model                                                        | Parameter |
| deep model：  `Conv(16, (3x3))->Pool(2)->Conv(32, (3x3))->Conv(64, (3x3))->Conv(64, (3x3))->Pool(2)->Conv(38, (3x3))->Conv(32, (3x3))->Conv(10, (1x1))->Pool(8)` |   93990   |

> 注：上述除最后一个卷积之外，每个卷积之后 加BatchNormalization和ReLU激活函数
>
> 训练采用的参数：`optimizer=Adam(lr=1e-3， amsgrad=True), epoch=100 `

In [1]:
import sys
sys.path.insert(0, '..')
import os
import torch
from argparse import Namespace
from torch import optim
from model import cross_entropy_loss
from model import deep_cifar
from dataset import cifar_train_loader, cifar_validate_loader
from utils import eval_sensitivity, Logger
from solver import HW1Solver

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
# 参数设置
root = os.path.join(os.path.expanduser('~'), 'data')
config = {'batch_size': 16, 'epochs': 100, 'resume': '', 'verbosity': 1, 'use_cuda': True,
          'lr': 1e-3, 'save_dir': '../pretrained/1_3_4', 'save_freq': 100, 'save_grad': False,
          'data_dir': root, 'dataset': 'cifar', 'valid': False, 'val_step': 1,
          'visdom': False, 'visdom_iter': True, 'visdom_fit': False}
config = Namespace(**config)
batch_list = [2 ** i for i in range(4, 11)]

In [None]:
# 训练阶段:  这部分时间可能达到几个小时; 
# ---友情提示: 你也可以直接使用提供的已训练好的模型, 那么你就别运行这部分了 ---
for batch_size in batch_list:
    config.batch_size = batch_size
    train_loader = cifar_train_loader(root, config.batch_size)
    model = deep_cifar('cifar_{}'.format(batch_size))
    optimizer = optim.Adam(model.parameters(), config.lr, amsgrad=True)
    solver = HW1Solver(model, optimizer, cross_entropy_loss, [], train_loader, None, config)
    solver.train()

In [None]:
# 计算sensitive --- 可能需要耗费几分钟
for b in batch_list:
    info = {'batch': b}
    checkpoints = torch.load('{}/cifar_{}_epoch{}.pth.tar'.format(config.save_dir, b, config.epochs))
    logger = checkpoints['logger']
    train_loss = [entry['loss'] for _, entry in logger.entries.items()][-1]
    net.load_state_dict(checkpoints['state_dict'])
    loss, sensitivity = eval_sensitivity(net, dataloader)
    info.update({'train_loss': train_loss, 'val_loss': loss, 'sensitivity': sensitivity})
    print(info)
    log.add_entry(info)
torch.save(log, save_dir+'/sensitivity.pth.tar')

In [None]:
# 可视化结果
