# Precipitation and temperature downscaling using GANs

In [None]:
# Common imports
import os
import warnings
import numpy as np
from time import time

# To make this notebook's output stable across runs
np.random.seed(42)

# Config matplotlib
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=10)
mpl.rc('xtick', labelsize=8)
mpl.rc('ytick', labelsize=8)

# Utils
from deepdown.utils.data_loader import *
from deepdown.utils.utils_plot import *
from deepdown.utils.utils_loss import *
from deepdown.utils.helpers import *
from deepdown.utils.data_generators import *
from deepdown.models.SRGAN import *

# Try dask.distributed and see if the performance improves...
from dask.distributed import Client
c = Client(n_workers=os.cpu_count()-2, threads_per_worker=1)

warnings.filterwarnings("ignore", category=RuntimeWarning, message="divide by zero encountered in divide")

In [None]:
print_cuda_availability()

In [None]:
# Define paths and constant
with open('../config.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

# Paths
PATH_DEM = config['PATH_DEM']
PATH_ERA5_025 = config['PATH_ERA5_025']  # Original ERA5 0.25°
PATH_ERA5_100 = config['PATH_ERA5_100']  # ERA5 1°
PATH_MCH = config['PATH_MeteoSwiss']  # Note that Meteoswiss has a different coordinate system, but it doesn't matter here, as we only care about tensors
PATH_TMP = config['PATH_TMP']

# Data options
DATE_START = '1999-01-01'  # '1979-01-01'
DATE_END = '2021-12-31'
YY_TRAIN = [1999, 2015]  # [1979, 2015]
YY_TEST = [2016, 2021]
LEVELS = [850,1000] #[300, 500, 700, 850, 1000]  # Available with CORDEX-CMIP6
RESOL_LOW = 0.25  # degrees
INPUT_VARIABLES = ['tp', 't']
INPUT_PATHS = [PATH_ERA5_025 + '/precipitation', PATH_ERA5_025 + '/temperature']
DUMP_DATA_TO_PICKLE = True

# Crop on a smaller region
DO_CROP = True
# I reduce the area of crop now, to avoid NA
CROP_X = [2700000, 2760000]  # with NAN: [2720000, 2770000]
CROP_Y = [1190000, 1260000]  # with NAN: [1290000, 1320000]

# Hyperparameters
BATCH_SIZE = 32

# Display options
PLOT_DATA_FULL_EXTENT = False
PLOT_DATA_CROPPED = False

## Target variables

In [None]:
# Load target data
target = load_target_data(DATE_START, DATE_END, PATH_MCH, PATH_TMP)

In [None]:
# Extract the axes of the final target domain based on temperature 
x_axis = target.TabsD.x
y_axis = target.TabsD.y

In [7]:
if PLOT_DATA_FULL_EXTENT:
    fig, axs = plt.subplots(1, 4, figsize=(18,5))
    plot_map(axs[0], target.RhiresD.mean(dim='time').to_numpy().squeeze(), title="Daily precipitation", cmap=mpl.cm.YlGnBu)
    plot_map(axs[1], target.TabsD.mean(dim='time').to_numpy().squeeze(), title="Daily Temperature", cmap=mpl.cm.RdBu_r)
    plot_map(axs[2], target.TmaxD.mean(dim='time').to_numpy().squeeze(), title="Daily Maximum temperature", cmap=mpl.cm.RdBu_r)
    plot_map(axs[3], target.TminD.mean(dim='time').to_numpy().squeeze(), title="Daily Minimum temperature", cmap=mpl.cm.RdBu_r)
    plt.tight_layout()
    plt.show()

## Input variables

In [None]:
input_data = load_input_data(DATE_START, DATE_END, PATH_DEM, INPUT_VARIABLES, INPUT_PATHS, 
                             LEVELS, RESOL_LOW, x_axis, y_axis, PATH_TMP)

In [None]:
if PLOT_DATA_FULL_EXTENT:
    fig, axs = plt.subplots(1, 4, figsize=(18,5))
    plot_map(axs[0], input_data.topo.to_numpy().squeeze(), title="Topography", cmap=mpl.cm.terrain)
    plot_map(axs[1], input_data.tp.mean(dim='time').to_numpy().squeeze(), title="Input precipitation", cmap=mpl.cm.YlGnBu)
    plot_map(axs[2], input_data.t.sel(level=850).mean(dim='time').to_numpy().squeeze(), title="Input temperature at 850hPa", cmap=mpl.cm.RdBu_r)
    plot_map(axs[3], input_data.t.sel(level=1000).mean(dim='time').to_numpy().squeeze(), title="Input temperature at 1000hPa", cmap=mpl.cm.RdBu_r)
    plt.tight_layout()
    plt.show()

## Crop domain

In [None]:
if DO_CROP:
    input_data = input_data.sel(x=slice(min(CROP_X), max(CROP_X)), y=slice(max(CROP_Y), min(CROP_Y)))
    target = target.sel(x=slice(min(CROP_X), max(CROP_X)), y=slice(max(CROP_Y), min(CROP_Y)))

In [None]:
if DO_CROP and PLOT_DATA_CROPPED:
    fig, axs = plt.subplots(1, 4, figsize=(18,4))
    plot_map(axs[0], target.RhiresD.mean(dim='time').to_numpy().squeeze(), title="Daily precipitation", cmap=mpl.cm.YlGnBu)
    plot_map(axs[1], target.TabsD.mean(dim='time').to_numpy().squeeze(), title="Daily Temperature", cmap=mpl.cm.RdBu_r)
    plot_map(axs[2], target.TmaxD.mean(dim='time').to_numpy().squeeze(), title="Daily Maximum temperature", cmap=mpl.cm.RdBu_r)
    plot_map(axs[3], target.TminD.mean(dim='time').to_numpy().squeeze(), title="Daily Minimum temperature", cmap=mpl.cm.RdBu_r)

In [None]:
if DO_CROP and PLOT_DATA_CROPPED:
    fig, axs = plt.subplots(1, 4, figsize=(18,4))
    plot_map(axs[0], input_data.topo.to_numpy().squeeze(), title="Topography", cmap=mpl.cm.terrain)
    plot_map(axs[1], input_data.tp.mean(dim='time').to_numpy().squeeze(), title="Input precipitation", cmap=mpl.cm.YlGnBu)
    plot_map(axs[2], input_data.t.sel(level=850).mean(dim='time').to_numpy().squeeze(), title="Input temperature at 850hPa", cmap=mpl.cm.RdBu_r)
    plot_map(axs[3], input_data.t.sel(level=1000).mean(dim='time').to_numpy().squeeze(), title="Input temperature at 1000hPa", cmap=mpl.cm.RdBu_r)

## Split sample and data generator

In [None]:
# Split the data
x_train = input_data.sel(time=slice('1999', '2011')) 
x_valid = input_data.sel(time=slice('2012', '2015')) 
x_test = input_data.sel(time=slice('2016', '2021'))

y_train = target.sel(time=slice('1999', '2011'))
y_valid = target.sel(time=slice('2012', '2005'))
y_test = target.sel(time=slice('2006', '2011'))

In [None]:
# Select the variables to use as input and output
input_vars = {'topo' : None, 'tp': None, 't': LEVELS}
output_vars = ['RhiresD', 'TabsD'] #['RhiresD', 'TabsD', 'TmaxD', 'TminD']

In [None]:
training_set = DataGenerator(x_train, y_train, input_vars, output_vars)
loader_train = torch.utils.data.DataLoader(training_set, batch_size=32)

In [None]:
# Validation
valid_set = DataGenerator(x_valid, y_valid, input_vars, output_vars, shuffle=False, mean=training_set.mean, std=training_set.std)
loader_val = torch.utils.data.DataLoader(valid_set, batch_size=32)

# Test
test_set = DataGenerator(x_test, y_test, input_vars, output_vars, shuffle=False, mean=training_set.mean, std=training_set.std)
loader_test = torch.utils.data.DataLoader(test_set, batch_size=32)

In [None]:
# Check to make sure the range on the input and output images is correct, and they're the correct shape
testx, testy = training_set.__getitem__(3)
print("x shape: ", testx.shape)
print("y shape: ", testy.shape)
print("x min: ", torch.min(testx))
print("x max: ", torch.max(testx))
print("y min: ", torch.min(testy))
print("y max: ",torch.max(testy))

In [None]:
training_set.n_samples

In [None]:
data = next(iter(loader_train))
x, y = data
print('Shape of x:', x.shape)
print('Shape of y:', y.shape)

In [None]:
# Plot input
# Plotting the mean of the predictors
n_figs = len(x[0,:,0,0])
ncols = 4
nrows = -(-n_figs // ncols)
fig, axes = plt.subplots(figsize=(24, 3.3*nrows), ncols=ncols, nrows=nrows)
for i in range(n_figs):
    i_row = i // ncols
    i_col = i % ncols
    if nrows == 1:
        ax = axes[i_col]
    else:
        ax = axes[i_row, i_col]
    vals = torch.mean(x[:,i,:,:],axis=0)
    plot_map(ax, vals, title=f"Average of feature {i+1}")


In [None]:
# Defining the G and D
# Adapted from https://github.com/mantariksh/231n_downscaling/blob/master/SRGAN.ipynb

In [None]:
### Test the Generator

In [None]:
NUM_CHANNELS_IN = 4
NUM_CHANNELS_OUT = 2
dtype = torch.float32 
input_size=y.shape[2:]


In [None]:
print(input_size)

### Check the generator

In [None]:
torch.cuda.empty_cache()

In [None]:
x,y = (training_set.__getitem__(3))

In [None]:
x = x.unsqueeze(0)
print(x.shape)

In [None]:
model = Generator(NUM_CHANNELS_IN, input_size, output_channels=NUM_CHANNELS_OUT)
model = model.to(device=device)
x = x.to(device=device)

In [None]:
output = model(x)

In [None]:
plt.subplot(121)
plt.imshow(x.cpu().detach().numpy()[0, 1, :, :])
plt.title("Input low-res Precip")
plt.subplot(122)
plt.imshow(output.cpu().detach().numpy()[0, 0, :, :])
plt.title("Output Precip")
plt.figure()

In [None]:
plt.subplot(121) #??
plt.imshow(x.cpu().detach().numpy()[0, 2, :, :])
plt.title("Input low-res Temp")
plt.subplot(122)
plt.imshow(output.cpu().detach().numpy()[0, 1, :, :])
plt.title("Output Temp")
plt.figure()

### Check the discriminator

In [None]:
x,y = (training_set.__getitem__(3))

In [None]:
print(x.shape)

In [None]:
#y = y.unsqueeze(0)
print(y.shape[1:])

In [None]:
h, w = y.shape[1:]

In [None]:
#### Test the discriminator
test_Discriminator(training_set, Discriminator)

In [None]:

# Helper functions for plotting
def plot_epoch(x, y_pred, y):
    figsize = (9,4)
    plt.figure(figsize=figsize)
    plt.subplot(1,3,1) 
    # x[0,0,..] correspond to topo
    plt.imshow(x[0,1,:,:].cpu().detach().numpy())
    plt.title("Input Precip")
    plt.subplot(1,3,2)
    plt.imshow(y_pred[0,0,:,:].cpu().detach().numpy())
    plt.title("Output Precip")
    plt.subplot(1,3,3)
    plt.imshow(y[0,0,:,:].cpu().detach().numpy())
    plt.title("True Precip")
    
    plt.figure(figsize=figsize)
    plt.subplot(1,3,1)
    plt.imshow(x[0,2,:,:].cpu().detach().numpy())
    plt.title("Input Temp")
    plt.subplot(1,3,2)
    plt.imshow(y_pred[0,1,:,:].cpu().detach().numpy())
    plt.title("Output Temp")
    plt.subplot(1,3,3)
    plt.imshow(y[0,1,:,:].cpu().detach().numpy())
    plt.title("True Temp")
    plt.show()
    
    
def plot_loss(G_content, G_advers, D_real_L, D_fake_L, weight_param):
    
    D_count = np.count_nonzero(D_real_L)
    G_count = np.count_nonzero(G_content)
    
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(range(G_count), G_content[range(G_count)])
    plt.plot(range(G_count), G_advers[range(G_count)])
    plt.plot(range(G_count), G_content[range(G_count)] + weight_param*G_advers[range(G_count)])
    plt.legend(("Content", "Adversarial", "Total"))
    plt.title("Generator loss")
    plt.xlabel("Iteration")
    
    plt.subplot(1,2,2)
    plt.plot(range(D_count), D_real_L[range(D_count)])
    plt.plot(range(D_count), D_fake_L[range(D_count)])
    plt.plot(range(D_count), D_real_L[range(D_count)] + D_fake_L[range(D_count)])
    plt.legend(("Real Pic", "Fake Pic", "Total"))
    plt.title("Discriminator loss")
    plt.xlabel("Iteration")
    plt.show()

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [None]:
D = Discriminator(num_channels=NUM_CHANNELS_OUT, H=h,W=w) 
G = Generator(NUM_CHANNELS_IN, input_size, output_channels=NUM_CHANNELS_OUT)

lr = 0.0005
# No checkpoints....
# Define optimizer for discriminator
D_solver = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
# Define optimizer for generator
G_solver = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))


