## 1 获取输入数据

In [11]:
import pandas as pd
from dgl.data import DGLDataset
import torch
import dgl
import os
os.environ["DGLBACKEND"] = "pytorch"
def save_data(data, save_file):
    """保存数据"""
    import pickle
    # 保存到文件
    with open(save_file, 'wb') as file:
        pickle.dump(data, file)


def load_data(load_file):
    """读取数据"""
    import pickle
    # 打开文件
    with open(load_file, 'rb') as file:
        data = pickle.load(file)
    return data

In [12]:
def get_val_graphs(edges_file, nodes_file, graph_pro_file, path="./val_data/"):
    # 读取edges，nodes数据
    edges_df = pd.read_csv(path+edges_file)
    graph_properties = pd.read_csv(path + graph_pro_file)
    nodes_df = pd.read_csv(path + nodes_file)
    graphs = []

    # 先依据graph_id排序
    nodes_df = nodes_df.sort_values(by=['graph_id'], ascending=[True])
    edges_df = edges_df.sort_values(by=['graph_id'], ascending=[True])
    # 将edges，nodes依据graph_id分组
    graph_nodes_dfs = nodes_df.groupby(by="graph_id")
    graph_edges_dfs = edges_df.groupby(by="graph_id")

    # 先读取每个图的结点数量
    num_nodes_dict = {}
    for _, row in graph_properties.iterrows():
        num_nodes_dict[row["graph_id"]] = row["num_nodes"]

    # 处理每个图数据
    for (graph_id, nodes_df), (e_graph_id, edges_df) in zip(graph_nodes_dfs, graph_edges_dfs):
        # 找对应的图
        assert (graph_id == e_graph_id)
        num_nodes = num_nodes_dict[graph_id]
        # 获取图结构
        src = edges_df["src"].to_numpy()
        dst = edges_df["dst"].to_numpy()

        # edges的feats
        edges_feat = torch.from_numpy(edges_df[["F_1", "F_2"]].to_numpy())
        # nodes的feats
        nodes_feat = []
        for i in range(num_nodes):
            t_row = nodes_df[nodes_df["node_id"] == i]
            if not t_row.empty:
                feat = torch.from_numpy(
                    t_row.iloc[:, 2:-2].values).view(35)
            else:
                # 补全数据
                feat = torch.zeros(35)
            nodes_feat.append(feat)
        nodes_feat = torch.stack(nodes_feat, dim=0)
        # 创建图
        g = dgl.graph((src, dst), num_nodes=num_nodes)

        # 添加图的node和edge的feats
        g.edata["feat"] = edges_feat.float()
        g.ndata["feat"] = nodes_feat.float()
        # 避免一些结点入度为0
        g = dgl.add_self_loop(g)
        graphs.append(g)
    return len(graphs), graphs

In [13]:
size,graphs = get_val_graphs(edges_file="edges.csv",
                        nodes_file="nodes.csv",
                        graph_pro_file="graph_propertity.csv")

### 查看输入数据

In [14]:
for g in graphs:
    print(f"Graph for validaing: {g}")
    ndata = g.ndata
    edata = g.edata
    print("Node features")
    print(g.ndata)
    print("Edge features")
    print(g.edata)

