In [None]:
import cv2
import torch
import random
import pickle
import logging
import statistics
import numpy as np
from tqdm import tqdm
import src.clip as CLIP
from torchinfo import summary
import matplotlib.pyplot as plt
import torchvision.transforms as T
import albumentations as alb
from accelerate import Accelerator
from yacs.config import CfgNode as CN
from main import get_config, init_accelerator, set_seed, FFPP
from src.models import Detector
# logging.basicConfig(level="DEBUG")


class Obj:
    pass


@torch.no_grad()
def extract_features(encoder, clip):
    # get attributes from each CLIP ViT layer
    kvs = encoder(clip.to("cuda"), with_out=True, with_q=True)
    # discard original CLS token and restore temporal dimension
    for i in range(len(kvs)):
        for k in kvs[i].keys():
            kvs[i][k] = kvs[i][k][:, 1:].to("cpu")
            if (not k == "out"):
                kvs[i][k] = kvs[i][k].flatten(-2)
    torch.cuda.empty_cache()
    return kvs


def var(x, patch_num, *args, **kargs):
    # shows the variance of the video patches over the temporal dimension.
    return (
        torch.var(x, dim=0)
        .mean(dim=-1)
        .view((patch_num, patch_num))
        .unsqueeze(-1)
    )


def max_stdev(x, patch_num, *args, **kargs):
    # shows the maximum stdev value of the video patches averaged over the temporal dimension.
    return (
        torch.max(
            torch.abs((x - torch.mean(x, dim=0)) / torch.sqrt(torch.var(x, dim=0))).mean(dim=-1),
            dim=0
        )[0]
        .view(patch_num, patch_num)
        .unsqueeze(-1)
    )


def one_patch_cos_sim(x, t, c, patch_num, *args, **kargs):
    # shows the video patch similarities given a patch location
    return (
        torch.nn.functional.cosine_similarity(x, x[t, c], dim=-1)
        .view((-1, patch_num, patch_num))
        .permute(1, 0, 2)
        .flatten(1, 2)
        .unsqueeze(-1)
    )


def semantic_patch_cos_sim(x, patch_num, part, _s, _l, semantic_patches, s=None, *args, **kargs):
    # shows the video patch similarities given a semantic patch
    # _l -> the layer of the feature x
    # s -> the mandatory subject, overwrites _s
    # _s -> the subject of the feature x
    return (
        torch.nn.functional.cosine_similarity(
            x,
            semantic_patches[_s if s == None else s][part][_l],
            dim=-1
        )
        .view((-1, patch_num, patch_num))
        .permute(1, 0, 2)
        .flatten(1, 2)
        .unsqueeze(-1)
    )


def plotter(features, title="", mode="subject-layer", num_layers=16, unit_size=3, font_size=12):
    keys = list(features.keys())
    num_keys = len(keys)
    num_layers = len(features[keys[0]])
    num_frames = features[keys[0]][0].shape[0]

    def create():
        if mode == "subject-layer":
            plt.figure(figsize=(unit_size * num_layers, unit_size * num_keys), layout="constrained")
            plt.suptitle(title, fontsize=font_size)
        elif mode == "layer-frame":
            plt.figure(figsize=(unit_size * num_frames, unit_size * num_layers), layout="constrained")
            plt.suptitle(title, fontsize=font_size)

    def show():
        plt.tight_layout()
        plt.show()

    if mode == "subject-layer":
        create()
        for j, s in enumerate(features.keys()):
            for i, v in enumerate(features[s]):
                plt.subplot(num_keys, num_layers, j * num_layers + i + 1)
                plt.title(f"L{i}-{s.upper()}")
                plt.gca().axis("off")
                plt.imshow(v)
        show()

    elif mode == "layer-frame":
        for j, s in enumerate(features.keys()):
            create()
            for i, v in enumerate(features[s]):
                plt.subplot(num_layers, 1, i + 1)
                plt.title(f"L{i}-{s.upper()}")
                plt.gca().axis("off")
                plt.imshow(v)
            show()
    else:
        raise NotImplementedError()


