In [1]:
import torch
import cv2
import random
import fastvqa
from fastvqa.models import BaseEvaluator
from fastvqa.datasets import VQAInferenceDataset, get_fragments, SampleFrames

from matplotlib import pyplot as plt

from scipy.stats import spearmanr, pearsonr
from scipy.stats.stats import kendalltau as kendallr
import numpy as np

from time import time
from tqdm import tqdm

## choose the device you would like to run on

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

In [2]:
## demo on how to run FAST-VQA in script 
model = fastvqa.deep_end_to_end_vqa(True, 'pretrained_weights/all_aligned_fragments_32.pth', device=device)


In [None]:
video = torch.randn((3,96,224,224)).to(device)
score = model(video)
print(score['score'])



In [None]:
## Calculate the finetune accuracy on small datasets

def get_finetune_results(results):
    if 'results' in results:
        results = results['results']
    srccs = np.array([r[0] for r in results])
    plccs = np.array([r[1] for r in results])
    krccs = np.array([r[2] for r in results])
    ms, ss, mds = np.mean(srccs), np.std(srccs), np.median(srccs)
    mp, sp, mdp = np.mean(plccs), np.std(plccs), np.median(plccs)
    mk, sk, mdk = np.mean(krccs), np.std(krccs), np.median(krccs)
    print(f'''In {len(results)} random split experiments,
        the mean SROCC is {ms:.4f} ({ss:.4f}), median {mds:.4f}
        the mean PLCC  is {mp:.4f} ({sp:.4f}), median {mdp:.4f}
        the mean KROCC is {mk:.4f} ({sk:.4f}), median {mdk:.4f}''')

finetune_results = torch.load('results/results_finetune_konvid_s32*32_ens1.pkl')
get_finetune_results(finetune_results)

In [None]:
## defining model and loading checkpoint

model = BaseEvaluator().to(device)
fsize = 32
load_path = f'pretrained_weights/all_aligned_fragments_{fsize}.pth'
state_dict = torch.load(load_path, map_location='cpu')

if 'state_dict' in state_dict:
    state_dict = state_dict['state_dict']
    from collections import OrderedDict
    i_state_dict = OrderedDict()
    for key in state_dict.keys():
        if 'cls' in key:
            tkey = key.replace('cls', 'vqa')
            i_state_dict[tkey] = state_dict[key]
        else:
            i_state_dict[key] = state_dict[key]

model.load_state_dict(i_state_dict)

In [None]:
## getting datasets (if you want to load from existing VQA datasets)

dataset_name = 'KoNViD'
dataset_path = f'/mnt/lustre/hnwu/datasets/{dataset_name}'

inference_set = VQAInferenceDataset(f'{dataset_path}/labels.txt', dataset_path, )
                                    #fragments = 224 // fsize, fsize = fsize)


In [None]:
## run the model with examplar fragment video
## with ultra...fast performance

## for example from the dataset

q = random.randrange(len(inference_set))


In [None]:
data = inference_set[q]

st = time()



vfrag = data['video'].to(device)

## or, directly get from your input videos as follows
## where 'video' is a torch Tensor

## from datasets import temporal_sampling (not implemented yet)

# data = temporal_sampling(video, 32, 2, 4)

# vfrag = get_fragments(data['video']).to(device)

demo_result = model(vfrag)
print(demo_result.shape)
demo_result = demo_result.reshape((-1,) + demo_result.shape[-2:])
score = torch.mean(demo_result)
end = time()

print(f'The quality of the video is {score.item()}, consuming time {end-st:.4f}s.')



In [None]:
def rescale(pr, gt=None):
    if gt is None:
        pr = ((pr - np.mean(pr)) / np.std(pr))
    else:
        pr = ((pr - np.mean(pr)) / np.std(pr)) * np.std(gt) + np.mean(gt)
    return pr

