In [119]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch.nn.functional as F
from torchmetrics.functional import accuracy
from torch.utils.data import random_split, DataLoader, Subset

import os
import dgl
import dgl.data
from dgl.nn import GraphConv
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

In [120]:
class GNNModel(pl.LightningModule):
    def __init__(self, in_feats, h_feats, num_classes):
        super().__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

        self.loss = nn.CrossEntropyLoss()

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, "h")
    

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.01)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = model(x , x.ndata["attr"].float())
        loss = self.loss(logits, y)
        acc = accuracy(logits.softmax(dim=-1), y, task='multiclass', num_classes=2)
        pbar = {'train_acc': acc}
        return {'loss': loss, 'progress_bar': pbar}


    def test_step(self, batch, batch_idx):
        results = self.training_step(batch, batch_idx)
        results['progress_bar']['test_acc'] = results['progress_bar']['train_acc']
        self.log('test_loss', results['loss'], prog_bar=True)
        self.log('test_acc', results['progress_bar']['test_acc'])
        return results


    def on_test_epoch_end(self):
        avg_test_loss = self.trainer.callback_metrics['test_loss'].item()
        avg_test_acc = self.trainer.callback_metrics['test_acc'].item()
        print("Test epoch ended.")
        print(f"Mean test loss: {avg_test_loss:.4f}")
        print(f"Mean test accuracy: {avg_test_acc:.4f}")

        #使用train_step进行验证
    def validation_step(self, batch, batch_idx):
        results = self.training_step(batch, batch_idx)
        self.log('val_loss',results['loss'], prog_bar=True)
        self.log('val_acc',results['progress_bar']['train_acc'])
        return results
    
    #在每次循环结束后计算其损失以及准确率
    
    def on_validation_epoch_end(self):
        avg_val_loss = self.trainer.callback_metrics['val_loss'].item()
        avg_val_acc = self.trainer.callback_metrics['val_acc'].item()
        print("Validation epoch ended.")
        
        print(f"Mean validation loss: {avg_val_loss:.4f}")

        print(f"Mean validation accuracy: {avg_val_acc:.4f}")

    def predict_step(self, batch, batch_idx):
        x, _ = batch
        logits = model(x,x.ndata["attr"].float())
        preds = logits.softmax(dim=-1)
        return preds

    def on_predict_batch_end(self, outputs, batch, batch_idx):
        prediction = outputs.cpu().argmax(dim=1)
        print("********predict results***********")
        print(prediction.item())
        


In [121]:
from torch.utils.data.sampler import SubsetRandomSampler
dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)
num_examples = len(dataset)
num_train = int(num_examples * 0.8)
num_validation = int(num_examples * 0.9)
print(num_examples)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_validation))
validation_sampler = SubsetRandomSampler(torch.arange(num_validation, num_examples))

train_loader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=5, drop_last=False
)
test_loader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=5, drop_last=False
)
validation_loader = GraphDataLoader(
    dataset, sampler=validation_sampler, batch_size=5, drop_last=False
)
predict_loader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=1, drop_last=False
)

model = GNNModel(dataset.dim_nfeats, 16, dataset.gclasses)

# 创建 Trainer 并在 GPU 上训练
trainer = pl.Trainer(max_epochs=5, accelerator='gpu', devices=[0])
trainer.fit(model,train_loader,validation_loader)

# 在测试集上测试模型
trainer.test(model,test_loader)

# 使用模型进行预测
trainer.predict(model,predict_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | conv1 | GraphConv        | 64     | train
1 | conv2 | GraphConv        | 34     | train
2 | loss  | CrossEntropyLoss | 0      | train
---------------------------------------------------
98        Trainable params
0         Non-trainable params
98        Total params
0.000     Total estimated model params size (MB)


1113
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 75.15it/s]Validation epoch ended.
Mean validation loss: 0.7019
Mean validation accuracy: 0.4000
                                                                           

