In [17]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from torch import Tensor



In [18]:
@dataclass
class ConsistencySettings:
    training_iterations: int = 100_000
    min_time_partitions: int = 1
    max_time_partitions: int = 150
    initial_ema_decay: float = 0.99
    min_time: float = 1e-4
    data_time: float = 1e-3
    max_time: float = 1.0
    rho: float = 7.0


In [19]:
from abc import ABC, abstractmethod
from loss.chamfer_loss import CDLoss
from geomloss import SamplesLoss
from torch_geometric.utils import to_dense_batch

class BaseDistanceFunc(ABC):
    def __init__(self) -> None:
        super().__init__()
        
    @abstractmethod
    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        ...
    

class ChamferDistance(BaseDistanceFunc):
    def __init__(self) -> None:
        super().__init__()
        self.loss = CDLoss()
        self.loss.eval()
        

    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        return self.loss(lhs, rhs, batch)
    

class MSEDistance(BaseDistanceFunc):
    def __init__(self) -> None:
        super().__init__()
        self.loss = nn.MSELoss()

    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        return self.loss(lhs, rhs)



class SinkhornEMDistance(BaseDistanceFunc):
    def __init__(self):
        self.loss = SamplesLoss(loss="sinkhorn", p=1, blur=0.01)
        
    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        lhs, rhs = to_dense_batch(lhs, batch)[0], to_dense_batch(rhs, batch)[0]
        return self.loss(lhs, rhs).mean()

In [20]:
import math


class BaseNumTimestepsSchedule(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, iteration: int) -> int:
        ...
    

class NumTimestepsSchedule(BaseNumTimestepsSchedule):
    def __init__(
        self, 
        min_time_partitions: int,
        max_time_partitions: int,
        training_iterations: int
    ) -> None:
        super().__init__()
        self.target_disc_steps = (min_time_partitions, max_time_partitions)
        self.training_iterations = training_iterations

    def __call__(self, iteration: int) -> int:
        s, K = self.target_disc_steps, self.training_iterations
        num_timesteps = (s[1] + 1)**2 - s[0]**2
        num_timesteps = iteration * num_timesteps / K
        num_timesteps = num_timesteps + s[0]**2
        num_timesteps = math.sqrt(num_timesteps)
        num_timesteps = math.ceil(-1. + num_timesteps)
        return 1 + num_timesteps


In [21]:

class BaseTimeSchedule(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, up_to: int, device: torch.device | None=None) -> torch.Tensor:
        ...
    
    
class KarrasTimeSchedule(BaseTimeSchedule):
    def __init__(self, min_time: float, max_time: float, rho: float) -> None:
        super().__init__()
        self.sigma_range = (min_time, max_time)
        self.rho = rho

    def __call__(self, up_to: int, device: torch.device | None=None) -> torch.Tensor:
        (eps, T), rho = self.sigma_range, self.rho
        rho_inv = 1.0 / rho
        steps = torch.arange(up_to, device=device) / max(up_to - 1, 1)
        sigmas = eps**rho_inv + steps * (T**rho_inv - eps**rho_inv)
        sigmas = sigmas**rho
        return sigmas


In [22]:

class BaseEMADecay(ABC):
    @abstractmethod
    def __call__(self, iteration: int) -> float:
        ...

    def __init__(self) -> None:
        super().__init__()


class ExponentialDecay(BaseEMADecay):
    def __init__(self, initial_decay: float, training_iterations: int) -> None:
        super().__init__()
        self.initial_decay = initial_decay
        self.training_iterations = training_iterations

    def __call__(self, iteration: int) -> float:
        return math.exp(
            iteration * math.log(self.initial_decay) / self.training_iterations
        )

