In [1]:
# C:\Users\Talha\AppData\Local\Programs\Python\Python312\Scripts>

import numpy as np
import torch
import os
from torchinfo import summary
from networks.novel.lip_resnet.lip_resnet_encoder import *

In [3]:
pickle_file_load = np.load('/Users/talhaahmed/Library/CloudStorage/OneDrive-HigherEducationCommission/Integration/GitHub/Diffusion-Codes/GMS/Dataset/busi/busi_train_test_names.pkl', allow_pickle = True)


In [10]:
image_dir = os.listdir('/Users/talhaahmed/Library/CloudStorage/OneDrive-HigherEducationCommission/Integration/GitHub/Diffusion-Codes/GMS/Dataset/busi/masks')
len(image_dir)

780

In [30]:
img_rgb = torch.randn((1, 3, 224, 224))
img_rgb.shape

torch.Size([1, 3, 224, 224])

In [34]:
enc = LIPResNetEncoder(backbone = 'resnet50',   # or 'resnet34'
                       pretrained      = False,
                       latent_channels = 4,     # to match VAE Z_I
                       model_freeze    = True,
                       lip_freeze      = False)           # optional fine-tune flag
feat_1_8, x, z_lip = enc(img_rgb)                          # shape (B, 4, 28, 28) if input 224×224

Trainable LIP params : 2.18 M
Frozen backbone parms: 22.92 M


In [36]:
print(f'feat_1_8.shape: {feat_1_8.shape}, x.shape: {x.shape}, z_lip.shape: {z_lip.shape}')

feat_1_8.shape: torch.Size([1, 512, 28, 28]), x.shape: torch.Size([1, 1000]), z_lip.shape: torch.Size([1, 4, 28, 28])


In [37]:
summary(enc, input_size=(1, 3, 224, 224), col_names = ["input_size", "output_size", "num_params", "trainable"])

Layer (type:depth-idx)                                       Input Shape               Output Shape              Param #                   Trainable
LIPResNetEncoder                                             [1, 3, 224, 224]          [1, 512, 28, 28]          --                        Partial
├─_LIPResNetBackbone: 1-1                                    [1, 3, 224, 224]          [1, 512, 28, 28]          --                        Partial
│    └─Conv2d: 2-1                                           [1, 3, 224, 224]          [1, 64, 112, 112]         (9,408)                   False
│    └─BatchNorm2d: 2-2                                      [1, 64, 112, 112]         [1, 64, 112, 112]         (128)                     False
│    └─ReLU: 2-3                                             [1, 64, 112, 112]         [1, 64, 112, 112]         --                        --
│    └─SimplifiedLIP: 2-4                                    [1, 64, 112, 112]         [1, 64, 56, 56]           --          

In [None]:
import torch, torch.nn.functional as F
from torch.autograd import gradcheck

# dummy tensors with requires_grad
x     = torch.randn(2, 3, 16, 16, dtype=torch.double, requires_grad=True)
logit = torch.randn(2, 1, 16, 16, dtype=torch.double, requires_grad=True)

def lip_wrapper(inp, lg):
    w = lg.exp()
    return F.avg_pool2d(inp * w, 3, 2, 1) / F.avg_pool2d(w, 3, 2, 1)

# gradient check
print(gradcheck(lip_wrapper, (x, logit), eps=1e-6, atol=1e-4))  # → True


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------------- LIP operator ------------------------------------ #
def lip2d(x, logit, kernel=3, stride=2, padding=1, margin=1e-6):
    weight = logit.exp()     # (B, 1, H, W), all values > 0
    a = F.avg_pool2d(x * weight, kernel, stride, padding)   # x * weight --> weighted pooling 
    b = F.avg_pool2d(weight, kernel, stride, padding) + margin
    return a / b             # normalized local weighted sum