def driver(features, method, subjects=None, patch_num=14, **kargs):
    if subjects == None:
        subjects = list(features[0].keys())

    assert features[0][subjects[0]].shape[1] == patch_num**2

    r = {
        k: [] for k in subjects
    }
    num_layers = len(features)
    for l in range(num_layers):
        for s in subjects:
            # variance
            r[s].append(method(features[l][s], patch_num=patch_num, _l=l, _s=s, ** kargs).float())

    return r


def fetch_semantic_features(
    types=["REAL", "NT", "DF", "FS", "F2F"],
    subjects=['q', 'k', 'v', 'out'],
    seconds=1,
    frames=2,
    patch_num=14,
    sample_num=100,
    save_path="",
    seed=None,
    visualize=False,
    vpt_weight_path="",
    farl_weight_path="",
):
    # the function summerizes the patch value at specific locations with semantic meanings.
    assert not (len(vpt_weight_path) > 0 and len(farl_weight_path) > 0)

    # create dataset & models
    c = FFPP.get_default_config()
    c.augmentation = "none"
    c.random_speed = False
    c.compressions = ["c23"]
    c.types = types

    accelerator = Accelerator(mixed_precision='no')

    encoder = CLIP.load("ViT-B/16")[0]
    if len(farl_weight_path) > 0:
        encoder.load_state_dict(
            torch.load(farl_weight_path)["state_dict"],
            strict=False
        )

    encoder = encoder.visual.float()
    if len(vpt_weight_path) > 0:
        encoder.load_state_dict({
            k[8:]: v for k, v in torch.load(vpt_weight_path, "cpu").items() if "encoder" == k[:7]
        })

    n_px = encoder.input_resolution

    transform = T.Compose(
        [
            T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
            T.CenterCrop(n_px),
            T.ConvertImageDtype(torch.float32),
            T.Normalize((0.48145466, 0.4578275, 0.40821073),
                        (0.26862954, 0.26130258, 0.27577711)),
        ]
    )

    dataset = FFPP(
        c.clone(),
        frames,
        seconds,
        T.Compose([
            T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC)
        ]),
        accelerator,
        split="train"
    )

    # the following patch locations are based on a 14x14 grid.
    semantic_locations = {
        "eyes": [
            [4, 3], [4, 4], [4, 9], [4, 10]
        ],
        "nose": [
            [7, 6], [6, 6], [5, 6],
            [7, 7], [6, 7], [5, 7]
        ],
        "lips": [
            [10, 5], [10, 6], [10, 7], [10, 8]
        ],
        "eyebrows": [
            [2, 3], [2, 4],
            [2, 9], [2, 10]
        ],
        "skin": [
            [0, 6], [0, 7],
            [1, 6], [1, 7],
            [7, 3], [7, 4], [7, 9], [7, 10],
            [12, 6], [12, 7],
            # [13, 6], [13, 7]
        ]
    }

    # part keypoint data, prepared for frame augmentations
    part_keypoints = []
    part_labels = []

    for part in semantic_locations:
        for loc in semantic_locations[part]:
            part_keypoints.append(loc)
            part_labels.append(part)
    part_keypoints = (np.array(part_keypoints) + 0.5) / 14 * n_px

    # frame augmentations
    augmentations = alb.Compose(
        [
            alb.Flip(p=1.0),
            alb.RandomResizedCrop(
                n_px, n_px,
                scale=(0.3, 0.7), ratio=(1, 1),
                p=1.0,
            )
        ],
        keypoint_params=alb.KeypointParams(format='yx', remove_invisible=True, label_fields=["part_name"])
    )

    # container
    layer_num = len(encoder.transformer.resblocks)
    semantic_patches = {
        s: {
            k: [[] for _ in range(layer_num)]
            for k in semantic_locations.keys()
        }
        for s in subjects
    }

    if (seed):
        random.seed(seed)

    # random samples
    for _ in range(sample_num):
        ########## random select index ############
        idx = random.randint(0, len(dataset))
        data = dataset[idx]

        ######### extract video features ###########
        for k, clip in data[0].items():
            # sample augmentation
            frame = clip[0].permute((1, 2, 0)).numpy()
            result = augmentations(image=frame, keypoints=part_keypoints, part_name=part_labels)
            rrc_frame, rrc_kp, rrc_lb = result["image"], result["keypoints"], result["part_name"]

            rrc_semantic_locations = {
                part: []
                for part in semantic_locations.keys()
            }

            for loc, part in zip(rrc_kp, rrc_lb):
                rrc_semantic_locations[part].append(loc)

            for part in rrc_semantic_locations.keys():
                rrc_semantic_locations[part] = np.round(
                    np.clip(np.array(rrc_semantic_locations[part]) / n_px * 14 - 0.5, a_min=0, a_max=14)
                )

            if visualize:
                # visualization for debugging
                # > visualize ordinary frame
                plt.imshow(frame)
                plt.scatter(part_keypoints[..., 1], part_keypoints[..., 0])
                plt.show()
                # > visualize frame after augmentation
                rrc_kp = np.array(rrc_kp)
                plt.imshow(rrc_frame)
                plt.scatter(rrc_kp[..., 1], rrc_kp[..., 0])
                plt.show()
                # > visualize frame and keypoints at patch level
                plt.imshow(cv2.resize(rrc_frame, (14, 14)))
                for part in rrc_semantic_locations.keys():
                    if (len(rrc_semantic_locations[part]) > 0):
                        plt.scatter(rrc_semantic_locations[part][..., 1], rrc_semantic_locations[part][..., 0])
                plt.show()

            # extract frame features
            features = extract_features(
                encoder,
                transform(
                    torch.from_numpy(rrc_frame).permute((2, 0, 1)).unsqueeze(0)
                )
            )

            # post-process semantic locations
            rrc_semantic_locations = {
                k: [
                    int(_v[0] / 13 * (patch_num - 1)) * patch_num + int(_v[1] / 13 * (patch_num - 1))
                    for _v in v
                ]
                for k, v in rrc_semantic_locations.items()
            }

            ######### extract video features ###########
            for l in range(layer_num):
                for s in subjects:
                    for p, loc in rrc_semantic_locations.items():
                        semantic_patches[s][p][l].extend(
                            features[l][s][0, loc].tolist()
                        )

    semantic_patches = {
        s: {
            p: [
                torch.tensor(semantic_patches[s][p][l]).mean(dim=0)
                for l in range(layer_num)
            ]
            for p in semantic_locations.keys()
        }
        for s in subjects
    }

    if (len(save_path) > 0):
        with open(save_path, "wb") as f:
            pickle.dump(semantic_patches, f)

    return semantic_patches

