- LDM trainer and sampler for triplanes

In [1]:
import os

os.getcwd()

'/root/dev/jjuke_diffusion'

# Unet for LDM

Neural backbone $ \epsilon_\theta( \mathbf{z}_t, t, \mathbf{y}) $ of the LDM

- $ \mathbf{y} $: Only handling text here!

In [2]:
from jjuke_diffusion.unet_cond.unet import UNetModel

# Conditioning Mechanism

Domain specific encoder $ \tau_\theta $

- Input:
    - $ \mathcal{y} $
- Output:
    - Intermediate representation $ \tau_\theta(\mathbf{y}) \in \mathbb{R}^{M \times d_\tau} $


## Domain specific encoder with BERT

In [3]:
from functools import partial

import torch
from torch import nn
from einops import rearrange, repeat
from transformers import CLIPProcessor, CLIPModel
import kornia

from jjuke_diffusion.unet_ldm.transformer import Encoder, TransformerWrapper

In [4]:
class AbstractEncoder(nn.Module):
    def __init__(self):
        super().__init__()

    def encode(self, *args, **kwargs):
        raise NotImplementedError

In [5]:
class BERTTokenizer(AbstractEncoder):
    """ Uses pre-trained BERT tokenizer from huggingface """
    def __init__(self, max_len=77, device="cuda"):
        super().__init__()

        from transformers import BertTokenizerFast
        self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
        self.max_len = max_len
        self.device = device


    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_len,
                                        return_length=True, return_overflowing_tokens=False,
                                        padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        return tokens

    @torch.no_grad()
    def encode(self, text):
        return self(text)

    def decode(self, text):
        return text

In [6]:
class BERTEmbedder(AbstractEncoder):
    """ Uses BERT tokenizer model and add some transformer encoder layers """
    def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
                 use_tokenizer=True, emb_dropout=0.0, device="cuda"):
        super().__init__()

        self.use_tokenizer = use_tokenizer
        if self.use_tokenizer:
            self.tokenize = BERTTokenizer(max_len=max_seq_len)

        encoder = Encoder(dim=n_embed, depth=n_layer)
        self.transformer = TransformerWrapper(
            num_tokens=vocab_size, max_seq_len=max_seq_len, attn_layers=encoder, emb_dropout=emb_dropout
        )


    def forward(self, text):
        if self.use_tokenizer:
            tokens = self.tokenize(text) # .to(self.device)
        else:
            tokens = text

        z = self.transformer(tokens, return_embeddings=True)
        return z


    def encode(self, text):
        return self(text) # output length: 77

In [8]:
text_encoder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()

In [9]:
text_example = ["a photo of a cat", "a photo of dog"]
text_encoder(text_example).shape

torch.Size([2, 77, 1280])

## Domain specific encoder with CLIP

In [1]:
from PIL import Image
import requests

from transformers import CLIPProcessor, CLIPModel
from torchvision.transforms import ToTensor

model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

print(ToTensor()(img).shape)

  from .autonotebook import tqdm as notebook_tqdm
(…)t-large-patch14/resolve/main/config.json: 100%|█| 4.52
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
model.safetensors: 100%|█| 1.71G/1.71G [01:34<00:00, 18.1
(…)14/resolve/main/preprocessor_config.json: 100%|█| 316/
(…)tch14/resolve/main/tokenizer_config.json: 100%|█| 905/
(…)it-large-patch14/resolve/main/vocab.json: 100%|█| 961k
(…)it-large-patch14/resolve/main/merges.txt: 100%|█| 525k
(…)arge-patch14/resolve/main/tokenizer.json: 100%|█| 2.22
(…)h14/resolve/main/special_tokens_map.json: 100%|█| 389/


torch.Size([3, 480, 640])


In [11]:
inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=img, return_tensors="pt", padding=True)

In [12]:
print("keys of inputs: ", inputs.keys())
print("shape of input_ids: ", inputs["input_ids"].shape)
print("shape of attention_mask: ", inputs["attention_mask"].shape)
print("shape of pixel_values: ", inputs["pixel_values"].shape)

