In [None]:
# Needed to plot rainfall maps. Restart runtime after installation (Option in the cell output)
!apt-get install libgeos-dev
!pip install https://github.com/matplotlib/basemap/archive/master.zip

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from skimage.transform import resize
import torch.optim as optim
from torch import LongTensor, FloatTensor
from scipy.stats import skewnorm, genpareto
from torchvision.utils import save_image
import sys
from datetime import datetime, timedelta
import torch.utils.data
import torchvision.utils as vutils
from tqdm import tqdm

In [None]:
!mkdir data
!wget https://raw.githubusercontent.com/Stream-AD/ExGAN/master/real.pt
!cp 'real.pt' data

In [None]:
from mpl_toolkits.basemap import Basemap, cm

latcorners = np.array([23.476929, 20.741224, 45.43908 , 51.61555 ])
loncorners = np.array([-118.67131042480469, -82.3469009399414,
                   -64.52022552490234, -131.4470977783203])
lon_0 = -105
lat_0 = 60

def plot_precip(data):
	'''
	data is a 813*1051 matrix containing unnormalized precipitation values
	'''
	if len(data.shape) == 3:
		data = data[0]
	data = resize(data, (813, 1051))
	data = (data+1)*50
	fig = plt.figure(figsize=(8,8))
	ax = fig.add_axes([0.1,0.1,0.8,0.8])
	m = Basemap(projection='stere',lon_0=lon_0,lat_0=90.,lat_ts=lat_0,\
	            llcrnrlat=latcorners[0],urcrnrlat=latcorners[2],\
	            llcrnrlon=loncorners[0],urcrnrlon=loncorners[2],\
	            rsphere=6371200.,resolution='i', area_thresh=10000)
	m.drawcoastlines()
	m.drawstates()
	m.drawcountries()
	m.drawlsmask(land_color="#FCF8F3", ocean_color='#E6FFFF')
	parallels = np.arange(0.,90,10.)
	m.drawparallels(parallels,labels=[1,0,0,0],fontsize=10)
	meridians = np.arange(180.,360.,10.)
	m.drawmeridians(meridians,labels=[0,0,0,1],fontsize=10)
	ny = data.shape[0]; nx = data.shape[1]
	lons, lats = m.makegrid(nx, ny) # get lat/lons of ny by nx evenly space grid.
	x, y = m(lons, lats) # compute map proj coordinates.
	clevs = np.array([0,1,2.5,5,7.5,10,15,20,30,40,50,70,100,150,200,250,300,400,500,600,750])
	cs = m.contourf(x,y,data,clevs,cmap=cm.s3pcpn)
	cbar = m.colorbar(cs,location='bottom',pad="5%")
	cbar.set_label('mm')
	plt.show()

In [None]:
LATENT_DIM = 20
DATASET_SIZE = 2557
BETAS = (0.5, 0.999)
c = 0.75
k = 10

In [None]:
def extremeness_measure(samples):
  if len(samples.shape) == 4:
    return samples.sum(dim=(1, 2, 3)) / 4096
  else:
    return samples.sum()/4096

In [None]:
class NWSDataset(Dataset):
    """
    NWS Dataset
    """

    def __init__(
            self, fake='data/fake.pt', c=0.75, i=0, conditional=False
    ):
        self.conditional = conditional
        self.real = torch.load('data/real.pt')
        if i > 0:
          FRAC = int(DATASET_SIZE * (c ** i))
          self.fake = torch.load(fake)
          self.data = torch.cat([self.real[:FRAC], self.fake[:DATASET_SIZE-FRAC]], 0)
        else:
          self.data = self.real
        self.data.requires_grad = False

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, item):
        if self.conditional:
          img = self.data[item]
          return img, extremeness_measure(img)
        else:
          return self.data[item]

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)


def convTINReLU(in_channels, out_channels, kernel_size=4, stride=2, padding=1):
    return nn.Sequential(
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        ),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2, True),
    )


def convINReLU(in_channels, out_channels, kernel_size=4, stride=2, padding=1):
    return nn.Sequential(
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        ),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2, True),
    )