# ----------------------------------- Bottleneck Logit Module --------------------------- #
class BottleneckLogit(nn.Module):
    def __init__(self, in_channels, bottleneck_ratio=4):
        super().__init__()
        mid = in_channels // bottleneck_ratio
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, mid, 1),         # 1×1 conv: compression
            nn.InstanceNorm2d(mid),
            nn.ReLU(inplace=True),

            nn.Conv2d(mid, mid, 3, padding=1),       # 3×3 conv: spatial processing
            nn.InstanceNorm2d(mid),
            nn.ReLU(inplace=True),

            nn.Conv2d(mid, 1, 1)                     # 1×1 conv: compress to 1 channel
        )

    def forward(self, x):
        return self.net(x) # logit output like a heatmap of importance of each pixel (x, y) regardless of channels

# ----------------------------------- LIP Block ----------------------------------------- #
class LIPBlock(nn.Module):
    def __init__(self, in_ch, out_ch, logit_module=None):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.logit_module = logit_module or BottleneckLogit(out_ch)

    def forward(self, x):
        x = self.conv(x)
        logits = self.logit_module(x)
        return lip2d(x, logits)   # downsampling

# ------------------------------ Arbitrary LIP Encoder ---------------------------------- #
class ArbitraryLIPEngineEncoder(nn.Module):
    def __init__(self, in_channels=3, latent_channels=4):
        super().__init__()
        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, 64, 7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )  # (B, 64, 112, 112)

        # LIP Block 1: (B, 64, 112, 112) → (B, 128, 56, 56)
        self.block1 = LIPBlock(64, 128)

        # LIP Block 2: (B, 128, 56, 56) → (B, 256, 28, 28)
        self.block2 = LIPBlock(128, 256)

        # Head: Project to latent space (B, 256, 28, 28) → (B, 4, 28, 28)
        self.head = nn.Sequential(
            nn.Conv2d(256, 64, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, latent_channels, 1)
        )

    def forward(self, x):
        x = self.stem(x)     # (B, 64, 112, 112)
        x = self.block1(x)   # (B, 128, 56, 56)
        x = self.block2(x)   # (B, 256, 28, 28)
        x = self.head(x)     # (B, 4, 28, 28)
        return x


In [None]:
model = ArbitraryLIPEngineEncoder(in_channels=3, latent_channels=4)
x     = torch.randn(2, 3, 224, 224)
z     = model(x)
print(z.shape)  # ➜ torch.Size([2, 4, 28, 28])


In [None]:
summary(model, input_size=(2, 3, 224, 224), col_names = ["input_size", "output_size", "num_params", "trainable"])

In [None]:
from networks.latent_mapping_model import *

mapping_model = ResAttnUNet_DS(
            in_channel = 4,
            out_channels = 4,
            num_res_blocks =  2,
            ch = 32,
            ch_mult= [1, 2, 4, 4],
        )
    

In [None]:
summary(mapping_model, input_size=(2, 4, 56, 56), col_names = ["input_size", "output_size", "num_params", "trainable"])

In [None]:
# For in_channels = 4 Trainable params = 1,564,976 vs for 8 channels its 1,566,128

In [None]:
import torch
from einops import rearrange
import torch.nn.functional as F


