In [1]:
import torch
import os
import torchvision
import matplotlib.pyplot as plt
import time
import numpy as np

from torchvision import transforms
from IPython.display import clear_output
from torchvision.utils import save_image
from torchvision.transforms.functional import to_pil_image, to_tensor
from scipy.integrate import solve_ivp
from IPython.display import clear_output, display

from dit_skip import DiT
from dit_controlnet import DiTControlNet

from torchinfo import summary

torch.manual_seed(0)
device = "cuda:0"
batch_size = 40

In [2]:
imgs = torch.randn(128, 3, 32, 32).to(device)    # Input image tensor
t = torch.rand((128,),).to(device)            # Noise labels (timestep embedding)

dit = DiT(input_size=32,
                 patch_size=2,
                 in_channels=3,
                 out_channels=3,
				 hidden_size=512,
				 depth=9,
                 num_heads=8,
                 mlp_ratio=4,
                 num_classes=0,
                 use_long_skip=True,
				 final_conv=True).to(device)

print(summary(dit, input_data=[imgs, t]))

with torch.no_grad():
	output = dit(imgs, t)

print(output.shape)

Layer (type:depth-idx)                   Output Shape              Param #
DiT                                      [128, 3, 32, 32]          131,072
├─PatchEmbed: 1-1                        [128, 256, 512]           --
│    └─Conv2d: 2-1                       [128, 512, 16, 16]        6,656
│    └─Identity: 2-2                     [128, 256, 512]           --
├─TimestepEmbedder: 1-2                  [128, 512]                --
│    └─Sequential: 2-3                   [128, 512]                --
│    │    └─Linear: 3-1                  [128, 512]                131,584
│    │    └─SiLU: 3-2                    [128, 512]                --
│    │    └─Linear: 3-3                  [128, 512]                262,656
├─ModuleList: 1-3                        --                        --
│    └─DiTBlock: 2-4                     [128, 256, 512]           --
│    │    └─Sequential: 3-4              [128, 3072]               1,575,936
│    │    └─LayerNorm: 3-5               [128, 256, 512]    

In [3]:
# ckpt_path = "/root/autodl-tmp/dit-1rf/003-DiT-S/2/checkpoints/0180000.pt"
# state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)["ema_ode"]
# model_ode_dict = dit.state_dict()

# missing_keys, unexpected_keys = dit.load_state_dict(state_dict)
# dit.eval()
# print(model_ode_dict.keys())
# print(missing_keys)
# print(unexpected_keys)

In [4]:
controlnet = DiTControlNet.from_transformer(transformer=dit,
                                            input_size=32,
											patch_size=2,
											in_channels=3,
											hidden_size=512,
											depth=4,
											num_heads=8,
											mlp_ratio=4,
											num_classes=0,
                                            load_weights_from_transformer=True).to(device)
controlnet.train()

dit.requires_grad_(False)
dit.to(device)

DiT(
  (x_embedder): PatchEmbed(
    (proj): Conv2d(3, 512, kernel_size=(2, 2), stride=(2, 2))
    (norm): Identity()
  )
  (t_embedder): TimestepEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=512, bias=True)
      (1): SiLU()
      (2): Linear(in_features=512, out_features=512, bias=True)
    )
  )
  (in_blocks): ModuleList(
    (0-3): 4 x DiTBlock(
      (norm1): LayerNorm((512,), eps=1e-06, elementwise_affine=False)
      (attn): Attention(
        (qkv): Linear(in_features=512, out_features=1536, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=512, out_features=512, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((512,), eps=1e-06, elementwise_affine=False)
      (mlp): Mlp(
        (fc1): Linear(in_features=512, out_features=2048, bias=True)
        (act): GELU(approximate='tanh')
        (drop1)

In [5]:
controlnet_cond = torch.randn_like(imgs)

with torch.no_grad():
	controlnet_block_samples, controlnet_mid_block_sample = controlnet(x=imgs,
					controlnet_x=controlnet_cond,
					t=t,
					conditioning_scale=1.0)

In [6]:
print(len(controlnet_block_samples))
print(len(controlnet_mid_block_sample))

print(controlnet_block_samples[0].shape)
print(controlnet_block_samples[0].requires_grad)

4
1
torch.Size([128, 256, 512])
False


In [7]:
with torch.no_grad():
	output_controlnet = dit(imgs, t, 
				controlnet_block_samples=controlnet_block_samples,
				controlnet_mid_block_sample=controlnet_mid_block_sample)

print(output_controlnet.shape)
print(output_controlnet.requires_grad)

torch.isclose(output, output_controlnet).all()

Add controlnet_block_samples[3] to decoder block 0
Add controlnet_block_samples[2] to decoder block 1
Add controlnet_block_samples[1] to decoder block 2
Add controlnet_block_samples[0] to decoder block 3
torch.Size([128, 3, 32, 32])
False


tensor(True, device='cuda:0')