In [22]:
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch
from mpl_toolkits.axes_grid1 import AxesGrid
from sklearn.mixture import BayesianGaussianMixture
from tqdm import tqdm

In [45]:
def KL_divergence(mean1, mean2, covariance2, precision1, precision2):
    assert precision1.shape == covariance2.shape
    d = precision1.shape[0]
    mean1 = mean1.type(torch.float64)
    mean2 = mean2.type(torch.float64)
    precision1 = precision1.type(torch.float64)
    covariance2 = covariance2.type(torch.float64)
    product = covariance2 @ precision1
    scal_prod = torch.t(mean2 - mean1) @ precision2 @ (mean2 - mean1)
    return 1 / 2 * torch.logdet(product) - d / 2 + 1 / 2 * torch.trace(product) + 1 / 2 * scal_prod

In [24]:
dataset = 'DOTA_train'

class_features = {}

classes = os.listdir(os.path.join('../class_features', dataset))
for c in classes:
    class_features[int(c)] = []
    for i in range(10):
        with open(os.path.join('../class_features', dataset, c, str(i) + '.pkl'), 'rb') as f:
            image_f = pickle.load(f)
        class_features[int(c)].append(image_f.squeeze().cpu().numpy())
    class_features[int(c)] = np.array(class_features[int(c)])

print(class_features.keys())
class_features[0].shape

dict_keys([5, 8, 9, 13, 14, 12, 2, 7, 6, 1, 10, 4, 0, 3, 15, 11])


(10, 1024)

In [46]:
bgmm_dict = {}

classes = list(map(int, os.listdir(os.path.join('../class_features', dataset))))
for c in tqdm(classes):
    bgmm = BayesianGaussianMixture()
    bgmm.fit(class_features[c])
    bgmm_dict[c] = bgmm

n = len(classes)
heatmap = np.zeros((n, n))

for i in tqdm(range(n)):
    mean1 = torch.tensor(bgmm_dict[i].means_).squeeze()
    covariance1 = torch.tensor(bgmm_dict[i].covariances_).squeeze()
    precision1 = torch.tensor(bgmm_dict[i].precisions_).squeeze()
    for j in range(n):
        mean2 = torch.tensor(bgmm_dict[j].means_).squeeze()
        covariance2 = torch.tensor(bgmm_dict[j].covariances_).squeeze()
        precision2 = torch.tensor(bgmm_dict[j].precisions_).squeeze()
        # heatmap1[i, j], heatmap2[i, j], heatmap3[i, j], heatmap[i, j] = KL_divergence_pinv(mean1, mean2, covariance1, covariance2)
        try:
            heatmap[i, j] = KL_divergence(mean1, mean2, covariance2, precision1, precision2)
        except torch._C._LinAlgError:
            heatmap[i, j] = np.nan



100%|██████████| 16/16 [00:03<00:00,  4.71it/s]
100%|██████████| 16/16 [00:03<00:00,  4.91it/s]


In [ ]:
0 plane,
1 ship,
2 storage tank,
3 baseball diamond,
4 tennis court,
5 basketball court,
6 ground track field,
7 harbor, 
8 bridge,
9 large vehicle,
10 small vehicle,
11 helicopter, 
12 roundabout,
13 soccer ball field,
14 swimming pool 
15 container crane

In [53]:
heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))

In [54]:
%matplotlib qt
plt.imshow(heatmap)

<matplotlib.image.AxesImage at 0x7f8fb11a7d10>

In [0]:
def KL_divergence_pinv(mean1, mean2, covariance1, covariance2):
    assert covariance1.shape == covariance2.shape
    d = covariance1.shape[0]
    mean1 = mean1.type(torch.float64)
    mean2 = mean2.type(torch.float64)
    covariance1 = covariance1.type(torch.float64)
    covariance2 = covariance2.type(torch.float64)
    cov1_bis = covariance1 + torch.eye(1024, device='cuda:0') * 1e-6
    cov2_bis = covariance2 + torch.eye(1024, device='cuda:0') * 1e-6
    product1 = cov2_bis @ torch.linalg.pinv(cov1_bis, atol=1e-7, hermitian=True)
    product2 = covariance2 @ torch.linalg.pinv(covariance1, atol=1e-7, hermitian=True)
    scal_prod = torch.t(mean2 - mean1) @ torch.linalg.pinv(covariance2, atol=1e-6, hermitian=True) @ (mean2 - mean1)
    lndet = torch.logdet(product1)
    trace = torch.trace(product2) + (
            1024 - torch.linalg.matrix_rank(torch.linalg.pinv(covariance1, atol=1e-7, hermitian=True)))
    return lndet, trace, scal_prod, 1 / 2 * (lndet - d + trace + scal_prod)

