In [17]:
import os, random, argparse, json, time
from tqdm import tqdm, trange

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from texttable import Texttable
from typing import Union, Tuple, Dict, List

import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

# Parameter

In [None]:
def parse_args():
    parser = argparse.ArgumentParser()  

    parser.add_argument('--seed', type=int, default=10, help='Random seed of the experiment')
    parser.add_argument('--exp_name', type=str, default='Exp', help='Name of the experiment')
    
    parser.add_argument('--num_step', type=int, default=8640, help='4*24*90=8640') # TODO 90days
    parser.add_argument('--batch_size', type=int, default=96, help='Size of the batch') 
    parser.add_argument('--train_ratio' , type=float, default=3/6, help='train set ratio') 
    parser.add_argument('--val_ratio'   , type=float, default=1/6, help='validation set ratio')
    parser.add_argument('--test_ratio'  , type=float, default=2/6, help='test set ratio') 

    parser.add_argument('--num_history', type=int, default=1, help='Number of historical time steps') 
    parser.add_argument('--ChebyshevDegree', type=int, default=8, help='degree of KAN') 
    parser.add_argument('--FourierDegree', type=int, default=6, help='degree of KAN') 

    parser.add_argument('--num_nodes', type=int, default=24, help='Number of nodes in the graph')
    parser.add_argument('--num_edges', type=int, default=34, help='Number of edges in the graph') 
    parser.add_argument('--hidden_dim', type=int, default=12, help='Dimension of the hidden layers') 
    
    parser.add_argument('--max_epoch', type=int, default=50, help='Maximum number of epochs') 
    parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate of Adam') 
    
    # L1-norm
    parser.add_argument('--l1_lambda', type=float, default=0.00, help='L1 regularization coefficient')
    # L2-norm
    parser.add_argument('--l2_lambda', type=float, default=0.00, help='L2 regularization coefficient')

    parser.add_argument('--gamma', type=float, default=0.95, help='decay parameter') 
    parser.add_argument('--decay_epoch', type=float, default=1, help='decay epoch') 
    parser.add_argument('--patience', type=int, default=10, help='Patience parameter for early stopping') 

    parser.add_argument('--pressure_sensor_name', type=str, default='["4", "13", "14", "23", "24"]', help='pressure sensor index')
    parser.add_argument('--flow_sensor_name', type=str, default='["4", "7", "25", "28", "34"]', help='flow sensor index')
    parser.add_argument('--demand_sensor_name', type=str, default='["2", "16", "19", "22", "24"]', help='demand sensor index')
    
    parser.add_argument('--water_pressure_file', type=str, default="./dataset/simulate_pressure3.csv")
    parser.add_argument('--water_flow_file', type=str, default="./dataset/simulate_flow3.csv")
    parser.add_argument('--water_demand_file', type=str, default="./dataset/simulate_demand3.csv")
	
    parser.add_argument('--node_file', type=str, default='./dataset/node_features_Apulia.csv') # (Reserved Interface)
    parser.add_argument('--edge_file', type=str, default='./dataset/edge_index_dir.csv') 	   # (Reserved Interface)

    parser.add_argument('--incidence_file', type=str, default='./dataset/initial_incidence_matrix_dir.csv')
    parser.add_argument('--cycle_file', type=str, default='./dataset/initial_cycle_matrix_dir.csv')

    args = parser.parse_args(args=[])  

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return args

# Loss Function

In [None]:
# 损失函数 MSE
class ConservationConstraints(nn.Module):
    def __init__(self):
        super(ConservationConstraints, self).__init__()

    def forward(self, pressure_Y, flow_Y, demand_Y, PipeFriction, IncidenceMatrix, CycleMatrix):
        
        # Conservation Constraints —— (10.67 Q^1.852L)/(C^1.852 D^4.87)
        PipeFriction = PipeFriction.view(1, 1, -1) 
        pipe_pressure_minus = (abs(flow_Y)**1.852) * PipeFriction 
        pipe_pressure_minus = torch.where(flow_Y > 0, 
                                        pipe_pressure_minus, 
                                        -pipe_pressure_minus)

        loss2_1 = torch.sum((torch.abs((flow_Y @ IncidenceMatrix.t()) + demand_Y)*100)** 2)/96           #  100 -->  targer<0.01  (Nodal flow equation)
        loss2_2 = torch.sum((torch.abs((pipe_pressure_minus @ CycleMatrix.t()))*10)** 2)/96              #  10  -->  targer<0.1   (loop energy equation)
        loss2_3 = torch.sum((torch.abs((pressure_Y @ IncidenceMatrix) - pipe_pressure_minus)*10)** 2)/96 #  10  -->  targer<0.1   (Pipe pressure drop equation)

        return loss2_1, loss2_2, loss2_3

In [None]:
class ConservationConstraints_generalization(nn.Module):
    def __init__(self):
        super(ConservationConstraints_generalization, self).__init__()
        
    def forward(self, pressure_Y, flow_Y, demand_Y, PipeFriction, IncidenceMatrix, CycleMatrix):
        
        # Conservation Constraints —— (10.67 Q^1.852L)/(C^1.852 D^4.87)
        PipeFriction = PipeFriction.view(1,  -1) 
        pipe_pressure_minus = (abs(flow_Y)**1.852) * PipeFriction 
        pipe_pressure_minus = torch.where(flow_Y > 0, 
                                        pipe_pressure_minus, 
                                        -pipe_pressure_minus)
        loss2_1 = torch.sum((torch.abs((flow_Y @ IncidenceMatrix.t()) + demand_Y))** 2, dim=0)/96            #  100 -->  targer<0.01 (Nodal flow equation)
        loss2_2 = torch.sum((torch.abs((pipe_pressure_minus @ CycleMatrix.t())))** 2, dim=0)/96              #  10  -->  targer<0.1  (loop energy equation)
        loss2_3 = torch.sum((torch.abs((pressure_Y @ IncidenceMatrix) - pipe_pressure_minus))** 2, dim=0)/96 #  10  -->  targer<0.1  (Pipe pressure drop equation)
        
        return loss2_1, loss2_2, loss2_3

# Masking

In [None]:
def apply_mask(args, pressure_data, flow_data, demand_data):
    ''' 
    生成两步掩码：一步掩码张量用于掩盖只剩传感器，二步掩码张量用于继续掩盖传感器
    1. 只有传感器：masked_tensor1 / 掩码方式：mask1
    2. 再对传感器进行掩盖：masked_tensor2 / 掩码方式：mask2
    '''
    pressure_sensor_name = json.loads(args.pressure_sensor_name)
    flow_sensor_name = json.loads(args.flow_sensor_name)
    demand_sensor_name = json.loads(args.demand_sensor_name)

    pressure_sensor_index = [int(i)-1 for i in pressure_sensor_name]
    flow_sensor_index = [int(i)-1 for i in flow_sensor_name]
    demand_sensor_index = [int(i)-1 for i in demand_sensor_name]

    mask1_pressure = torch.zeros(pressure_data.shape)
    mask1_flow     = torch.zeros(flow_data.shape)
    mask1_demand   = torch.zeros(pressure_data.shape)
    mask1_pressure[:,pressure_sensor_index] = 1
    mask1_flow[:,flow_sensor_index] = 1
    mask1_demand[:,demand_sensor_index] = 1
    mask1_demand_reduced = mask1_demand[:, :-1]

    mask1_pressure = mask1_pressure.to(args.device)
    mask1_flow = mask1_flow.to(args.device)
    mask1_demand_reduced = mask1_demand_reduced.to(args.device)
    
    return (mask1_pressure, mask1_flow, mask1_demand_reduced)

# DATA

In [None]:
def seq2instance(args, data):
    num_step, dims = data.shape
    # Each additional step generates a training sample.
    num_sample = num_step - args.batch_size - args.batch_size + 1

    x = torch.zeros(num_sample, args.batch_size, dims)
    y = torch.zeros(num_sample, args.batch_size, dims)
    
    for i in range(num_sample):
        x[i] = data[i: i + args.batch_size]
        y[i] = data[i: i + args.batch_size]
    return x, y

