In [None]:
import matplotlib.pyplot as plt
import numpy as np 
import torch
import pickle
import pandas as pd
from scipy.stats import entropy, norm

In [None]:
def get_quantile(samples,q,dim=1):
    return torch.quantile(samples,q,dim=dim).cpu().numpy()

def normalize_std(record):
    mean_val = np.mean(record, axis=-1)
    std_val = np.std(record, axis=-1)
    normalized_data = (record.T - mean_val) / std_val

    return normalized_data.T

def normalize_minmax(record):
    min_val = np.min(record, axis=-1)
    max_val = np.max(record, axis=-1)
    normalized_data = (record.T - min_val) / (max_val - min_val)

    return normalized_data.T

def JS_divergence(p, q):
    p = (p - np.min(p))/(np.max(p)-np.min(p))
    q = (q - np.min(q))/(np.max(q)-np.min(q))
    M = (p + q) / 2
    return 0.5 * entropy(p, M, base=2) + 0.5 * entropy(q, M, base=2)

def total_variation_distance(p, q):
    if len(p) != len(q):
        raise ValueError("两个序列的长度必须相同")

    total_var_distance = 0
    n = len(p)

    for i in range(n):
        total_var_distance += abs(p[i] - q[i]) / 2
    total_var_distance /= n

    return total_var_distance

def mean_squared_error(p, q):
    if len(p) != len(q):
        raise ValueError("两个序列的长度必须相同")

    mse = sum((pi - qi) ** 2 for pi, qi in zip(p, q)) / len(p)
    
    return mse

def mean_absolute_error(p, q):
    if len(p) != len(q):
        raise ValueError("两个序列的长度必须相同")

    mae = sum(abs(pi - qi) for pi, qi in zip(p, q)) / len(p)
    
    return mae

def get_CRPS(data, G):
    N, L = data.shape
    crps_values = []

    for i in range(N):
        true_cdf = compute_cdf(data[i, :])
        gen_cdf = compute_cdf(G[i, :])
        
        if np.max(gen_cdf)>2.5:
            continue

        crps_i = np.sum(np.square(true_cdf - gen_cdf))
        crps_values.append(crps_i)

    mean_crps = np.mean(crps_values)
    return mean_crps

def compute_cdf(values):
    # 使用SciPy中的norm库估计累积分布函数
    cdf_values = norm.cdf(values)
    return cdf_values

def new_CRPS(target, sample):
    true_cdf = compute_cdf(target)
    gen_cdf = compute_cdf(sample)
        
    crps_i = np.sum(np.square(true_cdf - gen_cdf))
    return crps_i

In [None]:
dataset = 'value' #choose 'healthcare', 'airquality' or 'value'
datafolder = 'shiqu_0523_202605-E700N560' # set the folder name
nsample = 10 # number of generated sample

path = './save/'+datafolder+'/generated_outputs_nsample' + str(nsample) + '.pk' 

with open(path, 'rb') as f:
    samples, all_target, all_observed_time, scaler, mean_scaler = pickle.load( f)

samples_np = []
for i in range(nsample):
    samples_np.append(samples[:, i, :, 0].cpu().numpy())
samples_np = np.array(samples_np)
all_target_np = all_target[:, :, 0].cpu().numpy()

L = samples_np.shape[-1] #time length

In [None]:
path_result = './save/'+datafolder+'/result_nsample' + str(nsample) + '.pk' 

with open(path_result, 'rb') as f:
    JS_div, JS_one_div, tv_distance, CRPS, _ = pickle.load(f)

print("JS_div = ", JS_div, "\nJS_one_div = ", JS_one_div, "\ntv_distance = ", tv_distance, "\nCRPS = ", CRPS)

In [None]:
print("all_target_np.shape: ", all_target_np.shape)
print("samples_np.shape: ", samples_np.shape)

In [None]:
# all_target_np = normalize_std(all_target_np)
# for i in range(nsample):
#     samples_np_one = samples_np[i]
#     samples_np[i] = normalize_std(samples_np_one)

**指标计算**

In [None]:
JS = [0 for i in range(nsample)]
TV_avg = [0 for i in range(nsample)]
MSE = [0 for i in range(nsample)]
MAE = [0 for i in range(nsample)]
CRPS = [0 for i in range(nsample)]

metrics_target = all_target_np[:, 72:96]
metrics_sample = samples_np[:, :, 72:96]

N = metrics_sample.shape[1]
for k in range(nsample):
    n = N
    for i in range(N):
        # if np.std(metrics_sample[k, i] - metrics_target[i]) > 1:
        #     n = n - 1
        #     continue
        JS[k] += JS_divergence(metrics_target[i], metrics_sample[k, i])
        TV_avg[k] += total_variation_distance(metrics_target[i], metrics_sample[k, i])
        MSE[k] += mean_squared_error(metrics_target[i], metrics_sample[k, i])
        MAE[k] += mean_absolute_error(metrics_target[i], metrics_sample[k, i])
        CRPS[k] += new_CRPS(metrics_target[i], metrics_sample[k, i])
    JS[k] = JS[k]/n
    TV_avg[k] = TV_avg[k]/n
    MSE[k] = MSE[k]/n
    MAE[k] = MAE[k]/n
    CRPS[k] = CRPS[k]/n

