In [1]:
from tensorflow import reshape,transpose,squeeze,GradientTape,expand_dims
from tensorflow.keras.layers import Dense,Input,GRU,Conv1D
from tensorflow.keras.models import Model
from tensorflow.keras import losses,optimizers
import numpy as np
import multiprocessing as mp
from spektral.layers import GCNConv,GATConv
from spektral.utils.convolution import gcn_filter

- **Model**: STGCN may be a possible method to handle with spatial-temporal prediction. Why such structure is needed?
    - [pytorch implementation](https://github.com/LMissher/STGNN)
    - [original](https://github.com/VeritasYin/STGCN_IJCAI-18)
- **Predict**: *T*-times 1-step prediction OR T-step prediction?

In [None]:
class Estimate:
    def __init__(self,conv=None,edges=None,resnet=False,recurrent=None,args=None):
        self.n_node = getattr(args,'n_node',5)
        self.n_in = getattr(args,'n_in',5)
        self.n_out = getattr(args,'n_out',3)
        self.seq_len = getattr(args,'seq_len',6)
        self.embed_size = getattr(args,'embed_size',64)
        self.hidden_dim = getattr(args,"hidden_dim",64)
        self.n_layer = getattr(args,"n_layer",3)

        self.hmax = getattr(args,"hmax",np.array([1.5 for _ in range(self.n_node)]))
        if edges is not None:
            self.edges = edges
            self.filter = self.get_adj(edges)
        self.conv = conv
        self.recurrent = recurrent
        self.model = self.build_network(conv,resnet,recurrent)
        self.loss_fn = losses.get(getattr(args,"loss_function","MeanSquaredError"))
        self.optimizer = optimizers.get(getattr(args,"optimizer","Adam"))
        self.optimizer.learning_rate = getattr(args,"learning_rate",1e-3)

    def get_adj(self,edges):
        A = np.zeros((edges.max()+1,edges.max()+1)) # adjacency matrix
        for u,v in edges:
            A[u,v] += 1
        return A

    def build_network(self,conv=None,resnet=False,recurrent=None):
        # (T,N,in) (T,in*N) (N,in) (in*N)
        input_shape = (self.n_node,self.n_in) if conv else (self.n_node * self.n_in,)
        if recurrent:
            input_shape = (self.seq_len,) + input_shape
        X_in = Input(shape=input_shape)
        x = X_in.copy()
        
        if conv:
            A_in = Input(self.filter.shape[0],)
            inp = [X_in,A_in]
            if 'GCN' in conv:
                self.filter,net = gcn_filter(self.filter),GCNConv
            elif 'GAT' in conv:
                net = GATConv
            elif 'CNN' in conv:
                # TODO: CNN
                net 
            else:
                raise AssertionError("Unknown Convolution layer %s"%str(conv))
        else:
            inp,net = X_in,Dense
        
        # (B,T,N,in) (B,T,in*N) --> (B*T,N,in) (B*T,in*N)
        x = reshape(x,(-1,)+input_shape[1:]) if recurrent else x
        for _ in range(self.n_layer):
            x = [x,A_in] if conv else x
            x_out = net(self.embed_size,activation='relu')(x)
            x = x_out + x if resnet else x_out

        if recurrent:
            # (B*T,N,E) (B*T,E) --> (B,T,N,E) (B,T,E)
            x = reshape(x,(-1,)+input_shape[:-1]+(self.embed_size,))
            # (B,T,N,E) (B,T,E) --> (B,N,T,E) (B,T,E)
            x = transpose(x,[0,2,1,3]) if conv else x
            
            if recurrent == 'Conv1D':
                # (B,N,T,E) (B,T,E) --> (B,N,H) (B,H)
                x = Conv1D(self.hidden_dim,self.seq_len,activation='relu',input_shape=x.shape[-2:])(x)
                x = squeeze(x)
            elif recurrent == 'GRU':
                # (B,N,T,E) (B,T,E) --> (B*N,T,E) (B,T,E)
                x = reshape(x,(-1,self.seq_len,self.embed_size)) if conv else x
                x = GRU(self.hidden_dim)(x)
                # (B*N,H) (B,H) --> (B,N,H) (B,H)
                x = reshape(x,(-1,self.n_node,self.hidden_dim)) if conv else x
            else:
                raise AssertionError("Unknown recurrent layer %s"%str(recurrent))

        out_shape = self.n_out if conv else self.n_out * self.n_node
        out = Dense(out_shape,activation='linear')(x)
        model = Model(inputs=inp, outputs=out)
        return model
    
    def train(self,x,y):
        with GradientTape() as tape:
            tape.watch(self.model.trainable_variables)
            pred = self.model(x)
            loss = self.loss_fn(y,pred)
        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads,self.model.trainable_variables))
        return loss.numpy()
    
    def predict(self,x):
        x = expand_dims(x,0)
        return squeeze(self.model(x),0).numpy()

    def constrain(self,y,r):
        # y,r are 2-d
        # r should be at the same step with q_ds
        if self.conv:
            h,q_us,q_ds = [y[:,:,i] for i in range(3)]
        else:
            h,q_us,q_ds = y[:,:self.n_node],y[:,self.n_node:self.n_node*2],y[:,self.n_node*2:]
        # q_us = [np.zeros((y.shape[0],)) for _ in self.n_node]
        # for u,v in self.edges:
        #     q_us[v] += q_ds[u]
        # q_us = np.array(q_us).T
        q_w = (q_us + r - q_ds).clip(0) * (h > self.hmax)
        return (h,q_us,q_ds,q_w)
        
        