def space_to_channel(x: torch.Tensor, p: int) -> torch.Tensor:
    """
    x : tensor of shape  (H, W, C)   (no batch dim shown)
    p : spatial block size (must divide H and W)

    returns: tensor of shape (H//p, W//p, p*p*C)
    """
    x_orig = x

    H, W, C = x.shape
    assert H % p == 0 and W % p == 0, "H and W must be divisible by p"

    # 1) split each spatial dim into (H//p, p) and (W//p, p)
    x = x.view(H // p, p, W // p, p, C)          # (H/p, p, W/p, p, C)

    # 2) bring the small p×p blocks next to the channel dim
    x = x.permute(0, 2, 1, 3, 4)                 # (H/p, W/p, p, p, C)

    # 3) merge them into the channel axis
    x = x.reshape(H // p, W // p, C * p * p)     # (H/p, W/p, p²C)

    y = rearrange(x_orig, '(h ph) (w pw) c -> h w (ph pw c)', ph=p, pw=p)

    z = F.pixel_unshuffle(x_orig.permute(2, 0, 1), downscale_factor=p)
    return x, y, z


In [None]:
x = torch.randn((224, 224, 3))
p = 2

x, y, z = space_to_channel(x, p = 2)

print(f'x.shape: {x.shape}, y.shape: {y.shape}, z.shape: {z.shape}')

In [None]:
import torch
from einops import rearrange

def downsample_shortcut(x: torch.Tensor, p: int = 2) -> torch.Tensor:
    """
    Non-learnable “space-to-channel + channel-averaging”

    Args
    ----
    x : (H, W, C)  – no batch for brevity
    p : block size (p = 2 in the diagram)

    Returns
    -------
    y : (H//p, W//p, 2C)
    """                                                          
    H, W, C = x.shape
    assert H % p == 0 and W % p == 0, "H and W must be divisible by p" 

    # 1) space → channel  … (H, W, C) ⟶ (H/p, W/p, p²·C)
    s2c = rearrange(x, '(h ph) (w pw) c -> h w (ph pw c)', ph=p, pw=p) # H , w p, 

    # 2) split the p²·C channels into two equal parts and average them
    g1, g2 = torch.chunk(s2c, 2, dim=-1)            # each: (H/p, W/p, p²·C/2)
    y = 0.5 * (g1 + g2)                             # (H/p, W/p, 2C)

    return y


In [None]:
def upsample_shortcut(x: torch.Tensor, p: int = 2) -> torch.Tensor:
    """
    Non-learnable “channel-to-space + channel-duplication”

    Args
    ----
    x : (H//p, W//p, 2C)
    p : block size (2 in the diagram)

    Returns
    -------
    y : (H, W, C)
    """
    Hp, Wp, twoC = x.shape
    assert twoC % (p**2) == 0, "Channel dim must equal p² × (C/2)"
    C_half = twoC // (p**2)          # will become C/2 after expansion

    # 1) channel → space  … (Hp, Wp, 2C) ⟶ (H, W, C/2)
    c2s = rearrange(x, 'h w (ph pw c) -> (h ph) (w pw) c', ph=p, pw=p)

    # 2) duplicate along channel dim and concatenate  … (H, W, C/2) ⟶ (H, W, C)
    y = torch.cat([c2s, c2s], dim=-1)

    return y


In [None]:
H, W, C = 224, 224, 4        # C must be even for the round-trip to work
x0     = torch.randn(H, W, C)

x1 = downsample_shortcut(x0)  # (112, 112, 8)
x2 = upsample_shortcut(x1)    # (224, 224, 4)

assert x2.shape == x0.shape


In [None]:
import torch
import torch.nn as nn
from einops import rearrange

# -------------------------------------------------------------------------
# 1) your original non-parametric helpers (unchanged, pasted for completeness)
# -------------------------------------------------------------------------

def downsample_shortcut(x: torch.Tensor, p: int = 2) -> torch.Tensor:
    H, W, C = x.shape
    assert H % p == 0 and W % p == 0
    s2c = rearrange(x, '(h ph) (w pw) c -> h w (ph pw c)', ph=p, pw=p)
    g1, g2 = torch.chunk(s2c, 2, dim=-1)
    return 0.5 * (g1 + g2)                       # (H/p, W/p, 2C)


def upsample_shortcut(x: torch.Tensor, p: int = 2) -> torch.Tensor:
    Hp, Wp, twoC = x.shape
    assert twoC % (p**2) == 0
    c2s = rearrange(x, 'h w (ph pw c) -> (h ph) (w pw) c', ph=p, pw=p)
    return torch.cat([c2s, c2s], dim=-1)         # (H, W, C)

# -------------------------------------------------------------------------
# 2) thin wrappers that make the helpers *batch-aware*
# -------------------------------------------------------------------------

def _map_per_sample(fn, x, **kwargs):
    """
    Apply `fn` to each sample in a batched tensor.
    Accepts tensors of shape (B, H, W, C) or (H, W, C).
    """
    if x.ndim == 3:   # (H, W, C)
        return fn(x, **kwargs)
    elif x.ndim == 4: # (B, H, W, C)
        return torch.stack([fn(sample, **kwargs) for sample in x], dim=0)
    else:
        raise ValueError("Expected tensor of shape (H,W,C) or (B,H,W,C)")

def downsample_shortcut_batched(x, p=2):
    return _map_per_sample(downsample_shortcut, x, p=p)

def upsample_shortcut_batched(x, p=2):
    return _map_per_sample(upsample_shortcut, x, p=p)

# -------------------------------------------------------------------------
# 3) residual-autoencoding blocks that call the (batched) shortcuts
# -------------------------------------------------------------------------

class ResidualDownAE(nn.Module):
    """
    x  ──► learnable_down ──► (+) ──► y
     ╰─► non-param down-shortcut ─╯
    
    • Input : (B?, H, W, C)
    • Output: (B?, H/2, W/2, 2C)
    """
    def __init__(self, learnable_down: nn.Module, p: int = 2):
        super().__init__()
        self.learnable_down = learnable_down   # any module that halves H,W and doubles C
        self.p = p

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.learnable_down(x) + downsample_shortcut_batched(x, p=self.p)


class ResidualUpAE(nn.Module):
    """
    x  ──► learnable_up ──► (+) ──► y
     ╰─► non-param up-shortcut ─╯
    
    • Input : (B?, H/2, W/2, 2C)
    • Output: (B?, H,   W,   C)
    """
    def __init__(self, learnable_up: nn.Module, p: int = 2):
        super().__init__()
        self.learnable_up = learnable_up       # any module that doubles H,W and halves C
        self.p = p

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.learnable_up(x) + upsample_shortcut_batched(x, p=self.p)

# -------------------------------------------------------------------------
# 4) minimal demo with toy learnable blocks
# -------------------------------------------------------------------------

class ToyDown(nn.Module):     # (H,W,C) ➜ (H/2,W/2,2C)
    def __init__(self, C): 
        super().__init__()
        self.conv = nn.Conv2d(C, 2*C, 3, 2, 1)

    def forward(self, x):      # convert to NCHW for conv, then back to HWC
        x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        return x

class ToyUp(nn.Module):       # (H/2,W/2,2C) ➜ (H,W,C)
    def __init__(self, C): 
        super().__init__()
        self.tconv = nn.ConvTranspose2d(2*C, C, 4, 2, 1)
        
    def forward(self, x):
        x = self.tconv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        return x

# Instantiate ---------------------------------------------------------------
B, H, W, C = 8, 256, 256, 64
x = torch.randn(B, H, W, C)

down_block = ResidualDownAE(ToyDown(C))
up_block   = ResidualUpAE  (ToyUp(C))

d = down_block(x)   # (B, 128, 128, 128)
u = up_block(d)     # (B, 256, 256,  64)

print(d.shape, u.shape)
# -------------------------------------------------------------------------


In [40]:
# utils/load_tiny_vae.py

%load_ext autoreload
%autoreload 2

from collections import OrderedDict
import re
import torch

from diffusers import AutoencoderTiny as HF_TinyVAE
from networks.novel.tiny_vae.autoencoder_tiny import AutoencoderTiny  # your new class
from torchinfo import summary

# ------------------------------------------------------------------
# Adjust these imports to match your repo layout
# ------------------------------------------------------------------y

def _remap_key(key: str) -> str:
    # Only rename encoder downsample convs that are wrapped.
    encoder_down_layers = {2, 6, 10}

    enc_match = re.match(r"(encoder\.layers\.(\d+))\.(weight|bias)", key)
    if enc_match:
        idx = int(enc_match.group(2))
        if idx in encoder_down_layers:
            return f"{enc_match.group(1)}.down.{enc_match.group(3)}"
        else:
            return key

    # Decoder: *no* renaming needed because convs are outside ResidualUpAE.
    return key




def load_residual_tiny_vae(
    device="cuda",
    freeze=False,
    dtype=torch.float32,
):
    # 1. Original checkpoint
    hf_vae = HF_TinyVAE.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
    hf_state = hf_vae.state_dict()

    # 2. Remap keys
    remapped = OrderedDict()
    for k, v in hf_state.items():
        remapped[_remap_key(k)] = v

    # 3. Instantiate our new model
    vae = AutoencoderTiny().to(device).to(dtype)

    # 4. Load
    missing, unexpected = vae.load_state_dict(remapped, strict=False)
    if missing or unexpected:
        print("⚠️  Unmatched keys")
        print("  missing   :", missing)
        print("  unexpected:", unexpected)
    else:
        print("✅  State-dict loaded successfully.")

    # 5. Freeze if requested
    if freeze:
        for p in vae.parameters():
            p.requires_grad_(False)

    vae.eval()
    return vae


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [41]:
vae_model = load_residual_tiny_vae(device="cpu")


✅  State-dict loaded successfully.


In [42]:
# ----------------------- hyper-params -----------------------------
BATCH_SIZE   = 2
IMG_SHAPE    = (3, 224, 224)   # C, H, W
LATENT_CH    = 4               # matches default
ENC_BLOCKS   = (1, 3, 3, 3)
DEC_BLOCKS   = (3, 3, 3, 1)
CHANNELS     = (64, 64, 64, 64)

# ----------------------- instantiate model ------------------------
# note that AutoencoderTiny default params is the same as the madebyollin/taesd config so the state dict can be easily loaded
# ----------------------- forward pass -----------------------------
x = torch.randn(BATCH_SIZE, *IMG_SHAPE)
latents = vae_model.encode(x).latents
print("Encoder output shape :", latents.shape)          # ➞ (B, 4, 28, 28)

recon = vae_model.decode(latents).sample
print("Decoder output shape :", recon.shape)            # ➞ (B, 3, 224, 224)

# quick loss & backward
loss = (recon - x).pow(2).mean()
loss.backward()
print("Backward pass OK – grads exist:", vae_model.encoder.layers[0].weight.grad is not None)

# ---------------------- parameter summary -------------------------
print("\n— learnable parameters —")
summary(vae_model, input_size = (BATCH_SIZE, *IMG_SHAPE), depth = 2, col_names = ("kernel_size", "num_params"))

Encoder output shape : torch.Size([2, 4, 28, 28])
Decoder output shape : torch.Size([2, 3, 224, 224])
Backward pass OK – grads exist: True

— learnable parameters —


Layer (type:depth-idx)                        Kernel Shape              Param #
AutoencoderTiny                               --                        --
├─EncoderTiny: 1-1                            --                        --
│    └─Sequential: 2-1                        --                        1,222,532
├─DecoderTiny: 1-2                            --                        --
│    └─Sequential: 2-2                        --                        1,222,531
Total params: 2,445,063
Trainable params: 2,445,063
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 50.55
Input size (MB): 1.20
Forward/backward pass size (MB): 750.68
Params size (MB): 9.78
Estimated Total Size (MB): 761.67

In [4]:
%load_ext autoreload
%autoreload 2

import torch
from networks.novel.DiffEIC.model import lfgcm_small
from torchinfo import summary


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
model_small = lfgcm_small.LFGCM(in_nc     = 3,
                                out_nc    = 4,
                                enc_mid   = [64, 128, 192, 192],
                                N         = 192,
                                M         = 320,
                                prior_nc  = 64,
                                sft_ks    = 3,
                                slice_num = 10,
                                slice_ch  = [8, 8, 8, 8, 16, 16, 32, 32, 96, 96])

In [6]:
summary(model_small, [(1, 3, 224, 224), (1, 4, 28, 28)])

Layer (type:depth-idx)                        Output Shape              Param #
LFGCM                                         [1, 4, 112, 112]          --
├─Encoder: 1-1                                [1, 320, 56, 56]          --
│    └─Sequential: 2-1                        [1, 64, 28, 28]           --
│    │    └─Conv2d: 3-1                       [1, 64, 28, 28]           2,368
│    │    └─GELU: 3-2                         [1, 64, 28, 28]           --
│    │    └─Conv2d: 3-3                       [1, 64, 28, 28]           4,160
│    └─Sequential: 2-2                        [1, 192, 112, 112]        --
│    │    └─ResidualBlockWithStride: 3-4      [1, 192, 112, 112]        338,112
│    │    └─ResidualBottleneck: 3-5           [1, 192, 112, 112]        120,192
│    │    └─ResidualBottleneck: 3-6           [1, 192, 112, 112]        120,192
│    │    └─ResidualBottleneck: 3-7           [1, 192, 112, 112]        120,192
│    └─SFT: 2-3                               [1, 192, 112, 112]     