## 📦 Packages and Basic Setup
---

In [None]:
%%capture
import os

import torch

torch_version = torch.__version__.split("+")
os.environ["TORCH"] = torch_version[0]
os.environ["CUDA"] = torch_version[1]

!pip install pyg-lib torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
!pip install torch-geometric
!pip install -q --upgrade wandb

import torch.nn.functional as F
import torch_geometric.transforms as T
import wandb
from torch.nn import BatchNorm1d as BN
from torch.nn import Linear, ReLU, Sequential
from torch.optim import Adam
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_mean_pool
from torch_geometric.utils import degree

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# @title ⚙ Configuration

# Paste your api key here
os.environ["WANDB_API_KEY"] = "..."

wandb.init(project="GIN", entity="graph-neural-networks")

config = wandb.config
config.lr = 0.01  # @param {type: "number"}
config.num_layers = 5  # @param {type: "number"}
config.batch_size = 32  # @param {type: "number"}
config.latent_dim = 32  # @param {type: "number"}
config.num_epochs = 50  # @param {type: "number"}
wandb.config.update(config)

## 💿 The Dataset
---

In [None]:
dataset = TUDataset("data/", "REDDIT-BINARY", cleaned=False)
dataset.data.edge_attr = None

In [None]:
class NormalizedDegree:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, data):
        deg = degree(data.edge_index[0], dtype=torch.float)
        deg = (deg - self.mean) / self.std
        data.x = deg.view(-1, 1)
        return data


if dataset.data.x is None:
    max_degree = 0
    degs = []
    for data in dataset:
        degs += [degree(data.edge_index[0], dtype=torch.long)]
        max_degree = max(max_degree, degs[-1].max().item())

    print(max_degree)

    if max_degree < 1000:
        dataset.transform = T.OneHotDegree(max_degree)
    else:
        deg = torch.cat(degs, dim=0).to(torch.float)
        mean, std = deg.mean().item(), deg.std().item()
        dataset.transform = NormalizedDegree(mean, std)

In [None]:
# train_loader = DataLoader(dataset, config.batch_size, shuffle=True)
train_loader = DataLoader(dataset, 32, shuffle=True)

In [None]:
def num_graphs(data):
    if hasattr(data, "num_graphs"):
        return data.num_graphs
    else:
        return data.x.size(0)

## ✍️ Model Architecture & Training
---

In [None]:
class GIN0(torch.nn.Module):
    def __init__(self, dataset, num_layers, hidden):
        super().__init__()
        self.conv1 = GINConv(
            Sequential(
                Linear(dataset.num_features, hidden),
                ReLU(),
                Linear(hidden, hidden),
                ReLU(),
                BN(hidden),
            ),
            train_eps=False,
        )
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            self.convs.append(
                GINConv(
                    Sequential(
                        Linear(hidden, hidden),
                        ReLU(),
                        Linear(hidden, hidden),
                        ReLU(),
                        BN(hidden),
                    ),
                    train_eps=False,
                )
            )
        self.lin1 = Linear(hidden, hidden)
        self.lin2 = Linear(hidden, dataset.num_classes)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        for conv in self.convs:
            conv.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        for conv in self.convs:
            x = conv(x, edge_index)
        x = global_mean_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1)


# model = GIN0(dataset, config.num_layers, config.latent_dim)
model = GIN0(dataset, 5, 32)

## Training
---

In [None]:
def train(model, optimizer, loader):
    model.train()

    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        data = data.to(device)
        out = model(data)
        loss = F.nll_loss(out, data.y.view(-1))
        loss.backward()
        total_loss += loss.item() * num_graphs(data)
        optimizer.step()
    return total_loss / len(loader.dataset)


def eval(model, loader):
    model.eval()

    correct = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data).max(1)[1]
        correct += pred.eq(data.y.view(-1)).sum().item()
    return correct / len(loader.dataset)

In [None]:
model.to(device).reset_parameters()
# optimizer = Adam(model.parameters(), lr=config.lr)
optimizer = Adam(model.parameters(), lr=0.01)


# for epoch in range(1, config.num_epochs + 1):
for epoch in range(1, 50 + 1):
    train_loss = train(model, optimizer, train_loader)
    train_acc = eval(model, train_loader)
    print("Training Loss: ", train_loss)
    print("Training Acc: ", train_acc)
    wandb.log({"Training Loss": train_loss, "Training Accuracy": train_acc})

wandb.run.summary["Training Loss"] = train_loss
wandb.run.summary["Training Accuracy"] = train_acc
wandb.finish()