# preset value

In [None]:
!rm -rf logs*

In [None]:
# 调试bug
# %xmode Verbose
%xmode Plain

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# Import Lib

In [None]:
from __future__ import print_function
%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
import torch.optim as optim
import torchvision.utils as vutils
from torchvision import datasets, transforms, models
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tensorboardX import SummaryWriter
import logging
import imageio
import shutil
from tqdm.notebook import tqdm
from IPython.display import HTML
import math
import time

# Setting parameters

In [None]:
if not os.path.exists('./logs/'):
    os.makedirs('./logs/')
#set logging dir
logging.getLogger().setLevel(logging.INFO)
fhandler = logging.FileHandler('./logs/BackwardForDCgan.log')
#tensorboard
writer = SummaryWriter('./logs/TensorboardLog')

In [None]:
workers = 4
max_epochs = 100 # 总共训练批次
batch_size = 1
image_size = 64
seed=999

save_model_interval = 5

DEVICE = torch.device("cuda" if torch.cuda.is_available() else
                      "cpu")  # 让torch判断是否使用GPU，建议使用GPU环境，因为会快很多
cudnn.benchmark = True

# Utils

In [None]:
def plt_image(imgs):
    plt.figure(figsize=(8, 8))
    plt.axis('off')
    plt.title('Training Images')
    see_image = np.transpose(
        vutils.make_grid(imgs, padding=2, normalize=True).cpu(), (1, 2, 0))
    plt.imshow(see_image)


def print_cuda_statistics():
    logger = logging.getLogger("Cuda Statistics")
    logger.addHandler(fhandler)
    import sys
    from subprocess import call
    import torch
    logger.info('__Python VERSION:  {}'.format(sys.version))
    logger.info('__pyTorch VERSION:  {}'.format(torch.__version__))
    logger.info('__CUDA VERSION')
    #     call(["nvcc", "--version"])
    logger.info('__CUDNN VERSION:  {}'.format(torch.backends.cudnn.version()))
    logger.info('__Number CUDA Devices:  {}'.format(torch.cuda.device_count()))
    logger.info('__Devices')
    #     call(["nvidia-smi", "--format=csv",
    #           "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free"])
    logger.info('Active CUDA Device: GPU {}'.format(
        torch.cuda.current_device()))
    logger.info('Available devices  {}'.format(torch.cuda.device_count()))
    logger.info('Current cuda device  {}'.format(torch.cuda.current_device()))


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


class AverageMeter:
    def __init__(self):
        self.value = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.reset()

    def reset(self):
        self.value = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.value = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    @property
    def val(self):
        return self.avg

In [None]:
def compare_num_tenToTwo(a, b, num_len):
    numberof2_list_a = str(bin(a)).split('0b')
    numberof2_list_b = str(bin(b)).split('0b')
    numberof2_list_a = list(numberof2_list_a[1])
    numberof2_list_b = list(numberof2_list_b[1])
    while len(numberof2_list_a) < num_len:
        numberof2_list_a = ['0'] + numberof2_list_a
    while len(numberof2_list_b) < num_len:
        numberof2_list_b = ['0'] + numberof2_list_b
    result = [
        int(numberof2_list_a[i]) - int(numberof2_list_b[i])
        for i in range(len(numberof2_list_a))
    ]
    return np.sum(np.abs(result))


