In [None]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

# Helper functions for visualization.
%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def visualize_mesh(pos, face):
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])
    ax.axes.zaxis.set_ticklabels([]) # type: ignore
    ax.plot_trisurf(pos[:, 0], pos[:, 1], pos[:, 2], triangles=face.t(), antialiased=False)
    plt.show()


def visualize_points(pos, c=None):
    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(projection='3d')
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])
    ax.axes.zaxis.set_ticklabels([]) # type: ignore
    ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c='blue' if c is None else c, s=3)
    plt.show()



In [None]:
import torch
from torch import nn, Tensor
import numpy as np
from torch_geometric.nn import knn
from torch_scatter import scatter
from torch_geometric.nn import Linear, MLP, PointTransformerConv, fps, knn, knn_graph, knn_interpolate
from torch_geometric.utils import scatter
import torch.nn.functional as F
from itertools import pairwise
import typing as ty
from torch_geometric.typing import OptTensor, Adj


class TransformerBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, time_channels: int, hidden_channels: int|None = None, last_act: str = 'silu'):
        super().__init__()
        hidden_channels = hidden_channels or out_channels

        self.input_fc = MLP([in_channels, in_channels], act='silu') 
        self.point_transformer = PointTransformerConv(
            in_channels, hidden_channels,
            pos_nn=MLP([3, 64, out_channels], norm=None, plain_last=False),
            attn_nn=MLP([in_channels, 64, out_channels], norm=None, plain_last=False)
        )
        
        self.time_fc = MLP([time_channels, hidden_channels], act='silu')
        self.output_fc = MLP([hidden_channels, out_channels], act=last_act) 

        

    def forward(self, x: Tensor, t: Tensor, pos: Tensor, edge_index):
        x = self.input_fc(x)
        x = self.point_transformer(x, pos, edge_index)
        x = self.output_fc(x)
        return x


class TransitionDown(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, ratio=0.5, k=16, act='silu'):
        super().__init__()
        self.k = k
        self.ratio = ratio
        self.mlp = MLP([in_channels, out_channels], plain_last=False, act=act)

    def forward(self, x: Tensor, pos: Tensor, batch):
        id_clusters = fps(pos, ratio=self.ratio, batch=batch)
        sub_batch = batch[id_clusters] if batch is not None else None
        id_k_neighbor = knn(pos, pos[id_clusters], k=self.k, batch_x=batch, batch_y=sub_batch)
        x = self.mlp(x)
        x_out = scatter(x[id_k_neighbor[1]], id_k_neighbor[0], dim=0, dim_size=id_clusters.size(0), reduce='max')
        sub_pos, out = pos[id_clusters], x_out
        return out, sub_pos, sub_batch



class TransitionUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()     
        self.mlp_sub = MLP([in_channels, out_channels], plain_last=False)
        self.mlp = MLP([out_channels, out_channels], plain_last=False)                                        
  
    def forward(self, x, x_sub, pos, pos_sub, batch=None, batch_sub=None):
        x_sub = self.mlp_sub(x_sub)                                                                           
        x_interpolated = knn_interpolate(x_sub, pos_sub, pos, k=3, batch_x=batch_sub, batch_y=batch)                                    
        x = self.mlp(x) + x_interpolated                                                                      
  
        return x


In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:

class DownBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, time_channels: int, ratio: float=0.5, k=16):
        super().__init__()
        self.k = k
        self.transition_down = TransitionDown(in_channels, out_channels, ratio, k, act='silu')
        self.transformer = TransformerBlock(out_channels, out_channels, time_channels, last_act='silu')

    def forward(self, x: Tensor, t: Tensor, pos: Tensor, batch: OptTensor=None):
        x, pos, batch = self.transition_down(x=x, pos=pos, batch=batch)
        edge_index = knn_graph(pos, k=self.k, batch=batch)
        x = self.transformer(x=x, t=t[batch], pos=pos, edge_index=edge_index)
        return x, pos, batch


class MLPBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, time_channels: int, k: int=16, norm: str | None='batch_norm'):
        super().__init__()
        self.k = k
        self.mlp = MLP([in_channels, out_channels], plain_last=False, norm=norm)
        self.transformer = TransformerBlock(out_channels, out_channels, time_channels)

    def forward(self, x: Tensor, t: Tensor, pos: Tensor, batch: OptTensor=None):
        x = self.mlp(x)
        edge_index = knn_graph(pos, k=self.k, batch=batch)
        x = self.transformer(x=x, t=t[batch], pos=pos, edge_index=edge_index)
        return x


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_channels):
        super().__init__()
        self.transition_up = TransitionUp(in_channels, out_channels)
        self.transformer = TransformerBlock(out_channels, out_channels, time_channels)

    def forward(self, x: Tensor, x_sub: Tensor, t: Tensor, pos: Tensor, pos_sub: Tensor, batch=None, batch_sub=None):
        x = self.transition_up(x, x_sub, pos, pos_sub, batch, batch_sub)
        edge_index = knn_graph(pos, k=16, batch=batch)
        x = self.transformer(x=x, t=t[batch], pos=pos, edge_index=edge_index)
        return x, pos, batch

