In [601]:
import pandas as pd
import h5py
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import networkx as nx
import time
from random import shuffle
import math
import torch
import torch.nn as nn
import torch_geometric as tg
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import GCNConv
from torch_geometric.utils import add_self_loops, degree
import torch.autograd as autograd
from torch.nn import init
import pdb
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from torch.utils.data import Dataset, DataLoader, random_split
from torch_geometric.data import Data
import torch.optim as optim
from IPython.display import clear_output
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
import matplotlib.ticker as mticker
from torch.autograd import Function
import pytorch_optimizer as p_optim

In [602]:
def open_data(file_path):
    file = open(file_path,"rb")
    raw_data = pickle.load(file)  
    return raw_data

In [603]:
class NonLinearOperate(nn.Module):
    def __init__(self, input_dimen, hidden_dimen, output_dimen):
        super(NonLinearOperate, self).__init__()  #类NonLinearLayer继承父类nn.Module的初始化方法
        self.layer_1 = nn.Linear(input_dimen, hidden_dimen)
        self.layer_2 = nn.Linear(hidden_dimen, output_dimen)
        self.acti_func = nn.ReLU()
        #for m in self.modules():#遍历所有子模块
        #Check if each sub-module is an example of the class nn.Linear
            #if isinstance(m, nn.Linear):
                #m.weight.data = init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
                #if m.bias is not None:
                    #m.bias.data = init.constant_(m.bias.data, 0.0)
            
        
    def forward(self, x):
        x = self.layer_1(x)
        x = self.acti_func(x)
        x = self.layer_2(x)
        
        
        return x

In [604]:
class PGNN_Layer(nn.Module):
    def __init__(self, input_dimen, output_dimen, max_ach_num):
        super(PGNN_Layer, self).__init__()
        self.input_dimen = input_dimen
        self.output_dimen = output_dimen
        self.distance_calculate = NonLinearOperate(1, output_dimen, 1)
        self.acti_func = nn.ReLU()
        self.linear_hidden = nn.Linear(2*input_dimen, output_dimen)
        self.out_transition = nn.Linear(output_dimen,1)
        self.linear_out_position = nn.Linear(max_ach_num,input_dimen)
        
        #for m in self.modules():
            #if isinstance(m, nn.Linear):
                #m.weight.data = init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
                #if m.bias is not None:
                    #m.bias.data = init.constant_(m.bias.data, 0.0)                
                    
    def forward(self, node_features, dists_max, dists_argmax):
        dists_max = self.distance_calculate(dists_max.unsqueeze(-1)).squeeze()
        subset_features = node_features[dists_argmax.flatten(), :]
        subset_features = subset_features.reshape(dists_argmax.shape[0], dists_argmax.shape[1], subset_features.shape[1])
        messages = subset_features * dists_max.unsqueeze(-1)
        feature_self = node_features.unsqueeze(1).repeat(1, dists_max.shape[1],1)
        messages = torch.concat((messages, feature_self), dim = -1) #N行M列D维
        messages = self.linear_hidden(messages).squeeze()#将输出维度改为Output Dimen,即n*m*output_dimen
        messages = self.acti_func(messages) 
        output_transition = self.out_transition(messages).squeeze(-1) #n * m * output_dimen to n * m
        output_position = self.linear_out_position(output_transition)
        output_structure = torch.mean(messages, dim=1)#n*output_dimen
        
        return output_position, output_structure

In [605]:
class PGNN(nn.Module):
    def __init__(self, input_dimen, hidden_dimen, output_dimen, max_ach_num, layer_num = 1, drop_out = True):
        super(PGNN, self).__init__()
        self.drop_out = drop_out
        self.layer_num = layer_num
        self.input_layer = nn.Linear(input_dimen, hidden_dimen)
        self.last_layer = nn.Linear(hidden_dimen, input_dimen)
        self.max_ach_num = max_ach_num
        if self.layer_num == 1:
            self.gnn_operate_1 = PGNN_Layer(hidden_dimen, output_dimen, max_ach_num)#输出维度是node_num * hidden_dimen
            
        if self.layer_num > 1:
            self.gnn_hidden = nn.ModuleList([PGNN_Layer(hidden_dimen, hidden_dimen, max_ach_num) for i in range(0, layer_num)])
            self.gnn_output_layer = PGNN_Layer(hidden_dimen, output_dimen)
                
        
    def forward(self, x, dist_max_sets, dist_argmax_sets):
        
        x = self.input_layer(x)
        if self.layer_num == 1:
            x_position, x = self.gnn_operate_1(x, dist_max_sets[0,:,:], dist_argmax_sets[0,:,:])
            if self.drop_out:
                x = F.dropout(x, training=self.training)
            x_position = self.last_layer(x_position)
            #print(f"pgnn output_layer size: {x_position.shape}")
            return x_position
    
        if self.layer_num > 1:
            for i in range(self.layer_num):
                _, x = self.gnn_hidden[i](x, dist_max_sets[i,:,:], dist_argmax_sets[i,:,:])
                if self.drop_out:
                    x = F.dropout(x, training=self.training)
                  
            _ = F.normalize(_, p=2, dim=-1)
            x_position = self.last_layer(_)    
            
            
            return x_position