In [49]:
def KL_divergence_std(mean1, mean2, variance1, variance2):
    assert variance1.shape == variance2.shape
    d = variance1.shape[0]
    mean1 = mean1.type(torch.float64)
    mean2 = mean2.type(torch.float64)
    variance1 = variance1.type(torch.float64) * torch.eye(d, device='cuda:0')
    variance2 = variance2.type(torch.float64) * torch.eye(d, device='cuda:0')
    product = variance2 @ torch.linalg.inv(variance1)
    scal_prod = torch.t(mean2 - mean1) @ torch.linalg.inv(variance2) @ (mean2 - mean1)
    return 1 / 2 * torch.logdet(product) - d / 2 + 1 / 2 * torch.trace(product) + 1 / 2 * scal_prod

In [50]:
paths = [os.path.join('./class_barycenter', dataset) for dataset in os.listdir('./class_barycenter')]
heatmaps = {}

FileNotFoundError: [Errno 2] No such file or directory: './class_barycenter'

In [None]:
for path in paths:
    print(path)
    mean = [x for x in sorted(os.listdir(path),
                              key=lambda item: (int(item.partition('_')[0])
                                                if item[0].isdigit() else float('inf'), item))
            if x.split('_')[1] == 'mean.pkl']

    cov = [x for x in sorted(os.listdir(path),
                             key=lambda item: (int(item.partition('_')[0])
                                               if item[0].isdigit() else float('inf'), item))
           if x.split('_')[1] == 'cov.pkl']

    n = len(mean)
    heatmap = np.zeros((n, n))

    for i in tqdm(range(n)):
        mean1_path = os.path.join(path, mean[i])
        cov1_path = os.path.join(path, cov[i])
        with open(mean1_path, 'rb') as f:
            mean1 = pickle.load(f)
        with open(cov1_path, 'rb') as f:
            covariance1 = pickle.load(f)
        for j in range(n):
            mean2_path = os.path.join(path, mean[j])
            cov2_path = os.path.join(path, cov[j])
            with open(mean2_path, 'rb') as f:
                mean2 = pickle.load(f)
            with open(cov2_path, 'rb') as f:
                covariance2 = pickle.load(f)

            # heatmap1[i, j], heatmap2[i, j], heatmap3[i, j], heatmap[i, j] = KL_divergence_pinv(mean1, mean2, covariance1, covariance2)
            try:
                heatmap[i, j] = KL_divergence_std(mean1, mean2, covariance1, covariance2)
            except torch._C._LinAlgError:
                heatmap[i, j] = np.nan
    heatmaps[path.split('/')[-1]] = heatmap

In [None]:
plt.imshow(heatmaps[0])

In [None]:
%matplotlib qt
fig = plt.figure()

grid = AxesGrid(fig, 111,
                nrows_ncols=(7, 6),
                axes_pad=0.5,
                share_all=True,
                cbar_location="right",
                cbar_mode="single",
                )
keys = sorted(list(heatmaps.keys()))[2:] + ['COCO_train', 'COCO_val']

for dataset, ax in zip(keys, grid):
    ax.set_title(dataset)
    im = ax.imshow(heatmaps[dataset])

grid.cbar_axes[0].colorbar(im)

for cax in grid.cbar_axes:
    cax.toggle_label(True)

plt.show()


In [None]:
n = len(mean)
heatmap = np.zeros((n, n))
heatmap1 = np.zeros((n, n))
heatmap2 = np.zeros((n, n))
heatmap3 = np.zeros((n, n))