print(f'JS-div: {np.mean(JS):.4g}\t FLU: {np.max([np.max(JS)-np.mean(JS), np.mean(JS)-np.min(JS)]):.4f}')
print(f'Avg-TV: {np.mean(TV_avg):.4g} \t FLU: {np.max([np.max(TV_avg)-np.mean(TV_avg), np.mean(TV_avg)-np.min(TV_avg)]):.4f}')
print(f'MSE: {np.mean(MSE):.4g} \t FLU: {np.max([np.max(MSE)-np.mean(MSE), np.mean(MSE)-np.min(MSE)]):.4f}')
print(f'MAE: {np.mean(MAE):.4g} \t FLU: {np.max([np.max(MAE)-np.mean(MAE), np.mean(MAE)-np.min(MAE)]):.4f}')
print(f'CRPS: {np.mean(CRPS):.4g} \t FLU: {np.max([np.max(CRPS)-np.mean(CRPS), np.mean(CRPS)-np.min(CRPS)]):.4f}')

In [None]:
JS = np.array(JS)
TV_avg = np.array(TV_avg)
MAE = np.array(MAE)
CRPS = np.array(CRPS)

with open('shiqu_main.pk', "wb") as f:
    pickle.dump([JS, TV_avg, MAE, CRPS], f)

**平均图像**

In [None]:
dataind = 2  # change to visualize a different time-series sample

target_avg = np.mean(all_target_np, axis = 0)
sample_avg = np.mean(np.mean(samples_np, axis=1), axis = 0)

JS = JS_divergence(target_avg, sample_avg)
TV_avg = total_variation_distance(target_avg, sample_avg)
MSE = mean_squared_error(target_avg, sample_avg)
# CRPS = get_CRPS(all_target_np, np.mean(samples_np, axis=0)) 
CRPS = get_CRPS(target_avg.reshape(1, -1), sample_avg.reshape(1, -1))

# plt.rcParams["font.size"] = 16
plt.figure(figsize=(5, 4), dpi=160)

df_target = pd.DataFrame({"x": np.arange(0, L), "val": target_avg})
plt.plot(df_target.x, df_target.val, color='b', linestyle='solid', label='real data', linewidth=0.8)

df_sample = pd.DataFrame({"x": np.arange(0, L), "val": sample_avg})
plt.plot(df_sample.x, df_sample.val, color='g', linestyle='solid', label='generated data', linewidth=0.8)

plt.xlabel('time')
# plt.ylabel('Normalized traffic value')
plt.ylabel('Unnormalized traffic value')

plt.legend()
plt.show()

print(f'JS-div: {JS}')
print(f'Avg-TV: {TV_avg}')
print(f'MSE: {MSE}')
print(f'CRPS: {CRPS}')

**单个图像**

In [None]:
dataind = 2  # change to visualize a different time-series sample
num_show = 10

for i in range(num_show):
    plt.figure(figsize=(25, 6), dpi=160)
    plt.subplots_adjust(hspace=0.5, wspace=0.22)  # 调整子图之间的垂直和水平间距
    Add = 0
    for j in range(nsample):        
        if j < nsample//2:
            plt.subplot(2, nsample//2, j + 1)
        else:
            plt.subplot(2, nsample//2, j + 1)
                
        df_target = pd.DataFrame({"x": np.arange(0, L), "val": all_target_np[i + Add]})
        plt.plot(df_target.x,
                df_target.val,
                color='b',
                linestyle='solid',
                label='real data[' + str(i + Add) + ']',
                linewidth=1.2)

        df_sample = pd.DataFrame({"x": np.arange(0, L), "val": samples_np[j, i + Add]})
        plt.plot(df_sample.x,
                df_sample.val,
                color='g',
                linestyle='solid',
                label='generated data[' + str(i + Add) + ']',
                linewidth=1.2)

        plt.xlabel('time')
        plt.ylabel('Normalized traffic value', fontsize=10)

        plt.legend()
    plt.show()

In [None]:
dataind = 2  # change to visualize a different time-series sample
real_num = 39

plt.figure(figsize=(25, 6), dpi=160)
plt.subplots_adjust(hspace=0.5, wspace=0.22)  # 调整子图之间的垂直和水平间距
for j in range(nsample):        
    if j < nsample//2:
        plt.subplot(2, nsample//2, j + 1)
    else:
        plt.subplot(2, nsample//2, j + 1)
            
    df_target = pd.DataFrame({"x": np.arange(0, L), "val": all_target_np[real_num]})
    plt.plot(df_target.x,
            df_target.val,
            color='b',
            linestyle='solid',
            label='real data',
            linewidth=1.8)

    df_sample = pd.DataFrame({"x": np.arange(0, L), "val": samples_np[j, real_num]})
    plt.plot(df_sample.x,
            df_sample.val,
            color='g',
            linestyle='solid',
            label='generated data',
            linewidth=1.8)

    plt.xlabel('time (hours)')
    plt.ylabel('Normalized traffic value')

    plt.legend()
