# 训练模型测试

In [1]:
import os
from torchvision import transforms 
import torch.nn as nn
import torch
from tqdm import tqdm_notebook as tqdm
from Dataset import YTBDatasetVer,YTBDatasetCNN
from Network import NANNet,CNNNet
import numpy as np
from util import evaluate

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] ='1' # 设置跑第几个GPU
# 使用cuda运算
device=torch.device("cuda")

### ROC曲线绘制函数

In [3]:
def plot_roc(fpr, tpr, figure_name="roc.png"):
    import matplotlib.pyplot as plt
    from sklearn.metrics import roc_curve, auc
    roc_auc = auc(fpr, tpr)
    fig = plt.figure()
    lw = 2
    plt.plot(fpr, tpr, color='darkorange',
             lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic')
    plt.legend(loc="lower right")
    fig.savefig(os.path.join("./", figure_name), dpi=fig.dpi)

### 初始化数据集

In [4]:
dataset = YTBDatasetVer(csv_file='../splits.txt', root_dir='../aligned_images_DB', img_size=224,num_frames=100)
dataload = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=10, num_workers=2)

### 初始化萌新（模型）

In [5]:
model=NANNet(cnn_path='./checkpoints/cnn_modelacc0.9982.pth',ave_pool=True).to(device)
model = model.train()

Load Pretrained weight successfully


### 查看可以更新的参数

In [6]:
for name, param in model.named_parameters():
  if param.requires_grad:
    print(name)

attention.q
attention.fc.weight
attention.fc.bias


### 读取存储好的NAN模型权值

In [7]:
# model.load_state_dict(torch.load("nan_model_bat.pth"))

# 测试

In [8]:
# model.init_weights()
acc_max = 0
# optimizer = torch.optim.Adadelta(model.parameters(),lr=0.05)
# optimizer =torch.optim.RMSprop(model.parameters(),lr=0.001, weight_decay=1e-5)
# optimizer = torch.optim.Adadelta(model.parameters(),lr=0.05)
# optimizer=torch.optim.Adagrad(model.parameters(), lr=0.0005, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
for epoch in range(300):
    total_loss = 0
    total_size = 0
    bar = tqdm(dataload)
    labels, distances = [], []
    for i, (x1, x2, y) in enumerate(bar):
#         optimizer.zero_grad()
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)
        _,_,l2 = model(x1, x2)
        total_size += l2.size(0)
#         loss.backward()
#         optimizer.step()
        # b=pred.item()
        distances.append(l2.detach().data.cpu().numpy())
        labels.append(y.cpu().numpy())
#         total_loss += loss.item()

        bar.set_postfix(epoch=f"{epoch+1}")
    labels = np.concatenate(labels)
    distances = np.concatenate(distances)

    tpr, fpr, accuracy, val, val_std, far = evaluate(distances, labels)
    print('\33[91mTrain set: Accuracy: {:.8f}\n\33[0m'.format(np.mean(accuracy)))
    plot_roc(fpr, tpr, figure_name="roc_train_epoch_{}.png".format(epoch))

    acc = np.mean(accuracy)
    torch.save(model.state_dict(), "nan_model.pth")
#     if acc_max < acc:
#         acc_max = max(acc, acc_max)
#         torch.save(model.state_dict(), f"./checkpoints/nan_model_acc{acc_max:0.4f}.pth")
#     if acc>0.8:
#         optimizer = torch.optim.Adadelta(model.parameters(),lr=1-acc)

HBox(children=(IntProgress(value=0, max=450), HTML(value='')))


[91mTrain set: Accuracy: 0.65911111
[0m


HBox(children=(IntProgress(value=0, max=450), HTML(value='')))

Process Process-4:
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 106, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/lulu/Dataset/face_rec/Dataset.py", line 178, in __getitem__
    data_face1 = self.load_face_from_dir(face_dir)
  File "/home/lulu/Dataset/face_rec/Dataset.py", line 138, in load_face_from_dir
    img = self.rgb_transform(img)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 60, in __call__
    img = t(img)
  File "/usr/local/lib/python3.6/di

KeyboardInterrupt: 

### 测试人脸验证