In [606]:
class P_GCN(nn.Module):
    def __init__(self, input_dimen, hidden_dimen, output_dimen, max_ach_num, layer_num = 1, drop_out = True):
        super(P_GCN, self).__init__()
        self.max_ach_num = max_ach_num
        self.layer_num = layer_num
        self.drop_out = drop_out
        self.input_layer = nn.Linear(input_dimen, hidden_dimen)
        self.p_gcn_block = nn.Sequential(PGNN_Layer(hidden_dimen, hidden_dimen, max_ach_num), 
                                         GCNConv(hidden_dimen, hidden_dimen, add_self_loops=True))
        self.acti_func = nn.ReLU()   
        
        if layer_num == 1:
            self.gcn_p_layers = self.p_gcn_block
        if layer_num > 1:
            self.gcn_p_layers = nn.ModuleList([self.p_gcn_block for i in range(layer_num)])
            
        self.output_layer = nn.Linear(hidden_dimen, output_dimen)
        
        #for m in self.modules():
            #if isinstance(m, nn.Linear):
                #m.weight.data = init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain('relu'))
                #if m.bias is not None:
                    #m.bias.data = init.constant_(m.bias.data, 0.0)   
        
        
    def forward(self, x, edge_index, dist_max, dist_argmax):#GCN_P_input shape: node_num, 4
        
        x_ = self.input_layer(x)
        
        if self.layer_num == 1:
            x_position, _ = self.gcn_p_layers[0](x_, dist_max[0,:,:], dist_argmax[0,:,:])
            if self.drop_out:
                x_position = F.dropout(x_position, training=self.training)
            x = self.gcn_p_layers[1](x_position, edge_index)
            if self.drop_out:
                x = F.dropout(x, training=self.training)
            x = self.acti_func(x + x_)
            
        else:
            for i in range(self.layer_num):
                x_position, _ = self.gcn_p_layers[i][0](x_, dist_max[i,:,:], dist_argmax[i,:,:])
                if self.drop_out:
                    x_position = F.dropout(x_position, training=self.training)
                x = self.gcn_p_layers[i][1](x_position, edge_index)
                if self.drop_out:
                    x = F.dropout(x, training=self.training)
                x = self.acti_func(x + x_)
                x_ = x
                    
        x = self.acti_func(self.output_layer(x))
    
        return x

In [607]:
#输入数据shape: node_num, 4, time_step(20)
class CNN_1D(nn.Module):
    def __init__(self, input_channels, hidden_channels_1, hidden_channels_2, out_channels, output_dimen):
        super(CNN_1D, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv1d(in_channels = input_channels, out_channels = hidden_channels_1, kernel_size = 3, padding=1),
            nn.BatchNorm1d(hidden_channels_1),
            nn.ReLU(),
            nn.Conv1d(in_channels = hidden_channels_1, out_channels = hidden_channels_2, kernel_size = 3, padding=0),
            nn.BatchNorm1d(hidden_channels_2),
            nn.ReLU(), #len: 18
            nn.MaxPool1d(kernel_size=2, stride=2),#(18-2)/2 +1 =9
            nn.ConvTranspose1d(in_channels= hidden_channels_2,
                               out_channels=out_channels,
                               kernel_size=4,
                               stride= 2, 
                               padding=0))

        self.fc1 = nn.Linear(out_channels, output_dimen)
        self.acti_func = nn.ReLU()  
        
        nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
    
    def forward(self, x):
        out = self.layer1(x)
        #output shape: change from batch_num, out_channels, t-step to batch_num, t-step, out_channels
        out = out.permute(0, 2, 1)
        out = self.acti_func(self.fc1(out))
        #output shape: batch_num, t-step, output_dimen
        return out

