In [1]:
from config import *
from architectures import * 



In [None]:
device     = 'cpu'
model      = LDM_Segmentor(pretrained_vae="CompVis/stable-diffusion-v1-4", scheduler_steps=1000, device="cpu").to(device)

model.eval()


In [7]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
trainable_params

67481604

In [7]:
import torch
import copy
import torch.nn as nn
from   diffusers import AutoencoderKL, UNet2DConditionModel, UNet2DModel, DDPMScheduler

torch.cuda.empty_cache()

class TauEncoder(nn.Module):
    """
    Learnable encoder for the input RGB image (tau_theta).
    Architecturally same as the VAE encoder but trainable.
    """
    def __init__(self, vae: AutoencoderKL):
        super().__init__()
        # Copy VAE encoder structure
        self.encoder = vae.encoder
        self.quant_conv = vae.quant_conv

    def forward(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        mean, logvar = torch.chunk(moments, 2, dim=1)
        z = mean  # no sampling, deterministic
        return z

In [None]:
class LDM_Segmentor(nn.Module):
    def __init__(self, pretrained_vae="CompVis/stable-diffusion-v1-4", scheduler_steps=1000, device="cuda"):
        super().__init__()
        self.device = device

        # Load frozen VAE
        self.vae = AutoencoderKL.from_pretrained(pretrained_vae, subfolder="vae").eval().to(device)
        for p in self.vae.parameters():
            p.requires_grad = False

        # Learnable encoder for input image (τ_θ)
        self.image_encoder = TauEncoder(self.vae).to(device)

        # Diffusion U-Net (8 channels in: 4 noisy mask + 4 image encoding)
        self.unet = UNet2DModel(
            sample_size=32,
            in_channels=8,
            out_channels=4,
            layers_per_block=2,
            block_out_channels=(128, 256, 256, 512),
            down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D"),
            up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D")
        ).to(device)

        # Scheduler (adds noise and steps)
        self.scheduler = DDPMScheduler(num_train_timesteps=scheduler_steps)

    def forward(self, image, mask, t):
        """
        Forward pass for training.
        image: (B, 3, 256, 256) → Input RGB image
        mask : (B, 3, 256, 256) → Binary mask (float in [-1, 1])
        t    : (B,)             → Timestep tensor for noise
        """
        # --- Step 1: Mask → VAE encoder (frozen)
        with torch.no_grad():
            posterior = self.vae.encode(mask).latent_dist
            z0 = posterior.sample() * 0.18215  # scaled latent # (B, 4, 32, 32)

        # --- Step 2: Add noise to z0 using scheduler → zt
        noise = torch.randn_like(z0)
        zt = self.scheduler.add_noise(z0, noise, t)

        # --- Step 3: Image → Tau encoder → zc
        zc = self.image_encoder(image) * 0.18215

        # --- Step 4: Concatenate and denoise
        zt_cat     = torch.cat([zt, zc], dim=1)  # (B, 8, 32, 32)
        noise_pred = self.unet(zt_cat, t).sample # (B, 4, 32, 32)

        # --- Step 5: Decode z0_hat to mask
        with torch.no_grad():
            z0_hat_list = []
            mask_hat_list = []
            for batch_idx in range(image.shape[0]):
                z0_hat   = self.scheduler.step(noise_pred[batch_idx].unsqueeze(0), t[batch_idx].unsqueeze(0), zt[batch_idx].unsqueeze(0)).pred_original_sample
                mask_hat = self.vae.decode(z0_hat / 0.18215).sample
                z0_hat_list.append(z0_hat)
                mask_hat_list.append(mask_hat)

            z0_hat = torch.cat(z0_hat_list, dim=0)  # (B, 4, 32, 32)
            mask_hat = torch.cat(mask_hat_list, dim=0)

        return {
            "z0": z0,
            "zt": zt,
            "zc": zc,
            "noise_pred": noise_pred,
            "z0_hat": z0_hat,
            "mask_hat": mask_hat
        }

In [35]:

# Dummy forward pass test
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LDM_Segmentor(device=device)

B, C, H, W = 2, 3, 256, 256
image = torch.randn(B, C, H, W).to(device)
mask  = (torch.rand(B, C, H, W) > 0.5).float().to(device)
t     = torch.randint(0, model.scheduler.config.num_train_timesteps, (B,), device=device)

outputs = model(image, mask, t)
outputs.keys()


torch.Size([1, 4, 32, 32])
First loop done!
torch.Size([1, 4, 32, 32])
First loop done!


dict_keys(['z0', 'zt', 'zc', 'noise_pred', 'z0_hat', 'mask_hat'])

In [2]:
class LDM_Segmentor_CrossAttention(nn.Module):
    """
    LDM-based segmentation model using cross-attention from UNet2DConditionModel.
    """

    def __init__(self, device="cuda", latent_scale=0.18215, num_inference_steps=1000):
        super().__init__()
        self.device = torch.device(device)
        self.latent_scale = latent_scale
        self.cross_attn_proj = nn.Linear(4, 768).to(device)  # Match SD1.4's cross_attention_dim

        # ------------------------------
        # Load pretrained VAE (Frozen)
        # ------------------------------
        self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device).eval()
        for p in self.vae.parameters():
            p.requires_grad = False

        # ------------------------------
        # Learnable Tau encoder
        # ------------------------------
        self.tau = TauEncoder(self.vae).to(self.device)

        # ------------------------------
        # Load pretrained U-Net (CrossAttention enabled)
        # ------------------------------
        self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").to(self.device)
        self.unet.train()  # Learnable during segmentation

        # ------------------------------
        # Scheduler (e.g., DDPM)
        # ------------------------------
        self.scheduler = DDPMScheduler(num_train_timesteps=num_inference_steps)
        
    def forward(self, image: torch.Tensor, mask: torch.Tensor, t: torch.Tensor):
        """
        Args:
            image (torch.Tensor): Input RGB image in [-1, 1], shape (B, 3, 256, 256)
            mask (torch.Tensor): Input GT mask in [-1, 1], shape (B, 3, 256, 256)
            t (torch.Tensor): Timesteps, shape (B,)
        Returns:
            dict: All intermediate latents and predicted mask.
        """

        # Step 1: Encode mask into latent z0 using frozen VAE
        with torch.no_grad():
            z0 = self.vae.encode(mask).latent_dist.sample() * self.latent_scale

        # Step 2: Add noise to z0 using scheduler → zt
        noise = torch.randn_like(z0)
        zt = self.scheduler.add_noise(z0, noise, t)

        # Step 3: Encode image into conditioning vector using Tau encoder
        z_c = self.tau(image)  # (B, 4, 32, 32)

        # Step 4: Reshape z_c to (B, HW, C) for cross-attention
        B, C, H, W = z_c.shape
        cross_attn = z_c.view(B, C, -1).permute(0, 2, 1)  # (B, HW, 4)
        cross_attn = self.cross_attn_proj(cross_attn)     # (B, HW, 768)

        # Step 5: Predict noise residual using cross-attention U-Net
        noise_pred = self.unet(sample = zt, timestep = t, encoder_hidden_states = cross_attn).sample

        # Step 6: Estimate denoised latent z0_hat using scheduler
        with torch.no_grad():
            z0_hat = self.scheduler.step(noise_pred, t, zt).prev_sample
            mask_hat = self.vae.decode(z0_hat / self.latent_scale).sample

        return {
            "z0": z0,
            "zt": zt,
            "z_c": z_c,
            "z0_hat": z0_hat,
            "mask_hat": mask_hat
        }


