In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from torch import Tensor



In [2]:
@dataclass
class ConsistencySettings:
    training_iterations: int = 100_000
    min_time_partitions: int = 1
    max_time_partitions: int = 150
    initial_ema_decay: float = 0.99
    min_time: float = 1e-4
    data_time: float = 1e-3
    max_time: float = 1.0
    rho: float = 7.0


In [3]:
from abc import ABC, abstractmethod
from loss.chamfer_loss import CDLoss
from geomloss import SamplesLoss
from torch_geometric.utils import to_dense_batch

class BaseDistanceFunc(ABC):
    def __init__(self) -> None:
        super().__init__()
        
    @abstractmethod
    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        ...
    

class ChamferDistance(BaseDistanceFunc):
    def __init__(self) -> None:
        super().__init__()
        self.loss = CDLoss()
        self.loss.eval()
        

    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        return self.loss(lhs, rhs, batch)
    

class MSEDistance(BaseDistanceFunc):
    def __init__(self) -> None:
        super().__init__()
        self.loss = nn.MSELoss()

    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        return self.loss(lhs, rhs)



class SinkhornEMDistance(BaseDistanceFunc):
    def __init__(self):
        self.loss = SamplesLoss(loss="sinkhorn", p=1, blur=0.01)
        
    def __call__(self, lhs: Tensor, rhs: Tensor, batch: Tensor | None) -> Tensor:
        lhs, rhs = to_dense_batch(lhs, batch)[0], to_dense_batch(rhs, batch)[0]
        return self.loss(lhs, rhs).mean()

In [4]:
import math


class BaseNumTimestepsSchedule(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, iteration: int) -> int:
        ...
    

class NumTimestepsSchedule(BaseNumTimestepsSchedule):
    def __init__(
        self, 
        min_time_partitions: int,
        max_time_partitions: int,
        training_iterations: int
    ) -> None:
        super().__init__()
        self.target_disc_steps = (min_time_partitions, max_time_partitions)
        self.training_iterations = training_iterations

    def __call__(self, iteration: int) -> int:
        s, K = self.target_disc_steps, self.training_iterations
        num_timesteps = (s[1] + 1)**2 - s[0]**2
        num_timesteps = iteration * num_timesteps / K
        num_timesteps = num_timesteps + s[0]**2
        num_timesteps = math.sqrt(num_timesteps)
        num_timesteps = math.ceil(-1. + num_timesteps)
        return 1 + num_timesteps


In [5]:

class BaseTimeSchedule(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, up_to: int, device: torch.device | None=None) -> torch.Tensor:
        ...
    
    
class KarrasTimeSchedule(BaseTimeSchedule):
    def __init__(self, min_time: float, max_time: float, rho: float) -> None:
        super().__init__()
        self.sigma_range = (min_time, max_time)
        self.rho = rho

    def __call__(self, up_to: int, device: torch.device | None=None) -> torch.Tensor:
        (eps, T), rho = self.sigma_range, self.rho
        rho_inv = 1.0 / rho
        steps = torch.arange(up_to, device=device) / max(up_to - 1, 1)
        sigmas = eps**rho_inv + steps * (T**rho_inv - eps**rho_inv)
        sigmas = sigmas**rho
        return sigmas


In [6]:

class BaseEMADecay(ABC):
    @abstractmethod
    def __call__(self, iteration: int) -> float:
        ...

    def __init__(self) -> None:
        super().__init__()


class ExponentialDecay(BaseEMADecay):
    def __init__(self, initial_decay: float, training_iterations: int) -> None:
        super().__init__()
        self.initial_decay = initial_decay
        self.training_iterations = training_iterations

    def __call__(self, iteration: int) -> float:
        return math.exp(
            iteration * math.log(self.initial_decay) / self.training_iterations
        )

