In [1]:
import os, re
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import torch
device = torch.device('cuda')
from reg_lora.clip_ti_reg import  CLIPTiTextModel
from transformers import CLIPTokenizer
from safetensors.torch import safe_open
from custom_datasets.utils import parse_templates_class_name

In [2]:
tokenizer = CLIPTokenizer.from_pretrained(
    "models/stable-diffusion-v1-5",
    subfolder="tokenizer",
    local_files_only=True,
)
clip_encoder = CLIPTiTextModel.from_pretrained(
    "models/stable-diffusion-v1-5",
    subfolder="text_encoder",
    local_files_only=True,
).to(device)

In [9]:
def calc_shift(lora_ckpt, mode, superclass, additional_prompt=[]):

    if len(additional_prompt) > 0:
        data = [prompt.replace('[class]', superclass) for prompt in additional_prompt]
    else:
        data = [f"photo of a {superclass}"]
    print(data)
    input_ids = tokenizer(
        data,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
        max_length=tokenizer.model_max_length,
    ).input_ids.to(device)

    text_feature = clip_encoder(input_ids)[0].data.mean(0).t().half()

    if mode in ['baseline', 'textReg']:
        safeloras = safe_open(os.path.join(lora_ckpt, "lora_weight_s900.safetensors"), framework="pt", device="cuda")
    elif mode in ['priorReal', 'priorGen']:
        safeloras = safe_open(os.path.join(lora_ckpt, "lora_weight_s700.safetensors"), framework="pt", device="cuda")
    lora_dict = {}
    for key in safeloras.keys():
        if "unet" in key:
            lora_dict[key] = safeloras.get_tensor(key)

    weight_dict, text_dict = {}, {}
    for key, value in lora_dict.items():
        if ':up' in key:
            result_key = key.replace(':up', '')
            weight_dict[result_key] = torch.matmul(value, lora_dict[key.replace(':up', ':down')])
            text_dict[result_key] = torch.matmul(weight_dict[result_key], text_feature)
    weight_shift = torch.norm(torch.cat(list(weight_dict.values())), p=2)
    text_shift = torch.norm(torch.cat(list(text_dict.values())), p=2)
    return weight_shift, text_shift

def calculate_means(tuple_list):
    sum_1 = 0
    sum_2 = 0
    for tup in tuple_list:
        sum_1 += tup[0]
        sum_2 += tup[1]
    mean_1 = sum_1 / len(tuple_list)
    mean_2 = sum_2 / len(tuple_list)
    return mean_1, mean_2

In [14]:
log_dir = "logs/log_ablation/image_reg/log_type/1_default"

SHIFT_DICT = {'baseline': [], 'priorReal': [], 'priorGen': [], 'textReg': []}
for root, dirs, files in sorted(os.walk(log_dir)):
    for dir in sorted(dirs):
        if any(k in dir for k in ['baseline', 'priorReal', 'priorGen', 'textReg']):
            mode = re.search(r'baseline|priorReal|priorGen|textReg', dir).group()
            target_name  = root.split("/")[-1]
            superclass = parse_templates_class_name(target_name)[-1]
            lora_ckpt = os.path.join(root, dir)
            weight_shift, text_shift = calc_shift(lora_ckpt, mode, superclass, additional_prompt=['[class] near the beach'])
            SHIFT_DICT[mode].append((weight_shift, text_shift))

for k, v in SHIFT_DICT.items():
    if len(v) > 0:
        weight_shift, text_shift = calculate_means(v)
        print(f"Mode '{k}': weight_shift {weight_shift:.2f}, |  text_shift {text_shift:.2f}, ")

['bird near the beach']
['bird near the beach']
['cat near the beach']
['clock near the beach']
['dog near the beach']
['teddybear near the beach']
['tortoise_plushy near the beach']
['wooden_pot near the beach']
Mode 'priorGen': weight_shift 7.76, |  text_shift 158.50, 
