In [None]:
import torch
from tqdm import tqdm
from safetensors import safe_open
from safetensors.torch import save_file
import sys
sys.path.append("./AnyDoor/")
sys.path.append("./")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def get_state_dict_from_safetensors(path:str):
    """ Load a state dict from a safetensors file """
    tensors = {}
    with safe_open(path, framework="pt", device=device) as f:
        for key in f.keys():
            tensors[key] = f.get_tensor(key)
    return tensors

# Sampling


In [7]:
x = torch.load("./tests/tensors/x.pt",weights_only=True)
initial_latents = torch.randn(1,4,32,32)
object_embedding = torch.load("./tests/tensors/object_embedding.pt",weights_only=True)
negative_object_embedding = torch.load("./tests/tensors/negative_object_embedding.pt",weights_only=True)
control = torch.load("./tests/tensors/control_features.pt",weights_only=True)
inference_steps = 10
scale = 5.0

In [None]:
from anydoor_original.cldm import model as cldm
from cldm.ddim_hacked import DDIMSampler
sampler = DDIMSampler(cldm)

In [12]:
with torch.no_grad():
    mocked_control_image = torch.zeros(1, 4, 32, 32)
    cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [object_embedding],
    }
    un_cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [negative_object_embedding],
    }

    samples, intermediates = sampler.sample(
        S=inference_steps,
        batch_size=1,
        shape=(4, 32, 32),
        conditioning=cond,
        x_T=initial_latents.clone(),
        verbose=False,
        unconditional_guidance_scale=scale,
        unconditional_conditioning=un_cond,
    )

Data shape for DDIM sampling is (1, 4, 32, 32), eta 0.0
Running DDIM Sampling with 10 timesteps


DDIM Sampler:  10%|█         | 1/10 [00:01<00:17,  1.98s/it]

tensor(68.5324)


DDIM Sampler:  20%|██        | 2/10 [00:03<00:15,  1.98s/it]

tensor(71.9554)


DDIM Sampler:  30%|███       | 3/10 [00:05<00:13,  1.98s/it]

tensor(74.8372)


DDIM Sampler:  40%|████      | 4/10 [00:07<00:11,  1.96s/it]

tensor(77.4812)


DDIM Sampler:  50%|█████     | 5/10 [00:09<00:09,  1.97s/it]

tensor(81.0923)


DDIM Sampler:  60%|██████    | 6/10 [00:11<00:08,  2.01s/it]

tensor(85.3483)


DDIM Sampler:  70%|███████   | 7/10 [00:13<00:05,  1.98s/it]

tensor(88.4471)


DDIM Sampler:  80%|████████  | 8/10 [00:15<00:03,  1.99s/it]

tensor(89.3949)


DDIM Sampler:  90%|█████████ | 9/10 [00:18<00:02,  2.04s/it]

tensor(87.8062)


DDIM Sampler: 100%|██████████| 10/10 [00:21<00:00,  2.20s/it]

tensor(87.7215)





In [None]:
from anydoor_refiners.model import AnyDoor,solver_params
from refiners.foundationals.latent_diffusion.solvers import DDIM
from tests.mocks import DINOv2EncoderMock,ControlNetMock,AnydoorAutoencoderMock

refiners_model = AnyDoor(
    lda=AnydoorAutoencoderMock(),
    object_encoder=DINOv2EncoderMock(object_embedding,negative_object_embedding),
    control_model=ControlNetMock(control),
    solver=DDIM(inference_steps,params=solver_params)
)
# refiners_model.unet.load_state_dict(weights_refiners)


In [None]:
import json
from utils.weight_mapper import get_converted_state_dict

with open("../tests/weights_mapping/unet.json", "r") as f:
    weight_mapping = json.load(f)
converted_state_dict = get_converted_state_dict(
    source_state_dict=sampler.model.model.diffusion_model.state_dict(),
    target_state_dict=refiners_model.unet.state_dict(),
    mapping=weight_mapping,
)
refiners_model.unet.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [None]:
with torch.no_grad():

    y = initial_latents.clone()
    conditionning = refiners_model.compute_conditionning(object=torch.zeros(1),background=torch.zeros(1))
    for s in tqdm(refiners_model.steps):
        y = refiners_model.forward(y,step=s,conditionning=conditionning,condition_scale=scale)
        print(torch.norm(y))

 10%|█         | 1/10 [00:01<00:09,  1.06s/it]

