In [1]:
import torch
import torch.nn as nn
from diffusers import AutoencoderKL, UNet2DConditionModel, UNet2DModel, DDPMScheduler


class TauEncoder(nn.Module):
    """
    Learnable encoder for the input RGB image (tau_theta).
    Architecturally same as the VAE encoder but trainable.
    """
    def __init__(self, vae: AutoencoderKL):
        super().__init__()
        # Copy VAE encoder structure
        self.encoder = vae.encoder
        self.quant_conv = vae.quant_conv

    def forward(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        mean, logvar = torch.chunk(moments, 2, dim=1)
        z = mean  # no sampling, deterministic
        return z


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class LDM_Segmentor(nn.Module):
    def __init__(self, pretrained_vae="CompVis/stable-diffusion-v1-4", scheduler_steps=1000, device="cuda"):
        super().__init__()
        self.device = device

        # Load frozen VAE
        self.vae = AutoencoderKL.from_pretrained(pretrained_vae, subfolder="vae").eval().to(device)
        for p in self.vae.parameters():
            p.requires_grad = False

        # Learnable encoder for input image (τ_θ)
        self.image_encoder = TauEncoder(self.vae).to(device)

        # Diffusion U-Net (8 channels in: 4 noisy mask + 4 image encoding)
        self.unet = UNet2DModel(
            sample_size=32,
            in_channels=8,
            out_channels=4,
            layers_per_block=2,
            block_out_channels=(128, 256, 256, 512),
            down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D"),
            up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D")
        ).to(device)

        # Scheduler (adds noise and steps)
        self.scheduler = DDPMScheduler(num_train_timesteps=scheduler_steps)

    def forward(self, image, mask, t):
        """
        Forward pass for training.
        image: (B, 3, 256, 256) → Input RGB image
        mask : (B, 3, 256, 256) → Binary mask (float in [-1, 1])
        t    : (B,)             → Timestep tensor for noise
        """
        # --- Step 1: Mask → VAE encoder (frozen)
        with torch.no_grad():
            posterior = self.vae.encode(mask).latent_dist
            z0 = posterior.sample() * 0.18215  # scaled latent

        # --- Step 2: Add noise to z0 using scheduler → zt
        noise = torch.randn_like(z0)
        zt = self.scheduler.add_noise(z0, noise, t)

        # --- Step 3: Image → Tau encoder → zc
        zc = self.image_encoder(image) * 0.18215

        # --- Step 4: Concatenate and denoise
        zt_cat = torch.cat([zt, zc], dim=1)  # (B, 8, 32, 32)
        noise_pred = self.unet(zt_cat, t).sample

        # --- Step 5: Decode z0_hat to mask
        with torch.no_grad():
            z0_hat   = self.scheduler.step(noise_pred, t, zt).prev_sample
            mask_hat = self.vae.decode(z0_hat / 0.18215).sample

        return {
            "z0": z0,
            "zt": zt,
            "zc": zc,
            "noise_pred": noise_pred,
            "z0_hat": z0_hat,
            "mask_hat": mask_hat
        }

In [4]:

# Dummy forward pass test
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LDM_Segmentor(device=device)

B, C, H, W = 1, 3, 256, 256
image = torch.randn(B, C, H, W).to(device)
mask  = (torch.rand(B, C, H, W) > 0.5).float().to(device)
t     = torch.randint(0, model.scheduler.config.num_train_timesteps, (B,), device=device)

outputs = model(image, mask, t)
outputs.keys()


dict_keys(['z0', 'zt', 'zc', 'noise_pred', 'z0_hat', 'mask_hat'])

In [2]:
class LDM_Segmentor_CrossAttention(nn.Module):
    """
    LDM-based segmentation model using cross-attention from UNet2DConditionModel.
    """

    def __init__(self, device="cuda", latent_scale=0.18215, num_inference_steps=1000):
        super().__init__()
        self.device = torch.device(device)
        self.latent_scale = latent_scale
        self.cross_attn_proj = nn.Linear(4, 768).to(device)  # Match SD1.4's cross_attention_dim

        # ------------------------------
        # Load pretrained VAE (Frozen)
        # ------------------------------
        self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device).eval()
        for p in self.vae.parameters():
            p.requires_grad = False

        # ------------------------------
        # Learnable Tau encoder
        # ------------------------------
        self.tau = TauEncoder(self.vae).to(self.device)

        # ------------------------------
        # Load pretrained U-Net (CrossAttention enabled)
        # ------------------------------
        self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").to(self.device)
        self.unet.train()  # Learnable during segmentation

        # ------------------------------
        # Scheduler (e.g., DDPM)
        # ------------------------------
        self.scheduler = DDPMScheduler(num_train_timesteps=num_inference_steps)
        
    def forward(self, image: torch.Tensor, mask: torch.Tensor, t: torch.Tensor):
        """
        Args:
            image (torch.Tensor): Input RGB image in [-1, 1], shape (B, 3, 256, 256)
            mask (torch.Tensor): Input GT mask in [-1, 1], shape (B, 3, 256, 256)
            t (torch.Tensor): Timesteps, shape (B,)
        Returns:
            dict: All intermediate latents and predicted mask.
        """

        # Step 1: Encode mask into latent z0 using frozen VAE
        with torch.no_grad():
            z0 = self.vae.encode(mask).latent_dist.sample() * self.latent_scale

        # Step 2: Add noise to z0 using scheduler → zt
        noise = torch.randn_like(z0)
        zt = self.scheduler.add_noise(z0, noise, t)

        # Step 3: Encode image into conditioning vector using Tau encoder
        z_c = self.tau(image)  # (B, 4, 32, 32)

        # Step 4: Reshape z_c to (B, HW, C) for cross-attention
        B, C, H, W = z_c.shape
        cross_attn = z_c.view(B, C, -1).permute(0, 2, 1)  # (B, HW, 4)
        cross_attn = self.cross_attn_proj(cross_attn)     # (B, HW, 768)

        # Step 5: Predict noise residual using cross-attention U-Net
        noise_pred = self.unet(sample = zt, timestep = t, encoder_hidden_states = cross_attn).sample

        # Step 6: Estimate denoised latent z0_hat using scheduler
        with torch.no_grad():
            z0_hat = self.scheduler.step(noise_pred, t, zt).prev_sample
            mask_hat = self.vae.decode(z0_hat / self.latent_scale).sample

        return {
            "z0": z0,
            "zt": zt,
            "z_c": z_c,
            "z0_hat": z0_hat,
            "mask_hat": mask_hat
        }


