In [None]:
!pip install torch_geometric
!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
#!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cpu.html
#!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cpu.html
#!pip install open3d
!pip install plotly
!pip install wandb

In [None]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import EdgeConv, global_max_pool
from torch_geometric.nn import knn_graph

class EdgeConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EdgeConvBlock, self).__init__()
        self.edgeconv = EdgeConv(nn.Sequential(
            nn.Linear(in_channels * 2, 64),
            nn.LeakyReLU(),
            nn.LayerNorm(64),
            nn.Linear(64, 64),
            nn.LeakyReLU(),
            nn.LayerNorm(64),
            nn.Linear(64, out_channels)
        ))
        #self.lnrm = nn.LayerNorm(out_channels)
        #self.leaky_relu = nn.LeakyReLU()

    def forward(self, x, edge_index):
        return self.edgeconv(x, edge_index)
        #x = self.edgeconv(x, edge_index)
        #x = self.leaky_relu(self.lnrm(x))
        #return x

class DGCNN_seg(nn.Module):
    def __init__(self, n_points, in_channels, n_classes, k=40):
        super(DGCNN_seg, self).__init__()
        self.n_classes = n_classes
        self.n_points = n_points
        self.k = k

        self.edge_conv1 = EdgeConvBlock(in_channels, 64)
        self.edge_conv2 = EdgeConvBlock(64, 64)
        self.edge_conv3 = EdgeConvBlock(64, 64)

        self.fc_agg = nn.Linear(192, 1024)
        '''self.fc_agg = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(),
            nn.LayerNorm(1024)
        )'''

        #self.fc_seg_ft = nn.Linear(n_points, 64)
        self.fc_lb_emb = nn.Sequential(
            nn.Linear(n_points, 64),
            nn.LeakyReLU(),
            nn.LayerNorm(64)
        )
        #self.fc_lb_emb = nn.Linear(n_classes, 64)

        self.classifier = nn.Sequential(
            nn.Linear(1280, 256),
            nn.LeakyReLU(),
            nn.LayerNorm(256),
            nn.Dropout(0.5),
            nn.Linear(256, 256),
            nn.LeakyReLU(),
            nn.LayerNorm(256),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.LayerNorm(128),
            nn.Dropout(0.5),
            nn.Linear(128, n_classes)
        )

    def forward(self, x, labels: torch.tensor=None, batch: torch.tensor=None, batch_size: int=None):
        edge_index = knn_graph(x, k=self.k, batch=batch)
        x1 = self.edge_conv1(x, edge_index)

        edge_index = knn_graph(x1, k=self.k, batch=batch)
        x2 = self.edge_conv2(x1, edge_index)

        edge_index = knn_graph(x2, k=self.k, batch=batch)
        x3 = self.edge_conv3(x2, edge_index)

        x = torch.cat([x1, x2, x3], dim=1)

        # Aggregate multi-scale features
        #x = F.leaky_relu(self.fc_agg(x))
        x = self.fc_agg(x)

        '''Possibile option, one hot encoding of labels
        if labels is not None:
          lb_emb = self.fc_lb_emb(labels)
        else:
          lb_emb = torch.zeros(x.shape[0], 64, device=x.device)

        x = torch.cat([x, lb_emb], dim=1)'''

        x = global_max_pool(x, batch)


        if labels is not None:
            #seg_ft = self.fc_seg_ft(labels.view(x.shape[0], -1).float())
            lb_emb = self.fc_lb_emb(
                labels.view(x.shape[0], -1).float()
            )
        else:
            lb_emb = torch.zeros(x.shape[0], 64, device=x.device)

        x = torch.cat([x, lb_emb], dim=1)
        x = torch.repeat_interleave(x, self.n_points, dim=0)

        x = torch.cat([x1, x2, x3, x], dim=1)

        x = self.classifier(x)

        #return x.view(-1, self.n_points, self.n_classes)
        return x

In [None]:
import torch
from torch import nn
from torch_geometric.nn import MessagePassing, knn_graph, global_max_pool
from torch_geometric.utils import degree