In [None]:
for i in tqdm(range(n)):
    mean1_path = os.path.join(path, mean[i])
    cov1_path = os.path.join(path, cov[i])
    with open(mean1_path, 'rb') as f:
        mean1 = pickle.load(f)
    with open(cov1_path, 'rb') as f:
        covariance1 = pickle.load(f)
    for j in range(n):
        mean2_path = os.path.join(path, mean[j])
        cov2_path = os.path.join(path, cov[j])
        with open(mean2_path, 'rb') as f:
            mean2 = pickle.load(f)
        with open(cov2_path, 'rb') as f:
            covariance2 = pickle.load(f)

        # heatmap1[i, j], heatmap2[i, j], heatmap3[i, j], heatmap[i, j] = KL_divergence_pinv(mean1, mean2, covariance1, covariance2)
        heatmap[i, j] = KL_divergence_std(mean1, mean2, covariance1, covariance2)

In [None]:
%matplotlib qt

plt.imshow(heatmap)

In [None]:
plt.imshow(heatmap1)

In [None]:
plt.imshow(heatmap2)

In [None]:
plt.imshow(heatmap3)

In [None]:
path = './class_barycenter/COCO_train'
mean = [x for x in sorted(os.listdir(path),
                          key=lambda item: (int(item.partition('_')[0])
                                            if item[0].isdigit() else float('inf'), item))
        if x.split('_')[1] == 'mean.pkl']

cov = [x for x in sorted(os.listdir(path),
                         key=lambda item: (int(item.partition('_')[0])
                                           if item[0].isdigit() else float('inf'), item))
       if x.split('_')[1] == 'cov.pkl']

In [None]:
mean1_path = os.path.join(path, mean[70])
cov1_path = os.path.join(path, cov[70])
with open(mean1_path, 'rb') as f:
    mean1 = pickle.load(f)
with open(cov1_path, 'rb') as f:
    covariance1 = pickle.load(f)

mean2_path = os.path.join(path, mean[0])
cov2_path = os.path.join(path, cov[0])
with open(mean2_path, 'rb') as f:
    mean2 = pickle.load(f)
with open(cov2_path, 'rb') as f:
    covariance2 = pickle.load(f)

In [None]:
mean1 = mean1.type(torch.float64)
mean2 = mean2.type(torch.float64)
covariance1 = covariance1.type(torch.float64)
covariance2 = covariance2.type(torch.float64)

In [None]:
np.isclose(np.linalg.eig(covariance1.cpu().numpy())[0], np.zeros(1024)).sum()

In [None]:
KL_divergence(mean1, mean2, covariance1, covariance2)

In [None]:
KL_divergence_pinv(mean1, mean2, covariance1, covariance2)

In [None]:
KL_divergence_std(mean1, mean2, covariance1, covariance2)

In [None]:
product2 = covariance2 @ torch.linalg.pinv(covariance1, atol=1e-7, hermitian=True)
trace = torch.trace(product2) + (
        1024 - torch.linalg.matrix_rank(torch.linalg.pinv(covariance1, atol=1e-7, hermitian=True)))

In [None]:
trace

In [None]:
paths = sorted(os.path.join('./class_barycenter', dataset) for dataset in os.listdir('./class_barycenter'))
for path in paths:
    print(path.split('/')[-1])

    cov = [x for x in sorted(os.listdir(path),
                             key=lambda item: (int(item.partition('_')[0])
                                               if item[0].isdigit() else float('inf'), item))
           if x.split('_')[1] == 'cov.pkl']

    n = len(cov)

    for i in range(n):
        cov1_path = os.path.join(path, cov[i])
        with open(cov1_path, 'rb') as f:
            covariance1 = pickle.load(f)

        variance1 = np.array([covariance1[i, i].cpu().numpy() for i in range(len(covariance1))])

        try:
            print(
                i,
                (np.linalg.eig(covariance1.cpu().numpy())[0] <= np.linalg.eig(covariance1.cpu().numpy())[
                    0].max() * 0.001).sum()
            )
        except np.linalg.LinAlgError:
            print(i)


In [None]:
(np.linalg.eig(covariance1.cpu().numpy())[0] > 5).sum()

In [None]:
with open(cov1_path, 'rb') as f:
    covariance1 = pickle.load(f)

In [None]:
with open('./dataset_barycenter/COCO_train_std.pkl', 'rb') as f:
    cov = pickle.load(f)

In [None]:
cov.shape

In [None]:
os.listdir('./dataset_barycenter')

np.isclose(np.linalg.eig(cov.cpu().numpy())[0], np.zeros(cov.shape[0]), atol=1e-6).sum()

