In [1]:
import torch
from tqdm import tqdm
from safetensors import safe_open
from safetensors.torch import save_file
import sys
sys.path.append("./AnyDoor/")
sys.path.append("./")

torch.set_num_threads(2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def get_state_dict_from_safetensors(path:str):
    """ Load a state dict from a safetensors file """
    tensors = {}
    with safe_open(path, framework="pt", device="cpu") as f:
        for key in f.keys():
            tensors[key] = f.get_tensor(key)
    return tensors

In [3]:
# Set deterministic behavior
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# torch.use_deterministic_algorithms(True)

# Sampling


In [None]:
# x = torch.load("./tests/tensors/x.pt",weights_only=True)
initial_latents = torch.randn(1,4,64,64).to(device)
object_embedding = torch.randn(1, 257, 1024).to(device)
negative_object_embedding = torch.zeros(1, 257, 1024).to(device) #torch.load("./tests/tensors/negative_object_embedding.pt",weights_only=True)
control = [x.to(device) for x in torch.load("./tests/tensors/control_features.pt",weights_only=True)]
inference_steps = 10
scale = 5.0

In [5]:
from anydoor_original.cldm import model as cldm
from cldm.ddim_hacked import DDIMSampler
sampler = DDIMSampler(cldm)

No module 'xformers'. Proceeding without it.
ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 865.91 M params.
Loaded model config from [./src/anydoor_original/configs/anydoor.yaml]
Model loaded successfully


In [6]:
with torch.no_grad():
    mocked_control_image = torch.zeros(1, 4, 32, 32)
    cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [object_embedding],
    }
    un_cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [negative_object_embedding],
    }

    samples, intermediates = sampler.sample(
        S=inference_steps,
        batch_size=1,
        shape=(4, 32, 32),
        conditioning=cond,
        x_T=initial_latents.clone(),
        verbose=False,
        unconditional_guidance_scale=scale,
        unconditional_conditioning=un_cond,
    )

Data shape for DDIM sampling is (1, 4, 32, 32), eta 0.0
Running DDIM Sampling with 10 timesteps


DDIM Sampler:  10%|█         | 1/10 [00:01<00:14,  1.62s/it]

tensor(102.6943)


DDIM Sampler:  20%|██        | 2/10 [00:03<00:12,  1.53s/it]

tensor(153.7481)


DDIM Sampler:  30%|███       | 3/10 [00:04<00:10,  1.50s/it]

tensor(214.7504)


DDIM Sampler:  40%|████      | 4/10 [00:06<00:09,  1.50s/it]

tensor(281.7013)


DDIM Sampler:  50%|█████     | 5/10 [00:07<00:07,  1.60s/it]

tensor(349.3274)


DDIM Sampler:  60%|██████    | 6/10 [00:09<00:06,  1.65s/it]

tensor(412.2058)


DDIM Sampler:  70%|███████   | 7/10 [00:11<00:05,  1.68s/it]

tensor(465.8795)


DDIM Sampler:  80%|████████  | 8/10 [00:13<00:03,  1.70s/it]

tensor(507.6265)


DDIM Sampler:  90%|█████████ | 9/10 [00:14<00:01,  1.70s/it]

tensor(536.7265)


DDIM Sampler: 100%|██████████| 10/10 [00:16<00:00,  1.67s/it]

tensor(536.9559)





In [7]:
from anydoor_refiners.model import AnyDoor,solver_params
from anydoor_refiners.unet import UNet
from refiners.foundationals.latent_diffusion.solvers import DDIM
from tests.mocks import DINOv2EncoderMock,ControlNetMock,AnydoorAutoencoderMock

unet = UNet(4).to(device)

refiners_model = AnyDoor(
    unet=unet,
    lda=AnydoorAutoencoderMock(),
    object_encoder=DINOv2EncoderMock(object_embedding,negative_object_embedding),
    control_model=ControlNetMock(control),
    solver=DDIM(inference_steps,params=solver_params)
)
# refiners_model.unet.load_state_dict(weights_refiners)


In [8]:
import json
from utils.weight_mapper import get_converted_state_dict

with open("./tests/weights_mapping/unet.json", "r") as f:
    weight_mapping = json.load(f)
converted_state_dict = get_converted_state_dict(
    source_state_dict=sampler.model.model.diffusion_model.state_dict(),
    target_state_dict=refiners_model.unet.state_dict(),
    mapping=weight_mapping,
)
refiners_model.unet.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [9]:
with torch.no_grad():

    y = initial_latents.clone()
    for s in tqdm(refiners_model.steps):
        y = refiners_model.forward(y,step=s,control_background_image=torch.zeros(1),object_embedding=object_embedding,negative_object_embedding=negative_object_embedding)
        print(torch.norm(y))

 10%|█         | 1/10 [00:00<00:07,  1.19it/s]

tensor(102.6943)


 20%|██        | 2/10 [00:01<00:06,  1.32it/s]

tensor(153.7481)


 30%|███       | 3/10 [00:02<00:05,  1.34it/s]

tensor(214.7505)


 40%|████      | 4/10 [00:03<00:04,  1.30it/s]

tensor(281.7013)


 50%|█████     | 5/10 [00:03<00:03,  1.32it/s]

tensor(349.3275)


 60%|██████    | 6/10 [00:04<00:02,  1.34it/s]

tensor(412.2059)


 70%|███████   | 7/10 [00:05<00:02,  1.30it/s]

