In [1]:
import dgl
from dgl.data import DGLDataset
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import dgl.data
from dgl.nn import GraphConv,MaxPooling
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.optim as optim
import numpy as np
import time
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import os
import yaml
import time
import datetime

Using backend: pytorch


In [2]:
class STL10TrainDataset(DGLDataset):
    def __init__(self,data_path,transforms=None):
        self.data_path = data_path
        self.transforms = transforms
        super().__init__(name='stl10_train_gprah')
    
    def process(self):
        GRAPHS, LABELS = dgl.load_graphs(self.data_path) #保存したグラーフデータの読み込み
        self.graphs = GRAPHS #グラフリストを代入
        self.labels = LABELS['label'] #ラベル辞書の値のみ代入
        self.dim_nfeats=len(self.graphs[0].ndata['f'][0])

    def __getitem__(self, idx):
        if self.transforms == None:
            return self.graphs[idx], self.labels[idx]
        else:
            data=self.transforms(self.graphs[idx])
            return data,self.labels[idx]
    def __len__(self):
        return len(self.graphs)


class STL10TestDataset(DGLDataset):
    def __init__(self,data_path,transforms=None):
        self.data_path = data_path
        self.transforms = transforms
        super().__init__(name='stl10_test_gprah')
    
    def process(self):
        GRAPHS, LABELS = dgl.load_graphs(self.data_path) #保存したグラーフデータの読み込み
        self.graphs = GRAPHS #グラフリストを代入
        self.labels = LABELS['label'] #ラベル辞書の値のみ代入
        self.dim_nfeats=len(self.graphs[0].ndata['f'][0])

    def __getitem__(self, idx):
        if self.transforms == None:
            return self.graphs[idx], self.labels[idx]
        else:
            data=self.transforms(self.graphs[idx])
            return data,self.labels[idx]
        
    def __len__(self):
        return len(self.graphs)

In [11]:
class SelfGCN(nn.Module):
    def __init__(self):
        super(SelfGCN,self).__init__()
        self.input_layer=GraphConv(28,10)
        self.mid_layer1=GraphConv(512,512)
        self.mid_layer2=GraphConv(512,512)
        self.mid_layer3=GraphConv(512,512)
        self.output_layer=GraphConv(512,10)

        self.m=nn.LeakyReLU()

        self.flatt=nn.Flatten()

    
    def forward(self,g,n_feat):
        '''h=self.flatt(n_feat)
        h=self.input_layer(g,h)
        h=self.m(h)

        h=self.mid_layer1(g,h)
        h=self.m(h)

        h=self.mid_layer2(g,h)
        h=self.m(h)

        h=self.mid_layer3(g,h)
        h=self.m(h)

        h=self.output_layer(g,h)
        h=self.m(h)'''
        h=self.input_layer(g,n_feat)
        h=torch.mean(h,(1,2))
        g.ndata['h'] = h

        return dgl.mean_nodes(g,'h')

In [4]:
data_path='ndata_8patch.dgl'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
traindataset=STL10TrainDataset(f'../data/STL10 Datasets/train/{data_path}')
testdataset=STL10TestDataset(f'../data/STL10 Datasets/test/{data_path}')

#データローダー作成
num_workers=2
traindataloader = GraphDataLoader(traindataset,batch_size = 512,shuffle = True,num_workers = num_workers,pin_memory = True)
testdataloader = GraphDataLoader(testdataset,batch_size = 100,shuffle = True,num_workers = num_workers,pin_memory = True)