def compute_BER(input_noise, output_noise, sigma):
    number_row = input_noise.size(0)
    number_line = input_noise.size(1)
    input_noise_to_msg = torch.floor((input_noise + 1) * 2**(sigma - 1))
    output_noise_to_msg = torch.floor((output_noise + 1) * 2**(sigma - 1))
    if len(input_noise_to_msg.shape) > 2:
        input_noise_to_msg = input_noise_to_msg.squeeze(3).squeeze(2)
    if len(output_noise_to_msg.shape) > 2:
        output_noise_to_msg = output_noise_to_msg.squeeze(3).squeeze(2)
    input_noise_to_msg_numpy = np.array(input_noise_to_msg.detach().cpu(),
                                        dtype=np.integer)
    output_noise_to_msg_numpy = np.array(output_noise_to_msg.detach().cpu(),
                                         dtype=np.integer)
    error_counter = 0
    for row in range(number_row):
        result = [
            compare_num_tenToTwo(input_noise_to_msg_numpy[row][i],
                                 output_noise_to_msg_numpy[row][i], sigma)
            for i in range(number_line)
        ]
        error_counter += np.sum(result)
    ber = error_counter / (number_row * number_line * sigma)
    return ber

# model

In [None]:
class Generator(nn.Module):
    '''
    Input: (N, 100, 1, 1)
    deconv1: (N, 512, 4, 4)       ==> H/16, W/16
    deconv2: (N, 256, 8, 8)       ==> H/8, W/8
    deconv3: (N, 128, 16, 16)     ==> H/4, W/4
    deconv4: (N, 64, 32, 32)      ==> H/2, W/2
    deconv5: (N, 3, 64, 64)       ==> H, W
    out: (N, 3, 64, 64)
    '''
    def __init__(self):
        super().__init__()

        self.relu = nn.ReLU(inplace=True)

        self.deconv1 = nn.ConvTranspose2d(in_channels=100,
                                          out_channels=64 * 8,
                                          kernel_size=4,
                                          stride=1,
                                          padding=0,
                                          bias=False)
        self.batch_norm1 = nn.BatchNorm2d(64 * 8)

        self.deconv2 = nn.ConvTranspose2d(in_channels=64 * 8,
                                          out_channels=64 * 4,
                                          kernel_size=4,
                                          stride=2,
                                          padding=1,
                                          bias=False)
        self.batch_norm2 = nn.BatchNorm2d(64 * 4)

        self.deconv3 = nn.ConvTranspose2d(in_channels=64 * 4,
                                          out_channels=64 * 2,
                                          kernel_size=4,
                                          stride=2,
                                          padding=1,
                                          bias=False)
        self.batch_norm3 = nn.BatchNorm2d(64 * 2)

        self.deconv4 = nn.ConvTranspose2d(in_channels=64 * 2,
                                          out_channels=64,
                                          kernel_size=4,
                                          stride=2,
                                          padding=1,
                                          bias=False)
        self.batch_norm4 = nn.BatchNorm2d(64)

        self.deconv5 = nn.ConvTranspose2d(in_channels=64,
                                          out_channels=3,
                                          kernel_size=4,
                                          stride=2,
                                          padding=1,
                                          bias=False)

        self.tanh = nn.Tanh()

        self.apply(weights_init)

    def forward(self, x):
        in_size = x.size(0)

        out = self.deconv1(x)
        out = self.batch_norm1(out)
        out = self.relu(out)

        out = self.deconv2(out)
        out = self.batch_norm2(out)
        out = self.relu(out)

        out = self.deconv3(out)
        out = self.batch_norm3(out)
        out = self.relu(out)

        out = self.deconv4(out)
        out = self.batch_norm4(out)
        out = self.relu(out)

        out = self.deconv5(out)

        out = self.tanh(out)

        return out


def test_Generator():
    inp = torch.randn(batch_size, 100, 1, 1) * 2 - 1
    print(inp.shape)
    netG = Generator()
    out = netG(inp)
    print(out.shape)


test_Generator()

