In [None]:
import os
import random
import time
import warnings
import sys
import torch
import torchvision
import logging
import matplotlib.pyplot as plt
import numpy as np
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from torch.optim.lr_scheduler import MultiStepLR, ReduceLROnPlateau
from collections import defaultdict

from network.Discriminator import DiscriminatorNet
from network.HidingUNet import UnetGenerator
from data.load_data import get_loader
from utils.AverageMeter import AverageMeter
from utils.SSIM import SSIM
from utils.helper import *

'''
def SpecifiedLabel(OriginalLabel):
    targetlabel = OriginalLabel + 1
    targetlabel = targetlabel % 10
    return targetlabel
'''
GPU = '0'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU
cfg = load_config()
run_folder = create_folder_acsac(cfg.save_path)
# Print the configuration.
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[logging.FileHandler(os.path.join(run_folder, f'run.log')), logging.StreamHandler(sys.stdout)])
logging.info("Experiment Configuration:")
logging.info("CUDA_VISIBLE_DEVICES：{}".format(os.getenv('CUDA_VISIBLE_DEVICES')))
logging.info(cfg)
logging.info("run_folder:{}".format(run_folder))

if torch.cuda.is_available():
    cudnn.benchmark = True
    if cfg.seed is not None:
        np.random.seed(cfg.seed)
        random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        cudnn.deterministic = True
        '''
        warnings.warn('You have choosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
        '''

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_loader, val_loader, test_loader, cover_loader = get_loader(cfg, 'train'), get_loader(cfg, 'val'), get_loader(cfg, 'test'), get_loader(cfg, 'original_work')
logging.info("train_loader:{} val_loader:{} test_loader:{}\n".format(len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset)))

# logo
transform_test = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])])

mini_logo = torchvision.datasets.ImageFolder(
    root='./data/IEEE', transform=transform_test)
mini_loader = torch.utils.data.DataLoader(mini_logo, batch_size=1)

for _, (logo, __) in enumerate(mini_loader):
    secret_img = logo.expand(cfg.wm_batchsize, logo.shape[1], logo.shape[2], logo.shape[3]).cuda()

np_labels = np.random.randint(100, size=(int(cfg.wm_num[0]/cfg.wm_batchsize), cfg.wm_batchsize))
wm_labels = torch.from_numpy(np_labels).cuda()
logging.info("wm_labels:{}".format(wm_labels))
# load the 1% origin sample
# get the watermark-cover pairs for each batch
wm_inputs, wm_cover_labels = [], []
# wm_labels = []
if cfg.wm_train:
    for wm_idx, (wm_input, wm_cover_label) in enumerate(cover_loader):
        wm_input, wm_cover_label = wm_input.cuda(), wm_cover_label.cuda()
        wm_inputs.append(wm_input)
        wm_cover_labels.append(wm_cover_label)
        #wm_labels.append(SpecifiedLabel(wm_cover_label))

        if wm_idx == (int(cfg.wm_num[0]/cfg.wm_batchsize)-1):  # Choose 1% dataset as origin samples.
            break

# adversarial ground truths
valid = torch.cuda.FloatTensor(cfg.wm_batchsize, 1).fill_(1.0)
fake = torch.cuda.FloatTensor(cfg.wm_batchsize, 1).fill_(0.0)

# wm_labels = SpecifiedLabel()
best_real_acc, best_wm_acc, best_wm_input_acc = 0, 0, 0
start_epoch = 0

# model
Hidnet = UnetGenerator()
Disnet = DiscriminatorNet()
Dnnet = resnet18()
Hidnet = nn.DataParallel(Hidnet.cuda())
Disnet = nn.DataParallel(Disnet.cuda())
Dnnet = nn.DataParallel(Dnnet.cuda())

criterionH_mse = nn.MSELoss()
criterionH_ssim = SSIM()
optimizerH = optim.Adam(Hidnet.parameters(), lr=cfg.lr[0], betas=(0.5, 0.999))
schedulerH = ReduceLROnPlateau(optimizerH, mode='min', factor=0.2, patience=5, verbose=True)

criterionD = nn.BCELoss()
optimizerD = optim.Adam(Disnet.parameters(), lr=cfg.lr[0], betas=(0.5, 0.999))
schedulerD = ReduceLROnPlateau(optimizerD, mode='min', factor=0.2, patience=8, verbose=True)

criterionN = nn.CrossEntropyLoss()
optimizerN = optim.SGD(Dnnet.parameters(), lr=cfg.lr[1], momentum=0.9, weight_decay=5e-4)
schedulerN = MultiStepLR(optimizerN, milestones=[40, 80], gamma=0.1)