tensor(465.8796)


 80%|████████  | 8/10 [00:06<00:01,  1.27it/s]

tensor(507.6266)


 90%|█████████ | 9/10 [00:06<00:00,  1.31it/s]

tensor(536.7266)


100%|██████████| 10/10 [00:07<00:00,  1.31it/s]

tensor(536.9562)





In [10]:
torch.norm(y-samples[-1]),torch.norm(y),torch.norm(samples[-1])

(tensor(0.0002), tensor(536.9562), tensor(536.9559))

### Using Anydoor weights

In [11]:
weights_anydoor = get_state_dict_from_safetensors("./ckpt/unet.safetensors")
sampler.model.model.diffusion_model.load_state_dict(weights_anydoor)

<All keys matched successfully>

In [12]:
converted_state_dict = get_converted_state_dict(
    source_state_dict=sampler.model.model.diffusion_model.state_dict(),
    target_state_dict=refiners_model.unet.state_dict(),
    mapping=weight_mapping,
)
refiners_model.unet.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [13]:
with torch.no_grad():
    mocked_control_image = torch.zeros(1, 4, 32, 32)
    cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [object_embedding],
    }
    un_cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [negative_object_embedding],
    }

    samples, intermediates = sampler.sample(
        S=inference_steps,
        batch_size=1,
        shape=(4, 32, 32),
        conditioning=cond,
        x_T=initial_latents.clone(),
        verbose=False,
        unconditional_guidance_scale=scale,
        unconditional_conditioning=un_cond,
    )

Data shape for DDIM sampling is (1, 4, 32, 32), eta 0.0
Running DDIM Sampling with 10 timesteps


DDIM Sampler:  10%|█         | 1/10 [00:02<00:22,  2.48s/it]

tensor(66.4161)


DDIM Sampler:  20%|██        | 2/10 [00:03<00:15,  1.88s/it]

tensor(69.9417)


DDIM Sampler:  30%|███       | 3/10 [00:05<00:11,  1.66s/it]

tensor(72.8006)


DDIM Sampler:  40%|████      | 4/10 [00:07<00:11,  1.83s/it]

tensor(75.2370)


DDIM Sampler:  50%|█████     | 5/10 [00:11<00:12,  2.46s/it]

tensor(77.0635)


DDIM Sampler:  60%|██████    | 6/10 [00:13<00:09,  2.39s/it]

tensor(77.3051)


DDIM Sampler:  70%|███████   | 7/10 [00:15<00:06,  2.20s/it]

tensor(76.4554)


DDIM Sampler:  80%|████████  | 8/10 [00:17<00:04,  2.16s/it]

tensor(74.5179)


DDIM Sampler:  90%|█████████ | 9/10 [00:19<00:02,  2.17s/it]

tensor(69.9801)


DDIM Sampler: 100%|██████████| 10/10 [00:21<00:00,  2.17s/it]

tensor(69.8353)





In [14]:
with torch.no_grad():

    y = initial_latents.clone()
    for s in tqdm(refiners_model.steps):
        y = refiners_model.forward(y,step=s,control_background_image=torch.zeros(1),object_embedding=object_embedding,negative_object_embedding=negative_object_embedding)
        print(torch.norm(y))

 10%|█         | 1/10 [00:01<00:15,  1.70s/it]

tensor(66.4241)


 20%|██        | 2/10 [00:02<00:09,  1.14s/it]

tensor(69.7671)


 30%|███       | 3/10 [00:03<00:07,  1.10s/it]

tensor(72.4971)


 40%|████      | 4/10 [00:04<00:07,  1.23s/it]

tensor(74.3268)


 50%|█████     | 5/10 [00:05<00:05,  1.15s/it]

tensor(75.4029)


 60%|██████    | 6/10 [00:06<00:04,  1.11s/it]

tensor(75.3962)


 70%|███████   | 7/10 [00:08<00:03,  1.21s/it]

tensor(74.2556)


 80%|████████  | 8/10 [00:09<00:02,  1.31s/it]

tensor(72.1551)


 90%|█████████ | 9/10 [00:10<00:01,  1.24s/it]

tensor(67.1968)


100%|██████████| 10/10 [00:12<00:00,  1.22s/it]

tensor(66.6866)





In [15]:
torch.norm(y-samples[-1]),torch.norm(y),torch.norm(samples[-1])

(tensor(16.2772), tensor(66.6866), tensor(69.8353))

# Unet




In [4]:
import json
from utils.weight_mapper import get_converted_state_dict
from anydoor_refiners.unet import UNet

In [5]:
x = torch.randn(1, 4, 64, 64).to(device) 
timestep = torch.full((1,), 960, dtype=torch.long).to(device)
object_embedding = torch.randn(1, 257, 1024).to(device)
negative_object_embedding = torch.randn(1, 257, 1024).to(device) 
control = [x.to(device) for x in torch.load("./tests/tensors/control_features_real_size.pt",weights_only=True)]

In [6]:
from anydoor_original.cldm import model as cldm

No module 'xformers'. Proceeding without it.
ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 865.91 M params.
Loaded model config from [./src/anydoor_original/configs/anydoor.yaml]
Model loaded successfully


In [7]:
cldm.to(device)
True # Just to avoid printing the model

True

In [8]:
with torch.no_grad():
    cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [object_embedding],
    }
    y1 = cldm.apply_model(
        x_noisy = x, 
        t = timestep, 
        cond = cond,
    )

In [9]:

unet = UNet(4, device=device)