In [None]:
class Extract(nn.Module):
    def __init__(self):
        super().__init__()

        self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        self.conv1 = nn.Conv2d(in_channels=3,
                               out_channels=64,
                               kernel_size=4,
                               stride=2,
                               padding=1,
                               bias=False)

        self.conv2 = nn.Conv2d(in_channels=64,
                               out_channels=64 * 2,
                               kernel_size=4,
                               stride=2,
                               padding=1,
                               bias=False)
        self.batch_norm2 = nn.BatchNorm2d(64 * 2)

        self.conv3 = nn.Conv2d(in_channels=64 * 2,
                               out_channels=64 * 4,
                               kernel_size=4,
                               stride=2,
                               padding=1,
                               bias=False)
        self.batch_norm3 = nn.BatchNorm2d(64 * 4)

        self.conv4 = nn.Conv2d(in_channels=64 * 4,
                               out_channels=64 * 8,
                               kernel_size=4,
                               stride=2,
                               padding=1,
                               bias=False)
        self.batch_norm4 = nn.BatchNorm2d(64 * 8)

        self.conv5 = nn.Conv2d(in_channels=64 * 8,
                               out_channels=1,
                               kernel_size=4,
                               stride=1,
                               padding=0,
                               bias=False)

        self.linear1 = nn.Linear(in_features=64 * 8 * 4 * 4,
                                 out_features=100,
                                 bias=True)

        self.tanh = nn.Tanh()

        self.apply(weights_init)

    def forward(self, x):
        in_size = x.size(0)

        out = self.conv1(x)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.batch_norm2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.batch_norm3(out)
        out = self.relu(out)

        out = self.conv4(out)
        out = self.batch_norm4(out)
        out = self.relu(out)

        out = out.view(in_size, -1)
        out = self.linear1(out)
        out = self.tanh(out)

        return out


def test_Extract():
    inp = torch.rand(batch_size, 3, 64, 64)
    print(inp.shape)
    netE = Extract()
    out = netE(inp)
    print(out.shape)


test_Extract()

# Metrics

In [None]:
class MSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = nn.MSELoss()

    def forward(self, logits, labels):
        loss = self.loss(logits, labels)
        return loss

# Save_Load_Model

In [None]:
class Save_Load_Model():
    def __init__(self):
        self.checkpoint_dir = './checkpoint_dir/'
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

    def load_checkpoint(self, agent, file_name):
        path_filename = self.checkpoint_dir + file_name
        try:
            agent.logger.info("Loading checkpoint '{}'".format(path_filename))
            checkpoint = torch.load(path_filename)

            agent.current_epoch = checkpoint['epoch']
            agent.current_iteration = checkpoint['iteration']
            agent.netG.load_state_dict(checkpoint['G_state_dict'])
            agent.optimG.load_state_dict(checkpoint['G_optimizer'])
            agent.netE.load_state_dict(checkpoint['netE_state_dict'])
            agent.fixed_noise = checkpoint['fixed_noise']
            agent.manual_seed = checkpoint['manual_seed']

            agent.logger.info(
                "Checkpoint loaded successfully from '{}' at (epoch {}) at (iteration {})\n"
                .format(self.checkpoint_dir, checkpoint['epoch'],
                        checkpoint['iteration']))
        except OSError as e:
            agent.logger.info(
                "No checkpoint exists from '{}'. Skipping...".format(
                    self.checkpoint_dir))
            agent.logger.info("**First time to train**")

    def save_checkpoint(self,
                        agent,
                        file_name='backwardDCGAN.pth.tar',
                        is_best=0):

        state = {
            'epoch': agent.current_epoch,
            'iteration': agent.current_iteration,
            'G_state_dict': agent.netG.state_dict(),
            'G_optimizer': agent.optimG.state_dict(),
            'netE_state_dict': agent.netE.state_dict(),
            'fixed_noise': agent.fixed_noise,
            'manual_seed': agent.manualSeed
        }
        # save the satate
        save_file_name = self.checkpoint_dir + file_name
        torch.save(state, save_file_name)

        if is_best:
            shutil.copyfile(self.checkpoint_dir + file_name,
                            checkpoint_dir + 'model_best.pth.tar')

# Agent