In [None]:

# Dummy forward pass test
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LDM_Segmentor_CrossAttention(device = device)

B, C, H, W = 1, 3, 256, 256
image = torch.randn(B, C, H, W).to(device)
mask  = (torch.rand(B, C, H, W) > 0.5).float().to(device)
t     = torch.randint(0, model.scheduler.config.num_train_timesteps, (B,), device=device)

outputs = model(image, mask, t)
outputs.keys()


In [2]:
def load_hybrid_unet(pretrained_path: str, device: str = "cuda") -> UNet2DConditionModel:
    # Step 1: Load the original model to access weights and config
    unet_orig = UNet2DConditionModel.from_pretrained(
        pretrained_path,
        subfolder="unet",
        torch_dtype=torch.float16
    ).to(device)

    # Step 2: Create a deep copy of the config and modify in_channels
    config = copy.deepcopy(unet_orig.config)  # Avoid modifying the original config
    config['in_channels'] = 8  # Update in_channels to 8

    # Step 3: Create a new UNet model with modified in_channels
    unet_new = UNet2DConditionModel(**config).to(device, dtype=torch.float16)

    # Step 4: Get the original and new state dicts
    orig_sd = unet_orig.state_dict()
    new_sd = unet_new.state_dict()

    # Step 5: Initialize new conv_in weights manually
    old_conv = orig_sd["conv_in.weight"]  # Shape [320, 4, 3, 3]
    out_ch, _, kH, kW = old_conv.shape

    # New conv_in.weight: [320, 8, 3, 3]
    new_conv = torch.zeros((out_ch, 8, kH, kW), dtype=torch.float16, device=device)
    new_conv[:, :4, :, :] = old_conv  # Copy pretrained channels
    nn.init.kaiming_normal_(new_conv[:, 4:, :, :], mode='fan_out', nonlinearity='leaky_relu')  # Random init rest

    # Step 6: Update the state dict with the new conv_in.weight
    new_sd["conv_in.weight"] = new_conv

    # Step 7: Copy compatible weights from the original model
    for key in new_sd:
        if key != "conv_in.weight" and key in orig_sd and new_sd[key].shape == orig_sd[key].shape:
            new_sd[key] = orig_sd[key]

    # Step 8: Load updated state dict into the new model
    unet_new.load_state_dict(new_sd)

    return unet_new.requires_grad_(True)

