In [None]:
%%capture
!pip install pytorch_lightning
!pip install torch_geometric

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch_geometric import datasets, loader, transforms as T
from torch_geometric.data import Data
import pytorch_lightning as pl
from typing import Literal

In [None]:
class AddAdjointEdgeEigenvectorPE(T.BaseTransform):
    def __init__(
            self,
            k,
            is_undirected:bool = False
    ):
        self.enc = T.AddLaplacianEigenvectorPE(k=k, attr_name='adjoint_edge_pe', is_undirected=True)

    def forward(self, data):
        assert data.edge_index is not None
        num_nodes = data.num_nodes
        assert num_nodes is not None

        adj_edge_index = torch.argwhere(data.edge_index[0] == data.edge_index[1].reshape(-1, 1)).transpose(0, 1)
        adjoint = Data(
            edge_index=adj_edge_index,
            num_nodes=data.edge_index.size(1)
        )
        data.adj = self.enc(adjoint).adjoint_edge_pe
        return data

In [None]:
class AddNodeFeature(T.BaseTransform):
    def forward(self, data):
        assert data.edge_index is not None
        num_nodes = data.num_nodes
        assert num_nodes is not None
        data.x = torch.zeros(num_nodes).int()
        return data

In [None]:
class CSLDataModule(pl.LightningDataModule):
    def __init__(self, split: Literal[0, 1, 2, 3, 4], pos_enc, k: int=None):
        super().__init__()
        self.pe = pos_enc
        self.k = k
        self.split = split
        self.dataset = datasets.GNNBenchmarkDataset(
            root="~/data",
            name='CSL',
            split='train',
            transform=T.Compose([
                T.AddLaplacianEigenvectorPE(k=self.k, attr_name='pe'),
                AddAdjointEdgeEigenvectorPE(k=self.k),
                AddNodeFeature(),
                # AddHyperEdgeEigenvectorPE(k=self.k, attr_name='hypr'),
                ]) if self.pe else None,)
        g = torch.Generator()
        g.manual_seed(123)
        self.order = torch.stack(
            [15 * i + torch.randperm(15, generator=g) for i in range(10)])

    def setup(self, stage):
        match stage:
            case 'fit':
                self.train_loader = loader.DataLoader(
                    dataset=self.dataset[self.order[:, [
                        i for i in range(15)
                        if i < 5 * self.split or i >= 5 * self.split + 3
                    ]].flatten()],
                    batch_size=5,
                    shuffle=True,
                    num_workers=2,
                )
                self.val_loader = loader.DataLoader(
                    dataset=self.dataset[self.order[:, 5 *
                                                    self.split:5 * self.split +
                                                    3].flatten()],
                    batch_size=5,
                    num_workers=2,
                )
            case 'test':
                self.test_loader = loader.DataLoader(
                    dataset=self.dataset[self.order[:, 5 *
                                                    self.split:5 * self.split +
                                                    3].flatten()],
                    batch_size=5,
                    num_workers=2,
                )

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

    def test_dataloader(self):
        return self.test_loader

