In [2]:
import torch
import numpy as np
from collections import defaultdict
from scipy.stats import spearmanr, pearsonr
from scipy.stats.stats import kendalltau as kendallr

In [79]:
results = torch.load('results/results_live_vqc_s32*32_ens6.pkl')

In [80]:
def rescale(pr, gt=None):
    if gt is None:
        pr = ((pr - np.mean(pr)) / np.std(pr))
    else:
        pr = ((pr - np.mean(pr)) / np.std(pr)) * np.std(gt) + np.mean(gt)
    return pr

all_datasets = ['LIVE_VQC', 'KoNViD', 'CVD2014', 'LSVQ']

def pyramid_ensemble_coefficients(results):
    ens_result_len = results[0].get('pr_labels').shape[0] // 4
    
    gt_labels = [r['gt_label'] for r in results]
    
    p_results = defaultdict()
    
    for i in range(1, ens_result_len + 1):
        scores = np.zeros(4)
        for j in range(ens_result_len - i + 1):
            pr_labels = [np.mean(r['pr_labels'][j*4:(j+i)*4]) for r in results]
            pr_labels = rescale(pr_labels, gt_labels)
            scores += np.array([spearmanr(gt_labels, pr_labels)[0], pearsonr(gt_labels, pr_labels)[0], kendallr(gt_labels, pr_labels)[0], np.sqrt(((gt_labels - pr_labels) ** 2).mean())])
        scores /= ens_result_len - i + 1
        p_results[i] = scores
        
    return p_results

def pyramid_ensemble_stds(results):
    ens_result_len = results[0].get('pr_labels').shape[0] // 4
    
    gt_labels = [r['gt_label'] for r in results]
    
    s_results = defaultdict()
    
    for i in range(1, ens_result_len + 1):
        all_pr_labels = []
        stds = np.zeros(2)
    
        for j in range(ens_result_len - i + 1):
            pr_labels = [np.mean(r['pr_labels'][j*4:(j+i)*4]) for r in results]
            pr_labels = rescale(pr_labels, gt_labels)
            all_pr_labels.append(pr_labels)

        all_pr_labels = np.stack(all_pr_labels, 1)
        for q in range(2):
            stds[q] = np.mean(np.std(all_pr_labels, q))
        s_results[i] = stds
    return s_results

In [82]:
pyramid_ensemble_coefficients(results)

defaultdict(None,
            {1: array([0.82231966, 0.84298417, 0.62961196, 9.55846627]),
             2: array([0.82415349, 0.84477332, 0.63168081, 9.50419353]),
             3: array([0.82465459, 0.84517109, 0.63225803, 9.49201806]),
             4: array([0.82506118, 0.84536551, 0.63282197, 9.48607255]),
             5: array([0.82504333, 0.84552036, 0.63273221, 9.48133771]),
             6: array([0.82485328, 0.84508254, 0.63265025, 9.49476492])})

In [81]:
pyramid_ensemble_stds(results)

defaultdict(None,
            {1: array([17.05764389,  1.04976797]),
             2: array([17.05764389,  0.62426716]),
             3: array([17.05764389,  0.39486992]),
             4: array([17.05764389,  0.24477518]),
             5: array([17.05764198,  0.13948995]),
             6: array([17.05764389,  0.        ])})