In [1]:
import dgl
from dgl.data import DGLDataset
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import dgl.data
from dgl.nn import GraphConv,MaxPooling
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.optim as optim
import numpy as np
import time
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import os
import yaml
import time
import datetime

In [3]:
class STL10TrainDataset(DGLDataset):
    def __init__(self,data_path,transforms=None):
        self.data_path = data_path
        self.transforms = transforms
        super().__init__(name='stl10_train_gprah')
    
    def process(self):
        GRAPHS, LABELS = dgl.load_graphs(self.data_path) #保存したグラーフデータの読み込み
        self.graphs = GRAPHS #グラフリストを代入
        self.labels = LABELS['label'] #ラベル辞書の値のみ代入
        self.dim_nfeats=len(self.graphs[0].ndata['f'][0])

    def __getitem__(self, idx):
        if self.transforms == None:
            return self.graphs[idx], self.labels[idx]
        else:
            data=self.transforms(self.graphs[idx])
            return data,self.labels[idx]
    def __len__(self):
        return len(self.graphs)


class STL10TestDataset(DGLDataset):
    def __init__(self,data_path,transforms=None):
        self.data_path = data_path
        self.transforms = transforms
        super().__init__(name='stl10_test_gprah')
    
    def process(self):
        GRAPHS, LABELS = dgl.load_graphs(self.data_path) #保存したグラーフデータの読み込み
        self.graphs = GRAPHS #グラフリストを代入
        self.labels = LABELS['label'] #ラベル辞書の値のみ代入
        self.dim_nfeats=len(self.graphs[0].ndata['f'][0])

    def __getitem__(self, idx):
        if self.transforms == None:
            return self.graphs[idx], self.labels[idx]
        else:
            data=self.transforms(self.graphs[idx])
            return data,self.labels[idx]
        
    def __len__(self):
        return len(self.graphs)

In [None]:
class PatchGCN(nn.Module):
    def __init__(self):
        super(PatchGCN,self).__init__()

        self.input_layer=GraphConv(28,56)
        self.mid_layer=GraphConv(56,112)
        self.output_layer=GraphConv(112,10)

        self.m=nn.LeakyReLU()
        self.flatt=nn.Flatten()

    def forward(self,g,n_feat):
        h=self.input_layer(g,n_feat)
        h=self.mid_layer(g,h)
        h=self.output_layer(g,h)

        g.ndata['h'] = h

        return dgl.mean_nodes(g,'h')

In [6]:
data_path='ndata_8patch100.dgl'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
traindataset=STL10TrainDataset(f'../../data/STL10 Datasets/train/{data_path}')
testdataset=STL10TestDataset(f'../../data/STL10 Datasets/test/{data_path}')