In [None]:
num_epochs=10
G_iters=2
dtype = torch.float32

In [None]:
# Move the models to the correct device (GPU if GPU is available)
D = D.to(device=device)
G = G.to(device=device)
    
# Put models in training mode
D.train()
G.train()
    
#print("Expected num iters: ", len(loader_train)*num_epochs)
G_content = np.zeros(len(loader_train)*num_epochs*G_iters+1)
G_advers = np.zeros(len(loader_train)*num_epochs*G_iters+1)
D_real_L = np.zeros(len(loader_train)*num_epochs+1)
D_fake_L = np.zeros(len(loader_train)*num_epochs+1)

In [None]:
iter_count = 0
G_iter_count = 0
show_every=40
tic = time()

for epoch in range(num_epochs):
    
    for x,y in loader_train:
        high_res_imgs = y.to(device=device, dtype=dtype)
        logits_real = D(high_res_imgs)

        x.requires_grad_()
        low_res_imgs = x.to(device=device, dtype=dtype)
        fake_images = G(low_res_imgs)
        logits_fake = D(fake_images)
    
        # Update for the discriminator
        #d_total_error, D_real_L[iter_count], D_fake_L[iter_count] = discriminator_with_Nan_loss(logits_real, logits_fake)
        d_total_error, D_real_L[iter_count], D_fake_L[iter_count] = discriminator_loss(logits_real, logits_fake)
        #print('d_total_error:', d_total_error)
        #print('D_real_L[iter_count]:', D_real_L[iter_count])
        #print('D_fake_L[iter_count]:', D_fake_L[iter_count])
        D_solver.zero_grad()
        d_total_error.backward()
        D_solver.step()
        
        for i in range(G_iters):
                # Update for the generator
                fake_images = G(low_res_imgs)
                logits_fake = D(fake_images)
                gen_logits_fake = D(fake_images)
                weight_param = 1e-1 # Weighting put on adversarial loss
                g_error, G_content[G_iter_count], G_advers[G_iter_count] = generator_loss(fake_images, high_res_imgs, gen_logits_fake, weight_param=weight_param)
                #g_error, G_content[G_iter_count], G_advers[G_iter_count] = generator_withNan_loss(fake_images, high_res_imgs, gen_logits_fake, weight_param=weight_param)
                
                G_solver.zero_grad()
                g_error.backward()
                G_solver.step()
                G_iter_count += 1
                
        if (iter_count % show_every == 0):
                toc = time()
                print('Epoch: {}, Iter: {}, D: {:.4}, G: {:.4}, Time since last print (min): {:.4}'.format(epoch,iter_count,d_total_error.item(),g_error.item(), (toc-tic)/60 ))
                tic = time()
                plot_epoch(x, fake_images, y)
                plot_loss(G_content, G_advers, D_real_L, D_fake_L, weight_param)
                print()
        iter_count += 1
        
        
        #torch.save(D.cpu().state_dict(), 'GAN_Discriminator_checkpoint_adversWP_1e-1.pt')
        #torch.save(G.cpu().state_dict(), 'GAN_Generator_checkpoint_adversWP_1e-1.pt')
        
        D = D.to(device=device)
        G = G.to(device=device)
        # Put models in training mode
        D.train()
        G.train()