In [608]:
class GRU(nn.Module):
    def __init__(self, input_dimen, hidden_dimen, pred_len, output_dimen = 4, num_layers = 2):
        super(GRU, self).__init__()
        self.hidden_dimen = hidden_dimen
        self.output_dimen = output_dimen
        self.num_layers = num_layers
        self.linear_layer_1 = nn.Linear(input_dimen, hidden_dimen)
        self.linear_layer_2 = nn.Linear(20, pred_len)
        self.gru_layers = nn.GRU(hidden_dimen, output_dimen, num_layers, batch_first = True)
        self.acti_func = nn.ReLU()
         
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data = init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
                if m.bias is not None:
                    m.bias.data = init.constant_(m.bias.data, 0.0)               
            
            
    #gru输入格式：node_num, t-steps, hidden_dimen
    def forward(self, pgnn_t_step_outs, extractor_outputs):  
        x = torch.cat((pgnn_t_step_outs, extractor_outputs), dim = -1)
        #x shape: node_num, t-steps, 2 * 4
        batch_size, seq_len, feat_dim = x.size()
        x = self.acti_func(self.linear_layer_1(x))
        x = self.acti_func(self.linear_layer_2(x.permute(0,2,1)))
        h_0 = torch.zeros(self.num_layers, batch_size, self.output_dimen)
        outputs, _ = self.gru_layers(x.permute(0,2,1), h_0)#outputs shape: batch_size, pred_len, output_dimen. 
        x = self.acti_func(outputs)
    
        
        return x

In [609]:
class T_Step_PGNN(nn.Module):
    def __init__(self, input_dimen, hidden_dimen, output_dimen, max_ach_num, layer_num = 1):
        super(T_Step_PGNN, self).__init__()
        self.pgnn_model = PGNN(input_dimen, hidden_dimen, output_dimen, max_ach_num, layer_num)
        
    def forward(self, t_step_inputs, subgraph_nodes, dist_max, dist_argmax): #t-step-inputs shape: 20, max_subgraph_nodenumber,4
        pgnn_template = torch.zeros((t_step_inputs.shape[0], t_step_inputs.shape[1], t_step_inputs.shape[2]))
        
        pgnn_outputs = torch.empty((0,subgraph_nodes.shape[0],t_step_inputs.shape[2]))
        subgraph_node_num = subgraph_nodes.shape[0]
            
        for t in range(t_step_inputs.shape[0]):
            pgnn_t_step = self.pgnn_model(t_step_inputs[t,:subgraph_node_num,:], dist_max[:,:,:], dist_argmax[:,:,:])
            pgnn_outputs = torch.cat((pgnn_outputs, pgnn_t_step.unsqueeze(0)), dim=0)
        
        pgnn_template[:,:subgraph_node_num,:] = pgnn_outputs
        
        return pgnn_template

In [610]:
class T_Step_PGCN(nn.Module):
    def __init__(self, input_dimen, hidden_dimen, output_dimen, max_ach_num, layer_num = 1):
        super(T_Step_PGCN, self).__init__()
        self.pgcn_model = P_GCN(input_dimen, hidden_dimen, output_dimen, max_ach_num, layer_num)
        self.output_dimen = output_dimen
        
    def forward(self, t_step_inputs, subgraph_nodes, edge_index, dist_max, dist_argmax): #t-step-inputs shape: 20, max_subgraph_nodenumber,4
        pgcn_template = torch.zeros((t_step_inputs.shape[0], t_step_inputs.shape[1], self.output_dimen))
        pgcn_outputs = torch.empty((0,subgraph_nodes.shape[0],self.output_dimen))
        #pgcn_template = torch.zeros((t_step_inputs.shape[0], t_step_inputs.shape[1], t_step_inputs.shape[2]))
        #pgcn_outputs = torch.empty((0,subgraph_nodes.shape[0],t_step_inputs.shape[2]))
        subgraph_node_num = subgraph_nodes.shape[0]
            
        for t in range(t_step_inputs.shape[0]):
            pgcn_t_step = self.pgcn_model(t_step_inputs[t,:subgraph_node_num,:], edge_index, dist_max[:,:,:], dist_argmax[:,:,:])
            pgcn_outputs = torch.cat((pgcn_outputs, pgcn_t_step.unsqueeze(0)), dim=0)
        
        pgcn_template[:,:subgraph_node_num,:] = pgcn_outputs
        
        return pgcn_template     

