In [None]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics
from torch import nn
from torch.optim import Adam
from torch_geometric.nn import XConv, fps, global_mean_pool

In [5]:
class PointCNN(pl.LightningModule):
    def __init__(self, numfeatures=1):
        super().__init__()
        self.learning_rate = 1e-3  # learning rate
        self.train_acc = torchmetrics.Accuracy(
            full_state_update=False
        )  # traning accuracy
        self.val_acc = torchmetrics.Accuracy(
            full_state_update=False
        )  # validation accuracy
        self.test_acc = torchmetrics.Accuracy(full_state_update=False)  # test accuracy
        self.numfeatures = numfeatures  # number of features

        # First XConv layer
        self.conv1 = XConv(
            self.numfeatures, 48, dim=3, kernel_size=8, hidden_channels=32
        )

        # Second XConv layer
        self.conv2 = XConv(
            48, 96, dim=3, kernel_size=12, hidden_channels=64, dilation=2
        )

        # Third XConv layer
        self.conv3 = XConv(
            96, 192, dim=3, kernel_size=16, hidden_channels=128, dilation=2
        )

        # Fourth XConv layer
        self.conv4 = XConv(
            192, 384, dim=3, kernel_size=16, hidden_channels=256, dilation=2
        )

        # Multilayer Perceptrons (MLPs) at the end of the PointCNN
        self.lin1 = nn.Linear(384, 256)
        self.lin2 = nn.Linear(256, 128)
        self.lin3 = nn.Linear(128, 8)  # change last value for number of classes

    def forward(self, data):
        pos, batch = data.pos, data.batch
        x = data.x if self.numfeatures else None

        # First XConv with no features
        x = F.relu(self.conv1(x, pos, batch))
        # x = torch.nn.ReLU(self.conv1(x, pos, batch))

        # Farthest point sampling, keeping only 37.5%
        idx = fps(pos, batch, ratio=0.375)
        x, pos, batch = x[idx], pos[idx], batch[idx]
        # Second XConv
        x = F.relu(self.conv2(x, pos, batch))

        # Farthest point samplling, keepiong only 33.4%
        idx = fps(pos, batch, ratio=0.334)
        x, pos, batch = x[idx], pos[idx], batch[idx]

        # Two additional XConvs
        x = F.relu(self.conv3(x, pos, batch))
        x = F.relu(self.conv4(x, pos, batch))

        # Pooling batch-elements together
        # Each tree is described in one single point with 384 features
        x = global_mean_pool(x, batch)

        # MLPs at the end with ReLU
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))

        # Dropout: Set randomly to value of zero
        # x = F.dropout(x, p=0.5, training=self.training)
        x = F.dropout(x, p=0.5, training=True)
        x = self.lin3(x)

        # log-SofMax activation to callculate Negative Log Likelihood (NLL)
        return F.log_softmax(x, dim=-1)

    def training_step(self, data, batch_idx):
        y = data.y
        out = self(data)
        loss = F.nll_loss(out, y)
        self.train_acc(out, y)
        self.log(
            "train_acc", self.train_acc, on_step=True, on_epoch=True, batch_size=16
        )
        self.log("train_loss", loss, batch_size=16)
        return loss

    def validation_step(self, data, batch_idx):
        y = data.y
        out = self(data)
        val_loss = F.nll_loss(out, y)
        self.val_acc(out, y)
        self.log("val_acc", self.val_acc, on_step=True, on_epoch=True, batch_size=16)
        self.log("val_loss", val_loss, batch_size=16)  # , on_step=True, on_epoch=True)
        return val_loss

    def test_step(self, data, batch_idx):
        y = data.y
        out = self(data)
        test_loss = F.nll_loss(out, y)
        self.test_acc(out, y)
        self.log("test_loss", test_loss, batch_size=16)
        # return out
        return {"test_loss": test_loss, "logits": logits, "labels": y}

    def test_step_end(self, outs):
        return outs

    def test_epoch_ends(self, outs):
        global all_preds
        globalall_labels
        for out in outs:
            probs = list(out["logits"].cpu().detach().numpy())
            labels = list(out["labels"].flatten().cpu().detach().numpy())
            all_preds.extend(probs)
            all_labels.extend(labels)

    # def test_epoch_end(self, outs):
    #     global res
    #     res = outs
    #     return outs

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.learning_rate)
        return optimizer