In [None]:
from tensorboardX import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import os
from torchvision.transforms import Compose, ToTensor
from skimage.transform import resize
import torch
import torch.nn as nn
from scipy.stats import genpareto
import torch.nn.functional as F
from torch.autograd import Variable
from torch import LongTensor, FloatTensor
from torchvision.utils import save_image
import sys
import time

def convTBNReLU(in_channels, out_channels, kernel_size=4, stride=2, padding=1):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        ),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2, True),
    )


def convBNReLU(in_channels, out_channels, kernel_size=4, stride=2, padding=1):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        ),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2, True),
    )


class Generator(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Generator, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.block1 = convTBNReLU(in_channels, 512, 4, 1, 0)
        self.block2 = convTBNReLU(512, 256)
        self.block3 = convTBNReLU(256, 128)
        self.block4 = convTBNReLU(128, 64)
        self.block5 = nn.ConvTranspose2d(64, out_channels, 4, 2, 1)

    def forward(self, inp):
        out = self.block1(inp)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        return torch.tanh(self.block5(out))

latentdim = 20
G = Generator(in_channels=latentdim, out_channels=1).cuda()
genpareto_params = (-0.09095992649837537, 0.0052528357032590265, 0.26882173805170484)
threshold = 0.0428257
rv = genpareto(*genpareto_params)
torch.set_grad_enabled(False)

In [None]:
G.load_state_dict(torch.load('Generator.pt'))
G.eval()

In [None]:
c_k = 0.75**10
for prob in [0.95, 0.99, 0.999, 0.9999]:
    t = time.time()
    i = (prob - (1-c_k))/(c_k)
    val = rv.ppf(i) + threshold
    images = []
    count = 0
    while count<100:
        latent = Variable(FloatTensor(torch.randn((100, latentdim, 1, 1)))).cuda()
        image = G(latent)
        sums = torch.abs(20*(image.sum(dim=(1, 2, 3))/4096) + 19 - val) <= 0.1*val
        if sums.nonzero().shape[0] > 0:
            images.append(image[sums])
            count += sums.nonzero().shape[0]
    print(time.time() - t)
    images = torch.cat(images, 0)[:100]