In [None]:
class BackwardForDCganAgent:
    def __init__(self):
        # define log
        self.logger = logging.getLogger('BackwardForDCganAgent')
        self.logger.addHandler(fhandler)
        # define models
        self.netG = Generator()
        self.netE = Extract()
        # define batch_size
        self.batch_size = batch_size
        # define loss
        self.loss = MSELoss()
        # define optimizers for both generator and discriminator
        self.optimG = torch.optim.Adam(self.netG.parameters(),
                                       lr=0.0002,
                                       betas=(0.5, 0.999))
        self.optimE = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                              self.netE.parameters()),
                                       lr=0.0002,
                                       betas=(0, 0.999))
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimG,
            'min',
            factor=0.5,
            patience=5,
            verbose=True,
            threshold=0.0001)
        # initialize counter
        self.current_epoch = 0
        self.current_iteration = 0
        # set the mannual seed for torch
        self.manualSeed = seed
        # manualSeed = random.randint(1, 10000) # use if you want new results
        print("Random Seed: ", self.manualSeed)
        random.seed(self.manualSeed)
        torch.manual_seed(self.manualSeed)
        self.fixed_noise = torch.rand(
            self.batch_size, 100, 1, 1, requires_grad=True) * 2 - 1
        #copy data to cuda or cpu
        self.is_cuda = torch.cuda.is_available()
        if self.is_cuda and not torch.cuda.is_available():
            self.logger.info(
                'WARNING: You have a CUDA device, so you should probably enable CUDA'
            )

        if self.is_cuda:
            self.logger.info('Program will run on *****GPU*****')
            torch.cuda.manual_seed_all(self.manualSeed)
            print_cuda_statistics()
            torch.cuda.set_device(0)
            self.fixed_noise = self.fixed_noise.cuda(async=True)
            self.device = torch.device('cuda')
        else:
            self.logger.info('Program will run on *****CPU******')
            self.device = torch.device('cpu')
        #define model
        self.netG = self.netG.to(self.device)
        self.netE = self.netE.to(self.device)
        self.loss = self.loss.to(self.device)
        #load model
        self.save_Load_Model = Save_Load_Model()
        # self.save_Load_Model.load_checkpoint(self, 'backwardDCGAN.pth.tar')

    def run(self):
        """
        This funciton will the operator
        """
        try:
            self.train()
        except KeyboardInterrupt:
            self.logger.info('You have entered CTRL+C..。 Wait to finalize')

    def train(self):
        start_time = time.perf_counter()
        ori_img = self.netG(self.fixed_noise.detach())
        for epoch in range(self.current_epoch, max_epochs):
            self.current_epoch = epoch
            self.train_one_epoch(ori_img)
            loss_valid = self.valid(ori_img)
            self.scheduler.step(loss_valid)
            # save the model
            if self.current_epoch % save_model_interval == 0:
                self.save_Load_Model.save_checkpoint(self)
            # stop if time too long
            if (time.perf_counter() - start_time) > 29880:
                self.save_Load_Model.save_checkpoint(self)
                break

    def train_one_epoch(self, ori_img):
        # initialize tqdm batch
        train_list = list(range(10))
        tqdm_batch = tqdm(train_list,
                          total=len(train_list),
                          desc="epoch-{}-".format(self.current_epoch))
        self.netG.train()
        self.netE.eval()
        epoch_loss_noise = AverageMeter()
        epoch_loss_img = AverageMeter()
        epoch_loss_G_all = AverageMeter()
        fake_noise = self.fixed_noise
        if self.is_cuda:
            fake_noise = fake_noise.cuda(async=True)
        #################################
        # Update Extract network
        # train
        for curr_it, _ in enumerate(tqdm_batch):
            self.netG.zero_grad()
            G_out = self.netG(fake_noise.detach())
            netE_out = self.netE(G_out)
            loss_G_noise = self.loss(netE_out, fake_noise.detach().squeeze())
            loss_G_img = self.loss(G_out, ori_img.detach())
            loss_G = 0.9 * loss_G_noise + 0.1 * loss_G_img
            loss_G.backward()
            self.optimG.step()
            #record loss
            epoch_loss_noise.update(loss_G_noise.item())
            epoch_loss_img.update(loss_G_img.item())
            epoch_loss_G_all.update(loss_G.item())
            #increase the iteration
            self.current_iteration += 1
        #################################
        tqdm_batch.close()
        #compute BER
        BER_1 = compute_BER(fake_noise, netE_out, sigma=1)
        BER_2 = compute_BER(fake_noise, netE_out, sigma=2)
        BER_3 = compute_BER(fake_noise, netE_out, sigma=3)
        #log BER
        self.logger.info("Training at epoch -" + str(self.current_epoch) +
                         "|" + "loss:" + str(epoch_loss_G_all.val) + "|" +
                         "BER_1:" + str(BER_1) + "| BER_2:" + str(BER_2) +
                         "| BER_3:" + str(BER_3))
        #tensorboard
        writer.add_scalar('Train/loss_G', epoch_loss_G_all.val,
                          self.current_epoch)
        writer.add_scalars(
            'Train/loss_G_imgANDnosie', {
                'epoch_loss_noise': epoch_loss_noise.val,
                'epoch_loss_img': epoch_loss_img.val,
            }, self.current_epoch)
        writer.add_scalars('Train/BERs', {
            'BER1': BER_1,
            'BER2': BER_2,
            'BER3': BER_3,
        }, self.current_epoch)
        writer.add_scalar('lr_netG', self.optimG.param_groups[0]['lr'],
                          self.current_epoch)
        for name, param in self.netG.named_parameters():
            writer.add_histogram(name + "_netG",
                                 param.clone().cpu().data.numpy(),
                                 self.current_epoch)

    def valid(self, ori_img):