In [5]:
def train(num,epochs,lr):
    start=time.time()
    #結果を保存するディレクトリを作成
    save_dir=f'save/{data_path}/patch_test/{num}'
    os.makedirs(save_dir,exist_ok=True)


    #モデルの初期化
    model=SelfGCN()
    model.to(device)
    lossF=nn.CrossEntropyLoss()
    optimizer=optim.AdamW(model.parameters(),lr=lr)

    #情報保存用の変数の初期化
    #トレーニング用
    num_correct=0
    num_tests=0
    train_loss_list = []
    train_acc_list = []
    #テスト用
    test_num_correct = 0
    test_num_tests = 0
    best_acc=0
    test_acc_list = []


    for epoch in tqdm(range(epochs)):
        #トレーニング
        model.train()
        for batched_grapg, labels in traindataloader:
            batched_grapg = batched_grapg.to(device)
            labels = labels.to(device)
            print(f'labels shape: {labels.shape}')
            pred = model(batched_grapg,batched_grapg.ndata['f'])
            loss=lossF(pred,labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            num_correct += (pred.argmax(1) == labels).sum().item()
            num_tests += len(labels)
        train_loss_list.append(loss.item())
        train_acc_list.append(num_correct / num_tests)
        #カウントリセット
        num_correct=num_tests=0

        #テスト
        model.eval()
        for tbatched_graph, tlabels in testdataloader:
            tbatched_graph = tbatched_graph.to(device)
            tlabels = tlabels.to(device)
            tpred = model(tbatched_graph, tbatched_graph.ndata['f'])
            test_num_correct += (tpred.argmax(1) == tlabels).sum().item()
            test_num_tests += len(tlabels)

        test_acc_list.append(test_num_correct/test_num_tests)
        if best_acc < test_num_correct/test_num_tests:
            best_acc = test_num_correct/test_num_tests
            best_weight = model
        #カウントリセット
        test_num_correct=test_num_tests=0

    #完全学習後の正答率の計算(推論)
    with torch.no_grad():
        #情報保存用の変数の初期化
        #トレーニング用
        num_correct=0
        num_tests=0
        save_train_acc=0
        #テスト用
        test_num_correct = 0
        test_num_tests = 0
        save_test_acc=0

        #全トレーニングデータでの正答率計算
        model.train()
        for batched_graph, labels in traindataloader:
            batched_graph = batched_graph.to(device)
            labels = labels.to(device)
            pred = model(batched_graph, batched_graph.ndata['f'])
            num_correct += (pred.argmax(1) == labels).sum().item()
            num_tests += len(labels)
        print('Training accuracy:', num_correct / num_tests)
        save_train_acc=(num_correct / num_tests)

        #全テストデータでの正答率
        model.eval()
        for batched_graph, labels in testdataloader:
            batched_graph = batched_graph.to(device)
            labels = labels.to(device)
            pred = model(batched_graph, batched_graph.ndata['f'])
            test_num_correct += (pred.argmax(1) == labels).sum().item()
            test_num_tests += len(labels)
        print('Test accuracy:', test_num_correct / test_num_tests)
        save_test_acc=(test_num_correct / test_num_tests)

    #各エポックごとの損失・正答率の記録をモデルごとに.npy形式で保存
    np.save(f'{save_dir}/train_loss_list',train_loss_list)
    np.save(f'{save_dir}/train_acc_list',train_acc_list)
    np.save(f'{save_dir}/test_acc_list',test_acc_list)
    torch.save(model,f'{save_dir}/model_weight.pth')
    torch.save(best_weight,f'{save_dir}/best_model_weight.pth')
    #完全学習後のトレーニング・テストデータそれぞれの正答率を.yaml形式で保存
    log={'train acc':save_train_acc,
        'test acc':save_test_acc,
        'epochs':epochs,
        'best test acc':best_acc,
        'date time':datetime.datetime.now(),
        'run time':time.time() - start}
        
    with open(f'{save_dir}/acc_result.yaml',"w") as f:
        yaml.dump(log,f)

    torch.cuda.empty_cache()


In [12]:
leran_num=2
lr=0.0001
epochs=2
for l in range(leran_num):
    print(f'{l+1}回目')
    train(l,epochs,lr)

  0%|          | 0/2 [00:00<?, ?it/s]

1回目
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([392])
h shape is torch.Size([25088, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is

 50%|█████     | 1/2 [00:03<00:03,  3.56s/it]

labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([392])
h shape is torch.Size([25088, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is tor

100%|██████████| 2/2 [00:07<00:00,  3.59s/it]


h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([25088, 3, 28, 10])
Training accuracy: 0.1
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6

  0%|          | 0/2 [00:00<?, ?it/s]

Test accuracy: 0.1
2回目
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([392])
h shape is torch.Size([25088, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 

 50%|█████     | 1/2 [00:03<00:03,  3.50s/it]

h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([512])
h shape is torch.Size([32768, 3, 28, 10])
labels shape: torch.Size([392])
h shape is torch.Size([25088, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is tor

100%|██████████| 2/2 [00:07<00:00,  3.54s/it]


h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([32768, 3, 28, 10])
h shape is torch.Size([25088, 3, 28, 10])
Training accuracy: 0.1
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6400, 3, 28, 10])
h shape is torch.Size([6