In [611]:
class Feature_Extractor(nn.Module):
    def __init__(self, input_dimen, hidden_dimen, output_dimen, max_ach_num,
                input_channels, hidden_channels_1, hidden_channels_2, out_channels, layer_num = 1):
        super(Feature_Extractor, self).__init__()
        self.pgcn_model = P_GCN(input_dimen, hidden_dimen, output_dimen, max_ach_num, layer_num = layer_num)
        self.cnn_1D = CNN_1D(input_channels, hidden_channels_1, hidden_channels_2, out_channels, output_dimen)
        self.output_dimen = output_dimen
         
    def forward(self, t_step_inputs, edge_index, subgraph_nodes, dist_max, dist_argmax): 
        pgcn_template = torch.zeros((t_step_inputs.shape[0], t_step_inputs.shape[1], self.output_dimen))
        pgcn_outputs = torch.empty((0, subgraph_nodes.shape[0], self.output_dimen))
       
        
        subgraph_node_num = subgraph_nodes.shape[0]
            
            
        for t in range(t_step_inputs.shape[0]):
            pgcn_t_step = self.pgcn_model(t_step_inputs[t,:subgraph_node_num,:], edge_index, dist_max[:,:,:], dist_argmax[:,:,:])
            pgcn_outputs = torch.cat((pgcn_outputs, pgcn_t_step.unsqueeze(0)), dim=0)
            
        pgcn_template[:,:subgraph_node_num,:] = pgcn_outputs
            
        #shape changed as: node_num, 4, time_step
        pgcn_template =pgcn_template.permute(1,2,0)
        extractor_outputs = self.cnn_1D(pgcn_template).permute(1, 0, 2)
        #cnn_output shape changed from node_num, t-step, 4 to t-step, node_num, 4 
    
        return extractor_outputs

In [612]:
def sort_pooling(extractor_outputs, k):
    norms = torch.norm(extractor_outputs, p=2, dim= -1)
    #print("norm shape", norms.shape)
    _, sorted_indices = torch.sort(norms, dim= -1, descending=True)
    sorted_outputs = torch.gather(extractor_outputs, dim=1, index=sorted_indices.unsqueeze(-1).expand(-1, -1, extractor_outputs.size(-1)))
    k_nodes_outputs = sorted_outputs[:, : k , :]
    
    return k_nodes_outputs

In [613]:
class GradientReversalLayer(Function):
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
    
    
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha  # 反转梯度
        return output, None



In [614]:
class Domain_Classifier(nn.Module):
    def __init__(self, input_dimen, hidden_dimen, k, seq_len=20, drop_out = True):
        super(Domain_Classifier, self).__init__()
        self.k = k
        self.drop_out = drop_out
        self.acti_func = nn.Sigmoid()
        self.input_dimen = input_dimen

        
        self.input_layer = nn.Sequential(nn.Linear(input_dimen, hidden_dimen), nn.BatchNorm1d(num_features = hidden_dimen), nn.Tanh())
        #self.linear_layer_1 = nn.Sequential(nn.Linear(hidden_dimen, hidden_dimen), nn.BatchNorm1d(num_features = hidden_dimen), nn.Tanh())    
        self.linear_layer_2 = nn.Sequential(nn.Linear(hidden_dimen,1), nn.Tanh())  
        self.output_layer = nn.Sequential(nn.Linear(k*seq_len, hidden_dimen), nn.Tanh(), nn.Linear(hidden_dimen, 1))

        
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear):
                module.weight.data = init.xavier_uniform_(module.weight.data)
                if module.bias is not None:
                    module.bias.data = init.constant_(module.bias.data, 0.0)                    
        
        
    def forward(self, extractor_outputs):
        k_nodes_outputs = sort_pooling(extractor_outputs, self.k) #Shape Changed as: 20, K, input_dimen
        k_nodes_outputs = k_nodes_outputs.reshape(-1, self.input_dimen)
        #k_nodes_outputs #Shape: 20*K, input_dimen
        x = self.input_layer(k_nodes_outputs)
        if self.drop_out:
            x = F.dropout(x, training=self.training, p = 0.3)
        x = self.linear_layer_2(x) #Shape from 20*K, hidden_dimen TO 20*K, 1
        if self.drop_out:
            x = F.dropout(x, training=self.training, p = 0.3)
        x = self.output_layer(x.squeeze(-1))#Shape: 20*K,
        if self.drop_out:
            x = F.dropout(x, training=self.training, p = 0.3)
        x = self.acti_func(x)
        #Shape: 1,
     
        return x

In [615]:
def loading_data(root_path, purpose):
    tar_path = root_path + f'Barcelona/input_target/{purpose}_regional_level.pt'
    sour_path = root_path + f'SourceDomain/regional_loaders/fake_{purpose}.pt'
    source_data = torch.load(sour_path)
    target_data = torch.load(tar_path)

    return source_data, target_data

In [616]:
class CustomData(Data):
    def __init__(self, trend, period, target_volume, target_label, edge_pairs, subgraph_node_num, subgraph_nodes, city_node_num, 
                 dist_max, dist_argmax, min_vals, max_vals):
        super(CustomData, self).__init__()
        self.trend = trend
        self.period = period
        self.target_volume = target_volume
        self.target_label = target_label
        self.edge_pairs = edge_pairs
        self.subgraph_node_num = subgraph_node_num
        self.subgraph_nodes = subgraph_nodes
        self.city_node_num = city_node_num
        self.dist_max = dist_max
        self.dist_argmax = dist_argmax
        self.min_vals = min_vals
        self.max_vals = max_vals

