In [1]:
import torch
import torch.nn as nn

# Define SELayer (as provided)
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

# Create a sample input tensor
batch_size, channels, height, width = 2, 64, 32, 32
x = torch.randn(batch_size, channels, height, width)  # Simulated feature map

# Instantiate SELayer
se_layer = SELayer(channel=channels, reduction=16)

# Forward pass
output = se_layer(x)

# Print shapes and sample attention weights
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

# Extract attention weights for inspection
with torch.no_grad():
    y = se_layer.avg_pool(x).view(batch_size, channels)
    y = se_layer.fc(y)
    print(f"Attention weights shape: {y.shape}")
    print(f"Sample attention weights for first sample:\n{y[0][:5]}")  # First 5 channels

Input shape: torch.Size([2, 64, 32, 32])
Output shape: torch.Size([2, 64, 32, 32])
Attention weights shape: torch.Size([2, 64])
Sample attention weights for first sample:
tensor([0.4997, 0.5000, 0.5000, 0.5002, 0.4997])


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define NonLocalBlock (as provided)
class NonLocalBlock(nn.Module):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NonLocalBlock, self).__init__()
        self.sub_sample = sub_sample
        self.in_channels = in_channels
        self.inter_channels = inter_channels
        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1
        self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)
        if bn_layer:
            self.W = nn.Sequential(
                nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
                          kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
                               kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)
        self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                               kernel_size=1, stride=1, padding=0)
        self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)
        if sub_sample:
            self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size=(2, 2)))
            self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size=(2, 2)))

    def forward(self, x):
        batch_size = x.size(0)
        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)
        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x
        return z

# Create a sample input tensor
batch_size, in_channels, height, width = 2, 64, 32, 32
x = torch.randn(batch_size, in_channels, height, width)  # Simulated feature map

# Instantiate NonLocalBlock
non_local_block = NonLocalBlock(in_channels=in_channels, inter_channels=None, sub_sample=True, bn_layer=True)

# Forward pass
output = non_local_block(x)

# Print shapes and sample attention weights
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

# Extract attention map for inspection
with torch.no_grad():
    g_x = non_local_block.g(x).view(batch_size, non_local_block.inter_channels, -1)
    theta_x = non_local_block.theta(x).view(batch_size, non_local_block.inter_channels, -1).permute(0, 2, 1)
    phi_x = non_local_block.phi(x).view(batch_size, non_local_block.inter_channels, -1)
    f = torch.matmul(theta_x, phi_x)
    f_div_C = F.softmax(f, dim=-1)
    print(f"Attention map shape: {f_div_C.shape}")
    print(f"Sample attention weights (first 5x5 for first sample):\n{f_div_C[0, :5, :5]}")

Input shape: torch.Size([2, 64, 32, 32])
Output shape: torch.Size([2, 64, 32, 32])
Attention map shape: torch.Size([2, 1024, 256])
Sample attention weights (first 5x5 for first sample):
tensor([[0.0052, 0.0056, 0.0017, 0.0021, 0.0009],
        [0.0016, 0.0070, 0.0040, 0.0005, 0.0004],
        [0.0097, 0.0216, 0.0051, 0.0020, 0.0051],
        [0.0087, 0.0036, 0.0016, 0.0019, 0.0107],
        [0.0001, 0.0020, 0.0030, 0.0100, 0.0002]])


In [1]:
import torch, copy

import torch.nn                              as nn

