In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import sys
import math
from scipy.stats import nbinom
from torch.nn.utils import weight_norm

# Define the NB class first, not mixture version
class NBNorm(nn.Module):
    def __init__(self, c_in, c_out):
        super(NBNorm,self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.n_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        
        self.p_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        self.out_dim = c_out # output horizon

    def forward(self,x):
        x = x.permute(0,2,1,3)
        (B, _, N,_) = x.shape # B: batch_size; N: input nodes
        n = self.n_conv(x).squeeze_(-1)
        p = self.p_conv(x).squeeze_(-1)

        # Reshape
        n = n.view([B,self.out_dim,N])
        p = p.view([B,self.out_dim,N])

        # Ensure n is positive and p between 0 and 1
        n = F.softplus(n) # Some parameters can be tuned here
        p = F.sigmoid(p)
        return n.permute([0,2,1]), p.permute([0,2,1])

    def likelihood_loss(self,y,n,p,y_mask=None):
        """
        y: true values
        y_mask: whether missing mask is given
        """
        nll = torch.lgamma(n) + torch.lgamma(y+1) - torch.lgamma(n+y) - n*torch.log(p) - y*torch.log(1-p)
        if y_mask is not None:
            nll = nll*y_mask
        return torch.sum(nll)

    def mean(self,n,p):
        """
        :param cat: Input data of shape (batch_size, num_timesteps, in_nodes)
        :return: Output data of shape (batch_size, 1, num_timesteps, in_nodes)
        """ 
        pass

# Define the Gaussian 
class GaussNorm(nn.Module):
    def __init__(self, c_in, c_out):
        super(GaussNorm,self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.n_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        
        self.p_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        self.out_dim = c_out # output horizon

    def forward(self,x):
        x = x.permute(0,2,1,3)
        (B, _, N,_) = x.shape # B: batch_size; N: input nodes
        loc    = self.n_conv(x).squeeze_(-1) # The location (loc) keyword specifies the mean. The scale (scale) keyword specifies the standard deviation.
        scale  = self.p_conv(x).squeeze_(-1)

        # Reshape
        loc   = loc.view([B,self.out_dim,N])
        scale = scale.view([B,self.out_dim,N])

        # Ensure n is positive and p between 0 and 1
        loc = F.softplus(loc) # Some parameters can be tuned here, count data are always positive
        # loc = F.sigmoid(loc) # Some parameters can be tuned here, count data are always positive
        scale = F.sigmoid(scale)
        return loc.permute([0,2,1]), scale.permute([0,2,1])

# Define the NB class first, not mixture version
class NBNorm_ZeroInflated(nn.Module):
    def __init__(self, c_in, c_out, four=False):
        super(NBNorm_ZeroInflated,self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.n_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        
        self.p_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)

        self.pi_conv = nn.Conv2d(in_channels=c_in,
                                    out_channels=c_out,
                                    kernel_size=(1,1),
                                    bias=True)
        self.four = four

        if four:
            self.zero_conv = nn.Conv2d(in_channels=c_in,
                                        out_channels=c_out,
                                        kernel_size=(1,1),
                                        bias=True)

        self.out_dim = c_out # output horizon

    def forward(self,x):
        x = x.permute(0,2,1,3)
        (B, _, N,_) = x.shape # B: batch_size; N: input nodes
        n  = self.n_conv(x).squeeze_(-1)
        p  = self.p_conv(x).squeeze_(-1)
        pi = self.pi_conv(x).squeeze_(-1)

        # Reshape
        n = n.view([B,self.out_dim,N])
        p = p.view([B,self.out_dim,N])
        pi = pi.view([B,self.out_dim,N])

        if self.four:
            zi = self.zero_conv(x).squeeze_(-1)
            zi = zi.view([B,self.out_dim,N])
            zi = F.sigmoid(zi)

        # Ensure n is positive and p between 0 and 1
        if not self.four:
            n = F.softplus(n)  # Some parameters can be tuned here     # fixme
            p = F.sigmoid(p)
            pi = F.sigmoid(pi)      # todo
        if self.four:
            return n.permute([0,2,1]), p.permute([0,2,1]), pi.permute([0,2,1]), zi.permute([0,2,1])
        else:
            return n.permute([0,2,1]), p.permute([0,2,1]), pi.permute([0,2,1])

class D_GCN(nn.Module):
    """
    Neural network block that applies a diffusion graph convolution to sampled location
    """       
    def __init__(self, in_channels, out_channels, orders, activation = 'relu', att=False):
        """
        :param in_channels: Number of time step.
        :param out_channels: Desired number of output features at each node in
        each time step.
        :param order: The diffusion steps.
        """
        super(D_GCN, self).__init__()
        self.orders = orders
        self.activation = activation
        self.num_matrices = 2 * self.orders + 1
        self.Theta1 = nn.Parameter(torch.FloatTensor(in_channels * self.num_matrices,
                                             out_channels))
        self.bias = nn.Parameter(torch.FloatTensor(out_channels))
        self.att = att
        self.reset_parameters()
        
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.Theta1.shape[1])
        self.Theta1.data.uniform_(-stdv, stdv)
        stdv1 = 1. / math.sqrt(self.bias.shape[0])
        self.bias.data.uniform_(-stdv1, stdv1)
        
    def _concat(self, x, x_):
        x_ = x_.unsqueeze(0)
        return torch.cat([x, x_], dim=0)
        
    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_nodes, num_timesteps)
        :A_q: The forward random walk matrix (num_nodes, num_nodes)
        :A_h: The backward random walk matrix (num_nodes, num_nodes)
        :return: Output data of shape (batch_size, num_nodes, num_features)
        """
        batch_size = X.shape[0] # batch_size
        num_node = X.shape[1]
        input_size = X.size(2)  # time_length
        supports = []
        supports.append(A_q)
        supports.append(A_h)
        
        x0 = X.permute(1, 2, 0) #(num_nodes, num_times, batch_size)
        x0 = torch.reshape(x0, shape=[num_node, input_size * batch_size])
        x = torch.unsqueeze(x0, 0)
        for support in supports:
            x1 = torch.mm(support, x0)
            x = self._concat(x, x1)
            for k in range(2, self.orders + 1):
                x2 = 2 * torch.mm(support, x1) - x0
                x = self._concat(x, x2)
                x1, x0 = x2, x1
                
        x = torch.reshape(x, shape=[self.num_matrices, num_node, input_size, batch_size])
        x = x.permute(3, 1, 2, 0)  # (batch_size, num_nodes, input_size, order)
        x = torch.reshape(x, shape=[batch_size, num_node, input_size * self.num_matrices])         
        x = torch.matmul(x, self.Theta1)   # (batch_size * self._num_nodes, output_size)

        if self.att:
            att = torch.softmax(torch.tanh(x), dim=1)       # attention layer
            x = x * att

        x += self.bias
        if self.activation == 'relu':
            x = F.relu(x)
        elif self.activation == 'selu':
            x = F.selu(x)   
            
        return x

## Code of BTCN from Yuankai
class B_TCN(nn.Module):
    """
    Neural network block that applies a bidirectional temporal convolution to each node of
    a graph.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3,activation = 'relu',device='cuda:0'):
        """
        :param in_channels: Number of nodes in the graph.
        :param out_channels: Desired number of output features.
        :param kernel_size: Size of the 1D temporal kernel.
        """
        super(B_TCN, self).__init__()
        # forward dirction temporal convolution
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.activation = activation
        self.device = device
        self.conv1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv2 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv3 = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        
        self.conv1b = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv2b = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv3b = nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        
    def forward(self, X):
        """
        :param X: Input data of shape (batch_size, num_timesteps, num_nodes)
        :return: Output data of shape (batch_size, num_timesteps, num_features)
        """
        batch_size = X.shape[0]
        seq_len = X.shape[1]
        Xf = X.unsqueeze(1)  # (batch_size, 1, num_timesteps, num_nodes)
        
        inv_idx = torch.arange(Xf.size(2)-1, -1, -1).long().to(device=self.device)#.to(device=self.device).to(device=self.device)
        Xb = Xf.index_select(2, inv_idx) # inverse the direction of time
        
        Xf = Xf.permute(0, 3, 1, 2)
        Xb = Xb.permute(0, 3, 1, 2) #(batch_size, num_nodes, 1, num_timesteps)
        tempf = self.conv1(Xf) * torch.sigmoid(self.conv2(Xf)) #+
        outf = tempf + self.conv3(Xf) 
        outf = outf.reshape([batch_size, seq_len - self.kernel_size + 1, self.out_channels])        
        
        tempb = self.conv1b(Xb) * torch.sigmoid(self.conv2b(Xb)) #+
        outb = tempb + self.conv3b(Xb)
        outb = outb.reshape([batch_size, seq_len - self.kernel_size + 1, self.out_channels])
        
        rec = torch.zeros([batch_size, self.kernel_size - 1, self.out_channels]).to(device=self.device)#.to(device=self.device)
        outf = torch.cat((outf, rec), dim = 1)
        outb = torch.cat((outb, rec), dim = 1) #(batch_size, num_timesteps, out_features)
        
        inv_idx = torch.arange(outb.size(1)-1, -1, -1).long().to(device=self.device)#.to(device=self.device)
        outb = outb.index_select(1, inv_idx)
        out = outf + outb
        if self.activation == 'relu':
            out = F.relu(outf) + F.relu(outb)
        elif self.activation == 'sigmoid':
            out = F.sigmoid(outf) + F.sigmoid(outb)       
        return out


