In [None]:
import os
import yaml
import torch
import pickle
import logging
import statistics
import cv2
import numpy as np
import random
import matplotlib.pyplot as plt
import torchvision.transforms as T
from accelerate import Accelerator
from yacs.config import CfgNode as CN
from main import get_config, init_accelerator, set_seed, FFPP
from src.models import Detector

logging_fmt = "[%(filename)s:%(lineno)d]: %(message)s"
logging.basicConfig(level="INFO", format=logging_fmt)
torch.cuda.empty_cache()


@torch.no_grad()
def run(
    cfg_path,
    num_frames=20,
    df_types=["REAL", "NT", "FS", "F2F", "DF"],
    parts=["lips", "skin"],
):
    torch.cuda.empty_cache()

    device = "cuda"

    accelerator = Accelerator(mixed_precision="no")

    # instantiate model
    with open(cfg_path) as f:
        preset = CN(yaml.safe_load(f))

    mc = Detector.get_default_config().merge_from_other_cfg(preset.model)
    mc.op_mode.attn_record = True

    model = Detector(mc, num_frames, accelerator).to(accelerator.device)
    model.load_state_dict(
        torch.load(
            os.path.join(os.path.split(cfg_path)[0], "best_weights.pt"),
            map_location="cpu",
        )
    )
    model.eval()
    model = model.to("cuda")

    # instantiate data loader
    c = FFPP.get_default_config()
    c.pack = 1
    c.augmentation = "none"
    c.random_speed = False
    c.compressions = ["c23"]
    c.types = df_types

    transform = T.Compose(
        [
            T.Resize(
                model.encoder.input_resolution,
                interpolation=T.InterpolationMode.BICUBIC,
            ),
            T.CenterCrop(model.encoder.input_resolution),
            T.ConvertImageDtype(torch.float32),
            T.Normalize(
                (0.48145466, 0.4578275, 0.40821073),
                (0.26862954, 0.26130258, 0.27577711),
            ),
        ]
    )

    x = FFPP(c.clone(), num_frames, 4, transform, accelerator, split="test")

    # load semantic patch features according to decode layers.
    with open("misc/semantic_patches.pickle", "rb") as f:
        prefetch_queries = pickle.load(f)
        semantic_queries = []
        for i in model.layer_indices:
            features = []
            for part in parts:
                features.append(prefetch_queries["k"][part][i])
            semantic_queries.append(torch.stack(features))
        semantic_queries = torch.stack(semantic_queries)

    # sample video clip
    vid_idx = random.randrange(0, len(x))
    clips, label, masks, speed, meta, task_index = x[vid_idx]
    clip = clips[0]
    mask = masks[0]

    # inference for layer_results and Qs
    _, features = model.predict(
        clip.unsqueeze(0).to(device), mask.unsqueeze(0).to(device)
    )

    # layer attention results
    layer_results = features["layer_results"]

    # locate attention blocks
    image_attn_blocks = list(
        dict(model.decoder.transformer.resblocks.named_children()).values()
    )

    # traverse overall blocks
    for i in range(len(image_attn_blocks)):
        blk = image_attn_blocks[i]
        logging.info(f"# Layer:{i}")

        # similarity over all semantic queries of the current layer
        logging.info(
            "## Semantic: {}".format(
                torch.nn.functional.cosine_similarity(
                    semantic_queries[i].unsqueeze(1), semantic_queries[i], dim=-1
                )
                .abs()
                .mean(1)
                .flatten()
                .tolist()
            )
        )

        for _i, _q in enumerate(blk.attn.qs):
            # _q has a shape of b,q,h,d
            _q = _q.detach().cpu()
            _q = _q.view(*_q.shape[:2], -1)
            # since the number of batch must be 1 in this setting, we squeeze the tensor.
            _q = _q.squeeze(0)  # _q now has shape q,h*d

            logging.info(
                "## Q{} Cross: {}".format(
                    _i,
                    torch.nn.functional.cosine_similarity(_q.unsqueeze(1), _q, dim=-1)
                    .abs()
                    .mean(1)
                    .flatten()
                    .tolist(),
                )
            )
            logging.info(
                "## Q{} Semantic: {}".format(
                    _i,
                    torch.nn.functional.cosine_similarity(
                        _q, semantic_queries[i], dim=-1
                    )
                    .flatten()
                    .tolist(),
                )
            )

        r = layer_results[:, i].detach().cpu()
        # since the number of batch must be 1 in this setting, we squeeze the tensor.
        r = r.squeeze(0)

        logging.info(
            "## Out Cross: {}".format(
                torch.nn.functional.cosine_similarity(r.unsqueeze(1), r, dim=-1)
                .abs()
                .mean(1)
                .flatten()
                .tolist()
            )
        )

In [None]:
# none
# cfg_path = "logs/test/dark-night-989/setting.yaml"

# dissim
# cfg_path = "logs/test/divine-pond-998/setting.yaml"

# semantic v1
# cfg_path = "logs/test/hopeful-violet-1004/setting.yaml"

# semantic v1 + smax
# cfg_path = "logs/test/still-terrain-1023/setting.yaml"

# semantic v2 + smax
# cfg_path = "logs/test/iconic-dragon-1034/setting.yaml"

# semantic_v2,k + smax
# cfg_path = "logs/test/silver-gorge-1037/setting.yaml"

# semantic_v2,k(+nose) + smax
# cfg_path = "logs/test/mild-gorge-1039/setting.yaml"

# 5,semantic_v3.1,k + smax
# cfg_path = "logs/test/celestial-frost-1048/setting.yaml"

# 2,semantic_v2,k + smax, gpv2
cfg_path = "logs/test/restful-water-1095/setting.yaml"


# run(cfg_path, parts=["lips", "skin", "nose", "eyes", "eyebrows"])
# run(cfg_path, parts=["lips", "skin", "nose"])
run(cfg_path)