In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
from torch import Tensor

class PointConsistencyModel(nn.Module):
    def __init__(self, model: nn.Module, sigma_data=0.5, sigma_min=1e-4):
        super().__init__()
        self.model = model
        self.sigma_data = sigma_data
        self.epsilon = sigma_min

    def forward(self, x: Tensor, t: Tensor, batch: Tensor | None=None) -> Tensor:
        f = self.model(x, t, batch)
        print(f.shape)
        c_out = self.c_out(t)
        c_skip = self.c_skip(t)
        out = c_out[batch].unsqueeze(-1) * f 
        out += c_skip[batch].unsqueeze(-1) * x
        return out

    def c_skip(self, t: Tensor) -> Tensor:
        return torch.div(self.sigma_data**2, (t - self.epsilon)**2 + self.sigma_data**2)
    
    def c_out(self, t: Tensor) -> Tensor:
        return torch.div(self.sigma_data * (t - self.epsilon), torch.sqrt(self.sigma_data**2 + t**2))

In [None]:
import math
import typing as ty
from dataclasses import dataclass

@dataclass
class PointConsistencySettings:
    training_iterations: int
    target_disc_steps: tuple[int, int] = (1, 150)
    initial_ema_decay_rate: float = 0.95
    initial_timesteps: int = 2
    sigma_range: tuple[float, float] = (0.002, 1.0)
    rho: float = 7.0


class PointConsistencyTraining(nn.Module):
    def __init__(self, settings: PointConsistencySettings):
        super().__init__()
        self.conf = settings
        
    def step_schedule_n(self, k: float) -> int:
        s, K = self.conf.target_disc_steps, self.conf.training_iterations

        num_timesteps = (s[1] + 1)**2 - s[0]**2
        num_timesteps = k * num_timesteps / K
        num_timesteps = num_timesteps + s[0]**2
        num_timesteps = math.sqrt(num_timesteps)
        num_timesteps = math.ceil(-1. + num_timesteps)
        return 1 + num_timesteps

    def ema_decay_rate_schedule_mu(self, n_k: int) -> float:
        s, mu_0 = self.conf.target_disc_steps, self.conf.initial_ema_decay_rate
        
        return math.exp(s[0] * math.log(mu_0) / float(n_k))

    def karras_schedule_t(self, n_k: int, device: torch.device | None = None) -> Tensor:
        (eps, T), rho = self.conf.sigma_range, self.conf.rho

        rho_inv = 1.0 / rho
        steps = torch.arange(n_k, device=device) / max(n_k - 1, 1)
        sigmas = eps**rho_inv + steps * (T**rho_inv - eps**rho_inv)
        sigmas = sigmas**rho
        return sigmas

    def ema_decay_rate_schedule(self, num_timesteps: int) -> float:
        return math.exp(
            (self.conf.initial_timesteps * math.log(self.conf.initial_ema_decay_rate)) / num_timesteps
        )

    def train_step(
        self, 
        iteration: int, 
        x: Tensor, batch: Tensor, 
        model: PointConsistencyModel, 
        ema_model: PointConsistencyModel
    ):
        num_timesteps = self.step_schedule_n(iteration)
        sigmas = self.karras_schedule_t(num_timesteps, device=x.device)
        noise = torch.randn_like(x)

        timesteps = torch.randint(0, num_timesteps - 1, (int(batch.max().item()) + 1, ), device=x.device)
        current_sigmas = sigmas[timesteps]
        next_sigmas = sigmas[timesteps + 1]

        next_x = x + (noise * next_sigmas[batch].unsqueeze(-1))
        next_x = model(next_x, next_sigmas, batch)

        with torch.no_grad():
            current_x = x + (x * current_sigmas[batch].unsqueeze(-1))
            current_x = ema_model(current_x, current_sigmas, batch)

        return next_x, current_x

    def after_train_step(self, iteration: int, model: PointConsistencyModel, ema_model: PointConsistencyModel):
        num_timesteps = self.step_schedule_n(iteration)
        ema_decay_rate = self.ema_decay_rate_schedule(num_timesteps)
        self._update_ema_weights(
            ema_model.parameters(), model.parameters(), ema_decay_rate
        )
        return ema_model

    def _update_ema_weights(
        self,
        ema_weight_iter: ty.Iterator[Tensor],
        online_weight_iter: ty.Iterator[Tensor],
        ema_decay_rate: float,
    ) -> None:
        for ema_weight, online_weight in zip(ema_weight_iter, online_weight_iter):
            if ema_weight.data is None:
                ema_weight.data.copy_(online_weight.data)
            else:
                ema_weight.data.lerp_(online_weight.data, 1.0 - ema_decay_rate)



In [None]:
import torch
from torch import nn

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from pytorch3d.loss import chamfer_distance

from models.model import Model

path = "data/ShapeNet"
category = 'Airplane' 
transform = T.Compose([
    T.NormalizeRotation(),
    T.FixedPoints(1024),
])
test_transform = T.Compose([
    T.NormalizeRotation(),
    T.FixedPoints(1024),
])
pre_transform = T.NormalizeScale()
train_dataset = ShapeNet(path, category, split='trainval', transform=transform, pre_transform=pre_transform)
test_dataset = ShapeNet(path, category, split='test', transform=test_transform, pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)


class CDLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        return chamfer_distance(pred, target)[0]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PointConsistencyModel(Model())
ema = PointConsistencyModel(Model())
ema.load_state_dict(model.state_dict())
ema.eval()

model = model.to(device)
ema = ema.to(device)

ct = PointConsistencyTraining(PointConsistencySettings(100_000))
loss = CDLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