tensor(67.1199)


 20%|██        | 2/10 [00:02<00:08,  1.06s/it]

tensor(70.5397)


 30%|███       | 3/10 [00:03<00:07,  1.09s/it]

tensor(73.1566)


 40%|████      | 4/10 [00:04<00:06,  1.11s/it]

tensor(74.3970)


 50%|█████     | 5/10 [00:05<00:05,  1.09s/it]

tensor(74.0968)


 60%|██████    | 6/10 [00:07<00:05,  1.32s/it]

tensor(72.4916)


 70%|███████   | 7/10 [00:09<00:04,  1.56s/it]

tensor(69.9599)


 80%|████████  | 8/10 [00:11<00:03,  1.76s/it]

tensor(66.9328)


 90%|█████████ | 9/10 [00:13<00:01,  1.85s/it]

tensor(61.8021)


100%|██████████| 10/10 [00:15<00:00,  1.54s/it]

tensor(61.4083)





In [None]:
torch.norm(y-samples[-1]),torch.norm(y),torch.norm(samples[-1])

(tensor(58.3949), tensor(61.4083), tensor(86.1132))

### Using Anydoor weights

In [None]:
weights_anydoor = get_state_dict_from_safetensors("./ckpt/unet.safetensors")
sampler.model.model.diffusion_model.load_state_dict(weights_anydoor)

In [None]:
converted_state_dict = get_converted_state_dict(
    source_state_dict=sampler.model.model.diffusion_model.state_dict(),
    target_state_dict=refiners_model.unet.state_dict(),
    mapping=weight_mapping,
)
refiners_model.unet.load_state_dict(converted_state_dict)

In [None]:
with torch.no_grad():
    mocked_control_image = torch.zeros(1, 4, 32, 32)
    cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [object_embedding],
    }
    un_cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [negative_object_embedding],
    }

    samples, intermediates = sampler.sample(
        S=inference_steps,
        batch_size=1,
        shape=(4, 32, 32),
        conditioning=cond,
        x_T=initial_latents.clone(),
        verbose=False,
        unconditional_guidance_scale=scale,
        unconditional_conditioning=un_cond,
    )

In [None]:
with torch.no_grad():

    y = initial_latents.clone()
    conditionning = refiners_model.compute_conditionning(object=torch.zeros(1),background=torch.zeros(1))
    for s in tqdm(refiners_model.steps):
        y = refiners_model.forward(y,step=s,conditionning=conditionning,condition_scale=scale)
        print(torch.norm(y))

# Unet




In [3]:
import json
from utils.weight_mapper import get_converted_state_dict
from anydoor_refiners.unet import UNet

In [4]:
x = torch.load("./tests/tensors/x.pt",weights_only=True)
initial_latents = torch.randn(1,4,32,32)
timestep = torch.full((1,), 960, dtype=torch.long)
object_embedding = torch.load("./tests/tensors/object_embedding.pt",weights_only=True)
negative_object_embedding = torch.load("./tests/tensors/negative_object_embedding.pt",weights_only=True)
control = torch.load("./tests/tensors/control_features.pt",weights_only=True)
inference_steps = 10
scale = 5.0

In [5]:
from anydoor_original.cldm import model as cldm

No module 'xformers'. Proceeding without it.
ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 865.91 M params.
Loaded model config from [./src/anydoor_original/configs/anydoor.yaml]
Model loaded successfully


In [6]:
with torch.no_grad():
    cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [object_embedding],
    }
    y1 = cldm.apply_model(
        x_noisy = x, 
        t = timestep, 
        cond = cond,
    )

In [7]:

unet = UNet(4)

In [8]:


with open("./tests/weights_mapping/unet.json", "r") as f:
    weight_mapping = json.load(f)
converted_state_dict = get_converted_state_dict(
    source_state_dict=cldm.model.diffusion_model.state_dict(),
    target_state_dict=unet.state_dict(),
    mapping=weight_mapping,
)
unet.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [9]:
with torch.no_grad():
    unet.set_control_residuals(control)
    unet.set_timestep(timestep)
    unet.set_dinov2_object_embedding(object_embedding)
    y2 = unet(x)

In [11]:
torch.norm(y1-y2),torch.norm(y1),torch.norm(y2)

(tensor(0.), tensor(0.), tensor(0.))