Graph for validaing: Graph(num_nodes=1154, num_edges=22581,
      ndata_schemes={'feat': Scheme(shape=(35,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})
Node features
{'feat': tensor([[-0.8390, -0.8570, -0.8270,  ..., -0.5210, -0.6030, -1.5680],
        [ 0.9050,  0.8160,  1.2150,  ...,  0.9250, -0.6860, -0.6890],
        [-0.8880, -0.9360, -0.8680,  ..., -0.5590, -0.4030, -1.2260],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])}
Edge features
{'feat': tensor([[ 5., 16.],
        [35., 13.],
        [15.,  3.],
        ...,
        [ 1.,  1.],
        [ 1.,  1.],
        [ 1.,  1.]])}
Graph for validaing: Graph(num_nodes=1154, num_edges=22581,
      ndata_schemes={'feat': Scheme(shape=(35,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(2,), dty

## 2 加载模型参数

In [15]:
g = graphs[0]
# 加载模型参数
from models.GAT import GAT
model = GAT(g.ndata["feat"].shape[1], 16, 2, heads=4)
model.load_state_dict(torch.load("./checkpoints/GAT.pt"))


<All keys matched successfully>

## 3 评估数据

In [16]:
# 处理一张图
def process_signle_graph(g,model):
    model.eval()
    features = g.ndata["feat"]
    logits = model(g, features)
    g.ndata["label"] = logits
    print(g.ndata["label"])

In [17]:
# 遍历所有graph
for graph in graphs:
    # 输入模型 得到所有的label
    process_signle_graph(graph,model)

tensor([[64.8690, 58.8948],
        [69.3735, 74.9575],
        [68.9673, 72.5281],
        ...,
        [65.9887, 62.5208],
        [65.0823, 62.0725],
        [63.8752, 59.5140]], grad_fn=<ViewBackward0>)
tensor([[69.7442, 74.7252],
        [69.1826, 72.8141],
        [69.8916, 74.0406],
        ...,
        [67.5801, 71.6904],
        [64.0976, 60.2424],
        [67.4643, 65.6494]], grad_fn=<ViewBackward0>)
tensor([[71.5447, 78.2391],
        [70.8220, 76.0115],
        [71.3010, 78.3669],
        ...,
        [70.7557, 76.4246],
        [65.0823, 62.0725],
        [69.1717, 70.4733]], grad_fn=<ViewBackward0>)
tensor([[70.7740, 75.9505],
        [71.8720, 79.3364],
        [70.8794, 78.2037],
        ...,
        [65.0823, 62.0725],
        [65.5361, 61.8347],
        [70.5800, 74.9101]], grad_fn=<ViewBackward0>)


## 4 保存输出数据

In [18]:
def save_data(data, save_file):
    """保存数据"""
    import pickle
    # 保存到文件
    with open(save_file, 'wb') as file:
        pickle.dump(data, file)


def load_data(load_file):
    """读取数据"""
    import pickle
    # 打开文件
    with open(load_file, 'rb') as file:
        data = pickle.load(file)
    return data

def get_maplist(mapping):
    maplist = [None] * len(mapping)
    
    for key, value in mapping.items():
        maplist[value] = key
        
    return maplist

In [19]:
graph_id_file = "./val_data/graph_id_map.pkl"
node_id_file = "./val_data/node_id_maps.pkl"
# 加载映射关系
node_id_maps = load_data(node_id_file)
graph_id_map = load_data(graph_id_file)

### 数据：graphs中的label数据，graph_id映射，node_id映射

In [20]:
# check 映射数据
print(graphs)
for node_id_map in node_id_maps:
    print(node_id_map)
print(graph_id_map)

[Graph(num_nodes=1154, num_edges=22581,
      ndata_schemes={'feat': Scheme(shape=(35,), dtype=torch.float32), 'label': Scheme(shape=(2,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)}), Graph(num_nodes=1154, num_edges=22581,
      ndata_schemes={'feat': Scheme(shape=(35,), dtype=torch.float32), 'label': Scheme(shape=(2,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)}), Graph(num_nodes=1154, num_edges=22475,
      ndata_schemes={'feat': Scheme(shape=(35,), dtype=torch.float32), 'label': Scheme(shape=(2,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)}), Graph(num_nodes=1155, num_edges=22584,
      ndata_schemes={'feat': Scheme(shape=(35,), dtype=torch.float32), 'label': Scheme(shape=(2,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(2,), dtype=torch.float32)})]
{'4885e281g': 0, 'bb50381dj': 1, '89768a854': 2, '193c8d9ce': 3, '64af8

In [21]:
# 转化为列表映射，方便输出
node_id_convert_lists = []
for node_id_map in node_id_maps:
    node_id_convert_list = get_maplist(node_id_map)
    node_id_convert_lists.append(node_id_convert_list)
graph_id_convert_list = get_maplist(graph_id_map)
print(graph_id_convert_list)
for node_id_convert_list in node_id_convert_lists:
    print(node_id_convert_list)

[20230404, 20230405, 20230406, 20230407]
['4885e281g', 'bb50381dj', '89768a854', '193c8d9ce', '64af8fe53', '884a20ceb', '6f7ec6f12', 'f36233b8f', 'f5d1d6ec1', '35e4fa260', '2cf2c8f95', '80770622g', '5f7b6c47c', '14b31a6bv', '5f22e1fcu', '97648e2eh', '699b9641t', '1e1f24afs', '48dbbcd7d', '3f1597eeh', 'bebecda1g', '0b161e69y', 'dea374c66', 'cbe38a1dd', '46cf284c9', 'a5655506m', '261118855', '11d6b6dbk', 'c8e93735j', '630357338', 'b1189a99h', '4b51b36eg', 'b33f3dae6', 'f40c520bk', '2bf92410f', 'ff2a13239', '0f3e7923e', '37a8716ed', 'ff4566eeb', '96ee0f5bc', 'de4fb9727', '58ccc264b', 'e86c3ebb9', 'c82efe34c', '726a7d688', '4696e5787', '9621ccf7j', 'd9b2b901f', 'e2e936ece', '03834b1bk', '9261fea2q', 'e221d044s', '90b6d2f1p', '28aebd81n', '95875b5bj', '8c63c468f', '344b0a09m', 'e00e8454t', '3e70dcc7k', '89c07369j', 'e2d16ba7u', '10892b25h', '55c97d49k', 'd5f4558dg', 'd1002af4f', '4c79efefv', 'cc95a193e', '4dd4c296d', '8acdef0am', '7fdeeacbw', '484525f9c', '25166491r', '380375147', '2f45160b

In [22]:
def create_pre_df(date_id, graph, node_id_list):
    pre_size = len(node_id_list)
    labels = graph.ndata["label"].view(2,-1).tolist()
    activity_level_list = labels[0] 
    consumption_level_list = labels[1] 
    pre_dict = {'geohash_id': node_id_list,
                'consumption_level': consumption_level_list,
                'activity_level': activity_level_list,
                'date_id': [date_id]*pre_size}
    pre_df = pd.DataFrame(pre_dict)
    return pre_df

In [23]:
# 逐图创建dataframe
pre_dfs = []
for graph_id in range(size):
    pre_df = create_pre_df(graph_id_convert_list[graph_id],graphs[graph_id],node_id_convert_lists[graph_id])
    pre_dfs.append(pre_df)
pres_df = pd.concat(pre_dfs)
print(pres_df)
pres_df.to_csv("submit.csv",index=False,sep='\t')

     geohash_id  consumption_level  activity_level   date_id
0     4885e281g          69.879883       64.868988  20230404
1     bb50381dj          75.960556       58.894768  20230404
2     89768a854          70.469284       69.373459  20230404
3     193c8d9ce          76.805397       74.957542  20230404
4     64af8fe53          70.148422       68.967331  20230404
...         ...                ...             ...       ...
1150  f9cc6c6ed          62.072475       70.025131  20230407
1151  0e3bed012          65.536072       74.028244  20230407
1152  df181761x          61.834713       68.638641  20230407
1153  7c5dd003d          70.580002       70.786179  20230407
1154  07588afdw          74.910110       69.934853  20230407

[4617 rows x 4 columns]


### 由于预测的结点数量多于它要求的结点，需要剪掉一些结点

In [24]:
# 读取submit_example.csv
example_df = pd.read_csv("./submit_example.csv",sep='\t')
# 根据examaple_df进行合并
merge_df = example_df.merge(pres_df,on=['geohash_id', 'date_id'],how="left")
merge_df.rename(columns = {"consumption_level_y": "consumption_level", "activity_level_y":"activity_level"},  inplace=True)
merge_df.drop(columns={"consumption_level_x","activity_level_x"},axis=1,inplace=True)
merge_df = merge_df[['geohash_id', 'consumption_level', 'activity_level', 'date_id']]
print(merge_df)
# 保存submit文件
merge_df.to_csv("submit.csv",index=False,sep='\t')

     geohash_id  consumption_level  activity_level   date_id
0     4885e281g          69.879883       64.868988  20230404
1     4885e281g          72.874275       74.797760  20230405
2     4885e281g          75.460129       72.191559  20230406
3     4885e281g          71.601425       74.391968  20230407
4     5324516fr          70.638458       69.911087  20230404
...         ...                ...             ...       ...
4555  607779c2c          68.915871       77.146927  20230407
4556  1d3640fad          69.908806       69.002823  20230404
4557  1d3640fad          68.902519       67.237480  20230405
4558  1d3640fad          69.184296       69.372856  20230406
4559  1d3640fad          68.238960       64.698975  20230407

[4560 rows x 4 columns]
