In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from torch import Tensor



In [2]:
@dataclass
class ConsistencySettings:
    training_iterations: int = 100_000
    min_time_partitions: int = 1
    max_time_partitions: int = 150
    initial_ema_decay: float = 0.99
    min_time: float = 1e-4
    data_time: float = 1e-3
    max_time: float = 1.0
    rho: float = 7.0


In [3]:
from abc import ABC, abstractmethod
from loss.chamfer_loss import CDLoss
from loss.emd_loss import EMDLoss

class BaseDistanceFunc(ABC):
    def __init__(self) -> None:
        super().__init__()
        
    @abstractmethod
    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        ...
    

class ChamferDistance(BaseDistanceFunc):
    def __init__(self) -> None:
        super().__init__()
        self.loss = CDLoss()
        

    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        return self.loss(lhs, rhs, batch)
    

class MSEDistance(BaseDistanceFunc):
    def __init__(self) -> None:
        super().__init__()
        self.loss = nn.MSELoss()

    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        return self.loss(lhs, rhs)


class EarthMoverDistance(BaseDistanceFunc):
    def __init__(self) -> None:
        super().__init__()
        self.loss = EMDLoss()

    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        return self.loss(lhs, rhs, batch)

In [4]:
import math


class BaseNumTimestepsSchedule(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, iteration: int) -> int:
        ...
    

class NumTimestepsSchedule(BaseNumTimestepsSchedule):
    def __init__(
        self, 
        min_time_partitions: int,
        max_time_partitions: int,
        training_iterations: int
    ) -> None:
        super().__init__()
        self.target_disc_steps = (min_time_partitions, max_time_partitions)
        self.training_iterations = training_iterations

    def __call__(self, iteration: int) -> int:
        s, K = self.target_disc_steps, self.training_iterations
        num_timesteps = (s[1] + 1)**2 - s[0]**2
        num_timesteps = iteration * num_timesteps / K
        num_timesteps = num_timesteps + s[0]**2
        num_timesteps = math.sqrt(num_timesteps)
        num_timesteps = math.ceil(-1. + num_timesteps)
        return 1 + num_timesteps


In [5]:

class BaseTimeSchedule(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, up_to: int, device: torch.device | None=None) -> torch.Tensor:
        ...
    
    
class KarrasTimeSchedule(BaseTimeSchedule):
    def __init__(self, min_time: float, max_time: float, rho: float) -> None:
        super().__init__()
        self.sigma_range = (min_time, max_time)
        self.rho = rho

    def __call__(self, up_to: int, device: torch.device | None=None) -> torch.Tensor:
        (eps, T), rho = self.sigma_range, self.rho
        rho_inv = 1.0 / rho
        steps = torch.arange(up_to, device=device) / max(up_to - 1, 1)
        sigmas = eps**rho_inv + steps * (T**rho_inv - eps**rho_inv)
        sigmas = sigmas**rho
        return sigmas


In [6]:

class BaseEMADecay(ABC):
    @abstractmethod
    def __call__(self, iteration: int) -> float:
        ...

    def __init__(self) -> None:
        super().__init__()


class ExponentialDecay(BaseEMADecay):
    def __init__(self, initial_decay: float, training_iterations: int) -> None:
        super().__init__()
        self.initial_decay = initial_decay
        self.training_iterations = training_iterations

    def __call__(self, iteration: int) -> float:
        return math.exp(
            iteration * math.log(self.initial_decay) / self.training_iterations
        )