In [10]:


with open("./tests/weights_mapping/unet.json", "r") as f:
    weight_mapping = json.load(f)
converted_state_dict = get_converted_state_dict(
    source_state_dict=cldm.model.diffusion_model.state_dict(),
    target_state_dict=unet.state_dict(),
    mapping=weight_mapping,
)
unet.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [11]:
with torch.no_grad():
    unet.set_control_residuals(control)
    unet.set_timestep(timestep)
    unet.set_dinov2_object_embedding(object_embedding)
    y2 = unet(x)

In [12]:
torch.norm(y1-y2),torch.norm(y1),torch.norm(y2)

(tensor(0.), tensor(0.), tensor(0.))

In [13]:
unet_weights = get_state_dict_from_safetensors("./ckpt/unet.safetensors")
cldm.model.diffusion_model.load_state_dict(unet_weights)

<All keys matched successfully>

In [14]:
converted_state_dict = get_converted_state_dict(
    source_state_dict=cldm.model.diffusion_model.state_dict(),
    target_state_dict=unet.state_dict(),
    mapping=weight_mapping,
)
unet.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [15]:
with torch.no_grad():
    cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [object_embedding],
    }
    y1_bis = cldm.apply_model(
        x_noisy = x, 
        t = timestep, 
        cond = cond,
    )

In [16]:
with torch.no_grad():
    unet.set_control_residuals(control)
    unet.set_timestep(timestep)
    unet.set_dinov2_object_embedding(object_embedding)
    y2_bis = unet(x)

In [17]:
torch.norm(y1_bis-y2_bis),torch.norm(y1_bis),torch.norm(y2_bis)

(tensor(0.), tensor(129.5021), tensor(129.5021))

In [26]:
unet

(CHAIN) UNet(in_channels=4)
    ├── (PASS) TimestepEncoder()
    │   ├── UseContext(context=diffusion, key=timestep)
    │   ├── (CHAIN) RangeEncoder(sinusoidal_embedding_dim=320, embedding_dim=1280)
    │   │   ├── Lambda(compute_sinusoidal_embedding(x: jaxtyping.Int[Tensor, '*batch 1']) -> jaxtyping.Float[Tensor, '*batch 1 embedding_dim'])
    │   │   ├── Converter(set_device=False)
    │   │   ├── Linear(in_features=320, out_features=1280, device=cuda:0, dtype=float32) #1
    │   │   ├── SiLU()
    │   │   └── Linear(in_features=1280, out_features=1280, device=cuda:0, dtype=float32) #2
    │   └── SetContext(context=range_adapter, key=timestep_embedding)
    ├── (CHAIN) DownBlocks(in_channels=4)
    │   ├── (CHAIN) #1
    │   │   ├── Conv2d(in_channels=4, out_channels=320, kernel_size=(3, 3), padding=(1, 1), device=cuda:0, dtype=float32)
    │   │   └── (PASS) ResidualAccumulator(n=0)
    │   │       ├── (RES) Residual()
    │   │       │   └── UseContext(context=unet, key=residuals

# SpatialTransformer




In [4]:
import json
from utils.weight_mapper import get_converted_state_dict
from anydoor_refiners.attention import CrossAttentionBlock2d
import refiners.fluxion.layers as fl

In [5]:
from AnyDoor.ldm.modules.attention import SpatialTransformer

In [6]:
# Define model configuration parameters with descriptive names
input_channels = 320  # Number of input channels for the model
num_heads = 5  # Number of attention heads
head_dim = 64  # Dimension of each attention head
num_layers = 1  # Depth of attention layers
context_dim = 1024  # Dimension of the context embedding
use_linear_projection = True  # Whether to use linear projection in attention

# Initialize the SpatialTransformer model
spatial_transformer = SpatialTransformer(
    in_channels=input_channels,
    n_heads=num_heads,
    d_head=head_dim,
    depth=num_layers,
    context_dim=context_dim,
    use_linear=use_linear_projection,
    use_checkpoint=True,
).to(device)

Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.


In [7]:

# Initialize the CrossAttentionBlock2d model
cross_attention_block = CrossAttentionBlock2d(
    channels=input_channels,
    context_embedding_dim=context_dim,
    context_key="key",  # Key to set the context in cross_attention_block
    num_attention_heads=num_heads,
    num_attention_layers=num_layers,
    num_groups=32,  # Number of groups for grouped attention
    use_bias=False,
    use_linear_projection=use_linear_projection,
).to(device)

In [8]:
# Convert the source model's state dict to match the target model's structure
with open("tests/weights_mapping/cross_attention_block_2d.json", "r") as f:
    weight_mapping = json.load(f)
converted_state_dict = get_converted_state_dict(
    source_state_dict=spatial_transformer.state_dict(),
    target_state_dict=cross_attention_block.state_dict(),
    mapping=weight_mapping,
)
cross_attention_block.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [9]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
context_dim = 1024  # Must match the model's context dimension configuration
input_tensor = torch.randn(1, input_channels, 32, 32).to(device)  # Example input tensor
context_tensor = torch.randn(1, 1, context_dim).to(device)  # Example context tensor

with torch.no_grad():
    # Set the context for the CrossAttentionBlock2d model
    cross_attention_block.set_context(  # noqa: F821
        "cross_attention_block", {"key": context_tensor}
    )
    # Forward pass through both models
    y_source = spatial_transformer.forward(input_tensor, context=context_tensor)
    y_target = cross_attention_block.forward(input_tensor)  # noqa: F821
    
torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)
    

