### 图神经网络环境配置，参见 https://www.bilibili.com/video/BV1184y1x71H?p=11

In [1]:
import os
import numpy as np
import pandas as pd
import os
import pandas as pd
from scipy.io import loadmat
import networkx as nx
import matplotlib.pyplot as plt
import torch
from torch_geometric.data import InMemoryDataset, Data
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from sklearn.metrics import precision_score, f1_score, recall_score, accuracy_score
from torch_geometric.nn import TopKPooling,SAGEConv, GCNConv, SAGPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as  gmp
from torch.nn import Linear,BatchNorm1d, ReLU
import torch.nn.functional as F
from imblearn.under_sampling import RandomUnderSampler

device = torch.device("mps" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

## 参数

In [2]:
epoch_num = 30  # 训练轮数
lr = 0.001      # 学习率
bs = 64         # 批次大小
isCause = True # 是否加入因果关系

dataset_root = '../Data' # 数据存放目录

## 因果关系的加入
注意：节点必须从0开始标号

In [3]:
def get_edge_index_and_edge_attr(ex_type):
    # source_nodes->target_nodes对应着一条有向边，而edge_attr为此条有向边的权值,无向图代表两个方向都有边
    source_nodes = [0, 1, 0, 4, 0, 7, 1, 2, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 8, 14, 8, 19, 9, 10, 10, 11, 11, 12, 12, 13, 14, 15, 15, 16, 16, 17, 17, 18, 19, 20, 20, 21]
    target_nodes = [1, 0, 4, 0, 7, 0, 2, 1, 3, 2, 5, 4, 6, 5, 8, 7, 9, 8, 14, 8, 19, 8, 10, 9, 11, 10, 12, 11, 13, 12, 15, 14, 16, 15, 17, 16, 18, 17, 20, 19, 21, 20]
    edge_attr = None
    if isCause:
        if ex_type == 2:
            source_nodes = [0, 1, 2, 4, 6, 7, 8, 8 , 8 , 9, 10, 11, 12, 15, 15, 16, 17, 19, 19, 21]
            target_nodes = [4, 0, 3, 5, 5, 0, 7, 9, 14, 10, 11, 12, 13, 14, 16, 17, 18, 8 , 20, 20]
            edge_attr = None
    
    edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)
    return edge_index, edge_attr

## 数据处理部分
注意：如果图结构发生改变，请删除数据根目录/processed/下的缓存文件，以重新进行数据处理