class EGCNNBlock(MessagePassing):
    """EGNN layer from https://arxiv.org/pdf/2102.09844.pdf"""
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int = 64,
        out_channels: int = 64,
        aggr: str = 'add'
    ):
        super(EGCNNBlock, self).__init__(aggr=aggr)

        self.phi_e = nn.Sequential(
                nn.Linear(2 * in_channels + 1, hidden_channels),
                nn.LayerNorm(hidden_channels),
                nn.SiLU(),
                nn.Linear(hidden_channels, hidden_channels),
                nn.LayerNorm(hidden_channels),
                nn.SiLU()
        )

        self.phi_x = nn.Sequential(
                nn.Linear(hidden_channels, hidden_channels),
                nn.LayerNorm(hidden_channels),
                nn.SiLU(),
                nn.Linear(hidden_channels, 1)
        )

        self.phi_h = nn.Sequential(
            nn.Linear(in_channels + hidden_channels, hidden_channels),
            nn.LayerNorm(hidden_channels),
            nn.SiLU(),
            nn.Linear(64, out_channels)
        )

    def forward(self, x, h, edge_index, c):
        #if c is None:
        #    c = degree(edge_index[0], x.shape[0]).unsqueeze(-1)
        return self.propagate(edge_index=edge_index, x=x, h=h, c=c)

    def message(self, x_i, x_j, h_i, h_j):
        mh_ij = self.phi_e(torch.cat([h_i, h_j, torch.norm(x_i - x_j, dim=-1, keepdim=True)**2], dim=-1))
        mx_ij = (x_i - x_j) * self.phi_x(mh_ij)
        return torch.cat([mx_ij, mh_ij], dim=-1)

    def update(self, aggr_out, x, h, c):
        m_x, m_h = aggr_out[:, :3], aggr_out[:, 3:]
        h_l1 = self.phi_h(torch.cat([h, m_h], dim=-1)) #+ h
        x_l1 = x + (m_x / c)
        return x_l1, h_l1

class EDGCNN_seg(nn.Module):
    def __init__(self, n_points, in_channels, n_classes, k=40):
        super(EDGCNN_seg, self).__init__()
        self.n_classes = n_classes
        self.n_points = n_points
        self.k = k

        self.egcnn1 = EGCNNBlock(in_channels, 64)
        self.egcnn2 = EGCNNBlock(64, 64)
        self.egcnn3 = EGCNNBlock(64, 64)

        self.fc_agg = nn.Linear(192 + 9, 1024)
        '''self.fc_agg = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(),
            nn.LayerNorm(1024)
        )'''

        #self.fc_seg_ft = nn.Linear(n_points, 64)
        self.fc_lb_emb = nn.Sequential(
            nn.Linear(n_points, 64),
            nn.LayerNorm(64),
            nn.LeakyReLU()
        )
        #self.fc_lb_emb = nn.Linear(n_classes, 64)

        self.classifier = nn.Sequential(
            nn.Linear(1280 + 9, 256),
            nn.LayerNorm(256),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 256),
            nn.LayerNorm(256),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.LeakyReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, n_classes)
        )

    def forward(self, x, labels: torch.tensor=None, batch: torch.tensor=None, batch_size: int=None):
        c = 1/(batch_size - 1)

        edge_index = knn_graph(x, k=self.k, batch=batch)
        m1, h1 = self.egcnn1(x, x, edge_index, c)

        edge_index = knn_graph(m1, k=self.k, batch=batch)
        m2, h2 = self.egcnn2(m1, h1, edge_index, c)

        edge_index = knn_graph(m2, k=self.k, batch=batch)
        m3, h3 = self.egcnn3(m2, h2, edge_index, c)

        x = torch.cat([m1, h1, m2, h2, m3, h3], dim=1)

        # Aggregate multi-scale features
        #x = F.leaky_relu(self.fc_agg(x))
        x = self.fc_agg(x)

        '''Possibile option, one hot encoding of labels
        if labels is not None:
          lb_emb = self.fc_lb_emb(labels)
        else:
          lb_emb = torch.zeros(x.shape[0], 64, device=x.device)

        x = torch.cat([x, lb_emb], dim=1)'''

        x = global_max_pool(x, batch)


        if labels is not None:
            #seg_ft = self.fc_seg_ft(labels.view(x.shape[0], -1).float())
            lb_emb = self.fc_lb_emb(
                labels.view(x.shape[0], -1).float()
            )
        else:
            lb_emb = torch.zeros(x.shape[0], 64, device=x.device)

        x = torch.cat([x, lb_emb], dim=1)
        x = torch.repeat_interleave(x, self.n_points, dim=0)

        x = torch.cat([m1, h1, m2, h2, m3, h3, x], dim=1)
        #return x

        x = self.classifier(x)

        #return x.view(-1, self.n_points, self.n_classes)
        return x


