In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import random 
import numpy as np


def get_data_iterator(iterable):
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()


def set_random_seed(seed, deterministic=False):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True # type: ignore
        torch.backends.cudnn.benchmark = False # type: ignore



In [3]:
from src.models.consistency.model import BaseConditionedModel
from src.models.backbone import AttnPointNetEncoder
from src.models.backbone.unet import UNet
from torch import Tensor


class UNetConditionedModel(BaseConditionedModel):
    def __init__(self, dim_model, dim_ctx) -> None:
        super().__init__()
        self.model = UNet(in_channels=3, out_channels=3, dim_model=dim_model, dim_ctx=dim_ctx)

    def forward(self, x: Tensor, t: Tensor, ctx: Tensor, batch: Tensor, ctx_batch: Tensor) -> Tensor:
        return self.model(pos=x, batch=batch, par=ctx, par_batch=ctx_batch, t=t)
    



In [4]:
from src.models.consistency.settings import ConsistencySettings
from src.models.consistency.model import ConsistencyModel, EpsilonResampler, SkipParametrization
from src.models.consistency.schedule import KarrasTimeSchedule, NumTimestepsSchedule, TimeScheduler
from src.models.consistency.ema import EMAModel, ExponentialDecay
from src.loss import ChamferDistance, MSEDistance


set_random_seed(69420)


settings = ConsistencySettings(
    training_iterations=250_000,
    min_time_partitions=2,
    max_time_partitions=150,
    initial_ema_decay=0.95,
    min_time=1e-4,
    data_time=0.29,
    max_time=16.0,
    rho=7.0
)

sigma_schedule = TimeScheduler(
    step_schedule=NumTimestepsSchedule(
        settings.min_time_partitions,
        settings.max_time_partitions,
        settings.training_iterations
    ), 
    time_schedule=KarrasTimeSchedule(
        settings.min_time,
        settings.max_time,
        settings.rho
    )
)


model = ConsistencyModel(
    model=UNetConditionedModel(dim_model=[32, 64, 128, 256], dim_ctx=256),
    resampler = EpsilonResampler(
        settings.min_time
    ),
    parametrization=SkipParametrization(
        settings.min_time,
        settings.data_time
    ),
)

ema = EMAModel(
    model=model,
    decay=ExponentialDecay(
        settings.initial_ema_decay,
        settings.training_iterations
    )
)

mse = MSEDistance()
chamfer_distance = ChamferDistance()

optimizer = torch.optim.AdamW(list(model.parameters()), lr=1e-3)




In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
ema.model = ema.model.to(device)

In [6]:
from torch_geometric.datasets import ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import NormalizeScale, FixedPoints, Compose
from src.transforms.knn_view import KNNSplit