In [23]:
class BaseParametrization(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, x: Tensor, y: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        ...
        

class BaseResampler(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, x: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        ...


class EpsilonParametrization(BaseParametrization):
    def __init__(self, min_time: float, data_time: float) -> None:
        super().__init__()
        self.data_time = data_time
        self.min_time = min_time

    def skip(self, t: Tensor) -> Tensor:
        return (self.data_time ** 2) / ((t - self.min_time) ** 2 + (self.data_time ** 2))

    def out(self, t: Tensor) -> Tensor:
        return (t - self.min_time) * self.data_time / (self.data_time**2 + t**2) ** 0.5    

    def __call__(self, x: Tensor, y: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        return self.skip(t)[batch, None] * x + self.out(t)[batch, None] * y 
    

class EpsilonResampler(BaseResampler):
    def __init__(self, min_time: float) -> None:
        super().__init__()
        self.min_time = min_time

    def __call__(self, x: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        mul = (t**2 - self.min_time**2)**0.5
        return x + mul[batch, None] * torch.randn_like(x)

In [24]:
import copy


class BaseConditionedModel(nn.Module, ABC):
    @abstractmethod
    def forward(self, x: Tensor, t: Tensor, ctx: torch.Tensor, batch: Tensor) -> Tensor:
        ...

    def __init__(self) -> None:
        super().__init__()

    def __call__(self, x: Tensor, t: Tensor, ctx: Tensor, batch: Tensor) -> Tensor:
        return super().__call__(x=x, t=t, ctx=ctx, batch=batch)
    


In [25]:
from typing import Sequence

class ConsistencyModel(nn.Module):
    def __init__(self, model: BaseConditionedModel, resampler: BaseResampler, parametrization: BaseParametrization) -> None:
        super().__init__()
        self.model = model
        self.resampler = resampler
        self.parametrization = parametrization

    def forward(self, x: Tensor, t: Tensor | Sequence[Tensor], ctx: Tensor, batch: Tensor) -> Tensor:
        assert not isinstance(t, int)
        ts = (t,) if isinstance(t, Tensor) else t

        x = self.wrapped_model(x=x, t=ts[0], ctx=ctx, batch=batch)
        for t in ts[1:]:
            x = self.resampler(x=x, t=t, batch=batch)
            x = self.wrapped_model(x=x, t=t, ctx=ctx, batch=batch)
        return x

    def wrapped_model(self, x: Tensor, t: Tensor, ctx: Tensor, batch: Tensor) -> Tensor:
        y = self.model(x=x, t=t, ctx=ctx, batch=batch)
        return self.parametrization(x=x, y=y, t=t, batch=batch)

    def __call__(self, x: Tensor, t: Tensor | Sequence[Tensor], ctx: Tensor, batch: Tensor) -> Tensor:
        return super().__call__(x=x, t=t, ctx=ctx, batch=batch)


class ConsistencyTrainer(nn.Module):
    def __init__(self, 
                 model: ConsistencyModel, 
                 ema: ConsistencyModel,
                 step_schedule: BaseNumTimestepsSchedule, 
                 time_schedule: BaseTimeSchedule,
                 ema_decay: BaseEMADecay):
        super().__init__()
        self.model = model
        self.ema = ema

        self.step_schedule = step_schedule
        self.time_schedule = time_schedule
        self.ema_decay = ema_decay
        
    def train_step(self, iteration: int, x: Tensor, ctx: Tensor, batch: Tensor):
        num_timesteps = self.step_schedule(iteration)
        times = self.time_schedule(num_timesteps, device=x.device)
        
        batch_size = int(batch.max()) + 1
        time_indices = torch.randint(0, num_timesteps - 1, (batch_size,), device=x.device)
        current_times = times[time_indices]
        next_times = times[time_indices + 1]

        z = torch.randn_like(x, device=x.device)
        next_x = x + z * next_times[batch, None]
        denoised = self.model(
            x=next_x, 
            t=next_times, 
            ctx=ctx,
            batch=batch, 
        )

        with torch.no_grad():
            current_x = x + z * current_times[batch, None]
            ema_denoised = self.ema(
                x=current_x, 
                t=current_times, 
                ctx=ctx,
                batch=batch, 
            )

        return denoised, ema_denoised 

    @torch.no_grad()
    def update_emas(self, iteration: int):
        alpha = self.ema_decay(iteration)
        for p, ema_p in zip(self.model.parameters(), self.ema.parameters()):
            ema_p.data = alpha * ema_p.data + (1 - alpha) * p.data



In [26]:

import torch
from torch_geometric.nn import knn

gt = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.tensor([[-1.0, 0.0]])
assign_index = knn(gt, y, 2, batch_x)

assign_index

tensor([[0, 0],
        [0, 1]])

In [27]:
settings = ConsistencySettings(
    training_iterations=100_000,
    min_time_partitions=2,
    max_time_partitions=150,
    initial_ema_decay=0.95,
    min_time=1e-4,
    data_time=1e-3,
    max_time=1.0,
    rho=7.0
)

step_schedule = NumTimestepsSchedule(
    settings.min_time_partitions,
    settings.max_time_partitions,
    settings.training_iterations
)
time_schedule = KarrasTimeSchedule(
    settings.min_time,
    settings.max_time,
    settings.rho
)
ema_decay = ExponentialDecay(
    settings.initial_ema_decay,
    settings.training_iterations
)
parametrization = EpsilonParametrization(
    settings.min_time,
    settings.data_time
)
resampler = EpsilonResampler(
    settings.min_time
)
loss_function = ChamferDistance()
eval_function = ChamferDistance()


In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import NormalizeScale, FixedPoints, Compose
from common.transforms.knn_view import KNNSplit

pre_transform = NormalizeScale()
transform = Compose([FixedPoints(8192), KNNSplit(2048)])

root = "data/ShapeNetAll"
train_dataset = ShapeNet(root=root, categories=None, pre_transform=pre_transform, transform=transform, split="train")
val_dataset = ShapeNet(root=root, categories=None, pre_transform=pre_transform, transform=transform, split="val")
test_dataset = ShapeNet(root=root, categories=None, pre_transform=pre_transform, transform=transform, split="test")

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, follow_batch=["pos", "incomplete"], drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True, follow_batch=["pos", "incomplete"], drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True, follow_batch=["pos", "incomplete"], drop_last=True)


In [30]:
from models.backbone.attdgcnn import AttDGCNNEncoder
from models.backbone.pointnet import PointNetEncoder
from models.backbone.glu import GLUDecoder


class GLUConditionedModel(BaseConditionedModel):
    def __init__(self, dim_ctx) -> None:
        super().__init__()
        self.model = GLUDecoder(dim_ctx=dim_ctx)

    def forward(self, x: Tensor, t: Tensor, ctx: Tensor, batch: Tensor) -> Tensor:
        return self.model(x=x, t=t, ctx=ctx, batch=batch)


encoder = AttDGCNNEncoder(512)
model = ConsistencyModel(
    model=GLUConditionedModel(512),
    resampler=resampler,
    parametrization=parametrization,
)
ema = copy.deepcopy(model)
optimizer = torch.optim.AdamW(list(model.parameters()) + list(encoder.parameters()), lr=1e-3)

In [31]:
from common.training import get_data_iterator
from common.visualization import visualize_batch_points
from common.data import MyData, MyDataBatched
from torch.utils.tensorboard.writer import SummaryWriter
import socket
import os
from datetime import datetime


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
current_time = datetime.now().strftime("%b%d_%H-%M-%S")
experiment_name = current_time + "_" + socket.gethostname() + "_dgcnn"
log_dir = os.path.join("runs", experiment_name)
ckpt_dir = os.path.join("checkpoints", experiment_name)
os.makedirs(ckpt_dir, exist_ok=True)


writer = SummaryWriter(log_dir=log_dir)
trainer = ConsistencyTrainer(model, ema, step_schedule, time_schedule, ema_decay)
encoder = encoder.to(device)
model = model.to(device)
ema = ema.to(device)



def train(epoch: int, data: MyDataBatched):
    epoch = epoch + 1

    optimizer.zero_grad()  # Clear gradients.
    model.train()
    data = data.to(device) # type: ignore

    gt, incomplete = data.pos, data.incomplete
    gt_batch, incomplete_batch = data.pos_batch, data.incomplete_batch
    
    incomplete_feat = encoder(incomplete, incomplete_batch)
    x_cur, x_ema = trainer.train_step(epoch, gt, incomplete_feat, gt_batch)
    
    loss = loss_function(x_cur, x_ema, gt_batch)

    loss.backward()  # Backward pass.
    optimizer.step()  # Update model parameters.
    trainer.update_emas(epoch)  # Update EMA parameters.

    return loss.item()
    

@torch.no_grad()
def validate(epoch):
    epoch = epoch + 1

    group_loss, group_cd = 0.0, 0.0
    model.eval()
    for data in val_loader:
        data: MyDataBatched = data.to(device) # type: ignore

        gt, incomplete = data.pos, data.incomplete
        gt_batch, incomplete_batch = data.pos_batch, data.incomplete_batch

        incomplete_feat = encoder(incomplete, incomplete_batch)
        x_cur, x_ema = trainer.train_step(epoch, gt, incomplete_feat, gt_batch)

        loss = loss_function(x_cur, x_ema, gt_batch)
        eval = eval_function(x_cur, gt, gt_batch)

        group_loss, group_cd = group_loss + loss.item(), group_cd + eval.item()

    group_loss, group_cd = group_loss / len(val_loader), group_cd / len(val_loader)
    return group_loss, group_cd


@torch.no_grad()
def sample(epoch: int):
    epoch = epoch + 1
    model.eval()

    data: MyDataBatched = next(iter(test_loader))
    data = data.to(device) # type: ignore
    
    gt_, incomplete = data.pos, data.incomplete
    gt_batch_, incomplete_batch = data.pos_batch, data.incomplete_batch

    ones_mask = torch.ones(int(gt_batch_.max() + 1), device=device)
    time = settings.max_time * ones_mask
    random_base = resampler(gt_, time, gt_batch_)
    
    incomplete_feat = encoder(incomplete, incomplete_batch)
    x_recon = model(random_base, time, incomplete_feat, gt_batch_)

    x_fig = visualize_batch_points(incomplete, incomplete_batch)
    fig = visualize_batch_points(x_recon, gt_batch_)
    return x_fig, fig


def training_loop(start=0):
    data_iterator = get_data_iterator(train_loader)
    failed_last = False
    for epoch in range(start, settings.training_iterations + 1):
        try:
            loss = train(epoch, next(data_iterator))
            if epoch % 1 == 0:
                print(f"It: {epoch}, Loss: {loss}")
                writer.add_scalar("loss", loss, epoch)

            if epoch and epoch % 1000 == 0:       
                x_fig, fig = sample(epoch)
                writer.add_figure("x_fig", x_fig, epoch)
                writer.add_figure("x_recon", fig, epoch)    
                fig.clear()

            if epoch and epoch % 1000 == 0:
                val_loss, val_cd = validate(epoch)
                print(f"Val It: {epoch}, Loss: {val_loss}, CD: {val_cd}")
                writer.add_scalar("val_loss", val_loss, epoch)
                writer.add_scalar("val_cd", val_cd, epoch)


            if epoch and epoch % 1000 == 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'encoder_state_dict': encoder.state_dict(),
                    'ema_state_dict': ema.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, os.path.join(ckpt_dir, f"checkpoint_{epoch}.pth"))
            
            failed_last = False
        except Exception as e:
            print(e)
            if failed_last:
                break
            failed_last = True



In [32]:
state_file = None
if state_file is not None:
    state = torch.load(state_file)
    start = state['epoch'] + 1
    encoder.load_state_dict(state['encoder_state_dict'])
    model.load_state_dict(state['model_state_dict'])
    ema.load_state_dict(state['ema_state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    print(f"Loaded checkpoint from {state_file}")
else:
    start = 0
    ema.load_state_dict(model.state_dict()) 

training_loop(start=start)

torch.Size([8, 512])
It: 0, Loss: 0.2138255089521408
torch.Size([8, 512])
It: 1, Loss: 0.3212275505065918
torch.Size([8, 512])
It: 2, Loss: 0.186448872089386
torch.Size([8, 512])
It: 3, Loss: 0.16066531836986542
torch.Size([8, 512])
It: 4, Loss: 0.23308484256267548
torch.Size([8, 512])
It: 5, Loss: 0.21543200314044952
torch.Size([8, 512])
It: 6, Loss: 0.30318760871887207
torch.Size([8, 512])
It: 7, Loss: 0.33970481157302856
torch.Size([8, 512])
It: 8, Loss: 0.3476470708847046
torch.Size([8, 512])
It: 9, Loss: 0.1481168568134308
torch.Size([8, 512])
It: 10, Loss: 0.276712030172348
torch.Size([8, 512])
It: 11, Loss: 0.27033692598342896
torch.Size([8, 512])
It: 12, Loss: 0.10079048573970795
torch.Size([8, 512])
It: 13, Loss: 0.15949709713459015
torch.Size([8, 512])
It: 14, Loss: 0.4224643409252167
torch.Size([8, 512])
It: 15, Loss: 0.13747812807559967
torch.Size([8, 512])
It: 16, Loss: 0.27220574021339417
torch.Size([8, 512])
It: 17, Loss: 0.25022849440574646
torch.Size([8, 512])
It: 18, 

KeyboardInterrupt: 