from   diffusers                             import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from   tqdm.auto                             import tqdm
from   diffusers.models.autoencoders.vae     import DiagonalGaussianDistribution
from   diffusers.models.modeling_outputs     import AutoencoderKLOutput

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_hybrid_unet(pretrained_path: str, device: str = "cuda") -> UNet2DConditionModel:

    # Step 1: Load the original model to access weights and config
    unet_orig = UNet2DConditionModel.from_pretrained(
        pretrained_path,
        subfolder="unet",
        torch_dtype=torch.float16
    ).to(device)
    orig_sd = unet_orig.state_dict()
    
    # Step 2: Create a deep copy of the config and modify in_channels
    config = copy.deepcopy(unet_orig.config)  # Avoid modifying the original config
    config['in_channels'] = 8  # Update in_channels to 8

    del unet_orig  # Free memory after loading original model
    
    # Step 3: Create a new UNet model with modified in_channels
    unet_new = UNet2DConditionModel(**config).to(device, dtype=torch.float16)

    # Step 4: Get the original and new state dicts

    new_sd = unet_new.state_dict()

    # Step 5: Initialize new conv_in weights manually
    old_conv = orig_sd["conv_in.weight"]  # Shape [320, 4, 3, 3]
    out_ch, _, kH, kW = old_conv.shape

    # New conv_in.weight: [320, 8, 3, 3]
    new_conv = torch.zeros((out_ch, 8, kH, kW), dtype=torch.float16, device=device)
    new_conv[:, :4, :, :] = old_conv  # Copy pretrained channels
    nn.init.kaiming_normal_(new_conv[:, 4:, :, :], mode='fan_out', nonlinearity='leaky_relu')  # Random init rest

    # Step 6: Update the state dict with the new conv_in.weight
    new_sd["conv_in.weight"] = new_conv

    # Step 7: Copy compatible weights from the original model
    for key in new_sd:
        if key != "conv_in.weight" and key in orig_sd and new_sd[key].shape == orig_sd[key].shape:
            new_sd[key] = orig_sd[key]

    # Step 8: Load updated state dict into the new model
    unet_new.load_state_dict(new_sd)

    return config, unet_new.requires_grad_(True)

In [3]:
new_config, unet = load_hybrid_unet("CompVis/stable-diffusion-v1-4")

In [4]:
old_config

FrozenDict([('sample_size', 64),
            ('in_channels', 4),
            ('out_channels', 4),
            ('center_input_sample', False),
            ('flip_sin_to_cos', True),
            ('freq_shift', 0),
            ('down_block_types',
             ['CrossAttnDownBlock2D',
              'CrossAttnDownBlock2D',
              'CrossAttnDownBlock2D',
              'DownBlock2D']),
            ('mid_block_type', 'UNetMidBlock2DCrossAttn'),
            ('up_block_types',
             ['UpBlock2D',
              'CrossAttnUpBlock2D',
              'CrossAttnUpBlock2D',
              'CrossAttnUpBlock2D']),
            ('only_cross_attention', False),
            ('block_out_channels', [320, 640, 1280, 1280]),
            ('layers_per_block', 2),
            ('downsample_padding', 1),
            ('mid_block_scale_factor', 1),
            ('dropout', 0.0),
            ('act_fn', 'silu'),
            ('norm_num_groups', 32),
            ('norm_eps', 1e-05),
            ('cross_attenti

In [4]:
new_config

FrozenDict([('sample_size', 64),
            ('in_channels', 8),
            ('out_channels', 4),
            ('center_input_sample', False),
            ('flip_sin_to_cos', True),
            ('freq_shift', 0),
            ('down_block_types',
             ['CrossAttnDownBlock2D',
              'CrossAttnDownBlock2D',
              'CrossAttnDownBlock2D',
              'DownBlock2D']),
            ('mid_block_type', 'UNetMidBlock2DCrossAttn'),
            ('up_block_types',
             ['UpBlock2D',
              'CrossAttnUpBlock2D',
              'CrossAttnUpBlock2D',
              'CrossAttnUpBlock2D']),
            ('only_cross_attention', False),
            ('block_out_channels', [320, 640, 1280, 1280]),
            ('layers_per_block', 2),
            ('downsample_padding', 1),
            ('mid_block_scale_factor', 1),
            ('dropout', 0.0),
            ('act_fn', 'silu'),
            ('norm_num_groups', 32),
            ('norm_eps', 1e-05),
            ('cross_attenti