d:\learningsoft\envs\pytorch_gpu\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=17` in the `DataLoader` to improve performance.
d:\learningsoft\envs\pytorch_gpu\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=17` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 178/178 [00:02<00:00, 77.53it/s, v_num=57]Validation epoch ended.
Mean validation loss: 1.2200
Mean validation accuracy: 0.0000
Epoch 1: 100%|██████████| 178/178 [00:02<00:00, 78.97it/s, v_num=57, val_loss=1.220]Validation epoch ended.
Mean validation loss: 1.0433
Mean validation accuracy: 0.1607
Epoch 2: 100%|██████████| 178/178 [00:02<00:00, 78.86it/s, v_num=57, val_loss=1.040]Validation epoch ended.
Mean validation loss: 1.2952
Mean validation accuracy: 0.0446
Epoch 3: 100%|██████████| 178/178 [00:02<00:00, 79.08it/s, v_num=57, val_loss=1.300]Validation epoch ended.
Mean validation loss: 1.2869
Mean validation accuracy: 0.0893
Epoch 4: 100%|██████████| 178/178 [00:02<00:00, 78.11it/s, v_num=57, val_loss=1.290]Validation epoch ended.
Mean validation loss: 1.6113
Mean validation accuracy: 0.0000
Epoch 4: 100%|██████████| 178/178 [00:02<00:00, 70.31it/s, v_num=57, val_loss=1.610]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 178/178 [00:02<00:00, 70.22it/s, v_num=57, val_loss=1.610]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
d:\learningsoft\envs\pytorch_gpu\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=17` in the `DataLoader` to improve performance.



Testing DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 106.84it/s]Test epoch ended.
Mean test loss: 1.6358
Mean test accuracy: 0.0000
Testing DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 105.40it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
d:\learningsoft\envs\pytorch_gpu\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=17` in the `DataLoader` to improve performance.



────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc                    0.0
        test_loss            1.635769248008728
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Predicting DataLoader 0:   1%|          | 1/111 [00:00<00:01, 99.86it/s]********predict results***********
tensor([0])
Predicting DataLoader 0:   2%|▏         | 2/111 [00:00<00:01, 101.46it/s]********predict results***********
tensor([0])
Predicting DataLoader 0:   3%|▎         | 3/111 [00:00<00:01, 105.08it/s]********predict results***********
tensor([0])
Predicting DataLoader 0:   4%|▎         | 4/111 [00:00<00:00, 107.24it/s]********predict results***********
tensor([0])
Predicting DataLoader 0:  

[tensor([[0.9159, 0.0841]]),
 tensor([[0.7961, 0.2039]]),
 tensor([[0.8400, 0.1600]]),
 tensor([[0.7167, 0.2833]]),
 tensor([[0.9007, 0.0993]]),
 tensor([[0.5866, 0.4134]]),
 tensor([[0.7491, 0.2509]]),
 tensor([[0.5860, 0.4140]]),
 tensor([[0.5232, 0.4768]]),
 tensor([[0.8872, 0.1128]]),
 tensor([[0.6253, 0.3747]]),
 tensor([[0.9012, 0.0988]]),
 tensor([[0.8072, 0.1928]]),
 tensor([[0.5236, 0.4764]]),
 tensor([[0.8626, 0.1374]]),
 tensor([[0.8828, 0.1172]]),
 tensor([[0.5276, 0.4724]]),
 tensor([[0.8720, 0.1280]]),
 tensor([[0.8917, 0.1083]]),
 tensor([[0.8185, 0.1815]]),
 tensor([[0.5860, 0.4140]]),
 tensor([[0.5226, 0.4774]]),
 tensor([[0.9140, 0.0860]]),
 tensor([[0.8385, 0.1615]]),
 tensor([[0.7653, 0.2347]]),
 tensor([[0.7202, 0.2798]]),
 tensor([[0.9357, 0.0643]]),
 tensor([[0.9078, 0.0922]]),
 tensor([[0.7616, 0.2384]]),
 tensor([[0.5232, 0.4768]]),
 tensor([[0.5867, 0.4133]]),
 tensor([[0.7382, 0.2618]]),
 tensor([[0.8821, 0.1179]]),
 tensor([[0.9041, 0.0959]]),
 tensor([[0.89