In [None]:
from tensorboardX import SummaryWriter
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import itertools
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, ToTensor
from skimage.transform import resize
from torchvision import transforms
import torch.optim as optim
from torch import LongTensor, FloatTensor
from scipy.stats import skewnorm, genpareto
from torchvision.utils import save_image
import sys

class NWSDataset(Dataset):
    """
    NWS Dataset
    """

    def __init__(
        self, fake='ENHDecrLambdLin/fake10.pt', val=128
    ):
        self.real = torch.load('data/real.pt').cuda()
        self.fake = torch.load(fake).cuda()
        self.realdata = torch.cat([self.real[:val], self.fake[:-val]], 0)
        
    def __len__(self):
        return self.realdata.shape[0]

    def __getitem__(self, item):
        img = self.realdata[item]
        return img, 20*(img.sum().view(-1, 1)/4096)+19


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)

def convTBNReLU(in_channels, out_channels, kernel_size=4, stride=2, padding=1):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        ),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2, True),
    )


def convBNReLU(in_channels, out_channels, kernel_size=4, stride=2, padding=1):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        ),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2, True),
    )


class Generator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Generator, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.block1 = convTBNReLU(in_channels + 1, 512, 4, 1, 0)
        self.block2 = convTBNReLU(512, 256)
        self.block3 = convTBNReLU(256, 128)
        self.block4 = convTBNReLU(128, 64)
        self.block5 = nn.ConvTranspose2d(64, out_channels, 4, 2, 1)

    def forward(self, latent, continuous_code):
        inp = torch.cat((latent, continuous_code), 1)
        out = self.block1(inp)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        return torch.tanh(self.block5(out))


class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.in_channels = in_channels
        self.block1 = convBNReLU(self.in_channels, 64)
        self.block2 = convBNReLU(64, 128)
        self.block3 = convBNReLU(128, 256)
        self.block4 = convBNReLU(256, 512)
        self.block5 = nn.Conv2d(512, 64, 4, 1, 0)
        self.source = nn.Linear(64+1, 1)

    def forward(self, inp, extreme):
        out = self.block1(inp)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.block5(out)
        size = out.shape[0]
        out = out.view(size, -1)
        sums = 20*(inp.sum(dim=(1, 2, 3))/4096)+19
        diff = extreme.view(size, 1) - sums.view(size, 1)
        source = torch.sigmoid(self.source(torch.cat([out, diff.view(size, 1)], 1)))
        return source

latentdim = 20
criterionSource = nn.BCELoss()
criterionContinuous = nn.L1Loss()
criterionValG = nn.L1Loss()
criterionValD = nn.L1Loss()
G = Generator(in_channels=latentdim, out_channels=1).cuda()
D = Discriminator(in_channels=1).cuda()
G.apply(weights_init_normal)
D.apply(weights_init_normal)
genpareto_params = (-0.07241964091641762, 0.0007877211367414993, 0.25543821779344433)
threshold = 0.08699417
rv = genpareto(*genpareto_params)

def sample_genpareto(size):
    return FloatTensor(rv.ppf(FloatTensor(*size).uniform_(0, 1))) + threshold

def sample_cont_code(batch_size):
    return Variable(sample_genpareto((batch_size, 1, 1, 1))).cuda()

optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.999))
static_code = sample_cont_code(81)

In [None]:
G.load_state_dict(torch.load('infoPD/Gepoch999.pt'))
G.eval()

In [None]:
for prob in [0.95, 0.99, 0.999, 0.9999]:
    t = time.time()
    i = (prob - 0.95)/(0.75**10)
    val = rv.ppf(i) + threshold
    code = Variable(torch.ones(100, 1, 1, 1)*val).cuda()
    latent = Variable(FloatTensor(torch.randn((100, latentdim, 1, 1)))).cuda()
    images = G(latent, code)
    print(time.time() - t)
    torch.save(0.5*(images+1), 'Images/infoPD'+str(prob)+'.pt')