In [None]:
# all_semantic_features = fetch_semantic_features(sample_num=100)
# semantic_features = fetch_semantic_features(
#     ["REAL"],
#     sample_num=1000,
#     save_path="./misc/real_semantic_patches_1000.pickle"
# )

# semantic_features = fetch_semantic_features(
#     vpt_encoder="logs/test/olive-water-1118/best_weights.pt",
#     sample_num=300,
#     save_path="./misc/ow_semantic_patch.pickle"
# )

# semantic_features = fetch_semantic_features(
#     farl_weight_path="./misc/FaRL-Base-Patch16-LAIONFace20M-ep64.pth",
#     sample_num=1000,
#     save_path="./misc/farl_semantic_patches_v3.pickle"
# )

In [None]:
# real_semantic_features = fetch_semantic_features(["REAL"], sample_num=100, seed=10)
# df_semantic_features = fetch_semantic_features(["DF"], sample_num=100, seed=10)
# fs_semantic_features = fetch_semantic_features(["FS"], sample_num=100, seed=10)
# f2f_semantic_features = fetch_semantic_features(["F2F"], sample_num=100, seed=10)
# nt_semantic_features = fetch_semantic_features(["NT"], sample_num=100, seed=10)
# all_semantic_features = fetch_semantic_features(sample_num=100, seed=10)