In [None]:
unet = load_hybrid_unet("CompVis/stable-diffusion-v1-4")
print(unet.conv_in.weight.shape)  # Should be torch.Size([320, 8, 3, 3])


In [7]:
inp = torch.randn(2, 8, 64, 64).to(torch.float16).to('cuda')
emb = torch.randn(2, 77, 768).to(torch.float16).to('cuda')
ts = 999.0
output = unet(inp, ts, encoder_hidden_states=emb)

In [3]:
class LDM_Segmentor_Concatenation(nn.Module):
    """
    LDM-based segmentation self.unet using Concatenation from UNet2DConditionself.unet with additional parameters randomly initalized.
    """

    def __init__(self, device="cuda", latent_scale=0.18215, num_inference_steps=1000):
        super().__init__()
        self.device = torch.device(device)
        self.latent_scale = latent_scale

        # ------------------------------
        # Load pretrained VAE (Frozen)
        # ------------------------------
        self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device).eval()
        for p in self.vae.parameters():
            p.requires_grad = False

        # ------------------------------
        # Learnable Tau encoder
        # ------------------------------
        self.tau = TauEncoder(self.vae).to(self.device)

        # ------------------------------
        # Load pretrained U-Net (CrossAttention enabled)
        # ------------------------------
        self.unet = load_hybrid_unet("CompVis/stable-diffusion-v1-4")
        self.unet.train()  # Learnable during segmentation

        # ------------------------------
        # Scheduler (e.g., DDPM)
        # ------------------------------
        self.scheduler = DDPMScheduler(num_train_timesteps=num_inference_steps)
        
    def forward(self, image: torch.Tensor, mask: torch.Tensor, t: torch.Tensor):
        """
        Args:
            image (torch.Tensor): Input RGB image in [-1, 1], shape (B, 3, 256, 256)
            mask (torch.Tensor): Input GT mask in [-1, 1], shape (B, 3, 256, 256)
            t (torch.Tensor): Timesteps, shape (B,)
        Returns:
            dict: All intermediate latents and predicted mask.
        """
        
        print(self.unet.conv_in.weight.dtype)
        print(self.unet.conv_in.bias.dtype)
        print(image.dtype)  # Must be float16 too

        
        # Step 1: Encode mask into latent z0 using frozen VAE
        with torch.no_grad():
            z0 = self.vae.encode(mask).latent_dist.sample() * self.latent_scale

        # Step 2: Add noise to z0 using scheduler → zt
        noise = torch.randn_like(z0)
        zt = self.scheduler.add_noise(z0, noise, t)

        # Step 3: Encode image into conditioning vector using Tau encoder
        zc = self.tau(image) * self.latent_scale  # (B, 4, 32, 32)

        # --- Step 4: Concatenate
        zt_cat = torch.cat([zt, zc], dim=1).to(device, dtype = torch.float16)  # (B, 8, 32, 32)

        # Step 5: Predict noise residual using dummy tensor of 0's in encoder_hidden_states (instead of text emneddings)
        dummy_embed = torch.zeros(mask.shape[0], 77, 768).to(dtype=torch.float16, device=device)
        noise_pred = self.unet(sample = zt_cat, timestep = t, encoder_hidden_states = dummy_embed).sample

        # Step 6: Estimate denoised latent z0_hat using scheduler
        with torch.no_grad():
            z0_hat = self.scheduler.step(noise_pred, t, zt).prev_sample
            mask_hat = self.vae.decode(z0_hat / self.latent_scale).sample

        return {
            "z0": z0,
            "zt": zt,
            "zc": zc,
            "z0_hat": z0_hat,
            "mask_hat": mask_hat
        }


In [4]:

# Dummy forward pass test
device = "cuda" if torch.cuda.is_available() else "cpu"
model = LDM_Segmentor_Concatenation(device = device)

B, C, H, W = 1, 3, 256, 256
image = torch.randn(B, C, H, W).to(device, dtype = torch.float16)
mask  = (torch.rand(B, C, H, W) > 0.5).to(device, dtype = torch.float16)
t     = torch.randint(0, model.scheduler.config.num_train_timesteps, (B,), device=device, dtype = torch.float16)

outputs = model(image, mask, t)
outputs.keys()


torch.float16
torch.float16
torch.float16


RuntimeError: Input type (c10::Half) and bias type (float) should be the same