class CustomDataset(Dataset):
    def __init__(self, args, dataset_type='train'):
        self.args = args
        self.dataset_type = dataset_type
        
        train_steps = round(args.train_ratio * args.num_step)
        test_steps = round(args.test_ratio * args.num_step)
        val_steps = args.num_step - train_steps - test_steps

        NodePressure = pd.read_csv(args.water_pressure_file, header=0, index_col=0)
        PipeFlow = pd.read_csv(args.water_flow_file, header=0, index_col=0)
        NodeDemand = pd.read_csv(args.water_demand_file, header=0, index_col=0)

        # global min-max 
        NodePressure_ = (NodePressure - NodePressure.min().min()) / (NodePressure.max().max() - NodePressure.min().min())
        PipeFlow_ = (PipeFlow - PipeFlow.min().min()) / (PipeFlow.max().max() - PipeFlow.min().min())
        NodeDemand_ = (NodeDemand - (NodeDemand.iloc[:, :-1]).min().min()) / (NodeDemand.max().max() - (NodeDemand.iloc[:, :-1]).min().min())

        NodePressure = torch.FloatTensor(NodePressure_.values)
        PipeFlow = torch.FloatTensor(PipeFlow_.values)
        NodeDemand = torch.FloatTensor(NodeDemand_.values)

        # split dataset
        self.train_pressure = NodePressure[: train_steps]
        self.train_flow     = PipeFlow[: train_steps]
        self.train_demand   = NodeDemand[: train_steps]

        self.val_pressure   = NodePressure[train_steps: train_steps + val_steps]
        self.val_flow       = PipeFlow[train_steps: train_steps + val_steps]
        self.val_demand     = NodeDemand[train_steps: train_steps + val_steps]

        self.test_pressure  = NodePressure[-test_steps:]
        self.test_flow      = PipeFlow[-test_steps:]
        self.test_demand    = NodeDemand[-test_steps:]

        self.trainX_pressure, self.trainY_pressure = seq2instance(args, self.train_pressure)
        self.trainX_flow, self.trainY_flow = seq2instance(args, self.train_flow)
        self.trainX_demand, self.trainY_demand = seq2instance(args, self.train_demand)

        self.valX_pressure, self.valY_pressure = seq2instance(args, self.val_pressure)
        self.valX_flow, self.valY_flow = seq2instance(args, self.val_flow)
        self.valX_demand, self.valY_demand = seq2instance(args, self.val_demand)

        self.testX_pressure, self.testY_pressure = seq2instance(args, self.test_pressure)
        self.testX_flow, self.testY_flow = seq2instance(args, self.test_flow)
        self.testX_demand, self.testY_demand = seq2instance(args, self.test_demand)

    def __len__(self):
        if self.dataset_type == 'train':
            return len(self.trainX_pressure)
        elif self.dataset_type == 'val':
            return len(self.valX_pressure)
        elif self.dataset_type == 'test':
            return len(self.testX_pressure)
        
    def __getitem__(self, idx):
        if self.dataset_type == 'train':
            return (self.trainX_pressure[idx].to(self.args.device), self.trainY_pressure[idx].to(self.args.device),
                    self.trainX_flow[idx].to(self.args.device), self.trainY_flow[idx].to(self.args.device),
                    self.trainX_demand[idx].to(self.args.device), self.trainY_demand[idx].to(self.args.device)
                    )
        
        elif self.dataset_type == 'val':
            return (self.valX_pressure[idx].to(self.args.device), self.valY_pressure[idx].to(self.args.device),
                    self.valX_flow[idx].to(self.args.device), self.valY_flow[idx].to(self.args.device),
                    self.valX_demand[idx].to(self.args.device), self.valY_demand[idx].to(self.args.device)
                    )

        elif self.dataset_type == 'test':
            return (self.testX_pressure[idx].to(self.args.device), self.testY_pressure[idx].to(self.args.device), 
                    self.testX_flow[idx].to(self.args.device), self.testY_flow[idx].to(self.args.device),
                    self.testX_demand[idx].to(self.args.device), self.testY_demand[idx].to(self.args.device)
                    )  

In [None]:
def load_matrix(args):
    # node file
    nodes_df = pd.read_csv(args.node_file, header=0, index_col=0)
    nodes_attr = torch.from_numpy(nodes_df.values).float().to(args.device)

    # edge file —— edge weight (Hazen-Williams)
    edge_df = pd.read_csv(args.edge_file, header=0, index_col=0)
    edge_index  = edge_df.iloc[:, 0:2].T  
    edge_attr = edge_df.iloc[:, 2:7]     
    edge_index = torch.from_numpy(edge_index.values).long().to(args.device)
    edge_attr = torch.from_numpy(edge_attr.values).float().to(args.device)

    Graph_Data = Data(x=nodes_attr, edge_index=edge_index, edge_attr=edge_attr)

    PipeFriction = edge_df.iloc[:, -1].values
    pipe_friction = torch.from_numpy(PipeFriction).float().to(args.device)
    
    # incidence matrix
    incidence_matrix_df = pd.read_csv(args.incidence_file, header=None, index_col=None)
    incidence_matrix = torch.from_numpy(incidence_matrix_df.values).float().to(args.device)
    # cycle matrix
    cycle_matrix_df = pd.read_csv(args.cycle_file, header=0, index_col=0)
    cycle_matrix = torch.from_numpy(cycle_matrix_df.values).float().to(args.device)

    return Graph_Data, pipe_friction, incidence_matrix, cycle_matrix

# Utiles

In [None]:
class IOStream():
    """训练日志文件"""
    def __init__(self, path):
        self.file = open(path, 'a') 

    def cprint(self, text):
        print(text)
        self.file.write(text + '\n')
        self.file.flush() 

    def close(self):
        self.file.close()

def table_printer(args):
    """绘制参数表格"""
    args = vars(args) 
    keys = sorted(args.keys()) 
    table = Texttable()
    table.set_cols_dtype(['t', 't']) 
    rows = [["Parameter", "Value"]] 
    for k in keys:
        rows.append([k.replace("_", " ").capitalize(), str(args[k])]) 
    table.add_rows(rows)
    return table.draw()

# KAN Parts

In [25]:
# 基于切比雪夫基函数的KAN层
class ChebyshevKANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, degree):
        super(ChebyshevKANLayer, self).__init__()
        self.inputdim = input_dim
        self.outdim = output_dim
        self.ChebyshevDegree = degree
        self.addbias  = True
        self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
        nn.init.normal(self.cheby_coeffs, mean=0.0, std=1/(input_dim * (degree + 1)))
        self.bias = nn.Parameter(torch.zeros(1, output_dim))

    def forward(self, x):
        xshp = x.shape                          # (seq_length, inputdim)
        outshape = xshp[0:-1] + (self.outdim,)  # (seq_length, outdim)
        x = x.view(-1, self.inputdim)           # (seq_length, inputdim)

        x = torch.tanh(x)

        cheby = torch.ones(x.shape[0], self.inputdim, self.ChebyshevDegree + 1, device=x.device)
        if self.ChebyshevDegree > 0:
            cheby[:, :, 1] = x
        for i in range(2, self.ChebyshevDegree + 1):
            cheby[:, :, i] = 2 * x * cheby[:, :, i - 1].clone() - cheby[:, :, i - 2].clone()

        # Compute the Chebyshev interpolation
        # (seq_length, inputdim, degree+1) x (inputdim, output_dim, degree+1) -> (seq_length, output_dim)
        y = torch.einsum('bid,iod->bo', cheby, self.cheby_coeffs)

        # y += self.bias

        y = y.view(outshape) 
        
        return y

In [None]:
# 傅里叶基函数
class FourierKANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, degree):
        super(FourierKANLayer,self).__init__()
        self.gridsize = degree
        self.addbias  = True
        self.inputdim = input_dim
        self.outdim = output_dim

        self.fouriercoeffs = nn.Parameter(
                    torch.randn(2, output_dim, input_dim, degree) / 
                    (np.sqrt(input_dim) * np.sqrt(self.gridsize))
                     )
        nn.init.normal(self.fouriercoeffs, mean=0.0, std=1 / (np.sqrt(input_dim) * np.sqrt(self.gridsize)))
        
        self.bias = nn.Parameter(torch.zeros(1, output_dim))

    def forward(self, x):
        xshp = x.shape                          # (batch_size, seq_length, inputdim)
        outshape = xshp[0:-1] + (self.outdim,)  # (batch_size, seq_lengtha, outdim)
        x = torch.tanh(x)     

        x = x.view(-1, self.inputdim)           # (batch_size*seq_length, inputdim)

        # Fourier basis function -- (1, 1, 1, gridsize)
        k = torch.reshape(torch.arange(1, self.gridsize+1, device=x.device), (1, 1, 1, self.gridsize))
        xrshp = x.view(x.shape[0], 1, x.shape[1], 1)   # (batch_size*seq_length, 1, inputdim, 1)

        c = torch.cos(k * xrshp)
        s = torch.sin(k * xrshp)

        y1 = torch.sum(c * self.fouriercoeffs[0:1], (-2, -1)) 
        y2 = torch.sum(s * self.fouriercoeffs[1:2], (-2, -1))

        y = y1 + y2
        
        # y += self.bias

        # (batch_size, seq_length, outdim)
        y = y.view(outshape) 
        
        return y
        

In [None]:
# 注意力机制
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.query_layer = nn.Linear(hidden_size, hidden_size)
        self.key_layer   = nn.Linear(hidden_size, hidden_size)
        self.value_layer = nn.Linear(hidden_size, hidden_size)

        self.layer_norm  = nn.LayerNorm(hidden_size)

    def forward(self, attention_from, attention_to):
        # 计算注意力分数
        query = self.query_layer(attention_from)    # (batch_size, hidden_size)
        key = self.key_layer(attention_to)          # (batch_size, hidden_size)
        value = self.value_layer(attention_from)    # (batch_size, hidden_size)

        # 计算注意力权重
        scores = torch.matmul(query, key.T)                # (batch_size, batch_size)
        attention_weights = F.softmax(scores, dim=-1)     

        # 计算加权输出
        output = torch.matmul(attention_weights, value)  # (batch_size, hidden_size)
        output = self.layer_norm(output)  # (batch_size, hidden_size)

        return output

