In [1]:
import wandb
from pathlib import Path

api = wandb.Api()
eval_run = api.run("dpfrommer-projects/image-diffusion-eval/igbrpifm")
run = api.run(eval_run.config["run"])
checkpoints = run.logged_artifacts()
iter_artifacts = {}
for artifact in checkpoints:
    if artifact.type != "model": continue
    iterations = artifact.metadata["step"]
    if iterations % 10000 == 0:
        iter_artifacts[iterations] = artifact
output = eval_run.logged_artifacts()[0]
print("Eval Artifact:", output.qualified_name)
output = Path(output.download())

Eval Artifact: dpfrommer-projects/image-diffusion-eval/evaluation:v10


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m:   29 of 29 files downloaded.  


In [2]:
from image_diffusion.main import logger
logger.setLevel('INFO')

In [3]:
import argon.util.serialize

path = Path(iter_artifacts[50_000].download()) / "checkpoint.zarr.zip"
checkpoint = argon.util.serialize.load_zarr(path)

[34m[1mwandb[0m: Downloading large artifact mnist-ddpm-050000:v2, 280.86MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.5


In [None]:
import zarr
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import numpy as np

data = []
eval_results = {}
for file in output.iterdir():
  iteration = int(file.name.strip(".zarr.zip"))
  results = argon.util.serialize.load_zarr(file)
  eval_results[iteration] = results
  with zarr.open(file) as zf:
    for lin_error, nw_error, t, cond in zip(results.lin_error.reshape(-1),
                                      results.nw_error.reshape(-1),
                                      results.ts.reshape(-1),
                                      results.cond[:,None,:].repeat(4, 1).reshape(-1, 2)):
      data.append({
          "iteration": iteration,
          "lin_error": lin_error,
          "nw_error": nw_error,
          "cond_x": cond[0],
          "cond_y": cond[1],
          "t": t
      }
  )
data = pd.DataFrame(data)
data.sort_values(by=["iteration"], inplace=True)
data

/home/daniel/Documents/code/argon/projects/image-diffusion/notebooks/artifacts/evaluation:v10/150000.zarr.zip
/home/daniel/Documents/code/argon/projects/image-diffusion/notebooks/artifacts/evaluation:v10/200000.zarr.zip
/home/daniel/Documents/code/argon/projects/image-diffusion/notebooks/artifacts/evaluation:v10/140000.zarr.zip
/home/daniel/Documents/code/argon/projects/image-diffusion/notebooks/artifacts/evaluation:v10/280000.zarr.zip
/home/daniel/Documents/code/argon/projects/image-diffusion/notebooks/artifacts/evaluation:v10/000000.zarr.zip
/home/daniel/Documents/code/argon/projects/image-diffusion/notebooks/artifacts/evaluation:v10/080000.zarr.zip
/home/daniel/Documents/code/argon/projects/image-diffusion/notebooks/artifacts/evaluation:v10/240000.zarr.zip
/home/daniel/Documents/code/argon/projects/image-diffusion/notebooks/artifacts/evaluation:v10/210000.zarr.zip
/home/daniel/Documents/code/argon/projects/image-diffusion/notebooks/artifacts/evaluation:v10/250000.zarr.zip
/home/dani

Unnamed: 0,iteration,lin_error,nw_error,cond_x,cond_y,t
10238,0,5.1792636,45.935543,2.3100395,1.2500505,12
8864,0,1.136258,46.566856,-0.31650496,-1.0692525,23
8865,0,1.1977514,44.738194,-0.31650496,-1.0692525,27
8866,0,0.9248023,45.25426,-0.31650496,-1.0692525,3
8867,0,1.478752,41.89927,-0.31650496,-1.0692525,32
...,...,...,...,...,...,...
7516,280000,0.39940286,0.7710396,-0.07392883,-1.2733829,31
7517,280000,3.7639937,8.5941515,-0.07392883,-1.2733829,6
7518,280000,2.5289836,7.1196814,-0.07392883,-1.2733829,20
6144,280000,10.559726,10.680752,-1.5111828,2.375042,6


In [None]:
import matplotlib.pyplot as plt

for i in []

In [5]:
import functools
import jax
import argon.graphics
import argon.core as F
import argon.numpy as npx
from IPython.display import display
from argon.train import Image

schedule = checkpoint.schedule
vars = checkpoint.vars
model = checkpoint.config.create()

normalizer, train_data, test_data = checkpoint.create_data()
train_data = jax.vmap(normalizer.normalize)(train_data.as_pytree())
test_data = jax.vmap(normalizer.normalize)(test_data.as_pytree())

In [11]:
from image_diffusion.eval import KeypointModel

keypoint_vars = eval_results[100_000].alpha_vars
keypoints = eval_results[100_000].keypoints
keypoint_model = KeypointModel(len(keypoints))

In [15]:

sampling_cond = np.array([3.0, 0.0])

@functools.partial(jax.jit, static_argnums=(0, 3,))
def sample_trajs(denoiser, cond, rng_key, N):
    def sample(rng_key):
        sample, traj = schedule.sample(rng_key, denoiser, npx.zeros(test_data.data[0].shape), trajectory=True)
        outputs = jax.lax.map(lambda s: denoiser(None, s[0], s[1]), (traj, npx.arange(1, 1 + traj.shape[0])))
        return sample, traj, outputs
    samples, trajs, outputs = jax.lax.map(sample, argon.random.split(rng_key, N), batch_size=8)
    samples = (128*(samples+1)).astype(npx.uint8)
    return Image(argon.graphics.image_grid(samples)), trajs, outputs

@functools.partial(jax.jit, static_argnums=(2,))
def nn_sample(cond, rng_key, N):
    def denoiser(rng_key, x, t):
        return model.apply(vars, x, t - 1, cond=cond)
    return sample_trajs(denoiser, cond, rng_key, N)

@functools.partial(jax.jit, static_argnums=(2,))
def linear_sample(cond, rng_key, N):
    def denoiser(rng_key, x, t):
        alphas = keypoint_model.apply(keypoint_vars, cond, t)
        out_keypoints = F.vmap(lambda k: model.apply(vars, x, t-1, cond=k))(keypoints)
        interpolated = alphas[:, None, None, None] * out_keypoints
        interpolated = npx.sum(interpolated, axis=0)
        return interpolated
    return sample_trajs(denoiser, cond, rng_key, N)

nn_grid, nn_trajs, nn_outputs = nn_sample(sampling_cond, jax.random.key(42), 16)
lin_grid, lin_trajs, lin_outputs = linear_sample(sampling_cond, jax.random.key(42), 16)
display(nn_grid)
display(lin_grid)

HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00p\x00\x00\x00p\x08\x02\x00\x00\x00…

HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00p\x00\x00\x00p\x08\x02\x00\x00\x00…