In [None]:
from embeddings import SigmaEmbedding, SigmaProjection


class DenoisingUNet(nn.Module):
    def __init__(
        self, 
        channels: int, 
        dim_model: tuple[int, ...], 
        k: int = 16
    ):
        super().__init__()
        time_channels = 4 * dim_model[0]
        self.time_projection = SigmaProjection(dim_model[0], True, 0)
        self.time_embedding = SigmaEmbedding(dim_model[0], time_channels)

        self.input_block = MLPBlock(in_channels=channels, out_channels=dim_model[0], time_channels=time_channels, k=k)
        self.down_blocks = nn.ModuleList([
            DownBlock(in_channels=upper, out_channels=lower, time_channels=time_channels, ratio=0.5, k=k) 
                for upper, lower in pairwise(dim_model)])
        
        self.summit_block = MLPBlock(in_channels=dim_model[-1], out_channels=dim_model[-1], time_channels=time_channels, k=k, norm=None)
        self.up_blocks = nn.ModuleList([
            UpBlock(in_channels=lower, out_channels=upper, time_channels=time_channels) 
                for lower, upper in pairwise(reversed(dim_model))])
        
        self.mlp_output = MLP([dim_model[0], 64, channels], norm=None, act='tanh')


    def forward(self, pos: Tensor, t: Tensor, batch: OptTensor=None):
        t = self.time_projection(t)
        t = self.time_embedding(t)
        
        x: Tensor
        x = self.input_block(x=pos, t=t, pos=pos, batch=batch)

        left_outputs = [] 
        for down_block in self.down_blocks:
            left_outputs.append((x, pos, batch))
            x, pos, batch = down_block(x=x, t=t, pos=pos, batch=batch)
            
        x = self.summit_block(x=x, t=t, pos=pos, batch=batch) 

        for up_block, left_output in zip(self.up_blocks, reversed(left_outputs)):
            left_x, left_pos, left_batch = left_output
            x, pos, batch = up_block(x=left_x, x_sub=x, t=t, pos=left_pos, pos_sub=pos, batch=left_batch, batch_sub=batch)
        
        x = self.mlp_output(x)
        return x
    


In [None]:
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader

path = "data/ShapeNet"
category = 'Airplane' 
transform = T.Compose([
    T.FixedPoints(1024),
    T.RandomRotate(15, axis=0),
    T.RandomRotate(15, axis=1),
    T.RandomRotate(15, axis=2),
])
pre_transform = T.NormalizeScale()
train_dataset = ShapeNet(path, category, split='trainval', transform=transform, pre_transform=pre_transform)
test_dataset = ShapeNet(path, category, split='test', pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)

In [None]:
from pytorch3d.loss import chamfer_distance
from torch_geometric.utils import to_dense_batch
from torch import nn

class CDLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target, batch):
        pred, target = to_dense_batch(pred, batch)[0], to_dense_batch(target, batch)[0]
        return chamfer_distance(pred, target, batch_reduction=None)[0]


In [None]:
sample = next(iter(train_loader))
pos, batch = sample.pos, sample.batch

loss = CDLoss()
loss.forward(pos, pos + 1.0, batch)

In [None]:
from consistency_models import ConsistencyTraining, timesteps_schedule, ema_decay_rate_schedule, update_ema_model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

online_model = DenoisingUNet(channels=3, dim_model=(32, 64, 128, 256)).to(device)
ema_model = DenoisingUNet(channels=3, dim_model=(32, 64, 128, 256)).to(device)
ema_model.load_state_dict(online_model.state_dict())

consistency_step = ConsistencyTraining()
optimizer = torch.optim.Adam(online_model.parameters(), lr=0.0001)
loss_fn = nn.MSELoss()


online_model.train()
optimizer.zero_grad()

k = 0
max_iters = 500 * len(train_loader)
for epoch in range(1, 500):
    for i, data in enumerate(train_loader):
        data = data.to(device)
        next_x, current_x = consistency_step(online_model, ema_model, data.pos, data.batch, k, max_iters)
        
        loss = loss_fn(next_x, current_x)
        loss.backward()
        optimizer.step()

        if i % 10 == 0:
            print(f'Epoch: {epoch}, Subepoch: {i} Loss: {loss.item()}')

        if i % 100 == 0:
            visualize_points(current_x.detach().cpu().numpy())
            visualize_points(next_x.detach().cpu().numpy())

        with torch.no_grad():
            num_timesteps = timesteps_schedule(epoch, 500)
            ema_decay_rate = ema_decay_rate_schedule(num_timesteps,)
            ema_model = update_ema_model(ema_model, online_model, ema_decay_rate)
            
        k += 1