In [66]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch.nn.functional as F
from torchmetrics.functional import accuracy
from torch.utils.data import random_split, DataLoader, Subset

import os
import dgl
import dgl.data
from dgl.nn import GraphConv
from dgl.dataloading import GraphDataLoader


In [67]:
from torch.utils.data.sampler import SubsetRandomSampler
class GNNModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv1 = GraphConv(3, 16)
        self.conv2 = GraphConv(16, 2)

        self.loss = nn.CrossEntropyLoss()

    def forward(self, g):
        h = self.conv1(g, g.ndata["attr"].float())
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, "h")
    

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.01)

    def prepare_data(self):
        dgl.data.GINDataset("PROTEINS", self_loop=True)
        
    def setup(self,stage = None):
        # 准备数据集
        dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)
        num_examples = len(dataset)
        num_train = int(num_examples * 0.8)
        num_validation = int(num_examples * 0.9)

        self.train_sampler = Subset(dataset,torch.arange(num_train))
        self.test_sampler = Subset(dataset,torch.arange(num_train, num_validation))
        self.validation_sampler = Subset(dataset, torch.arange(num_validation, num_examples))
        self.predict_sampler = Subset(dataset, list(range(10)))  #将predict的集合缩减为10个样本
        
    def train_dataloader(self):
        train_loader = GraphDataLoader(self.train_sampler, batch_size=5, drop_last = False)
        return train_loader

    def val_dataloader(self):
        val_loader = GraphDataLoader(self.validation_sampler, batch_size=5)
        return val_loader
    
    def test_dataloader(self):
        return GraphDataLoader(self.test_sampler, batch_size=5)
    
    def predict_dataloader(self):
        return GraphDataLoader(self.predict_sampler, batch_size=1)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        acc = accuracy(logits.softmax(dim=-1), y, task='multiclass', num_classes=2)
        pbar = {'train_acc': acc}
        return {'loss': loss, 'progress_bar': pbar}


    def test_step(self, batch, batch_idx):
        results = self.training_step(batch, batch_idx)
        results['progress_bar']['test_acc'] = results['progress_bar']['train_acc']
        self.log('test_loss', results['loss'], prog_bar=True)
        self.log('test_acc', results['progress_bar']['test_acc'])
        return results


    def on_test_epoch_end(self):
        avg_test_loss = self.trainer.callback_metrics['test_loss'].item()
        avg_test_acc = self.trainer.callback_metrics['test_acc'].item()
        print("Test epoch ended.")
        print(f"Mean test loss: {avg_test_loss:.4f}")
        print(f"Mean test accuracy: {avg_test_acc:.4f}")

        #使用train_step进行验证
    def validation_step(self, batch, batch_idx):
        results = self.training_step(batch, batch_idx)
        self.log('val_loss',results['loss'], prog_bar=True)
        self.log('val_acc',results['progress_bar']['train_acc'])
        return results
    
    #在每次循环结束后计算其损失以及准确率
    
    def on_validation_epoch_end(self):
        avg_val_loss = self.trainer.callback_metrics['val_loss'].item()
        avg_val_acc = self.trainer.callback_metrics['val_acc'].item()
        print("Validation epoch ended.")
        
        print(f"Mean validation loss: {avg_val_loss:.4f}")

        print(f"Mean validation accuracy: {avg_val_acc:.4f}")

    def predict_step(self, batch, batch_idx):
        x, _ = batch
        logits = self(x)
        preds = logits.softmax(dim=-1)
        return preds

    def on_predict_batch_end(self, outputs, batch, batch_idx):
        prediction = outputs.cpu().argmax(dim=1)
        print("    ")
        print("********predict results***********")
        print(prediction.item())
        


In [68]:
'''from torch.utils.data.sampler import SubsetRandomSampler
dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)
num_examples = len(dataset)
num_train = int(num_examples * 0.8)
num_validation = int(num_examples * 0.9)


train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_validation))
validation_sampler = SubsetRandomSampler(torch.arange(num_validation, num_examples))

train_loader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=5, drop_last=False
)
test_loader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=5, drop_last=False
)
validation_loader = GraphDataLoader(
    dataset, sampler=validation_sampler, batch_size=5, drop_last=False
)
predict_loader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=1, drop_last=False
)'''

model = GNNModel()

