In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import geopandas as gpd
import itertools

from rbf.interpolate import RBFInterpolant
from pykrige.ok import OrdinaryKriging
from pykrige.ok3d import OrdinaryKriging3D
from shapely.geometry import Polygon
from itertools import product
from torchvision import transforms
from skimage.metrics import structural_similarity, peak_signal_noise_ratio, mean_squared_error
from net import *
from utils import *
from loss import *

import warnings
warnings.filterwarnings("ignore")

In [2]:
img_root = "D:/nyc_taxi/data_min_max"
mask_root = "D:/nyc_taxi/data_min_max"
image_size = 64
all_time_max = 1428
chunk_size = 1
train_imgs = np.load(img_root+'/train.npy')
test_imgs = np.load(img_root+'/test.npy')
train_masks = np.load(mask_root+f'/train_random_mask.npy')
test_masks = np.load(mask_root+'/test_random_mask.npy')
dataset_train = taxi_data(train_imgs, train_masks, image_size, chunk_size)
dataset_test = taxi_data(test_imgs, test_masks, image_size, chunk_size)

### Test 1

In [5]:
## metrics
hole_l1_output = []
hole_mse_output = []
ssim_output_5 = []
psnr_output = []
    
for i in tqdm(range(len(dataset_test))):
    mask, gt = zip(*[dataset_test[i]])
    mask_single = torch.stack(mask).squeeze(0).squeeze(0).squeeze(0).numpy()
    gt_single = torch.stack(gt).squeeze(0).squeeze(0).squeeze(0).numpy()
    output_comp_single = gt_single.copy()
    
    ## valid regions
    train_x, train_y = np.where(mask_single==1)
    test_x, test_y = np.where(mask_single==0)
    v = gt_single[np.where(mask_single==1)]

    ## RBF
    rbf = RBFInterpolant(np.c_[train_x, train_y], v, sigma=0.1, phi='phs2')
    rbf_result = rbf(np.c_[test_x, test_y])
    rbf_result[rbf_result<0] = 0
    
    ## reshape back + prediction
    output_comp_single[mask_single == 0] = rbf_result

    ## scale back
    gt_single = gt_single*all_time_max
    output_comp_single = output_comp_single*all_time_max
    
    ## single image & output
    ssim_output_5.append(structural_similarity(output_comp_single, gt_single, win_size=5, data_range=all_time_max))
    psnr_output.append(peak_signal_noise_ratio(output_comp_single, gt_single, data_range=all_time_max))
    
    ## hole regions
    output_comp_single_hole = output_comp_single[np.where(mask_single == 0)]
    gt_single_hole = gt_single[np.where(mask_single == 0)]
    hole_l1_output.append(np.mean(np.abs(output_comp_single_hole - gt_single_hole)))
    hole_mse_output.append(mean_squared_error(output_comp_single_hole, gt_single_hole))

100%|████████████████████████████████████████████████████████████████████████████| 3624/3624 [1:10:43<00:00,  1.17s/it]


In [6]:
rbf_impute = []
psnr = np.array(psnr_output)
rbf_impute.append([
    np.mean(hole_l1_output),
    np.mean(hole_mse_output),
    np.mean(ssim_output_5),
    np.mean(psnr[~np.isinf(psnr)])
])
## make tabular view
rbf_impute = pd.DataFrame(rbf_impute, columns=['hole_l1_output', 'hole_mse_output', 'ssim_5', 'psnr'])

In [7]:
rbf_impute

Unnamed: 0,hole_l1_output,hole_mse_output,ssim_5,psnr
0,3.144211,284.880745,0.989042,54.834636