In [3]:

# Dummy forward pass test
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LDM_Segmentor_CrossAttention(device = device)

B, C, H, W = 1, 3, 256, 256
image = torch.randn(B, C, H, W).to(device)
mask  = (torch.rand(B, C, H, W) > 0.5).float().to(device)
t     = torch.randint(0, model.scheduler.config.num_train_timesteps, (B,), device=device)

outputs = model(image, mask, t)
outputs.keys()


dict_keys(['z0', 'zt', 'z_c', 'z0_hat', 'mask_hat'])

In [5]:
outputs['mask_hat'].shape

torch.Size([1, 3, 256, 256])

In [1]:
from diffusers import UNet2DConditionModel
import torch
import torch.nn as nn

def load_hybrid_unet(pretrained_model_path: str, device: str = "cuda") -> UNet2DConditionModel:
    """
    Loads a pretrained UNet2DConditionModel and updates in_channels from 4 → 8.
    The first 4 channels are copied from the pretrained model and the remaining 4 are randomly initialized.

    Returns:
        Modified UNet2DConditionModel with in_channels=8
    """
    # Load the pretrained model (in_channels=4)
    unet_pretrained = UNet2DConditionModel.from_pretrained(
        pretrained_model_path,
        subfolder="unet",
        torch_dtype=torch.float16
    ).to(device)

    # Deep copy config and modify in_channels
    config = unet_pretrained.config
    config.in_channels = 8  # <- THIS is critical

    # Reinitialize the model with new config
    unet_hybrid = UNet2DConditionModel.from_config(config).to(device)

    # Get pretrained weights and new model weights
    pretrained_sd = unet_pretrained.state_dict()
    hybrid_sd     = unet_hybrid.state_dict()

    # Copy matching keys except conv_in.weight
    for key in hybrid_sd:
        if key != "conv_in.weight" and key in pretrained_sd and hybrid_sd[key].shape == pretrained_sd[key].shape:
            hybrid_sd[key] = pretrained_sd[key]

    # Handle conv_in.weight separately
    old_conv_weight = pretrained_sd["conv_in.weight"]  # Shape: [320, 4, 3, 3]
    out_channels, _, kH, kW = old_conv_weight.shape

    new_conv_weight = torch.zeros((out_channels, 8, kH, kW), dtype=old_conv_weight.dtype).to(device)

    # Copy first 4 channels from pretrained
    new_conv_weight[:, :4, :, :] = old_conv_weight

    # Random init for the additional 4 channels
    nn.init.kaiming_normal_(new_conv_weight[:, 4:, :, :])

    # Update the state dict
    hybrid_sd["conv_in.weight"] = new_conv_weight

    # Load the updated state dict
    unet_hybrid.load_state_dict(hybrid_sd)

    return unet_hybrid.eval().requires_grad_(True)


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
unet = load_hybrid_unet("CompVis/stable-diffusion-v1-4")
print(unet.conv_in.weight.shape)  # Should be torch.Size([320, 8, 3, 3])