# name_features = {"REAL": real_semantic_features, "DF": df_semantic_features, "FS": fs_semantic_features,
#                  "F2F": f2f_semantic_features, "NT": nt_semantic_features, "ALL": all_semantic_features}
# for subject in real_semantic_features.keys():
#     for part in real_semantic_features[subject].keys():
#         # for layer in range(len(real_semantic_features[subject][part])):
#         for layer in [10, 11]:
#             print(f"===S:{subject} L:{layer} P:{part.upper()}===")
#             part_features = torch.stack([name_features[k][subject][part][layer] for k in name_features])
#             score = torch.nn.functional.cosine_similarity(
#                 part_features.unsqueeze(0),
#                 part_features.unsqueeze(1),
#                 dim=-1
#             )
#             print(list(name_features.keys()))
#             print(f"min:{torch.min(score)}")
#             print(score)

In [None]:
# # create dataset & models
# c = FFPP.get_default_config()
# c.augmentation = "none"
# c.random_speed = True
# c.compressions = ["c23"]
# c.types = ["REAL", "DF", "FS", "F2F", "NT"]
# frames = 2
# seconds = 4
# accelerator = Accelerator(mixed_precision='no')
# encoder = CLIP.load("ViT-B/16")[0].visual.float()
# n_px = encoder.input_resolution

# transform = T.Compose(
#     [
#         T.RandomVerticalFlip(),
#         T.RandomResizedCrop(n_px, (0.3, 0.5), (1, 1), interpolation=T.InterpolationMode.BICUBIC),
#         # T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
#         T.CenterCrop(n_px),
#         T.ConvertImageDtype(torch.float32),
#         T.Normalize((0.48145466, 0.4578275, 0.40821073),
#                     (0.26862954, 0.26130258, 0.27577711)),
#     ]
# )

# dataset = FFPP(c.clone(), frames, seconds, transform, accelerator, split="train")

In [None]:
# fetch_semantic_features(sample_num=10, visualize=True)

In [None]:
# data = dataset[random.randint(0, len(dataset))]
# features = extract_features(encoder, data[0]["c23"])

# evals = driver(features, var)
# plotter(evals, "", "subject-layer", unit_size=2)


# evals = driver(features, one_patch_cos_sim, t=0, c=21)
# plotter(evals, "", "layer-frame", unit_size=2)


# evals = driver(features, semantic_patch_cos_sim, part="lips", s="k", semantic_patches=semantic_patches)
# plotter(evals, "", "layer-frame", unit_size=2)

In [None]:
# semantic_patches = fetch_semantic_features(sample_num=500)

In [None]:
# data = dataset[random.randint(0, len(dataset))]
# features = extract_features(encoder, data[0]["c23"])
# evals = driver(features, semantic_patch_cos_sim, part="eyes", s="k", semantic_patches=semantic_patches)
# frame = data[0]["c23"][0]
# frame = (frame.permute((1, 2, 0)) - frame.min())/(frame.max()-frame.min()).numpy()
# print(frame.max())
# plt.imshow(frame)
# plt.show()
# plotter(evals, "", "layer-frame", unit_size=2)

In [None]:
encoder = CLIP.load("ViT-B/16")[0]
# TODO: remove FaRL
# encoder.load_state_dict(
#     torch.load("./misc/FaRL-Base-Patch16-LAIONFace20M-ep64.pth")["state_dict"],
#     strict=False
# )
encoder = encoder.visual.float()
n_px = encoder.input_resolution