(tensor(0., device='cuda:0'),
 tensor(573.0254, device='cuda:0'),
 tensor(573.0254, device='cuda:0'))

In [10]:

spatial_transformer_weights = get_state_dict_from_safetensors("./ckpt/spatial_transformer.safetensors")
spatial_transformer.load_state_dict(spatial_transformer_weights)
converted_state_dict = get_converted_state_dict(
    source_state_dict=spatial_transformer.state_dict(),
    target_state_dict=cross_attention_block.state_dict(),  # noqa: F821
    mapping=weight_mapping,
)
cross_attention_block.load_state_dict(converted_state_dict)  # noqa: F821


<All keys matched successfully>

In [11]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
context_dim = 1024  # Must match the model's context dimension configuration
input_tensor = torch.randn(1, input_channels, 32, 32).to(device)  # Example input tensor
context_tensor = torch.randn(1, 1, context_dim).to(device)  # Example context tensor

with torch.no_grad():
    # Set the context for the CrossAttentionBlock2d model
    cross_attention_block.set_context(  # noqa: F821
        "cross_attention_block", {"key": context_tensor}
    )
    # Forward pass through both models
    y_target = cross_attention_block.forward(input_tensor)  # noqa: F821
    y_source = spatial_transformer.forward(input_tensor, context=context_tensor)
    
torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)
    

(tensor(207.9414, device='cuda:0'),
 tensor(540.9109, device='cuda:0'),
 tensor(529.5787, device='cuda:0'))

In [13]:
transformer_block_keys = [
    k for k in spatial_transformer_weights.keys() if "transformer_blocks.0." in k
]
transformer_block_weights = {
    k.replace("transformer_blocks.0.",""): spatial_transformer_weights[k] for k in transformer_block_keys
}
save_file(transformer_block_weights,"./ckpt/transformer_block.safetensors")

In [16]:
input_projection_weight_keys = [
    'norm.bias', 'norm.weight', 'proj_in.bias', 'proj_in.weight'
]
input_projection_weights = {
    key: spatial_transformer.state_dict()[key]
    for key in input_projection_weight_keys
}
save_file(input_projection_weights, "./ckpt/input_projection.safetensors")

In [17]:
output_projection_weight_keys = [
    'proj_out.bias', 'proj_out.weight'
]
output_projection_weights = {
    key: spatial_transformer.state_dict()[key]
    for key in output_projection_weight_keys
}
save_file(output_projection_weights, "./ckpt/output_projection.safetensors")

# Input Projection





In [16]:
import json
import refiners.fluxion.layers as fl
from torch import nn
from utils.weight_mapper import get_converted_state_dict
from einops import rearrange
from refiners.fluxion.context import Contexts

In [17]:
class SmallModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm = nn.GroupNorm(num_groups=32, num_channels=320, eps=1e-6, affine=True)
        self.proj_in = nn.Linear(320, 320)

    def forward(self, x):
        x = self.norm(x)
        x = rearrange(x, "b c h w -> b (h w) c").contiguous()
        x = self.proj_in(x)
        return x



class SmallModelRefiners(fl.Chain):
    def __init__(self):
        super().__init__(
            fl.GroupNorm(
                channels=320,
                num_groups=32,
                eps=1e-6
            ),
            fl.Flatten(start_dim=2, end_dim=-1),
            fl.Transpose(1, 2),
            fl.Lambda(lambda x: x.contiguous()),
            fl.Linear(
                in_features=320,
                out_features=320,
            ),
        )
    def init_context(self) -> Contexts:
        return {"flatten": {"sizes": []}}

In [18]:
anydoor = SmallModel().to(device)
refiners = SmallModelRefiners().to(device)

In [19]:
mapping = {
    "GroupNorm": "norm",
    "Linear": "proj_in",
}
refiners_state_dict_converted = get_converted_state_dict(
    source_state_dict=anydoor.state_dict(),
    target_state_dict=refiners.state_dict(),
    mapping=mapping,
)
refiners.load_state_dict(refiners_state_dict_converted)

<All keys matched successfully>

In [20]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
input_tensor = torch.randn(1, input_channels, 32, 32).to(device)  # Example input tensor

with torch.no_grad():
    y_projected_target = anydoor.forward(input_tensor) 
    y_projected_source = refiners.forward(input_tensor)
    
torch.norm(y_projected_target-y_projected_source),torch.norm(y_projected_target),torch.norm(y_projected_source)
    

(tensor(0., device='cuda:0'),
 tensor(330.6661, device='cuda:0'),
 tensor(330.6661, device='cuda:0'))

In [21]:

weights = get_state_dict_from_safetensors("./ckpt/input_projection.safetensors")
anydoor.load_state_dict(weights)
refiners_state_dict_converted = get_converted_state_dict(
    source_state_dict=anydoor.state_dict(),
    target_state_dict=refiners.state_dict(),
    mapping=mapping,
)
refiners.load_state_dict(refiners_state_dict_converted)


<All keys matched successfully>

In [22]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
input_tensor = torch.randn(1, input_channels, 32, 32).to(device)  # Example input tensor

with torch.no_grad():
    y_projected_target = anydoor.forward(input_tensor) 
    y_projected_source = refiners.forward(input_tensor)
    
torch.norm(y_projected_target-y_projected_source),torch.norm(y_projected_target),torch.norm(y_projected_source)
    