class GeneratorUnconditional(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GeneratorUnconditional, self).__init__()
        self.block1 = convTINReLU(in_channels, 512, 4, 1, 0)
        self.block2 = convTINReLU(512, 256)
        self.block3 = convTINReLU(256, 128)
        self.block4 = convTINReLU(128, 64)
        self.block5 = nn.ConvTranspose2d(64, out_channels, 4, 2, 1)

    def forward(self, inp):
        out = self.block1(inp)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        return torch.tanh(self.block5(out))


class DiscriminatorUnconditional(nn.Module):
    def __init__(self, in_channels):
        super(DiscriminatorUnconditional, self).__init__()
        self.block1 = convINReLU(in_channels, 64)
        self.block2 = convINReLU(64, 128)
        self.block3 = convINReLU(128, 256)
        self.block4 = convINReLU(256, 512)
        self.block5 = nn.Conv2d(512, 64, 4, 1, 0)
        self.source = nn.Linear(64, 1)

    def forward(self, inp):
        out = self.block1(inp) 
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.block5(out)
        size = out.shape[0]
        out = out.view(size, -1)
        source = torch.sigmoid(self.source(out))
        return source

In [None]:
def getTrueFalseTensors(batch_size):
  trueTensor = 0.7+0.5*torch.rand((batch_size, 1))
  falseTensor = 0.3*torch.rand((batch_size, 1))
  probFlip = torch.rand((batch_size, 1)) < 0.05
  probFlip = probFlip.float()
  trueTensor, falseTensor = (
      probFlip * falseTensor + (1 - probFlip) * trueTensor,
      probFlip * trueTensor + (1 - probFlip) * falseTensor,
  )
  return trueTensor.cuda(), falseTensor.cuda()
  
def trainGAN(dataloader, Generator, Discriminator, optimizerGenerator, optimizerDiscrimintor, noise=0):
  for images in dataloader:
      batch_size = images[0].shape[0]
      trueTensor, falseTensor = getTrueFalseTensors(batch_size)
      images = images.cuda()
      realSource = Discriminator(images + noise*torch.randn_like(images).cuda())
      realLoss = criterionSource(realSource, trueTensor.expand_as(realSource))
      latent = torch.randn(batch_size, LATENT_DIM, 1, 1).cuda()
      fakeData = Generator(latent)
      fakeSource = Discriminator(fakeData.detach())
      fakeLoss = criterionSource(fakeSource, falseTensor.expand_as(fakeSource))
      lossD = realLoss + fakeLoss
      optimizerDiscrimintor.zero_grad()
      lossD.backward()
      torch.nn.utils.clip_grad_norm_(Discriminator.parameters(),20)
      optimizerDiscrimintor.step()
      fakeSource = Discriminator(fakeData)
      trueTensor = 0.9*torch.ones((batch_size, 1)).cuda()
      lossG = criterionSource(fakeSource, trueTensor.expand_as(fakeSource))
      optimizerGenerator.zero_grad()
      lossG.backward()
      torch.nn.utils.clip_grad_norm_(Generator.parameters(),20)
      optimizerGenerator.step()
      return lossG.item(), lossD.item()

In [None]:
dataloader = DataLoader(NWSDataset(), batch_size=256, shuffle=True)

criterionSource = nn.BCELoss()
criterionContinuous = nn.L1Loss()
criterionValG = nn.L1Loss()
criterionValD = nn.L1Loss()
UnconditionalG = GeneratorUnconditional(in_channels=LATENT_DIM, out_channels=1).cuda()
UnconditionalD = DiscriminatorUnconditional(in_channels=1).cuda()
UnconditionalG.apply(weights_init_normal)
UnconditionalD.apply(weights_init_normal)

optimizerG = optim.Adam(UnconditionalG.parameters(), lr=2e-4, betas=BETAS)
optimizerD = optim.Adam(UnconditionalD.parameters(), lr=1e-4, betas=BETAS)
static_z = FloatTensor(torch.randn((81, LATENT_DIM, 1, 1))).cuda()

