In [None]:

import torch
import logging
import statistics
import src.clip as clip
from torchinfo import summary
import torchvision.transforms as T
from accelerate import Accelerator
from yacs.config import CfgNode as CN
from main import get_config, init_accelerator, set_seed, FFPP
from src.models import Detector
logging.basicConfig(level="DEBUG")


class Obj:
    pass


c = FFPP.get_default_config()
c.augmentation = "normal+frame"
c.pair = 1
c.compressions = ["c23"]
c.types = ["REAL", "NT"]


mc = Detector.get_default_config()
mc.out_dim = [2]
mc.adapter = CN()
mc.adapter.frozen = 0
mc.adapter.struct = CN()
# mc.adapter.struct.type = "legacy-768-x-768"
mc.adapter.struct.type = "768-x-768-nln"
mc.adapter.struct.x = 256
mc.adapter.type = "normal"

accelerator = Accelerator(mixed_precision='no')
model = Detector(mc, 20, accelerator).to(accelerator.device).eval()
# encoder = clip.load("ViT-B/16")[0].visual.float()
encoder = model.encoder
adapter = model.adapter
adapter.load_state_dict(
    {
        k[8:]: v for k, v in
        torch.load(
            # "/home/od/Desktop/repos/dfd-clip/logs/comp-inv/comp-inv/mode1+256+resi+1e-2/last_weights.pt",
            "/home/od/Desktop/repos/dfd-clip/logs/test/0520T1305/best_weights.pt",
            map_location="cpu"
        ).items()
        if "adapter" in k
    }
)
model.eval()
model.to("cuda")

transform = T.Compose([
    T.Resize(encoder.input_resolution, interpolation=T.InterpolationMode.BICUBIC),
    T.CenterCrop(encoder.input_resolution),
    T.ConvertImageDtype(torch.float32),
    T.Normalize((0.48145466, 0.4578275, 0.40821073),
                (0.26862954, 0.26130258, 0.27577711)),
])

x = FFPP(c.clone(), 20, 5, transform, accelerator, split="train")
_x = FFPP(c.clone(), 20, 5, lambda x: x, accelerator, split="train")
c

In [None]:
import random
idx = random.randint(0, len(x))
data = x[idx]
_data = _x[idx]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.figure(figsize=(4, 2))
plt.subplot(1, 2, 1)
plt.imshow(data[0]["raw"][0].numpy().transpose((1, 2, 0)))
plt.subplot(1, 2, 2)
plt.imshow(data[0]["c23"][0].numpy().transpose((1, 2, 0)))

In [None]:
# features = {}
# nfeatures = {}
# with torch.no_grad():
#     for k, v in data[0].items():
#         # get key and value from each CLIP ViT layer
#         kvs = encoder(v[0].unsqueeze(0).to("cuda"))
#         # discard original CLS token and restore temporal dimension
#         kvs = [{k: v[:, 1:] for k, v in kv.items()} for kv in kvs]

#         _kvs = [{k: v.unsqueeze(0) for k, v in kv.items()} for kv in kvs]
#         _kvs = adapter([kvs[i] for i in range(0, 12, 2)])

#         kvs = [{k: v.squeeze(0).view((-1, 768)).to("cpu") for k, v in kv.items()} for kv in kvs]
#         features[k] = kvs
#         _kvs = [{k: v.squeeze(0).view((-1, 768)).to("cpu") for k, v in kv.items()} for kv in _kvs]
#         nfeatures[k] = _kvs


# torch.cuda.empty_cache()

In [None]:
features = {}
nfeatures = {}
with torch.no_grad():
    for k, v in data[0].items():
        # get key and value from each CLIP ViT layer
        kvs = encoder(v.to("cuda"))
        # discard original CLS token and restore temporal dimension
        kvs = [{k: v[:, 1:] for k, v in kv.items()} for kv in kvs]

        _kvs = [{k: v.unsqueeze(0) for k, v in kv.items()} for kv in kvs]
        _kvs = adapter([_kvs[i] for i in [5, 6, 7, 8, 9, 10]])
        kvs = [{k: v[0].view((-1, 768)).to("cpu") for k, v in kv.items()} for kv in kvs]
        features[k] = kvs
        _kvs = [{k: v[:, 0].view((-1, 768)).to("cpu") for k, v in kv.items()} for kv in _kvs]
        nfeatures[k] = _kvs


torch.cuda.empty_cache()

In [None]:
methods = []

In [None]:
l = {
    "k": [],
    "v": []
}
for layer in range(12):
    for subject in ["k", "v"]:
        l[subject].append(
            torch.nn.functional.mse_loss(
                features["raw"][layer][subject],
                features["c23"][layer][subject],
                reduction="none"
            ).mean(dim=-1).view(14, 14).numpy()
        )
methods.append(l)

In [None]:
l = {
    "k": [],
    "v": []
}
for layer in range(12):
    for subject in ["k", "v"]:
        l[subject].append(
            torch.nn.functional.kl_div(
                torch.nn.functional.log_softmax(features["raw"][layer][subject], dim=-1),
                torch.nn.functional.log_softmax(features["c23"][layer][subject], dim=-1),
                log_target=True,
                reduction="none"
            ).mean(dim=-1).view(14, 14).numpy()
        )
methods.append(l)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
for l in methods:
    plt.figure(figsize=(36, 6))
    for j, s in enumerate(["k", "v"]):
        for i, v in enumerate(l[s]):
            plt.subplot(2, 12, j * 12 + i + 1)
            im = plt.imshow(v)

            # Show all ticks and label them with the respective list entries
            plt.gca().set_xticks(np.arange(14))
            plt.gca().set_yticks(np.arange(14))

            # Rotate the tick labels and set their alignment.
            # plt.setp(plt.gca().get_xticklabels(), rotation=45, ha="right",
            #          rotation_mode="anchor")

    plt.tight_layout()
    plt.show()

In [None]:
nmethods = []

In [None]:
l = {
    "k": [],
    "v": []
}
for layer in range(6):
    for subject in ["k", "v"]:
        l[subject].append(
            torch.nn.functional.mse_loss(
                nfeatures["raw"][layer][subject],
                nfeatures["c23"][layer][subject],
                reduction="none"
            ).mean(dim=-1).view(14, 14).numpy()
        )
nmethods.append(l)

In [None]:
l = {
    "k": [],
    "v": []
}
for layer in range(6):
    for subject in ["k", "v"]:
        l[subject].append(
            torch.nn.functional.kl_div(
                torch.nn.functional.log_softmax(nfeatures["raw"][layer][subject], dim=-1),
                torch.nn.functional.log_softmax(nfeatures["c23"][layer][subject], dim=-1),
                log_target=True,
                reduction="none"
            ).mean(dim=-1).view(14, 14).numpy()
        )
nmethods.append(l)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
loc = [5, 6, 7, 8, 9, 10]
for l in nmethods:
    plt.figure(figsize=(36, 6))
    for j, s in enumerate(["k", "v"]):
        for i, v in enumerate(l[s]):
            plt.subplot(2, 12, j * 12 + loc[i] + 1)
            im = plt.imshow(v)

            # Show all ticks and label them with the respective list entries
            plt.gca().set_xticks(np.arange(14))
            plt.gca().set_yticks(np.arange(14))

            # Rotate the tick labels and set their alignment.
            # plt.setp(plt.gca().get_xticklabels(), rotation=45, ha="right",
            #          rotation_mode="anchor")

    plt.tight_layout()
    plt.show()

In [None]:
len(nmethods[0]["k"])