In [7]:
class BaseParametrization(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, x: Tensor, y: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        ...
        

class BaseResampler(ABC):
    def __init__(self) -> None:
        super().__init__()

    @abstractmethod
    def __call__(self, x: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        ...


class EpsilonParametrization(BaseParametrization):
    def __init__(self, min_time: float, data_time: float) -> None:
        super().__init__()
        self.data_time = data_time
        self.min_time = min_time

    def skip(self, t: Tensor) -> Tensor:
        return (self.data_time ** 2) / ((t - self.min_time) ** 2 + (self.data_time ** 2))

    def out(self, t: Tensor) -> Tensor:
        return (t - self.min_time) * self.data_time / (self.data_time**2 + t**2) ** 0.5    

    def __call__(self, x: Tensor, y: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        return self.skip(t)[batch, None] * x + self.out(t)[batch, None] * y 
    

class EpsilonResampler(BaseResampler):
    def __init__(self, min_time: float) -> None:
        super().__init__()
        self.min_time = min_time

    def __call__(self, x: Tensor, t: Tensor, batch: Tensor) -> Tensor:
        mul = (t**2 - self.min_time**2)**0.5
        return x + mul[batch, None] * torch.randn_like(x)

In [8]:
import copy


class BaseConditionedModel(nn.Module, ABC):
    @abstractmethod
    def forward(self, x: Tensor, t: Tensor, ctx: torch.Tensor, batch: Tensor) -> Tensor:
        ...

    def __init__(self) -> None:
        super().__init__()

    def __call__(self, x: Tensor, t: Tensor, ctx: Tensor, batch: Tensor) -> Tensor:
        return super().__call__(x=x, t=t, ctx=ctx, batch=batch)
    


In [9]:
from typing import Sequence

class ConsistencyModel(nn.Module):
    def __init__(self, model: BaseConditionedModel, resampler: BaseResampler, parametrization: BaseParametrization) -> None:
        super().__init__()
        self.model = model
        self.resampler = resampler
        self.parametrization = parametrization

    def forward(self, x: Tensor, t: Tensor | Sequence[Tensor], ctx: Tensor, batch: Tensor) -> Tensor:
        assert not isinstance(t, int)
        ts = (t,) if isinstance(t, Tensor) else t

        x = self.wrapped_model(x=x, t=ts[0], ctx=ctx, batch=batch)
        for t in ts[1:]:
            x = self.resampler(x=x, t=t, batch=batch)
            x = self.wrapped_model(x=x, t=t, ctx=ctx, batch=batch)
        return x

    def wrapped_model(self, x: Tensor, t: Tensor, ctx: Tensor, batch: Tensor) -> Tensor:
        y = self.model(x=x, t=t, ctx=ctx, batch=batch)
        return self.parametrization(x=x, y=y, t=t, batch=batch)

    def __call__(self, x: Tensor, t: Tensor | Sequence[Tensor], ctx: Tensor, batch: Tensor) -> Tensor:
        return super().__call__(x=x, t=t, ctx=ctx, batch=batch)


class ConsistencyTrainer(nn.Module):
    def __init__(self, 
                 model: ConsistencyModel, 
                 ema: ConsistencyModel,
                 step_schedule: BaseNumTimestepsSchedule, 
                 time_schedule: BaseTimeSchedule,
                 ema_decay: BaseEMADecay):
        super().__init__()
        self.model = model
        self.ema = ema

        self.step_schedule = step_schedule
        self.time_schedule = time_schedule
        self.ema_decay = ema_decay
        
    def get_times(self, iteration: int, batch_size: int, device: torch.device | None=None) -> tuple[torch.Tensor, torch.Tensor]:
        num_timesteps = self.step_schedule(iteration)
        times = self.time_schedule(num_timesteps, device=device)

        time_indices = torch.randint(0, num_timesteps - 1, (batch_size,), device=device)
        current_times = times[time_indices]
        next_times = times[time_indices + 1]

        return current_times, next_times

    def train_step(self, iteration: int, x: Tensor, ctx: Tensor, batch: Tensor):
        current_times, next_times = self.get_times(iteration, x.shape[0], device=x.device)

        z = torch.randn_like(x, device=x.device)
        next_x = x + z * next_times[batch, None]
        denoised = self.model(
            x=next_x, 
            t=next_times, 
            ctx=ctx,
            batch=batch, 
        )

        with torch.no_grad():
            current_x = x + z * current_times[batch, None]
            ema_denoised = self.ema(
                x=current_x, 
                t=current_times, 
                ctx=ctx,
                batch=batch, 
            )

        return denoised, ema_denoised 

    @torch.no_grad()
    def update_emas(self, iteration: int):
        alpha = self.ema_decay(iteration)
        for p, ema_p in zip(self.model.parameters(), self.ema.parameters()):
            ema_p.data = alpha * ema_p.data + (1 - alpha) * p.data



In [10]:

import torch
from torch_geometric.nn import knn

complete = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.tensor([[-1.0, 0.0]])
assign_index = knn(complete, y, 2, batch_x)

assign_index

tensor([[0, 0],
        [0, 1]])

In [11]:
settings = ConsistencySettings(
    training_iterations=100_000,
    min_time_partitions=2,
    max_time_partitions=150,
    initial_ema_decay=0.95,
    min_time=1e-4,
    data_time=1e-3,
    max_time=1.0,
    rho=7.0
)

step_schedule = NumTimestepsSchedule(
    settings.min_time_partitions,
    settings.max_time_partitions,
    settings.training_iterations
)
time_schedule = KarrasTimeSchedule(
    settings.min_time,
    settings.max_time,
    settings.rho
)
ema_decay = ExponentialDecay(
    settings.initial_ema_decay,
    settings.training_iterations
)
parametrization = EpsilonParametrization(
    settings.min_time,
    settings.data_time
)
resampler = EpsilonResampler(
    settings.min_time
)
loss_function = ChamferDistance()
eval_function = ChamferDistance()


In [12]:
%load_ext autoreload
%autoreload 2

In [13]:
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import NormalizeScale, FixedPoints, Compose
from common.transforms.knn_view import KNNSplit

pre_transform = NormalizeScale()
transform = Compose([FixedPoints(4096), KNNSplit(3072)])

root = "data/ShapeNetAll"
train_dataset = ShapeNet(root=root, categories=None, pre_transform=pre_transform, transform=transform, split="train")
val_dataset = ShapeNet(root=root, categories=None, pre_transform=pre_transform, transform=transform, split="val")
test_dataset = ShapeNet(root=root, categories=None, pre_transform=pre_transform, transform=transform, split="test")

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, follow_batch=["pos", "incomplete"], drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True, follow_batch=["pos", "incomplete"], drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True, follow_batch=["pos", "incomplete"], drop_last=True)


In [14]:
from models.backbone.att_dgcnn import AttDGCNNEncoder
from models.backbone.pointnet import PointNetEncoder
from models.backbone.glu import GLUDecoder


class LinearResBlock(nn.Module):
    def __init__(self, dim_in: int, dim_out: int) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            nn.Linear(dim_out, dim_out),
            nn.SiLU(),
        )
        self.residual = nn.Linear(dim_in, dim_out)


    def forward(self, x: Tensor) -> Tensor:
        return self.model(x) + self.residual(x)


class SharedEncoder(torch.nn.Module):
    def __init__(self, global_feat_size=512, shared_feat_size=128):
        super().__init__()
        self.encoder = PointNetEncoder(zdim=global_feat_size)
        self.complete_linear = LinearResBlock(global_feat_size, shared_feat_size)
        self.incomplete_linear = LinearResBlock(global_feat_size, shared_feat_size)
        self.code_linear = LinearResBlock(shared_feat_size, global_feat_size)

    def forward(
        self, 
        ctx_pos: Tensor, 
        ctx_batch: Tensor,  
        is_complete: bool = False
    ) -> Tensor:
        
        encoding = self.encoder(pos=ctx_pos, batch=ctx_batch)
        if is_complete:
            encoding_2 = self.complete_linear(encoding)
        else:
            encoding_2 = self.incomplete_linear(encoding)
        encoding_2 = self.code_linear(encoding_2)
        return encoding + encoding_2

    def __call__(
        self,
        ctx_pos: Tensor, 
        ctx_batch: Tensor,  
        is_complete: bool = False
    ) -> Tensor:
        return self.forward(ctx_pos, ctx_batch, is_complete)


class GLUConditionedModel(BaseConditionedModel):
    def __init__(self, dim_ctx) -> None:
        super().__init__()
        self.model = GLUDecoder(dim_ctx=dim_ctx)

    def forward(self, x: Tensor, t: Tensor, ctx: Tensor, batch: Tensor) -> Tensor:
        return self.model(x=x, t=t, ctx=ctx, batch=batch)


encoder = SharedEncoder(512)
model = ConsistencyModel(
    model=GLUConditionedModel(512),
    resampler=resampler,
    parametrization=parametrization,
)
ema = copy.deepcopy(model)
optimizer = torch.optim.AdamW(list(model.parameters()) + list(encoder.parameters()), lr=1e-3)

In [19]:
from common.training import get_data_iterator
# from common.visualization import visualize_batch_points
from common.data import MyDataBatched
from torch.utils.tensorboard.writer import SummaryWriter
import socket
import os
from datetime import datetime


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
t_cur = datetime.now().strftime("%b%d_%H-%M-%S")
experiment_name = t_cur + "_" + socket.gethostname() + "_sim"
log_dir = os.path.join("runs/bruh", experiment_name)
ckpt_dir = os.path.join("checkpoints", experiment_name)
os.makedirs(ckpt_dir, exist_ok=True)


writer = SummaryWriter(log_dir=log_dir)
trainer = ConsistencyTrainer(model, ema, step_schedule, time_schedule, ema_decay)
encoder = encoder.to(device)
model = model.to(device)
ema = ema.to(device)



def train(epoch: int, data: MyDataBatched):
    epoch = epoch + 1

    optimizer.zero_grad()  # Clear gradients.
    model.train()
    data = data.to(device) # type: ignore

    complete, incomplete = data.pos, data.incomplete
    complete_batch, incomplete_batch = data.pos_batch, data.incomplete_batch
    
    complete_ctx = encoder(complete, complete_batch, is_complete=True)
    incomplete_ctx = encoder(incomplete, incomplete_batch, is_complete=False)

    batch_size = int(complete_batch.max()) + 1
    t_cur, t_next = trainer.get_times(epoch, batch_size, device=device)
    z = torch.randn_like(complete, device=device)
    
    xzc = complete + z * t_next[complete_batch, None]
    xc = model(x=xzc, t=t_next, ctx=complete_ctx, batch=complete_batch)
    xi = model(x=xzc, t=t_next, ctx=incomplete_ctx, batch=complete_batch)
    with torch.no_grad():
        xze = complete + z * t_cur[complete_batch, None]
        xe = ema(x=xze, t=t_cur, ctx=complete_ctx, batch=complete_batch)

    consistency_loss = loss_function(xc, xe, complete_batch)
    similarity_loss = 2 * loss_function(xc, xi, complete_batch)
    loss = consistency_loss + similarity_loss

    loss.backward()  # Backward pass.
    optimizer.step()  # Update model parameters.
    trainer.update_emas(epoch)  # Update EMA parameters.

    return consistency_loss.item(), similarity_loss.item()
    


@torch.no_grad()
def validate_one(epoch: int):
    epoch = epoch + 1
    model.eval()

    data: MyDataBatched = next(iter(test_loader))
    data = data.to(device) # type: ignore
    
    complete, incomplete = data.pos, data.incomplete
    complete_batch, incomplete_batch = data.pos_batch, data.incomplete_batch

    ones_mask = torch.ones(int(complete_batch.max() + 1), device=device)
    time = settings.max_time * ones_mask
    random_base = resampler(complete, time, complete_batch)

    complete_feat = encoder(complete, complete_batch, is_complete=True)
    incomplete_feat = encoder(incomplete, incomplete_batch, is_complete=False)
    complete_recon = model(random_base, time, complete_feat, complete_batch)
    incomplete_recon = model(random_base, time, incomplete_feat, complete_batch)
    
    complete_loss = eval_function(complete_recon, complete, complete_batch)
    incomplete_loss = eval_function(incomplete_recon, complete, complete_batch)

    return complete_loss, incomplete_loss


@torch.no_grad()
def validate(epoch):
    epoch = epoch + 1

    group_complete_loss, group_incomplete_loss = 0.0, 0.0
    model.eval()
    for data in val_loader:
        data: MyDataBatched = data.to(device) # type: ignore
        complete_loss, incomplete_loss = validate_one(epoch)
        group_complete_loss += complete_loss
        group_incomplete_loss += incomplete_loss

    group_complete_loss /= len(val_loader)
    group_incomplete_loss /= len(val_loader)

    return group_complete_loss, group_incomplete_loss




def training_loop(start=0):
    data_iterator = get_data_iterator(train_loader)
    for epoch in range(start, settings.training_iterations + 1):
        c_loss, s_loss = train(epoch, next(data_iterator))
        if epoch % 1 == 0:
            print(f"T. It: {epoch}, Consistency loss: {c_loss}, Similarity loss: {s_loss}")
            writer.add_scalar("consistency_loss", c_loss, epoch)
            writer.add_scalar("similarity_loss", s_loss, epoch)

        if epoch and epoch % 1 == 0:
            val_c, val_i = validate_one(epoch)
            print(f"V. It: {epoch}, Complete loss: {val_c}, Incomplete loss: {val_i}")
            writer.add_scalar("single_val_complete_loss", val_c, epoch)
            writer.add_scalar("single_val_incomplete_loss", val_i, epoch)

        # if epoch and epoch % 1000 == 0:
        #     val_c, val_i = validate(epoch)
        #     print(f"Full val. It: {epoch}, Complete loss: {val_c}, Incomplete loss: {val_i}")
        #     writer.add_scalar("val_complete_loss", val_c, epoch)
        #     writer.add_scalar("val_incomplete_loss", val_i, epoch)


        if epoch and epoch % 1000 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'encoder_state_dict': encoder.state_dict(),
                'ema_state_dict': ema.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(ckpt_dir, f"checkpoint_{epoch}.pth"))
        
        


