首先使用CGCNN的CIFdataset将晶体数据整合

In [1]:
from networkx.generators.directed import gn_graph

from cgcnn.data import CIFData
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

si_dataset = './crystal_dataset'
si_data = CIFData(si_dataset)
data_len = len(si_data)
print(si_data)
print(len(si_data))

<cgcnn.data.CIFData object at 0x000001F945221E50>
5000


si_data 在经历了`CIFData`之后，有5000的长度，说明有5000张图

下面依次构建5000个graph的信息

structure含有`(atom_fea,nbr_fea,nbr_fea_idx)`
target含有CSV表中的target值
cif_id是晶体代号


In [2]:
# 要构建5000张graph
structures = []
target = []
cif_id = []
for i in tqdm(range(0,data_len)):
    structures.append(si_data[i][0])
    target.append(si_data[i][1])
    cif_id.append(si_data[i][2])

print(len(structures),len(target),len(cif_id))


100%|██████████| 5000/5000 [06:05<00:00, 13.68it/s]

5000 5000 5000





x 对应 atom_fea
edge_index 对应 nbr_fea_idx
edge_attr 对应 nbr_fea

In [27]:
from torch_geometric.data import Data
geometric_data = []
for i in tqdm(range(0,data_len)):
    geometric_data.append(Data(x = structures[i][0],
                             edge_attr = structures[i][2],
                             edge_index = structures[2][1],
                             y = target[i]))
print('len of geometric_data',len(geometric_data))
print(geometric_data[0])
print("x的维度",geometric_data[0].x.shape)
print("edge_index的维度",geometric_data[0].edge_index.shape)

100%|██████████| 5000/5000 [00:00<00:00, 29939.51it/s]

len of geometric_data 5000
Data(x=[114, 92], edge_index=[16, 12, 41], edge_attr=[114, 12], y=[1])
x的维度 torch.Size([114, 92])
edge_index的维度 torch.Size([16, 12, 41])





从这里可以看到，晶体数据集有一个特殊的地方在于，他的邻接矩阵是三维的，后面需要对其进行降维

In [31]:
from sat.data import GraphDataset
k_hop = 2
se = 'gnn'
use_edge_attr = True

for i in tqdm(range(0,data_len)):
    # 修改为long()格式
    geometric_data[i].edge_index = geometric_data[i].edge_index.long()

si_graph_data = GraphDataset(geometric_data,degree=True, k_hop=k_hop, se=se,
       use_subgraph_edge_attr=use_edge_attr)

100%|██████████| 5000/5000 [00:00<00:00, 6957.79it/s]


发现问题：
晶体数据的edge_index是三维的，而在SAT的graphdataset中图数据edge_index是二维的，所以我们这里需要设计算法，给晶体的邻接矩阵降维。
接下来，以第一张图的邻接矩阵做示范：

In [32]:
edge_index_0 = geometric_data[0].edge_index
print('降维前：',edge_index_0.shape)
edge_index_0 = edge_index_0.view(-1, edge_index_0.size(-1))
print('降维后：',edge_index_0.shape)

降维前： torch.Size([192, 41])
降维后： torch.Size([192, 41])


然后进行批量的降维

In [33]:
for i in tqdm(range(0,data_len)):
    # 降维成二维
    edge_index = geometric_data[i].edge_index
    edge_index = edge_index.view(-1, edge_index.size(-1))
    geometric_data[i].edge_index = edge_index


print(geometric_data[0].edge_index.shape)
si_graph_data = GraphDataset(geometric_data,degree=True, k_hop=k_hop, se=se,
       use_subgraph_edge_attr=use_edge_attr)

100%|██████████| 5000/5000 [00:00<00:00, 45165.12it/s]


torch.Size([192, 41])


错误分析
index 的值超出范围：

g.edge_index[0] 是边的源节点索引，它的值必须在 [0, num_nodes - 1] 范围内。
如果 g.edge_index[0] 包含的值大于或等于 num_nodes，scatter_add_ 会抛出 RuntimeError。
num_nodes 的值不正确：

num_nodes 是图中节点的数量，通常通过 g.num_nodes 获取。
如果 num_nodes 的值不正确（例如小于 g.edge_index[0] 的最大值），也会导致错误

In [30]:
print("边的索引值",edge_index_0[0].max().item())
print("num_nodes", geometric_data[0].num_nodes)

边的索引值 0.9992625713348389
num_nodes 114


In [1]:
def c_to_g(g):
    structures = []  # (atom_fea, nbr_fea, nbr_fea_idx)
    target = []
    cif_id = []
    g_list = []
    data_len = len(g)


    for i in tqdm(range(0, data_len)):
        # 修改为long()格式
        structures.append(g[i][0])
        target.append(g[i][1])
        cif_id.append(g[i][2])
        g_list.append(Data(x = structures[i][0],
                             edge_attr = structures[i][2],
                             edge_index = structures[i][1],
                             y = target[i]))
        g_list[i].edge_index = g[i].edge_index.long()
        # 降维成二维
        edge_index = g_list[i].edge_index
        edge_index = edge_index.view(-1, edge_index.size(-1))
        g_list[i].edge_index = edge_index

    print("转换完成！")

    return g_list

In [2]:
from cgcnn.data import collate_pool
from cgcnn.data import get_train_val_test_loader
collate_fn = collate_pool
#
from torch.utils.data.sampler import SubsetRandomSampler
from torch_geometric.data import DataLoader
train_ratio = 0.8
total_size = 5000
indices = list(range(total_size))
train_size = int(train_ratio * total_size)
train_sampler = SubsetRandomSampler(indices[:train_size])
print(train_sampler)

train_loader = DataLoader(si_graph_data, batch_size=128,
                              sampler=train_sampler,
                              num_workers=1,
                              collate_fn=collate_fn, pin_memory=True)

<torch.utils.data.sampler.SubsetRandomSampler object at 0x0000026E9E630CD0>


NameError: name 'si_graph_data' is not defined