In [None]:
import os
from tqdm import tqdm


def sensitivity_for_perturbations(
    encoder,
    n_px,
    num_samples,
    attributes=['q', 'k', 'v', 'out'],
    augmentation="perturbations",
):
    c = FFPP.get_default_config()
    c.random_speed = False
    c.compressions = ["c23"]
    c.types = ["REAL", "DF", "FS", "F2F", "NT"]
    frames = 1
    seconds = 3
    accelerator = Accelerator(mixed_precision='no')

    transform = T.Compose(
        [
            T.Resize(n_px),
            T.CenterCrop(n_px),
            T.ConvertImageDtype(torch.float32),
            T.Normalize((0.48145466, 0.4578275, 0.40821073),
                        (0.26862954, 0.26130258, 0.27577711)),
        ]
    )

    c.augmentation = "none"
    dataset1 = FFPP(c.clone(), frames, seconds, transform, accelerator, split="train")
    c.augmentation = augmentation
    dataset2 = FFPP(c.clone(), frames, seconds, transform, accelerator, split="train")

    # storage for patch-wise cosine distance of attention attributes
    storage = [{k: torch.zeros(196) for k in attributes} for _ in range(12)]

    for i in tqdm(range(num_samples)):
        idx = random.randrange(0, len(dataset1))
        data1 = dataset1[idx]
        data2 = dataset2[idx]
        #######################################
        # plt.figure(figsize=(50, 5))
        # plt.subplot(2, 1, 1)
        # plt.imshow(
        #     np.stack(
        #         (data1[0]["c23"][:30]).numpy().transpose((0, 2, 3, 1)), axis=1
        #     ).reshape((n_px, -1, 3))
        # )
        # plt.subplot(2, 1, 2)
        # plt.imshow(
        #     np.stack(
        #         (data2[0]["c23"][:30]).numpy().transpose((0, 2, 3, 1)), axis=1
        #     ).reshape((n_px, -1, 3))
        # )
        #######################################
        features1 = extract_features(encoder, data1[0]["c23"])
        features2 = extract_features(encoder, data2[0]["c23"])
        for i in range(12):
            for attr in attributes:
                storage[i][attr] += (
                    1 - torch.nn.functional.cosine_similarity(
                        features1[i][attr], features2[i][attr], dim=-1
                    )
                ).squeeze(0) / 2 / num_samples
    return storage


target_folder = "./misc/attn_attr_sens/"
os.makedirs(target_folder, exist_ok=True)
scenario_storages = {}
for aug_type in [
    "perturbations",
    "dev-mode+force-rgb",
    "dev-mode+force-hue",
    "dev-mode+force-bright",
    "dev-mode+force-comp",
    "dev-mode+force-dscale",
    "dev-mode+force-sharpen",
]:
    scenario_storages[aug_type] = sensitivity_for_perturbations(
        encoder,
        n_px,
        100,
        augmentation=aug_type
    )


# find max & min
global_max = -10
global_min = 10
for storage in scenario_storages.values():
    for i, attrs in enumerate(storage):
        for j, attr in enumerate(attrs):
            # record
            global_min = min(storage[i][attr].min().item(), global_min)
            global_max = max(storage[i][attr].max().item(), global_max)

for aug_type, storage in scenario_storages.items():
    plt.figure(figsize=(len(storage) * 1, len(storage[0]) * 1), layout="constrained")
    for i, attrs in enumerate(storage):
        for j, attr in enumerate(attrs):
            plt.subplot(len(storage[0]), len(storage), j * len(storage) + i + 1)
            data = (storage[i][attr].view(14, 14).numpy() - global_min) / (global_max - global_min)
            plt.imshow(data, vmin=0, vmax=1)
            plt.gca().axis("off")
    plt.savefig(os.path.join(target_folder, f"{aug_type}.pdf"))
    plt.close()