# 创建 Trainer 并在 GPU 上训练
trainer = pl.Trainer(max_epochs=5, accelerator='gpu', devices=[0])
trainer.fit(model)

# 在测试集上测试模型
trainer.test(model)

# 使用模型进行预测
trainer.predict(model)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | conv1 | GraphConv        | 64     | train
1 | conv2 | GraphConv        | 34     | train
2 | loss  | CrossEntropyLoss | 0      | train
---------------------------------------------------
98        Trainable params
0         Non-trainable params
98        Total params
0.000     Total estimated model params size (MB)


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 81.76it/s]Validation epoch ended.
Mean validation loss: 0.7349
Mean validation accuracy: 0.2000
                                                                           

d:\learningsoft\envs\pytorch_gpu\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=17` in the `DataLoader` to improve performance.
d:\learningsoft\envs\pytorch_gpu\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=17` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 178/178 [00:02<00:00, 74.00it/s, v_num=79]Validation epoch ended.
Mean validation loss: 0.3357
Mean validation accuracy: 1.0000
Epoch 1: 100%|██████████| 178/178 [00:02<00:00, 79.95it/s, v_num=79, val_loss=0.336]Validation epoch ended.
Mean validation loss: 0.8713
Mean validation accuracy: 0.0000
Epoch 2: 100%|██████████| 178/178 [00:02<00:00, 78.23it/s, v_num=79, val_loss=0.871]Validation epoch ended.
Mean validation loss: 0.9731
Mean validation accuracy: 0.0000
Epoch 3: 100%|██████████| 178/178 [00:02<00:00, 78.19it/s, v_num=79, val_loss=0.973]Validation epoch ended.
Mean validation loss: 0.9927
Mean validation accuracy: 0.0000
Epoch 4: 100%|██████████| 178/178 [00:02<00:00, 75.25it/s, v_num=79, val_loss=0.993]Validation epoch ended.
Mean validation loss: 0.9965
Mean validation accuracy: 0.0000
Epoch 4: 100%|██████████| 178/178 [00:02<00:00, 67.80it/s, v_num=79, val_loss=0.997]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 178/178 [00:02<00:00, 67.61it/s, v_num=79, val_loss=0.997]
                                                  

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
d:\learningsoft\envs\pytorch_gpu\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=17` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 98.16it/s] Test epoch ended.
Mean test loss: 0.9965
Mean test accuracy: 0.0000
Testing DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 96.13it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc                    0.0
        test_loss           0.9964700937271118
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
d:\learningsoft\envs\pytorch_gpu\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=17` in the `DataLoader` to improve performance.


Predicting DataLoader 0:  10%|█         | 1/10 [00:00<00:00, 103.51it/s]    
********predict results***********
0
Predicting DataLoader 0:  20%|██        | 2/10 [00:00<00:00, 103.37it/s]    
********predict results***********
0
Predicting DataLoader 0:  30%|███       | 3/10 [00:00<00:00, 104.81it/s]    
********predict results***********
0
Predicting DataLoader 0:  40%|████      | 4/10 [00:00<00:00, 115.86it/s]    
********predict results***********
0
Predicting DataLoader 0:  50%|█████     | 5/10 [00:00<00:00, 116.93it/s]    
********predict results***********
0
Predicting DataLoader 0:  60%|██████    | 6/10 [00:00<00:00, 117.33it/s]    
********predict results***********
0
Predicting DataLoader 0:  70%|███████   | 7/10 [00:00<00:00, 124.38it/s]    
********predict results***********
0
Predicting DataLoader 0:  80%|████████  | 8/10 [00:00<00:00, 124.80it/s]    
********predict results***********
0
Predicting DataLoader 0:  90%|█████████ | 9/10 [00:00<00:00, 123.86it/s]    
********pre

[tensor([[0.6308, 0.3692]]),
 tensor([[0.6308, 0.3692]]),
 tensor([[0.6308, 0.3692]]),
 tensor([[0.6308, 0.3692]]),
 tensor([[0.6308, 0.3692]]),
 tensor([[0.6308, 0.3692]]),
 tensor([[0.6308, 0.3692]]),
 tensor([[0.6308, 0.3692]]),
 tensor([[0.6308, 0.3692]]),
 tensor([[0.6308, 0.3692]])]