# ECE 285 - Final Project
## Implementation of EmerDiff for Semantic Segmentation
### Name: Pushkal Mishra
### PID: A69033424

This project is an implementation of the [EmerDiff architecture](https://kmcode1.github.io/Projects/EmerDiff/) which uses Stable Diffusion for semantic segmentation.

The model is trained on the CARLA dataset, collected from an [earlier project](https://wcsng.ucsd.edu/c-shenron-demo/).

In [1]:
import cv2
import torch
import numpy as np
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.cluster import KMeans
from collections import defaultdict
from transformers import CLIPTextModel, CLIPTokenizer
from torchvision.transforms.functional import pil_to_tensor
from diffusers import DDIMScheduler, AutoencoderKL, UNet2DConditionModel

In [None]:
# Load CARLA dataset
img_path = "/home/pushkalm11/Courses/ece285/Project/dataset/s1_2025-03-05/s1_Town01_Rep1/Town01_Scenario1_route0_03_05_14_14_18/rgb/0110.jpg"
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image[:, 512 - 128 : 512 + 128, :]

# convert to PIL image
image = Image.fromarray(image)
image_2 = image.resize((512, 512), resample=Image.LANCZOS)

img = pil_to_tensor(image_2).cuda().unsqueeze(0).float() / 255.0 * 2.0 - 1.0

# show both images side by side
plt.figure(figsize=(8, 5))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image_2)
plt.axis('off')
plt.show()

print(img.min(), img.max())

In [3]:
# img = torch.cat([img, img], dim = 0)
# print(img.shape)

## Configurations for EmerDiff

In [4]:
# Parameters for mask proposal
# Negative offset for mask proposal
lambda_1 = -10 
# Positive offset for mask proposal
lambda_2 = 10
# Number of segmentation masks
num_mask = 25
# Text prompt for mask proposal
text_prompt = ""

# Compression factor for VAE to produce latent space
vae_compress_rate = 8
# Stable Diffusion Model checkpoint
stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
# CLIP Model checkpoint
clip_version = "openai/clip-vit-large-patch14"

# Number of timesteps for diffusion process
inference_time_steps = 50

# Timesteps for mask proposal
# Perform multiple runs for convergence
time_steps = [0] * 100
# Timestep at which the feature map is injected
inject_mask_time_stamp = [281]
# This is the number of timesteps to use which is scaled down from the 1000 steps in stable diffusion to the inference_time_steps
index_to_use = int(1 + (max(inject_mask_time_stamp) // (1000 / inference_time_steps)))

# k-means parameters - Standard parameters
n_init = 100
init_algo = "k-means++"
kmeans_algo = "lloyd"

# Layers to record feature maps - This is the first cross-attention layer at the 16x16 resolution
record_layers = [("cross", "up", 0, 0)]
# Layers to inject mask offsets
inject_mask_perturbations = [("cross", "up", 0, 2)]

# EmerDiff Config Dictionary
emerdiff_config = {
    "lambda_1": lambda_1,
    "lambda_2": lambda_2,
    "num_mask": num_mask,
    "text_prompt": text_prompt,
    "time_steps": time_steps,
    "inject_mask_time_stamp": inject_mask_time_stamp,
    "vae_compress_rate": vae_compress_rate,
    "stable_diffusion_version": stable_diffusion_version,
    "clip_version": clip_version,
    "inference_time_steps": inference_time_steps,
    "n_init": n_init,
    "init_algo": init_algo,
    "kmeans_algo": kmeans_algo,
    "record_layers": record_layers,
    "inject_mask_perturbations": inject_mask_perturbations,
    "multiplication_factor": 0.18215,
    "index_to_use": index_to_use
}

## Importing pre-trained models

In [5]:
# Load VAE
VAE = AutoencoderKL.from_pretrained(
    stable_diffusion_version, 
    subfolder = "vae", 
    cache_dir = "./").to("cuda:0")

# Load DDIM scheduler for noise addition and reverse process
ddim = DDIMScheduler.from_pretrained(
    stable_diffusion_version, 
    subfolder = "scheduler", 
    cache_dir = "./")
ddim.set_timesteps(
    inference_time_steps,
    device = "cuda:0")

# Importing CLIP parameters
text_tokenizer = CLIPTokenizer.from_pretrained(
    clip_version, 
    cache_dir = "./")
text_encoder = CLIPTextModel.from_pretrained(
    clip_version, 
    cache_dir = "./").to("cuda:0")

# Importing Text-Conditioned UNet
UNet = UNet2DConditionModel.from_pretrained(
    stable_diffusion_version, 
    subfolder = "unet", 
    cache_dir = "./").to("cuda:0")

# k-Means processor
kmeans_classifier = KMeans(
    n_clusters = emerdiff_config["num_mask"], 
    init = emerdiff_config["init_algo"], 
    n_init = emerdiff_config["n_init"], 
    random_state = 1234, 
    algorithm = emerdiff_config["kmeans_algo"])

## Functions to swap out attention layers
These functions are required to swap out the attention layers in the UNet architecture so that they can store and modify the attention maps.

### Description of the functions:

#### Function: `store_attention_maps()`
Stores attention queries (and optionally keys) during forward passes from selected cross-attention layers. These are used for clustering mask features or preserving attention structure. Returns possibly modified query, key, and value based on the flags set.

#### Function: `modify_feature_maps()`
Applies spatial perturbations to selected feature maps using a learned binary mask and lambda weights. This helps measure how specific attention regions influence the final image. Returns the perturbed hidden states in the original shape.

#### Function: `extract_mask_features()`
Aggregates and averages stored query features across timesteps for selected layers. These features are then reshaped into vectors used for k-means clustering to generate low-resolution masks. Returns a 2D tensor of shape `[tokens, features]`.

#### Function: `inject_attention()`
Replaces default attention forward passes in the U-Net with custom logic that records or modifies attention behavior. It injects hooks into specific attention layers (cross/self, up/down/mid) to enable storing Q/K and applying feature perturbations. Does not return anything; modifies the U-Net in-place.


In [6]:
"""
Use of each flag:
    - record_mask_embeddings: Record the mask embeddings at the correct timestep for modification
    - record_kqv_attention: Record the original Query and Key for the attention map so that it can be used later while reconstruction
    - use_recorded_kqv_attention: Use the original Query and Key for the attention map so that it can be used later while reconstruction
    - perturb_feature: Perturb the feature maps
"""
layer_dict = {
    # All the flags
    "record_mask_embeddings": False,
    "record_kqv_attention": False,
    "use_recorded_kqv_attention": False,
    "perturb_feature": False,
    # Storing the attention maps
    "stored_attention": defaultdict(lambda: 0),
    "stored_attention_count": defaultdict(lambda: 0),
    "original_kqv_attention": {},
    # Timestep
    "timestep": -1,
    # Layers to record
    "record_layers": emerdiff_config["record_layers"],
    # Layers to perturb
    "perturb_layers": emerdiff_config["inject_mask_perturbations"],
    # Mask
    "mask": None,
    # Lambda
    "lambda_1": emerdiff_config["lambda_1"],
    "lambda_2": emerdiff_config["lambda_2"]
}

# Defining a class for new attention blocks to store the maps
class NewAttentionBlock:
    def __init__(self, layer_dict):
        self.layer_dict = layer_dict
    
    # Function to store attention maps
    def store_attention_maps(
        self,
        ty = "cross",
        pos = "up",
        res = 0,
        idx = 0,
        query = None,
        key = None,
        value = None
    ):
        if self.layer_dict["record_mask_embeddings"]:
            if (ty, pos, res, idx) in self.layer_dict["record_layers"]:
                fet = query[0].clone().detach().cpu() #[h*w, 768]
                self.layer_dict["stored_attention"][(self.layer_dict["timestep"],ty,pos,res,idx)] += fet
                self.layer_dict["stored_attention_count"][(self.layer_dict["timestep"],ty,pos,res,idx)] += 1
            if self.layer_dict["record_kqv_attention"]:
                # Storing the original Query and Key for the attention map
                self.layer_dict["original_kqv_attention"][(self.layer_dict["timestep"],ty,pos,res,idx)] = (query.detach().clone().cpu(), key.detach().clone().cpu())
            if self.layer_dict["use_recorded_kqv_attention"]:
                stored_query, stored_key = self.layer_dict["original_kqv_attention"][(self.layer_dict["timestep"],ty,pos,res,idx)]
                query[:] = stored_query[0].to("cuda:0")
                key[:] = stored_key[0].to("cuda:0")
        return query, key, value

    # Function to modify the feature maps
    def modify_feature_maps(
        self,
        hidden_states,
        ty = "cross",
        pos = "up",
        res = 0,
        idx = 0,
        to_v = None
    ):
        if self.layer_dict["perturb_feature"] and ((ty, pos, res, idx) in self.layer_dict["perturb_layers"]):
            original_shape = hidden_states.shape
            if len(original_shape) == 4:
                hidden_states = hidden_states.reshape((hidden_states.shape[0], hidden_states.shape[1], -1)).permute((0, 2, 1))
            # Reshape the mask to the same shape as the hidden states
            mask = self.layer_dict["mask"].reshape((1, -1, 1))
            if hidden_states.shape[0] == 2:
                lam = torch.from_numpy(np.array([self.layer_dict["lambda_1"], self.layer_dict["lambda_2"]])).float().reshape(2, 1, 1).to("cuda:0")
            else:
                lam = torch.from_numpy(np.array([self.layer_dict["lambda_1"]])).float().reshape(1, 1, 1).to("cuda:0")
            
            # Perturb the feature maps - Main contribution of the EmerDiff architecture
            hidden_states = hidden_states + lam * mask
            if len(original_shape) == 4:
                hidden_states = hidden_states.permute((0,2,1)).reshape(original_shape)
        return hidden_states
    # Function to extract out the mask features
    def extract_mask_features(
        self
    ):
        fet = []
        for (k, v) in self.layer_dict["stored_attention"].items():
            # Here we are taking the mean of the attention maps at the correct timestep
            # Normalization is necessary to stabilize and avoid drift
            fet.append(v.unsqueeze(0) / self.layer_dict["stored_attention_count"][k])
        fet = torch.cat(fet, dim=0)
        fet = fet.permute((1,0,2))
        fet = fet.reshape((fet.shape[0], -1))
        return fet

new_attention_block = NewAttentionBlock(layer_dict)

#### Function: inject_attention()
This function is used to inject the attention maps into the UNet architecture.

Reference: [https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py)

In [7]:
# This code has been taken from an existing github repository which follows the UNet Text-conditioned architecture
# The code has been modified to store the attention maps and modify the feature maps
def inject_attention(unet, new_attention_block):
    def new_forward_attention(ty = "cross", pos = "up", res=0, idx = 0):
        def forward(attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
            batch_size, sequence_length, _ = (
                hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
            )
            inner_dim = hidden_states.shape[-1]
            query = attn.to_q(hidden_states) #(batch, seq_len, num_heads*head_dim)
            if encoder_hidden_states is None:
                encoder_hidden_states = hidden_states
            key = attn.to_k(encoder_hidden_states) #(batch, seq_len, num_heads*head_dim)
            value = attn.to_v(encoder_hidden_states) #(batch, seq_len, num_heads*head_dim)
            head_dim = inner_dim // attn.heads

            #store qkv
            query, key, value = new_attention_block.store_attention_maps(ty, pos, res, idx, query, key, value)

            query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) #(batch, num_heads, seq_len, head_dim)
            key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
            value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

            hidden_states = F.scaled_dot_product_attention(
                query, key, value, attn_mask=attention_mask
            )

            hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) #(batch, seq_len, num_heads*head_dim)
            hidden_states = hidden_states.to(query.dtype)

            # linear proj
            hidden_states = attn.to_out[0](hidden_states)

            #perturb the output
            hidden_states = new_attention_block.modify_feature_maps(hidden_states, ty, pos, res, idx, None)

            return hidden_states
        return forward
    def inject_block(blocks=unet.up_blocks, pos="up"):
        #ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
        res = -1
        if pos == "mid":
            children = [blocks]
        else:
            children = blocks.children()
        for net_ in children:
            if net_.__class__.__name__ in ["CrossAttnUpBlock2D","CrossAttnDownBlock2D","UNetMidBlock2DCrossAttn"]:
                res += 1
                idx = -1
                for atn in net_.attentions:
                        if atn.__class__.__name__ == "Transformer2DModel":
                            idx += 1
                            for block in atn.transformer_blocks:
                                if block.__class__.__name__ == "BasicTransformerBlock":
                                    #self attention
                                    if block.attn1.processor.__class__.__name__ == "AttnProcessor2_0":
                                        block.attn1.processor = new_forward_attention(ty = "self", pos = pos, res = res, idx = idx)
                                    #cross attention
                                    if block.attn2.processor.__class__.__name__ == "AttnProcessor2_0":
                                        block.attn2.processor = new_forward_attention(ty="cross", pos = pos, res = res, idx = idx)
        return blocks
    unet.up_blocks = inject_block(unet.up_blocks, pos="up")
    unet.down_blocks = inject_block(unet.down_blocks, pos="down")
    unet.mid_block = inject_block(unet.mid_block, pos="mid")
    return unet

### Injecting the attention maps into the UNet architecture

In [8]:
UNet = inject_attention(UNet, new_attention_block)

## Functions for Forward and Reverse DDPM process
These functions are copied over from the original implementation of the DDPM process, can be found here: [https://github.com/inbarhub/DDPM_inversion/blob/main/ddm_inversion/inversion_utils.py](https://github.com/inbarhub/DDPM_inversion/blob/main/ddm_inversion/inversion_utils.py).

### Description of the functions:

#### Function: `sample_xts_from_x0()`
Generates noisy latent samples `x_t` from clean latent `x₀` by adding scheduled Gaussian noise across timesteps. Used to simulate the forward diffusion process for inversion.

#### Function: `get_variance()`
Computes the noise variance at a specific timestep using the scheduler’s cumulative alpha values. Needed for reverse sampling and DDIM step calculations.

#### Function: `inversion_forward_process()`
Performs DDPM inversion by simulating the forward process from `x₀` to `x_T`, then reconstructing noise vectors `z` via U-Net predictions at each step. Returns noisy latents, noise vectors, and timesteps.

#### Function: `reverse_step()`
Executes a single DDIM denoising step from `x_t` to `x_{t-1}` using predicted noise, diffusion schedule, and optional noise scaling. Core operation for controlled reverse generation.

In [9]:
def sample_xts_from_x0(unet, scheduler, x0, num_inference_steps, rng):
    alpha_bar = scheduler.alphas_cumprod
    sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
    variance_noise_shape = (
            num_inference_steps,
            x0.shape[-3],
            x0.shape[-2],
            x0.shape[-1])

    timesteps = scheduler.timesteps.to("cuda:0")
    t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
    xts = torch.zeros(variance_noise_shape).to(x0.device)
    for t in reversed(timesteps):
        idx = t_to_idx[int(t)]
        # print(xts.shape, x0.shape)
        xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn(x0.shape, generator=rng).to("cuda:0") * sqrt_one_minus_alpha_bar[t]
    xts = torch.cat([xts, x0],dim = 0)

    return xts

def get_variance(scheduler, timestep):
    prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
    alpha_prod_t = scheduler.alphas_cumprod[timestep]
    alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
    beta_prod_t = 1 - alpha_prod_t
    beta_prod_t_prev = 1 - alpha_prod_t_prev
    variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
    return variance

def inversion_forward_process(unet, scheduler, x0,
                              uncond_embedding,
                              etas = 1.0,
                              num_inference_steps=50,
                              ddpm_seed = 42
                             ):

    timesteps = scheduler.timesteps.to("cuda:0")
    variance_noise_shape = (
        num_inference_steps,
        x0.shape[-3],
        x0.shape[-2],
        x0.shape[-1])
    rng = torch.Generator().manual_seed(ddpm_seed)

    etas = [etas]*scheduler.num_inference_steps
    #generate noisy samples xts
    # print(x0.shape)
    xts = sample_xts_from_x0(unet, scheduler, x0, num_inference_steps=num_inference_steps, rng = rng)
    alpha_bar = scheduler.alphas_cumprod
    zs = torch.zeros(size=variance_noise_shape, device="cuda:0") #[50, 4, 64, 64]

    t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
    xt = x0

    for t in timesteps:
        idx = t_to_idx[int(t)]

        xt = xts[idx][None]

        with torch.no_grad():
            out = unet.forward(xt, timestep =  t, encoder_hidden_states = uncond_embedding)

        noise_pred = out.sample

        xtm1 =  xts[idx+1][None]
        pred_original_sample = (xt - (1-alpha_bar[t])  ** 0.5 * noise_pred ) / alpha_bar[t] ** 0.5
        prev_timestep = t - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
        alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
        variance = get_variance(scheduler, t)
        pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance ) ** (0.5) * noise_pred
        mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
        z = (xtm1 - mu_xt ) / ( etas[idx] * variance ** 0.5 )
        zs[idx] = z
        xtm1 = mu_xt + ( etas[idx] * variance ** 0.5 )*z
        xts[idx+1] = xtm1
    return xts, zs, timesteps

# Simulating one step of the reverse diffusion process
def reverse_step(scheduler, model_output, timestep, sample, eta, variance_noise):
    prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
    alpha_prod_t = scheduler.alphas_cumprod[timestep]
    alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
    beta_prod_t = 1 - alpha_prod_t
    pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
    variance = get_variance(scheduler, timestep)
    std_dev_t = eta * variance ** (0.5)
    model_output_direction = model_output
    pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
    prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
    if eta > 0:
        sigma_z =  eta * variance ** (0.5) * variance_noise
        prev_sample = prev_sample + sigma_z

    return prev_sample

## Main EmerDiff Process

In [10]:
# Forward Process of Stable Diffusion along with UNet denoising to record attention maps
def fwd_stable_diffusion(
    latent, 
    time_steps, 
    new_attention_block):
    # Random number generator which is shared across all timesteps
    random_number_generator = torch.Generator().manual_seed(42)
    # Adding noise to the latent vector at each timestep
    for t in time_steps:
        noise = torch.randn(latent.shape, generator = random_number_generator).to("cuda:0")
        latent_t = ddim.add_noise(latent, noise, torch.tensor([t]).int())
        new_attention_block.layer_dict["timestep"] = t
        UNet(latent_t, t, encoder_hidden_states = new_attention_block.layer_dict["prompt_embedding"])

# Running one step of the reverse diffusion process
def reverse_diffusion_one_step(
    t, 
    latents, 
    cond1, 
    cond2 = None, 
    z_ddpm = None):
    if cond2 != None:
        text_embeddings  = torch.cat([cond1, cond2])
    else:
        text_embeddings = cond1
    noise_pred = UNet(
                            latents,
                            t,
                            encoder_hidden_states=text_embeddings
                        ).sample
    z_ddpm = z_ddpm.expand(latents.shape[0],-1,-1,-1) #scheduled noise
    return reverse_step(ddim, noise_pred, t, latents, 1.0, z_ddpm) 

In [None]:
# Generating Latent Vector embedding from the image
latent = VAE.encode(img)['latent_dist'].mean
latent = latent * emerdiff_config["multiplication_factor"]

# Shape of the original image
h, w = latent.shape[-2], latent.shape[-1]
h, w = h * emerdiff_config["vae_compress_rate"], w * emerdiff_config["vae_compress_rate"]

# Adding text embedding to the layer_dict
text_prompt_tokens = text_tokenizer(
    text_prompt, 
    padding = "max_length", 
    max_length = text_tokenizer.model_max_length, 
    truncation = True, 
    return_tensors = "pt")
text_prompt_embeddings = text_encoder(text_prompt_tokens.input_ids.to("cuda:0"))[0]
print(f"Shape of prompt embeddings in CLIP space: {text_prompt_embeddings.shape}")
new_attention_block.layer_dict["prompt_embedding"] = text_prompt_embeddings
# new_attention_block.layer_dict["prompt_embedding"] = torch.cat([text_prompt_embeddings, text_prompt_embeddings], dim = 0)

# Extracting mask for clustering
new_attention_block.layer_dict["record_mask_embeddings"] = True
fwd_stable_diffusion(latent, emerdiff_config["time_steps"], new_attention_block)
new_attention_block.layer_dict["record_mask_embeddings"] = False

# Clustering the features to produce low-resolution segmentation maps
mask_features = new_attention_block.extract_mask_features()
kmeans_classifier.fit(mask_features)
mask_to_id_mapping = torch.from_numpy(kmeans_classifier.labels_).to("cuda:0")
num_masks = mask_to_id_mapping.max() + 1

In [12]:
# Run the forward process and track the latents and noise at the timesteps to be used for the final image
index_to_use = emerdiff_config["index_to_use"]
xts, zs, timesteps = inversion_forward_process(
    UNet,
    ddim,
    latent,
    new_attention_block.layer_dict["prompt_embedding"],
    etas = 1.0,
    num_inference_steps = emerdiff_config["inference_time_steps"],
    ddpm_seed = 42
)
xts = xts.unsqueeze(1)
# Excluding the first timestep
xts = xts[-index_to_use-1:-1]
timesteps = timesteps[-index_to_use:]
zs = zs[-index_to_use:]

# Run the reverse process to get the de-noised latents at the timesteps to be used for the final image
with torch.no_grad():
    new_attention_block.layer_dict["record_kqv_attention"] = True
    original_latents = xts[0]
    for i, (xt, t) in enumerate(zip(xts, timesteps)):
        new_attention_block.layer_dict["timestep"] = t
        original_latents = reverse_diffusion_one_step(
            t = t,
            latents = original_latents,
            cond1 = new_attention_block.layer_dict["prompt_embedding"],
            z_ddpm = zs[i]
        )
    new_attention_block.layer_dict["record_kqv_attention"] = False

In [None]:
# Now that we have the de-noised latents, we can modulate the attention maps for every mask to generate the final image
img_with_mask_id = []
for i in range(emerdiff_config["num_mask"]):
    print(f"Generating the difference map for mask id: {i + 1} / {emerdiff_config['num_mask']}")
    with torch.no_grad():
        # Here we generate the difference map for the mask id i and store them in a list
        # Updating the mask
        new_attention_block.layer_dict["mask"] = (mask_to_id_mapping == i).float().to("cuda:0")
        
        # Modifying the latent and generating the difference map
        latent_to_modify = torch.cat([xts[0]]*2, dim = 0)
        # Setting the flag to use the recorded attention maps
        new_attention_block.layer_dict["use_recorded_kqv_attention"] = True
        for i, (xt, t) in enumerate(zip(xts, timesteps)):
            # Setting the timestep
            new_attention_block.layer_dict["timestep"] = t
            # Setting the flag to perturb the feature maps
            new_attention_block.layer_dict["perturb_feature"] = (t in emerdiff_config["inject_mask_time_stamp"])
            # Running the diffusion step
            latent_to_modify = reverse_diffusion_one_step(
                t = t,
                latents = latent_to_modify,
                cond1 = new_attention_block.layer_dict["prompt_embedding"],
                cond2 = new_attention_block.layer_dict["prompt_embedding"],
                z_ddpm = zs[i]
            )
        new_attention_block.layer_dict["use_recorded_kqv_attention"] = False
        new_attention_block.layer_dict["perturb_feature"] = False
        x_modified = VAE.decode(latent_to_modify / emerdiff_config["multiplication_factor"]).sample
        
        # Adding blurring to the image
        x_modified = torchvision.transforms.functional.gaussian_blur(x_modified, kernel_size = 3)
        difference_map = torch.linalg.norm(x_modified[0 : 1] - x_modified[1 : 2], dim = 1)
        img_with_mask_id.append(difference_map.cpu())
    
# Concatenating the difference maps and taking argmax to get the final image
segmented_image = torch.argmax(torch.cat(img_with_mask_id, dim = 0), dim = 0)
print(f"Shape of the segmented image: {segmented_image.shape}")

In [14]:
# cmap = plt.get_cmap('tab20', torch.max(segmented_image.flatten())+1)
# col_img = cmap(segmented_image)[:, :, :3]  # shape: (H, W, 3), values in [0,1]
# col_img = (col_img * 255).astype(np.uint8)
# plt.figure()
# plt.imshow(col_img)
# plt.axis('off')
# plt.show()

In [None]:
np.random.seed(123)
palette = np.random.choice(range(256), size=(1000,3)).astype(np.uint8)
col_img = palette[segmented_image.flatten()].reshape(segmented_image.shape+(3,))
im = np.array(col_img)*0.8+((img[0]+1.0)/2.0*255.0).permute((1,2,0)).cpu().numpy()*0.2
plt.imshow(Image.fromarray(im.astype(np.uint8)))
plt.axis('off')
plt.show()

## Forward Process

TODO:
1. Describe the entire forward process here
2. Talk about VAE and noise addition according to the DDPM scheduler
3. Provide mathematical equations for this

Use of VAE:
- A meaningful, structured latent space
- The VAE maps high-dimensional images (e.g. 512×512×3) into lower-dimensional, semantic latent variables (e.g. 64×64×4).
- This latent space is:
  - Continuous and smooth
  - More semantically meaningful (similar images → nearby latents)
  - Easier for the diffusion model to learn in
- Without a VAE, the diffusion model would need to operate directly in image space, which is harder, slower, and requires more memory.
- Efficiency (Compression!)
  - Stable Diffusion doesn't run diffusion in pixel space — it uses a latent diffusion model (LDM).
  - Instead of operating on 512×512×3 (≈786K values), it works on 64×64×4 (≈16K values)
  - That’s ~50× smaller
- This allows:
  - Faster training and inference
  - Memory-efficient denoising steps

### Why do we need a probabilistic model and not a deterministic one?
In generative modeling (and especially diffusion models), we don’t just want a single latent representation — we want:

- A distribution over possible representations

- The ability to sample diverse outputs from the same input

- A smooth latent space where nearby points yield semantically similar images

Without a probabilistic latent space:
- Interpolating between images becomes meaningless or discontinuous
- Sampling new data becomes impossible (you can't sample from a deterministic encoder)
- Inversion (mapping image → latent) becomes unstable

## Reverse Process - Denoising using DDPM and EmerDiff modifications

### Exaplaination:
1. Here we add noise in the reverse process as proposed in the DDPM paper.
2. This helps model the true posterior distribution of the diffusion process.
3. If we were not to add the noise, the reverse process would be deterministic and so not a good approximation of the true posterior distribution.

TODO:
1. Add reference to DDPM paper
2. Explain the logic behind this 
3. Provide mathematical equations for this

## State of the art semantic segmentation model

In [None]:
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import torch

# Load model and processor
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")

# Prepare image
inputs = processor(images=image_2, return_tensors="pt")

# Get predictions
outputs = model(**inputs)
logits = outputs.logits  # shape (batch_size, num_classes, height, width)

# Get predicted segmentation map
upsampled_logits = torch.nn.functional.interpolate(
    logits,
    size=image_2.size[::-1],  # (height, width)
    mode='bilinear',
    align_corners=False
)

pred_seg = upsampled_logits.argmax(dim=1)[0]

# Plot results
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image_2)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(pred_seg.cpu())
plt.title("Segmentation Map") 
plt.axis('off')
plt.show()

In [None]:
# -*- coding: utf-8 -*-
"""EmerDiff example implementation.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1Kl1tnHYo2OaUY9hdKFP4wfIpTe-WHaSp

This notebook provides the **minimal** re-implementation of "EmerDiff: Emerging Pixel-level Semantic Knowledge in Diffusion Models". **We are currently working on improving the notebook.** If you spot a bug, please let us know. For more details, check out our project page: https://kmcode1.github.io/Projects/EmerDiff/

# Set up environment
"""

# !pip install diffusers

# Commented out IPython magic to ensure Python compatibility.
# %matplotlib inline
import torch
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import torch
from math import sqrt
import torchvision
import torch.nn.functional as F
from diffusers import DDPMScheduler, DDIMScheduler, StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from collections import defaultdict
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

#DDPM-inversion
#Code mostly copied from: https://github.com/inbarhub/DDPM_inversion/blob/main/ddm_inversion/inversion_utils.py
#Reference paper: https://arxiv.org/pdf/2304.06140.pdf

def add_noise(scheduler, original_samples, next_timestep, rng): #add noise from prev_timestep -> next_timestep
    noise = torch.randn(original_samples.shape, generator=rng).to("cuda:0")

    prev_timestep = next_timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
    alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
    alpha_prod_t_next = scheduler.alphas_cumprod[next_timestep]

    alpha = alpha_prod_t_next / alpha_prod_t_prev
    alpha = alpha.to("cuda:0")

    #ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py#L477
    sqrt_alpha_prod = alpha ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
    while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

    sqrt_one_minus_alpha_prod = (1 - alpha) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
    while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

    noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise #x_{prev} -> x_{next}
    return noisy_samples

def sample_xts_from_x0(unet, scheduler, x0, num_inference_steps, rng):
    alpha_bar = scheduler.alphas_cumprod
    sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
    variance_noise_shape = (
            num_inference_steps,
            x0.shape[-3],
            x0.shape[-2],
            x0.shape[-1])

    timesteps = scheduler.timesteps.to("cuda:0")
    t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
    xts = torch.zeros(variance_noise_shape).to(x0.device)
    for t in reversed(timesteps):
        idx = t_to_idx[int(t)]
        xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn(x0.shape, generator=rng).to("cuda:0") * sqrt_one_minus_alpha_bar[t]
    xts = torch.cat([xts, x0 ],dim = 0)

    return xts

def get_variance(scheduler, timestep):
    prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
    alpha_prod_t = scheduler.alphas_cumprod[timestep]
    alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
    beta_prod_t = 1 - alpha_prod_t
    beta_prod_t_prev = 1 - alpha_prod_t_prev
    variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
    return variance

def inversion_forward_process(unet, scheduler, x0,
                              uncond_embedding,
                              etas = 1.0,
                              num_inference_steps=50,
                              ddpm_seed = 40
                             ):

    timesteps = scheduler.timesteps.to("cuda:0")
    variance_noise_shape = (
        num_inference_steps,
        x0.shape[-3],
        x0.shape[-2],
        x0.shape[-1])
    rng = torch.Generator().manual_seed(ddpm_seed)

    etas = [etas]*scheduler.num_inference_steps
    #generate noisy samples xts
    xts = sample_xts_from_x0(unet, scheduler, x0, num_inference_steps=num_inference_steps, rng = rng)
    alpha_bar = scheduler.alphas_cumprod
    zs = torch.zeros(size=variance_noise_shape, device="cuda:0") #[50, 4, 64, 64]

    t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
    xt = x0

    for t in timesteps:
        idx = t_to_idx[int(t)]

        xt = xts[idx][None]

        with torch.no_grad():
            out = unet.forward(xt, timestep =  t, encoder_hidden_states = uncond_embedding)

        noise_pred = out.sample

        xtm1 =  xts[idx+1][None]
        pred_original_sample = (xt - (1-alpha_bar[t])  ** 0.5 * noise_pred ) / alpha_bar[t] ** 0.5
        prev_timestep = t - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
        alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
        variance = get_variance(scheduler, t)
        pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance ) ** (0.5) * noise_pred
        mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
        z = (xtm1 - mu_xt ) / ( etas[idx] * variance ** 0.5 )
        zs[idx] = z
        xtm1 = mu_xt + ( etas[idx] * variance ** 0.5 )*z
        xts[idx+1] = xtm1
    return xts, zs, timesteps

def reverse_step(scheduler, model_output, timestep, sample, eta, variance_noise):
    prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
    alpha_prod_t = scheduler.alphas_cumprod[timestep]
    alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
    beta_prod_t = 1 - alpha_prod_t
    pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
    variance = get_variance(scheduler, timestep)
    std_dev_t = eta * variance ** (0.5)
    model_output_direction = model_output
    pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
    prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
    if eta > 0:
        sigma_z =  eta * variance ** (0.5) * variance_noise
        prev_sample = prev_sample + sigma_z

    return prev_sample

#utils
def load_img(path, resize_to_512 = False):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    if resize_to_512:
      w, h = 512, 512
    else:
      fac = sqrt(512*512/w/h)
      w = int(w*fac)
      h = int(h*fac)
      w, h = map(lambda x: x - x % 64, (w + 63, h + 63))  # resize to integer multiple of 64
    image = image.resize((w, h), resample=Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2. * image - 1.

def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

class Controller:
    def __init__(self, config):
        self.config = config

        self.record_mask_proposal = False #record feature maps for clustering
        self.record_attention = False #record original q/k/v
        self.use_recorded_attention = False #use recorded original attention to ensure better reconstruction
        self.perturb_feature = False #perturb feature map

        self.store = defaultdict(lambda : 0) #for mask proposal
        self.store_cnt = defaultdict(lambda : 0)
        self.orig = {} #for recording original q/k/v

        self.t = -1

        self.h = self.w = 512 #resolution of image
        self.compress_rate = config.compress_rate #resolution of image/resolution of latent

        self.lam1 = config.lam1
        self.lam2 = config.lam2

        return

    def extract_mask_feature(self):
        fet = [] #[h*w,768]
        for (k, v) in self.store.items():
            fet.append(v.unsqueeze(0)/self.store_cnt[k]) #take mean
        self.store = defaultdict(lambda : 0)
        self.store_cnt = defaultdict(lambda : 0)
        fet = torch.cat(fet,dim=0) #(sample, h*w, 768)
        fet = fet.permute((1,0,2)) #(h*w, sample, 768)
        fet = fet.reshape((fet.shape[0], -1)) #(h*w, sample*768)

        return fet

    def store_attn(self, ty, pos, res, idx, query, key, val):
        # print(ty, pos, res, idx)
        # print(self.config.record_mask_proposal_layers)
        if self.record_mask_proposal:
            if (ty, pos, res, idx) in self.config.record_mask_proposal_layers:
                fet = query[0].clone().detach().cpu() #[h*w, 768]
                self.store[(self.t,ty,pos,res,idx)] += fet
                self.store_cnt[(self.t,ty,pos,res,idx)] += 1
        if self.record_attention:  #save cross/self attention map
            self.orig[(self.t, ty, pos, res, idx)] = (query.detach().clone().cpu(), key.detach().clone().cpu())
        if self.use_recorded_attention: #preserve cross/self attention map
            query_, key_= self.orig[(self.t, ty, pos, res, idx)]
            query[:] = query_[0].to("cuda:0")
            key[:] = key_[0].to("cuda:0")
        return query, key, val

    def modify_feature(self, hidden_states, ty, pos, res, idx, to_v = None):
        if self.perturb_feature and ((ty, pos, res, idx) in self.config.perturb_feature_layers):
            original_shape = hidden_states.shape
            if len(original_shape) == 4: #[2, 1280, 16, 16]
                hidden_states = hidden_states.reshape((hidden_states.shape[0], hidden_states.shape[1], -1)).permute((0,2,1)) #[2,h*w,1280]
            mask = self.mask.reshape((1,-1,1)) #[1,h*w,1]
            if hidden_states.shape[0] == 2:
                lam = torch.from_numpy(np.array([self.lam1, self.lam2])).float().reshape(2,1,1).to("cuda:0")
            else:
                lam = torch.from_numpy(np.array([self.lam1])).float().reshape(1,1,1).to("cuda:0")
            hidden_states = hidden_states + lam*mask
            if len(original_shape) == 4:
                hidden_states = hidden_states.permute((0,2,1)).reshape(original_shape)
        return hidden_states

def inject_attention(unet, controller):
        def new_forward_attention(ty = "cross", pos = "up", res=0, idx = 0):
             def forward(attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
                batch_size, sequence_length, _ = (
                    hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
                )
                inner_dim = hidden_states.shape[-1]
                query = attn.to_q(hidden_states) #(batch, seq_len, num_heads*head_dim)
                if encoder_hidden_states is None:
                    encoder_hidden_states = hidden_states
                key = attn.to_k(encoder_hidden_states) #(batch, seq_len, num_heads*head_dim)
                value = attn.to_v(encoder_hidden_states) #(batch, seq_len, num_heads*head_dim)
                head_dim = inner_dim // attn.heads

                #store qkv
                query, key, value = controller.store_attn(ty, pos, res, idx, query, key, value)

                query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) #(batch, num_heads, seq_len, head_dim)
                key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
                value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

                hidden_states = F.scaled_dot_product_attention(
                    query, key, value, attn_mask=attention_mask
                )

                hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) #(batch, seq_len, num_heads*head_dim)
                hidden_states = hidden_states.to(query.dtype)

                # linear proj
                hidden_states = attn.to_out[0](hidden_states)

                #perturb the output
                hidden_states = controller.modify_feature(hidden_states,ty, pos, res, idx, None)

                return hidden_states
             return forward
        def inject_block(blocks=unet.up_blocks, pos="up"):
            #ref: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
            res = -1
            if pos == "mid":
                children = [blocks]
            else:
                children = blocks.children()
            for net_ in children:
                if net_.__class__.__name__ in ["CrossAttnUpBlock2D","CrossAttnDownBlock2D","UNetMidBlock2DCrossAttn"]:
                    res += 1
                    idx = -1
                    for atn in net_.attentions:
                           if atn.__class__.__name__ == "Transformer2DModel":
                                idx += 1
                                for block in atn.transformer_blocks:
                                    if block.__class__.__name__ == "BasicTransformerBlock":
                                        #self attention
                                        if block.attn1.processor.__class__.__name__ == "AttnProcessor2_0":
                                            block.attn1.processor = new_forward_attention(ty = "self", pos = pos, res = res, idx = idx)
                                        #cross attention
                                        if block.attn2.processor.__class__.__name__ == "AttnProcessor2_0":
                                            block.attn2.processor = new_forward_attention(ty="cross", pos = pos, res = res, idx = idx)
        inject_block(unet.up_blocks, pos="up")
        inject_block(unet.down_blocks, pos="down")
        inject_block(unet.mid_block, pos="mid")

class Model:
    def __init__(self, config):
        self.config = config

        #load model
        self.ddim_scheduler = DDIMScheduler().from_pretrained(config.stable_version, subfolder="scheduler", cache_dir=config.cache_dir)
        self.ddim_scheduler.set_timesteps(config.inference_steps, device="cuda:0")
        self.vae = AutoencoderKL.from_pretrained(config.stable_version, subfolder="vae",cache_dir=config.cache_dir).to("cuda:0")
        requires_grad(self.vae, False)
        self.tokenizer = CLIPTokenizer.from_pretrained(config.clip_version, cache_dir=config.cache_dir)
        self.text_encoder = CLIPTextModel.from_pretrained(config.clip_version, cache_dir=config.cache_dir).to("cuda:0")
        requires_grad(self.text_encoder, False)
        self.unet = UNet2DConditionModel.from_pretrained(config.stable_version, subfolder="unet", cache_dir=config.cache_dir).to("cuda:0")
        requires_grad(self.unet, False)

        self.uncond = self.get_text_embedding([config.uncond_words])

        #load attention controller and inject
        self.controller = Controller(config)
        inject_attention(self.unet, self.controller)

    def latent2tensor(self, latents):
        x_samples = self.vae.decode(latents / 0.18215).sample #[1,3,512,512]
        return x_samples
    def tensor2latent(self, image):
        latents = self.vae.encode(image)['latent_dist'].mean
        latents = latents * 0.18215
        return latents
    def get_text_embedding(self, prompt, device="cuda:0"):
        tokens = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        for i in range(1):
            text_embeddings = self.text_encoder(tokens.input_ids.to(device))[0] #[1, 77, 768]
            return text_embeddings
    def record_ddpm(self, z0, tims):
        #Add noises to z0 under the shared seed at tims and run unet
        rng = torch.Generator().manual_seed(self.config.ddpm_seed) #shared noise
        for t in tims:
            if t < 0:
                lat, t = init_latent, 0
            else:
                noise = torch.randn(z0.shape, generator=rng).to("cuda:0")
                zt = self.ddim_scheduler.add_noise(z0, noise=noise, timesteps = torch.tensor([t]).int())
            self.controller.t = t
            self.unet(zt, t, encoder_hidden_states = self.uncond)
    def diffusion_step(self, t, latents, cond1, cond2 = None, z_ddpm = None):
        #diffusion_step: x_t -> x_{t-1}
        if cond2 != None:
          text_embeddings  = torch.cat([cond1, cond2])
        else:
          text_embeddings = cond1
        noise_pred = self.unet(
                                latents,
                                t,
                                encoder_hidden_states=text_embeddings
                            ).sample
        z_ddpm = z_ddpm.expand(latents.shape[0],-1,-1,-1) #scheduled noise
        return reverse_step(self.ddim_scheduler, noise_pred, t, latents, 1.0, z_ddpm) #ddpm_inversion

"""# Config"""

class ConfigBase:
    #Model Config
    stable_version = "CompVis/stable-diffusion-v1-4"
    clip_version = "openai/clip-vit-large-patch14"
    cache_dir = "./"
    compress_rate = 32 #(input image resolution) / (cross-attention resolution)
    vae_compress_rate = 8 #(input image resolution) / (latent resolution)

    #Method Config
    inference_steps = 50 #total inference steps

    #For mask proposal
    lam1 = -10 #negative offset
    lam2 = 10 #positive offset
    num_mask = 15 #number of segmentation masks
    k_means_seed = 1234
    uncond_words =  ''
    n_init = 100 # use "auto" in practice
    init_algo = "k-means++"
    kmeans_algo = "lloyd"
    tims_mask = [0]*100 #timesteps for mask proposal. Better to sample several times for convergence. Reduce it in practice.
    inject_offset_tims = [281] #timesteps to modulaate feature maps (i.e. inject offsets)

    #For feature extraction
    ddpm_seed = 42 #we use shared ddpm noise

    #Create palatte (todo)
    np.random.seed(123)
    palette = np.random.choice(range(256), size=(1000,3)).astype(np.uint8)

    #Layers
    record_mask_proposal_layers = [("cross", "up", 0, 0)] #(ty, pos, res, idx), layers to record feature maps
    perturb_feature_layers = [("cross", "up", 0, 2)] #layers to inject noise

"""# Main pipeline"""

class MaskProposal:
    num_mask = None #number of masks
    ids_img = None #high-resolution segmentation maps: e.g. [512*512] -> mask_id
    ids_hidden = None #low-resolution segmentation maps: e.g. [16*16] -> mask_id

    def save_cluster(self, clusters, filename):
        col = self.config.palette[clusters.flatten()].reshape(clusters.shape+(3,))
        label = Image.fromarray(col).resize((self.w, self.h), Image.NEAREST)
        label.save(filename)

    def __init__(self, config, model, init_latent):

        self.config = config
        self.h = init_latent.shape[-2]*config.vae_compress_rate
        self.w = init_latent.shape[-1]*config.vae_compress_rate

        #Step 1. extract mask features maps for clustering
        model.controller.record_mask_proposal = True
        model.record_ddpm(init_latent, config.tims_mask)
        model.controller.record_mask_proposal = False

        #Step 2. cluster feature maps to produce low-resolution segmentation maps
        print("Perform K-means and generate low resolution masks..")
        mask_fet = model.controller.extract_mask_feature() #[16*16, 1280]
        kmeans = KMeans(n_clusters=config.num_mask, init=config.init_algo, n_init=config.n_init, random_state = config.k_means_seed, algorithm=config.kmeans_algo).fit(mask_fet.numpy())
        #self.save_cluster(kmeans.labels_.reshape((self.h//config.compress_rate, self.w//config.compress_rate)), "./mask_proposal.png")
        self.ids_hidden = torch.from_numpy(kmeans.labels_)

        self.num_mask = self.ids_hidden.max()+1

        #Perform DDPM inversion
        print("Perform DDPM inversion..")
        invert_until = int(max(config.inject_offset_tims)//20+1)
        xts, zs, ts = inversion_forward_process(model.unet, model.ddim_scheduler, init_latent, model.uncond, 1.0, config.inference_steps,  ddpm_seed = config.ddpm_seed)
        xts = xts.unsqueeze(1)
        xts = xts[-invert_until-1:-1] #excludes x0
        ts = ts[-invert_until:]
        zs = zs[-invert_until:]
        original_latent_t = list(zip(xts, ts))

        #Store original attention maps
        with torch.no_grad():
            model.controller.record_attention = True
            original_latents = original_latent_t[0][0]
            for i, (_, t) in enumerate(original_latent_t):
                model.controller.t = t
                original_latents = model.diffusion_step(t, original_latents, model.uncond, z_ddpm = zs[i])
            model.controller.record_attention = False

        #Step 3. Now, perform modulated diffusion process for each low-resolution mask
        print("Now generating pixel-level segmentation maps..")
        def sample_loop(latents):
            latents = torch.cat([latents]*2,dim=0)
            model.controller.use_recorded_attention = True
            with torch.no_grad():
                #Modulated denoising loop
                for i, (_, t) in enumerate(original_latent_t):
                    model.controller.t = t
                    model.controller.perturb_feature = (t in config.inject_offset_tims) #Modify attention map
                    latents = model.diffusion_step(t, latents, model.uncond, model.uncond, z_ddpm = zs[i])
            model.controller.use_recorded_attention = False
            model.controller.perturb_feature = False
            x_edited = model.latent2tensor(latents) #[2,3,512,512]
            return x_edited

        def generate_dif_map(cluster_id):
            print("Peform modulated denoising process for mask id:", cluster_id)
            model.controller.mask = (self.ids_hidden == cluster_id).float().to("cuda:0") #[16,16]
            x_edited = sample_loop(original_latent_t[0][0])
            x_edited = torchvision.transforms.functional.gaussian_blur(x_edited, kernel_size=3)
            dif_map = torch.linalg.norm(x_edited[0:1]-x_edited[1:2], dim=1)
            return dif_map #[1,512,512]

        lis_img = []
        for i in range(config.num_mask):
            with torch.no_grad():
                lis_img.append(generate_dif_map(i).cpu())
        self.ids_img = torch.argmax(torch.cat(lis_img,dim=0),dim=0) #[HxW]

"""# Load image"""

# Write an equivalent installation of the following code:
# !wget https://raw.githubusercontent.com/NVlabs/ODISE/main/demo/examples/coco.jpg
os.system("wget https://raw.githubusercontent.com/NVlabs/ODISE/main/demo/examples/coco.jpg")

# img = load_img(path="./coco.jpg", resize_to_512 = True).cuda()
print(img.min(), img.max())
plt.figure()
plt.imshow(((img[0]+1.0)/2.0*255.0).permute((1,2,0)).cpu().numpy().astype(np.uint8))
plt.axis('off')

"""# Generate segmentation masks"""

config = ConfigBase()
model = Model(config)
init_latent = model.tensor2latent(img)
mask_proposal = MaskProposal(config, model, init_latent)

"""# Visualize image-level segmentation maps"""

#Visualize pixel-level segmentation maps
col_img = config.palette[mask_proposal.ids_img.flatten()].reshape(mask_proposal.ids_img.shape+(3,))
im = np.array(col_img)*0.8+((img[0]+1.0)/2.0*255.0).permute((1,2,0)).cpu().numpy()*0.2
plt.imshow(Image.fromarray(im.astype(np.uint8)))
plt.axis('off')