# FlowAdapter M2

## Imports

In [87]:
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import ruamel.yaml as yaml

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from transformers import CLIPImageProcessor
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DDIMScheduler
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy
from PIL import Image

from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ip_adapter.ip_adapter import ImageProjModel
from ip_adapter.utils import is_torch2_available
if is_torch2_available():
    from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
else:
    from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor

from models_eyeformer.model_tracking import TrackingTransformer
from pytorchSoftdtwCuda.soft_dtw_cuda import SoftDTW
from tqdm.auto import tqdm

## Flow Encoders & Projectors

In [24]:

class FlowEncoder_MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.add_module('encoder', nn.Sequential(
            nn.Linear(45, 128),
            nn.LeakyReLU(0.3),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.3),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.3),
            nn.Linear(512, 1024)))

        self.add_module('decoder', nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.3),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.3),
            nn.Linear(128, 45)))

    def forward(self, x):
      x = x.view(x.size(0), -1)
      x = self.encoder(x)
      x = self.decoder(x)
      return x

In [25]:
class FlowEncoder(nn.Module):
    def __init__(self, cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens, 
                 eyeFormer, flow_latenizer):
        super().__init__()
        self.cross_attention_dim = cross_attention_dim
        self.clip_embeddings_dim = clip_embeddings_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.eyeFormer = eyeFormer
        self.flow_latenizer = flow_latenizer
    
    def forward(self, images):
        if self.eyeFormer is not None:
            flow_embeds = self.eyeFormer(images)
        else:
            flow_embeds = images
        
        flow_embeds = flow_embeds.view(flow_embeds.size(0), -1)
        flow_embeds = self.flow_latenizer(flow_embeds)
        # print("Final flow_embeds Shape: ", flow_embeds.shape)
        return flow_embeds

In [26]:
class CorrectProjModel(torch.nn.Module):
    """
    Correct the final flow embedding wiht a linear norm and final projection layer
    """

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

## Dataset 