(tensor(0., device='cuda:0'),
 tensor(176.8964, device='cuda:0'),
 tensor(176.8964, device='cuda:0'))

# TransformerBlock




In [56]:
import json
from utils.weight_mapper import get_converted_state_dict
from anydoor_refiners.attention import CrossAttentionBlock

In [57]:
from AnyDoor.ldm.modules.attention import BasicTransformerBlock

In [58]:
# Define model configuration parameters with descriptive names
input_channels = 320  # Number of input channels for the model
num_heads = 5  # Number of attention heads
head_dim = 64  # Dimension of each attention head
num_layers = 1  # Depth of attention layers
context_dim = 1024  # Dimension of the context embedding
use_linear_projection = True  # Whether to use linear projection in attention

# Initialize the SpatialTransformer model
anydoor = BasicTransformerBlock(
    dim=input_channels,
    n_heads=num_heads,
    d_head=head_dim,
    context_dim=context_dim,
    disable_self_attn=False,
    checkpoint=True).to(device)

refiners = CrossAttentionBlock(
    embedding_dim=input_channels,
    context_embedding_dim=context_dim,
    context_key="key",
    num_heads=num_heads,
    use_bias=False).to(device)



In [59]:
mapping = {
    "Residual_1.SelfAttention.Linear": "attn1.to_out.0",
    "Residual_2.Attention.Linear": "attn2.to_out.0",
    "Residual_1.LayerNorm": "norm1",
    "Residual_2.LayerNorm": "norm2",
    "Residual_3.LayerNorm": "norm3",
    "Residual_1.SelfAttention.Distribute.Linear_1": "attn1.to_q",
    "Residual_1.SelfAttention.Distribute.Linear_2": "attn1.to_k",
    "Residual_1.SelfAttention.Distribute.Linear_3": "attn1.to_v",
    "Residual_2.Attention.Distribute.Linear_1": "attn2.to_q",
    "Residual_2.Attention.Distribute.Linear_2": "attn2.to_k",
    "Residual_2.Attention.Distribute.Linear_3": "attn2.to_v",
    "Residual_3.Linear_1": "ff.net.0.proj",
    "Residual_3.Linear_2": "ff.net.2",
}

In [60]:

converted_state_dict = get_converted_state_dict(
    source_state_dict=anydoor.state_dict(),
    target_state_dict=refiners.state_dict(),
    mapping=mapping,
)
refiners.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [61]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
context_dim = 1024  # Must match the model's context dimension configuration
input_tensor = torch.randn(1, 2, input_channels).to(device)  # Example input tensor
context_tensor = torch.randn(1, 1, context_dim).to(device)  # Example context tensor

with torch.no_grad():
    # Set the context for the CrossAttentionBlock2d model
    refiners.set_context(  # noqa: F821
        "cross_attention_block", {"key": context_tensor}
    )
    # Forward pass through both models
    y_target = refiners.forward(input_tensor)  # noqa: F821
    y_source = anydoor.forward(input_tensor, context=context_tensor)
    
torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)
    

(tensor(1.9120e-06, device='cuda:0'),
 tensor(27.1595, device='cuda:0'),
 tensor(27.1595, device='cuda:0'))

In [29]:

weights = get_state_dict_from_safetensors("./ckpt/transformer_block.safetensors")
anydoor.load_state_dict(weights)
converted_state_dict = get_converted_state_dict(
    source_state_dict=anydoor.state_dict(),
    target_state_dict=refiners.state_dict(),  # noqa: F821
    mapping=mapping,
)
refiners.load_state_dict(converted_state_dict)  # noqa: F821


<All keys matched successfully>

In [30]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
context_dim = 1024  # Must match the model's context dimension configuration
input_tensor = torch.randn(1, 1, input_channels).to(device) * 100  # Example input tensor
context_tensor = torch.randn(1, 1, context_dim).to(device)  # Example context tensor

with torch.no_grad():
    # Set the context for the CrossAttentionBlock2d model
    refiners.set_context(  # noqa: F821
        "cross_attention_block", {"key": context_tensor}
    )
    # Forward pass through both models
    y_target = refiners.forward(input_tensor)  # noqa: F821
    y_source = anydoor.forward(input_tensor, context=context_tensor)
    
print(torch.allclose(y_target,y_source,rtol=1e-7,atol=1e-7))
torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)
    

True


(tensor(0., device='cuda:0'),
 tensor(1719.6464, device='cuda:0'),
 tensor(1719.6464, device='cuda:0'))

In [45]:
class DerangedModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.norm = nn.GroupNorm(num_groups=32, num_channels=320, eps=1e-6, affine=True)
        self.proj_in = nn.Linear(320, 320)

    def forward(self, x):
        # x = self.norm(x)
        x = rearrange(x, "b c h w -> b (h w) c").contiguous()
        x = self.proj_in(x)
        return x

model = DerangedModel().to(device)

In [51]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
context_dim = 1024  # Must match the model's context dimension configuration
input_tensor = torch.randn(1, 2, input_channels).to(device)   # Example input tensor
context_tensor = torch.randn(1, 1, context_dim).to(device)  # Example context tensor

input_tensor_2 = model(torch.randn( 1, input_channels, 32 , 32).to(device))

In [48]:
input_tensor.shape,input_tensor_2.shape

(torch.Size([1, 1, 320]), torch.Size([1, 1024, 320]))

In [None]:

