# FlowAdapter M2

## Imports

In [12]:
import os
import random
import argparse
from pathlib import Path
import json
import itertools
import time
import ruamel.yaml as yaml

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from transformers import CLIPImageProcessor
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ip_adapter.ip_adapter import ImageProjModel
from ip_adapter.utils import is_torch2_available
if is_torch2_available():
    from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
else:
    from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor

from models_eyeformer.model_tracking import TrackingTransformer
from pytorchSoftdtwCuda.soft_dtw_cuda import SoftDTW
from tqdm.auto import tqdm

## Flow Encoders & Projectors

In [3]:

class FlowEncoder_MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.add_module('encoder', nn.Sequential(
            nn.Linear(45, 128),
            nn.LeakyReLU(0.3),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.3),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.3),
            nn.Linear(512, 1024)))

        self.add_module('decoder', nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.3),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.3),
            nn.Linear(128, 45)))

    def forward(self, x):
      x = x.view(x.size(0), -1)
      x = self.encoder(x)
      x = self.decoder(x)
      return x

In [4]:
class FlowEncoder(nn.Module):
    def __init__(self, cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens, 
                 eyeFormer, flow_latenizer):
        super().__init__()
        self.cross_attention_dim = cross_attention_dim
        self.clip_embeddings_dim = clip_embeddings_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.eyeFormer = eyeFormer
        self.flow_latenizer = flow_latenizer
    
    def forward(self, images):
        if self.eyeFormer is not None:
            flow_embeds = self.eyeFormer(images)
        else:
            flow_embeds = images
        
        flow_embeds = flow_embeds.view(flow_embeds.size(0), -1)
        flow_embeds = self.flow_latenizer(flow_embeds)
        # print("Final flow_embeds Shape: ", flow_embeds.shape)
        return flow_embeds

In [6]:
class CorrectProjModel(torch.nn.Module):
    """
    Correct the final flow embedding wiht a linear norm and final projection layer
    """

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

## Dataset 

In [None]:
class MyDataset(torch.utils.data.Dataset):

    def __init__(self, json_file, tokenizer, size=256, 
                 t_drop_rate=0.05, i_drop_rate=0.05, 
                 ti_drop_rate=0.05, dataset_name="ueyes"):
        super().__init__()

        self.tokenizer = tokenizer
        self.size = size
        self.i_drop_rate = i_drop_rate
        self.t_drop_rate = t_drop_rate
        self.ti_drop_rate = ti_drop_rate
        self.dataset_name = dataset_name

        self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]

        self.transform = transforms.Compose([
            transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(self.size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        # self.clip_image_processor = CLIPImageProcessor()
        

    def __getitem__(self, idx):
        item = self.data[idx] 
        text = item["prompt"]
        image_file = item["target"]
        
        # read image and flow vector
        raw_image = Image.open(image_file)
        image = self.transform(raw_image.convert("RGB"))

        # clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
        
        # drop
        drop_flow_embed = 0
        rand_num = random.random()
        if rand_num < self.i_drop_rate:
            drop_flow_embed = 1
        elif rand_num < (self.i_drop_rate + self.t_drop_rate):
            text = ""
        elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
            text = ""
            drop_flow_embed = 1
        # get text and tokenize
        text_input_ids = self.tokenizer(
            text,
            max_length=self.tokenizer.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids
        
        # Added for flow-Adapter:
        if self.dataset_name == "ueyes": # NOT Implemented yet
            flow_input = item["flow_input"]
        else:
            flow_input = None

        return {
            "image": image,
            "text_input_ids": text_input_ids,
            # "clip_image": clip_image,
            "drop_flow_embed": drop_flow_embed,
            "flow_input": flow_input
        }


    def __len__(self):
        return len(self.data)
    

def collate_fn(data):
    images = torch.stack([example["image"] for example in data])
    text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
    # clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
    drop_flow_embeds = [example["drop_flow_embed"] for example in data]

    return {
        "images": images,
        "text_input_ids": text_input_ids,
        # "clip_images": clip_images,
        "drop_flow_embeds": drop_flow_embeds
    }

## IPAdapter (Custom)

In [None]:
class IPAdapter(torch.nn.Module):
    """Custom variation"""
    def __init__(self, unet, correct_proj_model, adapter_modules, ckpt_path=None):
        super().__init__()
        self.unet = unet
        self.correct_proj_model = correct_proj_model
        self.adapter_modules = adapter_modules

        if ckpt_path is not None:
            self.load_from_checkpoint(ckpt_path)

    def forward(self, noisy_latents, timesteps, encoder_hidden_states, flow_embeds):
        # print("flow_embeds shape: ", flow_embeds.shape)
        # print("encoder_hidden_states shape: ", encoder_hidden_states.shape)
        ip_tokens = self.correct_proj_model(flow_embeds)
        # print("ip_tokens shape: ", ip_tokens.shape)
        
        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
        # Predict the noise residual
        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
        return noise_pred

    def load_from_checkpoint(self, ckpt_path: str):
        # Calculate original checksums
        orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.correct_proj_model.parameters()]))
        orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        state_dict = torch.load(ckpt_path, map_location="cpu")

        # Load state dict for correct_proj_model and adapter_modules
        self.correct_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
        self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)

        # Calculate new checksums
        new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.correct_proj_model.parameters()]))
        new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))

        # Verify if the weights have changed
        assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of correct_proj_model did not change!"
        assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"

        print(f"Successfully loaded weights from checkpoint {ckpt_path}")