#         self.netG.eval()
#         self.netE.eval()
        fake_noise = self.fixed_noise
        #################################
        G_out = self.netG(fake_noise)
        netE_out = self.netE(G_out)
        loss_G_noise = self.loss(netE_out, fake_noise.squeeze())
        loss_G_img = self.loss(G_out.detach(), ori_img.detach())
        loss_valid = 0.9 * loss_G_noise + 0.1 * loss_G_img
        #################################
        #compute BER
        BER_1 = compute_BER(fake_noise, netE_out, sigma=1)
        BER_2 = compute_BER(fake_noise, netE_out, sigma=2)
        BER_3 = compute_BER(fake_noise, netE_out, sigma=3)
        #tensorboard
        self.logger.info("Valid at epoch -" + str(self.current_epoch) + "|" +
                         "loss:" + str(loss_valid.item()) + "|" + "BER_1:" +
                         str(BER_1) + "|" + "BER_2:" + str(BER_2) + "|" +
                         "BER_3:" + str(BER_3))
        writer.add_scalar('Valid/loss_E', loss_valid.item(),
                          self.current_epoch)
        writer.add_scalars('Valid/BERs', {
            'BER1': BER_1,
            'BER2': BER_2,
            'BER3': BER_3,
        }, self.current_epoch)
        return loss_valid

# main

In [None]:
agent = BackwardForDCganAgent()

In [None]:
state = torch.load('../input/checkpoint/checkpointE.pth.tar9000',
                   map_location=torch.device('cpu'))
# state = torch.load('../input/checkpoint/checkpointE.pth.tar1500')
agent.netG.load_state_dict(state['G_state_dict'])
agent.netE.load_state_dict(state['netE_state_dict'])

In [None]:
fn = agent.fixed_noise
# fn = torch.rand(batch_size, 100, 1, 1,requires_grad=True).to(DEVICE)*2-1
img = agent.netG(fn)
plt_image(img.detach().to(DEVICE)[:64])
netE_out = agent.netE(img)
for i in range(1, 4):
    ber = compute_BER(fn, netE_out, sigma=i)
    print('BER of {} is: {}'.format(i, ber))

In [None]:
agent.run()

# clean

In [None]:
!tar zcf logs.tar logs