In [None]:
## see the spatial-temporal quality localization for a reference
def init_demo_reader(path, i):
    from decord import VideoReader, cpu
    video_names = [ele.split(',')[0] for ele in open(f'{path}/labels.txt').readlines()]
    frame_reader = VideoReader(f'{path}/{video_names[i]}', ctx=cpu(0))
    return frame_reader

frame_reader = init_demo_reader(dataset_path, q)
video_names = [ele.split(',')[0] for ele in open(f'{dataset_path}/labels.txt').readlines()]

r_index = random.randrange(len(data['frame_inds']))
frame_index = data['frame_inds'][r_index]
frame = frame_reader[frame_index]
frame_quality_map = demo_result[r_index // 2]
fragment = (vfrag.permute(0,2,3,4,1).reshape((128,) + vfrag.shape[-2:] + (3,)).cpu() * inference_set.std + inference_set.mean).numpy()[r_index]

frame_quality_map = frame_quality_map.cpu().numpy()
qlt = cv2.resize(rescale(frame_quality_map), (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_LINEAR)
plt.figure(dpi=450)
plt.subplot(221)
plt.imshow(frame)
plt.subplot(222)
plt.imshow(frame / 255. - np.stack((qlt,) + (np.zeros_like(qlt),)*2, -1) / 2.)

fqlt = cv2.resize(frame_quality_map, (fragment.shape[1], fragment.shape[0]), interpolation=cv2.INTER_NEAREST)
plt.figure(dpi=300)
plt.subplot(221)
plt.imshow(fragment / 255.)
plt.subplot(222)
plt.imshow(fqlt, cmap='gray') #fragment / 255. - np.stack((fqlt,) + (np.zeros_like(fqlt),)*2, -1) / 2.)
#plt.savefig(f'demos/demo_{video_names[q].split("/")[-1]}.png')
print(frame_quality_map)

In [None]:
## run inference for a whole testing database
## note that the Jupyter program might be relatively slower than running directly with './inference_dataset.py'

inference_loader = torch.utils.data.DataLoader(inference_set, batch_size=1, num_workers=4)
results = []

for i, data in tqdm(enumerate(inference_loader)):
    result = dict()
    vqfrag = data['video'].to(device).squeeze(0)
    with torch.no_grad():
        result['pr_labels'] = model(vfrag).cpu().numpy()
    result['gt_label'] = data['gt_label'].item()
    result['frame_inds'] = data['frame_inds']
                                                                                                                                                                                                                                                                                        del data
    results.append(result)

In [None]:
## calculating several accuracies indices

gt_labels = [r['gt_label'] for r in results]
pr_labels = [np.mean(r['pr_labels'][:]) for r in results]
opr_labels = pr_labels
pr_labels = rescale(pr_labels, gt_labels)

srocc = spearmanr(gt_labels, pr_labels)[0]
plcc = pearsonr(gt_labels, pr_labels)[0]
krocc = kendallr(gt_labels, pr_labels)[0]
rmse = np.sqrt(((gt_labels - pr_labels) ** 2).mean())

print(f'For dataset {dataset_name} we inference, the accuracy of the model is as follows:\n  SROCC: {srocc:.4f}\n  PLCC:  {plcc:.4f}\n  KROCC: {krocc:.4f}\n  RMSE:  {rmse:.4f}')

In [None]:
(ten_133 - np.mean(opr_labels)) / np.std(opr_labels)

In [None]:
# stableness of Fragments

ten_102 = [-0.1089, -0.1156, -0.1259, -0.1196, -0.1227, -0.1156, -0.1100, -0.1113, -0.1181, -0.1126]
print('Video No.102', np.mean(ten_102), np.std(ten_102))
ten_133 = [-0.0748, -0.0671, -0.0784, -0.0530, -0.0521, -0.0653, -0.0456, -0.0777, -0.0378, -0.0584]
print('Video No.133', np.mean(ten_133), np.std(ten_133))