DIRNAME = 'DCGAN/'
os.makedirs(DIRNAME, exist_ok=True)
tk = tqdm(range(1000))
for epoch in tk:
    noise = 1e-5*max(1 - (epoch/500.0), 0)
    lossG, lossD = trainGAN(dataloader, UnconditionalG, UnconditionalD, optimizerG, optimizerD, noise=0)
    tk.set_postfix(lossG=lossG, lossD=lossD)
UnconditionalG.eval()
with torch.no_grad():
    fakeSamples = UnconditionalG(Variable(torch.randn(int(DATASET_SIZE/c), LATENT_DIM, 1, 1)).cuda()).cpu()
sorted_indices = extremeness_measure(fakeSamples).numpy().argsort()[::-1].copy()
UnconditionalG.train()
torch.save(fakeSamples[sorted_indices], 'data/fake.pt')

In [None]:
plot_precip(fakeSamples[0])

In [None]:
optimizerG = optim.Adam(UnconditionalG.parameters(), lr=2e-5, betas=BETAS)
optimizerD = optim.Adam(UnconditionalD.parameters(), lr=1e-5, betas=BETAS)

c = 0.75
k = 10
DIRNAME = 'DistShift/'
os.makedirs(DIRNAME, exist_ok=True)

fake_name = 'data/fake.pt'
for i in range(1, k):
    print("Distribution Shift: Iteration ", i)
    dataloader = DataLoader(NWSDataset(fake=fake_name, c=c, i=i), batch_size=256, shuffle=True)
    tk = tqdm(range(0, 100))
    for epoch in tk:
        lossG, lossD = trainGAN(dataloader, UnconditionalG, UnconditionalD, optimizerG, optimizerD)
        tk.set_postfix(lossG=lossG, lossD=lossD)
    with torch.no_grad():
        UnconditionalG.eval()
        fsize = int((1 - (c ** (i + 1))) * DATASET_SIZE / c)
        fakeSamples = UnconditionalG(torch.randn(fsize, LATENT_DIM, 1, 1).cuda()).cpu()
        sorted_indices = extremeness_measure(fakeSamples).numpy().argsort()[::-1].copy()
        fake_name = DIRNAME + 'fake' + str(i + 1) + '.pt'
        torch.save(fakeSamples.data[sorted_indices], fake_name)
        UnconditionalG.train()

In [None]:
plot_precip(torch.load(DIRNAME+'fake2.pt')[1000]) #Sorted by extremeness. Hence, looking at the middle elements. 

In [None]:
plot_precip(torch.load(DIRNAME+'fake'+str(k)+'.pt')[1000])

In [None]:
dataset = NWSDataset(fake='DistShift/fake'+str(k)+'.pt', c=c, i=k, conditional=True)

In [None]:
measures = extremeness_measure(dataset.data)
threshold = measures.min() # Already tail of the data
tail = measures[np.where(measures > threshold)[0]]

In [None]:
genpareto_params = genpareto.fit(tail-threshold)

In [None]:
class GeneratorConditional(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GeneratorConditional, self).__init__()
        self.block1 = convTINReLU(in_channels + 1, 512, 4, 1, 0)
        self.block2 = convTINReLU(512, 256)
        self.block3 = convTINReLU(256, 128)
        self.block4 = convTINReLU(128, 64)
        self.block5 = nn.ConvTranspose2d(64, out_channels, 4, 2, 1)

    def forward(self, latent, continuous_code):
        inp = torch.cat((latent, continuous_code), 1)
        out = self.block1(inp)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        return torch.tanh(self.block5(out))

class DiscriminatorConditional(nn.Module):
    def __init__(self, in_channels):
        super(DiscriminatorConditional, self).__init__()
        self.block1 = convINReLU(in_channels, 64)
        self.block2 = convINReLU(64, 128)
        self.block3 = convINReLU(128, 256)
        self.block4 = convINReLU(256, 512)
        self.block5 = nn.Conv2d(512, 64, 4, 1, 0)
        self.source = nn.Linear(64 + 1, 1)

    def forward(self, inp, extreme):
        sums = extremeness_measure(inp)
        diff = torch.abs(extreme.view(-1, 1) - sums.view(-1, 1)) / torch.abs(extreme.view(-1, 1))
        out = self.block1(inp)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.block5(out)
        size = out.shape[0]
        out = out.view(size, -1)
        source = torch.sigmoid(self.source(torch.cat([out, diff], 1)))
        return source