with torch.no_grad():
    # Set the context for the CrossAttentionBlock2d model
    refiners.set_context(  # noqa: F821
        "cross_attention_block", {"key": context_tensor.clone()}
    )
    # Forward pass through both models
    y1 = refiners.forward(input_tensor)  # noqa: F821
    y2 = anydoor.forward(input_tensor, context=context_tensor.clone())
    
    print(torch.allclose(y1,y2,rtol=1e-7,atol=1e-7)) # type: ignore
    print(torch.norm(y1-y2),torch.norm(y1),torch.norm(y2))

    # Set the context for the CrossAttentionBlock2d model
    refiners.set_context(  # noqa: F821
        "cross_attention_block", {"key": context_tensor.clone()}
    )
    # Forward pass through both models
    y1_bis = refiners.forward(y_projected_target)  # noqa: F821
    y2_bis = anydoor.forward(y_projected_target, context=context_tensor.clone())
    
    print(torch.allclose(y1_bis,y2_bis,rtol=1e-7,atol=1e-7))
    print(torch.norm(y1_bis-y2_bis),torch.norm(y1_bis),torch.norm(y2_bis))

False
tensor(3.6385e-06, device='cuda:0') tensor(27.5168, device='cuda:0') tensor(27.5168, device='cuda:0')
True
tensor(0., device='cuda:0') tensor(387.2995, device='cuda:0') tensor(387.2995, device='cuda:0')


In [28]:
from refiners.conversion.model_converter import ModelConverter

model_converter = ModelConverter(source_model=anydoor, target_model=refiners)

In [29]:
with torch.no_grad():
    refiners.set_context( 
        "cross_attention_block", {"key": context_tensor.clone()}
    )
    model_converter.run(source_args=(y_projected_source, context_tensor), target_args=(y_projected_source,))

Stage 0 -> 1 - Models have the same number of basic layers. Finding matching shapes and layers...
Stage 1 -> 2 - Shape of both models agree. Applying state_dict to target model. Comparing models...
Models diverged between attn1.to_v and attn1.to_out.0, and between Residual_1.SelfAttention.Distribute.Linear_3 and Residual_1.SelfAttention.Linear, difference in norm: 9.048733045347035e-05
Models do not agree. Try to increase the threshold or modify the models.
Conversion failed at stage 3


In [64]:
y_projected_target.dtype, y_projected_source.dtype, torch.allclose(y_projected_target,y_projected_source,rtol=1e-12,atol=1e-12)

(torch.float32, torch.float32, True)

In [65]:
y_projected_target.is_contiguous(), y_projected_source.is_contiguous()

(True, True)

In [45]:
for weight in refiners.parameters():
    print(weight.is_contiguous(), weight.dtype)

True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32


In [46]:
for weight in anydoor.parameters():
    print(weight.is_contiguous(), weight.dtype)

True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32


# Attention




In [26]:
from einops import rearrange
from jaxtyping import Float
from torch import Tensor
from torch.nn.functional import scaled_dot_product_attention


In [34]:
import math

def scaled_dot_product_attention_non_optimized(
    query: Float[Tensor, "batch source_sequence_length dim"],
    key: Float[Tensor, "batch target_sequence_length dim"],
    value: Float[Tensor, "batch target_sequence_length dim"],
    is_causal: bool = False,
) -> Float[Tensor, "batch source_sequence_length dim"]:
    """Non-optimized Scaled Dot Product Attention.

    See [[arXiv:1706.03762] Attention Is All You Need (Equation 1)](https://arxiv.org/abs/1706.03762) for more details.
    """
    if is_causal:
        # TODO: implement causal attention
        raise NotImplementedError(
            "Causal attention for `scaled_dot_product_attention_non_optimized` is not yet implemented"
        )

    dim = query.shape[-1]
    attention = query @ key.permute(0, 1, 3, 2)
    attention = attention / math.sqrt(dim)
    attention = torch.softmax(input=attention, dim=-1)
    return attention @ value