In [12]:
unet_weights = get_state_dict_from_safetensors("./ckpt/unet.safetensors")
cldm.model.diffusion_model.load_state_dict(unet_weights)

<All keys matched successfully>

In [13]:
converted_state_dict = get_converted_state_dict(
    source_state_dict=cldm.model.diffusion_model.state_dict(),
    target_state_dict=unet.state_dict(),
    mapping=weight_mapping,
)
unet.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [14]:
with torch.no_grad():
    cond = {
        "c_concat": control, ## Not used
        "c_crossattn": [object_embedding],
    }
    y1_bis = cldm.apply_model(
        x_noisy = x, 
        t = timestep, 
        cond = cond,
    )

In [15]:
with torch.no_grad():
    unet.set_control_residuals(control)
    unet.set_timestep(timestep)
    unet.set_dinov2_object_embedding(object_embedding)
    y2_bis = unet(x)

In [16]:
torch.norm(y1_bis-y2_bis),torch.norm(y1_bis),torch.norm(y2_bis)

(tensor(2.3982e-05), tensor(65.5668), tensor(65.5668))

# SpatialTransformer




In [5]:
import json
from utils.weight_mapper import get_converted_state_dict
from anydoor_refiners.attention import CrossAttentionBlock2d

In [6]:
from AnyDoor.ldm.modules.attention import SpatialTransformer

In [7]:
# Define model configuration parameters with descriptive names
input_channels = 320  # Number of input channels for the model
num_heads = 5  # Number of attention heads
head_dim = 64  # Dimension of each attention head
num_layers = 1  # Depth of attention layers
context_dim = 1024  # Dimension of the context embedding
use_linear_projection = True  # Whether to use linear projection in attention

# Initialize the SpatialTransformer model
spatial_transformer = SpatialTransformer(
    in_channels=input_channels,
    n_heads=num_heads,
    d_head=head_dim,
    depth=num_layers,
    context_dim=context_dim,
    use_linear=use_linear_projection,
)

In [8]:

# Initialize the CrossAttentionBlock2d model
cross_attention_block = CrossAttentionBlock2d(
    channels=input_channels,
    context_embedding_dim=context_dim,
    context_key="key",  # Key to set the context in cross_attention_block
    num_attention_heads=num_heads,
    num_attention_layers=num_layers,
    num_groups=32,  # Number of groups for grouped attention
    use_bias=False,
    use_linear_projection=use_linear_projection,
)

In [9]:
# Convert the source model's state dict to match the target model's structure
with open("tests/weights_mapping/cross_attention_block_2d.json", "r") as f:
    weight_mapping = json.load(f)
converted_state_dict = get_converted_state_dict(
    source_state_dict=spatial_transformer.state_dict(),
    target_state_dict=cross_attention_block.state_dict(),
    mapping=weight_mapping,
)
cross_attention_block.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [10]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
context_dim = 1024  # Must match the model's context dimension configuration
input_tensor = torch.randn(1, input_channels, 32, 32)  # Example input tensor
context_tensor = torch.randn(1, 1, context_dim)  # Example context tensor

with torch.no_grad():
    # Set the context for the CrossAttentionBlock2d model
    cross_attention_block.set_context(  # noqa: F821
        "cross_attention_block", {"key": context_tensor}
    )
    # Forward pass through both models
    y_target = cross_attention_block.forward(input_tensor)  # noqa: F821
    y_source = spatial_transformer.forward(input_tensor, context=context_tensor)
    
torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)
    

(tensor(0.), tensor(570.7159), tensor(570.7162))

In [11]:

spatial_transformer_weights = get_state_dict_from_safetensors("./ckpt/spatial_transformer.safetensors")
spatial_transformer.load_state_dict(spatial_transformer_weights)
converted_state_dict = get_converted_state_dict(
    source_state_dict=spatial_transformer.state_dict(),
    target_state_dict=cross_attention_block.state_dict(),  # noqa: F821
    mapping=weight_mapping,
)
cross_attention_block.load_state_dict(converted_state_dict)  # noqa: F821


<All keys matched successfully>

In [12]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
context_dim = 1024  # Must match the model's context dimension configuration
input_tensor = torch.randn(1, input_channels, 32, 32)  # Example input tensor
context_tensor = torch.randn(1, 1, context_dim)  # Example context tensor