In [27]:
class MyDataset(torch.utils.data.Dataset):

    def __init__(self, json_file, tokenizer, size=256, 
                 t_drop_rate=0.05, i_drop_rate=0.05, 
                 ti_drop_rate=0.05, dataset_name="ueyes"):
        super().__init__()

        self.tokenizer = tokenizer
        self.size = size
        self.i_drop_rate = i_drop_rate
        self.t_drop_rate = t_drop_rate
        self.ti_drop_rate = ti_drop_rate
        self.dataset_name = dataset_name

        self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]

        self.transform = transforms.Compose([
            transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        # self.clip_image_processor = CLIPImageProcessor()
        

    def __getitem__(self, idx):
        item = self.data[idx] 
        text = item["prompt"]
        image_file = item["target"]
        
        # read image and flow vector
        raw_image = Image.open(image_file)
        image = self.transform(raw_image.convert("RGB"))

        # clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
        
        # drop
        drop_flow_embed = 0
        rand_num = random.random()
        if rand_num < self.i_drop_rate:
            drop_flow_embed = 1
        elif rand_num < (self.i_drop_rate + self.t_drop_rate):
            text = ""
        elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
            text = ""
            drop_flow_embed = 1
        # get text and tokenize
        text_input_ids = self.tokenizer(
            text,
            max_length=self.tokenizer.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids
        
        # Added for flow-Adapter:
        if self.dataset_name == "ueyes": # NOT Implemented yet
            flow_input = item["flow_input"]
        else:
            flow_input = None

        return {
            "image": image,
            "text_input_ids": text_input_ids,
            # "clip_image": clip_image,
            "drop_flow_embed": drop_flow_embed,
            "flow_input": flow_input
        }


    def __len__(self):
        return len(self.data)
    

def collate_fn(data):
    images = torch.stack([example["image"] for example in data])
    text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
    # clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
    drop_flow_embeds = [example["drop_flow_embed"] for example in data]

    return {
        "images": images,
        "text_input_ids": text_input_ids,
        # "clip_images": clip_images,
        "drop_flow_embeds": drop_flow_embeds
    }

## IPAdapter (Custom)

In [98]:
def get_generator(seed, device):

    if seed is not None:
        if isinstance(seed, list):
            generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
        else:
            generator = torch.Generator(device).manual_seed(seed)
    else:
        generator = None

    return generator

from typing import List
from diffusers.pipelines.controlnet import MultiControlNetModel
from ip_adapter.attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
class IPAdapter(torch.nn.Module):
    """Custom variation"""
    def __init__(self, pipe, correct_proj_model, ckpt_path=None, num_tokens=4):
        super().__init__()
        self.num_tokens = num_tokens
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pipe = pipe.to(self.device)
        self.correct_proj_model = correct_proj_model.to(self.device)
        self.adapter_modules = self.set_ip_adapter()

        if ckpt_path is not None:
            self.load_ip_adapter(ckpt_path)

        
    def set_ip_adapter(self):
        unet = self.pipe.unet
        attn_procs = {}
        for name in unet.attn_processors.keys():
            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]
            if cross_attention_dim is None:
                attn_procs[name] = AttnProcessor()
            else:
                attn_procs[name] = IPAttnProcessor(
                    hidden_size=hidden_size,
                    cross_attention_dim=cross_attention_dim,
                    scale=1.0,
                    num_tokens=self.num_tokens,
                ).to(self.device, dtype=torch.float16)
        unet.set_attn_processor(attn_procs)
        if hasattr(self.pipe, "controlnet"):
            if isinstance(self.pipe.controlnet, MultiControlNetModel):
                for controlnet in self.pipe.controlnet.nets:
                    controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
            else:
                self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))


    def forward(self, noisy_latents, timesteps, encoder_hidden_states, flow_embeds):
        # print("flow_embeds shape: ", flow_embeds.shape)
        # print("encoder_hidden_states shape: ", encoder_hidden_states.shape)
        ip_tokens = self.correct_proj_model(flow_embeds)
        # print("ip_tokens shape: ", ip_tokens.shape)
        
        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
        # Predict the noise residual
        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
        return noise_pred


    def load_ip_adapter(self, ckpt_path):
        state_dict = torch.load(ckpt_path, map_location="cpu")
        self.correct_proj_model.load_state_dict(state_dict["flow_proj"])
        ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
        ip_layers.load_state_dict(state_dict["ip_adapter"])
        print(f"Successfully loaded IP Adapter from checkpoint {ckpt_path}")
    

    def load_from_checkpoint(self, ckpt_path: str):
        # Calculate original checksums
        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.correct_proj_model.parameters()]))
        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        state_dict = torch.load(ckpt_path, map_location="cpu")

        # Load state dict for correct_proj_model and adapter_modules
        self.correct_proj_model.load_state_dict(state_dict["flow_proj"], strict=True)
        self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)

        # Calculate new checksums
        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.correct_proj_model.parameters()]))
        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        # Verify if the weights have changed
        assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of correct_proj_model did not change!"
        assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"

        print(f"Successfully loaded weights from checkpoint {ckpt_path}")


    # TODO: Fix this with EyeFormer and the flow_latenizer stuff instead of CLIP
    @torch.inference_mode()
    def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
        if pil_image is not None:
            if isinstance(pil_image, Image.Image):
                pil_image = [pil_image]
            clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
            clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
        else:
            clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
        image_prompt_embeds = self.correct_proj_model(clip_image_embeds)
        uncond_image_prompt_embeds = self.correct_proj_model(torch.zeros_like(clip_image_embeds))
        return image_prompt_embeds, uncond_image_prompt_embeds


    def set_scale(self, scale):
        for attn_processor in self.pipe.unet.attn_processors.values():
            if isinstance(attn_processor, IPAttnProcessor):
                attn_processor.scale = scale


    # TODO: Fix this with EyeFormer and the flow_latenizer stuff instead of CLIP
    def generate(
        self,
        pil_image=None,
        clip_image_embeds=None,
        prompt=None,
        negative_prompt=None,
        scale=1.0,
        num_samples=4,
        seed=None,
        guidance_scale=7.5,
        num_inference_steps=30,
        **kwargs):
        self.set_scale(scale)

        if pil_image is not None:
            num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
        else:
            num_prompts = clip_image_embeds.size(0)

        if prompt is None:
            prompt = "best quality, high quality, user interfaces"
        if negative_prompt is None:
            negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"

        if not isinstance(prompt, List):
            prompt = [prompt] * num_prompts
        if not isinstance(negative_prompt, List):
            negative_prompt = [negative_prompt] * num_prompts

        image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
            pil_image=pil_image, clip_image_embeds=clip_image_embeds
        )
        bs_embed, seq_len, _ = image_prompt_embeds.shape
        image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
        image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
        uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)

        with torch.inference_mode():
            prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
                prompt,
                device=self.device,
                num_images_per_prompt=num_samples,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )
            prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
            negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

        generator = get_generator(seed, self.device)

        images = self.pipe(
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
            **kwargs,
        ).images

        return images

## Args & Other Configs