In [None]:
class GAT(nn.Module):

    def __init__(self, net_params):
        super().__init__()

        in_dim_node = net_params['in_dim']  # node_dim (feat is an integer)
        hidden_dim = net_params['hidden_dim']
        out_dim = net_params['out_dim']
        n_classes = net_params['n_classes']
        num_heads = net_params['n_heads']
        in_feat_dropout = net_params['in_feat_dropout']
        dropout = net_params['dropout']
        n_layers = net_params['L']

        self.readout = net_params['readout']
        self.batch_norm = net_params['batch_norm']
        self.residual = net_params['residual']
        self.dropout = dropout
        self.n_classes = n_classes
        self.device = net_params['device']
        self.pos_enc = net_params['pos_enc']
        pos_enc_dim = net_params['pos_enc_dim']
        self.edge_enc = net_params['edge_enc']
        self.edge_enc_dim = net_params['edge_enc_dim']

        if self.pos_enc:
            self.embedding_pos_enc = nn.Linear(pos_enc_dim, hidden_dim * num_heads)
        self.embedding = nn.Embedding(in_dim_node, hidden_dim *
                                          num_heads)  # node feat is an integer

        if self.edge_enc:
            self.edge_embedding = nn.Linear(pos_enc_dim, pos_enc_dim)

        self.in_feat_dropout = nn.Dropout(in_feat_dropout)

        self.layers = nn.ModuleList([
            GATLayer(hidden_dim * num_heads,
                     hidden_dim,
                     num_heads,
                     dropout,
                     self.batch_norm,
                     self.residual,
                     edge_dim=self.edge_enc_dim) for i in range(n_layers - 1)
        ])
        self.layers.append(
            GATLayer(hidden_dim * num_heads,
                     out_dim,
                     1,
                     dropout,
                     self.batch_norm,
                     self.residual,
                     edge_dim=self.edge_enc_dim))

        self.mlp = nn.Sequential(nn.Linear(out_dim, out_dim // 2, bias=True),
                                 nn.ReLU(inplace=True),
                                 nn.Linear(out_dim // 2, n_classes, bias=True))

    def forward(self,
                x,
                pe,
                edge_index,
                edge_attr=None,
                agg: Literal[None, 'mean'] = None,
                batch=None):

        # input embedding
        if self.pos_enc:
            x = self.embedding(x) + self.embedding_pos_enc(pe)
        else:
            x = self.embedding(x)
        x = self.in_feat_dropout(x)
        if self.edge_enc:
            assert edge_attr is not None
            edge_attr = self.edge_embedding(edge_attr)

        # GAT
        for i, conv in enumerate(self.layers):
            x = conv(x=x, edge_index=edge_index, edge_attr=edge_attr)

        if agg == 'mean':
            aggr = gnn.aggr.MeanAggregation()
            x = aggr(x, batch)

        # output
        h_out = self.mlp(x)

        return h_out


class GATLayer(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 num_heads,
                 dropout,
                 batch_norm,
                 residual=False,
                 activation=F.elu,
                 edge_dim=None):
        super().__init__()
        self.residual = residual
        self.activation = activation
        self.batch_norm = batch_norm
        if in_dim != (out_dim * num_heads):
            self.residual = False
        self.gatconv = gnn.GATv2Conv(in_channels=in_dim,
                                     out_channels=out_dim,
                                     heads=num_heads,
                                     dropout=dropout,
                                     edge_dim=edge_dim)
        if self.batch_norm:
            self.batchnorm_h = nn.BatchNorm1d(out_dim * num_heads)

    def forward(self, x, edge_index, edge_attr=None):
        x_in = x  # for residual connection
        x = self.gatconv(x, edge_index, edge_attr=edge_attr).flatten(1)
        if self.batch_norm:
            x = self.batchnorm_h(x)
        if self.activation:
            x = self.activation(x)
        if self.residual:
            x = x_in + x  # residual connection
        return x

In [None]:
class Model(pl.LightningModule):

    def configure_model(self):
        # from https://github.com/graphdeeplearning/benchmarking-gnns/blob/master/configs/SBMs_node_clustering_GAT_CLUSTER_100k.json
        self.net_params = {
            "L": 4,
            "n_heads": 8,
            "hidden_dim": 19,
            "out_dim": 152,
            "residual": True,
            "readout": "mean",
            "in_feat_dropout": 0.0,
            "dropout": 0.0,
            "batch_norm": True,
            "self_loop": False,
            "pos_enc": False,
            "pos_enc_dim": 20,
            "edge_enc": True,
            "hypergraph_edge_enc": False,
            "edge_enc_dim": 20,
        }
        self.net_params['device'] = 'cuda'
        self.net_params['in_dim'] = 1
        self.net_params['n_classes'] = 10
        self.model = GAT(self.net_params)

    def configure_optimizers(self):
        return torch.optim.AdamW(
            self.parameters(),
            lr=0.001,
            weight_decay=0.0,
        )

    def forward(self, batch):
        if self.net_params['pos_enc']:
            k = self.net_params['pos_enc_dim']
            sign = torch.tensor([-1, 1])[torch.randint(0, 2, (1, k))].to(self.net_params['device'])
            batch.pe *= sign

        if self.net_params['edge_enc']:
            k = self.net_params['pos_enc_dim']
            if self.net_params['hypergraph_edge_enc']:
                raise NotImplementedError('This has not been implemented')
            else:
                sign = torch.tensor([-1, 1])[torch.randint(
                    0, 2, (1, self.net_params['pos_enc_dim']))].to(self.net_params['device'])
                batch.adj *= sign

        return self.model(
            batch.x,
            batch.pe if self.net_params['pos_enc'] else None,
            batch.edge_index,
            edge_attr=batch.adj if self.net_params['edge_enc'] else None,
            batch=batch.batch,
            agg="mean",
        )

    def loss(self, batch):
        pred = self(batch)
        label = batch.y

        # calculating label weights for weighted loss computation
        V = label.size(0)
        label_count = torch.bincount(label)
        label_count = label_count[label_count.nonzero()].squeeze()
        cluster_sizes = torch.zeros(self.net_params['n_classes']).long().to(
            self.net_params['device'])
        cluster_sizes[torch.unique(label)] = label_count
        weight = (V - cluster_sizes).float() / V
        weight *= (cluster_sizes > 0).float()

        # weighted cross-entropy for unbalanced classes
        loss = F.cross_entropy(pred, label, weight)

        acc = weight[pred.argmax(-1)][(pred.argmax(-1) == label)].sum(-1)

        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, acc = self.loss(batch)
        self.log('train loss', loss, batch_size=128)
        self.log('train acc', acc, batch_size=128)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self.loss(batch)
        self.log('val loss', loss, batch_size=128)
        self.log('val acc', acc, batch_size=128)
        return loss

    def test_step(self, batch, batch_idx):
        loss, acc = self.loss(batch)
        self.log('test loss', loss, batch_size=128)
        self.log('test acc', acc, batch_size=128)
        return loss

In [None]:
dataloader = CSLDataModule(split=0, pos_enc=True, k=20)
logger = pl.loggers.WandbLogger(name=f'GAT-CSL-node',
                                    project='Edge-Encodings')

In [None]:
model = Model()
trainer = pl.Trainer(
    accelerator='gpu',
    logger=False,
    max_epochs=100,
    max_time='00:00:01:00',
    enable_checkpointing=False,
    log_every_n_steps=50,
)
trainer.fit(model, dataloader)
trainer.test(model, dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type | Params | Mode 
---------------------------------------
0 | model | GAT  | 213 K  | train
---------------------------------------
213 K     Trainable params
0         Non-trainable params
213 K     Total params
0.854     Total estimated model params size (MB)
37        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Time limit reached. Elapsed time is 0:01:00. Signaling Trainer to stop.


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test loss': 2.444157361984253, 'test acc': 0.20000000298023224}]

tensor([0, 0, 0, 1, 1])