In [None]:
from dataclasses import dataclass

@dataclass
class DiffusionConfig:
    batch_size: int = 32
    num_epochs: int = 100
    save_image_epochs: int = 1
    save_model_epochs: int = 1
    learning_rate: float = 5e-5
    num_warmup_steps: int = 400
    push_to_hub: bool = False
    output_dir: str = "output/"
    num_train_timesteps: int = 1000

config = DiffusionConfig()

In [None]:
import torch_geometric
from torch_geometric.data import DataLoader

data = torch_geometric.datasets.QM9(root="./data/")
loader = DataLoader(data, follow_batch=[""], batch_size=config.batch_size, shuffle=True)

In [None]:
from model.diffusionstep import DiffusionStep

diffstep = DiffusionStep(33, 256, n_heads=4, num_layers=6)
f"num parameters: {sum(p.numel() for p in diffstep.parameters()):_}"

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./output/logs/

In [None]:
import torch
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=config.num_train_timesteps)

optimizer = torch.optim.AdamW(diffstep.parameters(), lr=config.learning_rate, weight_decay=0.01)

lr_scheduler = get_cosine_schedule_with_warmup(
  optimizer=optimizer,
  num_warmup_steps=config.num_warmup_steps,
  num_training_steps=(len(loader) * config.num_epochs),
)

In [None]:
from model.train import train_diffusion

train_diffusion(diffstep, config, loader, noise_scheduler=noise_scheduler, optimizer=optimizer, lr_scheduler=lr_scheduler)

In [None]:
from matplotlib import pyplot as plt
from IPython.display import clear_output
from tqdm import tqdm
from model.utils import kabsch_torch_batched

loader1 = DataLoader(data, follow_batch=[""], batch_size=1, shuffle=True)

def evaluate(model, batch, device, num_inference_steps=500, ):
  print(config.num_train_timesteps)
  noise_scheduler = DDPMScheduler(num_train_timesteps=config.num_train_timesteps)
  progress_bar = tqdm(total=num_inference_steps)
  image = torch.randn((1, 33)).expand(batch.pos.shape[0], -1)

  image = image.to(device)

  # set step values

  print(noise_scheduler)
  noise_scheduler.set_timesteps(num_inference_steps)


  for t in noise_scheduler.timesteps[:-1]:
    # 1. predict noise model_output
    timestep = torch.tensor([t], dtype=torch.long, device=device)
    timestep = timestep.expand(image.shape[0])

    model_output = model(image, batch.z.to(device), batch.edge_index.to(device), gnn_time_step=1, diffusion_time=timestep)

    # 2. predict previous mean of image x_t-1 and add variance depending on eta
    # eta corresponds to η in paper and should be between [0, 1]
    # do x_t -> x_t-1
    image = noise_scheduler.step(model_output, t, image).prev_sample

    ones = torch.ones(image.shape[0], 1).to(image.device)

    R, t = kabsch_torch_batched(image[None, :, :3], batch.pos[None, :, :3].cuda())
    warped = image[:, :3] - image[:, :3].mean(dim=0, keepdims=True)
    warped = warped[:, :3] @ R.squeeze().T

    c = batch.pos - batch.pos.mean(dim=0, keepdims=True)

    plt.scatter(*(warped).T.detach().cpu().numpy())
    plt.scatter(*c.T)

    clear_output(wait=True)
    plt.show()

  return image, batch.edge_index, batch.pos

im = evaluate(diffstep, next(iter(loader1)), "cuda")