In [4]:
# 这个类继承自InMemoryDataset，专门处理中小型数据集
class BinaryDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []
    # 如果根目录下的processed文件夹下已经尤处理好的图数据，则直接读取
    @property
    def processed_file_names(self):
        if isCause:
            return ['cause.dataset']
        else:
            return ['base.dataset',]
    def download(self):
        # Download to `self.raw_dir`.
        pass

    def process(self):
        """
        Load Data Seperate All data in Protective and Non-Prtective Data
        """
        # Assuming the 'Data' folder is in the current working directory
        data_folders = [dataset_root]

        # Initialize lists to store the separated dataframes
        protective_dfs = []
        non_protective_dfs = []

        # Loop through each data folder
        for data_folder in data_folders:
            # List all .mat files in the current data folder
            mat_files = [f for f in os.listdir(data_folder) if f.endswith('.mat')]
            # Load each mat file
            for mat_file in mat_files:
                # Construct the full path to the .mat file
                mat_path = os.path.join(data_folder, mat_file)
                # Load the .mat file
                mat_data = loadmat(mat_path)
                # Convert the data into a pandas dataframe
                df = pd.DataFrame(mat_data['data'])
                # Select only the first 70 columns and the last column (73rd) which contains the behavior label
                df = df.iloc[:, list(range(66)) + [70] + [72]]
                # Split the data based on the protective behavior label
                # Assuming the last column in df is the protective behavior label
                protective_behavior = df.iloc[:, -1]
                protective_df = df[protective_behavior == 1]
                non_protective_df = df[protective_behavior == 0]
                # Append the resulting dataframes to their respective lists
                protective_dfs.append(protective_df)
                non_protective_dfs.append(non_protective_df)

        # Concatenate all protective and non-protective dataframes
        all_protective_data = pd.concat(protective_dfs, axis=0, ignore_index=True)
        all_non_protective_data = pd.concat(non_protective_dfs, axis=0, ignore_index=True)

        # Now `all_protective_data` and `all_non_protective_data` hold the protective and non-protective data respectively
        # You can process these dataframes as needed for your analysis or save them to new .mat files
        # 保护和非保护数据个数 (77298, 68) (437247, 68)
        # 数据合并
        all_data = np.concatenate([all_protective_data,all_non_protective_data],axis=0)
        # 提取运动类型和保护性行为列
        ys = all_data[:,-2:]
        # 将特征列构建为N*V*C的形式，其中N=图个数（一条数据就可以看作一个图），V代表节点个数， C代表每个节点的特征个数
        x = all_data[:,np.newaxis,[0,22,44]]
        for node in range(2, 23):
            x_index = node - 1
            y_index = x_index + 22
            z_index = x_index + 44
            temp = all_data[:,np.newaxis,[x_index,y_index,z_index]]
            x = np.concatenate([x,temp], axis=1)
        # 将边整理为GNN要求的格式，边索引必须从0开始
       
        # 构建数据列表
        data_list = []
        # 一条数据构建为一个图
        for i in range(x.shape[0]):
            y = torch.tensor([ys[i,-1]], dtype=torch.float)   
            # 根据数据所属运动类型获得其拓扑结构
            edge_index, edge_attr = get_edge_index_and_edge_attr(ys[i,-2])
            if ys[i,-2] == 2:
                data = Data(x=torch.tensor(x[i], dtype=torch.float), edge_index=edge_index, edge_attr=edge_attr, y=y)
                data_list.append(data)
        # 将处理后的数据存储至指定根目录下
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
            
        
# 加载并分割数据集      
def load_data_and_split(dataset_root):
    # 数据打乱
    dataset = BinaryDataset(root=dataset_root)
    # 随机打乱
    dataset = dataset.shuffle()
    # 数据划分逻辑
    num_graphs = len(dataset)
    train_size = int(num_graphs * 0.8)
    # train_size = int(num_graphs * 0.01)
    
    # 创建训练集和测试集的子集索引
    train_indices = list(range(train_size))
    test_indices = list(range(train_size, num_graphs))
    
    # 使用torch_geometric.data.Subset来划分数据
    train_dataset = dataset.index_select(train_indices)
    test_dataset = dataset.index_select(test_indices)
    print("isCause:",isCause)
    print("len of train_dataset:",len(train_dataset))
    print("len of test_dataset :",len(test_dataset))
    # 创建DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)
    
    return train_loader, test_loader

train_loader, test_loader =  load_data_and_split(dataset_root)

Processing...


isCause: True
len of train_dataset: 17884
len of test_dataset : 4471


Done!


## 模型定义
模块参见文档 https://pytorch-geometric.readthedocs.io/en/latest/