class ScaledDotProductAttention(torch.nn.Module):

    def __init__(
        self,
        num_heads: int = 1,
        is_causal: bool = False,
        is_optimized: bool = True,
        slice_size: int | None = None,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.is_causal = is_causal
        self.is_optimized = is_optimized
        self.slice_size = slice_size
        self.dot_product = scaled_dot_product_attention_non_optimized

    def forward(
        self,
        query: Float[Tensor, "batch num_queries embedding_dim"],
        key: Float[Tensor, "batch num_keys embedding_dim"],
        value: Float[Tensor, "batch num_values embedding_dim"],
    ) -> Float[Tensor, "batch num_queries embedding_dim"]:

        return self._process_attention(
            query=query,
            key=key,
            value=value,
        )


    def _process_attention(
        self,
        query: Float[Tensor, "batch num_queries embedding_dim"],
        key: Float[Tensor, "batch num_keys embedding_dim"],
        value: Float[Tensor, "batch num_values embedding_dim"],
    ) -> Float[Tensor, "batch num_queries embedding_dim"]:
        return self._merge_multi_head(
            x=self.dot_product(
                query=self._split_to_multi_head(query),
                key=self._split_to_multi_head(key),
                value=self._split_to_multi_head(value),
                is_causal=self.is_causal,
            )
        )

    def _split_to_multi_head(
        self,
        x: Float[Tensor, "batch_size sequence_length embedding_dim"],
    ) -> Float[Tensor, "batch_size num_heads sequence_length (embedding_dim//num_heads)"]:
        """Split the input tensor into multiple heads along the embedding dimension.

        See also `merge_multi_head`, which is the inverse operation.
        """
        assert (
            x.ndim == 3
        ), f"Expected input tensor with shape (batch_size sequence_length embedding_dim), got {x.shape}"
        assert (
            x.shape[-1] % self.num_heads == 0
        ), f"Expected embedding_dim (x.shape[-1]={x.shape[-1]}) to be divisible by num_heads ({self.num_heads})"

        return x.reshape(x.shape[0], x.shape[1], self.num_heads, x.shape[-1] // self.num_heads).transpose(1, 2)

    def _merge_multi_head(
        self,
        x: Float[Tensor, "batch_size num_heads sequence_length heads_dim"],
    ) -> Float[Tensor, "batch_size sequence_length heads_dim * num_heads"]:
        """Merge the input tensor from multiple heads along the embedding dimension.

        See also `split_to_multi_head`, which is the inverse operation.
        """
        return x.transpose(1, 2).reshape(x.shape[0], x.shape[2], self.num_heads * x.shape[-1])

In [35]:

class AnyDoorAttentionProduct(torch.nn.Module):

    def __init__(self, heads: int):
        super().__init__()
        self.heads = heads
        self.scale = 64 ** -0.5

    def forward(self, q, k, v):
        h = self.heads

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        with torch.autocast(enabled=False, device_type = 'cuda'):
            q, k = q.float(), k.float()
            sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
        
        del q, k

        # attention, what we cannot get enough of
        sim = sim.softmax(dim=-1)

        out = torch.einsum('b i j, b j d -> b i d', sim, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return out

In [36]:
refiners_attention = ScaledDotProductAttention(num_heads=5).to(device)
anydoor_attention = AnyDoorAttentionProduct(heads=5).to(device)

In [38]:
# Define input tensors
dim = 320  
q = torch.randn(1, 10, 320).to(device)  # Example input tensor
k = torch.randn(1, 10, 320).to(device)  # Example input tensor
v = torch.randn(1, 10, 320).to(device)  # Example input tensor

anydoor_result = anydoor_attention(q, k, v)
refiners_result = refiners_attention(q, k, v)

torch.norm(anydoor_result-refiners_result),torch.norm(anydoor_result),torch.norm(refiners_result)
    

(tensor(0., device='cuda:0'),
 tensor(23.2752, device='cuda:0'),
 tensor(23.2752, device='cuda:0'))

# Output Projection





In [None]:
output_projection = nn.Linear(320,320)


with torch.no_grad():
    y_target = output_projection(y_target) 
    y_source = output_projection(y_source)

torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)

In [None]:
weights = get_state_dict_from_safetensors("./ckpt/output_projection.safetensors")
output_projection.load_state_dict(weights)

In [None]:
output_projection = nn.Linear(320,320)


with torch.no_grad():
    y_target = output_projection(y_target) 
    y_source = output_projection(y_source)

torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)

# LDA

In [2]:
import torch
import sys
from omegaconf import OmegaConf

sys.path.append("./AnyDoor/")
from ldm.util import instantiate_from_config

conf = OmegaConf.load("src/anydoor_original/configs/anydoor.yaml")

lda_anydoor = instantiate_from_config(conf.model.params.first_stage_config)


No module 'xformers'. Proceeding without it.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


In [3]:
lda_anydoor.load_state_dict(torch.load("ckpt/lda_anydoor.ckpt"))

  lda_anydoor.load_state_dict(torch.load("ckpt/lda_anydoor.ckpt"))


<All keys matched successfully>

In [4]:
from anydoor_refiners.model import AnydoorAutoencoder

lda_refiners = AnydoorAutoencoder()

In [6]:
lda_refiners = lda_refiners.load_from_safetensors("ckpt/anydoor_refiners_safetensors/lda.safetensors")

In [7]:
# Print nb of trainable parameters of the two models

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(lda_anydoor))
print(count_parameters(lda_refiners))


83653863
83653863


In [15]:
with torch.no_grad():
    img = torch.randn(1,3,256,256)
    y1 = lda_refiners.forward(img)
    y2 = lda_anydoor.forward(img)

In [19]:
torch.norm(y1 - y2[0])

tensor(207.3743)

In [16]:
y2[0].shape

torch.Size([1, 3, 256, 256])

In [17]:
from refiners.conversion.model_converter import ModelConverter

converter = ModelConverter(source_model=lda_anydoor,target_model=lda_refiners)

In [18]:
with torch.no_grad():
    img = torch.randn(1,3,256,256)
    converter.run(source_args=(img,),target_args=(img,))

Models do not have the same number of basic layers:
  <class 'torch.nn.modules.conv.Conv2d'>: Source 72 - Target 64
  <class 'torch.nn.modules.linear.Linear'>: Source 0 - Target 8
Conversion failed at stage 1


In [7]:
y.shape

torch.Size([1, 3, 256, 256])

# ControlNet




In [5]:
import json
from utils.weight_mapper import get_converted_state_dict
from anydoor_refiners.controlnet import ControlNet

In [6]:
x = torch.randn(1, 4, 512, 512).to(device) ##torch.load("./tests/tensors/x.pt",weights_only=True).to(device)
# initial_latents = torch.randn(1,4,32,32).to(device)
timestep = torch.full((1,), 1, dtype=torch.long).to(device)
object_embedding = torch.randn(1, 257, 1024).to(device)#torch.load("./tests/tensors/object_embedding.pt",weights_only=True).to(device)
# negative_object_embedding = torch.randn(1, 257, 1024).to(device) #torch.load("./tests/tensors/negative_object_embedding.pt",weights_only=True).to(device)
# control = [x.to(device) for x in torch.load("./tests/tensors/control_features.pt",weights_only=True)]
# inference_steps = 10
# scale = 5.0