with torch.no_grad():
    # Set the context for the CrossAttentionBlock2d model
    cross_attention_block.set_context(  # noqa: F821
        "cross_attention_block", {"key": context_tensor}
    )
    # Forward pass through both models
    y_target = cross_attention_block.forward(input_tensor)  # noqa: F821
    y_source = spatial_transformer.forward(input_tensor, context=context_tensor)
    
torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)
    

(tensor(0.0002), tensor(528.3328), tensor(528.3333))

In [16]:
transformer_block_keys = [
    k for k in spatial_transformer_weights.keys() if "transformer_blocks.0." in k
]
transformer_block_weights = {
    k.replace("transformer_blocks.0.",""): spatial_transformer_weights[k] for k in transformer_block_keys
}
save_file(transformer_block_weights,"./ckpt/transformer_block.safetensors")

In [17]:
input_projection_weight_keys = [
    'norm.bias', 'norm.weight', 'proj_in.bias', 'proj_in.weight'
]
input_projection_weights = {
    key: spatial_transformer.state_dict()[key]
    for key in input_projection_weight_keys
}
save_file(input_projection_weights, "./ckpt/input_projection.safetensors")

In [18]:
output_projection_weight_keys = [
    'proj_out.bias', 'proj_out.weight'
]
output_projection_weights = {
    key: spatial_transformer.state_dict()[key]
    for key in output_projection_weight_keys
}
save_file(output_projection_weights, "./ckpt/output_projection.safetensors")

# Input Projection





In [3]:
import json
import refiners.fluxion.layers as fl
from torch import nn
from utils.weight_mapper import get_converted_state_dict
from einops import rearrange
from refiners.fluxion.context import Contexts

In [4]:
class SmallModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm = nn.GroupNorm(num_groups=32, num_channels=320, eps=1e-6, affine=True)
        self.proj_in = nn.Linear(320, 320)

    def forward(self, x):
        x = self.norm(x)
        x = rearrange(x, "b c h w -> b (h w) c").contiguous()
        x = self.proj_in(x)
        return x


class StatefulFlatten(fl.Chain):
    def __init__(
        self, context: str, key: str, start_dim: int = 0, end_dim: int = -1
    ) -> None:
        self.start_dim = start_dim
        self.end_dim = end_dim

        super().__init__(
            fl.SetContext(context=context, key=key, callback=self.push),
            fl.Flatten(start_dim=start_dim, end_dim=end_dim),
        )

    def push(self, sizes: list[torch.Size], x: torch.Tensor) -> None:
        sizes.append(
            x.shape[
                slice(
                    self.start_dim,
                    (
                        self.end_dim + 1
                        if self.end_dim >= 0
                        else x.ndim + self.end_dim + 1
                    ),
                )
            ]
        )


class SmallModelRefiners(fl.Chain):
    def __init__(self):
        super().__init__(
            fl.GroupNorm(
                channels=320,
                num_groups=32,
                eps=1e-6
            ),
            StatefulFlatten(context="flatten", key="sizes", start_dim=2),
            fl.Transpose(1, 2),
            fl.Lambda(lambda x: x.contiguous()),
            fl.Linear(
                in_features=320,
                out_features=320,
            ),
        )
    def init_context(self) -> Contexts:
        return {"flatten": {"sizes": []}}

In [5]:
anydoor = SmallModel()
refiners = SmallModelRefiners()

In [6]:
mapping = {
    "GroupNorm": "norm",
    "Linear": "proj_in",
}
refiners_state_dict_converted = get_converted_state_dict(
    source_state_dict=anydoor.state_dict(),
    target_state_dict=refiners.state_dict(),
    mapping=mapping,
)
refiners.load_state_dict(refiners_state_dict_converted)

<All keys matched successfully>

In [7]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
input_tensor = torch.randn(1, input_channels, 32, 32)  # Example input tensor

with torch.no_grad():
    y_projected_target = anydoor.forward(input_tensor) 
    y_projected_source = refiners.forward(input_tensor)
    
torch.norm(y_projected_target-y_projected_source),torch.norm(y_projected_target),torch.norm(y_projected_source)
    

(tensor(0.), tensor(331.8626), tensor(331.8626))

In [8]:

weights = get_state_dict_from_safetensors("./ckpt/input_projection.safetensors")
anydoor.load_state_dict(weights)
refiners_state_dict_converted = get_converted_state_dict(
    source_state_dict=anydoor.state_dict(),
    target_state_dict=refiners.state_dict(),
    mapping=mapping,
)
refiners.load_state_dict(refiners_state_dict_converted)


