In [1]:
%load_ext autoreload
%autoreload 2

from text3d2video.artifacts.diffusion_data import DiffusionDataCfg

In [2]:
from pathlib import Path
import shutil

h5_file_path = Path('data.h5')
shutil.rmtree(h5_file_path, ignore_errors=True)
h5_file_path.touch(exist_ok=True)

In [3]:
from text3d2video.artifacts.diffusion_data import DiffusionData

cfg = DiffusionDataCfg(
    enabled=True,
    n_save_steps=3,
    n_save_frames=5,
    attn_paths=["layer1"]
)

diffusion_data = DiffusionData(cfg, h5_file_path)

In [4]:
from diffusers import DDIMScheduler

scheduler = DDIMScheduler()
scheduler.set_timesteps(10)

n_frames = 10

# setup diffusion data
diffusion_data.calculate_save_frames(n_frames)
diffusion_data.calculate_save_steps(scheduler)

In [5]:
from text3d2video.artifacts.diffusion_data import LatentsWriter
import torch

diffusion_data.end_recording()
diffusion_data.begin_recording()

latents_writer = LatentsWriter(diffusion_data)

latents = torch.randn(10, 4, 64, 64)

latents_writer.write_latents_batched(0, latents)

latents_read = latents_writer.read_latent(0, 0)

assert torch.allclose(latents[0], latents_read)

  return Tensor(dset)


In [6]:
from text3d2video.artifacts.diffusion_data import AttnFeaturesWriter

diffusion_data.end_recording()
diffusion_data.begin_recording()

attn_data_writer = AttnFeaturesWriter(diffusion_data)

qry = torch.randn(10, 1000, 200)
key = torch.randn(10, 1000, 200)
value = torch.randn(10, 1000, 200)

attn_data_writer.write_qkv_batched(0, 'layer1', qry, key, value)

qry_read = attn_data_writer.read_qry(0, 0, 'layer1')
key_read = attn_data_writer.read_key(0, 0, 'layer1')
value_read = attn_data_writer.read_val(0, 0, 'layer1')

assert torch.allclose(qry[0], qry_read)
assert torch.allclose(key[0], key_read)
assert torch.allclose(value[0], value_read)

In [9]:
from text3d2video.artifacts.gr_data import GrDataWriter


gr_writer = GrDataWriter(diffusion_data)

layers = ["layer1", "layer2", "layer3"]
vert_features = {layer: torch.randn(100, 200) for layer in layers}

gr_writer.write_vertex_features(0, vert_features)