# model

In [None]:
class KANSA(nn.Module):
    def __init__(self, args):
        """
        :param num_layers: number of KANSA layers

        :param input_node_dim: input dimension of node features
        :param input_edge_dim: input dimension of edge features

        :param max_path_distance: max pairwise distance between two nodes
        """
        super().__init__()
        self.batch_size = args.batch_size
        self.ChebyshevDegree = args.ChebyshevDegree
        self.FourierDegree = args.FourierDegree

        self.num_nodes = args.num_nodes
        self.num_edges = args.num_edges
        self.hidden_dim = args.hidden_dim

        ''' input ''' 
        # Temporal features
        self.flow_batch = nn.Sequential(
            nn.Linear(self.batch_size, self.hidden_dim),  
            nn.Linear(self.hidden_dim, self.batch_size),
            FourierKANLayer(self.batch_size, self.hidden_dim, self.FourierDegree),  
            FourierKANLayer(self.hidden_dim, self.batch_size, self.FourierDegree),   
            nn.LayerNorm(self.batch_size),            
            )
        self.pressure_batch = nn.Sequential(
            nn.Linear(self.batch_size, self.hidden_dim), 
            nn.Linear(self.hidden_dim, self.batch_size),
            FourierKANLayer(self.batch_size, self.hidden_dim, self.FourierDegree), 
            FourierKANLayer(self.hidden_dim, self.batch_size, self.FourierDegree),   
            nn.LayerNorm(self.batch_size),    
            )
        self.demand_batch = nn.Sequential(
            nn.Linear(self.batch_size, self.hidden_dim), 
            nn.Linear(self.hidden_dim, self.batch_size),
            FourierKANLayer(self.batch_size, self.hidden_dim, self.FourierDegree),   
            FourierKANLayer(self.hidden_dim, self.batch_size, self.FourierDegree),   
            nn.LayerNorm(self.batch_size),      
            )
        # Spatial features
        self.flow_features = nn.Sequential(
            nn.Linear(self.num_edges, self.hidden_dim),
            nn.Linear(self.hidden_dim, self.num_edges), 
            ChebyshevKANLayer(self.num_edges, self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(self.hidden_dim, self.num_edges, self.ChebyshevDegree), 
            nn.LayerNorm(self.num_edges),
            )
        self.pressure_features = nn.Sequential(
            nn.Linear(self.num_nodes, self.hidden_dim),
            nn.Linear(self.hidden_dim, self.num_nodes),
            ChebyshevKANLayer(self.num_nodes, self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(self.hidden_dim, self.num_nodes, self.ChebyshevDegree), 
            nn.LayerNorm(self.num_nodes), 
            )
        self.demand_features = nn.Sequential(
            nn.Linear(self.num_nodes-1, self.hidden_dim),
            nn.Linear(self.hidden_dim, self.num_nodes-1),
            ChebyshevKANLayer(self.num_nodes-1, self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(self.hidden_dim, self.num_nodes-1, self.ChebyshevDegree), 
            nn.LayerNorm(self.num_nodes-1), 
            )
        # Self-attention temporal
        self.attention_pressure_batch = Attention(self.batch_size)
        self.attention_flow_batch     = Attention(self.batch_size)
        self.attention_demand_batch   = Attention(self.batch_size)
        # Self-attention spatial
        self.attention_pressure_features = Attention(self.num_nodes)
        self.attention_flow_features     = Attention(self.num_edges)
        self.attention_demand_features   = Attention(self.num_nodes-1)

        self.pressure_1 = nn.Sequential(
            nn.LayerNorm(self.num_nodes),
            ChebyshevKANLayer(self.num_nodes, self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(self.hidden_dim, self.hidden_dim, self.ChebyshevDegree)
            )
        self.pressure_2 = nn.Sequential(
            nn.LayerNorm(self.num_nodes),
            ChebyshevKANLayer(self.num_nodes, self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(self.hidden_dim, self.hidden_dim, self.ChebyshevDegree),
            )
        self.flow_1 = nn.Sequential(
            nn.LayerNorm(self.num_edges),
            ChebyshevKANLayer(self.num_edges, self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(self.hidden_dim, self.hidden_dim, self.ChebyshevDegree),
            )
        self.flow_2 = nn.Sequential(
            nn.LayerNorm(self.num_edges),
            ChebyshevKANLayer(self.num_edges, self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(self.hidden_dim, self.hidden_dim, self.ChebyshevDegree),
            )
        self.demand_1 = nn.Sequential(
            nn.LayerNorm(self.num_nodes-1),
            ChebyshevKANLayer(self.num_nodes-1, self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(self.hidden_dim, self.hidden_dim, self.ChebyshevDegree),
            )
        self.demand_2 = nn.Sequential(
            nn.LayerNorm(self.num_nodes-1),
            ChebyshevKANLayer(self.num_nodes-1, self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(self.hidden_dim, self.hidden_dim, self.ChebyshevDegree),
            )

        self.pressure_sum = nn.Sequential(
            nn.LayerNorm(self.hidden_dim),
            ChebyshevKANLayer(self.hidden_dim, 2*self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(2*self.hidden_dim, 2*self.hidden_dim, self.ChebyshevDegree),
            )
        self.flow_sum = nn.Sequential(
            nn.LayerNorm(self.hidden_dim),
            ChebyshevKANLayer(self.hidden_dim, 2*self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(2*self.hidden_dim, 2*self.hidden_dim, self.ChebyshevDegree),
            )
        self.demand_sum = nn.Sequential(
            nn.LayerNorm(self.hidden_dim),
            ChebyshevKANLayer(self.hidden_dim, 2*self.hidden_dim, self.ChebyshevDegree),
            ChebyshevKANLayer(2*self.hidden_dim, 2*self.hidden_dim, self.ChebyshevDegree),
            )

        # Processor
        self.graph = nn.Sequential(
            nn.LayerNorm(3),
            ChebyshevKANLayer(3, 12, self.ChebyshevDegree),
            ChebyshevKANLayer(12, 1, self.ChebyshevDegree),
            )

        # Decoding
        self.flow_out = nn.Sequential(
            ChebyshevKANLayer(2*self.hidden_dim, 2*self.num_edges, self.ChebyshevDegree),
            ChebyshevKANLayer(2*self.num_edges, self.num_edges, self.ChebyshevDegree),
            nn.Linear(self.num_edges, self.num_edges),
            nn.Linear(self.num_edges, self.num_edges),
            nn.Sigmoid(),
            )
        self.pressure_out = nn.Sequential(
            ChebyshevKANLayer(2*self.hidden_dim, 2*self.num_nodes, self.ChebyshevDegree),
            ChebyshevKANLayer(2*self.num_nodes, self.num_nodes, self.ChebyshevDegree),
            nn.Linear(self.num_nodes, self.num_nodes),
            nn.Linear(self.num_nodes, self.num_nodes),
            nn.Sigmoid(),
            )
        # The last column of Demand is negative (only the previous columns are estimated)
        self.demand_out = nn.Sequential(
            ChebyshevKANLayer(2*self.hidden_dim, 2*self.num_nodes, self.ChebyshevDegree),
            ChebyshevKANLayer(2*self.num_nodes, self.num_nodes-1, self.ChebyshevDegree),
            nn.Linear(self.num_nodes-1, self.num_nodes-1),
            nn.Linear(self.num_nodes-1, self.num_nodes-1),
            nn.Sigmoid(),
            )

    def forward(self, Graph_Data, 
                mask_pressure, mask_flow, mask_demand,
                pressure, flow, demand) -> Tuple[torch.Tensor, torch.Tensor]:

        pressure_ = pressure.float()     # (batch_size, num_nodes)
        flow_ = flow.float()             # (batch_size, num_edges)
        demand_ = demand.float()         # (batch_size, num_nodes)

        # Temporal
        pressure_batch      = self.pressure_batch(pressure_.T)     # (batch_size, num_nodes)
        flow_batch          = self.flow_batch(flow_.T)             # (batch_size, num_edges)
        demand_batch        = self.demand_batch(demand_.T)         # (batch_size, num_nodes)
        pressure_batch      = (self.attention_pressure_batch(pressure_batch, pressure_.T)).T       # (batch_size, num_nodes)
        flow_batch          = (self.attention_flow_batch(flow_batch, flow_.T)).T                   # (batch_size, num_edges)
        demand_batch        = (self.attention_demand_batch(demand_batch, demand_.T)).T             # (batch_size, num_nodes)
        # Spatial
        pressure_features   = self.pressure_features(pressure_)    # (batch_size, num_nodes)
        flow_features       = self.flow_features(flow_)            # (batch_size, num_edges)
        demand_features     = self.demand_features(demand_)        # (batch_size, num_nodes)
        pressure_features   = self.attention_pressure_features(pressure_features, pressure_)       # (batch_size, num_nodes)
        flow_features       = self.attention_flow_features(flow_features, flow_)                   # (batch_size, num_edges)
        demand_features     = self.attention_demand_features(demand_features, demand_)             # (batch_size, num_nodes)

        #  "*" is better
        pressure_batch_    = self.pressure_1(pressure_batch * (1-mask_pressure) + pressure_)
        pressure_features_ = self.pressure_2(pressure_features * (1-mask_pressure) + pressure_)
        pressure_st_       = pressure_batch_ * pressure_features_
        flow_batch_        = self.flow_1(flow_batch * (1-mask_flow) + flow_)
        flow_features_     = self.flow_2(flow_features * (1-mask_flow) + flow_)
        flow_st_           = flow_batch_ * flow_features_  
        demand_batch_      = self.demand_1(demand_batch * (1-mask_demand)  + demand_)
        demand_features_   = self.demand_2(demand_features * (1-mask_demand) + demand_)
        demand_st_         = demand_batch_ * demand_features_

        # Concat --> Processor
        Graph_features = torch.cat((self.pressure_sum(pressure_st_).unsqueeze(-1), 
                                    self.flow_sum(flow_st_).unsqueeze(-1), 
                                    self.demand_sum(demand_st_).unsqueeze(-1)), 
                                    dim=-1)
        graph_features = self.graph(Graph_features).squeeze(-1)

        # Decoding
        # Residual connection directly replaces the output
        pressure_dd = self.pressure_out(graph_features) * (1-mask_pressure)   + pressure_ 
        flow_dd = self.flow_out(graph_features)         * (1-mask_flow)       + flow_ 
        demand_dd = self.demand_out(graph_features)     * (1-mask_demand)     + demand_

        return pressure_dd, flow_dd, demand_dd

# Main

In [None]:
def train(args, IO,  train_loader, val_loader, 
        min_pressure, max_pressure, min_flow, max_flow, min_demand, max_demand):
    best_val_loss = float('inf')
    patience = args.patience  
    patience_counter = 0

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.cuda.manual_seed(args.seed)  
    
    Graph_Data, pipe_friction, incidence_matrix, cycle_matrix = load_matrix(args)

    model = KANSA(args).to(device)

    IO.cprint(str(model))
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    IO.cprint('Model Parameter: {}'.format(total_params))
    
    # RMSprop  (Trick)
    optimizer = optim.RMSprop(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                        step_size=args.decay_epoch,
                                        gamma=args.gamma)
    IO.cprint('Using RMSprop')

    # Loss Function
    criterion = ConservationConstraints()
    # Save
    train_loss1_list = []
    train_loss2_list = []
    train_loss3_list = []
    val_loss1_list = []
    val_loss2_list = []
    val_loss3_list = []
    start_time = time.time()
    for epoch in range(args.max_epoch):
        #################
        ###   Train   ###
        #################
        model.train()  
        train_loss1 = 0.0 
        train_loss2 = 0.0
        train_loss3 = 0.0
        pressure_True = torch.Tensor().to(args.device)
        pressure_Est  = torch.Tensor().to(args.device)
        flow_True     = torch.Tensor().to(args.device)
        flow_Est      = torch.Tensor().to(args.device)
        demand_True   = torch.Tensor().to(args.device)
        demand_Est    = torch.Tensor().to(args.device)

        for i, data in tqdm(enumerate(train_loader), total=len(train_loader), desc="Train_Loader"):
            # 加载数据
            (train_pressureX, train_pressureY, 
            train_flowX, train_flowY, 
            train_demandX, train_demandY) = data
            train_pressureX = train_pressureX.squeeze(0)
            train_flowX     = train_flowX.squeeze(0)
            train_demandX   = train_demandX.squeeze(0)

            # Random Masking
            (mask1_pressure, mask1_flow, mask1_demand)=apply_mask(args, train_pressureX, train_flowX, train_demandX)

            # masking
            train_pressureX_ = train_pressureX * mask1_pressure
            train_flowX_     = train_flowX * mask1_flow
            train_demandX_   = (train_demandX[:, :-1]) * mask1_demand

            # model
            pressure_Y, flow_Y, demand_Y = model(Graph_Data, mask1_pressure, mask1_flow, mask1_demand, 
                                                train_pressureX_, train_flowX_, train_demandX_)
            # Add last column to demand_Y
            row_sums = torch.sum(demand_Y, dim=1, keepdim=True)
            demand_Y = torch.cat((demand_Y, -row_sums), dim=1)

            # min-max normalized reduction
            pressure_Y = (pressure_Y * (max_pressure - min_pressure) + min_pressure)
            flow_Y     = (flow_Y     * (max_flow     - min_flow)     + min_flow)
            demand_Y   = (demand_Y   * (max_demand   - min_demand)   + min_demand)
            train_pressureX  = (train_pressureX  * (max_pressure - min_pressure) + min_pressure)
            train_pressureX_ = (train_pressureX_ * (max_pressure - min_pressure) + min_pressure)
            train_flowX      = (train_flowX      * (max_flow     - min_flow)     + min_flow)
            train_flowX_     = (train_flowX_     * (max_flow     - min_flow)     + min_flow)
            train_demandX    = (train_demandX    * (max_demand   - min_demand)   + min_demand)
            train_demandX_   = (train_demandX_   * (max_demand   - min_demand)   + min_demand)
            
            # Caculate Loss
            (loss1, loss2, loss3) = criterion(pressure_Y, flow_Y, demand_Y, pipe_friction, incidence_matrix, cycle_matrix)

            # Save Data
            pressure_True = torch.cat((pressure_True, train_pressureX), 0)
            pressure_Est  = torch.cat((pressure_Est, pressure_Y), 0)
            flow_True     = torch.cat((flow_True,   train_flowX), 0)
            flow_Est      = torch.cat((flow_Est,     flow_Y), 0)
            demand_True   = torch.cat((demand_True, train_demandX), 0)
            demand_Est    = torch.cat((demand_Est,   demand_Y), 0)

            # l1+l2 norm 
            l1_norm = sum(p.abs().sum() for p in model.parameters())
            l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            loss_ = loss1 + loss2 + loss3

            loss = loss_ + args.l1_lambda * l1_norm + args.l2_lambda * l2_norm
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            train_loss1 += (loss1).item()
            train_loss2 += (loss2).item()
            train_loss3 += (loss3).item()

        scheduler.step()
        avg_train_loss1 = (train_loss1) / len(train_loader.dataset)
        avg_train_loss2 = (train_loss2) / len(train_loader.dataset)
        avg_train_loss3 = (train_loss3) / len(train_loader.dataset)

        train_loss1_list.append(avg_train_loss1)
        train_loss2_list.append(avg_train_loss2)
        train_loss3_list.append(avg_train_loss3)

        IO.cprint('Epoch #{:03d}, Conservation1: {:.4f}, Conservation2: {:.4f}, Conservation3: {:.4f}'.format(
                    epoch, (avg_train_loss1), (avg_train_loss2), (avg_train_loss3)
                    ))

        #################
        ###   Valid   ###
        #################
        model.eval()
        val_loss1 = 0.0
        val_loss2 = 0.0
        val_loss3 = 0.0
        with torch.no_grad():
            for i, data in tqdm(enumerate(val_loader), total=len(val_loader), desc="Val_Loader"):

                (val_pressureX, val_pressureY, 
                 val_flowX, val_flowY,
                 val_demandX, val_demandY)= data
                val_pressureX = val_pressureX.squeeze(0)
                val_flowX     = val_flowX.squeeze(0)
                val_demandX   = val_demandX.squeeze(0)

                # Random Masking
                (mask1_pressure, mask1_flow, mask1_demand)=apply_mask(args, val_pressureX, val_flowX, val_demandX)
                val_pressureX_ = val_pressureX * mask1_pressure
                val_flowX_     = val_flowX * mask1_flow
                val_demandX_   = (val_demandX[:, :-1]) * mask1_demand

                pressure_Y, flow_Y, demand_Y = model(Graph_Data, mask1_pressure, mask1_flow, mask1_demand, val_pressureX_, val_flowX_, val_demandX_)
                # Add last column to demand_Y
                row_sums = torch.sum(demand_Y, dim=1, keepdim=True)
                demand_Y = torch.cat((demand_Y, -row_sums), dim=1)
                
                # min-max normalized reduction
                pressure_Y = (pressure_Y * (max_pressure - min_pressure) + min_pressure)
                flow_Y     = (flow_Y     * (max_flow     - min_flow)     + min_flow)
                demand_Y   = (demand_Y   * (max_demand   - min_demand)   + min_demand)
                val_pressureX  = (val_pressureX * (max_pressure - min_pressure) + min_pressure)
                val_flowX      = (val_flowX     * (max_flow     - min_flow)     + min_flow)
                val_demandX    = (val_demandX   * (max_demand   - min_demand)   + min_demand)
                val_pressureX_ = (val_pressureX_ * (max_pressure- min_pressure) + min_pressure)
                val_flowX_     = (val_flowX_     * (max_flow    - min_flow)     + min_flow)
                val_demandX_   = (val_demandX_   * (max_demand   - min_demand)  + min_demand)

                # Caculate Loss
                (loss1, loss2, loss3) = criterion(pressure_Y, flow_Y, demand_Y, pipe_friction, incidence_matrix, cycle_matrix)
                
                # Save
                pressure_True = torch.cat((pressure_True, val_pressureX), 0)
                pressure_Est  = torch.cat((pressure_Est, pressure_Y), 0)
                flow_True     = torch.cat((flow_True,   val_flowX), 0)
                flow_Est      = torch.cat((flow_Est,     flow_Y), 0)
                demand_True   = torch.cat((demand_True, val_demandX), 0)
                demand_Est    = torch.cat((demand_Est,   demand_Y), 0)

                # Loss
                val_loss1 += loss1.item()
                val_loss2 += loss2.item()
                val_loss3 += loss3.item()
        avg_val_loss1 = (val_loss1) / len(val_loader.dataset)
        avg_val_loss2 = (val_loss2) / len(val_loader.dataset)
        avg_val_loss3 = (val_loss3) / len(val_loader.dataset)
        avg_val_loss = (avg_val_loss1 + avg_val_loss2 + avg_val_loss3)

        val_loss1_list.append(avg_val_loss1)
        val_loss2_list.append(avg_val_loss2)
        val_loss3_list.append(avg_val_loss3)
        
        IO.cprint('Epoch #{:03d}, Conservation1: {:.4f}, Conservation2: {:.4f}, Conservation3: {:.4f}'.format(
                epoch, (avg_val_loss1), (avg_val_loss2), (avg_val_loss3)
                ))

        # Choose the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            # Save the best model
            best_model_wts = model.state_dict()
        else:
            patience_counter += 1

        # Check for early stopping
        if patience_counter >= patience:
            IO.cprint('Early stopping triggered. Best Val_Loss: {:.4f}'.format(best_val_loss))
            break
        
    model.load_state_dict(best_model_wts)

    torch.save(model, 'outputs/%s/model.pth' % args.exp_name)
    IO.cprint('The current best model is saved in: {}'.format('******** outputs/%s/model.pth *********' % args.exp_name))
    end_time = time.time()
    IO.cprint('Total time: {:.4f}s'.format(end_time - start_time))
    return (pressure_True, pressure_Est, flow_True, flow_Est, demand_True, demand_Est,
            train_loss1_list, train_loss2_list, train_loss3_list, 
            val_loss1_list, val_loss2_list, val_loss3_list)

In [None]:
def test(args, IO, test_loader, 
        min_pressure, max_pressure, min_flow, max_flow, min_demand, max_demand):
    """测试模型"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    Graph_Data, pipe_friction, incidence_matrix, cycle_matrix = load_matrix(args)
    
    IO.cprint('')
    IO.cprint('********** TEST START **********')
    IO.cprint('Reload Best Model')
    IO.cprint('The current best model is saved in: {}'.format('******** outputs/%s/model.pth *********' % args.exp_name))

    model = torch.load('outputs/%s/model.pth' % args.exp_name).to(device)
    model = model.train()
    
    optimizer = optim.RMSprop(model.parameters(), lr=1e-7)
                                        
    ##############################
    ### Test For Generalization ##
    ##############################

    # Loss Function
    criterion1 = ConservationConstraints()
    criterion2 = ConservationConstraints_generalization()

    offset = 0  		# initial offset
    batch_size = 3 * 96 # from 3-days ago

    MAE_pressure = []
    MAE_flow = []
    MAE_demand = []
    R2_pressure = []
    R2_flow = []
    R2_demand = []
    Loss1 = []
    Loss2 = []
    Loss3 = []

    for loop_ in range(len(test_loader) - 3 * 96):

        start_index = offset
        end_index = min(start_index + batch_size, len(test_loader))
        sliced_dataset = torch.utils.data.Subset(test_loader.dataset, range(start_index, end_index))
        new_test_loader = torch.utils.data.DataLoader(sliced_dataset, batch_size=1, shuffle=False)

        for i, data in tqdm(enumerate(new_test_loader), total=len(new_test_loader), desc="Test_Loader"):
            (test_pressureX, test_pressureY,
            test_flowX, test_flowY, 
            test_demandX, test_demandY)= data
            test_pressureX = test_pressureX.squeeze(0)
            test_flowX     = test_flowX.squeeze(0)
            test_demandX   = test_demandX.squeeze(0)

            (mask1_pressure, mask1_flow, mask1_demand)=apply_mask(args, test_pressureX, test_flowX, test_demandX)

            test_pressureX_ = test_pressureX * mask1_pressure
            test_flowX_     = test_flowX * mask1_flow
            test_demandX_   = (test_demandX[:, :-1]) * mask1_demand

            pressure_Y, flow_Y, demand_Y = model(Graph_Data, mask1_pressure, mask1_flow, mask1_demand, test_pressureX_, test_flowX_ , test_demandX_)
            row_sums = torch.sum(demand_Y, dim=1, keepdim=True)
            demand_Y = torch.cat((demand_Y, -row_sums), dim=1)
        
            pressure_Y = (pressure_Y * (max_pressure - min_pressure) + min_pressure)
            flow_Y     = (flow_Y     * (max_flow     - min_flow)     + min_flow)
            demand_Y   = (demand_Y   * (max_demand   - min_demand)   + min_demand)
            test_pressureX = (test_pressureX * (max_pressure - min_pressure) + min_pressure)
            test_flowX     = (test_flowX     * (max_flow     - min_flow)     + min_flow)
            test_demandX   = (test_demandX   * (max_demand   - min_demand)   + min_demand)

            (loss1, loss2, loss3) = criterion1(pressure_Y, flow_Y, demand_Y, pipe_friction, incidence_matrix, cycle_matrix)
            loss_ = loss1 + loss2 + loss3
            loss_.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Save the last batch
            if i == len(new_test_loader) - 1:

                IO.cprint('Loss1: {:.4f}, Loss2: {:.4f}, Loss3: {:.4f}'.format(loss1.item(), loss2.item(), loss3.item()))
                (loss1, loss2, loss3) = criterion2(pressure_Y, flow_Y, demand_Y, pipe_friction, incidence_matrix, cycle_matrix)

                mae_pressure = torch.mean(torch.abs(pressure_Y - test_pressureX), dim=0)
                mae_flow     = torch.mean(torch.abs(flow_Y - test_flowX), dim=0)
                mae_demand   = torch.mean(torch.abs(demand_Y - test_demandX), dim=0)

                r2_pressure = 1 - (torch.sum(((pressure_Y - test_pressureX) ** 2), dim=0) / torch.sum((pressure_Y - torch.mean(pressure_Y)) ** 2, dim=0))
                r2_flow     = 1 - (torch.sum(((flow_Y - test_flowX) ** 2), dim=0) / torch.sum((flow_Y - torch.mean(flow_Y)) ** 2, dim=0))
                r2_demand   = 1 - (torch.sum(((demand_Y - test_demandX) ** 2), dim=0) / torch.sum((demand_Y - torch.mean(demand_Y)) ** 2, dim=0))

                Loss1.append(loss1.detach().cpu().numpy())
                Loss2.append(loss2.detach().cpu().numpy())
                Loss3.append(loss3.detach().cpu().numpy())

                MAE_pressure.append(mae_pressure.detach().cpu().numpy())
                MAE_flow.append(mae_flow.detach().cpu().numpy())
                MAE_demand.append(mae_demand.detach().cpu().numpy())

                R2_pressure.append(r2_pressure.detach().cpu().numpy())
                R2_flow.append(r2_flow.detach().cpu().numpy())  
                R2_demand.append(r2_demand.detach().cpu().numpy())

                offset += 1

            del pressure_Y, flow_Y, demand_Y, loss1, loss2, loss3
            del test_pressureX, test_flowX, test_demandX, test_pressureX_, test_flowX_, test_demandX_

    return (Loss1, Loss2, Loss3,
            MAE_pressure, MAE_flow, MAE_demand,
            R2_pressure, R2_flow, R2_demand)

In [None]:
args = parse_args()
def exp_init():
    if not os.path.exists('outputs'):
        os.mkdir('outputs')
    if not os.path.exists('outputs/' + args.exp_name):
        os.mkdir('outputs/' + args.exp_name)

In [None]:
if __name__ == '__main__':
      random.seed(args.seed) 
      torch.manual_seed(args.seed) 
      exp_init()      
      IO = IOStream('outputs/' + args.exp_name + '/run.log')
      IO.cprint(str(table_printer(args)))  

      train_dataset = CustomDataset(args, dataset_type='train')
      val_dataset = CustomDataset(args, dataset_type='val')
      test_dataset = CustomDataset(args, dataset_type='test')
      train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, drop_last=True)
      val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, drop_last=True)
      test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=True)

      NodePressure = pd.read_csv(args.water_pressure_file, header=0, index_col=0)
      PipeFlow = pd.read_csv(args.water_flow_file, header=0, index_col=0)
      NodeDemand = pd.read_csv(args.water_demand_file, header=0, index_col=0)
      max_pressure = NodePressure.max().max()
      min_pressure = NodePressure.min().min()
      max_flow = PipeFlow.max().max() 
      min_flow = PipeFlow.min().min()
      max_demand = NodeDemand.max().max()
      min_demand = (NodeDemand.iloc[:, :-1]).min().min()                          

      min_pressure = torch.tensor(min_pressure).to(args.device)
      max_pressure = torch.tensor(max_pressure).to(args.device)
      min_flow = torch.tensor(min_flow).to(args.device)
      max_flow = torch.tensor(max_flow).to(args.device)

	  # TODO ： If you are just testing the model, just mask the following line of code to run the model
      (pressure_True, pressure_Est, flow_True, flow_Est, demand_True, demand_Est,
       train_loss1_list, train_loss2_list, train_loss3_list,
       val_loss1_list, val_loss2_list, val_loss3_list) = train(
       args, IO, train_dataloader, val_dataloader, min_pressure, max_pressure, min_flow, max_flow, min_demand, max_demand)
      
      (Loss1, Loss2, Loss3,
       MAE_pressure, MAE_flow, MAE_demand,
       R2_pressure, R2_flow, R2_demand) = test(
      args, IO, test_dataloader, min_pressure, max_pressure, min_flow, max_flow, min_demand, max_demand) 

+----------------------+-------------------------------------------------------+
|      Parameter       |                         Value                         |
| Chebyshevdegree      | 8                                                     |
+----------------------+-------------------------------------------------------+
| Fourierdegree        | 6                                                     |
+----------------------+-------------------------------------------------------+
| Batch size           | 96                                                    |
+----------------------+-------------------------------------------------------+
| Cycle file           | D:/GraphormerForRobustness/dataset2/initial_cycle_mat |
|                      | rix_dir.csv                                           |
+----------------------+-------------------------------------------------------+
| Decay epoch          | 1                                                     |
+----------------------+----

  nn.init.normal(self.fouriercoeffs, mean=0.0, std=1 / (np.sqrt(input_dim) * np.sqrt(self.gridsize)))
  nn.init.normal(self.cheby_coeffs, mean=0.0, std=1/(input_dim * (degree + 1)))


KANSA(
  (flow_batch): Sequential(
    (0): Linear(in_features=96, out_features=12, bias=True)
    (1): Linear(in_features=12, out_features=96, bias=True)
    (2): FourierKANLayer()
    (3): FourierKANLayer()
    (4): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (pressure_batch): Sequential(
    (0): Linear(in_features=96, out_features=12, bias=True)
    (1): Linear(in_features=12, out_features=96, bias=True)
    (2): FourierKANLayer()
    (3): FourierKANLayer()
    (4): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (demand_batch): Sequential(
    (0): Linear(in_features=96, out_features=12, bias=True)
    (1): Linear(in_features=12, out_features=96, bias=True)
    (2): FourierKANLayer()
    (3): FourierKANLayer()
    (4): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (flow_features): Sequential(
    (0): Linear(in_features=34, out_features=12, bias=True)
    (1): Linear(in_features=12, out_features=34, bias=True)
    (2): ChebyshevKANLayer()
   

Train_Loader: 100%|██████████| 4129/4129 [09:30<00:00,  7.24it/s]


Epoch #000, Conservation1: 467070.0138, Conservation2: 21254955.9058, Conservation3: 26147178.2699


Val_Loader: 100%|██████████| 1249/1249 [00:53<00:00, 23.19it/s]


Epoch #000, Conservation1: 1010.5197, Conservation2: 106.9387, Conservation3: 1163.9568


Train_Loader: 100%|██████████| 4129/4129 [10:18<00:00,  6.68it/s]


Epoch #001, Conservation1: 687.0553, Conservation2: 76.5890, Conservation3: 927.4044


Val_Loader: 100%|██████████| 1249/1249 [01:11<00:00, 17.58it/s]


Epoch #001, Conservation1: 246.6851, Conservation2: 33.5862, Conservation3: 361.0021


Train_Loader: 100%|██████████| 4129/4129 [10:35<00:00,  6.49it/s]


Epoch #002, Conservation1: 186.2768, Conservation2: 29.7768, Conservation3: 279.1882


Val_Loader: 100%|██████████| 1249/1249 [01:16<00:00, 16.37it/s]


Epoch #002, Conservation1: 103.8278, Conservation2: 29.3487, Conservation3: 176.9509


Train_Loader: 100%|██████████| 4129/4129 [10:27<00:00,  6.58it/s]


Epoch #003, Conservation1: 105.5992, Conservation2: 21.3788, Conservation3: 152.9279


Val_Loader: 100%|██████████| 1249/1249 [01:19<00:00, 15.68it/s]


Epoch #003, Conservation1: 88.3545, Conservation2: 33.7687, Conservation3: 190.7471


Train_Loader: 100%|██████████| 4129/4129 [10:52<00:00,  6.32it/s]


Epoch #004, Conservation1: 75.1331, Conservation2: 17.6548, Conservation3: 109.0948


Val_Loader: 100%|██████████| 1249/1249 [01:22<00:00, 15.09it/s]


Epoch #004, Conservation1: 67.0370, Conservation2: 24.5351, Conservation3: 129.4896


Train_Loader: 100%|██████████| 4129/4129 [11:01<00:00,  6.24it/s]


Epoch #005, Conservation1: 52.5229, Conservation2: 14.9855, Conservation3: 83.4094


Val_Loader: 100%|██████████| 1249/1249 [01:25<00:00, 14.56it/s]


Epoch #005, Conservation1: 39.6352, Conservation2: 19.3222, Conservation3: 87.7997


Train_Loader: 100%|██████████| 4129/4129 [11:20<00:00,  6.07it/s]


Epoch #006, Conservation1: 38.1092, Conservation2: 12.3963, Conservation3: 65.5024


Val_Loader: 100%|██████████| 1249/1249 [01:26<00:00, 14.47it/s]


Epoch #006, Conservation1: 34.4009, Conservation2: 10.4323, Conservation3: 82.8742


Train_Loader: 100%|██████████| 4129/4129 [11:20<00:00,  6.07it/s]


Epoch #007, Conservation1: 29.9445, Conservation2: 10.4762, Conservation3: 52.7569


Val_Loader: 100%|██████████| 1249/1249 [01:27<00:00, 14.32it/s]


Epoch #007, Conservation1: 23.2165, Conservation2: 14.3018, Conservation3: 55.7260


Train_Loader: 100%|██████████| 4129/4129 [11:27<00:00,  6.00it/s]


Epoch #008, Conservation1: 24.6463, Conservation2: 9.1361, Conservation3: 43.6693


Val_Loader: 100%|██████████| 1249/1249 [01:29<00:00, 13.94it/s]


Epoch #008, Conservation1: 26.8253, Conservation2: 7.9837, Conservation3: 61.0921


Train_Loader: 100%|██████████| 4129/4129 [11:20<00:00,  6.07it/s]


Epoch #009, Conservation1: 20.7337, Conservation2: 8.1387, Conservation3: 37.1652


Val_Loader: 100%|██████████| 1249/1249 [01:30<00:00, 13.87it/s]


Epoch #009, Conservation1: 14.4570, Conservation2: 8.2727, Conservation3: 31.8875


Train_Loader: 100%|██████████| 4129/4129 [11:28<00:00,  6.00it/s]


Epoch #010, Conservation1: 17.5925, Conservation2: 7.2959, Conservation3: 32.1016


Val_Loader: 100%|██████████| 1249/1249 [01:32<00:00, 13.55it/s]


Epoch #010, Conservation1: 12.0689, Conservation2: 8.1481, Conservation3: 29.0095


Train_Loader: 100%|██████████| 4129/4129 [11:37<00:00,  5.92it/s]


Epoch #011, Conservation1: 15.1556, Conservation2: 6.5509, Conservation3: 27.9540


Val_Loader: 100%|██████████| 1249/1249 [01:36<00:00, 12.98it/s]


Epoch #011, Conservation1: 20.2402, Conservation2: 25.2246, Conservation3: 69.1549


Train_Loader: 100%|██████████| 4129/4129 [11:27<00:00,  6.00it/s]


Epoch #012, Conservation1: 13.1669, Conservation2: 5.9591, Conservation3: 24.6002


Val_Loader: 100%|██████████| 1249/1249 [01:39<00:00, 12.57it/s]


Epoch #012, Conservation1: 11.5917, Conservation2: 8.0640, Conservation3: 35.9489


Train_Loader: 100%|██████████| 4129/4129 [11:32<00:00,  5.96it/s]


Epoch #013, Conservation1: 11.5968, Conservation2: 5.3600, Conservation3: 21.7516


Val_Loader: 100%|██████████| 1249/1249 [01:35<00:00, 13.02it/s]


Epoch #013, Conservation1: 11.1703, Conservation2: 3.3520, Conservation3: 28.5669


Train_Loader: 100%|██████████| 4129/4129 [11:30<00:00,  5.98it/s]


Epoch #014, Conservation1: 10.2217, Conservation2: 4.8976, Conservation3: 19.3402


Val_Loader: 100%|██████████| 1249/1249 [01:38<00:00, 12.68it/s]


Epoch #014, Conservation1: 10.7397, Conservation2: 3.2959, Conservation3: 28.5758


Train_Loader: 100%|██████████| 4129/4129 [11:40<00:00,  5.89it/s]


Epoch #015, Conservation1: 9.1658, Conservation2: 4.4647, Conservation3: 17.3035


Val_Loader: 100%|██████████| 1249/1249 [01:39<00:00, 12.60it/s]


Epoch #015, Conservation1: 8.9053, Conservation2: 2.5052, Conservation3: 22.0833


Train_Loader: 100%|██████████| 4129/4129 [11:34<00:00,  5.94it/s]


Epoch #016, Conservation1: 8.1914, Conservation2: 4.0885, Conservation3: 15.5921


Val_Loader: 100%|██████████| 1249/1249 [01:39<00:00, 12.55it/s]


Epoch #016, Conservation1: 9.1846, Conservation2: 2.9467, Conservation3: 21.7363


Train_Loader: 100%|██████████| 4129/4129 [11:44<00:00,  5.86it/s]


Epoch #017, Conservation1: 7.4068, Conservation2: 3.7464, Conservation3: 14.0904


Val_Loader: 100%|██████████| 1249/1249 [01:41<00:00, 12.35it/s]


Epoch #017, Conservation1: 7.9357, Conservation2: 2.3411, Conservation3: 22.6247


Train_Loader: 100%|██████████| 4129/4129 [11:51<00:00,  5.81it/s]


Epoch #018, Conservation1: 6.7235, Conservation2: 3.4487, Conservation3: 12.7919


Val_Loader: 100%|██████████| 1249/1249 [01:41<00:00, 12.31it/s]


Epoch #018, Conservation1: 6.0711, Conservation2: 4.5482, Conservation3: 18.8075


Train_Loader: 100%|██████████| 4129/4129 [11:23<00:00,  6.04it/s]


Epoch #019, Conservation1: 6.1273, Conservation2: 3.1757, Conservation3: 11.6746


Val_Loader: 100%|██████████| 1249/1249 [01:33<00:00, 13.37it/s]


Epoch #019, Conservation1: 6.7144, Conservation2: 2.1694, Conservation3: 20.4767


Train_Loader: 100%|██████████| 4129/4129 [12:01<00:00,  5.72it/s]


Epoch #020, Conservation1: 5.5977, Conservation2: 2.9332, Conservation3: 10.6799


Val_Loader: 100%|██████████| 1249/1249 [01:46<00:00, 11.73it/s]


Epoch #020, Conservation1: 5.2046, Conservation2: 3.6102, Conservation3: 14.4925


Train_Loader: 100%|██████████| 4129/4129 [12:03<00:00,  5.71it/s]


Epoch #021, Conservation1: 5.1211, Conservation2: 2.7150, Conservation3: 9.8339


Val_Loader: 100%|██████████| 1249/1249 [01:46<00:00, 11.76it/s]


Epoch #021, Conservation1: 4.5113, Conservation2: 2.5998, Conservation3: 13.3247


Train_Loader: 100%|██████████| 4129/4129 [12:00<00:00,  5.73it/s]


Epoch #022, Conservation1: 4.6994, Conservation2: 2.5126, Conservation3: 9.0403


Val_Loader: 100%|██████████| 1249/1249 [01:45<00:00, 11.88it/s]


Epoch #022, Conservation1: 5.4636, Conservation2: 1.8151, Conservation3: 16.6655


Train_Loader: 100%|██████████| 4129/4129 [12:00<00:00,  5.73it/s]


Epoch #023, Conservation1: 4.3470, Conservation2: 2.3387, Conservation3: 8.3522


Val_Loader: 100%|██████████| 1249/1249 [01:52<00:00, 11.12it/s]


Epoch #023, Conservation1: 5.5478, Conservation2: 2.0895, Conservation3: 16.1253


Train_Loader: 100%|██████████| 4129/4129 [12:06<00:00,  5.68it/s]


Epoch #024, Conservation1: 4.0214, Conservation2: 2.1788, Conservation3: 7.7355


Val_Loader: 100%|██████████| 1249/1249 [01:49<00:00, 11.36it/s]


Epoch #024, Conservation1: 4.9514, Conservation2: 1.8129, Conservation3: 14.0920


Train_Loader: 100%|██████████| 4129/4129 [12:15<00:00,  5.62it/s]


Epoch #025, Conservation1: 3.7057, Conservation2: 2.0260, Conservation3: 7.1485


Val_Loader: 100%|██████████| 1249/1249 [01:50<00:00, 11.26it/s]


Epoch #025, Conservation1: 3.1344, Conservation2: 1.2158, Conservation3: 6.6725


Train_Loader: 100%|██████████| 4129/4129 [11:55<00:00,  5.77it/s]


Epoch #026, Conservation1: 3.4354, Conservation2: 1.8884, Conservation3: 6.6524


Val_Loader: 100%|██████████| 1249/1249 [01:45<00:00, 11.82it/s]


Epoch #026, Conservation1: 3.1923, Conservation2: 3.0224, Conservation3: 10.5638


Train_Loader: 100%|██████████| 4129/4129 [11:58<00:00,  5.75it/s]


Epoch #027, Conservation1: 3.2123, Conservation2: 1.7678, Conservation3: 6.1874


Val_Loader: 100%|██████████| 1249/1249 [01:44<00:00, 11.96it/s]


Epoch #027, Conservation1: 3.3467, Conservation2: 2.8953, Conservation3: 8.7506


Train_Loader: 100%|██████████| 4129/4129 [12:27<00:00,  5.53it/s]


Epoch #028, Conservation1: 2.9918, Conservation2: 1.6430, Conservation3: 5.7696


Val_Loader: 100%|██████████| 1249/1249 [01:51<00:00, 11.15it/s]


Epoch #028, Conservation1: 3.6497, Conservation2: 1.2954, Conservation3: 10.3807


Train_Loader: 100%|██████████| 4129/4129 [12:20<00:00,  5.58it/s]


Epoch #029, Conservation1: 2.7920, Conservation2: 1.5426, Conservation3: 5.3939


Val_Loader: 100%|██████████| 1249/1249 [01:52<00:00, 11.06it/s]


Epoch #029, Conservation1: 2.6947, Conservation2: 1.3222, Conservation3: 7.1633


Train_Loader: 100%|██████████| 4129/4129 [12:16<00:00,  5.61it/s]


Epoch #030, Conservation1: 2.6131, Conservation2: 1.4435, Conservation3: 5.0478


Val_Loader: 100%|██████████| 1249/1249 [01:50<00:00, 11.25it/s]


Epoch #030, Conservation1: 3.0140, Conservation2: 1.9697, Conservation3: 6.9055


Train_Loader: 100%|██████████| 4129/4129 [12:49<00:00,  5.37it/s]


Epoch #031, Conservation1: 2.4541, Conservation2: 1.3508, Conservation3: 4.7430


Val_Loader: 100%|██████████| 1249/1249 [02:00<00:00, 10.37it/s]


Epoch #031, Conservation1: 3.0627, Conservation2: 1.2191, Conservation3: 7.4430


Train_Loader: 100%|██████████| 4129/4129 [12:28<00:00,  5.52it/s]


Epoch #032, Conservation1: 2.3230, Conservation2: 1.2692, Conservation3: 4.4608


Val_Loader: 100%|██████████| 1249/1249 [01:54<00:00, 10.90it/s]


Epoch #032, Conservation1: 3.1592, Conservation2: 1.1957, Conservation3: 8.8061


Train_Loader: 100%|██████████| 4129/4129 [12:42<00:00,  5.42it/s]


Epoch #033, Conservation1: 2.1863, Conservation2: 1.1939, Conservation3: 4.1966


Val_Loader: 100%|██████████| 1249/1249 [02:00<00:00, 10.36it/s]


Epoch #033, Conservation1: 2.3767, Conservation2: 1.9048, Conservation3: 8.3630


Train_Loader: 100%|██████████| 4129/4129 [13:41<00:00,  5.03it/s]


Epoch #034, Conservation1: 2.0882, Conservation2: 1.1212, Conservation3: 3.9568


Val_Loader: 100%|██████████| 1249/1249 [02:08<00:00,  9.73it/s]


Epoch #034, Conservation1: 2.4225, Conservation2: 0.9640, Conservation3: 6.5826


Train_Loader: 100%|██████████| 4129/4129 [13:18<00:00,  5.17it/s]


Epoch #035, Conservation1: 1.9784, Conservation2: 1.0584, Conservation3: 3.7499


Val_Loader: 100%|██████████| 1249/1249 [02:06<00:00,  9.91it/s]


Epoch #035, Conservation1: 2.4547, Conservation2: 1.4049, Conservation3: 6.0125


Train_Loader: 100%|██████████| 4129/4129 [13:22<00:00,  5.15it/s]


Epoch #036, Conservation1: 1.8811, Conservation2: 0.9978, Conservation3: 3.5478


Val_Loader: 100%|██████████| 1249/1249 [02:10<00:00,  9.60it/s]


Epoch #036, Conservation1: 2.4880, Conservation2: 1.0737, Conservation3: 6.8862


Train_Loader: 100%|██████████| 4129/4129 [13:37<00:00,  5.05it/s]


Epoch #037, Conservation1: 1.8072, Conservation2: 0.9433, Conservation3: 3.3762


Val_Loader: 100%|██████████| 1249/1249 [02:10<00:00,  9.56it/s]


Epoch #037, Conservation1: 2.5185, Conservation2: 1.1325, Conservation3: 6.9087


Train_Loader: 100%|██████████| 4129/4129 [13:42<00:00,  5.02it/s]


Epoch #038, Conservation1: 1.7225, Conservation2: 0.8951, Conservation3: 3.2076


Val_Loader: 100%|██████████| 1249/1249 [02:13<00:00,  9.36it/s]


Epoch #038, Conservation1: 2.1084, Conservation2: 1.1501, Conservation3: 5.0422


Train_Loader: 100%|██████████| 4129/4129 [13:36<00:00,  5.06it/s]


Epoch #039, Conservation1: 1.6555, Conservation2: 0.8486, Conservation3: 3.0537


Val_Loader: 100%|██████████| 1249/1249 [02:11<00:00,  9.47it/s]


Epoch #039, Conservation1: 2.1756, Conservation2: 0.8974, Conservation3: 6.1581


Train_Loader: 100%|██████████| 4129/4129 [14:05<00:00,  4.88it/s]


Epoch #040, Conservation1: 1.5885, Conservation2: 0.8070, Conservation3: 2.9132


Val_Loader: 100%|██████████| 1249/1249 [02:16<00:00,  9.13it/s]


Epoch #040, Conservation1: 1.7892, Conservation2: 1.0592, Conservation3: 4.1988


Train_Loader: 100%|██████████| 4129/4129 [13:54<00:00,  4.95it/s]


Epoch #041, Conservation1: 1.5272, Conservation2: 0.7671, Conservation3: 2.7880


Val_Loader: 100%|██████████| 1249/1249 [02:16<00:00,  9.18it/s]


Epoch #041, Conservation1: 1.6086, Conservation2: 0.9141, Conservation3: 4.2117


Train_Loader: 100%|██████████| 4129/4129 [14:37<00:00,  4.70it/s]


Epoch #042, Conservation1: 1.4737, Conservation2: 0.7321, Conservation3: 2.6667


Val_Loader: 100%|██████████| 1249/1249 [02:25<00:00,  8.61it/s] 


Epoch #042, Conservation1: 1.6590, Conservation2: 0.9568, Conservation3: 3.9149


Train_Loader: 100%|██████████| 4129/4129 [14:15<00:00,  4.83it/s]


Epoch #043, Conservation1: 1.4098, Conservation2: 0.6984, Conservation3: 2.5535


Val_Loader: 100%|██████████| 1249/1249 [02:22<00:00,  8.77it/s] 


Epoch #043, Conservation1: 1.7314, Conservation2: 0.9075, Conservation3: 4.3346


Train_Loader: 100%|██████████| 4129/4129 [14:10<00:00,  4.86it/s]


Epoch #044, Conservation1: 1.3638, Conservation2: 0.6688, Conservation3: 2.4516


Val_Loader: 100%|██████████| 1249/1249 [02:20<00:00,  8.91it/s]


Epoch #044, Conservation1: 1.6349, Conservation2: 0.8846, Conservation3: 4.3793


Train_Loader: 100%|██████████| 4129/4129 [14:07<00:00,  4.87it/s]


Epoch #045, Conservation1: 1.3180, Conservation2: 0.6410, Conservation3: 2.3573


Val_Loader: 100%|██████████| 1249/1249 [02:20<00:00,  8.90it/s]


Epoch #045, Conservation1: 1.4376, Conservation2: 0.9905, Conservation3: 3.8973


Train_Loader: 100%|██████████| 4129/4129 [14:29<00:00,  4.75it/s]


Epoch #046, Conservation1: 1.2832, Conservation2: 0.6141, Conservation3: 2.2726


Val_Loader: 100%|██████████| 1249/1249 [02:23<00:00,  8.73it/s] 


Epoch #046, Conservation1: 1.5107, Conservation2: 0.9865, Conservation3: 4.6904


Train_Loader: 100%|██████████| 4129/4129 [14:03<00:00,  4.89it/s]


Epoch #047, Conservation1: 1.2418, Conservation2: 0.5899, Conservation3: 2.1913


Val_Loader: 100%|██████████| 1249/1249 [02:18<00:00,  9.01it/s]


Epoch #047, Conservation1: 1.4718, Conservation2: 0.8118, Conservation3: 3.6292


Train_Loader: 100%|██████████| 4129/4129 [14:36<00:00,  4.71it/s]


Epoch #048, Conservation1: 1.2041, Conservation2: 0.5677, Conservation3: 2.1149


Val_Loader: 100%|██████████| 1249/1249 [02:23<00:00,  8.69it/s] 


Epoch #048, Conservation1: 1.4989, Conservation2: 0.6284, Conservation3: 3.8995


Train_Loader: 100%|██████████| 4129/4129 [14:44<00:00,  4.67it/s]


Epoch #049, Conservation1: 1.1759, Conservation2: 0.5475, Conservation3: 2.0475


Val_Loader:   0%|          | 0/1249 [00:00<?, ?it/s]

In [None]:
output_dir = 'outputs/DataManage' 
os.makedirs(output_dir, exist_ok=True)  

In [None]:
# # 数据输出到csv文件 
# save_pressure_True = pressure_True.cpu().detach().numpy().astype(float)
# save_pressure_Est = pressure_Est.cpu().detach().numpy().astype(float)
# save_flow_True = flow_True.cpu().detach().numpy().astype(float)
# save_flow_Est = flow_Est.cpu().detach().numpy().astype(float)
# save_demand_True = demand_True.cpu().detach().numpy().astype(float)
# save_demand_Est = demand_Est.cpu().detach().numpy().astype(float)

# np.savetxt(os.path.join(output_dir, 'pressure_True.csv'), save_pressure_True, delimiter=',')
# np.savetxt(os.path.join(output_dir, 'pressure_Est.csv'), save_pressure_Est, delimiter=',')
# np.savetxt(os.path.join(output_dir, 'flow_True.csv'), save_flow_True, delimiter=',')  
# np.savetxt(os.path.join(output_dir, 'flow_Est.csv'), save_flow_Est, delimiter=',')
# np.savetxt(os.path.join(output_dir, 'demand_True.csv'), save_demand_True, delimiter=',')
# np.savetxt(os.path.join(output_dir, 'demand_Est.csv'), save_demand_Est, delimiter=',')

In [None]:
# Save the generalization (MAE_pressure, MAE_flow, MAE_demand, R2_pressure, R2_flow, R2_demand) 
MAE_pressure = pd.DataFrame(MAE_pressure)
MAE_flow = pd.DataFrame(MAE_flow)
MAE_demand = pd.DataFrame(MAE_demand)

R2_pressure = pd.DataFrame(R2_pressure)
R2_flow = pd.DataFrame(R2_flow)
R2_demand = pd.DataFrame(R2_demand)

Loss1 = pd.DataFrame(Loss1)
Loss2 = pd.DataFrame(Loss2)
Loss3 = pd.DataFrame(Loss3)
Loss1.to_csv(os.path.join(output_dir, 'Coservation1_generalization.csv'), index=False)
Loss2.to_csv(os.path.join(output_dir, 'Coservation2_generalization.csv'), index=False)
Loss3.to_csv(os.path.join(output_dir, 'Coservation3_generalization.csv'), index=False)

MAE_pressure.to_csv(os.path.join(output_dir, 'MAE_pressure_generalization.csv'), index=False)
MAE_flow.to_csv(os.path.join(output_dir, 'MAE_flow_generalization.csv'), index=False)
MAE_demand.to_csv(os.path.join(output_dir, 'MAE_demand_generalization.csv'), index=False)

R2_pressure.to_csv(os.path.join(output_dir, 'R2_pressure_generalization.csv'), index=False)
R2_flow.to_csv(os.path.join(output_dir, 'R2_flow_generalization.csv'), index=False)
R2_demand.to_csv(os.path.join(output_dir, 'R2_demand_generalization.csv'), index=False)