In [2]:
class DataGenerator:
    def __init__(self,env,seq_len = 4,act = False):
        self.env = env
        self.seq_len = seq_len
        if act:
            self.action_table = list(env.config['action_space'].values())
    
    def simulate(self, event, act = False):
        state = self.env.reset(event,global_state=True,seq=self.seq_len)
        states,settings = [state],[]
        done = False
        while not done:
            setting = [table[np.random.randint(0,len(table))] for table in self.action_table] if act else None
            done = self.env.step(setting)
            state = self.env.state(seq=self.seq_len)
            states.append(state)
            settings.append(setting)
        return np.array(states),np.array(settings) if act else None
    
    def state_split(self,states,settings=None):
        if settings is not None:
            # B,T,N,S
            states = states[:settings.shape[0]+1,:,:,:]
            # B,T,n_act
            a = np.tile(np.expand_dims(settings,axis=1),[1,self.seq_len,1])
        h,q_totin,q_ds,r = [states[:,:,:,i] for i in range(4)]
        q_us = q_totin - r
        # B,T,N,in
        X = np.stack([h[:-1],q_us[:-1],q_ds[:-1],r[1:]],axis=-1)
        Y = np.stack([h[1:,-1,:],q_us[1:,-1,:],q_ds[1:,-1,:]],axis=-1)
        if settings is not None:
            X = np.concatenate([X,a],axis=-1)
        return X,Y

    def generate(self,events,processes=5,act=False):
        pool = mp.Pool(processes)
        if processes > 1:
            res = [pool.apply_async(func=self.simulate,args=(event,act,)) for event in events]
            pool.close()
            pool.join()
            res = [self.state_split(*r.get()) for r in res]
        else:
            res = [self.state_split(*self.simulate(event,act)) for event in events]
        self.X,self.Y = [np.concatenate([r[i] for r in res],axis=0) for i in range(2)]
        self.length = self.X.shape[0]
    
    def sample(self,size):
        idx = np.random.choice(range(self.length),size)
        return self.X[idx],self.Y[idx]


In [3]:
from datetime import datetime
from swmm_api import read_inp_file
from envs import shunqing
env = shunqing()
inp = read_inp_file(env.config['swmm_input'])

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
for k,v in inp.TIMESERIES.items():
    if k.startswith(env.config['rainfall']['suffix']):
        dura = v.data[-1][0] - v.data[0][0]
        st = (inp.OPTIONS['START_DATE'],inp.OPTIONS['START_TIME'])
        st = datetime(st[0].year,st[0].month,st[0].day,st[1].hour,st[1].minute,st[1].second)
        et = (st + dura)
        inp.OPTIONS['END_DATE'],inp.OPTIONS['END_TIME'] = et.date(),et.time()
        inp.RAINGAGES['RainGage'].Timeseries = k
        inp.write_file(env.config['rainfall']['filedir']+k+'.inp')
events = [env.config['rainfall']['filedir']+k+'.inp' for k in inp.TIMESERIES if k.startswith(env.config['rainfall']['suffix'])]

In [5]:
dG = DataGenerator(env,seq_len=4)
dG.generate(events,processes=1)

In [13]:
dG.Y.shape

(20396, 113, 2)