def train(epoch):
    epoch_start_time = time.time()
    logging.info('\nEpoch: %d' % epoch)
    Dnnet.train()
    Hidnet.train()
    Disnet.train()

    step = 1
    total_duration = 0
    wm_cover_correct, wm_correct, real_correct, wm_total, real_total, real_dis_correct, fake_dis_correct, dis_total = 0, 0, 0, 0, 0, 0, 0, 0
    loss_H_ = AverageMeter()
    loss_D_ = AverageMeter()
    loss_mse_ = AverageMeter()
    loss_ssim_ = AverageMeter()
    loss_DNN_ = AverageMeter()
    loss_real_ = AverageMeter()
    real_acc = AverageMeter()
    wm_acc = AverageMeter()
    dis_acc = AverageMeter()

    # Save relevant metrics in dictionaries.
    losses_dict = defaultdict(AverageMeter)
    metrics_dict = defaultdict(AverageMeter)

    for batch_idx, (input, label) in enumerate(train_loader):
        input, label = input.cuda(), label.cuda()
        wm_input = wm_inputs[(wm_idx + batch_idx) % len(wm_inputs)]
        wm_label = wm_labels[(wm_idx + batch_idx) % len(wm_inputs)]
        if batch_idx == 0:
            logging.info("wm_label:{}".format(wm_label))
        wm_cover_label = wm_cover_labels[(wm_idx + batch_idx) % len(wm_inputs)]
        #############Discriminator##############
        optimizerD.zero_grad()
        wm_img = Hidnet(wm_input, secret_img)
        wm_dis_output = Disnet(wm_img.detach())
        real_dis_output = Disnet(wm_input)

        fake_dis_pred = wm_dis_output.lt(0.5)
        real_dis_pred = real_dis_output.gt(0.5)

        loss_D_wm = criterionD(wm_dis_output, fake)
        loss_D_real = criterionD(real_dis_output, valid)
        loss_D = loss_D_wm + loss_D_real
        loss_D.backward()
        optimizerD.step()
        ################Hidding Net#############
        optimizerH.zero_grad()
        optimizerD.zero_grad()
        optimizerN.zero_grad()
        wm_dis_output = Disnet(wm_img)
        wm_dnn_output = Dnnet(wm_img)
        loss_mse = criterionH_mse(wm_input, wm_img)
        loss_ssim = criterionH_ssim(wm_input, wm_img)
        loss_adv = criterionD(wm_dis_output, valid)

        loss_dnn = criterionN(wm_dnn_output, wm_label)
        loss_H = cfg.hyper_parameters[0] * loss_mse + cfg.hyper_parameters[1] * (1-loss_ssim) + cfg.hyper_parameters[2] * loss_adv + cfg.hyper_parameters[3] * loss_dnn
        loss_H.backward()
        optimizerH.step()
        ################DNNet#############
        optimizerN.zero_grad()
        inputs = torch.cat([input, wm_img.detach()], dim=0)
        labels = torch.cat([label, wm_label], dim=0)
        dnn_output = Dnnet(inputs)

        loss_DNN = criterionN(dnn_output, labels)
        loss_DNN.backward()
        optimizerN.step()

        # calculate the accuracy
        wm_cover_output = Dnnet(wm_input)
        _, wm_cover_predicted = wm_cover_output.max(1)
        wm_cover_correct += wm_cover_predicted.eq(wm_cover_label).sum().item()

        _, wm_predicted = dnn_output[cfg.batchsize: cfg.batchsize + cfg.wm_batchsize].max(1)
        wm_correct += wm_predicted.eq(wm_label).sum().item()
        wm_total += cfg.wm_batchsize

        _, real_predicted = dnn_output[0:cfg.batchsize].max(1)
        real_correct += real_predicted.eq(labels[0:cfg.batchsize]).sum().item()
        real_total += cfg.batchsize

        real_dis_correct += real_dis_pred.eq(valid).sum().item()
        fake_dis_correct += fake_dis_pred.eq(fake).sum().item()
        dis_total += 2 * cfg.wm_batchsize

        loss_real = criterionN(dnn_output[0:cfg.batchsize], label)

        loss_H_.update(loss_H.item(), int(wm_input.size()[0]))
        loss_D_.update(loss_D.item(), int(wm_input.size()[0]))
        loss_mse_.update(loss_mse.item(), int(wm_input.size()[0]))
        loss_ssim_.update(loss_ssim.item(), int(wm_input.size()[0]))
        loss_DNN_.update(loss_DNN.item(), int(inputs.size()[0]))
        loss_real_.update(loss_real.item(), int(input.size()[0]))

        real_acc.update(100. * real_correct / real_total)
        wm_acc.update(100. * wm_correct / wm_total)
        dis_acc.update(100. * (real_dis_correct + fake_dis_correct) / dis_total)

        if step % cfg.print_freq == 0 or step == (len(train_loader)):
            logging.info('[%d/%d][%d/%d]  Loss_D: %.4f Loss_H: %.4f (mse: %.4f ssim: %.4f adv: %.4f)  Loss_DNN: %.4f loss_real: %.4f Real acc: %.3f  wm acc: %.3f dis_acc: %.3f' % (
                epoch, cfg.num_epochs, step, len(train_loader),
                loss_D.item(), loss_H.item(), loss_mse.item(), loss_ssim.item(), loss_adv.item(), loss_DNN.item(), loss_real.item(),
                100. * real_correct / real_total, 100. * wm_correct / wm_total, dis_acc.avg))

        step += 1

    losses_dict['loss_D'] = loss_D_.avg
    losses_dict['loss_H'] = loss_H_.avg
    losses_dict['loss_mse'] = loss_mse_.avg
    losses_dict['ssim'] = loss_ssim_.avg
    losses_dict['loss_DNN'] = loss_DNN_.avg
    losses_dict['loss_real'] = loss_real_.avg

    metrics_dict['real_acc'] = real_acc.avg
    metrics_dict['wm_acc'] = wm_acc.avg
    metrics_dict['dis_acc'] = dis_acc.avg

    total_duration = time.time() - epoch_start_time
    logging.info('Epoch {} total duration: {:.2f} sec'.format(epoch, total_duration))
    logging.info('-' * 130)

    write_scalars_acsac(epoch, os.path.join(run_folder, 'train.csv'), losses_dict, metrics_dict, total_duration)

    return losses_dict, metrics_dict