In [617]:
#PGCN:
input_dimen, hidden_dimen, output_dimen = 4, 16 ,8
max_ach_num = 50

#1-d Conv
input_channels, hidden_channels_1, hidden_channels_2, out_channels = 8, 32, 16, 8

#GRU
gru_input_dimen = 16

#classifier
classifier_hidden_dimen = 64

k = 60
pred_len = 10
seq_len =20
beta = 0.6

In [637]:
#-----------------------PGNN----------------------------
pgcn_model = T_Step_PGCN(input_dimen, hidden_dimen, output_dimen, max_ach_num, layer_num = 2)                      
for para in pgcn_model.parameters():
    para.requires_grad = False

    
#-------------------Feature Extractor----------------------   
feature_extractor = Feature_Extractor(input_dimen, hidden_dimen, output_dimen, max_ach_num,
                                      input_channels, hidden_channels_1, hidden_channels_2, out_channels, layer_num = 1)

#-------------------Discriminator-----------------------------
classifier = Domain_Classifier(input_dimen = output_dimen, hidden_dimen = classifier_hidden_dimen, k=k, seq_len = seq_len)

#------------------Predictor-----------------------------------
predictor = GRU(gru_input_dimen, hidden_dimen, pred_len)

In [619]:
#加载模型和优化器参数
checkpoint = torch.load('D:/ThesisData/processed data/ModelPara/model_para.pth')

In [644]:
bce_loss = nn.BCELoss()
criterion_regression_MSE = nn.MSELoss(reduction='sum')
optimizer_extractor = optim.Adam(list(feature_extractor.parameters())+list(pgcn_model.parameters())+list(predictor.parameters()), lr=0.0006, weight_decay=0.00001)
optimizer_extractor.load_state_dict(checkpoint['optimizer_state_dict'])
optimizer_classifier = p_optim.RAdam(classifier.parameters(), lr=4e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer_classifier, step_size = 5, gamma = 0.6)

In [645]:
#Loading model Parameters
pgcn_model.load_state_dict(checkpoint['pgcn_state_dict'])
feature_extractor.load_state_dict(checkpoint['feature_extractor_state_dict'])
predictor.load_state_dict(checkpoint['predictor_state_dict'])


<All keys matched successfully>

In [515]:
root_path = "D:/ThesisData/processed data/"
purposes = ["train", "vali","test"]
source_train, target_train = loading_data(root_path, purposes[0])
source_vali, target_vali = loading_data(root_path, purposes[1])
source_test, target_test = loading_data(root_path, purposes[2])

In [646]:
print(f"source train len: {len(source_train)}----target train len: {len(target_train)}")
train_data = source_train + target_train
shuffle(train_data)
batch_size = 14
batch_num = len(train_data)//batch_size
vali_batch_num = len(target_vali)//batch_size
test_batch_num = len(target_test) // batch_size
print("Training set batch number ", batch_num)
print("Validation set batch number: ", vali_batch_num)
print("Test set batch number: ", test_batch_num)

source train len: 336----target train len: 168
Training set batch number  36
Validation set batch number:  4
Test set batch number:  5


In [648]:
epoch_num = 40


plt.ion() 
#---------collect train set value-------------
train_mse_losses = []
train_mae_losses = []
train_class_losses = []
#--------collect vali set value---------------
vali_mse_losses = []
vali_mae_losses = []
#collect test set value--------------------
test_mse_losses = []
test_mae_losses = []

#---------------------------------------------
epoch_numbers = []