In [7]:
class BaseParametrization(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, x: Tensor, y: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        ...
        

class BaseResampler(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, x: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        ...


class EpsilonParametrization(BaseParametrization):
    def __init__(self, min_time: float, data_time: float) -> None:
        super().__init__()
        self.data_time = data_time
        self.min_time = min_time

    def skip(self, t: Tensor) -> Tensor:
        return (self.data_time ** 2) / ((t - self.min_time) ** 2 + (self.data_time ** 2))

    def out(self, t: Tensor) -> Tensor:
        return (t - self.min_time) * self.data_time / (self.data_time**2 + t**2) ** 0.5    

    def __call__(self, x: Tensor, y: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        return self.skip(t)[batch, None] * x + self.out(t)[batch, None] * y 
    

class EpsilonResampler(BaseResampler):
    def __init__(self, min_time: float) -> None:
        super().__init__()
        self.min_time = min_time

    def __call__(self, x: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        mul = (t**2 - self.min_time**2)**0.5
        return x + mul[batch, None] * torch.randn_like(x)

In [8]:
import copy


class BaseConditionedModel(nn.Module, ABC):
    @abstractmethod
    def forward(self, x: Tensor, t: Tensor, ctx: torch.Tensor, batch: Tensor) -> Tensor:
        ...

    def __init__(self) -> None:
        super().__init__()

    def __call__(self, x: Tensor, t: Tensor, ctx: Tensor, batch: Tensor) -> Tensor:
        return super().__call__(x=x, t=t, ctx=ctx, batch=batch)
    


In [9]:
class ConsistencyModel(nn.Module):
    def __init__(self, model: BaseConditionedModel, resampler: BaseResampler, parametrization: BaseParametrization) -> None:
        super().__init__()
        self.model = model
        self.resampler = resampler
        self.parametrization = parametrization

    def forward(self, x: Tensor, t: Tensor | tuple[Tensor], ctx: Tensor, batch: Tensor) -> Tensor:
        assert not isinstance(t, int)
        ts = (t,) if isinstance(t, Tensor) else t

        x = self.wrapped_model(x=x, t=ts[0], ctx=ctx, batch=batch)
        for t in ts[1:]:
            x = self.resampler(x=x, t=t, batch=batch)
            x = self.wrapped_model(x=x, t=t, ctx=ctx, batch=batch)
        return x

    def wrapped_model(self, x: Tensor, t: Tensor, ctx: Tensor, batch: Tensor) -> Tensor:
        y = self.model(x=x, t=t, ctx=ctx, batch=batch)
        return self.parametrization(x=x, y=y, t=t, batch=batch)

    def __call__(self, x: Tensor, t: Tensor | tuple[Tensor], ctx: Tensor, batch: Tensor) -> Tensor:
        return super().__call__(x=x, t=t, ctx=ctx, batch=batch)


class ConsistencyTrainer(nn.Module):
    def __init__(self, 
                 model: ConsistencyModel, 
                 ema: ConsistencyModel,
                 step_schedule: BaseNumTimestepsSchedule, 
                 time_schedule: BaseTimeSchedule,
                 ema_decay: BaseEMADecay):
        super().__init__()
        self.model = model
        self.ema = ema

        self.step_schedule = step_schedule
        self.time_schedule = time_schedule
        self.ema_decay = ema_decay
        
    def train_step(self, iteration: int, x: Tensor, ctx: Tensor, batch: Tensor):
        num_timesteps = self.step_schedule(iteration)
        times = self.time_schedule(num_timesteps, device=x.device)
        
        batch_size = int(batch.max()) + 1
        time_indices = torch.randint(0, num_timesteps - 1, (batch_size,), device=x.device)
        current_times = times[time_indices]
        next_times = times[time_indices + 1]

        z = torch.randn_like(x, device=x.device)
        next_x = x + z * next_times[batch, None]
        denoised = self.model(
            x=next_x, 
            t=next_times, 
            ctx=ctx,
            batch=batch, 
        )

        with torch.no_grad():
            current_x = x + z * current_times[batch, None]
            ema_denoised = self.ema(
                x=current_x, 
                t=current_times, 
                ctx=ctx,
                batch=batch, 
            )

        return denoised, ema_denoised 

    @torch.no_grad()
    def update_emas(self, iteration: int):
        alpha = self.ema_decay(iteration)
        for p, ema_p in zip(self.model.parameters(), self.ema.parameters()):
            ema_p.data = alpha * ema_p.data + (1 - alpha) * p.data



In [10]:
from models.backbone.pointnet import PointNetEncoder
from models.backbone.glu import GLUDecoder


class GLUConditionedModel(BaseConditionedModel):
    def __init__(self, dim_ctx) -> None:
        super().__init__()
        self.model = GLUDecoder(dim_ctx=dim_ctx)

    def forward(self, x: Tensor, t: Tensor, ctx: Tensor, batch: Tensor) -> Tensor:
        return self.model(x=x, t=t, ctx=ctx, batch=batch)
    

In [11]:
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import NormalizeScale, FixedPoints, Compose, NormalizeRotation

pre_transform = NormalizeScale()
transform = FixedPoints(1024)

root = "data/ShapeNet"
train_dataset = ShapeNet(root=root, categories=["Airplane"], pre_transform=pre_transform, transform=transform, split="train")
val_dataset = ShapeNet(root=root, categories=["Airplane"], pre_transform=pre_transform, transform=transform, split="val")
test_dataset = ShapeNet(root=root, categories=["Airplane"], pre_transform=pre_transform, transform=transform, split="test")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [12]:
settings = ConsistencySettings(
    training_iterations=100_000,
    min_time_partitions=2,
    max_time_partitions=150,
    initial_ema_decay=0.95,
    min_time=1e-4,
    data_time=1e-3,
    max_time=1.0,
    rho=7.0
)

step_schedule = NumTimestepsSchedule(
    settings.min_time_partitions,
    settings.max_time_partitions,
    settings.training_iterations
)
time_schedule = KarrasTimeSchedule(
    settings.min_time,
    settings.max_time,
    settings.rho
)
ema_decay = ExponentialDecay(
    settings.initial_ema_decay,
    settings.training_iterations
)
parametrization = EpsilonParametrization(
    settings.min_time,
    settings.data_time
)
resampler = EpsilonResampler(
    settings.min_time
)
loss_function = ChamferDistance()


In [None]:
from common.training import get_data_iterator
from common.visualization import visualize_batch_results
from torch.utils.tensorboard.writer import SummaryWriter
import socket
import os
from datetime import datetime


current_time = datetime.now().strftime("%b%d_%H-%M-%S")
experiment_name = current_time + "_" + socket.gethostname() + "_consistency"
log_dir = os.path.join("runs", experiment_name)
ckpt_dir = os.path.join("checkpoints", experiment_name)
os.makedirs(ckpt_dir, exist_ok=True)


writer = SummaryWriter(log_dir=log_dir)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder = PointNetEncoder(256).to(device)
model = ConsistencyModel(
    model=GLUConditionedModel(256),
    resampler=resampler,
    parametrization=parametrization
).to(device)
ema = copy.deepcopy(model)
ema.load_state_dict(model.state_dict())
ema.eval()
ema = ema.to(device)


trainer = ConsistencyTrainer(model, ema, step_schedule, time_schedule, ema_decay)

optimizer = torch.optim.Adam(list(model.parameters()) + list(encoder.parameters()), lr=1e-3)

# state_file = 'checkpoints/Oct29_22-53-03_fedora_consistency/checkpoint_13000.pth'
# if state_file is not None:
#     state = torch.load(state_file)
#     model.load_state_dict(state['model_state_dict'])
#     encoder.load_state_dict(state['encoder_state_dict'])
#     ema.load_state_dict(state['ema_state_dict'])
#     optimizer.load_state_dict(state['optimizer'])
#     start_epoch = state['epoch']
#     print(f"Loaded checkpoint from {state_file}")


def train(epoch, data):
    epoch = epoch + 1

    optimizer.zero_grad()  # Clear gradients.
    model.train()
    
    data = data.to(device)
    x: Tensor = data.pos
    batch: Tensor = data.batch

    ctx = encoder(x, batch)
    x_cur, x_ema = trainer.train_step(epoch, x, ctx, batch)
    loss = loss_function(x_cur, x_ema, batch)

    loss.backward()  # Backward pass.
    optimizer.step()  # Update model parameters.
    trainer.update_emas(epoch)

    return loss.item()
    

@torch.no_grad()
def validate(epoch):
    epoch = epoch + 1

    group_loss = 0.0
    for data in val_loader:
        model.eval()
        data = data.to(device)
        x: Tensor = data.pos
        batch: Tensor = data.batch

        ctx = encoder(x, batch)
        x_cur, x_ema = trainer.train_step(epoch, x, ctx, batch)
        loss = loss_function(x_cur, x_ema, batch)
        group_loss += loss.item()

    group_loss /= len(val_loader)
    return group_loss


@torch.no_grad()
def sample(epoch):
    epoch = epoch + 1
    model.eval()

    data = next(iter(test_loader))
    data = data.to(device)
    x: Tensor = data.pos
    batch: Tensor = data.batch

    ctx = encoder(x, batch)
    x_cur, x_ema = trainer.train_step(epoch, x, ctx, batch)

    fig = visualize_batch_results(x_cur, batch)
    fig_ema = visualize_batch_results(x_ema, batch)
    return fig, fig_ema


def training_loop():
    data_iterator = get_data_iterator(train_loader)
    for epoch in range(0, settings.training_iterations):
        loss = train(epoch, next(data_iterator))

        print(f"It: {epoch}, Loss: {loss}")
        if epoch % 100 == 0:
            writer.add_scalar("loss", loss, epoch)

        if epoch % 500 == 0:
            val_loss = validate(epoch)
            fig, fig_ema = sample(epoch)

            print(f"Val It: {epoch}, Loss: {val_loss}")
            writer.add_scalar("val_loss", val_loss, epoch)
            writer.add_figure("test_fig", fig, epoch)    
            writer.add_figure("test_fig_ema", fig_ema, epoch)

        if epoch % 1000 == 0:
            state = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'encoder_state_dict': encoder.state_dict(),
                'ema_state_dict': ema.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(state, os.path.join(ckpt_dir, f"checkpoint_{epoch}.pth"))

training_loop()