The goal of this notebook is to play around with how time information is used in unet in controlnet.

---

In [1]:
import torch
import torch.nn as nn

In [2]:
device = "cpu"
device_dtype = torch.float32

In [3]:
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel

In [4]:
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=device_dtype)

In [5]:
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=device_dtype
).to(device)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


In [7]:
unet = pipe.unet

In [24]:
def public_attrs(o, contains=''): return [a for a in dir(o) if not a.startswith('_') and contains in a]

In [26]:
public_attrs(unet, contains='time')

['time_embed_act', 'time_embedding', 'time_proj']

In [32]:
unet.time_embedding

TimestepEmbedding(
  (linear_1): Linear(in_features=320, out_features=1280, bias=True)
  (act): SiLU()
  (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
)

In [28]:
unet.time_proj

Timesteps()

In [30]:
unet.time_embed_act??

[0;31mType:[0m        NoneType
[0;31mString form:[0m None
[0;31mDocstring:[0m   <no docstring>

**Q1**: I assumed `time_proj` to be a module, but it is of type `Timesteps`. Could it actually be the embedding, and not the projection?

**Q2**: I assumed `time_embedding` to the time embedding (ie a deterministic, non-parametric function), but is is a module. 

**Q3**: Are the names just swapped, ie diffusers calls 'projection' what I understand as 'embedding', and vice versa?

In [34]:
unet.time_proj??

[0;31mSignature:[0m       [0munet[0m[0;34m.[0m[0mtime_proj[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m            Timesteps
[0;31mString form:[0m     Timesteps()
[0;31mFile:[0m            ~/Documents/GitHub/diffusers/diffusers/src/diffusers/models/embeddings.py
[0;31mSource:[0m         
[0;32mclass[0m [0mTimesteps[0m[0;34m([0m[0mnn[0m[0;34m.[0m[0mModule[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mnum_channels[0m[0;34m:[0m [0mint[0m[0;34m,[0m [0mflip_sin_to_cos[0m[0;34m:[0m [0mbool[0m[0;34m,[0m [0mdownscale_freq_shift[0m[0;34m:[0m [0mfloat[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0msuper[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0mnum_channels[0m [0;34m=[0m [

`Timesteps` (the type of `time_proj`) has a forward method that only applies `get_timestep_embedding`. Let's look at that.

In [36]:
from diffusers.models.embeddings import get_timestep_embedding

In [37]:
get_timestep_embedding??

[0;31mSignature:[0m
[0mget_timestep_embedding[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtimesteps[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0membedding_dim[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mflip_sin_to_cos[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdownscale_freq_shift[0m[0;34m:[0m [0mfloat[0m [0;34m=[0m [0;36m1[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mscale[0m[0;34m:[0m [0mfloat[0m [0;34m=[0m [0;36m1[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_period[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m10000[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mget_timestep_embedding[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtimesteps[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0membedding_dim

**A1: Yes,** `time_proj` is actually the embedding! Weird naming, but okay.

___

Now, let's look at `TimestepEmbedding`, which is the type of `time_embedding`.

In [39]:
from diffusers.models.embeddings import TimestepEmbedding

In [44]:
TimestepEmbedding.forward??

[0;31mSignature:[0m [0mTimestepEmbedding[0m[0;34m.[0m[0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0msample[0m[0;34m,[0m [0mcondition[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Defines the computation performed at every call.

Should be overridden by all subclasses.

.. note::
    Although the recipe for forward pass needs to be defined within
    this function, one should call the :class:`Module` instance afterwards
    instead of this since the former takes care of running the
    registered hooks while the latter silently ignores them.
[0;31mSource:[0m   
    [0;32mdef[0m [0mforward[0m[0;34m([0m[0mself[0m[0;34m,[0m [0msample[0m[0;34m,[0m [0mcondition[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;32mif[0m [0mcondition[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m
[0;34m[0m            [0msample[0m [0;34m=[0m [0msample[0m [0;

**A2: Yes,** `time_embedding` is what I understand as a projection.

**A3: Yes, the names for these concepts are swapped!**