for h in range(epoch_num):
    
    train_mse_loss =0.0
    train_mae_loss = 0.0
    train_class_loss = 0.0

    feature_extractor.train()
    pgcn_model.train()
    predictor.train()   
    classifier.train()
    
    for b in range(batch_num):
        mini_batch = train_data[b: b + batch_size]
        all_subgraph_nodes = sum([mini_batch[i].subgraph_node_num for i in range(len(mini_batch))])

        
        #-----------------------------------------
        one_batch_extractor = []
        one_batch_predictor = []
        one_batch_classifier = []
        #-----------------------------------------
     
    
        #-----------------------------------------
        target_volumes = []
        target_labels = []

        
        #Each Batch Covers batch_size data files
        #for i in range(batch_size):
            #extractor outputs Shape: T-step, Node_num, 8
            #------------target city--------------------
        i = 0
        while i < int(batch_size * 2/3):
            extractor_outputs = feature_extractor(mini_batch[i].trend, mini_batch[i].edge_pairs, mini_batch[i].subgraph_nodes, mini_batch[i].dist_max, mini_batch[i].dist_argmax).detach()
            classifier_out = classifier(extractor_outputs)
           
                
            #-------------------------------------------------
            target_labels.append(mini_batch[i].target_label)
            one_batch_classifier.append(classifier_out)
            #-------------------------------------------------    
            i += 1
       
        
        #----------------------------------------------------------------
        one_batch_classifier = torch.stack(one_batch_classifier)
        target_labels = torch.tensor([ 0.9 if label == 1.0 else 0.1 for label in target_labels]).unsqueeze(-1)
        #target_labels = torch.tensor(target_labels).unsqueeze(-1)
        
        #---------------------------------------------------------------- 
        
        
        #---------Update the Para of Domain Classifier-----------
        optimizer_classifier.zero_grad()
        loss_classifier = bce_loss(one_batch_classifier, target_labels)
        loss_classifier.backward()
        #Gradien Clip
        torch.nn.utils.clip_grad_value_(classifier.parameters(), clip_value=0.5)
        optimizer_classifier.step()
        #------------------------------------------------------
        del one_batch_classifier
        del target_labels

        
        one_batch_classifier= []
        target_labels = []

        
        while i >= int(batch_size * 2/3) and i < batch_size:
            #--------Update Feature Extractor Para-------------
            #Classifier_out Shape: out_dimen
   
            #for i in range(batch_size):
            #extractor outputs Shape: T-step, Node_num, 8
            #------------target city--------------------
            extractor_outputs = feature_extractor(mini_batch[i].trend, mini_batch[i].edge_pairs, mini_batch[i].subgraph_nodes, mini_batch[i].dist_max, mini_batch[i].dist_argmax)    
            pgcn_period_outputs = pgcn_model(mini_batch[i].period, mini_batch[i].subgraph_nodes, mini_batch[i].edge_pairs, mini_batch[i].dist_max, mini_batch[i].dist_argmax)
            predictor_output = predictor(extractor_outputs.permute(1, 0, 2), pgcn_period_outputs.permute(1, 0, 2))
            predictor_output = predictor_output.permute(1, 0, 2)
            alpha = (b + 1) / epoch_num
            #alpha = 0.01
            classifier_input = GradientReversalLayer.apply(extractor_outputs, alpha)
            classifier_out = classifier(classifier_input)
            #-----------------------------------------------

            
            #-------------------------------------------------
            one_batch_extractor.append(extractor_outputs)
            one_batch_classifier.append(classifier_out)
            one_batch_predictor.append(predictor_output)
            #-------------------------------------------------    
            
            #--------------------------------------------------
            target_volumes.append(mini_batch[i].target_volume) 
            target_labels.append(mini_batch[i].target_label)
            i += 1

        
        #--------------------------------------------------
        target_volumes = torch.stack(target_volumes)      
        target_labels = torch.tensor([ 0.9 if label == 1.0 else 0.1 for label in target_labels]).unsqueeze(-1)
        #target_labels = torch.tensor(target_labels).unsqueeze(-1)
    

        #----------------------------------------------------------------
        one_batch_extractor = torch.stack(one_batch_extractor)  
        one_batch_classifier = torch.stack(one_batch_classifier)
        one_batch_predictor = torch.stack(one_batch_predictor)
        #----------------------------------------------------------------

        
        loss_classifier = bce_loss(one_batch_classifier, target_labels)
        loss_mse = (criterion_regression_MSE(one_batch_predictor, target_volumes)) / all_subgraph_nodes
        loss_mae = torch.sum(torch.abs(one_batch_predictor - target_volumes)) / all_subgraph_nodes
        #----------------------------------------------
        train_mse_loss += loss_mse.item()
        train_mae_loss += loss_mae.item()
        train_class_loss += loss_classifier.item()
        #-----------------------------------------------
    
        print(f"After batch {b}, regression loss: {loss_mse}; domain classifier loss: {loss_classifier}")      
        
        
        optimizer_extractor.zero_grad()
        loss_feat_ext = loss_mse + beta * loss_classifier
        loss_feat_ext.backward(retain_graph=True)
        optimizer_extractor.step()      


        del one_batch_predictor
        del target_volumes
        del one_batch_classifier
   
        
    scheduler.step()
    
    train_mse_loss /= batch_num
    train_mae_loss /= batch_num   
    train_class_loss /= batch_num   


    train_mse_losses.append(train_mse_loss)
    train_mae_losses.append(train_mae_loss)
    train_class_losses.append(train_class_loss)
        
    

