In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import transforms
from skimage.metrics import structural_similarity, peak_signal_noise_ratio, mean_squared_error
from net import *
from utils import *
from loss import *

In [2]:
img_root = "D:/nyc_taxi/data_min_max"
mask_root = "D:/nyc_taxi/data_min_max"
image_size = 64
all_time_max = 1428
train_imgs = np.load(img_root+'/train.npy')
test_imgs = np.load(img_root+'/test.npy')
train_masks = np.load(mask_root+f'/train_random_mask.npy')
test_masks = np.load(mask_root+'/test_random_mask.npy')
dataset_train = taxi_data(train_imgs, train_masks, image_size, 1)
dataset_test = taxi_data(test_imgs, test_masks, image_size, 1)

In [3]:
## global mean
_, gt = zip(*[dataset_train[i] for i in range(len(dataset_train))])
gt = torch.stack(gt).squeeze(1).squeeze(1).numpy()
global_mean = []
for i in range(24):
    hour_data = gt[i::24, :]
    global_mean.append(np.median(hour_data, axis=0))
global_mean = np.stack(global_mean)

In [4]:
## metrics
hole_l1_output = []
hole_mse_output = []
ssim_output_5 = []
psnr_output = []
    
for i in tqdm(range(len(dataset_test))):
    mask, gt = zip(*[dataset_test[i]])
    gt_single = torch.stack(gt).squeeze(0).squeeze(0).squeeze(0).numpy()
    mask_single = torch.stack(mask).squeeze(0).squeeze(0).squeeze(0).numpy()
    output_comp_single = gt_single.copy()
    
    ## use global mean to impute
    output_comp_single[mask_single == 0] = global_mean[i%24][mask_single == 0]
    
    ## scale back
    gt_single = gt_single*all_time_max
    output_comp_single = output_comp_single*all_time_max
    
    ## single image & output
    ssim_output_5.append(structural_similarity(output_comp_single, gt_single, win_size=5, data_range=all_time_max))
    psnr_output.append(peak_signal_noise_ratio(output_comp_single, gt_single, data_range=all_time_max))
    
    ## hole regions
    output_comp_single_hole = output_comp_single[np.where(mask_single == 0)]
    gt_single_hole = gt_single[np.where(mask_single == 0)]
    hole_l1_output.append(np.mean(np.abs(output_comp_single_hole - gt_single_hole)))
    hole_mse_output.append(mean_squared_error(output_comp_single_hole, gt_single_hole))

  return 10 * np.log10((data_range ** 2) / err)
100%|██████████████████████████████████████████████████████████████████████████████| 3624/3624 [05:41<00:00, 10.62it/s]


In [5]:
## biase test set evaluation
global_median_impute = []
psnr = np.array(psnr_output)
global_median_impute.append([
    np.mean(hole_l1_output),
    np.mean(hole_mse_output),
    np.mean(ssim_output_5),
    np.mean(psnr[~np.isinf(psnr)])
])
## make tabular view
global_median_impute = pd.DataFrame(global_median_impute, 
                                    columns=['hole_l1_output', 'hole_mse_output', 'ssim_5', 'psnr'])

In [7]:
global_median_impute

Unnamed: 0,hole_l1_output,hole_mse_output,ssim_5,psnr
0,1.184803,58.190562,0.997593,61.975086
