#### Installing necessary dependencies in Colab

In [None]:
! pip install wandb -Uq
! pip install wget
! pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html
! pip install trimesh
! pip install pytorch_lightning

### Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric
import torch_geometric.nn as gnn
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset, Data
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

from sklearn.preprocessing import LabelEncoder

import trimesh
import networkx as nx
import glob
import numpy as np
from tqdm.notebook import tqdm

#### Logging to wandb

In [None]:
import wandb
wandb.login()

#### Downloading dataset

In [None]:
import wget
wget.download('http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip')

'ModelNet10.zip'

In [None]:
! unzip ModelNet10.zip

### GNN Classifier

#### GNN Conv Block consisting of GAT Layer

In [None]:
class GNNConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, pooling_ratio: float):
        """
        GNN Conv Block consisting of GAT Layer

        :param int in_channels: _description_
        :param int out_channels: _description_
        :param float pooling_ratio: _description_
        """
        super().__init__()

        # message passing using GAT layer
        self.conv = gnn.GATConv(in_channels, out_channels, heads=3, dropout=0.2)
        # reshaping output to out_channels with linear layer
        self.linear = nn.Linear(3*out_channels, out_channels)
        self.bn = gnn.BatchNorm(out_channels)
        self.leaky_relu = nn.LeakyReLU(0.2)
        # top k pooling layer which removes nodes based on pooling ratio
        self.pool = gnn.TopKPooling(out_channels, ratio=pooling_ratio)

    def forward(self, x: torch.Tensor, edge_idx: torch.Tensor, batch_idx: torch.Tensor):
        """
        Forward method for GNN Layer

        :param torch.Tensor x: node features of size (num_nodes in batch, num_features)
        :param torch.Tensor edge_idx: tensor representing connections between nodes of size (2, num_node_connections)
        :param torch.Tensor batch_idx: tensor reprenting which node belongs to certain graph in batch
        """
        x = self.conv(x, edge_idx)
        x = self.linear(x)
        x = self.leaky_relu(x)
        x, edge_idx, _, batch_idx, _, _ = self.pool(x, edge_idx, None, batch_idx)

        # calculating intermediate output which will be summed up with other intermediate ouputs
        x_pooled = torch.cat([gnn.global_max_pool(x, batch_idx), gnn.global_mean_pool(x, batch_idx)], dim=1)

        return x, x_pooled, edge_idx, batch_idx

#### GNN Classifier Model

In [None]:
class GNNClassifier(pl.LightningModule):
    def __init__(self, in_channels: int, embedding_size: int, out_classes: int, num_layers: int, lr: float, weight_decay: float):
        """
        GNN Classifier

        :param int in_channels: number of input node features
        :param int embedding_size: size of intermediate output sizes
        :param int out_classes: number of output classes
        :param int num_layers: number of message passing layers
        :param float lr: learning rate
        :param float weight_decay: weight decay regularization
        """
        super().__init__()

        # architecture
        self.initial_conv = GNNConvBlock(in_channels, embedding_size, pooling_ratio=0.5)
        self.conv_blocks = nn.ModuleList(
            [GNNConvBlock(embedding_size, embedding_size, pooling_ratio=0.5) for _ in range(num_layers)]
        )
        self.classifier = nn.Sequential(
            nn.Linear(embedding_size*2, 1024),
            nn.ReLU(),
            nn.Linear(1024, out_classes)
        )

        # parameters
        self.save_hyperparameters()

        # metrics
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()

        self.train_f1 = torchmetrics.F1Score()
        self.val_f1 = torchmetrics.F1Score()
        self.test_f1 = torchmetrics.F1Score()


    def forward(self, x: torch.Tensor, edge_idx: torch.Tensor, batch_idx: torch.Tensor):
        """
        Forward pass of GNN Classifier

        :param torch.Tensor x: node features of size (num_nodes in batch, num_features)
        :param torch.Tensor edge_idx: tensor representing connections between nodes of size (2, num_node_connections)
        :param torch.Tensor batch_idx: tensor reprenting which node belongs to certain graph in batch
        """
        # list for storing intermediate ouputs
        intermediate_outputs = []

        x, x_pooled, edge_idx, batch_idx = self.initial_conv(x, edge_idx, batch_idx)
        intermediate_outputs.append(x_pooled)

        for layer in self.conv_blocks:
            x, x_pooled, edge_idx, batch_idx = layer(x, edge_idx, batch_idx)
            intermediate_outputs.append(x_pooled)

        # summing intermediate outputs to get input for classifier module
        features = sum(intermediate_outputs)

        return self.classifier(features)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.hparams["lr"], weight_decay=self.hparams["weight_decay"])

    def training_step(self, batch, batch_idx):
        predicted = self.forward(batch.x, batch.edge_index, batch.batch)
        loss = F.cross_entropy(predicted, batch.y)
        
        self.train_acc(predicted, batch.y)
        self.train_f1(predicted, batch.y)

        self.log("train/loss", loss, on_epoch=True, batch_size=len(torch.unique(batch.batch)))
        self.log("train/acc", self.train_acc, on_epoch=True, batch_size=len(torch.unique(batch.batch)))
        self.log("train/f1", self.train_f1, on_epoch=True, batch_size=len(torch.unique(batch.batch)))

        return loss

    def validation_step(self, batch, batch_idx):
        predicted = self.forward(batch.x, batch.edge_index, batch.batch)
        loss = F.cross_entropy(predicted, batch.y)
        
        self.val_acc(predicted, batch.y)
        self.val_f1(predicted, batch.y)

        self.log("val/loss", loss, on_epoch=True, batch_size=len(torch.unique(batch.batch)))
        self.log("val/acc", self.val_acc, on_epoch=True, batch_size=len(torch.unique(batch.batch)))
        self.log("val/f1", self.val_f1, on_epoch=True, batch_size=len(torch.unique(batch.batch)))

        return predicted

    def test_step(self, batch, batch_idx):
        predicted = self.forward(batch.x, batch.edge_index, batch.batch)
        loss = F.cross_entropy(predicted, batch.y)
        
        self.test_acc(predicted, batch.y)
        self.test_f1(predicted, batch.y)

        self.log("test/loss", loss, on_epoch=True, batch_size=len(torch.unique(batch.batch)))
        self.log("test/acc", self.test_acc, on_epoch=True, batch_size=len(torch.unique(batch.batch)))
        self.log("test/f1", self.test_f1, on_epoch=True, batch_size=len(torch.unique(batch.batch)))