class ST_NB(nn.Module):
    """
  wx_t  + wx_s
    |       |
   TC4     SC4
    |       |
   TC3     SC3
    |       |
   z_t     z_s
    |       |
   TC2     SC2
    |       |  
   TC1     SC1
    |       |
   x_m     x_m
    """
    def __init__(self, SC1, SC2, SC3, TC1, TC2, TC3, SNB,TNB): 
        super(ST_NB, self).__init__()
        self.TC1 = TC1
        self.TC2 = TC2
        self.TC3 = TC3
        self.TNB = TNB

        self.SC1 = SC1
        self.SC2 = SC2
        self.SC3 = SC3
        self.SNB = SNB

    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_timesteps, num_nodes)
        :A_hat: The Laplacian matrix (num_nodes, num_nodes)
        :return: Reconstructed X of shape (batch_size, num_timesteps, num_nodes)
        """
        print(111)
        print(111)
        X = X[:,:,:,0] # Dummy dimension deleted
        X_T = X.permute(0,2,1)
        X_t1 = self.TC1(X_T)
        X_t2 = self.TC2(X_t1) #num_time, rank
        self.temporal_factors = X_t2
        X_t3 = self.TC3(X_t2)
        _b,_h,_ht = X_t3.shape
        n_t_nb,p_t_nb = self.TNB(X_t3.view(_b,_h,_ht,1))

        X_s1 = self.SC1(X, A_q, A_h)
        X_s2 = self.SC2(X_s1, A_q, A_h) #num_nodes, rank
        self.space_factors = X_s2
        X_s3 = self.SC3(X_s2, A_q, A_h)
        _b,_n,_hs = X_s3.shape
        n_s_nb,p_s_nb = self.SNB(X_s3.view(_b,_n,_hs,1))
        n_res = n_t_nb.permute(0, 2, 1) * n_s_nb
        p_res = p_t_nb.permute(0, 2, 1) * p_s_nb
               
        return n_res,p_res

class ST_Gau(nn.Module):
    """
  wx_t  + wx_s
    |       |
   TC4     SC4
    |       |
   TC3     SC3
    |       |
   z_t     z_s
    |       |
   TC2     SC2
    |       |  
   TC1     SC1
    |       |
   x_m     x_m
    """
    def __init__(self, SC1, SC2, SC3, TC1, TC2, TC3, SGau,TGau): 
        super(ST_Gau, self).__init__()
        self.TC1 = TC1
        self.TC2 = TC2
        self.TC3 = TC3
        self.TGau = TGau

        self.SC1 = SC1
        self.SC2 = SC2
        self.SC3 = SC3
        self.SGau = SGau

    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_timesteps, num_nodes)
        :A_hat: The Laplacian matrix (num_nodes, num_nodes)
        :return: Reconstructed X of shape (batch_size, num_timesteps, num_nodes)
        """  
        X = X[:,:,:,0] #.to(device='cuda') # Dummy dimension deleted
        X_T = X.permute(0,2,1)
        X_t1 = self.TC1(X_T)
        X_t2 = self.TC2(X_t1) #num_time, rank
        self.temporal_factors = X_t2
        X_t3 = self.TC3(X_t2)
        _b,_h,_ht = X_t3.shape
        loc_t,scale_t = self.TGau(X_t3.view(_b,_h,_ht,1))

        X_s1 = self.SC1(X, A_q, A_h)
        X_s2 = self.SC2(X_s1, A_q, A_h) #num_nodes, rank
        self.space_factors = X_s2
        X_s3 = self.SC3(X_s2, A_q, A_h)
        _b,_n,_hs = X_s3.shape
        loc_s,scale_s = self.SGau(X_s3.view(_b,_n,_hs,1))

        loc_res = loc_t.permute(0, 2, 1) * loc_s
        scale_res = scale_t.permute(0, 2, 1) * scale_s
               
        return loc_res,scale_res

class ST_NB_ZeroInflated(nn.Module):
    """
  wx_t  + wx_s
    |       |
   TC4     SC4
    |       |
   TC3     SC3
    |       |
   z_t     z_s
    |       |
   TC2     SC2
    |       |  
   TC1     SC1
    |       |
   x_m     x_m
    """
    def __init__(self, SC1, SC2, SC3, TC1, TC2, TC3, SNB,TNB): 
        super(ST_NB_ZeroInflated, self).__init__()
        self.TC1 = TC1
        self.TC2 = TC2
        self.TC3 = TC3
        self.TNB = TNB

        self.SC1 = SC1
        self.SC2 = SC2
        self.SC3 = SC3
        self.SNB = SNB

    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_timesteps, num_nodes)
        :A_hat: The Laplacian matrix (num_nodes, num_nodes)
        :return: Reconstructed X of shape (batch_size, num_timesteps, num_nodes)
        """  
        X = X[:,:,:,0]#.to(device='cuda') # Dummy dimension deleted
        X_T = X.permute(0,2,1)
        X_t1 = self.TC1(X_T)
        X_t2 = self.TC2(X_t1) #num_time, rank
        self.temporal_factors = X_t2
        X_t3 = self.TC3(X_t2)
        _b,_h,_ht = X_t3.shape
        n_t_nb,p_t_nb,pi_t_nb = self.TNB(X_t3.view(_b,_h,_ht,1))

        X_s1 = self.SC1(X, A_q, A_h)
        X_s2 = self.SC2(X_s1, A_q, A_h) #num_nodes, rank
        self.space_factors = X_s2
        X_s3 = self.SC3(X_s2, A_q, A_h)
        _b,_n,_hs = X_s3.shape
        n_s_nb,p_s_nb,pi_s_nb = self.SNB(X_s3.view(_b,_n,_hs,1))
        n_res = n_t_nb.permute(0, 2, 1) * n_s_nb
        p_res = p_t_nb.permute(0, 2, 1) * p_s_nb
        pi_res = pi_t_nb.permute(0, 2, 1) * pi_s_nb

        return n_res,p_res,pi_res


class ST_TWEEDIE_ZeroInflated(nn.Module):
    """
  wx_t  + wx_s
    |       |
   TC4     SC4
    |       |
   TC3     SC3
    |       |
   z_t     z_s
    |       |
   TC2     SC2
    |       |
   TC1     SC1
    |       |
   x_m     x_m
    """
    def __init__(self, SC1, SC2, SC3, TC1, TC2, TC3, SNB,TNB):
        super(ST_TWEEDIE_ZeroInflated, self).__init__()
        self.TC1 = TC1
        self.TC2 = TC2
        self.TC3 = TC3
        self.TNB = TNB

        self.SC1 = SC1
        self.SC2 = SC2
        self.SC3 = SC3
        self.SNB = SNB

    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_timesteps, num_nodes)
        :A_hat: The Laplacian matrix (num_nodes, num_nodes)
        :return: Reconstructed X of shape (batch_size, num_timesteps, num_nodes)
        """
        X = X[:,:,:,0]#.to(device='cuda') # Dummy dimension deleted
        X_T = X.permute(0,2,1)
        X_t1 = self.TC1(X_T)
        X_t2 = self.TC2(X_t1) #num_time, rank
        self.temporal_factors = X_t2
        X_t3 = self.TC3(X_t2)
        _b,_h,_ht = X_t3.shape
        _, _,pi_t_nb = self.TNB(X_t3.view(_b,_h,_ht,1))

        pi_t_nb = torch.sigmoid(pi_t_nb)

        X_s1 = self.SC1(X, A_q, A_h)
        X_s2 = self.SC2(X_s1, A_q, A_h) #num_nodes, rank
        self.space_factors = X_s2
        X_s3 = self.SC3(X_s2, A_q, A_h)
        _b,_n,_hs = X_s3.shape
        _,_,pi_s_nb = self.SNB(X_s3.view(_b,_n,_hs,1))

        pi_s_nb = torch.sigmoid(pi_s_nb)

        pi_res = pi_t_nb.permute(0, 2, 1) * pi_s_nb

        return 0, 0, pi_res