In [5]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = GCNConv(3,64)
        self.pool1 = TopKPooling(64, ratio=0.8)
        self.conv2 = GCNConv(64,64)
        self.pool2 = TopKPooling(64, ratio=0.8)
        self.conv3 = GCNConv(64,64)
        self.pool3 = TopKPooling(64, ratio=0.8)
        self.lin1 = Linear(64,64)
        self.lin2 = Linear(64,32)
        self.lin3 = Linear(32,1)
        # self.bn1 = BatchNorm1d(64)
        # self.bn2 = BatchNorm1d(32)
        self.act1 = ReLU()
        self.act2 = ReLU()

    
    def forward(self, data):
        x, edge_index, batch, edge_attr = data.x, data.edge_index, data.batch, data.edge_attr

        x = x.clone().detach().requires_grad_(True)
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        # x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, edge_attr, batch)
        x1 = gap(x, batch)
        
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, edge_attr, batch)
        x2 = gap(x, batch)
        
        # x = F.relu(self.conv3(x, edge_index))
        # x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)
        # x3 = gap(x, batch)  
        
        x = x1 + x2
        x = self.lin1(x)
        x = self.act1(x)
        x = self.lin2(x)
        x = self.act2(x)
        x = F.dropout(x, p=0.5, training=self.training)
        
        # 输出值为0-1之间 （0为非保护 1为保护）
        x = torch.sigmoid(self.lin3(x)).squeeze(1)
        return x     

## 模型训练和测试

In [6]:
def train(train_loader):
    model.train()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        label = data.y.to(device)
        loss = crit(output, label)
        loss.backward()
        optimizer.step()

def test(loader):
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():  # 在评估模式下不计算梯度
        for data in loader:
            data = data.to(device)
            out = model(data)  # 前向传播
            pred = (out >= 0.5).int()

            y_true.extend(data.y.clone().detach().tolist())  # 收集真实标签
            y_pred.extend(pred.clone().detach().tolist())  # 收集预测标签
    
    accuracy = accuracy_score(y_true, y_pred)
    macro_precision = precision_score(y_true, y_pred,  pos_label=1, average='macro') 
    
    f1 = f1_score(y_true, y_pred, pos_label=1)  # Only report results for the class specified by pos_label. 考虑了非平衡
    recall = recall_score(y_true, y_pred, pos_label=1)
    return accuracy, macro_precision, f1, recall

model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
crit = torch.nn.BCELoss()
for epoch in range(1, epoch_num):
    train(train_loader)
    train_accuracy, train_macro_precision, train_f1, train_recall = test(train_loader)
    test_accuracy,  test_macro_precision,  test_f1 , test_recall  = test(test_loader)
    
    print(f'Epoch: {epoch:03d}  -------------------------- '
          f'\nTrain: accuracy: {train_accuracy:.4f}, Macro Precision: {train_macro_precision:.4f}, F1: {train_f1:.4f}, recall: {train_recall:.4f}  '
          f'\nTest : accuracy: {test_accuracy:.4f}, Macro Precision: {test_macro_precision:.4f}, F1: {test_f1:.4f}, recall: {test_recall:.4f}')

Epoch: 001  -------------------------- 
Train: accuracy: 0.7546, Macro Precision: 0.7562, F1: 0.6290, recall: 0.5366  
Test : accuracy: 0.7614, Macro Precision: 0.7622, F1: 0.6392, recall: 0.5494
Epoch: 002  -------------------------- 
Train: accuracy: 0.7607, Macro Precision: 0.7584, F1: 0.6484, recall: 0.5694  
Test : accuracy: 0.7660, Macro Precision: 0.7630, F1: 0.6557, recall: 0.5791
Epoch: 003  -------------------------- 
Train: accuracy: 0.7652, Macro Precision: 0.7605, F1: 0.6623, recall: 0.5939  
Test : accuracy: 0.7620, Macro Precision: 0.7555, F1: 0.6561, recall: 0.5901
Epoch: 004  -------------------------- 
Train: accuracy: 0.8229, Macro Precision: 0.8502, F1: 0.7259, recall: 0.6052  
Test : accuracy: 0.8197, Macro Precision: 0.8426, F1: 0.7211, recall: 0.6058
Epoch: 005  -------------------------- 
Train: accuracy: 0.8160, Macro Precision: 0.8260, F1: 0.7288, recall: 0.6379  
Test : accuracy: 0.8121, Macro Precision: 0.8211, F1: 0.7206, recall: 0.6297
Epoch: 006  --------