<All keys matched successfully>

In [9]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
input_tensor = torch.randn(1, input_channels, 32, 32)  # Example input tensor

with torch.no_grad():
    y_projected_target = anydoor.forward(input_tensor) 
    y_projected_source = refiners.forward(input_tensor)
    
torch.norm(y_projected_target-y_projected_source),torch.norm(y_projected_target),torch.norm(y_projected_source)
    

(tensor(0.), tensor(177.1698), tensor(177.1698))

# TransformerBlock




In [10]:
import json
from utils.weight_mapper import get_converted_state_dict
from anydoor_refiners.attention import CrossAttentionBlock

In [11]:
from AnyDoor.ldm.modules.attention import BasicTransformerBlock

In [12]:
# Define model configuration parameters with descriptive names
input_channels = 320  # Number of input channels for the model
num_heads = 5  # Number of attention heads
head_dim = 64  # Dimension of each attention head
num_layers = 1  # Depth of attention layers
context_dim = 1024  # Dimension of the context embedding
use_linear_projection = True  # Whether to use linear projection in attention

# Initialize the SpatialTransformer model
anydoor = BasicTransformerBlock(
    dim=input_channels,
    n_heads=num_heads,
    d_head=head_dim,
    context_dim=context_dim,
    disable_self_attn=False,
    checkpoint=True)

refiners = CrossAttentionBlock(
    embedding_dim=input_channels,
    context_embedding_dim=context_dim,
    context_key="key",
    num_heads=num_heads,
    use_bias=False)



In [13]:
mapping = {
    "Residual_1.SelfAttention.Linear": "attn1.to_out.0",
    "Residual_2.Attention.Linear": "attn2.to_out.0",
    "Residual_1.LayerNorm": "norm1",
    "Residual_2.LayerNorm": "norm2",
    "Residual_3.LayerNorm": "norm3",
    "Residual_1.SelfAttention.Distribute.Linear_1": "attn1.to_q",
    "Residual_1.SelfAttention.Distribute.Linear_2": "attn1.to_k",
    "Residual_1.SelfAttention.Distribute.Linear_3": "attn1.to_v",
    "Residual_2.Attention.Distribute.Linear_1": "attn2.to_q",
    "Residual_2.Attention.Distribute.Linear_2": "attn2.to_k",
    "Residual_2.Attention.Distribute.Linear_3": "attn2.to_v",
    "Residual_3.Linear_1": "ff.net.0.proj",
    "Residual_3.Linear_2": "ff.net.2",
}

In [14]:

converted_state_dict = get_converted_state_dict(
    source_state_dict=anydoor.state_dict(),
    target_state_dict=refiners.state_dict(),
    mapping=mapping,
)
refiners.load_state_dict(converted_state_dict)

<All keys matched successfully>

In [15]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
context_dim = 1024  # Must match the model's context dimension configuration
input_tensor = torch.randn(1, 1, input_channels)  # Example input tensor
context_tensor = torch.randn(1, 1, context_dim)  # Example context tensor

with torch.no_grad():
    # Set the context for the CrossAttentionBlock2d model
    refiners.set_context(  # noqa: F821
        "cross_attention_block", {"key": context_tensor}
    )
    # Forward pass through both models
    y_target = refiners.forward(input_tensor)  # noqa: F821
    y_source = anydoor.forward(input_tensor, context=context_tensor)
    
torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)
    

(tensor(0.), tensor(21.7846), tensor(21.7846))

In [16]:

weights = get_state_dict_from_safetensors("./ckpt/transformer_block.safetensors")
anydoor.load_state_dict(weights)
converted_state_dict = get_converted_state_dict(
    source_state_dict=anydoor.state_dict(),
    target_state_dict=refiners.state_dict(),  # noqa: F821
    mapping=mapping,
)
refiners.load_state_dict(converted_state_dict)  # noqa: F821


<All keys matched successfully>

In [17]:
# Define input tensors
input_channels = 320  # Must match the model's input channel configuration
context_dim = 1024  # Must match the model's context dimension configuration
input_tensor = torch.randn(1, 1, input_channels)  # Example input tensor
context_tensor = torch.randn(1, 1, context_dim)  # Example context tensor

