<a href="https://colab.research.google.com/github/Mechanics-Mechatronics-and-Robotics/CV-2025/blob/main/Week_10/Week_10_Cloud_of_Points.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Install Dependencies
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-2.6.0+cu124.html
!pip install -q pytorch-lightning pytorch-metric-learning plotly

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
import torch
import pytorch_lightning as pl
from torch_geometric.data import DataLoader
from torch_geometric.datasets import ModelNet
from torch_geometric.nn import MLP, DynamicEdgeConv, global_max_pool
from torch_geometric.transforms import SamplePoints, RandomJitter, RandomFlip, RandomShear, Compose
from pytorch_metric_learning.losses import NTXentLoss
import plotly.graph_objects as go
from typing import Tuple

In [None]:
class ModelNet10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, num_points=1024):
        super().__init__()
        self.batch_size = batch_size
        self.num_points = num_points
        self.train_transform = Compose([
            SamplePoints(num_points),
            RandomJitter(0.03),
            RandomFlip(1),
            RandomShear(0.2)
        ])
        self.val_transform = SamplePoints(num_points)

    def setup(self, stage=None):
        self.train_dataset = ModelNet(
            root='./ModelNet10',
            name='10',
            train=True,
            transform=self.train_transform
        )
        self.val_dataset = ModelNet(
            root='./ModelNet10',
            name='10',
            train=False,
            transform=self.val_transform
        )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=2)

In [None]:
class PointCloudContrastive(pl.LightningModule):
    def __init__(self, temperature=0.1, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.conv1 = DynamicEdgeConv(MLP([2 * 3, 64, 64]), k=20, aggr='max')
        self.conv2 = DynamicEdgeConv(MLP([2 * 64, 128]), k=20, aggr='max')
        self.projection_head = MLP([128 + 64, 256, 128], norm=None)
        self.criterion = NTXentLoss(temperature=temperature)

    def forward(self, data) -> torch.Tensor:
        x1 = self.conv1(data.pos, data.batch)
        x2 = self.conv2(x1, data.batch)
        h_points = self.projection_head(torch.cat([x1, x2], dim=1))
        return global_max_pool(h_points, data.batch)

    def training_step(self, batch, batch_idx):
        aug1 = batch.transform()
        aug2 = batch.transform()
        z1 = self(aug1)
        z2 = self(aug2)
        loss = self.criterion(z1, z2)
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
        return [optimizer], [scheduler]

In [None]:
def plot_point_cloud(pos: torch.Tensor, title: str = ""):
    pos = pos.cpu().numpy()
    fig = go.Figure(data=[go.Scatter3d(
        x=pos[:,0], y=pos[:,1], z=pos[:,2],
        mode='markers',
        marker=dict(size=3, opacity=0.8, color=pos[:,2], colorscale='Viridis')
    ])
    fig.update_layout(
        title=title,
        scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
        width=800,
        height=600
    )
    fig.show()

def visualize_augmentations(dataset, num_samples=3):
    loader = DataLoader(dataset, batch_size=num_samples, shuffle=True)
    batch = next(iter(loader))

    print("Original samples:")
    for i in range(num_samples):
        pos = batch.pos[batch.batch == i]
        plot_point_cloud(pos, f"Original Sample {i+1}")

    augmented = batch.transform()
    print("\nAugmented samples:")
    for i in range(num_samples):
        pos = augmented.pos[augmented.batch == i]
        plot_point_cloud(pos, f"Augmented Sample {i+1}")

In [None]:
if __name__ == "__main__":
    dm = ModelNet10DataModule(batch_size=32)
    model = PointCloudContrastive(lr=1e-3)

    dm.setup()
    visualize_augmentations(dm.train_dataset)

    trainer = pl.Trainer(
        max_epochs=100,
        accelerator='auto',
        devices=1,
        log_every_n_steps=10
    )
    trainer.fit(model, dm)

    torch.save(model.state_dict(), 'contrastive_model.pt')