### Training

#### Dataset

In [None]:
class ModelNet10Dataset(Dataset):
    def __init__(self, filepaths: list):
        """
        Dataset module for ModelNet10

        :param list filepaths: list of filepaths to models
        """
        super().__init__()

        self.filepaths = filepaths
        # grabbing unique labels
        self.labels = list(set([file.split("/")[1] for file in filepaths]))
        self.label_encoder = LabelEncoder().fit(self.labels)
        
    def len(self):
        return len(self.filepaths)

    def load_off_mesh(self, file):
        # models in modelnet are in off format
        off = trimesh.exchange.off.load_off(open(file))
        mesh = trimesh.Trimesh(off["vertices"], off["faces"])
        
        return mesh

    def get(self, idx):
        mesh = self.load_off_mesh(self.filepaths[idx])
        label = self.filepaths[idx].split("/")[1]

        x = torch.tensor(mesh.vertices, dtype=torch.float32)
        edge_index = torch.tensor(mesh.edges, dtype=torch.long).t()
        encoded_label = torch.tensor(self.label_encoder.transform([label]), dtype=torch.long)

        return Data(x=x, edge_index=edge_index, y=encoded_label)

In [None]:
train_filepaths = glob.glob("ModelNet10/**/train/*.off")
validation_filepaths = glob.glob("ModelNet10/**/test/*.off")

#### Hyperparameters

In [None]:
params = dict(
    lr=3e-4,
    weight_decay=0.01,
    epochs=50,
    batch_size=32,
    save_path="model/gnn.pt",
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
)

#### Setting up training

In [None]:
def model_pipeline(params, train_filepaths, validation_filepaths):
    # setting up logger
    wandb_logger = WandbLogger(project="modelnet10_classification")

    # setting up trainer
    trainer = pl.Trainer(logger=wandb_logger, log_every_n_steps=50, accelerator='gpu', devices=-1, max_epochs=params["epochs"], precision=16)

    # datasets
    train_set = ModelNet10Dataset(train_filepaths)
    val_set = ModelNet10Dataset(validation_filepaths)

    # dataloaders
    train_loader = DataLoader(train_set, batch_size=params["batch_size"], shuffle=True)
    val_loader = DataLoader(val_set, batch_size=params["batch_size"], shuffle=False)

    # model
    model = GNNClassifier(in_channels=3, embedding_size=64, out_classes=len(train_set.labels), num_layers=4, lr=3e-4, weight_decay=0.001)

    # training
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

    wandb.finish()

In [None]:
model_pipeline(params, train_filepaths, validation_filepaths)

[34m[1mwandb[0m: Currently logged in as: [33mjasiekkaczmarczyk[0m. Use [1m`wandb login --relogin`[0m to force relogin


  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit native Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name         | Type         | Params
----------------------------------------------
0 | initial_conv | GNNConvBlock | 13.7 K
1 | conv_blocks  | ModuleList   | 101 K 
2 | classifier   | Sequential   | 142 K 
3 | train_acc    | Accuracy     | 0     
4 | val_acc      | Accuracy     | 0     
5 | test_acc     | Accuracy     | 0     
6 | train_f1     | F1Score      | 0  

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇█
train/acc_epoch,▁▄▃▂▃▅▆▇▇▇██
train/acc_step,▁▁▇▄▅▃▅▄▃▄▅▅▅▆▅▅▇▇▆▅█▅▇▆▆▇▇██▇▆
train/f1_epoch,▁▄▃▂▃▅▆▇▇▇██
train/f1_step,▁▁▇▄▅▃▅▄▃▄▅▅▅▆▅▅▇▇▆▅█▅▇▆▆▇▇██▇▆
train/loss_step,▅█▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▃▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/acc,▃▃▁▂▅▆▅▇▆▇██
val/f1,▃▃▁▂▅▆▅▇▆▇██
val/loss,▆█▇▅▄█▂▅▃▃▁▂

0,1
epoch,12.0
train/acc_epoch,0.77399
train/acc_step,0.6875
train/f1_epoch,0.77399
train/f1_step,0.6875
train/loss_epoch,
train/loss_step,2.0266
trainer/global_step,1549.0
val/acc,0.65308
val/f1,0.65308