In [None]:
import torch
from torch_geometric.data import Data
import plotly.graph_objects as go
import plotly.express as px
import numpy as np


def pointcloud(ptc: Data):
  # Access point coordinates and labels
  points = ptc.pos.numpy()
  labels = ptc.y.numpy()

  # Create a color map
  unique_labels = np.unique(labels)
  color_scale = px.colors.qualitative.Plotly  # You can choose other color scales like 'Set1', 'Pastel', etc.
  color_map = {label: color_scale[i % len(color_scale)] for i, label in enumerate(unique_labels)}

  point_colors = [color_map[label] for label in labels]

  # Create the 3D scatter plot
  fig = go.Figure(data=[go.Scatter3d(
      x=points[:, 0],
      y=points[:, 1],
      z=points[:, 2],
      mode='markers',
      marker=dict(
        size=2,
        color=point_colors,
        opacity=0.8
      ),
      text=[f'Label: {label}' for label in labels],
      hoverinfo='text'
  )])

  # Update the layout
  fig.update_layout(
      title=f"Pointcloud's category: {ptc.category.item()}",
      scene=dict(
          xaxis_title='X',
          yaxis_title='Y',
          zaxis_title='Z'
      ),
      width=800,
      height=800,
  )

  # Show the plot
  fig.show()

In [None]:
from torch_geometric.nn import fps
import torch.nn.functional as F

class FPSSampler:
    def __init__(self, num_points,  num_classes):
        self.num_points = num_points
        self.num_classes = num_classes

    def __call__(self, data):
        index = fps(data.pos, ratio=self.num_points / data.pos.size(0))
        y_onehot = F.one_hot(data.y[index], self.num_classes).float()
        return Data(x=data.x[index], y=data.y[index], y_onehot=y_onehot, pos=data.pos[index], category=data.category)

In [None]:
from torch_geometric.datasets import ShapeNet
from torch_geometric.transforms import Compose

sp_train = ShapeNet('data', split='train')
sp_test = ShapeNet('data', split='test')

#shapeNet.transform = Compose([FPSSampler(2048, 50)])

sp_train.transform = Compose([FPSSampler(2048, 50)])
sp_test.transdofrm = Compose([FPSSampler(2048, 50)])

In [None]:
from torch_geometric.data import InMemoryDataset

def train_eval_split(dataset: InMemoryDataset, eval_size: float):
  index = torch.rand(len(dataset)) < (1 - eval_size)
  return dataset[index], dataset[~index]

In [None]:
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
import wandb
import os
from tqdm import tqdm

def evaluate(model: nn.Module, eval_set: DataLoader, criterion: nn.CrossEntropyLoss, device):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for datapt in eval_set:
            inputs, labels = datapt.pos.to(device), datapt.y.to(device)
            outputs = model(
                inputs,
                batch=datapt.batch.to(device),
                batch_size=datapt.batch_size
            )
            loss = criterion(outputs, labels)
            total_loss += loss.item()

    return total_loss / len(eval_set)

def save_checkpoint(model: nn.Module, optimizer: Adam, epoch: int, val_loss: float, config):
    if (epoch + 1) % config['hyper']['save_every'] != 0:
      return

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
    }

    if not os.path.exists(config['hyper']['checkpoint_dir']):
        os.makedirs(config['hyper']['checkpoint_dir'])

    path = os.path.join(config['hyper']['checkpoint_dir'], f'checkpoint-{epoch+1}.pth')
    torch.save(checkpoint, path)

    checkpoints = [f for f in os.listdir(config['hyper']['checkpoint_dir']) if f.startswith('checkpoint-')]
    checkpoints.sort(key=lambda x: int(x.split('-')[-1].split('.')[0]))
    while len(checkpoints) > config['hyper']['max_checkpoints']:
        os.remove(os.path.join(config['hyper']['checkpoint_dir'], checkpoints.pop(0)))