keys of inputs:  dict_keys(['input_ids', 'attention_mask', 'pixel_values'])
shape of input_ids:  torch.Size([2, 7])
shape of attention_mask:  torch.Size([2, 7])
shape of pixel_values:  torch.Size([1, 3, 224, 224])


In [13]:
outputs = model(**inputs)

In [14]:
print(outputs.__class__)
print("keys of outputs: ", vars(outputs).keys())
print("shape of loss: ", vars(outputs)["loss"])
print("shape of logits_per_image: ", vars(outputs)["logits_per_image"].shape)
print("shape of logits_per_text: ", vars(outputs)["logits_per_text"].shape)
print("shape of image_embeds: ", vars(outputs)["image_embeds"].shape)
print("shape of text_embeds: ", vars(outputs)["text_embeds"].shape)
print("class of model output: ", vars(outputs)["vision_model_output"].__class__)

<class 'transformers.models.clip.modeling_clip.CLIPOutput'>
keys of outputs:  dict_keys(['loss', 'logits_per_image', 'logits_per_text', 'text_embeds', 'image_embeds', 'text_model_output', 'vision_model_output'])
shape of loss:  None
shape of logits_per_image:  torch.Size([1, 2])
shape of logits_per_text:  torch.Size([2, 1])
shape of image_embeds:  torch.Size([1, 768])
shape of text_embeds:  torch.Size([2, 768])
class of model output:  <class 'transformers.modeling_outputs.BaseModelOutputWithPooling'>


In [15]:
print("Output of CLIP vision model: ")
print("last_hidden_state: ", vars(vars(outputs)["vision_model_output"])["last_hidden_state"].shape)
print("pooler_output: ", vars(vars(outputs)["vision_model_output"])["last_hidden_state"].shape)
print("hidden_states: ", vars(vars(outputs)["vision_model_output"])["hidden_states"])
print("attentions: ", vars(vars(outputs)["vision_model_output"])["attentions"])

print("\nOutput of CLIP text model: ")
print("last_hidden_state: ", vars(vars(outputs)["text_model_output"])["last_hidden_state"].shape)
print("pooler_output: ", vars(vars(outputs)["text_model_output"])["last_hidden_state"].shape)
print("hidden_states: ", vars(vars(outputs)["text_model_output"])["hidden_states"])
print("attentions: ", vars(vars(outputs)["text_model_output"])["attentions"])

Output of CLIP vision model: 
last_hidden_state:  torch.Size([1, 257, 1024])
pooler_output:  torch.Size([1, 257, 1024])
hidden_states:  None
attentions:  None

Output of CLIP text model: 
last_hidden_state:  torch.Size([2, 7, 768])
pooler_output:  torch.Size([2, 7, 768])
hidden_states:  None
attentions:  None


In [16]:
inputs_wo_img = processor(text=["a photo of a cat", "a photo of a dog"], padding=True)

In [17]:
print("keys of inputs: ", inputs_wo_img.keys())
print("input_ids: ", inputs_wo_img["input_ids"])
print("attention_mask: ", inputs_wo_img["attention_mask"])
# print("shape of pixel_values: ", inputs_wo_img["pixel_values"].shape) → error

keys of inputs:  dict_keys(['input_ids', 'attention_mask'])
input_ids:  [[49406, 320, 1125, 539, 320, 2368, 49407], [49406, 320, 1125, 539, 320, 1929, 49407]]
attention_mask:  [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]]


In [18]:
from transformers import CLIPTokenizerFast, CLIPTextModel

tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

text_inputs = ["a photo of a cat", "a photo of a dog"]
max_len = 77

