## Spatial Rainfall Super resolution using untrained networks

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

In [None]:
import torch
import torch.nn as nn
import numpy as np
import os, os.path as osp
import data_utils
from matplotlib import pyplot as plt
import models
from copy import deepcopy
import dip
import metrics
import vis
import utils
from tqdm.notebook import tqdm

In [None]:
# torch.random.manual_seed(7) # for reproduction

In [None]:
os.environ['CUDA_VISIBLE_DEVICES']='2'

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
path_to_temp_file="./data/gt/precip_tx.nc"
KEY = "precip"

In [None]:
inp_data_temp = data_utils.read_ncdf_precip(path_to_temp_file)
idx = 8682
# idx=  8682+10
# idx=8683

plt.figure()
plt.imshow(inp_data_temp["precip"][idx,...]);plt.colorbar()
plt.show()



In [None]:
inp_data_temp["precip"][idx,...].shape

In [None]:
# define all CONSTANTS AND FLAGS
IN_CHANNELS=16
OUT_CHANNELS=1
USE_RANDOM_NOISE=True
USE_GUIDANCE = (not USE_RANDOM_NOISE)
SR_FACTOR = 8
LR_RES = inp_data_temp["precip"][idx,...].shape
HR_RES = (LR_RES[0]*SR_FACTOR, LR_RES[1]*SR_FACTOR)
z_noise_var = 0.01

In [None]:
PLOT_EVERY=500
LEARNING_RATE = 1e-3
NUM_EPOCHS = 4000*(int(np.log2(SR_FACTOR))-1) # sets an upper bound on number of iterations.
RESET_THRESHOLD = 5

In [None]:
if USE_RANDOM_NOISE:
    inp_tensor = torch.randn(1, IN_CHANNELS, HR_RES[0], HR_RES[1]).float().to(device)
elif USE_GUIDANCE:
    IN_CHANNELS = 1
    inp_np_img = data_utils.get_guidance_tensor(idx=idx, size=HR_RES)
    inp_tensor = torch.from_numpy(utils.normalize(inp_np_img))[None, None,...].float().to(device)

In [None]:
# define data
inp_low_res_rainfall = inp_data_temp[KEY][idx,:48, :128]
low_res_gt_tensor = torch.from_numpy(inp_low_res_rainfall)[None, None,...].float().to(device)#torch.ones(1, 1, 256,256).float().to(device)
low_res_gt_tensor = (low_res_gt_tensor - low_res_gt_tensor.min())/(low_res_gt_tensor.max() -low_res_gt_tensor.min())

In [None]:
inp_data_temp[KEY].shape

In [None]:
model = dip.DIP(num_in_channels=IN_CHANNELS, num_out_channels=OUT_CHANNELS).float().to(device)

In [None]:
optim = torch.optim.Adam(lr=LEARNING_RATE,  params=model.parameters())

In [None]:
loss_history = []
psnr_history = []
best_model_score = -np.inf
best_model_weights = None
reset_counter = 0

In [None]:
downsampler = nn.AvgPool2d(SR_FACTOR).to(device)
pbar = tqdm(range(NUM_EPOCHS))
for i in pbar:
    optim.zero_grad()
    # z_noise = torch.randn_like(inp_tensor)*z_noise_var
    out = model(inp_tensor)
    downsampled = downsampler(out)
    loss = ((downsampled-low_res_gt_tensor)**2).mean()
    loss.backward()
    optim.step()
    loss_history.append(loss.item())

    curr_psnr = metrics.psnr(downsampled.clone().detach().cpu().numpy()[0,0], low_res_gt_tensor.clone().detach().cpu().numpy()[0,0])
    psnr_history.append(curr_psnr)
    if curr_psnr > best_model_score:
        reset_counter = 0
        best_model_score = curr_psnr
        best_model_weights = deepcopy(model.state_dict())
    # else:
    #     reset_counter += 1
    
    # if reset_counter == RESET_THRESHOLD:
    #     optim.zero_grad()
    #     model.load_state_dict(best_model_weights)
    #     optim = torch.optim.Adam(lr=LEARNING_RATE, params=model.parameters())
    #     reset_counter = 0
    
    pbar.set_description(f"Loss = {loss.item():.6f} PSNR = {curr_psnr:.6f}")
    pbar.refresh()
    if i%PLOT_EVERY == 0:
        utils.plot_sr_results(inp_tensor, out, low_res_gt_tensor, LR_RES, HR_RES)

In [None]:
out.clone().detach().cpu().numpy()[0,0].shape, low_res_gt_tensor.clone().detach().cpu().numpy()[0,0].shape

In [None]:
plt.figure()
plt.plot(loss_history)
plt.title("Loss History")
plt.show()

In [None]:
def plot_sr(hr, lr):
    plt.figure(dpi=200)
    plt.subplot(121)
    plt.imshow(utils.tensor2im(hr))
    plt.subplot(122)
    plt.imshow(utils.tensor2im(lr))
    plt.tight_layout()
    plt.show()


In [None]:
bilinear_upscaled = utils.get_baseline_bilinear(inp_low_res_rainfall, HR_RES)

In [None]:

def plot_sr_results_np(x, y, z, titles=None):
    fig, ax = plt.subplots(1, 3, dpi=200)
    if titles is None:
        titles = ["","",""]
    for _i, (_a, _t) in enumerate(zip([x,y,z], titles)):
        print(_a.shape)
        ax[_i].imshow(_a)
        ax[_i].set_title(_t)
        ax[_i].axis('off')
    plt.suptitle(f"Super resolution {SR_FACTOR}x")
    plt.tight_layout()
    plt.show()

    

In [None]:
import vis
vis.plot_sr_results_np(bilinear_upscaled, utils.tensor2im(out), utils.tensor2im(low_res_gt_tensor), titles=["Bilinear Scaled", 'Ours', 'Low-res Precipitation'], suptitle=f"Super resolution {SR_FACTOR}x")

In [None]:
vis.plot_sr_results_np(bilinear_upscaled, utils.tensor2im(out), utils.tensor2im(low_res_gt_tensor), 
                       titles=["Bilinear Scaled", 'Ours', 'Low-res Precipitation'], suptitle=f"Super resolution {SR_FACTOR}x",
                       save=f"./results/results_{idx}_superres_{SR_FACTOR}x_guidance_{USE_GUIDANCE}.png")

In [None]:
# save all results
utils.save_single_image(f"./results/bilinear_results_{idx}_superres_{SR_FACTOR}x_guidance_{USE_GUIDANCE}.png", bilinear_upscaled)
utils.save_single_image(f"./results/dip_results_{idx}_superres_{SR_FACTOR}x_guidance_{USE_GUIDANCE}.png", utils.tensor2im(out))
utils.save_single_image(f"./results/input_lowres_{idx}_superres_{SR_FACTOR}x_guidance_{USE_GUIDANCE}.png", utils.tensor2im(low_res_gt_tensor))
# cv2.imwrite("./results/bilinear_results_{idx}_superres_{SR_FACTOR}x_guidance_{USE_GUIDANCE}.png", bilinear_upscaled)
# cv2.imwrite( "./results/dip_results_{idx}_superres_{SR_FACTOR}x_guidance_{USE_GUIDANCE}.png", utils.tensor2im(out))
# cv2.imwrite( "./results/input_lowres_{idx}_superres_{SR_FACTOR}x_guidance_{USE_GUIDANCE}.png", utils.tensor2im(low_res_gt_tensor))