In [1]:
from Models.GAE_GConvLSTM import GAE_GConvLSTM_encoder, GAE_GConvLSTM_decoder, GAE_GConvLSTM_seq2seq
from torchinfo import summary
import torch
import scipy
import Models.Get_data as Gd
import pickle

import numpy as np
import os

In [2]:
import random
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(66)  # 你可以选择任意一个数字作为种子

In [3]:
directory = '/home/wl4023/data/Sibo_22Mar2024'
folders = [os.path.join(directory, f, 'hessian_') for f in os.listdir(directory) if f.startswith('case_')]

xyfile = "/home/wl4023/data/Sibo_22Mar2024/case_0/hessian_/xy_coords.npy"
pos = torch.tensor(np.load(xyfile), dtype=torch.float32)

sparse_graph = scipy.sparse.load_npz('/home/wl4023/data/Sibo_22Mar2024/new_sparse_matrix.npz')
indices = np.array(sparse_graph.nonzero())
values = sparse_graph.data
shape = sparse_graph.shape

# sparse edge tensor
edge_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float).coalesce()
edge_index = edge_tensor.indices()
edge_weight = edge_tensor.values()
edge_weight = edge_weight.unsqueeze(1)

In [4]:
window_size = 10
step_size = 3
latent_space = 5

num_mp_layers = [2, 2, 2]
num_clusters = [1000, latent_space]
clusters = torch.load(f'/home/wl4023/github_repos/IRP/result/Latent space {latent_space}/clusters.pt')
centroids = torch.load(f'/home/wl4023/github_repos/IRP/result/Latent space {latent_space}/centroids.pt')

In [5]:
dataset, length = Gd.get_all_nodes(folders[:-1])
dataset = np.expand_dims(dataset[:, :, 0], axis=2)
# dataset = dataset[:, :, 0]

dataset = torch.tensor(dataset, dtype=torch.float32)

In [6]:
dataset.shape

torch.Size([2900, 97149, 1])

In [7]:
# Define the weights initialization function
def initialize_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

In [8]:
dataset = Gd.segment_data_torch(dataset, length, 10, 3)
dataset.shape

torch.Size([899, 10, 97149, 1])

In [9]:
encoder = GAE_GConvLSTM_encoder(latent_space=latent_space,
                                input_node_channel=1,
                                num_mp_layers=num_mp_layers,
                                clusters=clusters,
                                centroids=centroids,
                                hidden_channels=8,
                                n_mlp_mp=3,
                                hidden_channel_lstm=16)
# h, c = encoder(seg_data[:5, :5, :])
summary(encoder, input_data=(dataset[0, :5, :, :], edge_index, edge_weight, pos))

Layer (type:depth-idx)                             Output Shape              Param #
GAE_GConvLSTM_encoder                              [5, 16]                   --
├─Encoder: 1-1                                     [5, 8]                    291
│    └─ModuleList: 2-193                           --                        (recursive)
│    │    └─Linear: 3-1                            [97149, 8]                16
│    └─ELU: 2-2                                    [97149, 8]                --
│    └─ModuleList: 2-193                           --                        (recursive)
│    │    └─Linear: 3-2                            [97149, 8]                72
│    └─ELU: 2-4                                    [97149, 8]                --
│    └─ModuleList: 2-193                           --                        (recursive)
│    │    └─Linear: 3-3                            [97149, 8]                72
│    └─ELU: 2-6                                    [97149, 8]                --
│    └─

In [10]:
h_t, c_t, hidden_edge_index, hidden_edge_attr, edge_indices, edge_attrs, edge_indices_f2c, position, node_attrs, clusters=encoder(dataset[0, :5, :, :], edge_index, edge_weight, pos)
decoder = GAE_GConvLSTM_decoder(hidden_channel_lstm=16,
                                latent_space=5,
                                output_node_channel=1,
                                num_mp_layers=num_mp_layers,
                                hidden_channels=8,
                                n_mlp_mp=3)


In [11]:
summary(decoder, input_data=(h_t, c_t, 5, hidden_edge_index, hidden_edge_attr, edge_indices, edge_attrs, edge_indices_f2c, position, node_attrs, clusters))

Layer (type:depth-idx)                             Output Shape              Param #
GAE_GConvLSTM_decoder                              [5, 97149, 1]             --
├─GCNConv: 1-1                                     [5, 16]                   16
│    └─Linear: 2-1                                 [5, 16]                   512
│    └─SumAggregation: 2-2                         [5, 16]                   --
├─ELU: 1-2                                         [5, 16]                   --
├─GConvLSTM_cell: 1-3                              [5, 8]                    --
│    └─GCNConv: 2-3                                [5, 32]                   32
│    │    └─Linear: 3-1                            [5, 32]                   768
│    │    └─SumAggregation: 3-2                    [5, 32]                   --
├─Decoder: 1-4                                     [97149, 1]                291
│    └─ModuleList: 2-234                           --                        (recursive)
│    │    └─ModuleList:

In [13]:
model = GAE_GConvLSTM_seq2seq(latent_space=latent_space,
                              hidden_channel_lstm=1,
                              input_node_channel=1,
                              output_node_channel=1,
                              num_mp_layers=num_mp_layers,
                              clusters=clusters,
                              centroids=centroids,
                              hidden_channels=8,
                              n_mlp_mp=3)

summary(model, input_data=(dataset[0, :5, :, :], 5, edge_index, edge_weight, pos))

Layer (type:depth-idx)                                  Output Shape              Param #
GAE_GConvLSTM_seq2seq                                   [5, 97149, 1]             --
├─GAE_GConvLSTM_encoder: 1-1                            [5, 1]                    --
│    └─Encoder: 2-1                                     [5, 8]                    291
│    │    └─ModuleList: 3-193                           --                        (recursive)
│    │    └─ELU: 3-2                                    [97149, 8]                --
│    │    └─ModuleList: 3-193                           --                        (recursive)
│    │    └─ELU: 3-4                                    [97149, 8]                --
│    │    └─ModuleList: 3-193                           --                        (recursive)
│    │    └─ELU: 3-6                                    [97149, 8]                --
│    │    └─ModuleList: 3-233                           --                        (recursive)
│    │    └─ELU: 3-8   