Epoch: 001  -------------------------- 
Train: accuracy: 0.7546, Macro Precision: 0.7562, F1: 0.6290, recall: 0.5366  
Test : accuracy: 0.7614, Macro Precision: 0.7622, F1: 0.6392, recall: 0.5494
Epoch: 002  -------------------------- 
Train: accuracy: 0.7607, Macro Precision: 0.7584, F1: 0.6484, recall: 0.5694  
Test : accuracy: 0.7660, Macro Precision: 0.7630, F1: 0.6557, recall: 0.5791
Epoch: 003  -------------------------- 
Train: accuracy: 0.7652, Macro Precision: 0.7605, F1: 0.6623, recall: 0.5939  
Test : accuracy: 0.7620, Macro Precision: 0.7555, F1: 0.6561, recall: 0.5901
Epoch: 004  -------------------------- 
Train: accuracy: 0.8229, Macro Precision: 0.8502, F1: 0.7259, recall: 0.6052  
Test : accuracy: 0.8197, Macro Precision: 0.8426, F1: 0.7211, recall: 0.6058
Epoch: 005  -------------------------- 
Train: accuracy: 0.8160, Macro Precision: 0.8260, F1: 0.7288, recall: 0.6379  
Test : accuracy: 0.8121, Macro Precision: 0.8211, F1: 0.7206, recall: 0.6297
Epoch: 006  -------------------------- 
Train: accuracy: 0.8154, Macro Precision: 0.8192, F1: 0.7348, recall: 0.6598  
Test : accuracy: 0.8150, Macro Precision: 0.8169, F1: 0.7338, recall: 0.6628
Epoch: 007  -------------------------- 
Train: accuracy: 0.8364, Macro Precision: 0.8624, F1: 0.7504, recall: 0.6344  
Test : accuracy: 0.8352, Macro Precision: 0.8593, F1: 0.7472, recall: 0.6331
Epoch: 008  -------------------------- 
Train: accuracy: 0.8288, Macro Precision: 0.8571, F1: 0.7358, recall: 0.6148  
Test : accuracy: 0.8215, Macro Precision: 0.8482, F1: 0.7214, recall: 0.6006
Epoch: 009  -------------------------- 
Train: accuracy: 0.8273, Macro Precision: 0.8320, F1: 0.7531, recall: 0.6796  
Test : accuracy: 0.8253, Macro Precision: 0.8291, F1: 0.7485, recall: 0.6756
Epoch: 010  -------------------------- 
Train: accuracy: 0.8326, Macro Precision: 0.8727, F1: 0.7351, recall: 0.5991  
Test : accuracy: 0.8253, Macro Precision: 0.8628, F1: 0.7210, recall: 0.5866
Epoch: 011  -------------------------- 
Train: accuracy: 0.8367, Macro Precision: 0.8745, F1: 0.7435, recall: 0.6108  
Test : accuracy: 0.8396, Macro Precision: 0.8767, F1: 0.7469, recall: 0.6151
Epoch: 012  -------------------------- 
Train: accuracy: 0.8509, Macro Precision: 0.8842, F1: 0.7706, recall: 0.6461  
Test : accuracy: 0.8463, Macro Precision: 0.8801, F1: 0.7605, recall: 0.6343
Epoch: 013  -------------------------- 
Train: accuracy: 0.8397, Macro Precision: 0.8432, F1: 0.7741, recall: 0.7085  
Test : accuracy: 0.8361, Macro Precision: 0.8385, F1: 0.7672, recall: 0.7023
Epoch: 014  -------------------------- 
Train: accuracy: 0.8178, Macro Precision: 0.8704, F1: 0.7007, recall: 0.5502  
Test : accuracy: 0.8110, Macro Precision: 0.8634, F1: 0.6848, recall: 0.5337
Epoch: 015  -------------------------- 
Train: accuracy: 0.8510, Macro Precision: 0.8669, F1: 0.7814, recall: 0.6871  
Test : accuracy: 0.8479, Macro Precision: 0.8619, F1: 0.7759, recall: 0.6843
Epoch: 016  -------------------------- 
Train: accuracy: 0.8529, Macro Precision: 0.8719, F1: 0.7825, recall: 0.6826  
Test : accuracy: 0.8513, Macro Precision: 0.8696, F1: 0.7783, recall: 0.6785
Epoch: 017  -------------------------- 
Train: accuracy: 0.8514, Macro Precision: 0.8723, F1: 0.7787, recall: 0.6746  
Test : accuracy: 0.8479, Macro Precision: 0.8676, F1: 0.7718, recall: 0.6686
Epoch: 018  -------------------------- 
Train: accuracy: 0.8581, Macro Precision: 0.8794, F1: 0.7894, recall: 0.6862  
Test : accuracy: 0.8557, Macro Precision: 0.8762, F1: 0.7842, recall: 0.6814
Epoch: 019  -------------------------- 
Train: accuracy: 0.8398, Macro Precision: 0.8418, F1: 0.7758, recall: 0.7152  
Test : accuracy: 0.8410, Macro Precision: 0.8439, F1: 0.7745, recall: 0.7099
Epoch: 020  -------------------------- 
Train: accuracy: 0.8159, Macro Precision: 0.8153, F1: 0.7413, recall: 0.6808  
Test : accuracy: 0.8126, Macro Precision: 0.8112, F1: 0.7343, recall: 0.6733
Epoch: 021  -------------------------- 
Train: accuracy: 0.8569, Macro Precision: 0.8870, F1: 0.7825, recall: 0.6639  
Test : accuracy: 0.8544, Macro Precision: 0.8815, F1: 0.7780, recall: 0.6634
Epoch: 022  -------------------------- 
Train: accuracy: 0.8502, Macro Precision: 0.8802, F1: 0.7713, recall: 0.6516  
Test : accuracy: 0.8454, Macro Precision: 0.8733, F1: 0.7625, recall: 0.6448
Epoch: 023  -------------------------- 
Train: accuracy: 0.8599, Macro Precision: 0.8864, F1: 0.7893, recall: 0.6771  
Test : accuracy: 0.8546, Macro Precision: 0.8783, F1: 0.7804, recall: 0.6715
Epoch: 024  -------------------------- 
Train: accuracy: 0.8335, Macro Precision: 0.8335, F1: 0.7686, recall: 0.7134  
Test : accuracy: 0.8271, Macro Precision: 0.8251, F1: 0.7587, recall: 0.7064
Epoch: 025  -------------------------- 
Train: accuracy: 0.8619, Macro Precision: 0.8862, F1: 0.7939, recall: 0.6861  
Test : accuracy: 0.8562, Macro Precision: 0.8777, F1: 0.7843, recall: 0.6797
Epoch: 026  -------------------------- 
Train: accuracy: 0.8607, Macro Precision: 0.8688, F1: 0.8023, recall: 0.7297  
Test : accuracy: 0.8604, Macro Precision: 0.8690, F1: 0.7999, recall: 0.7250
Epoch: 027  -------------------------- 
Train: accuracy: 0.8901, Macro Precision: 0.8824, F1: 0.8615, recall: 0.8823  
Test : accuracy: 0.8904, Macro Precision: 0.8828, F1: 0.8600, recall: 0.8750
Epoch: 028  -------------------------- 
Train: accuracy: 0.8992, Macro Precision: 0.9037, F1: 0.8620, recall: 0.8123  
Test : accuracy: 0.9023, Macro Precision: 0.9054, F1: 0.8661, recall: 0.8215
Epoch: 029  -------------------------- 
Train: accuracy: 0.8898, Macro Precision: 0.8966, F1: 0.8470, recall: 0.7868  
Test : accuracy: 0.8915, Macro Precision: 0.8969, F1: 0.8491, recall: 0.7936