In [7]:
from anydoor_original.control_net import model as controlnet

No module 'xformers'. Proceeding without it.
Model loaded successfully


In [8]:
controlnet.to(device)
True # Just to avoid printing the model

True

In [9]:
with torch.no_grad():
    control = controlnet.forward(torch.zeros(1), x, timestep, object_embedding)

In [10]:

controlnet_refiners = ControlNet(4, device=device)

In [11]:


with open("./tests/weights_mapping/control_net.json", "r") as f:
    weight_mapping = json.load(f)
converted_state_dict = get_converted_state_dict(
    source_state_dict=controlnet.state_dict(),
    target_state_dict=controlnet_refiners.state_dict(),
    mapping=weight_mapping,
)
controlnet_refiners.load_state_dict(converted_state_dict)



<All keys matched successfully>

In [12]:
with torch.no_grad():
    controlnet_refiners.set_timestep(timestep)
    controlnet_refiners.set_dinov2_object_embedding(object_embedding)
    control2 = controlnet_refiners(x)

In [14]:
for i,tensor in enumerate(control):
    print(torch.norm(tensor-control2[i]), tensor.shape, control2[i].shape)

tensor(0.) torch.Size([1, 320, 64, 64]) torch.Size([1, 320, 64, 64])
tensor(0.) torch.Size([1, 320, 64, 64]) torch.Size([1, 320, 64, 64])
tensor(0.) torch.Size([1, 320, 64, 64]) torch.Size([1, 320, 64, 64])
tensor(0.) torch.Size([1, 320, 32, 32]) torch.Size([1, 320, 32, 32])
tensor(0.) torch.Size([1, 640, 32, 32]) torch.Size([1, 640, 32, 32])
tensor(0.) torch.Size([1, 640, 32, 32]) torch.Size([1, 640, 32, 32])
tensor(0.) torch.Size([1, 640, 16, 16]) torch.Size([1, 640, 16, 16])
tensor(0.) torch.Size([1, 1280, 16, 16]) torch.Size([1, 1280, 16, 16])
tensor(0.) torch.Size([1, 1280, 16, 16]) torch.Size([1, 1280, 16, 16])
tensor(0.) torch.Size([1, 1280, 8, 8]) torch.Size([1, 1280, 8, 8])
tensor(0.) torch.Size([1, 1280, 8, 8]) torch.Size([1, 1280, 8, 8])
tensor(0.) torch.Size([1, 1280, 8, 8]) torch.Size([1, 1280, 8, 8])
tensor(0.) torch.Size([1, 1280, 8, 8]) torch.Size([1, 1280, 8, 8])


In [15]:
# Save control tensors
torch.save(control2, "tests/tensors/control_features_real_size.pt")

In [None]:
unet_weights = get_state_dict_from_safetensors("./ckpt/unet.safetensors")
cldm.model.diffusion_model.load_state_dict(unet_weights)

NameError: name 'get_state_dict_from_safetensors' is not defined

In [None]:
converted_state_dict = get_converted_state_dict(
    source_state_dict=cldm.model.diffusion_model.state_dict(),
    target_state_dict=unet.state_dict(),
    mapping=weight_mapping,
)
unet.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [None]:
with torch.no_grad():
    cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [object_embedding],
    }
    y1_bis = cldm.apply_model(
        x_noisy = x, 
        t = timestep, 
        cond = cond,
    )

In [None]:
with torch.no_grad():
    unet.set_control_residuals(control)
    unet.set_timestep(timestep)
    unet.set_dinov2_object_embedding(object_embedding)
    y2_bis = unet(x)

In [None]:
torch.norm(y1_bis-y2_bis),torch.norm(y1_bis),torch.norm(y2_bis)

(tensor(0.), tensor(65.7131), tensor(65.7131))

In [None]:
unet

(CHAIN) UNet(in_channels=4)
    ├── (PASS) TimestepEncoder()
    │   ├── UseContext(context=diffusion, key=timestep)
    │   ├── (CHAIN) RangeEncoder(sinusoidal_embedding_dim=320, embedding_dim=1280)
    │   │   ├── Lambda(compute_sinusoidal_embedding(x: jaxtyping.Int[Tensor, '*batch 1']) -> jaxtyping.Float[Tensor, '*batch 1 embedding_dim'])
    │   │   ├── Converter(set_device=False)
    │   │   ├── Linear(in_features=320, out_features=1280, device=cuda:0, dtype=float32) #1
    │   │   ├── SiLU()
    │   │   └── Linear(in_features=1280, out_features=1280, device=cuda:0, dtype=float32) #2
    │   └── SetContext(context=range_adapter, key=timestep_embedding)
    ├── (CHAIN) DownBlocks(in_channels=4)
    │   ├── (CHAIN) #1
    │   │   ├── Conv2d(in_channels=4, out_channels=320, kernel_size=(3, 3), padding=(1, 1), device=cuda:0, dtype=float32)
    │   │   └── (PASS) ResidualAccumulator(n=0)
    │   │       ├── (RES) Residual()
    │   │       │   └── UseContext(context=unet, key=residuals