def test(epoch):
    epoch_start_time = time.time()
    Dnnet.eval()
    Hidnet.eval()
    Disnet.eval()
    global best_real_acc
    global best_wm_acc

    total_duration = 0
    wm_cover_correct, wm_correct, real_correct, real_total, wm_total, real_dis_correct, fake_dis_correct, dis_total = 0, 0, 0, 0, 0, 0, 0, 0
    loss_H_ = AverageMeter()
    loss_D_ = AverageMeter()
    loss_mse_ = AverageMeter()
    loss_ssim_ = AverageMeter()
    loss_DNN_ = AverageMeter()
    loss_real_ = AverageMeter()

    real_acc_ = AverageMeter()
    wm_acc_ = AverageMeter()
    dis_acc_ = AverageMeter()
    ste_psnr_ = AverageMeter()

    # Save relevant metrics in dictionaries.
    losses_dict = defaultdict(AverageMeter)
    metrics_dict = defaultdict(AverageMeter)

    with torch.no_grad():
        for batch_idx, (input, label) in enumerate(val_loader):
            input, label = input.cuda(), label.cuda()
            wm_input = wm_inputs[(wm_idx + batch_idx) % len(wm_inputs)]
            wm_label = wm_labels[(wm_idx + batch_idx) % len(wm_inputs)]
            if batch_idx == 0:
                logging.info("wm_label:{}".format(wm_label))
            wm_cover_label = wm_cover_labels[(wm_idx + batch_idx) % len(wm_inputs)]
            #############Discriminator###############
            wm_img = Hidnet(wm_input, secret_img)
            wm_dis_output = Disnet(wm_img.detach())
            real_dis_output = Disnet(wm_input)

            fake_dis_pred = wm_dis_output.lt(0.5)
            real_dis_pred = real_dis_output.gt(0.5)

            loss_D_wm = criterionD(wm_dis_output, fake)
            loss_D_real = criterionD(real_dis_output, valid)
            loss_D = loss_D_wm + loss_D_real

            ################Hidding Net#############
            wm_dnn_outputs = Dnnet(wm_img)
            loss_mse = criterionH_mse(wm_input, wm_img)
            loss_ssim = criterionH_ssim(wm_input, wm_img)
            loss_adv = criterionD(wm_dis_output, valid)

            loss_dnn = criterionN(wm_dnn_outputs, wm_label)
            loss_H = cfg.hyper_parameters[0] * loss_mse + cfg.hyper_parameters[1] * (1-loss_ssim) + cfg.hyper_parameters[2] * loss_adv + cfg.hyper_parameters[3] * loss_dnn
            ################DNNet#############
            inputs = torch.cat([input, wm_img.detach()], dim=0)
            labels = torch.cat([label, wm_label], dim=0)
            dnn_outputs = Dnnet(inputs)

            loss_DNN = criterionN(dnn_outputs, labels)

            wm_cover_output = Dnnet(wm_input)
            _, wm_cover_predicted = wm_cover_output.max(1)
            wm_cover_correct += wm_cover_predicted.eq(
                wm_cover_label).sum().item()

            _, wm_predicted = dnn_outputs[cfg.batchsize:
                                          cfg.batchsize + cfg.wm_batchsize].max(1)
            wm_correct += wm_predicted.eq(wm_label).sum().item()
            wm_total += cfg.wm_batchsize

            _, real_predicted = dnn_outputs[0:cfg.batchsize].max(1)
            real_correct += real_predicted.eq(
                labels[0:cfg.batchsize]).sum().item()
            real_total += cfg.batchsize

            real_dis_correct += real_dis_pred.eq(valid).sum().item()
            fake_dis_correct += fake_dis_pred.eq(fake).sum().item()
            dis_total += 2 * cfg.wm_batchsize

            loss_real = criterionN(dnn_outputs[0:cfg.batchsize], label)
            ste_psnr = cal_psnr(wm_input, wm_img.detach())

            loss_D_.update(loss_D.item(), int(wm_input.size()[0]))
            loss_H_.update(loss_H.item(), int(wm_input.size()[0]))
            loss_mse_.update(loss_mse.item(), int(wm_input.size()[0]))
            loss_ssim_.update(loss_ssim.item(), int(wm_input.size()[0]))
            loss_DNN_.update(loss_DNN.item(), int(inputs.size()[0]))
            loss_real_.update(loss_real.item(), int(input.size()[0]))

            real_acc_.update(100. * real_correct / real_total)
            wm_acc_.update(100. * wm_correct / wm_total)
            dis_acc_.update(100. * (real_dis_correct + fake_dis_correct) / dis_total)
            ste_psnr_.update(ste_psnr)

    logging.info('Loss_D: %.4f Loss_H: %.4f (mse: %.4f ssim: %.4f adv: %.4f) Loss_DNN: %.4f  loss_real: %.4f Real acc: %.3f  wm acc: %.3f dis_acc: %.3f ste_psnr: %.3f' % (
        loss_D.item(), loss_H.item(), loss_mse.item(), loss_ssim.item(), loss_adv.item(), loss_DNN.item(), loss_real.item(), 
        100. * real_correct / real_total, 100. * wm_correct / wm_total, dis_acc_.avg, ste_psnr_.avg))

    losses_dict['loss_D'] = loss_D_.avg
    losses_dict['loss_H'] = loss_H_.avg
    losses_dict['loss_mse'] = loss_mse_.avg
    losses_dict['ssim'] = loss_ssim_.avg
    losses_dict['loss_DNN'] = loss_DNN_.avg
    losses_dict['loss_real'] = loss_real_.avg

    metrics_dict['real_acc'] = real_acc_.avg
    metrics_dict['wm_acc'] = wm_acc_.avg
    metrics_dict['dis_acc'] = dis_acc_.avg
    metrics_dict['ste_psnr'] = ste_psnr_.avg

    real_acc = 100. * real_correct / real_total
    wm_acc = 100. * wm_correct / wm_total

    if real_acc >= best_real_acc:  # and (wm_acc >= best_wm_acc):
        save_all_models_acsac(epoch, run_folder, Hidnet, Disnet, Dnnet, optimizerH, optimizerD, optimizerN,
            losses_dict, metrics_dict, wm_labels)
        best_real_acc = real_acc
        logging.info("best_real_acc:{}".format(best_real_acc))

    if wm_acc > best_wm_acc:
        best_wm_acc = wm_acc
        logging.info("best_wm_acc:{}".format(best_wm_acc))

    total_duration = time.time() - epoch_start_time
    write_scalars_acsac(epoch, os.path.join(run_folder, 'val.csv'), losses_dict, metrics_dict, total_duration)
    save_cat_image_acsac(cfg, epoch, run_folder, wm_input, wm_img, secret_img, None, 'val')
    logging.info('Epoch {} total duration: {:.2f} sec'.format(epoch, total_duration))

    return losses_dict, metrics_dict


for epoch in range(cfg.num_epochs):
    train_losses_dict, train_metrics_dict = train(epoch)
    val_losses_dict, val_metrics_dict = test(epoch)
    plot_scalars(epoch, run_folder, train_losses_dict, train_metrics_dict, val_losses_dict, val_metrics_dict, None)
    schedulerH.step(val_losses_dict['loss_H'])
    schedulerD.step(val_losses_dict['loss_D'])
    schedulerN.step()

