In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader 
import numpy as np
import os, shutil
import pandas, csv, json
import random
from datetime import datetime
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style({u'font.sans-serif':['simhei']})
plt.rcParams[u'font.sans-serif'] = ['simhei']
plt.rcParams['axes.unicode_minus'] = False
import pandas as pd
import scipy
import math
from scipy.stats import genextreme as gev, gumbel_r, norm, gompertz
from scipy.special import gamma, factorial

from constants import * 
from utils import *

In [None]:
def get_scope(x, y):
    x_list, y_list = [], []
    index = np.where(y==1)[0]
    if len(index) == 0:
        return x_list, y_list
    st = index[0]
    ed = index[0]
    for i in range(1, len(index)):
        obj = index[i]
        if obj == ed + 1:
            ed = obj
        else:
            x_list.append(x[st:ed+1])
            y_list.append(y[st:ed+1])
            st, ed = obj, obj
    x_list.append(x[st:ed+1])
    y_list.append(y[st:ed+1])
    return x_list, y_list

In [None]:
def get_mask(data, thres):
    dif_data = data[1:, 7] - data[:-1, 7]
    index = np.argwhere(dif_data>=10)[:, 0] + 1
    mask = np.zeros((data.shape[0]))
    mask[index] = 1
    mask[data[:,7]>=thres[:,7]] = 1
    return mask

In [None]:
# For classification output
def plot(sitename, a_path, a_label):
    # -------------------------------
    # Load data
    # -------------------------------
    if mode == 1:
        a_pred = np.load(f"{a_path}/{sitename}_class.npy")[:, :, 0]
    else:
        a_pred = np.load(f"{a_path}/{sitename}.npy")[:, :, 0]

    true   = np.load(f"{origin_path}/{sitename}.npy")
    thres  = np.load(f"{thres_path}/{sitename}.npy")
    date   = pd.date_range(start="2019-01-01 00:00",end="2019-12-31 23:00", freq='H')
    mask   = get_mask(true, thres)
    # -------------------------------
    # Shift data
    # -------------------------------
    
    a_pred = a_pred[:, shift]
#     print(a_pred.shape, true.shape)
    if mode == 2:
        true   = true [memory_size+source_size+target_size+shift:, 7]
        thres  = thres[memory_size+source_size+target_size+shift:, 7]
        date_  = date [memory_size+source_size+target_size+shift:]
        mask   = mask [memory_size+source_size+target_size+shift:]
    elif mode == 1:
        true   = true [memory_size+window_size+source_size+target_size+shift:, 7]
        thres  = thres[memory_size+window_size+source_size+target_size+shift:, 7]
        date_  = date [memory_size+window_size+source_size+target_size+shift:]
        mask   = mask [memory_size+window_size+source_size+target_size+shift:]
    elif shift < -1:
        true   = true [source_size+target_size+shift:shift+1, 7]
        thres  = thres[source_size+target_size+shift:shift+1, 7]
        date_  = date [source_size+target_size+shift:shift+1]
        mask   = mask [source_size+target_size+shift:shift+1]
    else:
        true   = true [source_size+target_size-1:, 7]
        thres  = thres[source_size+target_size-1:, 7]
        date_  = date [source_size+target_size-1:]
        mask   = mask [source_size+target_size-1:]
#     print(a_pred.shape, true.shape, thres.shape)
#     print(mask.shape, true.shape, a_pred.shape)
    
    # -------------------------------
    # Zoom data
    # -------------------------------
    st = np.where(date_ == '2019/05/01 00:00')[0][0]
    ed = st + 24 * 28
#     st = 0
#     ed = -1
    true   = true  [st:ed]
    thres  = thres [st:ed]
    mask   = mask  [st:ed]
    x      = date_ [st:ed]
    a_pred = a_pred[st:ed]
#     print(true, true.shape)
#     print(mask, mask.shape)
    # -------------------------------
    # Draw data
    # -------------------------------
    fig, ax = plt.subplots(1, 1, figsize=(32,8))
    ax.fill_between(x, 0, max(true), where=a_pred >= .5, 
                    color='red', alpha=0.5, transform=ax.get_xaxis_transform())
    ax.fill_between(x, 0, max(true), where=mask>.5, 
                    color='blue', alpha=0.5, transform=ax.get_xaxis_transform())
    ax.plot(x, true,      color='black',  lw=2,  alpha=.5, label='true')
#     ax.plot(x, thres,     color='black',  lw=2,  alpha=.5, label='thres')
    ax.legend(loc='best', frameon=False, fontsize='xx-large')
    
    plt.title(f"{sitename}", fontsize=14)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.show()



In [None]:

args = ["--no=-1"]
# args = ["--no=-1", "--skip_site"]
opt = parse(args)
same_seeds(opt.seed)
origin_path = "data/origin/valid"
thres_path  = "data/thres/valid"

In [None]:
site_list = ['陽明','中山','萬華','古亭']
method = ""
memory_size = 960
window_size = 24 
source_size = 24
target_size = 8
shift = -1 # -1 ... -target_size
mode = 1 # 2 for seq 1 for fudan 0 for others

for sitename in site_list:
#     print(sitename)
    no_a = 107
    plot(
        sitename, 
        f"split_method/results/{no_a}", f"{no_a}"
    )
#     break