#---------------------------------------------------------
#-----------验证集---------------------------------------
    feature_extractor.eval()
    pgcn_model.eval()
    classifier.eval()
    predictor.eval()

    vali_mse_loss = 0.0
    vali_mae_loss = 0.0    

    test_mse_loss = 0.0
    test_mae_loss = 0.0  

    
    with torch.no_grad():
        for b in range(vali_batch_num):
            target_batch = target_vali[b: b + batch_size]
            all_subgraph_nodes = sum([target_batch[i].subgraph_node_num for i in range(len(target_batch))]) 
            #-----------------------------------------
            one_batch_predictor = []
            #-----------------------------------------

            
            #-----------------------------------------
            target_volumes = []
       
        
            for i in range(batch_size):
                #extractor outputs Shape: T-step, Node_num, 8
                #------------target city--------------------
                extractor_outputs = feature_extractor(target_batch[i].trend, target_batch[i].edge_pairs, target_batch[i].subgraph_nodes, target_batch[i].dist_max, target_batch[i].dist_argmax)    
                pgcn_period_outputs = pgcn_model(target_batch[i].period, target_batch[i].subgraph_nodes, target_batch[i].edge_pairs, target_batch[i].dist_max, target_batch[i].dist_argmax)
                predictor_output = predictor(extractor_outputs.permute(1, 0, 2), pgcn_period_outputs.permute(1, 0, 2))
                predictor_output = predictor_output.permute(1, 0, 2)
                
                #将结果反归一化-------------------------------------
                min_vals = torch.round(target_batch[i].min_vals.unsqueeze(0).unsqueeze(0))
                max_vals = torch.round(target_batch[i].max_vals.unsqueeze(0).unsqueeze(0))  
                predictor_output[:,:target_batch[i].subgraph_node_num,:] = torch.round(predictor_output[:,:target_batch[i].subgraph_node_num,:]*(max_vals- min_vals) + min_vals)     
                #-------------------------------------------------
                one_batch_predictor.append(predictor_output)

                
                #--------------------------------------------------
                target_volume = torch.zeros_like(target_batch[i].target_volume)  
                target_volume[:,:target_batch[i].subgraph_node_num,:] = torch.round(target_batch[i].target_volume[:,:target_batch[i].subgraph_node_num,:]*(max_vals- min_vals) + min_vals) 
                target_volumes.append(target_volume)                
                    
            #---------------------------------------------------------------- 
            one_batch_predictor = torch.stack(one_batch_predictor)
            #----------------------------------------------------------------
            target_volumes = torch.stack(target_volumes)                  
          
         
            loss_mse = criterion_regression_MSE(one_batch_predictor, target_volumes) / all_subgraph_nodes
            loss_mae =torch.sum(torch.abs(one_batch_predictor - target_volumes)) / all_subgraph_nodes
            vali_mse_loss +=  loss_mse.item()
            vali_mae_loss += loss_mae.item()
            print(f"After batch {b}, Vali MSE loss: {loss_mse}, Vali MAE loss: {loss_mae}")  


        vali_mse_loss /= vali_batch_num
        vali_mae_loss /= vali_batch_num

    
        vali_mse_losses.append(vali_mse_loss)
        vali_mae_losses.append(vali_mae_loss)