def train(model: nn.Module, train_set: DataLoader, eval_set: DataLoader, config):
    wandb.init(
        project=config['init']['project'],
        name=config['init']['run_name'],
        config=config['hyper']
    )

    model.to(config['hyper']['device'])
    optimizer = SGD(model.parameters(), lr=config['hyper']['learning_rate'], momentum=0.9)
    #optimizer = Adam(model.parameters(), lr=config['hyper']['learning_rate'])
    criterion = nn.CrossEntropyLoss()

    scheduler = CosineAnnealingLR(optimizer, config['hyper']['epochs'], eta_min=0.001)

    for epoch in range(config['hyper']['epochs']):
        model.train()
        running_loss = 0.0

        for i, datapt in enumerate(tqdm(train_set, desc=f"Epoch {epoch+1}/{config['hyper']['epochs']}")):
            use_labels = torch.rand(1).item() < config['hyper']['use_labels_prob']

            y_hat = model(
                datapt.pos.to(config['hyper']['device']),
                #datapt.y_onehot.to(config['hyper']['device']) if use_labels else None,
                datapt.y.to(config['hyper']['device']) if use_labels else None,
                datapt.batch.to(config['hyper']['device']),
                datapt.batch_size
            )

            loss = criterion(
                y_hat,
                datapt.y.to(config['hyper']['device'])
            )

            loss = loss / config['hyper']['gradient_accumulation_steps']

            loss.backward()

            if (i + 1) % config['hyper']['gradient_accumulation_steps'] == 0:
                optimizer.step()
                optimizer.zero_grad()

            running_loss += loss.item() * config['hyper']['gradient_accumulation_steps']

            if i % config['hyper']['logging_steps'] == 0:
                wandb.log({
                    "train_loss": loss.item() * config['hyper']['gradient_accumulation_steps']
                })

        if (i + 1) % config['hyper']['gradient_accumulation_steps'] != 0:
            optimizer.step()
            optimizer.zero_grad()

        avg_train_loss = running_loss / len(train_set)

        val_loss = evaluate(
            model,
            eval_set,
            criterion,
            config['hyper']['device']
        )

        wandb.log({
            "epoch": epoch + 1,
            "val_loss": val_loss,
            "learning_rate": optimizer.param_groups[0]['lr']
        })

        scheduler.step()

        print(f"Epoch {epoch+1}/{config['hyper']['epochs']}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}")

        save_checkpoint(
            model,
            optimizer,
            epoch,
            val_loss,
            config
        )

    wandb.finish()

In [None]:
conf = {
    'init': {
        'project': 'DGCNN',
        'run_name': 'edgcnn-seg-v0.0.4'
    },
    'hyper':{
        'epochs': 3,
        'learning_rate': 0.01,
        'batch_training_size': 18,
        'batch_eval_size': 18,
        'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        'gradient_accumulation_steps': 1,
        'logging_steps': 5,
        'checkpoint_dir': 'checkpoints',
        'save_every': 1,
        'save_max': 3,
        'dataset': 'ShapeNet',
        'optimizer': 'SGD',
        'optimizer_kwargs': {'momentum': 0.9},
        'scheduler': 'CosineAnelingLr',
        'use_labels_prob': 0.7
    }
}

train_set, eval_set = train_eval_split(sp_train, 0.1)

In [None]:
model = EDGCNN_seg(n_points=2048, in_channels=3, n_classes=50)

In [None]:
train(
    model=model,
    train_set=DataLoader(train_set, batch_size=conf['hyper']['batch_training_size'], shuffle=True),
    eval_set=DataLoader(eval_set, batch_size=conf['hyper']['batch_eval_size'], shuffle=True),
    config=conf
)


In [None]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train_loss,▇▅▄▄▂█▄▆▁▇▁

0,1
train_loss,2.61737