In [18]:
state_file = None
if state_file is not None:
    state = torch.load(state_file)
    start = state['epoch'] + 1
    encoder.load_state_dict(state['encoder_state_dict'])
    model.load_state_dict(state['model_state_dict'])
    ema.load_state_dict(state['ema_state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    print(f"Loaded checkpoint from {state_file}")
else:
    start = 0
    ema.load_state_dict(model.state_dict()) 

training_loop(start=start)

T. It: 0, Consistency loss: 0.18683573603630066, Similarity loss: 1.192650233861059e-05
T. It: 1, Consistency loss: 0.1540910303592682, Similarity loss: 1.531501766294241e-05
V. It: 1, Complete loss: 0.42739608883857727, Incomplete loss: 0.4273129999637604
T. It: 2, Consistency loss: 0.261183500289917, Similarity loss: 2.5237886802642606e-05
V. It: 2, Complete loss: 0.49902090430259705, Incomplete loss: 0.49916473031044006
T. It: 3, Consistency loss: 0.14014650881290436, Similarity loss: 5.18938250024803e-05
V. It: 3, Complete loss: 0.3523908257484436, Incomplete loss: 0.3526962101459503
T. It: 4, Consistency loss: 0.2651427984237671, Similarity loss: 4.9600639613345265e-05
V. It: 4, Complete loss: 0.48475709557533264, Incomplete loss: 0.4846714735031128
T. It: 5, Consistency loss: 0.19216829538345337, Similarity loss: 1.760788290994242e-05
V. It: 5, Complete loss: 0.38546887040138245, Incomplete loss: 0.3855758607387543
T. It: 6, Consistency loss: 0.04829737916588783, Similarity loss:

KeyboardInterrupt: 