# Identity Trick Comparison
In this notebook, we:
1. Load our custom diffusion model (trained on color patches).
2. Load the official Stable Diffusion v1.5 model from Hugging Face.
3. Perform an identity test on both models for a couple of images.

In [1]:
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image

import matplotlib.pyplot as plt
import os

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: mps


In [2]:
def identity_test(unet, vae, latents_scaling, real_image, device):
    """
    For an input image:
    1. Encode image -> latents (using VAE)
    2. Assume t=0 (no noise)
    3. Model should predict ~0 noise
    Returns the MSE between predicted noise and 0.
    """
    real_image = real_image.to(device)

    with torch.no_grad():
        latents = vae.encode(real_image).latent_dist.sample()
        latents = latents * latents_scaling

        t = torch.tensor([0]*latents.shape[0], device=device).long()
        predicted_noise = unet(latents, t).sample
        mse_value = F.mse_loss(predicted_noise, torch.zeros_like(predicted_noise)).item()

    return mse_value

In [3]:
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.unets.unet_2d import UNet2DModel
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler


MY_UNET_CKPT = "../../color_diffusion_checkpoints/v1/unet_epoch_9.pt"
vae_my = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)
vae_my.eval()
for param in vae_my.parameters():
    param.requires_grad_(False)

unet_my = UNet2DModel(
    sample_size=64,
    in_channels=4,
    out_channels=4,
    layers_per_block=2,
    block_out_channels=(320, 640, 640, 1280),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
)

unet_my.load_state_dict(torch.load(MY_UNET_CKPT, map_location="cpu"))
unet_my.to(device)
unet_my.eval()

my_latent_scaling = 0.18215

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Suppose we test with a few color images from your dataset
test_image_paths = [
    "../../datasets/colors/black/black3.png",
    "../../datasets/colors/blue/blue3.png",
    "../../datasets/colors/green/green3.jpg",
    "../../datasets/colors/red/red3.jpg",
    "../../datasets/colors/yellow/yellow3.jpg",
]

transform_my = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

for img_path in test_image_paths:
    img = Image.open(img_path).convert("RGB")
    img_tensor = transform_my(img).unsqueeze(0)  # shape [1,3,H,W]
    mse = identity_test(unet_my, vae_my, my_latent_scaling, img_tensor, device)
    print(f"Identity MSE for '{os.path.basename(img_path)}': {mse:.6f}")

Identity MSE for 'black3.png': 0.431787
Identity MSE for 'blue3.png': 0.082421
Identity MSE for 'green3.jpg': 0.044352
Identity MSE for 'red3.jpg': 0.607667
Identity MSE for 'yellow3.jpg': 0.182590


In [5]:
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL


# Load VAE and UNet separately
vae_15 = AutoencoderKL.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="vae",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)

unet_15 = UNet2DConditionModel.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="unet",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)

vae_15.eval()
unet_15.eval()
for p in vae_15.parameters():
    p.requires_grad_(False)
for p in unet_15.parameters():
    p.requires_grad_(False)

# We'll assume the same scaling factor 0.18215, though it's typically in config too
sd15_latent_scaling = 0.18215

In [6]:

transform_15 = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

for img_path in test_image_paths:
    img = Image.open(img_path).convert("RGB")
    img_tensor = transform_15(img).unsqueeze(0)  # shape [1,3,H,W]

    # Because the unet_15 is a UNet2DConditionModel, we can do:
    # unet_15(noisy_latents, t, encoder_hidden_states=?) in normal usage
    # but for identity test, we'll pass None or an empty 'encoder_hidden_states'.

    with torch.no_grad():
        img_tensor = img_tensor.to(device)
        latents = vae_15.encode(img_tensor).latent_dist.sample()
        latents = latents * sd15_latent_scaling

        t = torch.tensor([0]*latents.shape[0], device=device)
        # Official SD unet is conditional, so we must pass something for encoder_hidden_states
        # For pure identity test, let's pass a dummy zero embedding of the correct shape.
        # Typically stable diffusion uses text encoder hidden states shape [batch, max_len, hidden_dim].
        # We'll just use zeros with shape [1, 77, 768] if that matches the model config.
        # (This is a hack to "simulate" no conditioning.)
        dummy_embeds = torch.zeros((1, 77, 768), device=device)

        pred_noise = unet_15(latents, t, encoder_hidden_states=dummy_embeds).sample
        mse_15 = F.mse_loss(pred_noise, torch.zeros_like(pred_noise)).item()

    print(f"SD1.5 Identity MSE for '{os.path.basename(img_path)}': {mse_15:.6f}")

SD1.5 Identity MSE for 'black3.png': 0.117680
SD1.5 Identity MSE for 'blue3.png': 0.042924
SD1.5 Identity MSE for 'green3.jpg': 0.034675
SD1.5 Identity MSE for 'red3.jpg': 0.087377
SD1.5 Identity MSE for 'yellow3.jpg': 0.039821
