In [None]:
import os
import re
import sys
import time
import random
import logging
import argparse
import warnings
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from scipy import misc
from torch.utils.data import Dataset, DataLoader

from load_data.pair_loader import *
from algorithm.SRNet import Srnet
from utils.helper import *
from utils.AverageMeter import AverageMeter

cfg = load_config()
device = torch.device(cfg.device)
run_folder = create_folder(cfg.run_folder)
# Print the configuration.
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[
    logging.FileHandler(os.path.join(run_folder, f'run.log')), logging.StreamHandler(sys.stdout)])
logging.info("Experiment Configuration:")
logging.info("CUDA_VISIBLE_DEVICES：{}".format(os.getenv('CUDA_VISIBLE_DEVICES')))
logging.info(cfg)
logging.info("run_folder:{}".format(run_folder))

warnings.filterwarnings("ignore")

# to be reproducible
if torch.cuda.is_available():
    cudnn.benchmark = True
    if cfg.seed is not None:
        np.random.seed(cfg.seed)  # Numpy module.
        random.seed(cfg.seed)  # Python random module.
        torch.manual_seed(cfg.seed)  # Sets the seed for generating random numbers.
        torch.cuda.manual_seed(cfg.seed)  # Sets the seed for generating random numbers for the current GPU.
        torch.cuda.manual_seed_all(cfg.seed)  # Sets the seed for generating random numbers on all GPUs.
        cudnn.deterministic = True

        warnings.warn('You have choosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
        logging.info('torch.cuda is available!')


def adjust_learning_rate(epoch):
    """Sets the learning rate to the initial LR decays by 10 every 80 epochs"""
    lr = cfg.lr * (0.1 ** (epoch // 80))
    optimizer.param_groups[0]['lr'] = lr


# Weight initialization for conv layers and fc layers
def inti_SRNet_weights(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform(m.weight.data)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.2)
    elif isinstance(m, nn.Linear):
        torch.nn.init.normal_(m.weight.data, mean=0., std=0.01)
        torch.nn.init.constant_(m.bias.data, 0.)


# model creation and initialization
ana_model = Srnet()
ana_model.to(device)
ana_model = ana_model.apply(inti_SRNet_weights)
criterion_ce = nn.CrossEntropyLoss()
optimizer = torch.optim.Adamax(ana_model.parameters(), lr=cfg.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0)
# data loaders
train_data = Dataset_Load(cfg.cover_path, cfg.stego_path, cfg.train_size,
                                      transform=transforms.Compose([
                                          transforms.Resize((cfg.resize, cfg.resize)),
                                          transforms.Grayscale(),
                                          dataset.ToTensor()]))

val_data = Dataset_Load(cfg.valid_cover_path, cfg.valid_stego_path, cfg.val_size,
                                      transform=transforms.Compose([
                                          transforms.Resize((cfg.resize, cfg.resize)),
                                          transforms.Grayscale(),
                                          dataset.ToTensor()]))

train_loader = DataLoader(train_data, batch_size=cfg.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=cfg.batch_size, shuffle=False)


def main():

    if cfg.has_ckp is True:
        ana_ckp = torch.load("")  # Fill in the path of the model checkpoint.
        ana_model.load_state_dict(ana_ckp[''])  # model weights
        optimizer.load_state_dict(ana_ckp[''])  # parameters of the optimizer
        logging.info("Have loaded analyzer checkpoint.")

    for epoch in range(cfg.start_epoch, cfg.num_epochs + 1):
        ana_model.train()
        st_time = time.time()
        adjust_learning_rate(epoch)

        # Save relevant metrics in dictionaries.
        train_losses_dict = defaultdict(AverageMeter)
        train_acc_dict = defaultdict(AverageMeter)
        val_losses_dict = defaultdict(AverageMeter)
        val_acc_dict = defaultdict(AverageMeter)

        step = 1

        for i, train_batch in enumerate(train_loader):
            images = torch.cat((train_batch['cover'], train_batch['stego']), 0).to(device, dtype=torch.float)
            labels = torch.cat((train_batch['label'][0], train_batch['label'][1]), 0).to(device, dtype=torch.long)

            outputs = ana_model(images)

            ana_loss = criterion_ce(outputs, labels)

            optimizer.zero_grad()
            ana_loss.backward()
            optimizer.step()

            prediction = outputs.data.max(1)[1]
            accuracy = prediction.eq(labels.data).sum() * 100.0 / (labels.size()[0])

            train_losses_dict['ana_loss'].update(ana_loss.item(), images.size(0))
            train_acc_dict['accuracy'].update(accuracy.item(), images.size(0))

            if step % cfg.print_freq == 0 or step == (len(train_loader)):
                logging.info('\r Epoch:[%d/%d] Batch:[%d/%d] Loss:[%.4f] Acc:[%.2f] lr:[%.4f]'
                                 % (epoch, cfg.num_epochs, i + 1, len(train_loader), train_losses_dict['ana_loss'].avg, train_acc_dict['accuracy'].avg,
                                 optimizer.param_groups[0]['lr']))
            step += 1

        end_time = time.time()
        logging.info("Epoch: {} Training time: {:.2f}".format(epoch, end_time - st_time))
        logging.info("========================validation========================")
        ana_model.eval()
        with torch.no_grad():
            step = 1
            for i, val_batch in enumerate(val_loader):
                images = torch.cat((val_batch['cover'], val_batch['stego']), 0).to(device, dtype=torch.float)
                labels = torch.cat((val_batch['label'][0], val_batch['label'][1]), 0).to(device, dtype=torch.long)

                outputs = ana_model(images)

                ana_loss = criterion_ce(outputs, labels)

                prediction = outputs.data.max(1)[1]
                accuracy = prediction.eq(labels.data).sum() * 100.0 / (labels.size()[0])

                val_losses_dict['ana_loss'].update(ana_loss.item(), images.size(0))
                val_acc_dict['accuracy'].update(accuracy.item(), images.size(0))

                if step % cfg.print_freq == 0 or step == (len(val_loader)):
                    logging.info('\r Epoch:[%d/%d] Batch:[%d/%d] Loss:[%.4f] Acc:[%.2f]'
                                     % (epoch, cfg.num_epochs, i + 1, len(val_loader), val_losses_dict['ana_loss'].avg, val_acc_dict['accuracy'].avg))
                step += 1
        logging.info("========================validation========================")

        plot_scalars(epoch, run_folder, train_losses_dict, train_acc_dict, val_losses_dict, val_acc_dict)
        save_checkpoint(epoch, run_folder, ana_model, optimizer, val_losses_dict, val_acc_dict)


if __name__ == '__main__':
    main()