#---------------------Test set--------------------------------------------------
        
        for b in range(test_batch_num):
            target_batch = target_test[b: b + batch_size]
            all_subgraph_nodes = sum([target_batch[i].subgraph_node_num for i in range(len(target_batch))]) 
            #-----------------------------------------
            one_batch_predictor = []
            #-----------------------------------------
         
            #-----------------------------------------
            target_volumes = []      
            for i in range(batch_size):
                extractor_outputs = feature_extractor(target_batch[i].trend, target_batch[i].edge_pairs, target_batch[i].subgraph_nodes, target_batch[i].dist_max, target_batch[i].dist_argmax)    
                pgcn_period_outputs = pgcn_model(target_batch[i].period, target_batch[i].subgraph_nodes, target_batch[i].edge_pairs, target_batch[i].dist_max, target_batch[i].dist_argmax)
                predictor_output = predictor(extractor_outputs.permute(1, 0, 2), pgcn_period_outputs.permute(1, 0, 2))
                predictor_output = predictor_output.permute(1, 0, 2)
                
                #将结果反归一化-------------------------------------
                min_vals = torch.round(target_batch[i].min_vals.unsqueeze(0).unsqueeze(0))
                max_vals = torch.round(target_batch[i].max_vals.unsqueeze(0).unsqueeze(0))  
                predictor_output[:,:target_batch[i].subgraph_node_num,:] = torch.round(predictor_output[:,:target_batch[i].subgraph_node_num,:]*(max_vals- min_vals) + min_vals)     
                #-------------------------------------------------
                one_batch_predictor.append(predictor_output)

                
                #--------------------------------------------------
                target_volume = torch.zeros_like(target_batch[i].target_volume)  
                target_volume[:,:target_batch[i].subgraph_node_num,:] = torch.round(target_batch[i].target_volume[:,:target_batch[i].subgraph_node_num,:]*(max_vals- min_vals) + min_vals) 
                target_volumes.append(target_volume)                
                    
            #---------------------------------------------------------------- 
            one_batch_predictor = torch.stack(one_batch_predictor)
            #----------------------------------------------------------------
            target_volumes = torch.stack(target_volumes)                  
                          
            loss_mse = criterion_regression_MSE(one_batch_predictor, target_volumes) / all_subgraph_nodes
            loss_mae =torch.sum(torch.abs(one_batch_predictor - target_volumes)) / all_subgraph_nodes
            test_mse_loss +=  loss_mse.item()
            test_mae_loss += loss_mae.item()
            print(f"After batch {b}, Test MSE loss: {loss_mse}, Test MAE loss: {loss_mae}")  
    
    
        test_mse_loss /= test_batch_num
        test_mae_loss /= test_batch_num

    
        test_mse_losses.append(test_mse_loss)
        test_mae_losses.append(test_mae_loss)            

    
    #------------------------------------------------------------------
    epoch_numbers.append(h+1)
    
    clear_output(wait=True)  #清除上一次的输出
    fig, axs = plt.subplots(1,2, figsize=(16,6))
    plt.subplots_adjust(wspace=0.35)
    line1, = axs[0].plot(epoch_numbers, train_mse_losses, label = "MSE", color = 'g') 
    axs[0].tick_params(axis='both',labelsize=8)
    axs[0].set_ylim(0,max(train_mse_losses)+0.1)
    axs[0].set_xlabel('Epoch number', fontsize = 11)
    axs[0].set_ylabel('MSE loss', fontsize = 11)
    axs[0].set_title('Normalized training data', fontsize = 13)

    
    ax_2 = axs[0].twinx()
    line2, = ax_2.plot(epoch_numbers, train_class_losses, label = "Binary cross-entropy loss", color = 'b')
    ax_2.set_ylim(0,max(train_class_losses)+0.1)
    ax_2.set_ylabel('Binary cross-entropy loss', fontsize = 11)
    ax_2.tick_params(axis='y',labelsize=8)
    lines = [line1, line2] 
    labels = [line.get_label() for line in lines]
    axs[0].legend(lines, labels, loc='upper right') 

    
    formatter = mticker.ScalarFormatter(useMathText=True)
    formatter.set_scientific(False)

    axs[1].plot(epoch_numbers, vali_mse_losses, label = "Validation set", color = 'y')
    axs[1].plot(epoch_numbers, test_mse_losses, label = "Test set", color = 'b')
    axs[1].set_ylim(0,max(max(vali_mse_losses), max(test_mse_losses)) + 1000)
    axs[1].tick_params(axis='both',labelsize=8)
    axs[1].yaxis.set_major_formatter(formatter)
    axs[1].yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, pos: f'{x / 1000:.0f}'))
    axs[1].set_xlabel('Epoch number', fontsize = 11)
    axs[1].set_ylabel('Vehicles * day (*10³)', fontsize = 11)
    axs[1].set_title('MSE value (inverse-normalized data)', fontsize = 13)    
    axs[1].legend(loc='upper right')
    
    plt.show()
    plt.pause(0.1)
    
        
plt.ioff()  #关闭交互模式
plt.show()      

After batch 0, regression loss: 0.0617988146841526; domain classifier loss: 0.8797403573989868
After batch 1, regression loss: 0.061779431998729706; domain classifier loss: 0.7679815292358398
After batch 2, regression loss: 0.05472857877612114; domain classifier loss: 0.7762728929519653
After batch 3, regression loss: 0.04675084725022316; domain classifier loss: 0.9242613911628723
After batch 4, regression loss: 0.02646695263683796; domain classifier loss: 0.4969669282436371
After batch 5, regression loss: 0.025378329679369926; domain classifier loss: 0.6383566856384277
After batch 6, regression loss: 0.03332853689789772; domain classifier loss: 0.5507872104644775
After batch 7, regression loss: 0.02643711306154728; domain classifier loss: 0.5581260323524475
After batch 8, regression loss: 0.042385831475257874; domain classifier loss: 0.7486862540245056
After batch 9, regression loss: 0.05795174464583397; domain classifier loss: 0.9554055333137512


KeyboardInterrupt: 