# 训练模型测试

### 初始化

In [None]:
import os
from torchvision import transforms 
import torch.nn as nn
import torch
from tqdm import tqdm_notebook as tqdm
from Dataset import YTBDatasetVer,YTBDatasetCNN
from Network import NANNet,CNNNet
import numpy as np
from util import evaluate

os.environ['CUDA_VISIBLE_DEVICES'] ='3' # 设置跑第几个GPU
# 使用cuda运算
device=torch.device("cuda")

### ROC曲线绘制函数

In [None]:
def plot_roc(fpr, tpr, figure_name="roc.png"):
    import matplotlib.pyplot as plt
    from sklearn.metrics import roc_curve, auc
    roc_auc = auc(fpr, tpr)
    fig = plt.figure()
    lw = 2
    plt.plot(fpr, tpr, color='darkorange',
             lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic')
    plt.legend(loc="lower right")
    fig.savefig(os.path.join("./", figure_name), dpi=fig.dpi)

### 初始化数据集

In [None]:
dataset = YTBDatasetVer(csv_file='../splits.txt', root_dir='../aligned_images_DB', img_size=224,num_frames=100)
dataload = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=8, num_workers=2)

### 初始化萌新（模型）

In [None]:
model=NANNet(cnn_path='./checkpoints/cnn_modelacc0.9982.pth').to(device)
model = model.train()

### 查看可以更新的参数

In [None]:
for name, param in model.named_parameters():
  if param.requires_grad:
    print(name)

### 读取存储好的NAN模型权值

In [None]:
# model.load_state_dict(torch.load("nan_model_fix_CNN.pth"))
model.load_state_dict(torch.load("./checkpoints/nan_model_acc0.9551.pth"))

# 训练

In [None]:
# model.init_weights()
acc_max = 0
# optimizer =torch.optim.RMSprop(model.parameters(),lr=0.001, weight_decay=1e-5)              #last
optimizer =torch.optim.RMSprop(model.parameters(),lr=1e-4, weight_decay=1e-6)              #last
# optimizer = torch.optim.Adadelta(model.parameters(),lr=0.05)
# optimizer=torch.optim.Adagrad(model.parameters(), lr=0.0005, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
for epoch in range(300):
    total_loss = 0
    total_size = 0
    bar = tqdm(dataload)
    labels, distances = [], []
    for i, (x1, x2, y) in enumerate(bar):
        optimizer.zero_grad()
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)
        l2, loss = model.process(x1, x2, y)
        total_size += x1.size(0)
        loss.backward()
        optimizer.step()
        # b=pred.item()
        distances.append(l2.detach().data.cpu().numpy())
        labels.append(y.cpu().numpy())
        total_loss += loss.item()

        bar.set_postfix(loss=f"{total_loss/(i+1):0.4f}",
                        epoch=f"{epoch+1}")
    labels = np.concatenate(labels)
    distances = np.concatenate(distances)

    tpr, fpr, accuracy, val, val_std, far = evaluate(distances, labels)
    print('\33[91mTrain set: Accuracy: {:.8f}\n\33[0m'.format(np.mean(accuracy)))
    plot_roc(fpr, tpr, figure_name="roc_train_epoch_{}.png".format(epoch))

    acc = np.mean(accuracy)
    torch.save(model.state_dict(), "nan_model.pth")
    if acc_max < acc:
        acc_max = max(acc, acc_max)
        torch.save(model.state_dict(), f"./checkpoints/nan_model_acc{acc_max:0.4f}.pth")

### 测试集

In [None]:
dataset = YTBDatasetVer(csv_file='../splits.txt', root_dir='../aligned_images_DB', img_size=224,num_frames=100,train=False)
dataload = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=2)
model=NANNet(cnn_path='./cnn_model6.pth').to(device)
model.load_state_dict(torch.load("./checkpoints/nan_model_acc0.9551.pth"))
model = model.eval()

acc_max = 0
for epoch in range(300):
    total_loss = 0
    total_size = 0
    bar = tqdm(dataload)
    labels, distances = [], []
    for i, (x1, x2, y) in enumerate(bar):
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)
        l2, loss = model.process(x1, x2, y)
        total_size += x1.size(0)
        distances.append(l2.detach().data.cpu().numpy())
        labels.append(y.cpu().numpy())
        total_loss += loss.item()
        bar.set_postfix(loss=f"{total_loss/(i+1):0.4f}",
                        epoch=f"{epoch+1}")
    labels = np.concatenate(labels)
    distances = np.concatenate(distances)

    tpr, fpr, accuracy, val, val_std, far = evaluate(distances, labels)
    print('\33[91mTrain set: Accuracy: {:.8f}\n\33[0m'.format(np.mean(accuracy)))
    plot_roc(fpr, tpr, figure_name="roc_train_epoch_{}.png".format(epoch))

    acc = np.mean(accuracy)
    torch.save(model.state_dict(), "nan_model.pth")
    if acc_max < acc:
        acc_max = max(acc, acc_max)
        torch.save(model.state_dict(), f"./checkpoints/nan_model_acc{acc_max:0.4f}.pth")