In [None]:
path = './dataset_barycenter'

cov = [x for x in sorted(os.listdir(path)) if x.split('_')[-1] == 'std.pkl']

for dataset in cov:
    cov_path = os.path.join(path, dataset)
    with open(cov_path, 'rb') as f:
        covariance = pickle.load(f)

    variance = np.array([covariance[i, i].cpu().numpy() for i in range(len(covariance))])

    print(
        dataset,
        (np.linalg.eig(covariance.cpu().numpy())[0] <= np.linalg.eig(covariance.cpu().numpy())[0].max() * 0.001).sum()
    )


In [None]:
prop_eig0 = [773.45, 703.25, 808.60,
             776.25, 716.19, 833.06,
             968.71, 872.29, 1002.20,
             888.60, 834.20, 887.00,
             970.85, 976.40, 1007.89]
# 549.80, 549.80, 684.40]

perf_5shot = [30.42, 30.42, 30.42,
              14.45, 14.45, 14.45,
              55.58, 55.58, 55.58,
              13.25, 13.25, 13.25,
              5.26, 5.26, 5.26]
# 5.74, 5.74, 5.74]

labels = ['DIOR_test', 'DIOR_train', 'DIOR_val',
          'DOTA_test', 'DOTA_train', 'DOTA_val',
          'DeepFruits_test', 'DeepFruits_train', 'DeepFruits_val',
          'SIXray_test', 'SIXray_train', 'SIXray_val',
          'clipart_test', 'clipart_train', 'clipart_val']
# 'VisDrone_test', 'VisDrone_train', 'VisDrone_val']

In [None]:
colors_dict = {'COCO': '#9D6FD7',
               'CrowdHuman': '#E6C8C7',
               'DIOR': '#003A37',
               'DOTA': '#8AFF66',
               'DeepFruits': '#284D26',
               'KITTI': '#F65794',
               'Oktoberfest': '#804E96',
               'SIXray': '#C05E5A',
               'VisDrone': '#458BBC',
               'clipart': '#C83F0F',
               'comic': '#3C7503',
               'fashionpedia': '#3B9F21',
               'watercolor': '#E68C70'}

In [None]:
for prop, perf, label in zip(prop_eig0, perf_5shot, labels):
    if label.split('_')[-1] == 'test':
        plt.scatter(prop, perf, color=colors_dict[label.split('_')[0]])
        plt.text(prop, perf + 1, label.split('_')[0])
plt.show()

In [None]:
prop_eig0 = [970.85, 976.40, 1007.89,
             856.00, 860.00, 862.00,
             968.71, 872.29, 1002.20,
             891.89, 872.17, 949.87,
             855.00, 836.50, 856.75,
             1009.33, 925.73, 992.07,
             888.60, 834.20, 887.00]
# 549.80, 549.80, 684.40]

perf_5shot = [49.3, 49.3, 49.3,
              51.5, 51.5, 51.5,
              60.7, 60.7, 60.7,
              21.3, 21.3, 21.3,
              45.3, 45.3, 45.3,
              81.6, 81.6, 81.6,
              23.9, 23.9, 23.9]
# 14.2, 14.2, 14.2]

labels = ['clipart_test', 'clipart_train', 'clipart_val',
          'CrowdHuman_test', 'CrowdHuman_train', 'CrowdHuman_val',
          'DeepFruits_test', 'DeepFruits_train', 'DeepFruits_val',
          'fashionpedia_test', 'fashionpedia_train', 'fashionpedia_val',
          'KITTI_test', 'KITTI_train', 'KITTI',
          'Oktoberfest_test', 'Oktoberfest_train', 'Oktoberfest_val',
          'SIXray_test', 'SIXray_train', 'SIXray_val']
# 'VisDrone_test', 'VisDrone_train', 'VisDrone_val']

In [None]:
plt.rcParams.update({'font.size': 13})
for prop, perf, label in zip(prop_eig0, perf_5shot, labels):
    if label.split('_')[-1] == 'test':
        plt.scatter(prop, perf, color=colors_dict[label.split('_')[0]])
        plt.text(prop, perf + 1, label.split('_')[0])
plt.xlabel(r"Difficulté $\varepsilon$")
plt.ylabel("AP50 5-shot")
plt.grid()
plt.show()