criterionSource = nn.BCELoss()
G = GeneratorConditional(in_channels=LATENT_DIM, out_channels=1).cuda()
D = DiscriminatorConditional(in_channels=1).cuda()
G.apply(weights_init_normal)
D.apply(weights_init_normal)

rv = genpareto(*genpareto_params)

c = 0.75
k = 10

def sample_genpareto(size):
    probs = torch.rand(size)
    return FloatTensor(rv.ppf(probs)) + threshold


optimizerG = optim.Adam(G.parameters(), lr=2e-4, betas=BETAS)
optimizerD = optim.Adam(D.parameters(), lr=1e-4, betas=BETAS)
static_code = sample_genpareto((81, 1, 1, 1)).cuda()
static_z = FloatTensor(torch.randn((81, LATENT_DIM, 1, 1))).cuda()
    
def sample_image(batches_done):
    static_sample = G(static_z, static_code).cpu()
    static_sample = (static_sample + 1) / 2.0
    save_image(static_sample, DIRNAME + "%d.png" % batches_done, nrow=9)

DIRNAME = 'ExGAN/'
os.makedirs(DIRNAME, exist_ok=True)
fakename = 'DistShift/fake'+str(k)+'.pt'
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
tk = tqdm(range(0, 1)) # Actual Number of Epochs is 1000
for epoch in tk:
    noise = 1e-5 * max(1 - (epoch / 1000.0), 0)
    for images, labels in dataloader:
        batch_size = images.shape[0]
        trueTensor, falseTensor = getTrueFalseTensors(batch_size)
        images, labels = images.cuda(), labels.view(-1, 1).cuda()
        realSource = D(images, labels)
        realLoss = criterionSource(realSource, trueTensor.expand_as(realSource))
        latent = torch.randn(batch_size, LATENT_DIM, 1, 1).cuda()
        code = sample_genpareto((batch_size, 1, 1, 1)).cuda()
        fakeGen = G(latent, code)
        fakeGenSource = D(fakeGen.detach(), code)
        fakeGenLoss = criterionSource(fakeGenSource, falseTensor.expand_as(fakeGenSource))
        lossD = realLoss + fakeGenLoss
        optimizerD.zero_grad()
        lossD.backward()
        torch.nn.utils.clip_grad_norm_(D.parameters(), 20)
        optimizerD.step()
        fakeGenSource = D(fakeGen, code)
        fakeLabels = extremeness_measure(fakeGen)
        L_ext = torch.mean(torch.abs((fakeLabels - code.view(batch_size)) / code.view(batch_size)))
        lossG = criterionSource(fakeGenSource, trueTensor.expand_as(fakeGenSource)) + L_ext
        optimizerG.zero_grad()
        lossG.backward()
        torch.nn.utils.clip_grad_norm_(G.parameters(), 20)
        optimizerG.step()
    tk.set_postfix(lossG=lossG.item(), lossD=lossD.item())

In [None]:
# Takes around 1.5 hrs for training 1000 epochs. Use the pretrained weights instead.
!wget https://raw.githubusercontent.com/Stream-AD/ExGAN/master/ExGANweights.pt

In [None]:
G.load_state_dict(torch.load('ExGANweights.pt'))

In [None]:
G.eval()
tau = 1e-4
tau_prime = tau/c**k
val = rv.ppf((1-tau_prime)) + threshold
code = torch.ones(100, 1, 1, 1).cuda()*val
latent = torch.randn((100, LATENT_DIM, 1, 1)).cuda()
with torch.no_grad():
  images = G(latent, code).cpu()

In [None]:
plot_precip(images[0]) # Feel free to change tau, and look at more samples