batch_encoding = tokenizer(text_inputs, truncation=True, max_length=max_len,
                           return_length=True, return_overflowing_tokens=False,
                           padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"]

print(tokens)
print(tokens.shape)

tensor([[49406,   320,  1125,   539,   320,  2368, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407],
        [49406,   320,  1125,   539,   320,  1929, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 

In [19]:
from transformers import CLIPTokenizerFast, CLIPTextModel
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")

text = "a photo of a cat"


In [20]:
class CLIPTokenizer(AbstractEncoder):
    """ Uses pre-trained CLIP tokenizer from huggingface """
    def __init__(self, max_len=77, device="cuda"):
        super().__init__()

        from transformers import CLIPTokenizerFast
        self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
        self.max_len = max_len
        self.device = device

    
    def forward(self, text):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_len, 
                                        return_length=True, return_overflowing_tokens=False,
                                        padding="max_length", return_tensors="pt")
        tokens = batch_encoding["input_ids"].to(self.device)
        return tokens


    @torch.no_grad()
    def encode(self, text):
        return self(text)


    def decode(self, text):
        return text

In [21]:
from transformers import CLIPTextConfig, CLIPTextModel

clip_text_model = CLIPTextModel(CLIPTextConfig(vocab_size=49408, max_position_embeddings=77, hidden_size=1280, intermediate_size=1280, projection_dim=1280, num_hidden_layers=32, attention_dropout=0.0)).cuda()

In [22]:
def count_params(model):
    return sum(p.numel() for p in model.parameters())

In [23]:
last_hidden_state, pooler_output = clip_text_model(CLIPTokenizer(max_len=77)(text_example), return_dict=False)

In [24]:
last_hidden_state.shape # same to output of BERT Embedder (LDM)

torch.Size([2, 77, 1280])

In [25]:
pooler_output.shape

torch.Size([2, 1280])

In [26]:
class CLIPEmbedder(AbstractEncoder):
    """ Uses CLIP tokenizer and add some transformer encoder layers """
    def __init__(self, n_embed, n_layer, vocab_size=49408, max_seq_len=77,
                 use_tokenizer=True, emb_dropout=0.0, device="cuda"):
        super().__init__()

        self.use_tokenizer = use_tokenizer
        if self.use_tokenizer:
            self.tokenize = CLIPTokenizer(max_len=max_seq_len)

        encoder = Encoder(dim=n_embed, depth=n_layer)
        self.transformer = TransformerWrapper(
            num_tokens=vocab_size, max_seq_len=max_seq_len, attn_layers=encoder, emb_dropout=emb_dropout
        )


    def forward(self, text):
        if self.use_tokenizer:
            tokens = self.tokenize(text) # .to(self.device)
        else:
            tokens = text

        z = self.transformer(tokens, return_embeddings=True)
        return z


    def decode(self, text):
        return self(text) # output length: 77


class CLIPEmbedderHF(AbstractEncoder):
    """ Uses CLIP tokenizer and CLIP Text Model for the text encoder """
    def __init__(self, n_embed, n_layer, vocab_size=49408, max_seq_len=77,
                 use_tokenizer=True, emb_dropout=0.0, device="cuda"):
        super().__init__()

        from transformers import CLIPTextConfig, CLIPTextModel

        self.use_tokenizer = use_tokenizer
        if self.use_tokenizer:
            self.tokenize = CLIPTokenizer(max_len=max_seq_len)

        self.transformer = CLIPTextModel(CLIPTextConfig(
            vocab_size=vocab_size, max_position_embeddings=max_seq_len, hidden_size=n_embed,
            intermediate_size=n_embed, projection_dim=n_embed, num_hidden_layers=n_layer,
            attention_dropout=emb_dropout
        ))


    def forward(self, text):
        if self.use_tokenizer:
            tokens = self.tokenize(text) # .to(self.device)
        else:
            tokens = text

        z, _ = self.transformer(tokens, return_dict=False)
        return z


    def decode(self, text):
        return self(text) # output length: 77

In [27]:
# compare number of parameters of CLIPEmbedder, CLIPEmbedderHF, and BERTEmbedder

bert_embedder = BERTEmbedder(1280, 32).cuda()
clip_embedder = CLIPEmbedder(1280, 32).cuda()
clip_embedder_hf = CLIPEmbedderHF(1280, 32).cuda()
print("Number of params of BERTEmbedder: ", count_params(bert_embedder.transformer))
print("Number of params of CLIPEmbedder: ", count_params(clip_embedder.transformer))
print("Number of params of CLIPEmbedderHF: ", count_params(clip_embedder_hf.transformer))

Number of params of BERTEmbedder:  581994042
Number of params of CLIPEmbedder:  630361088
Number of params of CLIPEmbedderHF:  378325760


In [28]:
bert_embedder(text_example).shape

torch.Size([2, 77, 1280])

In [29]:
clip_embedder(text_example).shape

torch.Size([2, 77, 1280])

In [30]:
clip_embedder_hf(text_example).shape

torch.Size([2, 77, 1280])

# LDMTrainer

$$ \mathcal{L}_\text{LDM} := \mathbb{E}_{\mathbf{y}, \epsilon \sim \mathcal{N}(0, 1), t} \left[ \left\Vert \epsilon - \epsilon_\theta(\mathbf{z}_t, t, \tau_\theta(\mathbf{y})) \right\Vert_2^2 \right] $$

- Input:
    - x: Feature map
    - c: Condition
    - ...
- Output:
    - Losses dictinoary

In [31]:
from jjuke_diffusion.diffusion.common import get_betas
from jjuke_diffusion.diffusion.ddpm import DDPMTrainer
from jjuke_diffusion.unet_cond.unet import UNetModel

In [32]:
# DDPM Trainer args
args_model_mean_type = "eps"
args_model_var_type = "fixed_small"
args_loss_type = "l2"

betas = get_betas("linear", 1000)

In [45]:
trainer = DDPMTrainer(
    betas,
    model_mean_type=args_model_mean_type,
    model_var_type=args_model_var_type,
    loss_type=args_loss_type,
    clip_denoised=False # TODO: check if it is True
).cuda()

In [46]:
model = UNetModel(
    unet_dim=2,
    in_channels=96,
    out_channels=96,
    model_channels=320,
    attention_resolutions=[16, 8, 4], # downsampling factor = spatial resolution (h, w of feature maps) / CA resoliutions (32, 16, 8 → from LDM paper)
    channel_mult=[1, 2, 4, 4],
    num_heads=8,
    use_spatial_transformer=True,
    transformer_depth=1,
    context_dim=1280,
    attention_type="xformers",
).cuda()
# lr: 1e-4
# iterations: 390,000
# batch_size: 680

OutOfMemoryError: CUDA out of memory. Tried to allocate 114.00 MiB (GPU 0; 23.62 GiB total capacity; 21.67 GiB already allocated; 111.75 MiB free; 22.13 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
import torch
from einops import rearrange, repeat

In [None]:
# Feature maps (3 triplanes)
x = torch.rand(2, 128, 128, 96).cuda() # (b, h, w, c)
t = torch.randint(0, 1000, (2,)).cuda()

# Conditions
cond = bert_embedder(text_example).cuda()
print(cond.shape)

In [None]:
# get_input function of LDM official code -> get_input(batch, "caption").to(self.device)
x = rearrange(x, "b h w c -> b c h w").contiguous().float()
cond = rearrange(cond, "b n c -> b c n").contiguous().float()

In [None]:
print("Shape of feature maps: ", x.shape)
print("Shape of text embedding: ", cond.shape)

In [None]:
model(x, t, cond).shape

In [None]:
feature_maps = torch.rand(2, 128, 128, 96).cuda()

In [None]:
losses = trainer(model, x)

In [None]:
losses

# LDMSampler

In [44]:
from jjuke_diffusion.diffusion.ddim import DDIMSampler
from jjuke_diffusion.diffusion.karras import KarrasSampler

In [None]:
# args for samplers
args_n_sampler_steps = 50
args_eta = 0.

args_model_mean_type = "eps"
args_model_var_type = "fixed_small"
args_loss_type = "l2"

In [None]:
ddim_sampler = DDIMSampler(
    betas,
    ddim_num_timesteps=50,
    
)