with torch.no_grad():
    # Set the context for the CrossAttentionBlock2d model
    refiners.set_context(  # noqa: F821
        "cross_attention_block", {"key": context_tensor}
    )
    # Forward pass through both models
    y_target = refiners.forward(input_tensor)  # noqa: F821
    y_source = anydoor.forward(input_tensor, context=context_tensor)
    
print(torch.allclose(y_target,y_source,rtol=1e-7,atol=1e-7))
torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)
    

True


(tensor(0.), tensor(18.9913), tensor(18.9913))

In [18]:
context_tensor = torch.randn(1, 1, context_dim)  # Example context tensor

In [19]:
y_projected_target.dtype, y_projected_source.dtype, torch.allclose(y_projected_target,y_projected_source,rtol=1e-12,atol=1e-12)

(torch.float32, torch.float32, True)

In [21]:
y_projected_target.is_contiguous(), y_projected_source.is_contiguous()

(True, True)

In [23]:
for weight in refiners.parameters():
    print(weight.is_contiguous(), weight.dtype)

True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32


In [24]:
for weight in anydoor.parameters():
    print(weight.is_contiguous(), weight.dtype)

True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32
True torch.float32


In [20]:

with torch.no_grad():
    # Set the context for the CrossAttentionBlock2d model
    refiners.set_context(  # noqa: F821
        "cross_attention_block", {"key": context_tensor}
    )
    # Forward pass through both models
    y_transformed_target = refiners.forward(y_projected_target)  # noqa: F821
    y_transformed_source = anydoor.forward(y_projected_source, context=context_tensor)
    
print(torch.allclose(y_transformed_target,y_transformed_source,rtol=1e-7,atol=1e-7))
torch.norm(y_transformed_target-y_transformed_source),torch.norm(y_transformed_target),torch.norm(y_transformed_source)

False


(tensor(0.0002), tensor(225.7923), tensor(225.7923))

# Output Projection





In [None]:
output_projection = nn.Linear(320,320)


with torch.no_grad():
    y_target = output_projection(y_target) 
    y_source = output_projection(y_source)

torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)

In [None]:
weights = get_state_dict_from_safetensors("./ckpt/output_projection.safetensors")
output_projection.load_state_dict(weights)

In [None]:
output_projection = nn.Linear(320,320)


with torch.no_grad():
    y_target = output_projection(y_target) 
    y_source = output_projection(y_source)

torch.norm(y_target-y_source),torch.norm(y_target),torch.norm(y_source)

# LDA

In [2]:
import torch
import sys
from omegaconf import OmegaConf

sys.path.append("./AnyDoor/")
from ldm.util import instantiate_from_config

conf = OmegaConf.load("src/anydoor_original/configs/anydoor.yaml")

lda_anydoor = instantiate_from_config(conf.model.params.first_stage_config)


No module 'xformers'. Proceeding without it.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


In [3]:
lda_anydoor.load_state_dict(torch.load("ckpt/lda_anydoor.ckpt"))

  lda_anydoor.load_state_dict(torch.load("ckpt/lda_anydoor.ckpt"))


<All keys matched successfully>

In [4]:
from anydoor_refiners.model import AnydoorAutoencoder

lda_refiners = AnydoorAutoencoder()

In [6]:
lda_refiners = lda_refiners.load_from_safetensors("ckpt/anydoor_refiners_safetensors/lda.safetensors")

In [7]:
# Print nb of trainable parameters of the two models

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(lda_anydoor))
print(count_parameters(lda_refiners))


83653863
83653863


In [15]:
with torch.no_grad():
    img = torch.randn(1,3,256,256)
    y1 = lda_refiners.forward(img)
    y2 = lda_anydoor.forward(img)

In [19]:
torch.norm(y1 - y2[0])

tensor(207.3743)

In [16]:
y2[0].shape

torch.Size([1, 3, 256, 256])

In [17]:
from refiners.conversion.model_converter import ModelConverter

converter = ModelConverter(source_model=lda_anydoor,target_model=lda_refiners)

In [18]:
with torch.no_grad():
    img = torch.randn(1,3,256,256)
    converter.run(source_args=(img,),target_args=(img,))

Models do not have the same number of basic layers:
  <class 'torch.nn.modules.conv.Conv2d'>: Source 72 - Target 64
  <class 'torch.nn.modules.linear.Linear'>: Source 0 - Target 8
Conversion failed at stage 1


In [7]:
y.shape

torch.Size([1, 3, 256, 256])