### 多GPU训练

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] ='2，3'
model=nn.DataParallel(model,device_ids=['0','1']) 
# model.init_weights()
acc_max = 0
# optimizer = torch.optim.Adadelta(model.parameters(),lr=0.05)
# optimizer =torch.optim.RMSprop(model.parameters(),lr=0.001, weight_decay=1e-5)              #last
optimizer =torch.optim.RMSprop(model.parameters(),lr=1e-4, weight_decay=1e-6)              #last
# optimizer = torch.optim.Adadelta(model.parameters(),lr=0.05)
# optimizer=torch.optim.Adagrad(model.parameters(), lr=0.0005, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
for epoch in range(300):
    total_loss = 0
    total_size = 0
    bar = tqdm(dataload)
    labels, distances = [], []
    for i, (x1, x2, y) in enumerate(bar):
        optimizer.zero_grad()
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)
        l2, loss = model.process(x1, x2, y)
        total_size += x1.size(0)
        loss.backward()
        optimizer.step()
        # b=pred.item()
        distances.append(l2.detach().data.cpu().numpy())
        labels.append(y.cpu().numpy())
        total_loss += loss.item()

        bar.set_postfix(loss=f"{total_loss/(i+1):0.4f}",
                        epoch=f"{epoch+1}")
    labels = np.concatenate(labels)
    distances = np.concatenate(distances)

    tpr, fpr, accuracy, val, val_std, far = evaluate(distances, labels)
    print('\33[91mTrain set: Accuracy: {:.8f}\n\33[0m'.format(np.mean(accuracy)))
    plot_roc(fpr, tpr, figure_name="roc_train_epoch_{}.png".format(epoch))

    acc = np.mean(accuracy)
    torch.save(model.state_dict(), "nan_model.pth")
    if acc_max < acc:
        acc_max = max(acc, acc_max)
        torch.save(model.state_dict(), f"./checkpoints/nan_model_acc{acc_max:0.4f}.pth")
torch.save(model.module.state_dict(), path)

### 测试人脸验证

# 更换CNN模型进行的训练

In [None]:

### 初始化

import os
from torchvision import transforms 
import torch.nn as nn
import torch
from tqdm import tqdm_notebook as tqdm
from Dataset import YTBDatasetVer,YTBDatasetCNN
from Network import NANNet,CNNNet
import numpy as np
from util import evaluate

os.environ['CUDA_VISIBLE_DEVICES'] ='3' # 设置跑第几个GPU
# 使用cuda运算
device=torch.device("cuda")

### ROC曲线绘制函数

def plot_roc(fpr, tpr, figure_name="roc.png"):
    import matplotlib.pyplot as plt
    from sklearn.metrics import roc_curve, auc
    roc_auc = auc(fpr, tpr)
    fig = plt.figure()
    lw = 2
    plt.plot(fpr, tpr, color='darkorange',
             lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic')
    plt.legend(loc="lower right")
    fig.savefig(os.path.join("./", figure_name), dpi=fig.dpi)

### 初始化数据集

dataset = YTBDatasetVer(csv_file='../splits.txt', root_dir='../aligned_images_DB', img_size=224,num_frames=100)
dataload = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=4, num_workers=0)


### 改变读取desnet的cnn模型
import torchvision.models as models
model=NANNet().to(device)
model.CNN.cnn_model=models.densenet161(num_classes=128).to(device)
model.CNN.load_state_dict(torch.load('./checkpoints/densenet_cnn_acc0.9983.pth'))
for param in model.CNN.parameters():
    param.requires_grad = False
model = model.train()

### 查看可以更新的参数

for name, param in model.named_parameters():
  if param.requires_grad:
    print(name)

### 读取存储好的NAN模型权值

model.load_state_dict(torch.load("./checkpoints/desnet_nan_acc0.6200.pth"))

# 训练

# model.init_weights()
acc_max = 0
# optimizer =torch.optim.RMSprop(model.parameters(),lr=0.001, weight_decay=1e-5)              #last
optimizer =torch.optim.RMSprop(model.parameters(),lr=1e-1, weight_decay=1e-6)              #last
# optimizer = torch.optim.Adadelta(model.parameters(),lr=0.05)
# optimizer=torch.optim.Adagrad(model.parameters(), lr=0.0005, lr_decay=0, weight_decay=0, initial_accumulator_value=0)
for epoch in range(300):
    total_loss = 0
    total_size = 0
    bar = tqdm(dataload)
    labels, distances = [], []
    for i, (x1, x2, y) in enumerate(bar):
        optimizer.zero_grad()
        x1, x2, y = x1.to(device), x2.to(device), y.to(device)
        l2, loss = model.process(x1, x2, y)
        total_size += x1.size(0)
        loss.backward()
        optimizer.step()
        # b=pred.item()
        distances.append(l2.detach().data.cpu().numpy())
        labels.append(y.cpu().numpy())
        total_loss += loss.item()

        bar.set_postfix(loss=f"{total_loss/(i+1):0.4f}",
                        epoch=f"{epoch+1}")
    labels = np.concatenate(labels)
    distances = np.concatenate(distances)

    tpr, fpr, accuracy, val, val_std, far = evaluate(distances, labels)
    print('\33[91mTrain set: Accuracy: {:.8f}\n\33[0m'.format(np.mean(accuracy)))
    plot_roc(fpr, tpr, figure_name="roc_train_epoch_{}.png".format(epoch))

    acc = np.mean(accuracy)
    torch.save(model.state_dict(), "desnet_nan.pth")
    if acc_max < acc:
        acc_max = max(acc, acc_max)
        torch.save(model.state_dict(), f"./checkpoints/desnet_nan_acc{acc_max:0.4f}.pth")