In [1]:
import dgl
from dgl.data import DGLDataset
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv,MaxPooling
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import time
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import os
import pandas
%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CIFAR10TrainDataset(DGLDataset):
    def __init__(self,data_path):
        self.data_path = data_path
        super().__init__(name='cifar10_train__gprah')
    
    def process(self):
        GRAPHS, LABELS = dgl.load_graphs(self.data_path) #保存したグラーフデータの読み込み
        self.graphs = GRAPHS #グラフリストを代入
        self.labels = LABELS['label'] #ラベル辞書の値のみ代入

    def __getitem__(self, idx):
        return self.graphs[idx], self.labels[idx]

    def __len__(self):
        return len(self.graphs)


class CIFAR10TestDataset(DGLDataset):
    def __init__(self,data_path):
        self.data_path = data_path
        super().__init__(name='cifar10_test_gprah')
    
    def process(self):
        GRAPHS, LABELS = dgl.load_graphs(self.data_path) #保存したグラーフデータの読み込み
        self.graphs = GRAPHS #グラフリストを代入
        self.labels = LABELS['label'] #ラベル辞書の値のみ代入

    def __getitem__(self, idx):
        return self.graphs[idx], self.labels[idx]

    def __len__(self):
        return len(self.graphs)

In [3]:
traindataset = CIFAR10TrainDataset("../data/NewMyData/train_dist_40_full.dgl")
testdataset = CIFAR10TestDataset("../data/NewMyData/test_dist_40_full.dgl")

In [4]:
def return_two_list(node_num):
    taikaku = torch.full((node_num,node_num),fill_value=1.)
    for i in range(node_num):
        taikaku[i][i] = 0.
    src_ids = []
    dst_ids = []
    for i in range(node_num):
        for j in range(i,node_num):
            if taikaku[i][j] != 0:
                src_ids.append(i)
                dst_ids.append(j)
                src_ids.append(j)
                dst_ids.append(i)
    tensor_src = torch.tensor(src_ids)
    tensor_dst = torch.tensor(dst_ids)
    return tensor_src,tensor_dst

In [5]:
num_node_list = [5,10,15,20,25,30,35]
#num_node_list = [5]
graphs = []
labels = []
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [6]:
for node_num in num_node_list:
    print(f'node : {node_num}')
    src,dst = return_two_list(node_num)
    graphs = []
    labels = []
    for graph,label in tqdm(traindataset):
        graph = graph.to(device)
        pool_graph = torch.zeros((node_num,node_num),device=device)
        for p in range(node_num):
            pool_graph[p] = graph.ndata['feat value'][-node_num + p][-node_num:]
        g = dgl.graph((src,dst),num_nodes=node_num,device=device)
        g.ndata['feat value'] = pool_graph
        graphs.append(g)
        labels.append(label)
    output_labels = {'label':torch.tensor(labels)}
    path = f'../data/somedata/train_dist_{node_num}_full.dgl'
    dgl.save_graphs(path,g_list=graphs,labels=output_labels)

        
for node_num in num_node_list:
    print(f'node : {node_num}')
    src,dst = return_two_list(node_num)
    graphs = []
    labels = []
    for graph,label in tqdm(testdataset):
        graph = graph.to(device)
        pool_graph = torch.zeros((node_num,node_num),device=device)
        for p in range(node_num):
            pool_graph[p] = graph.ndata['feat value'][-node_num + p][-node_num:]
        g = dgl.graph((src,dst),num_nodes=node_num,device=device)
        g.ndata['feat value'] = pool_graph
        graphs.append(g)
        labels.append(label)
    output_labels = {'label':torch.tensor(labels)}
    path = f'../data/somedata/test_dist_{node_num}_full.dgl'
    dgl.save_graphs(path,g_list=graphs,labels=output_labels)


node : 5


100%|██████████| 50000/50000 [00:30<00:00, 1630.33it/s]


node : 10


100%|██████████| 50000/50000 [00:34<00:00, 1458.27it/s]


node : 15


100%|██████████| 50000/50000 [00:38<00:00, 1288.01it/s]


node : 20


100%|██████████| 50000/50000 [00:42<00:00, 1183.85it/s]


node : 25


100%|██████████| 50000/50000 [00:44<00:00, 1123.60it/s]


node : 30


100%|██████████| 50000/50000 [00:47<00:00, 1045.22it/s]


node : 35


100%|██████████| 50000/50000 [00:49<00:00, 1004.46it/s]


node : 5


100%|██████████| 10000/10000 [00:06<00:00, 1574.65it/s]


node : 10


100%|██████████| 10000/10000 [00:05<00:00, 1677.26it/s]


node : 15


100%|██████████| 10000/10000 [00:07<00:00, 1390.34it/s]


node : 20


100%|██████████| 10000/10000 [00:07<00:00, 1309.01it/s]


node : 25


100%|██████████| 10000/10000 [00:09<00:00, 1096.37it/s]


node : 30


100%|██████████| 10000/10000 [00:09<00:00, 1072.10it/s]


node : 35


100%|██████████| 10000/10000 [00:10<00:00, 958.96it/s]