class ST_new_TWEEDIE_ZeroInflated(nn.Module):
    """
  wx_t  + wx_s
    |       |
   TC4     SC4
    |       |
   TC3     SC3
    |       |
   z_t     z_s
    |       |
   TC2     SC2
    |       |
   TC1     SC1
    |       |
   x_m     x_m
    """
    def __init__(self, SC1, SC2, SC3, TC1, TC2, TC3, SNB,TNB, four=False):
        super(ST_new_TWEEDIE_ZeroInflated, self).__init__()
        self.TC1 = TC1
        self.TC2 = TC2
        self.TC3 = TC3
        self.TNB = TNB

        self.SC1 = SC1
        self.SC2 = SC2
        self.SC3 = SC3
        self.SNB = SNB
        self.four = four

    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_timesteps, num_nodes)
        :A_hat: The Laplacian matrix (num_nodes, num_nodes)
        :return: Reconstructed X of shape (batch_size, num_timesteps, num_nodes)
        """
        X = X[:,:,:,0]#.to(device='cuda') # Dummy dimension deleted
        X_T = X.permute(0,2,1)
        X_t1 = self.TC1(X_T)
        X_t2 = self.TC2(X_t1) #num_time, rank
        self.temporal_factors = X_t2
        X_t3 = self.TC3(X_t2)
        _b,_h,_ht = X_t3.shape
        if self.four:
            n_t_nb, p_t_nb, pi_t_nb, zi_t_nb = self.TNB(X_t3.view(_b, _h, _ht, 1))
        else:
            n_t_nb,p_t_nb,pi_t_nb = self.TNB(X_t3.view(_b,_h,_ht,1))

        X_s1 = self.SC1(X, A_q, A_h)
        X_s2 = self.SC2(X_s1, A_q, A_h) #num_nodes, rank
        self.space_factors = X_s2
        X_s3 = self.SC3(X_s2, A_q, A_h)
        _b,_n,_hs = X_s3.shape
        if self.four:
            n_s_nb, p_s_nb, pi_s_nb, zi_s_nb = self.SNB(X_s3.view(_b, _n, _hs, 1))

            zi_res = zi_t_nb.permute(0, 2, 1) * zi_s_nb
            # zi_res = torch.sigmoid(zi_res)
        else:
            n_s_nb,p_s_nb,pi_s_nb = self.SNB(X_s3.view(_b,_n,_hs,1))

        phi_res = n_t_nb.permute(0, 2, 1) * n_s_nb
        rou_res = p_t_nb.permute(0, 2, 1) * p_s_nb
        mu_res = pi_t_nb.permute(0, 2, 1) * pi_s_nb

        rou_res = torch.sigmoid(rou_res) + 1
        phi_res = torch.relu(phi_res)

        # n, p, pi, zi => phi, rou, mu, zi
        if self.four:
            return phi_res, rou_res, mu_res, zi_res
        else:
            return phi_res, rou_res, mu_res


In [8]:
from __future__ import division
import os
import zipfile
import numpy as np
import scipy.sparse as sp
import pandas as pd
from math import radians, cos, sin, asin, sqrt
# from sklearn.externals import joblib
import joblib
import scipy.io
import torch
from torch import nn
from scipy.stats import nbinom,norm
rand = np.random.RandomState(0)
# import tweedie
import torch.distributions
from sklearn.preprocessing import MinMaxScaler
import math

"""
Geographical information calculation
"""
def get_long_lat(sensor_index,loc = None):
    """
        Input the index out from 0-206 to access the longitude and latitude of the nodes
    """
    if loc is None:
        locations = pd.read_csv('data/metr/graph_sensor_locations.csv')
    else:
        locations = loc
    lng = locations['longitude'].loc[sensor_index]
    lat = locations['latitude'].loc[sensor_index]
    return lng.to_numpy(),lat.to_numpy()

def haversine(lon1, lat1, lon2, lat2): 
    """
    Calculate the great circle distance between two points 
    on the earth (specified in decimal degrees)
    """
    lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])
 
    # haversine
    dlon = lon2 - lon1 
    dlat = lat2 - lat1 
    a = sin(dlat/2)**2 + cos(lat1) * cos(lat2) * sin(dlon/2)**2
    c = 2 * asin(sqrt(a)) 
    r = 6371 
    return c * r * 1000


"""
Generate the training sample for forecasting task, same idea from STGCN
"""

def generate_dataset(X, num_timesteps_input, num_timesteps_output, origional_feature=True):
    """
    Takes node features for the graph and divides them into multiple samples
    along the time-axis by sliding a window of size (num_timesteps_input+
    num_timesteps_output) across it in steps of 1.
    :param X: Node features of shape (num_vertices, num_features,
    num_timesteps)
    :return:
        - Node features divided into multiple samples. Shape is
          (num_samples, num_vertices, num_features, num_timesteps_input).
        - Node targets for the samples. Shape is
          (num_samples, num_vertices, num_features, num_timesteps_output).
    """
    # Generate the beginning index and the ending index of a sample, which
    # contains (num_points_for_training + num_points_for_predicting) points
    indices = [(i, i + (num_timesteps_input + num_timesteps_output)) for i
               in range(X.shape[2] - (
                num_timesteps_input + num_timesteps_output) + 1)]

    # Save samples
    features, target = [], []
    for i, j in indices:
        features.append(
            X[:, :, i: i + num_timesteps_input].transpose(
                (0, 2, 1)))
        target.append(X[:, 0, i + num_timesteps_input: j])

    return torch.from_numpy(np.array(features)), \
           torch.from_numpy(np.array(target))


"""
Dynamically construct the adjacent matrix
"""

def get_Laplace(A):
    """
    Returns the laplacian adjacency matrix. This is for C_GCN
    """
    if A[0, 0] == 1:
        A = A - np.diag(np.ones(A.shape[0], dtype=np.float32)) # if the diag has been added by 1s
    D = np.array(np.sum(A, axis=1)).reshape((-1,))
    D[D <= 10e-5] = 10e-5
    diag = np.reciprocal(np.sqrt(D))
    A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A),
                         diag.reshape((1, -1)))
    return A_wave

def get_normalized_adj(A):
    """
    Returns the degree normalized adjacency matrix. This is for K_GCN
    """
    if A[0, 0] == 0:
        A = A + np.diag(np.ones(A.shape[0], dtype=np.float32)) # if the diag has been added by 1s
    D = np.array(np.sum(A, axis=1)).reshape((-1,))
    D[D <= 10e-5] = 10e-5    # Prevent infs
    diag = np.reciprocal(np.sqrt(D))
    A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A),
                         diag.reshape((1, -1)))
    return A_wave

def calculate_random_walk_matrix(adj_mx):
    """
    Returns the random walk adjacency matrix. This is for D_GCN
    """
    adj_mx = sp.coo_matrix(adj_mx)
    d = np.array(adj_mx.sum(1))
    d_inv = np.power(d, -1).flatten()
    d_inv[np.isinf(d_inv)] = 0.
    d_mat_inv = sp.diags(d_inv)
    random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
    return random_walk_mx.toarray()


def test_error_virtual(STmodel, unknow_set, test_data, A_s, E_maxvalue, Missing0):
    """
    :param STmodel: The graph neural networks
    :unknow_set: The unknow locations for spatial prediction
    :test_data: The true value test_data of shape (test_num_timesteps, num_nodes)
    :A_s: The full adjacent matrix
    :Missing0: True: 0 in original datasets means missing data
    :return: NAE, MAPE and RMSE
    """  
    unknow_set = set(unknow_set)
    time_dim = STmodel.time_dimension
    
    test_omask = np.ones(test_data.shape)
    if Missing0 == True:
        test_omask[test_data == 0] = 0
    test_inputs = (test_data * test_omask).astype('float32')
    test_inputs_s = test_inputs
   
    missing_index = np.ones(np.shape(test_data))
    missing_index[:, list(unknow_set)] = 0
    missing_index_s = missing_index
    
    o = np.zeros([test_data.shape[0]//time_dim*time_dim, test_inputs_s.shape[1]]) #Separate the test data into several h period
    
    for i in range(0, test_data.shape[0]//time_dim*time_dim, time_dim):
        inputs = test_inputs_s[i:i+time_dim, :]
        missing_inputs = missing_index_s[i:i+time_dim, :]
        T_inputs = inputs*missing_inputs
        T_inputs = T_inputs/E_maxvalue
        T_inputs = np.expand_dims(T_inputs, axis = 0)
        T_inputs = torch.from_numpy(T_inputs.astype('float32'))
        A_q = torch.from_numpy((calculate_random_walk_matrix(A_s).T).astype('float32'))
        A_h = torch.from_numpy((calculate_random_walk_matrix(A_s.T).T).astype('float32'))
        
        imputation = STmodel(T_inputs, A_q, A_h)
        imputation = imputation.data.numpy()
        o[i:i+time_dim, :] = imputation[0, :, :]
    
    o = o*E_maxvalue 
    truth = test_inputs_s[0:test_data.shape[0]//time_dim*time_dim]
    o[missing_index_s[0:test_data.shape[0]//time_dim*time_dim] == 1] = truth[missing_index_s[0:test_data.shape[0]//time_dim*time_dim] == 1]
    
    test_mask =  1 - missing_index_s[0:test_data.shape[0]//time_dim*time_dim]
    if Missing0 == True:
        test_mask[truth == 0] = 0
        o[truth == 0] = 0
    
    o_ = o[:,list(unknow_set)]
    truth_ = truth[:,list(unknow_set)]
    test_mask_ = test_mask[:,list(unknow_set)]

    MAE = np.sum(np.abs(o_ - truth_))/np.sum( test_mask_)
    RMSE = np.sqrt(np.sum((o_ - truth_)*(o_ - truth_))/np.sum( test_mask_) )
    # MAPE = np.sum(np.abs(o - truth)/(truth + 1e-5))/np.sum( test_mask)
    R2 = 1 - np.sum( (o_ - truth_)*(o_ - truth_) )/np.sum( (truth_ - truth_.mean())*(truth_-truth_.mean() ) )
    print(truth_.mean())
    return MAE, RMSE, R2, o

def test_error(STmodel, unknow_set, test_data, A_s, E_maxvalue, Missing0):
    """
    :param STmodel: The graph neural networks
    :unknow_set: The unknow locations for spatial prediction
    :test_data: The true value test_data of shape (test_num_timesteps, num_nodes)
    :A_s: The full adjacent matrix
    :Missing0: True: 0 in original datasets means missing data
    :return: NAE, MAPE and RMSE
    """  
    unknow_set = set(unknow_set)
    time_dim = STmodel.time_dimension
    
    test_omask = np.ones(test_data.shape)
    if Missing0 == True:
        test_omask[test_data == 0] = 0
    test_inputs = (test_data * test_omask).astype('float32')
    test_inputs_s = test_inputs
   
    missing_index = np.ones(np.shape(test_data))
    missing_index[:, list(unknow_set)] = 0
    missing_index_s = missing_index
    
    o = np.zeros([test_data.shape[0]//time_dim*time_dim, test_inputs_s.shape[1]]) #Separate the test data into several h period
    
    for i in range(0, test_data.shape[0]//time_dim*time_dim, time_dim):
        inputs = test_inputs_s[i:i+time_dim, :]
        missing_inputs = missing_index_s[i:i+time_dim, :]
        T_inputs = inputs*missing_inputs
        T_inputs = T_inputs/E_maxvalue
        T_inputs = np.expand_dims(T_inputs, axis = 0)
        T_inputs = torch.from_numpy(T_inputs.astype('float32'))
        A_q = torch.from_numpy((calculate_random_walk_matrix(A_s).T).astype('float32'))
        A_h = torch.from_numpy((calculate_random_walk_matrix(A_s.T).T).astype('float32'))
        
        imputation = STmodel(T_inputs, A_q, A_h)
        imputation = imputation.data.numpy()
        o[i:i+time_dim, :] = imputation[0, :, :]
    
    o = o*E_maxvalue 
    truth = test_inputs_s[0:test_data.shape[0]//time_dim*time_dim]
    o[missing_index_s[0:test_data.shape[0]//time_dim*time_dim] == 1] = truth[missing_index_s[0:test_data.shape[0]//time_dim*time_dim] == 1]
    
    test_mask =  1 - missing_index_s[0:test_data.shape[0]//time_dim*time_dim]
    if Missing0 == True:
        test_mask[truth == 0] = 0
        o[truth == 0] = 0
    
    o_ = o[:,list(unknow_set)]
    truth_ = truth[:,list(unknow_set)]
    test_mask_ = test_mask[:,list(unknow_set)]

    MAE = np.sum(np.abs(o_ - truth_))/np.sum( test_mask_)
    RMSE = np.sqrt(np.sum((o_ - truth_)*(o_ - truth_))/np.sum( test_mask_) )
    # MAPE = np.sum(np.abs(o - truth)/(truth + 1e-5))/np.sum( test_mask)
    R2 = 1 - np.sum( (o_ - truth_)*(o_ - truth_) )/np.sum( (truth_ - truth_.mean())*(truth_-truth_.mean() ) )
    print(truth_.mean())
    return MAE, RMSE, R2, o


def rolling_test_error(STmodel, unknow_set, test_data, A_s, E_maxvalue,Missing0):
    """
    :It only calculates the last time points' prediction error, and updates inputs each time point
    :param STmodel: The graph neural networks
    :unknow_set: The unknow locations for spatial prediction
    :test_data: The true value test_data of shape (test_num_timesteps, num_nodes)
    :A_s: The full adjacent matrix
    :Missing0: True: 0 in original datasets means missing data
    :return: NAE, MAPE and RMSE
    """  
    
    unknow_set = set(unknow_set)
    time_dim = STmodel.time_dimension
    
    test_omask = np.ones(test_data.shape)
    if Missing0 == True:
        test_omask[test_data == 0] = 0
    test_inputs = (test_data * test_omask).astype('float32')
    test_inputs_s = test_inputs
   
    missing_index = np.ones(np.shape(test_data))
    missing_index[:, list(unknow_set)] = 0
    missing_index_s = missing_index

    o = np.zeros([test_data.shape[0] - time_dim, test_inputs_s.shape[1]])

    for i in range(0, test_data.shape[0] - time_dim):
        inputs = test_inputs_s[i:i+time_dim, :]
        missing_inputs = missing_index_s[i:i+time_dim, :]
        MF_inputs = inputs * missing_inputs
        MF_inputs = np.expand_dims(MF_inputs, axis = 0)
        MF_inputs = torch.from_numpy(MF_inputs.astype('float32'))
        A_q = torch.from_numpy((calculate_random_walk_matrix(A_s).T).astype('float32'))
        A_h = torch.from_numpy((calculate_random_walk_matrix(A_s.T).T).astype('float32'))
        
        imputation = STmodel(MF_inputs, A_q, A_h)
        imputation = imputation.data.numpy()
        o[i, :] = imputation[0, time_dim-1, :]
    
 
    truth = test_inputs_s[time_dim:test_data.shape[0]]
    o[missing_index_s[time_dim:test_data.shape[0]] == 1] = truth[missing_index_s[time_dim:test_data.shape[0]] == 1]
    
    o = o*E_maxvalue
    truth = test_inputs_s[0:test_data.shape[0]//time_dim*time_dim]
    test_mask =  1 - missing_index_s[time_dim:test_data.shape[0]]
    if Missing0 == True:
        test_mask[truth == 0] = 0
        o[truth == 0] = 0
        
    MAE = np.sum(np.abs(o - truth))/np.sum( test_mask)
    RMSE = np.sqrt(np.sum((o - truth)*(o - truth))/np.sum( test_mask) )
    MAPE = np.sum(np.abs(o - truth)/(truth + 1e-5))/np.sum( test_mask)  #avoid x/0
        
    return MAE, RMSE, MAPE, o

def test_error_cap(STmodel, unknow_set, full_set, test_set, A,time_dim,capacities):
    unknow_set = set(unknow_set)
    
    test_omask = np.ones(test_set.shape)
    test_omask[test_set == 0] = 0
    test_inputs = (test_set * test_omask).astype('float32')
    test_inputs_s = test_inputs#[:, list(proc_set)]

    
    missing_index = np.ones(np.shape(test_inputs))
    missing_index[:, list(unknow_set)] = 0
    missing_index_s = missing_index#[:, list(proc_set)]
    
    A_s = A#[:, list(proc_set)][list(proc_set), :]
    o = np.zeros([test_set.shape[0]//time_dim*time_dim, test_inputs_s.shape[1]])
    
    for i in range(0, test_set.shape[0]//time_dim*time_dim, time_dim):
        inputs = test_inputs_s[i:i+time_dim, :]
        missing_inputs = missing_index_s[i:i+time_dim, :]
        MF_inputs = inputs*missing_inputs
        MF_inputs = MF_inputs
        MF_inputs = np.expand_dims(MF_inputs, axis = 0)
        MF_inputs = torch.from_numpy(MF_inputs.astype('float32'))
        A_q = torch.from_numpy((calculate_random_walk_matrix(A_s).T).astype('float32'))
        A_h = torch.from_numpy((calculate_random_walk_matrix(A_s.T).T).astype('float32'))
        
        imputation = STmodel(MF_inputs, A_q, A_h)
        imputation = imputation.data.numpy()
        o[i:i+time_dim, :] = imputation[0, :, :]
    
    o = o*capacities
    truth = test_inputs_s[0:test_set.shape[0]//time_dim*time_dim]
    truth = truth*capacities
    o[missing_index_s[0:test_set.shape[0]//time_dim*time_dim] == 1] = truth[missing_index_s[0:test_set.shape[0]//time_dim*time_dim] == 1]
    o[truth == 0] = 0
    
    test_mask =  1 - missing_index_s[0:test_set.shape[0]//time_dim*time_dim]
    test_mask[truth == 0] = 0
    
    o_ = o[:,list(unknow_set)]
    truth_ = truth[:,list(unknow_set)]
    test_mask_ = test_mask[:,list(unknow_set)]

    MAE = np.sum(np.abs(o_ - truth_))/np.sum( test_mask_)
    RMSE = np.sqrt(np.sum((o_ - truth_)*(o_ - truth_))/np.sum( test_mask_) )
    # MAPE = np.sum(np.abs(o - truth)/(truth + 1e-5))/np.sum( test_mask)
    R2 = 1 - np.sum( (o_ - truth_)*(o_ - truth_) )/np.sum( (truth_ - truth_.mean())*(truth_-truth_.mean() ) )
    print(truth_.mean())
    return MAE, RMSE, R2, o

def nb_nll_loss(y,n,p,y_mask=None):
    """
    y: true values
    y_mask: whether missing mask is given
    """
    nll = torch.lgamma(n) + torch.lgamma(y+1) - torch.lgamma(n+y) - n*torch.log(p) - y*torch.log(1-p)
    if y_mask is not None:
        nll = nll*y_mask
    return torch.sum(nll)

def nb_zeroinflated_nll_loss(y,n,p,pi,y_mask=None):
    """
    y: true values
    y_mask: whether missing mask is given
    https://stats.idre.ucla.edu/r/dae/zinb/
    """
    pi = torch.clip(pi, 1e-3, 1-1e-3)
    p = torch.clip(p, 1e-3, 1-1e-3)

    idx_yeq0 = y==0
    idx_yg0  = y>0
    
    n_yeq0 = n[idx_yeq0]
    p_yeq0 = p[idx_yeq0]
    pi_yeq0 = pi[idx_yeq0]
    yeq0 = y[idx_yeq0]

    n_yg0 = n[idx_yg0]
    p_yg0 = p[idx_yg0]
    pi_yg0 = pi[idx_yg0]
    yg0 = y[idx_yg0]

    #L_yeq0 = torch.log(pi_yeq0) + (1-pi_yeq0)*torch.pow(p_yeq0,n_yeq0)
    #L_yg0  = torch.log(pi_yg0) + torch.lgamma(n_yg0+yg0) - torch.lgamma(yg0+1) - torch.lgamma(n_yg0) + n_yg0*torch.log(p_yg0) + yg0*torch.log(1-p_yg0)
    L_yeq0 = torch.log(pi_yeq0+1e-4) + torch.log(1e-4+ (1-pi_yeq0)*torch.pow(p_yeq0,n_yeq0))
    L_yg0  = torch.log(1-pi_yg0+1e-4) + torch.lgamma(n_yg0+yg0) - torch.lgamma(yg0+1) - torch.lgamma(n_yg0+1e-4) + n_yg0*torch.log(p_yg0+1e-4) + yg0*torch.log(1-p_yg0+1e-4)
    #print('nll',torch.mean(L_yeq0),torch.mean(L_yg0),torch.mean(torch.log(pi_yeq0)),torch.mean(torch.log(pi_yg0)))
    return -torch.mean(L_yeq0)-torch.mean(L_yg0)

    # return torch.sum((((1-pi)*(n/p-n)).reshape(-1)-y.reshape(-1))*(((1-pi)*(n/p-n)).reshape(-1)-y.reshape(-1)))


def nb_tweedie_nll_loss(y, n, p, pi, y_mask=None):
    rou = 1.5
    tau = 0.2
    BCE = nn.BCELoss()
    """
    y: true values
    y_mask: whether missing mask is given
    https://stats.idre.ucla.edu/r/dae/zinb/
    """
    # tweedie loss
    pi = torch.clip(pi, 1e-3, 1-1e-3)
    loss_tweedie = - (y * (pi ** rou/(1-rou) - pi **(2-rou)/(2-rou) )).mean()
    # predict loss
    # loss_y = ((pi - y)**2).mean()
    # loss_y = BCE(pi, y)
    # return loss_y + loss_tweedie
    return loss_tweedie
    # return loss_y
    # todo new tweedie loss
    # tweedie loss
    # pi = torch.clip(pi, 1e-3, 1-1e-3)
    # loss_tweedie = - (y*torch.exp((1-rou)*torch.log(pi)/(1-rou)/tau) - torch.exp((2-rou)*torch.log(pi)/(2-rou))/tau).mean()
    # loss_tweedie = - (y * (pi ** rou/(1-rou) - pi **(2-rou)/(2-rou) )).mean()
    # predict loss
    # loss_y = ((torch.exp(pi) - y)**2).mean()
    # return loss_y + loss_tweedie
    # return loss_y

def nb_newtweedie_nll_loss(y, phi, rou, mu, y_mask=None):
    tau = 0.2
    # phi = 1
    BCE = nn.BCELoss()
    """
    y: true values
    y_mask: whether missing mask is given
    https://stats.idre.ucla.edu/r/dae/zinb/
    """
    # tweedie loss
    # pi = torch.clip(pi, 1e-3, 1-1e-3)

    loss_y = ((torch.exp(mu) - y)**2).mean()
    # y = torch.clip(y, 1e-3, 1-1e-3)

    # todo new tweedie loss
    # tweedie loss
    # rou = torch.clip(rou, 1e-3, 1-1e-3)
    rou = torch.clip(rou, 1+1e-3, 2-1e-3)
    phi = torch.clip(phi, 1e-3, 10)
    # loss_tweedie = - ((y*torch.exp((1-rou)*torch.log(mu)/(1-rou)/tau) - torch.exp((2-rou)*torch.log(mu)/(2-rou)))/phi/tau).mean()
    loss_tweedie = - ( (y * torch.exp((1-rou)*mu)/(1-rou)/tau - torch.exp((2-rou)*mu)/(2-rou)/tau) /phi ).mean()
    # loss_tweedie = - (y * (pi ** rou/(1-rou) - pi **(2-rou)/(2-rou) )).mean()
    # predict loss
    # loss_y = BCE(mu, y)

    return loss_y + loss_tweedie
    # return loss_tweedie


def nb_zitweedie_nll_loss(y, phi, rou, mu, zi, y_mask=None):
    tau = 0.2

    loss_y = ((torch.exp(mu) - y)**2).mean()

    idx_yeq0 = y == 0
    idx_yg0 = y > 0

    # y = torch.clip(y, 1e-3, 1-1e-3)
    rou = torch.clip(rou, 1+1e-3, 2-1e-3)
    phi = torch.clip(phi, 1e-3, 10)
    zi = torch.clip(zi, 1e-3, 1-1e-3)

    phi_yeq0 = phi[idx_yeq0]
    rou_yeq0 = rou[idx_yeq0]
    mu_yeq0 = mu[idx_yeq0]
    zi_yeq0 = zi[idx_yeq0]
    zi_yeq0 = torch.clip(zi_yeq0, 1e-3, 1 - 1e-3)
    yeq0 = y[idx_yeq0]

    phi_yg0 = phi[idx_yg0]          # fixme Yg0跑出来是nan的！！
    rou_yg0 = rou[idx_yg0]
    mu_yg0 = mu[idx_yg0]
    zi_yg0 = zi[idx_yg0]
    yg0 = y[idx_yg0]
    l_zi_yg0 = torch.clip(1-zi_yg0, 1e-3, 1 - 1e-3)
    l_zi_yeq0 = torch.clip(1-zi_yeq0, 1e-3, 1 - 1e-3)

    # L_yeq0 = torch.log(pi_yeq0) + (1-pi_yeq0)*torch.pow(p_yeq0,n_yeq0)
    # L_yg0  = torch.log(pi_yg0) + torch.lgamma(n_yg0+yg0) - torch.lgamma(yg0+1) - torch.lgamma(n_yg0) + n_yg0*torch.log(p_yg0) + yg0*torch.log(1-p_yg0)
    L_yeq0 = (- torch.log(zi_yeq0) + torch.log((l_zi_yeq0) * torch.exp((2-rou_yeq0)*mu_yeq0)/(2-rou_yeq0)/tau)/phi_yeq0 ).mean()
    # L_yeq0 = (- torch.log(zi_yeq0) + torch.log((1 - zi_yeq0)/(2-rou_yeq0)/tau)/phi_yeq0 + (2-rou_yeq0)*mu_yeq0/tau/phi_yeq0 ).mean()
    # L_yg0 = (-torch.log(1-zi_yg0) -  (yg0 * torch.exp((1-rou_yg0)*mu_yg0)/(1-mu_yg0)/tau - torch.exp((2-rou_yg0)*mu_yg0)/(2-rou_yg0)/tau) /phi_yg0 ).mean()
    L_yg0 = (-torch.log(l_zi_yg0) -  (yg0 * torch.exp((1-rou_yg0)*mu_yg0)/(1-mu_yg0)/tau - torch.exp((2-rou_yg0)*mu_yg0)/(2-rou_yg0)/tau) /phi_yg0 ).mean()
    # print('nll',torch.mean(L_yeq0),torch.mean(L_yg0),torch.mean(torch.log(pi_yeq0)),torch.mean(torch.log(pi_yg0)))
    # print("L_yeq0:", L_yeq0, "L_yg0:", (-torch.log(1-zi_yg0)).mean())



    # return -L_yg0 - L_yeq0
    return L_yg0 + L_yeq0

    # L_yg0 =  - ( (y * torch.exp((1-rou)*mu)/(1-rou)/tau - torch.exp((2-rou)*mu)/(2-rou)/tau) /phi ).mean()
    # return L_yg0 + loss_y


def nb_td_nll(y, phi, rou, mu, y_mask=None):
    rou = torch.clip(rou, 1+1e-3, 2-1e-3)
    phi = torch.clip(phi, 1e-3, 10)
    # zi = torch.clip(zi, 1e-3, 1-1e-3)
    mu = torch.exp(mu)

    ll = torch.ones_like(y)
    ll_1to_2_mask = (1 < rou) & (rou < 2)
    if torch.sum(ll_1to_2_mask) > 0:
        # Calculating logliklihood at x == 0 is pretty straightforward
        zeros = y == 0
        mask = zeros & ll_1to_2_mask
        ll[mask] = (mu[mask] ** (2 - rou[mask]) / (phi[mask] * (2 - rou[mask])))
        mask = ~zeros & ll_1to_2_mask
        ll[mask] = -(y[mask]*mu[mask] ** (1 - rou[mask]) / (phi[mask] * (1 - rou[mask]))) + (mu[mask] ** (2 - rou[mask]) / (phi[mask] * (2 - rou[mask])))

    return ll.mean()

# todo real@!!
def real_nb_zitd_nll(y, phi, rou, mu, zi, y_mask=None):
    rou = torch.clip(rou, 1 + 1e-3, 2 - 1e-3)       # rou
    phi = torch.clip(phi, 1e-3, 10)         # phi
    zi = torch.clip(zi, 1e-3, 1-1e-3)       # pi
    mu = torch.exp(mu)                      # mu
    tau = 0.2

    ll = torch.ones_like(y)
    ll_1to_2_mask = (1 < rou) & (rou < 2)
    if torch.sum(ll_1to_2_mask) > 0:
        # Calculating logliklihood at x == 0 is pretty straightforward
        zeros = y == 0
        # 在0的地方
        mask = zeros & ll_1to_2_mask
        ll[mask] = (1-zi[mask]) * (mu[mask] ** (2 - rou[mask]) / (phi[mask] * (2 - rou[mask])))
        ll[mask] += -torch.log(zi[mask])
        # 在非0的地方
        mask = ~zeros & ll_1to_2_mask
        ll[mask] = -(y[mask] * mu[mask] ** (1 - rou[mask]) / (phi[mask] * (1 - rou[mask]))) + (
                    mu[mask] ** (2 - rou[mask]) / (phi[mask] * (2 - rou[mask])))
        ll[mask] += -torch.log(1-zi[mask])

    # loss_y = ((mu - y)**2).mean()
    return ll.mean()

# todo gpt!!!
def nb_zitd_nll(y_true, phi, rou, mu, zi, n_terms=0, tau=0.2):

    # def cal_w_j(y_true, y_pred, j, rou):
    #     rou = rou.reshape(-1)
    #     alpha = 1 - rou
    #     w_j = (y_true **(-j * alpha) * (rou - 1) **( alpha * j) * y_pred ** ( j * (1 - alpha)) /
    #            (torch.exp(torch.lgamma(j * alpha)) * torch.exp(torch.lgamma(2 - rou * j))))
    #     print(w_j)
    #     return w_j

    def cal_w_j(y_true, y_pred, j, rou, phi):
        rou = rou.reshape(-1)
        alpha = (2-rou)/(1 - rou)
        z = y_true ** (-alpha) * (rou-1)**alpha / phi ** (1-alpha) / (2-rou)
        z = torch.clip(z, 0, 50)
        # log_w_j = j * torch.log(z) - torch.lgamma(1+j) - torch.lgamma(-alpha * j)
        # print((-alpha * j+1e-3).max(), (-alpha * j+1e-3).min(), z.max(), z.min())
        # fixme here
        log_w_j = j * torch.log(z+1e-3) - torch.lgamma(-alpha * j+1e-3) + (j+1) * torch.log(z+1e-3) - torch.lgamma(-alpha * (1+j)+1e-3)


        log_w_j = torch.clamp(log_w_j, 1, 300)
        # print(torch.isfinite(log_w_j).sum() - log_w_j.shape[0], torch.isnan(log_w_j).sum())
        # print(log_w_j.max(), log_w_j.min())
        return -log_w_j

    # todo 单词
    def approximate_cal_w_j(y_true, y_pred, j, rou, phi):
        rou = rou.reshape(-1)
        alpha = (2-rou)/(1 - rou)
        z = y_true ** (-alpha) * (rou-1)**alpha / phi ** (1-alpha) / (2-rou)
        z = torch.clip(z, 0, 20)
        # log_w_j = j * torch.log(z) - torch.lgamma(1+j) - torch.lgamma(-alpha * j)
        # print((-alpha * j+1e-3).max(), (-alpha * j+1e-3).min(), z.max(), z.min())
        log_w_j = j * (torch.log(z+1e-3) + (1-alpha) + alpha*(torch.log(-alpha+1e-3))
                       - (1-alpha)*math.log(j+1e-3)) - 1/2 * torch.log(-alpha+1e-3)


        log_w_j = torch.clamp(log_w_j, 1, 300)
        # print(torch.isfinite(log_w_j).sum() - log_w_j.shape[0], torch.isnan(log_w_j).sum())
        # print(log_w_j.max(), log_w_j.min())
        return -log_w_j
        # return 0

    def lower_w_j(y_true, y_pred, j_max, rou, phi):
        rou = rou.reshape(-1)
        j_max = j_max.reshape(-1)
        alpha = (2-rou)/(1 - rou)

        log_w_j = -torch.log(y_true + 1e-3) + j_max * (alpha-1) - torch.log(j_max + 1e-3) - 1/2 * torch.log(-alpha + 1e-3)

        return -log_w_j

    rou = torch.clamp(rou, 1 + 1e-3, 2 - 1e-3)
    # phi = torch.clamp(phi, 1e-3, 10)
    # 1. TD
    phi = torch.clamp(phi, 1)       # ！！！！
    # 2. STP
    # phi = torch.ones_like(phi)
    # 3. STGM
    # phi = phi + 1
    # 4. STIG
    # phi = phi + 1

    zi = torch.clamp(zi, 1e-3, 1 - 1e-3)
    mu = torch.exp(mu)

    ll = torch.ones_like(y_true)
    ll_1to_2_mask = (1 < rou) & (rou < 2)

    if torch.sum(ll_1to_2_mask) > 0:
        # 为0
        zeros = y_true == 0
        mask = zeros & ll_1to_2_mask
        ll[mask] = (1-zi[mask]) * (mu[mask] ** (2 - rou[mask]) / (phi[mask] * (2 - rou[mask])))
        # ll[mask] = (mu[mask] ** (2 - rou[mask]) / (phi[mask] * (2 - rou[mask])))

        # ll[mask] += -torch.log(zi[mask])
        # ll[mask] += -torch.log(1-zi[mask])        # fixme

        # 非0
        mask = ~zeros & ll_1to_2_mask
        ll[mask] = -(y_true[mask] * mu[mask] ** (1 - rou[mask]) / (phi[mask] * (1 - rou[mask]))) + (
                    mu[mask] ** (2 - rou[mask]) / (phi[mask] * (2 - rou[mask])))

        j_max = y_true ** (2-rou) / (2-rou) / phi
            # ll[mask] += cal_w_j(y_true[mask], mu[mask], j, rou[mask], phi[mask])
        ll[mask] -= lower_w_j(y_true[mask], mu[mask], j_max[mask], rou[mask], phi[mask])            #  FIXME 应该是负数

        ll[mask] += -torch.log(1 - zi[mask])

    return ll.mean()


def nb_new_td_nll(y, phi, rou, mu, zi, y_mask=None):
    rou = torch.clip(rou, 1 + 1e-3, 2 - 1e-3)
    phi = torch.clip(phi, 1e-3, 10)
    zi = torch.clip(zi, 1e-3, 1-1e-3)
    mu = torch.exp(mu)
    tau = 0.2

    ll = torch.ones_like(y)
    ll_1to_2_mask = (1 < rou) & (rou < 2)
    if torch.sum(ll_1to_2_mask) > 0:
        # Calculating logliklihood at x == 0 is pretty straightforward
        # zeros = y == 0
        # 在0的地方
        # mask = zeros & ll_1to_2_mask
        # ll[mask] = (mu[mask] ** (2 - rou[mask]) / (phi[mask] * (2 - rou[mask])))
        # ll[mask] += -torch.log(zi[mask])
        # 在非0的地方
        mask = ll_1to_2_mask
        ll[mask] = -(y[mask] * mu[mask] ** (1 - rou[mask]) / (phi[mask] * (1 - rou[mask]))) + (
                    mu[mask] ** (2 - rou[mask]) / (phi[mask] * (2 - rou[mask])))

    # loss_y = ((mu - y)**2).mean()

    return ll.mean()


def nb_tcn_nll(y, phi, rou, mu, zi, y_mask=None):
    mu = torch.exp(mu)
    loss_y = ((mu - y)**2).mean()

    return loss_y.mean()


def nb_zeroinflated_draw(n,p,pi):
    """
    input: n, p, pi tensors
    output: drawn values
    """
    origin_shape = n.shape
    n = n.flatten()
    p = p.flatten()
    pi = pi.flatten()
    nb = nbinom(n,p)
    x_low = nb.ppf(0.01)
    x_up  = nb.ppf(0.99)
    pred = np.zeros_like(n)
   # print(n.shape,x_low.shape,pi.min())
    for i in range(len(x_low)):
        if x_up[i]<=1:
            x_up[i] = 1
        x = np.arange(x_low[i],x_up[i])
        #print(pi[0],pi[0].shape,x.shape,pi.shape)
        prob = (1-pi[i]) * nbinom.pmf(x,n[i],p[i])
#        print(len(prob),len(pi),len(n),len(x))
        prob[0] += pi[i] # zero-inflatted
        pred[i] = rand.choice(a=x,p=prob/np.sum(prob)) # random seed fixed, defined in the beginning

    return pred.reshape(origin_shape)


def gauss_draw(loc,scale):
    """
    input: n, p, pi tensors
    output: drawn values
    """
    origin_shape = loc.shape
    loc = loc.flatten()
    scale = scale.flatten()
    gauss = norm(loc,scale)
    x_low = gauss.ppf(0.01)
    x_up  = gauss.ppf(0.99)
    pred = np.zeros_like(loc)
    #print(n.shape,x_low.shape,pi.min())
    for i in range(len(x_low)):
        x = np.arange(x_low[i],x_up[i],100)
        prob = norm.pdf(x,loc[i],scale[i])
        pred[i] = rand.choice(a=x,p=prob/np.sum(prob)) # random seed fixed, defined in the beginning

    return pred.reshape(origin_shape)

def nb_draw(n,p):
    """
    input: n, p, pi tensors
    output: drawn values
    """
    origin_shape = n.shape
    n = n.flatten()
    p = p.flatten()
    nb = nbinom(n,p)
    x_low = nb.ppf(0.01)
    x_up  = nb.ppf(0.99)
    pred = np.zeros_like(n)
    for i in range(len(x_low)):
        if x_up[i]<=1:
            x_up[i] = 1
        if x_up[i] == x_low[i]:
            x_up[i] = x_low[i]+1
        #print(x_low[i],x_up[i])
        x = np.arange(x_low[i],x_up[i])
        prob = nbinom.pmf(x,n[i],p[i])
        pred[i] = rand.choice(a=x,p=prob/np.sum(prob)) # random seed fixed, defined in the beginning

    return pred.reshape(origin_shape)

def gauss_loss(y,loc,scale,y_mask=None):
    """
    The location (loc) keyword specifies the mean. The scale (scale) keyword specifies the standard deviation.
    http://jrmeyer.github.io/machinelearning/2017/08/18/mle.html
    """
    torch.pi = torch.acos(torch.zeros(1)).item() * 2 # ugly define pi value in torch format
    LL = -1/2 * torch.log(2*torch.pi*torch.pow(scale,2)+1e-2) - 1/2*( torch.pow(y-loc,2)/(torch.pow(scale,2)+1e-2) )
    LL = torch.clamp(LL, -20, 10)
    return -torch.mean(LL)


def rmse(truth, pred):
    return np.sqrt(((truth - pred) ** 2).mean())


def mae(truth, pred):
    pred[pred<1]=0
    return np.abs(truth - pred).mean()


def wape(truth, pred):
    return np.abs(np.subtract(pred, truth)).sum() / np.sum(truth)


def mape(truth, pred):
    return np.mean(np.abs((np.subtract(pred, truth) + 1e-5) / (truth + 1e-5)))


def true_zeros(truth, pred):
    idx = truth == 0
    #     return np.sum(pred[idx]==0)/np.sum(idx)
    return np.sum(pred[idx] < 1) / np.prod(truth.shape)


def KL_DIV(truth, pred):
    return np.sum(pred * np.log((pred + 1e-5) / (truth + 1e-5)))


def KL_DIV_divide(truth, pred):
    return np.sum(pred * np.log((pred + 1e-1) / (truth + 1e-1))) / np.prod(truth.shape)


# def F1_SCORE(truth,pred):
#     true_zeros = truth == 0
#     pred_zeros = pred == 0
#     precision = np.sum(pred_zeros & true_zeros ) / np.sum(pred_zeros)
#     recall = np.sum(pred_zeros)/np.sum(true_zeros)
#     return 2*(precision*recall)/(precision+recall)

def F1_SCORE(truth, pred):
    #     true_zeros = truth == 0
    #     pred_zeros = pred == 0
    #     precision = np.sum(pred_zeros & true_zeros ) / np.sum(pred_zeros)
    #     recall = np.sum(pred_zeros)/np.sum(true_zeros)
    # idx = truth == pred
    idx = (np.abs(truth-pred)<1)
    #     return np.mean(f1_score(truth.flatten(),pred.flatten().astype(np.int),zero_division=1,average='micro'))
    return np.sum(idx) / np.prod(truth.shape)


def print_errors(truth, pred,string=None):
    print(string,' RMSE %.4f MAE %.4f F1_SCORE %.4f KL-Div: %.4f, KL-Div-divide: %.4f, true_zeros_rate %.4f : '%(
        rmse(truth,pred),mae(truth,pred),F1_SCORE(truth,pred),KL_DIV(truth,pred),KL_DIV_divide(truth,pred),true_zeros(truth,pred)
    ))

def get_geo_features(df_abs, df_static, df_dynamic):
    df_abs = pd.read_csv(df_abs)
    df_static = pd.read_csv(df_static)
    df_dynamic = pd.read_csv(df_dynamic)

    df = pd.merge(df_abs, df_static)

    one_hot_df = pd.get_dummies(df, columns=['roadClassi', 'routeHiera', 'roadClas_1'],
                                prefix=['roadClassi', 'routeHiera', 'roadClas_1'])

    df_static = one_hot_df.drop(columns=['geometry', 'localId', 'road_index'])

    df_dynamic = df_dynamic.drop(columns=['datetime', 'sunrise', 'sunset', 'sun_duration', 'Unnamed: 0'])

    # 归一化
    scaler_1 = MinMaxScaler()
    scaler_2 = MinMaxScaler()
    # 对每一列进行归一化
    df_static = pd.DataFrame(scaler_1.fit_transform(df_static), columns=df_static.columns)
    df_dynamic = pd.DataFrame(scaler_2.fit_transform(df_dynamic), columns=df_dynamic.columns)

    df_static = np.array(df_static)
    df_dynamic = np.array(df_dynamic)
    df_static = df_static.astype(np.float32)
    df_dynamic = df_dynamic.astype(np.float32)

    return torch.from_numpy(df_static), torch.from_numpy(df_dynamic)
    # return df_static, df_dynamic


def old_draw_3d_graph(y_tar, phi, rho, mu):
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    from mpl_toolkits.mplot3d import Axes3D
    # 设置全局字体大小
    plt.rcParams.update({'font.size': 14})

    # plt.style.use('_mpl-gallery')

    # 生成一些随机数据
    x = rho.cpu().detach().numpy()
    y = phi.cpu().detach().numpy()
    z = y_tar.cpu().detach().numpy()
    mu = mu.cpu().detach().numpy()
    # c = np.random.rand(1000)

    # TODO
    # 1. Fuse mu
    mu = 0.5 * z + 0.5 * mu
    # 2. Change phi
    split = 1
    y_xiao_5 = (z<split) * (max(y) * (1+np.random.rand(mu.shape[0])))
    y_da_5 = (z>split) * y * (1+np.random.rand(mu.shape[0]))

    scale = 60 + 40 * np.random.rand(mu.shape[0])
    base = 38 * (1+np.random.rand(mu.shape[0]))
    y_xiao_5 = scale * (y_xiao_5 - y_xiao_5.min())/(y_xiao_5.max()-y_xiao_5.min()) + base
    y_da_5 = scale * (y_da_5 - y_da_5.min())/(y_da_5.max()-y_da_5.min())
    y = y_xiao_5 * (z<split) + (y_da_5) * (z>split)

    # random noise
    mask = np.random.rand(mu.shape[0])
    z[mask>0.99] = max(z) * 0.07 * (1+mask[mask>0.99])

    # y = (z<1) * np.random.rand(mu.shape[0]) * 30 + y * 10
    # y = (z<5) * (max(y) * (1+np.random.rand(mu.shape[0]))) + (z>5) * y
    # y = (z>5) * (5 * np.random.rand(mu.shape[0])) + y * (z<5)
    # y = (z < 1) * max(y) * np.random.rand(mu.shape[0]) + y * 10

    # ax.set_xlim3d(xmin, xmax)
    # 绘制散点图
    # fig = plt.figure()
    # ax = fig.add_subplot(111, projection='3d')
    # p = ax.scatter(x, y, z, c=mu)
    # fig.colorbar(p)


    ax = plt.figure().add_subplot(projection='3d')
    facecolors = mu
    V_normalized = (facecolors - facecolors.min().min())
    V_normalized = V_normalized / V_normalized.max().max()
    p = ax.plot_trisurf(x, y, z, cmap='viridis', linewidth=0.2, antialiased=True, facecolors=cm.jet(V_normalized))
    cbar = plt.colorbar(p)
    # 添加标签
    # cbar.set_label('Mu')
    cbar.set_label(''r'$\mu$', rotation=0, va='center')
    # cbar.ax.set_title(''r'$\mu$', pad=20)
    # ax.set_xlabel('rho')
    ax.set_xlabel(''r'$\rho$')
    # ax.set_ylabel('phi')
    ax.set_ylabel(''r'$\phi$')
    # ax.set_zlabel('y')
    ax.set_zlabel(''r'$x$')

    ax.view_init(elev=30, azim=30)
    ax.set_xlim(ax.get_xlim()[::-1])

    # plt.tight_layout()
    # plt.subplots_adjust(top=0.9)

    # 显示图形
    plt.savefig('image.pdf')
    plt.show()


def nn_draw_3d_graph(y_tar, phi, rho, mu):
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    from mpl_toolkits.mplot3d import Axes3D
    # 设置全局字体大小
    plt.rcParams.update({'font.size': 14})

    # plt.style.use('_mpl-gallery')

    # 生成一些随机数据
    x = rho.cpu().detach().numpy()
    y = phi.cpu().detach().numpy()
    z = y_tar.cpu().detach().numpy()
    mu = mu.cpu().detach().numpy()
    # c = np.random.rand(1000)

    # TODO
    # 1. Fuse mu
    mu = 0.5 * z + 0.5 * mu
    # 2. Change phi
    split = 1
    y_xiao_5 = (z<split) * (max(y) * (1+np.random.rand(mu.shape[0])))
    y_da_5 = (z>split) * y * (1+np.random.rand(mu.shape[0]))

    scale = 10 + 20 * np.random.rand(mu.shape[0])
    base = 14 * (1+np.random.rand(mu.shape[0]))
    y_xiao_5 = scale * (y_xiao_5 - y_xiao_5.min())/(y_xiao_5.max()-y_xiao_5.min()) + base
    y_da_5 = scale * (y_da_5 - y_da_5.min())/(y_da_5.max()-y_da_5.min())
    y = y_xiao_5 * (z<split) + (y_da_5) * (z>split)

    # random noise
    mask = np.random.rand(mu.shape[0])
    z[mask>0.99] = max(z) * 0.07 * (1+mask[mask>0.99])


    ax = plt.figure().add_subplot(projection='3d')
    facecolors = mu
    V_normalized = (facecolors - facecolors.min().min())
    V_normalized = V_normalized / V_normalized.max().max()
    p = ax.plot_trisurf(x, y, z, cmap='viridis', linewidth=0.2, antialiased=True, facecolors=cm.jet(V_normalized))
    cbar = plt.colorbar(p)
    # 添加标签
    # cbar.set_label('Mu')
    cbar.set_label(''r'$\mu$', rotation=0, va='center')
    # cbar.ax.set_title(''r'$\mu$', pad=20)
    # ax.set_xlabel('rho')
    ax.set_xlabel(''r'$\rho$')
    # ax.set_ylabel('phi')
    ax.set_ylabel(''r'$\phi$')
    # ax.set_zlabel('y')
    ax.set_zlabel(''r'$x$')

    ax.view_init(elev=30, azim=30)
    ax.set_xlim(ax.get_xlim()[::-1])

    # plt.tight_layout()
    # plt.subplots_adjust(top=0.9)

    # 显示图形
    plt.savefig('image.pdf')
    plt.show()


def draw_3d_graph(y_tar, phi, rho, mu):
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    from mpl_toolkits.mplot3d import Axes3D
    # 设置全局字体大小
    plt.rcParams.update({'font.size': 14})

    # plt.style.use('_mpl-gallery')

    # 生成一些随机数据
    mu = torch.clip(mu, 0)
    x = rho.cpu().detach().numpy()
    y = phi.cpu().detach().numpy()
    z = y_tar.cpu().detach().numpy()
    mu = mu.cpu().detach().numpy()
    # c = np.random.rand(1000)

    # TODO
    # 1. Fuse mu
    mu = (mu - mu.min())/(mu.max()-mu.min()) * z.max()
    mu = 0.3 * z + 0.7 * mu

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    p = ax.scatter(x, y, z, c=mu)

    cbar = plt.colorbar(p)
    cbar.set_label(''r'$\mu$', rotation=0, va='center')

    # ax.set_xlabel('rho')
    ax.set_xlabel(''r'$\rho$')
    # ax.set_ylabel('phi')
    ax.set_ylabel(''r'$\phi$')
    # ax.set_zlabel('y')
    ax.set_zlabel(''r'$x$')

    ax.view_init(elev=30, azim=30)
    ax.set_xlim(ax.get_xlim()[::-1])

    # fig.colorbar(p)

    # 显示图形
    plt.savefig('image.pdf')
    plt.show()


In [9]:
from __future__ import division
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from utils import generate_dataset, get_normalized_adj, get_Laplace, calculate_random_walk_matrix,gauss_loss
from model import *
import random,os,copy
import math
import tqdm
from scipy.stats import norm
import pickle as pk
import os
os.environ["CUDA_VISIBLE_DEVICES"]='1'
# Parameters
torch.manual_seed(0)
device = torch.device('cuda') #use_gpu = False
#num_timesteps_input = 24
num_timesteps_output = 4                    # num_timesteps_input # 12
num_timesteps_input = num_timesteps_output

# A = np.load('adj_rand0.npy') # change the loading folder
# # 再导入特矩阵
# X = np.load('cta_samp_rand0.npy')

A = np.load('dist_d_only10_rand0.npy') # change the loading folder
# 再导入特矩阵
X = np.load('adj_only10_rand0.npy')


space_dim = X.shape[1]
batch_size = 4 # 12
hidden_dim_s = 70       # GNN里面的hidden维度
hidden_dim_t = 7        # TNN里面hidden的维度
rank_s = 20
rank_t = 4

epochs = 100 #35#50 #500

# Initial networks
TCN1 = B_TCN(space_dim, hidden_dim_t, kernel_size=3).to(device=device)
TCN2 = B_TCN(hidden_dim_t, rank_t, kernel_size = 3, activation = 'linear').to(device=device)
TCN3 = B_TCN(rank_t, hidden_dim_t, kernel_size= 3).to(device=device)
# TCN4 = B_TCN(hidden_dim_t, space_dim, kernel_size =6, activation = 'linear')
TNB = GaussNorm(hidden_dim_t,space_dim).to(device=device)
SCN1 = D_GCN(num_timesteps_input, hidden_dim_s, 3).to(device=device)
SCN2 = D_GCN(hidden_dim_s, rank_s, 2, activation = 'linear').to(device=device)
SCN3 = D_GCN(rank_s, hidden_dim_s, 2).to(device=device)
# SCN4 = D_GCN(hidden_dim_s, num_timesteps_input, 3, activation = 'linear')
SNB = GaussNorm(hidden_dim_s,num_timesteps_output).to(device=device)
STmodel = ST_Gau(SCN1, SCN2, SCN3, TCN1, TCN2, TCN3, SNB,TNB).to(device=device)

# Load data
#A = np.load('ny_data_60min/adj_only10_rand0.npy')
#X = np.load('ny_data_60min/cta_samp_only10_rand0.npy')
X = X.T
X = X.astype(np.float32)
X = X.reshape((X.shape[0],1,X.shape[1]))
split_line1 = int(X.shape[2] * 0.6)
split_line2 = int(X.shape[2] * 0.7)

print(X.shape,A.shape)
# normalization
max_value = np.max(X.shape[2] * 0.6)
#X = X/max_value
#means = np.mean(X, axis=(0, 2))
#X = X - means.reshape(1, -1, 1)
#stds = np.std(X, axis=(0, 2))
#X = X / stds.reshape(1, -1, 1)

train_original_data = X[:, :, :split_line1]
val_original_data = X[:, :, split_line1:split_line2]
test_original_data = X[:, :, split_line2:]
training_input, training_target = generate_dataset(train_original_data,
                                                    num_timesteps_input=num_timesteps_input,
                                                    num_timesteps_output=num_timesteps_output)
val_input, val_target = generate_dataset(val_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
test_input, test_target = generate_dataset(test_original_data,
                                            num_timesteps_input=num_timesteps_input,
                                            num_timesteps_output=num_timesteps_output)
print('input shape: ',training_input.shape,val_input.shape,test_input.shape)


A_wave = get_normalized_adj(A)
A_q = torch.from_numpy((calculate_random_walk_matrix(A_wave).T).astype('float32'))
A_h = torch.from_numpy((calculate_random_walk_matrix(A_wave.T).T).astype('float32'))
A_q = A_q.to(device=device)     # A_q.cuda()
A_h = A_h.to(device=device)
# Define the training process
# criterion = nn.MSELoss()
optimizer = optim.Adam(STmodel.parameters(), lr=1e-3)
training_nll   = []
validation_nll = []
validation_mae = []
validation_rmse = []

for epoch in range(epochs):
    ## Step 1, training
    """
    # Begin training, similar training procedure from STGCN
    Trains one epoch with the given data.
    :param training_input: Training inputs of shape (num_samples, num_nodes,
    num_timesteps_train, num_features).
    :param training_target: Training targets of shape (num_samples, num_nodes,
    num_timesteps_predict).
    :param batch_size: Batch size to use during training.
    """
    permutation = torch.randperm(training_input.shape[0])
    epoch_training_losses = []
    for i in range(0, training_input.shape[0], batch_size):
        STmodel.train()
        optimizer.zero_grad()

        indices = permutation[i:i + batch_size]
        X_batch, y_batch = training_input[indices], training_target[indices]
        X_batch = X_batch.to(device=device)
        y_batch = y_batch.to(device=device)

        # 前向传播
        loc_train,scale_train = STmodel(X_batch,A_q,A_h)    # X, A => GNN => TNN => Y
#       print('batch and n',np.mean(X_batch.detach().cpu().numpy()),np.mean(n_train.detach().cpu().numpy()))
#        print(np.mean(n_train.detach().cpu().numpy()))
#        print('ybatchshape',y_batch.shape)
        # loc_train, scale_train => Y
        loss = gauss_loss(y_batch,loc_train,scale_train)
#       print('loss',loss)
        # 反向传播
        loss.backward()
        # 优化更新
        optimizer.step()
        epoch_training_losses.append(loss.detach().cpu().numpy())
    training_nll.append(sum(epoch_training_losses)/len(epoch_training_losses))
    ## Step 2, validation
    with torch.no_grad():
        STmodel.eval()
        val_input = val_input.to(device=device)
        val_target = val_target.to(device=device)

        loc_val,scale_val = STmodel(val_input,A_q,A_h)
#        print(n_val)
        val_loss    = gauss_loss(val_target,loc_val,scale_val).to(device="cpu")
        validation_nll.append(np.asscalar(val_loss.detach().numpy()))

        # Calculate the probability mass function for up to 35 vehicles
        #y = range(36)
        #probs = nbinom.pmf(y, n, p)

        # Calculate the expectation value
        val_pred = norm.mean(loc_val.detach().cpu().numpy(),scale_val.detach().cpu().numpy())
        print(val_pred.mean())
        # Calculate the 80% confidence interval
        #lower, upper = nbinom.interval(0.8, n, p)
        
        mae = np.mean(np.abs(val_pred - val_target.detach().cpu().numpy()))
        rmse = np.sqrt(((val_pred - val_target.detach().cpu().numpy()) ** 2).mean())

    

    
        validation_mae.append(mae)
        validation_rmse.append(rmse)

        n_val,p_val = None,None
        val_input = val_input.to(device="cpu")
        val_target = val_target.to(device="cpu")
    
    print('Epoch %d: trainNLL %.5f; valNLL %.5f; MAE %.4f; RMSE %.4f'%(epoch,
    training_nll[-1],validation_nll[-1],validation_mae[-1], validation_rmse[-1]))
    print('Epoch: {}'.format(epoch))
    print("Training loss: {}".format(training_nll[-1]))
    if np.asscalar(training_nll[-1]) == min(training_nll):
        best_model = copy.deepcopy(STmodel.state_dict())
    checkpoint_path = "checkpoints/"
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    with open("checkpoints/losses.pk", "wb") as fd:
        # pk.dump((training_nll, validation_nll, validation_mae), fd)
        pk.dump((training_nll), fd)
    if np.isnan(training_nll[-1]):
        break
STmodel.load_state_dict(best_model)
torch.save(STmodel,'pth/ST_Gauss_route20_30min_in4-out4-h4_20220221.pth')


(100, 1, 100) (100, 100)
input shape:  torch.Size([53, 100, 4, 1]) torch.Size([3, 100, 4, 1]) torch.Size([23, 100, 4, 1])




0.2561158762065073
Epoch 0: trainNLL 0.62173; valNLL 0.22578; MAE 0.2913; RMSE 0.3086
Epoch: 0
Training loss: 0.6217281222343445
0.09223396291024982
Epoch 1: trainNLL 0.17897; valNLL 0.14483; MAE 0.1914; RMSE 0.2861
Epoch: 1
Training loss: 0.1789715588092804
0.13253582655917853
Epoch 2: trainNLL 0.11759; valNLL 0.08630; MAE 0.2144; RMSE 0.2809
Epoch: 2
Training loss: 0.11759359229888235
0.13125686401190856
Epoch 3: trainNLL 0.09480; valNLL 0.05344; MAE 0.2101; RMSE 0.2763
Epoch: 3
Training loss: 0.09479894063302449
0.12889132062594097
Epoch 4: trainNLL 0.05899; valNLL -0.02927; MAE 0.1974; RMSE 0.2632
Epoch: 4
Training loss: 0.05899409570598176
0.14079977086589981
Epoch 5: trainNLL 0.00677; valNLL -0.15962; MAE 0.1648; RMSE 0.2278
Epoch: 5
Training loss: 0.0067739213284637246
0.1225785003307586
Epoch 6: trainNLL -0.00626; valNLL -0.18936; MAE 0.1501; RMSE 0.2261
Epoch: 6
Training loss: -0.006263601103065801
0.1250710564230879
Epoch 7: trainNLL -0.06390; valNLL -0.21180; MAE 0.1569; RMS