## Args & Other Configs

In [14]:
class ARGS:
    pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
    pretrained_ip_adapter_path = "./sd-flow_adapter/checkpoint-216000/model.safetensors"
    data_json_file = "/home/researcher/Documents/dataset/original_datasets/webui_prompts.json"
    resolution = 256
    dataloader_num_workers = 8
    dataset_name = "everything_else"

In [15]:
args = ARGS()

## Load Models

In [16]:
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")

flowAE = FlowEncoder_MLP()
flow_latenizer = flowAE.encoder

if args.dataset_name == "ueyes":
    flowAE.load_state_dict(torch.load("/home/researcher/flowAE.pth")) # TODO: Put the right path here
    eyeFormer = None
else: 
    config = yaml.load(open("./configs/Tracking.yaml", 'r'), Loader=yaml.Loader)
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
    eyeFormer = TrackingTransformer(config = config, init_deit=False)
    checkpointEF = torch.load("/home/researcher/Documents/aryan/asciProject/flowEncoder/weights/checkpoint_19.pth",
                            map_location='cpu')
    state_dict = checkpointEF['model']
    eyeFormer.load_state_dict(state_dict)
    eyeFormer.requires_grad_(False)
    
flow_latenizer.requires_grad_(False)
        

AttributeError: 'ARGS' object has no attribute 'output_dir'

In [20]:
!pip install safetensors
from safetensors import safe_open


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [21]:
with safe_open(args.pretrained_ip_adapter_path, framework="pt", device="cpu") as f:
    for key in f.keys():
        if "correct_proj_model" in key:
            print(key)
        elif 


correct_proj_model.norm.bias
correct_proj_model.norm.weight
correct_proj_model.proj.bias
correct_proj_model.proj.weight
unet.conv_in.bias
unet.conv_in.weight
unet.conv_norm_out.bias
unet.conv_norm_out.weight
unet.conv_out.bias
unet.conv_out.weight
unet.down_blocks.0.attentions.0.norm.bias
unet.down_blocks.0.attentions.0.norm.weight
unet.down_blocks.0.attentions.0.proj_in.bias
unet.down_blocks.0.attentions.0.proj_in.weight
unet.down_blocks.0.attentions.0.proj_out.bias
unet.down_blocks.0.attentions.0.proj_out.weight
unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight
unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias
unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight
unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight
unet.down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight
unet.down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight
unet.down_blocks.0.attentions.0.tra