In [None]:
# @title Download model and data
!mkdir -p cache
!curl -Lo "cache/cache_r0_vbp.h5" "https://pub-2fdef7a2969f43289c42ac5ae3412fd4.r2.dev/cache_r0_vbp.h5"
!curl -Lo "cache/cache_r1_vbp.h5" "https://pub-2fdef7a2969f43289c42ac5ae3412fd4.r2.dev/cache_r1_vbp.h5"
!pip install -q safetensors diffusers omegaconf accelerate

# restart runtime
exit()

In [20]:
# @title Functions

%matplotlib inline

import os
# Hide welcome message from bitsandbytes
os.environ.update({"BITSANDBYTES_NOWELCOME": "1"})

import h5py, random, torch
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from pathlib import Path  
from torchvision import transforms
from diffusers import AutoencoderKL

# Load the model
vae = AutoencoderKL.from_pretrained("nyanko7/sdxl-vae-0.9")
vae.eval().cuda()
vae.requires_grad_(False)

def denormalize(img, mean=0.5, std=0.5):
    res = transforms.Normalize((-1*mean/std), (1.0/std))(img)
    res = torch.clamp(res, 0, 1)
    return res

def create_vds_for_group(source_group, target_group, bar):
    for key, item in source_group.items():
        if key in target_group:
            if key.endswith(".latents"):
                bar.update(1)
            continue
        layout = h5py.VirtualLayout(shape=item.shape, dtype=item.dtype)
        layout[:] = h5py.VirtualSource(item)
        target_group.create_virtual_dataset(key, layout)
        if key.endswith(".latents"):
            bar.update(1)

# Load latents from the h5 file
def load_latents_from_h5(h5_path, hashsum=None):
    cache_parts = list(Path(h5_path).glob("*.h5"))
    with h5py.File("cache_index.tmp", 'a', libver='latest', driver='core') as fo:  # using 'latest' for VDS support
        bar = tqdm(desc="Creating index")
        for input_file in cache_parts:
            with h5py.File(input_file, 'r') as fi:
                create_vds_for_group(fi, fo, bar)
                
    with h5py.File("cache_index.tmp", 'r') as f:
        # Get all datasets keys that match the pattern
        keys = [key for key in f.keys() if key.endswith(".latents")]
        chosen_key = hashsum
        if chosen_key is None:
            chosen_key = random.choice(keys)
            
        if (Path(h5_path) / "dataset.json").exists():
            import json
            with open(Path(h5_path) / "dataset.json", "r") as f2:
                dataset = json.load(f2)
                print(json.dumps(dataset[chosen_key[:-8]], indent=2))
        latents = f[chosen_key][:]
        return torch.asarray(latents).cuda().to(torch.float32)

def inspect(path="cache", hashsum=None):
    # Define path to h5 file containing latents
    latents = load_latents_from_h5(path, hashsum)

    # Decode latents
    print(f"stat: {latents.mean()}, {latents.std()}")
    latents = latents.unsqueeze(0)
    
    # latents = 1.0 / 0.13025 * latents
    with torch.no_grad():
        img_decoded = vae.decode(latents).sample

    # Convert tensors to numpy arrays for visualization
    img_decoded = denormalize(img_decoded).squeeze().permute(1, 2, 0).cpu().numpy()

    # Display the decoded image
    plt.figure(dpi=300)
    plt.imshow(img_decoded)
    plt.axis('off')  # turn off the axis
    plt.title('Decoded from Latents', fontsize=4)
    plt.show()

In [21]:
# @title Inspect
inspect(path="/notebooks/BA_latents", hashsum=None)