pre_transform = NormalizeScale()
transform = Compose([FixedPoints(1024), KNNSplit(3 * 1024 // 4, attr="par")])

root = "data/ShapeNetAll"
categories = ['Airplane']
train_dataset = ShapeNet(root=root, categories=categories, pre_transform=pre_transform, transform=transform, split="train")
val_dataset = ShapeNet(root=root, categories=categories, pre_transform=pre_transform, transform=transform, split="val")
test_dataset = ShapeNet(root=root, categories=categories, pre_transform=pre_transform, transform=transform, split="test")

train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, follow_batch=["pos", "par"], drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=True, follow_batch=["pos", "par"], drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=True, follow_batch=["pos", "par"], drop_last=True)


In [7]:
from torch.utils.tensorboard.writer import SummaryWriter
import socket
import os
from datetime import datetime


t_cur = datetime.now().strftime("%b%d_%H-%M-%S")
experiment_name = t_cur + "_" + socket.gethostname() + "_sim"
log_dir = os.path.join("runs/bruh", experiment_name)
ckpt_dir = os.path.join("checkpoints", experiment_name)


os.makedirs(ckpt_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)


In [8]:
from src.data import MyDataBatched


def train(epoch: int, data: MyDataBatched) -> tuple[float, float, float]:
    epoch = epoch + 1

    optimizer.zero_grad()  # Clear gradients.
    model.train()
    data = data.to(device) # type: ignore

    pos, par = data.pos, data.par
    pos_batch, par_batch = data.pos_batch, data.par_batch

    batch_size = int(pos_batch.max()) + 1
    t_cur, t_next = sigma_schedule.get_times(epoch, batch_size, device=device)
    z = torch.randn_like(pos, device=device)
    
    x_comp = pos + z * t_next[pos_batch, None]
    x_comp = model(x=x_comp, t=t_next, ctx=par, batch=pos_batch, ctx_batch=par_batch)
    
    with torch.no_grad():
        x_ema = pos + z * t_cur[pos_batch, None]
        x_ema = ema.model(x=x_ema, t=t_cur, ctx=par, batch=pos_batch, ctx_batch=par_batch)

    loss = chamfer_distance(x_comp, pos, pos_batch)

    loss.backward()  # Backward pass.
    optimizer.step()  # Update model parameters.
    ema.update(model, epoch)

    return float(loss), float(0), float(0)
    


@torch.no_grad()
def validate_one(epoch: int):
    epoch = epoch + 1
    model.eval()

    data: MyDataBatched = next(iter(test_loader))
    data = data.to(device) # type: ignore
    
    pos, par = data.pos, data.par
    pos_batch, par_batch = data.pos_batch, data.par_batch

    batch_size = int(pos_batch.max()) + 1
    _, t_next = sigma_schedule.get_times(epoch, batch_size, device=device)

    random_base = model.resampler(pos, t_next, pos_batch)

    xcz = model(x=random_base, t=t_next, ctx=par, batch=pos_batch, ctx_batch=par_batch)
    
    complete_loss = chamfer_distance(xcz, pos, pos_batch)

    return complete_loss.item(), 0


from src.visualization import visualize_batch_points


@torch.no_grad()
def visualize_sample(epoch: int):
    model.eval()
    ema.model.eval()

    sample: MyDataBatched = next(iter(test_loader))
    sample = sample.to(device) # type: ignore

    pos, par = sample.pos, sample.par
    pos_batch, par_batch = sample.pos_batch, sample.par_batch

    batch_size = int(pos_batch.max()) + 1
    _, t_next = sigma_schedule.get_times(epoch, batch_size, device=device)

    n = pos.shape[0] // batch_size
    n_cur = par.shape[0] // batch_size

    par_pad_batch = torch.repeat_interleave(torch.arange(batch_size, device=device), n - n_cur)
    par_pad = torch.randn((batch_size * (n - n_cur), par.shape[1]), device=device) * settings.data_time
    par_padded = torch.cat([par, par_pad], dim=0)
    par_padded_batch = torch.cat([par_batch, par_pad_batch], dim=0)

    random_base = model.resampler(par_padded, t_next, par_padded_batch)

    complete_recon = model(x=random_base, t=t_next, ctx=par, batch=par_padded_batch, ctx_batch=par_batch)

    comp = visualize_batch_points(pos, pos_batch)
    comp_r = visualize_batch_points(complete_recon, par_padded_batch)

    return comp, comp_r, #incomp_r




def training_loop(start):
    data_iterator = get_data_iterator(train_loader)
    for epoch in range(start, settings.training_iterations + 1):
        c_loss, r_loss, s_loss = train(epoch, next(data_iterator))
        if epoch % 1 == 0:
            print(f"T. It: {epoch}, Consistency loss: {c_loss}, Similarity loss: {s_loss}, Reconstruction loss: {r_loss}")
            writer.add_scalar("consistency_loss", c_loss, epoch)
            writer.add_scalar("reconstruction_loss", r_loss, epoch)
            writer.add_scalar("similarity_loss", s_loss, epoch)

        if epoch and epoch % 1 == 0:
            val_c, val_i = validate_one(epoch)
            print(f"V. It: {epoch}, Complete CD: {val_c}, Incomplete CD: {val_i}")
            writer.add_scalar("single_val_complete_loss", val_c, epoch)
            writer.add_scalar("single_val_incomplete_loss", val_i, epoch)

        if epoch % 100 == 0:
            complete, complete_recon = visualize_sample(epoch)
            writer.add_figure("complete", complete, epoch)
            # writer.add_figure("incomplete", incomplete, epoch)
            writer.add_figure("complete_recon", complete_recon, epoch)
            # writer.add_figure("incomplete_recon", incomplete_recon, epoch)
            complete.clear(); complete_recon.clear(); #incomplete.clear(); incomplete_recon.clear()

        if epoch and epoch % 1000 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'ema_state_dict': ema.model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(ckpt_dir, f"checkpoint_{epoch}.pth"))
        
        


In [9]:
import torch

set_random_seed(69420)
state_file = None
if state_file is not None:
    state = torch.load(state_file, map_location=device)
    start = state['epoch']
    model.load_state_dict(state['model_state_dict'])
    ema.model.load_state_dict(state['ema_state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    optimizer.param_groups[0]['capturable'] = True
    print(f"Loaded checkpoint from {state_file}")
else:
    start = 0



In [10]:
start = 0
training_loop(start=start)

T. It: 0, Consistency loss: 0.05816715210676193, Similarity loss: 0.0, Reconstruction loss: 0.0
T. It: 1, Consistency loss: 0.04365706443786621, Similarity loss: 0.0, Reconstruction loss: 0.0
V. It: 1, Complete CD: 0.06025184690952301, Incomplete CD: 0
T. It: 2, Consistency loss: 0.042816437780857086, Similarity loss: 0.0, Reconstruction loss: 0.0
V. It: 2, Complete CD: 0.06702850013971329, Incomplete CD: 0
T. It: 3, Consistency loss: 0.04537834972143173, Similarity loss: 0.0, Reconstruction loss: 0.0
V. It: 3, Complete CD: 0.08011848479509354, Incomplete CD: 0
T. It: 4, Consistency loss: 0.028706707060337067, Similarity loss: 0.0, Reconstruction loss: 0.0
V. It: 4, Complete CD: 0.08102038502693176, Incomplete CD: 0
T. It: 5, Consistency loss: 0.03282872587442398, Similarity loss: 0.0, Reconstruction loss: 0.0
V. It: 5, Complete CD: 0.07021422684192657, Incomplete CD: 0
T. It: 6, Consistency loss: 0.023308608680963516, Similarity loss: 0.0, Reconstruction loss: 0.0
V. It: 6, Complete C

KeyboardInterrupt: 