<a href="https://colab.research.google.com/github/GuJiMao/GNN_Learning/blob/main/%E6%95%B0%E6%8D%AE%E9%9B%86%E6%A0%BC%E5%BC%8F.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 载入数据集

原作者提供了一些关于[数据集加载的例子](https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/datasets)，可以其中的一个`amazon.py`打开看一下

In [None]:
import os.path as osp

import torch
from torch_geometric.data import InMemoryDataset, download_url
from torch_geometric.io import read_npz


class Amazon(InMemoryDataset):
    r"""The Amazon Computers and Amazon Photo networks from the
    `"Pitfalls of Graph Neural Network Evaluation"
    <https://arxiv.org/abs/1811.05868>`_ paper.
    Nodes represent goods and edges represent that two goods are frequently
    bought together.
    Given product reviews as bag-of-words node features, the task is to
    map goods to their respective product category.
    Args:
        root (string): Root directory where the dataset should be saved.
        name (string): The name of the dataset (:obj:`"Computers"`,
            :obj:`"Photo"`).
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
    """

    url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/'

    def __init__(self, root, name, transform=None, pre_transform=None):
        self.name = name.lower()
        assert self.name in ['computers', 'photo']
        super(Amazon, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return osp.join(self.root, self.name.capitalize(), 'raw')

    @property
    def processed_dir(self):
        return osp.join(self.root, self.name.capitalize(), 'processed')

    @property
    def raw_file_names(self):
        return 'amazon_electronics_{}.npz'.format(self.name.lower())

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        download_url(self.url + self.raw_file_names, self.raw_dir)

    def process(self):
        data = read_npz(self.raw_paths[0])
        data = data if self.pre_transform is None else self.pre_transform(data)
        data, slices = self.collate([data])
        torch.save((data, slices), self.processed_paths[0])

    def __repr__(self):
        return '{}{}()'.format(self.__class__.__name__, self.name.capitalize())


# 数据格式
我在本地查看了一下数据的格式，因为后期我们要准备的是自己的数据，所以要看一下这些教程的数据长什么样子

In [None]:
import numpy as np


data = np.load('F:/浏览器下载/amazon_electronics_photo.npz',allow_pickle = True)  # 加载数据
print(data.files)  # 查看npz文件下有几个压缩文件npy
print(type(data.files))
data_dict = {}
for file_name in data.files:
    data_dict[file_name] = data[file_name] # 提取其中的columns数组，视为数据的标签
for key in data_dict.keys():
    print('数据的名字：{0}， 数据的大小：{1}'.format(key, data_dict[key].shape))

输出：

数据的名字：adj_data， 数据的大小：(143663,)

数据的名字：adj_indices， 数据的大小：(143663,)

数据的名字：adj_indptr， 数据的大小：(7651,)

数据的名字：adj_shape， 数据的大小：(2,)

数据的名字：attr_data， 数据的大小：(1979909,)

数据的名字：attr_indices， 数据的大小：(1979909,)

数据的名字：attr_indptr， 数据的大小：(7651,)

数据的名字：attr_shape， 数据的大小：(2,)

数据的名字：labels， 数据的大小：(7650,)

数据的名字：class_names， 数据的大小：(8,)
>