for i, data in enumerate(train_loader, 1):
    optimizer.zero_grad()
    data = data.to(device)

    next, cur = ct.train_step(i, data.pos, data.batch, model, ema)    
    break

In [None]:
from models.model import Model

In [None]:
import matplotlib.pyplot as plt


def visualize_points(pos, c=None):
    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(projection='3d')
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])
    ax.axes.zaxis.set_ticklabels([]) # type: ignore
    ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], c='blue' if c is None else c, s=3)
    plt.show()



In [None]:
from pytorch3d.loss import chamfer_distance
from torch_geometric.utils import to_dense_batch
from torch import nn

class CDLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target, batch):
        pred, target = to_dense_batch(pred, batch)[0], to_dense_batch(target, batch)[0]
        return chamfer_distance(pred, target)[0]


In [None]:
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_batch

path = "data/ShapeNet"
category = 'Airplane' 
transform = T.Compose([
    T.NormalizeRotation(),
    T.FixedPoints(1024),
])
test_transform = T.Compose([
    T.NormalizeRotation(),
    T.FixedPoints(1024),
])
pre_transform = T.NormalizeScale()
train_dataset = ShapeNet(path, category, split='trainval', transform=transform, pre_transform=pre_transform)
test_dataset = ShapeNet(path, category, split='test', transform=test_transform, pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=True)


settings = PointConsistencySettings(training_iterations=100_000, sigma_range=(0.02, 20))
ct = PointConsistencyTraining(settings)

model = PointConsistencyModel(Model(), sigma_data=0.5, sigma_min=settings.sigma_range[0])
ema_model = PointConsistencyModel(Model(), sigma_data=0.5, sigma_min=settings.sigma_range[0])
ema_model.load_state_dict(model.state_dict())
loss = CDLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, ema_model = model.to(device), ema_model.to(device)

k = 1
for epoch in range(2, settings.training_iterations // len(train_loader) + 1):
    model.train()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        next_x, cur_x = ct.train_step(k, data.pos, data.batch, model, ema_model)

        loss_val = loss(next_x, cur_x)
        loss_val.backward()

        optimizer.step()
        ema_model = ct.after_train_step(k, model, ema_model)

        print(f"Epoch: {epoch}, Iteration: {k}, Loss: {loss_val.item()}")

        if k % 10 == 1:
            first_next_x = to_dense_batch(next_x, data.batch)[0][0]
            visualize_points(first_next_x.detach().cpu().numpy())

            first_cur_x = to_dense_batch(cur_x, data.batch)[0][0]
            visualize_points(first_cur_x.detach().cpu().numpy())
        
        k += 1


In [None]:
import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader

path = "data/ShapeNet"
category = 'Airplane' 
transform = T.Compose([
    T.NormalizeRotation(),
    T.FixedPoints(1024),
])
test_transform = T.Compose([
    T.FixedPoints(1024),
])
pre_transform = T.NormalizeScale()
train_dataset = ShapeNet(path, category, split='trainval', transform=transform, pre_transform=pre_transform)
test_dataset = ShapeNet(path, category, split='test', transform=test_transform, pre_transform=pre_transform)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=True)

In [None]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

# Helper functions for visualization.
%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def visualize_mesh(pos, face):
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])
    ax.axes.zaxis.set_ticklabels([]) # type: ignore
    ax.plot_trisurf(pos[:, 0], pos[:, 1], pos[:, 2], triangles=face.t(), antialiased=False)
    plt.show()


def visualize_points(pos, c=None):
    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(projection='3d')
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])
    ax.axes.zaxis.set_ticklabels([]) # type: ignore
    ax.scatter(pos[:, 0], pos[:, 2], pos[:, 1], c=pos[:, 1], s=2)
    plt.show()

In [None]:
from pytorch3d.loss import chamfer_distance
from torch_geometric.utils import to_dense_batch


def chamfer_loss(x, y, batch):
    x, y = to_dense_batch(x, batch)[0], to_dense_batch(y, batch)[0]
    return chamfer_distance(x, y)[0]


In [None]:
print(chamfer_loss(train_dataset[4].pos, test_dataset[4].pos, train_dataset[8].batch))

In [None]:
# Train the model.

from torch.nn import MSELoss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train():
    model.train()

    total_loss = 0
    for i, data in enumerate(train_loader):
        data = data.to(device)
        data_morph = data.clone()
        data_morph.pos += torch.randn_like(data.pos) * 0.2

        optimizer.zero_grad()

        out = model(data_morph)
        loss = chamfer_loss(out, data.pos, data.batch)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

        if (i + 1) % 10 == 0:
            print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss/10:.4f} ')
            total_loss = 0


@torch.no_grad()
def test():
    model.eval()

    total_loss = 0
    for i, data in enumerate(test_loader):
        data = data.to(device)
        data_morph = data.clone()
        data_morph.pos += torch.randn_like(data.pos) * 0.2

        out = model(data_morph)
        loss = chamfer_loss(out, data.pos, data.batch)
        
        total_loss += loss.item()

        if i == 0:
            data_pos = to_dense_batch(data.pos, data.batch)[0][0]
            visualize_points(data_pos.detach().cpu())   
            morphed_pos = to_dense_batch(data_morph.pos, data.batch)[0][0]
            visualize_points(morphed_pos.detach().cpu())
            new_pos = to_dense_batch(out, data.batch)[0][0]
            visualize_points(new_pos.detach().cpu())

    return total_loss / len(test_loader)
    

        
        


for epochs in range(100):
    print(f'Epoch {epochs}')
    train()
    test_loss = test()
    print(f'Test loss: {test_loss:.4f}')