In [89]:
class ARGS:
    pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
    pretrained_ip_adapter = "./sd-flow_adapter/custom_model_checkpoint-112000.pt"
    pretrained_flow_latenizer = "./sd-flow_adapter/checkpoint-112000/pytorch_model_1.bin"
    data_json_file = "/home/researcher/Documents/dataset/original_datasets/webui_prompts.json"
    resolution = 256
    dataloader_num_workers = 8
    dataset_name = "everything_else"
    output_dir = "./test_flowAdapter"

In [90]:
args = ARGS()

## Load Models

In [92]:
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")

# attn_procs = {}
# unet_sd = unet.state_dict()
# for name in unet.attn_processors.keys():
#     cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
#     if name.startswith("mid_block"):
#         hidden_size = unet.config.block_out_channels[-1]
#     elif name.startswith("up_blocks"):
#         block_id = int(name[len("up_blocks.")])
#         hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
#     elif name.startswith("down_blocks"):
#         block_id = int(name[len("down_blocks.")])
#         hidden_size = unet.config.block_out_channels[block_id]
#     if cross_attention_dim is None:
#         attn_procs[name] = AttnProcessor()
#     else:
#         layer_name = name.split(".processor")[0]
#         weights = {
#             "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
#             "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
#         }
#         attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, 
#                                            cross_attention_dim=cross_attention_dim)
#         attn_procs[name].load_state_dict(weights)
# unet.set_attn_processor(attn_procs)
# adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())

flowAE = FlowEncoder_MLP()
flow_latenizer = flowAE.encoder

if args.dataset_name == "ueyes":
    flowAE.load_state_dict(torch.load("/home/researcher/flowAE.pth")) # TODO: Put the right path here
    eyeFormer = None
else: 
    config = yaml.load(open("./configs/Tracking.yaml", 'r'), Loader=yaml.Loader)
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
    eyeFormer = TrackingTransformer(config = config, init_deit=False)
    checkpointEF = torch.load("/home/researcher/Documents/aryan/asciProject/flowEncoder/weights/checkpoint_19.pth",
                            map_location='cpu')
    state_dict = checkpointEF['model']
    eyeFormer.load_state_dict(state_dict)
    eyeFormer.requires_grad_(False)
    
flow_latenizer.requires_grad_(False)
        

Model will generate 16 points


Sequential(
  (0): Linear(in_features=45, out_features=128, bias=True)
  (1): LeakyReLU(negative_slope=0.3)
  (2): Linear(in_features=128, out_features=256, bias=True)
  (3): LeakyReLU(negative_slope=0.3)
  (4): Linear(in_features=256, out_features=512, bias=True)
  (5): LeakyReLU(negative_slope=0.3)
  (6): Linear(in_features=512, out_features=1024, bias=True)
)

In [93]:
f =  torch.load(args.pretrained_flow_latenizer)
print(f.keys()) 
fix_flow_latenizer = {}
for key in f.keys():
    fix_flow_latenizer[key.replace("flow_latenizer.", "")] = f[key]
flow_latenizer.load_state_dict(fix_flow_latenizer)

flow_projection_model = CorrectProjModel(
            cross_attention_dim=unet.config.cross_attention_dim,
            clip_embeddings_dim=1024,
            clip_extra_context_tokens=4)
# model_ckpt = torch.load("./sd-flow_adapter/custom_model_checkpoint-112000.pt")


odict_keys(['flow_latenizer.0.weight', 'flow_latenizer.0.bias', 'flow_latenizer.2.weight', 'flow_latenizer.2.bias', 'flow_latenizer.4.weight', 'flow_latenizer.4.bias', 'flow_latenizer.6.weight', 'flow_latenizer.6.bias'])


In [94]:
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float16)

In [95]:
# load SD pipeline
pipe = StableDiffusionPipeline.from_pretrained(
    args.pretrained_model_name_or_path,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    vae=vae,
    feature_extractor=None,
    safety_checker=None
)

Loading pipeline components...: 100%|██████████| 5/5 [00:01<00:00,  4.18it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


In [99]:
ip_adapter = IPAdapter(pipe, 
                       flow_projection_model,  
                       args.pretrained_ip_adapter)

Successfully loaded IP Adapter from checkpoint ./sd-flow_adapter/custom_model_checkpoint-112000.pt


In [None]:
# TODO: Input a flow prompt or a 
# UI image whose flow you want to use to generate new designs for

# TODO - A: Handcraft a flow vector
flow_vector = torch.randn(1, 45).to(torch.float16)

# TODO - B: Use a UI image to extract flow vector
image = Image.open("assets/images/woman.png")
image.resize((256, 256))

In [None]:
# generate image variations
images = ip_adapter.generate(pil_image=image, num_samples=4, num_inference_steps=50, seed=42)
grid = image_grid(images, 1, 4)
grid