In [None]:
import os

import argparse
import pickle as pkl
import random

import open_clip
import numpy as np
import torch
import torch.nn as nn
import yaml
from scipy.stats import pearsonr, spearmanr
from scipy.stats import kendalltau as kendallr
from tqdm import tqdm

from buona_vista import datasets

import wandb

In [None]:
def rescale(x):
    x = np.array(x)
    print("Mean:", x.mean(), "Std", x.std())
    x = (x - x.mean()) / x.std()
    return 1 / (1 + np.exp(-x))

In [None]:
    with open("buona_vista_sa_index.yml", "r") as f:
        opt = yaml.safe_load(f)   
    
    val_datasets = {}
    for name, dataset in opt["data"].items():
        val_datasets[name] = getattr(datasets, dataset["type"])(dataset["args"])

    print(open_clip.list_pretrained())
    model, _, _ = open_clip.create_model_and_transforms("RN50",pretrained="openai")
    model = model.to("cuda")
    print("loading succeed")

    texts = [
        "high quality",
        "low quality",
        "a good photo",
        "a bad photo",
    ]
    tokenizer = open_clip.get_tokenizer("RN50")
    text_tokens = tokenizer(texts).to("cuda")
    print(f"Prompt_loading_succeed, {texts}")

    results = {}

In [None]:
gts, paths = {}, {}

for val_name, val_dataset in val_datasets.items():
    gts[val_name] = [val_dataset.video_infos[i]["label"] for i in range(len(val_dataset))]
    
for val_name, val_dataset in val_datasets.items():
    paths[val_name] = [val_dataset.video_infos[i]["filename"] for i in range(len(val_dataset))]

In [None]:
if not glob.glob("CLIP_vis_features.pt"):
    visual_features = get_features(False)
visual_features = torch.load("CLIP_vis_features.pt")

In [None]:




def encode_text_prompts(prompts):
    text_tokens = tokenizer(prompts).to("cuda")
    with torch.no_grad():
        embedding = model.token_embedding(text_tokens)
        text_features = model.encode_text(text_tokens).float().cpu()
    return text_tokens.cpu(), embedding.cpu(), text_features
    
text_tokens, embedding, text_feats = encode_text_prompts(texts)

In [None]:
demo_features.keys()

In [None]:
backend = "Matlab" # Matlab | Pytorch

if backend == "Matlab":
    with open("naturalnesses_matlab_results.pkl","rb") as f:
        matlab_results = pkl.load(f)
        sn = matlab_results["spatial"]
        tn2 = matlab_results["temporal"]

else:
    sn, tn2 = {}, {}
    for val_name in visual_features:
        with open(f"spatial_naturalness_{val_name}.pkl","rb") as infile:
            sn[val_name] = pkl.load(infile)["pr_labels"]

        with open("temporal_naturalness_pubs.pkl","rb") as infile:
            tn = pkl.load(infile)

        tn2[val_name] = tn[f"{val_name}"]["tn_index"]

In [None]:
prs  = {}
local_prs = {}
for val_name in visual_features:
    prs[val_name] = []
    local_prs[val_name] = [[],[],[],[]]
    for feat in tqdm(visual_features[val_name]):
        with torch.no_grad():
            logits = feat @ text_feats.T.cpu()
        logits = logits.cpu().numpy()
        semantic_affinity_index = np.zeros((50,64))
        for k in [0,1]:
            pn_pair = torch.from_numpy(logits[..., 2 * k : 2 * k + 2]).float().numpy()
            semantic_affinity_index += pn_pair[...,0] - pn_pair[...,1]
        
        prs[val_name].append(semantic_affinity_index[1:].mean())
        #local_prs[val_name][0].append(semantic_affinity_index[1:].reshape(64,7,7))
        

            
            
    prs[val_name] = rescale(prs[val_name])
    if val_name == 'val-maxwell':
        for key in d:
            try:
                print(key, val_name, "P", pearsonr(prs[val_name], d[key])[0])
            except:
                print(key)
    print("sa_only", val_name, "S", spearmanr(prs[val_name], gts[val_name])[0], "P", pearsonr(prs[val_name], gts[val_name])[0])
    #prs += sn[val_name] #+ tn2[val_name]
    #print("all_indices", val_name, "S", spearmanr(prs, gts[val_name])[0], "P", pearsonr(prs, gts[val_name])[0])
    

In [None]:
for val_name in visual_features:
    if val_name == "val-ytugc":
        sn[val_name] = np.zeros(1147)
        tn2[val_name] = np.zeros(1147)
    all_prs =  prs[val_name] + sn[val_name] + tn2[val_name]
    print("sa_only", val_name, "S", spearmanr(prs[val_name], gts[val_name])[0], "P", pearsonr(prs[val_name], gts[val_name])[0])
    print("sn_only", val_name, "S", spearmanr(sn[val_name], gts[val_name])[0], "P", pearsonr(sn[val_name], gts[val_name])[0])
    print("tn_only", val_name, "S", spearmanr(tn2[val_name], gts[val_name])[0], "P", pearsonr(tn2[val_name], gts[val_name])[0])
    print("all_indices", val_name, "S", spearmanr(all_prs, gts[val_name])[0], "P", pearsonr(all_prs, gts[val_name])[0])
    print("")

In [None]:
for val_name in visual_features:
    for i in range(4):
        local_prs[val_name][i] = rescale(np.stack(local_prs[val_name][i], 0))

In [359]:
import cv2
from torch.nn.functional import interpolate

def visualize_local_quality(video_path, quality_map_tensors, output_path):
    cap = cv2.VideoCapture(video_path)
    old_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    old_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    width = 480
    height = int(480 * old_height / old_width)
    frame_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    resized_quality_maps = []
    for quality_map_tensor in quality_map_tensors:
        resized_quality_maps += [interpolate(quality_map_tensor[None, None, :], size=(frame_cnt, height, width), mode="nearest").numpy()[0,0]]
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, 5*height))

    frame_idx = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.resize(frame, (width, height))
        original_frame = frame
        for resized_quality_map in resized_quality_maps:
            resized_quality_map = resized_quality_map[frame_idx]

            #quality_map = quality_map_tensor[frame_idx].numpy()
            #resized_quality_map = cv2.resize(quality_map, (width, height), interpolation=cv2.INTER_LINEAR)

            color_map = np.zeros((height, width, 3), dtype=np.uint8)
            color_map[:, :, 2] = (1 - resized_quality_map) * 255  # Red channel
            color_map[:, :, 1] = resized_quality_map * 255  # Green channel

            alpha = 0.5
            blended_frame = cv2.addWeighted(original_frame, alpha, color_map, 1 - alpha, 0)
            frame = np.concatenate((frame, blended_frame), 0)
        out.write(frame)
        frame_idx += 1

    cap.release()
    out.release()
