In [23]:
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import glob
from yaml import unsafe_load as yaml_load
import numpy as np
import tqdm
from skimage import measure
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
from scipy import linalg

from keras import models

from keras_transfer_learning import dataset, model
from keras_transfer_learning.data import compare

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


TODOs:

* read https://arxiv.org/pdf/1805.12462.pdf

In [3]:
print('\n'.join(sorted([f.rpartition(os.path.sep)[-1] for f in glob.glob(os.path.join('.', 'models', '*'))])))

A00_unet_stardist_hl60-high-noise_P_002
A00_unet_stardist_hl60-high-noise_P_005
A00_unet_stardist_hl60-high-noise_P_010
A00_unet_stardist_hl60-high-noise_P_050
A00_unet_stardist_hl60-high-noise_P_200
A00_unet_stardist_hl60-high-noise_R_002
A00_unet_stardist_hl60-high-noise_R_005
A00_unet_stardist_hl60-high-noise_R_010
A00_unet_stardist_hl60-high-noise_R_050
A00_unet_stardist_hl60-high-noise_R_200
A00_unet_stardist_hl60-high-noise_R_F
A00_unet_stardist_hl60-low-noise_P_002
A00_unet_stardist_hl60-low-noise_P_005
A00_unet_stardist_hl60-low-noise_P_010
A00_unet_stardist_hl60-low-noise_P_050
A00_unet_stardist_hl60-low-noise_P_200
A00_unet_stardist_hl60-low-noise_R_002
A00_unet_stardist_hl60-low-noise_R_005
A00_unet_stardist_hl60-low-noise_R_010
A00_unet_stardist_hl60-low-noise_R_050
A00_unet_stardist_hl60-low-noise_R_200
A00_unet_stardist_hl60-low-noise_R_F
A01_unet_stardist_hl60-high-noise_P_002
A01_unet_stardist_hl60-high-noise_P_005
A01_unet_stardist_hl60-high-noise_P_010
A01_unet_stardi

In [4]:
model_name = 'A00_unet_stardist_hl60-high-noise_R_F'

In [5]:
model_dir = os.path.join('.', 'models', model_name)
m = model.Model(model_dir=model_dir, load_weights='last')

Instructions for updating:
Colocations handled automatically by placer.


## Create the feature model

In [6]:
def model_up_to_layer(keras_model, layer_name):
    for l in keras_model.layers:
        if l.name == layer_name:
            break
    return models.Model(keras_model.input, l.output)

feature_model = model_up_to_layer(m.model, 'activation_1')
features_size = feature_model.output.shape[-1].value

## Compute the mean and var for model data

In [7]:
d = dataset.Dataset(m.config)
x_test, _ = d.create_test_dataset()

In [21]:
def get_mgf(data, limit=20):
    if limit is not None:
        data = [data[i] for i in np.random.permutation(len(data))[:limit]]
    
    features_list = []
    for img in tqdm.tqdm(data):
        pred = feature_model.predict(img[None, ..., None])
        features_list.append(np.reshape(pred, (-1, features_size)))

    features = np.concatenate(features_list)
    mean = np.mean(features, axis=0)
    cov = np.cov(features, rowvar=0)
    return mean, cov

In [11]:
mean, cov = get_mgf(x_test)

100%|██████████| 20/20 [00:45<00:00,  2.29s/it]


In [30]:
mean_1, cov_1 = get_mgf(x_test)


  0%|          | 0/20 [00:00<?, ?it/s][A
  5%|▌         | 1/20 [00:02<00:50,  2.67s/it][A
 10%|█         | 2/20 [00:04<00:41,  2.28s/it][A
 15%|█▌        | 3/20 [00:05<00:34,  2.05s/it][A
 20%|██        | 4/20 [00:06<00:29,  1.85s/it][A
 25%|██▌       | 5/20 [00:08<00:26,  1.74s/it][A
 30%|███       | 6/20 [00:09<00:23,  1.65s/it][A
 35%|███▌      | 7/20 [00:11<00:20,  1.57s/it][A
 40%|████      | 8/20 [00:12<00:18,  1.51s/it][A
 45%|████▌     | 9/20 [00:13<00:16,  1.47s/it][A
 50%|█████     | 10/20 [00:15<00:14,  1.47s/it][A
 55%|█████▌    | 11/20 [00:16<00:12,  1.44s/it][A
 60%|██████    | 12/20 [00:18<00:11,  1.42s/it][A
 65%|██████▌   | 13/20 [00:19<00:10,  1.44s/it][A
 70%|███████   | 14/20 [00:21<00:08,  1.42s/it][A
 75%|███████▌  | 15/20 [00:22<00:07,  1.42s/it][A
 80%|████████  | 16/20 [00:23<00:05,  1.41s/it][A
 85%|████████▌ | 17/20 [00:25<00:04,  1.40s/it][A
 90%|█████████ | 18/20 [00:26<00:02,  1.39s/it][A
 95%|█████████▌| 19/20 [00:28<00:01,  1.39s/it]

## Compute the mean and var for other data

In [17]:
with open(os.path.join('configs', 'data', 'dsb2018.yaml'), 'r') as f:
    new_data_conf = { 'data': yaml_load(f) }
new_d = dataset.Dataset(new_data_conf)
new_x_test, _ = new_d.create_test_dataset()

In [22]:
new_mean, new_cov = get_mgf(new_x_test)


  0%|          | 0/20 [00:00<?, ?it/s][A
  5%|▌         | 1/20 [00:01<00:32,  1.73s/it][A
 10%|█         | 2/20 [00:05<00:44,  2.48s/it][A
 15%|█▌        | 3/20 [00:07<00:35,  2.10s/it][A
 20%|██        | 4/20 [00:11<00:44,  2.78s/it][A
 25%|██▌       | 5/20 [00:15<00:47,  3.15s/it][A
 30%|███       | 6/20 [00:18<00:44,  3.14s/it][A
 35%|███▌      | 7/20 [00:20<00:34,  2.68s/it][A
 40%|████      | 8/20 [00:21<00:26,  2.20s/it][A
 45%|████▌     | 9/20 [00:25<00:30,  2.76s/it][A
 50%|█████     | 10/20 [00:27<00:24,  2.42s/it][A
 55%|█████▌    | 11/20 [00:27<00:16,  1.87s/it][A
 60%|██████    | 12/20 [00:29<00:14,  1.79s/it][A
 65%|██████▌   | 13/20 [00:33<00:17,  2.46s/it][A
 70%|███████   | 14/20 [00:37<00:17,  2.97s/it][A
 75%|███████▌  | 15/20 [00:39<00:12,  2.56s/it][A
 80%|████████  | 16/20 [00:41<00:10,  2.52s/it][A
 85%|████████▌ | 17/20 [00:45<00:09,  3.11s/it][A
 90%|█████████ | 18/20 [00:47<00:05,  2.65s/it][A
 95%|█████████▌| 19/20 [00:48<00:02,  2.10s/it]

In [32]:
new_mean_1, new_cov_1 = get_mgf(new_x_test)


  0%|          | 0/20 [00:00<?, ?it/s][A
  5%|▌         | 1/20 [00:02<00:44,  2.34s/it][A
 10%|█         | 2/20 [00:04<00:40,  2.24s/it][A
 15%|█▌        | 3/20 [00:04<00:29,  1.73s/it][A
 20%|██        | 4/20 [00:07<00:29,  1.87s/it][A
 25%|██▌       | 5/20 [00:08<00:26,  1.77s/it][A
 30%|███       | 6/20 [00:10<00:26,  1.89s/it][A
 35%|███▌      | 7/20 [00:12<00:25,  1.95s/it][A
 40%|████      | 8/20 [00:15<00:24,  2.00s/it][A
 45%|████▌     | 9/20 [00:17<00:22,  2.04s/it][A
 50%|█████     | 10/20 [00:17<00:15,  1.58s/it][A
 55%|█████▌    | 11/20 [00:19<00:15,  1.74s/it][A
 60%|██████    | 12/20 [00:21<00:14,  1.85s/it][A
 65%|██████▌   | 13/20 [00:23<00:13,  1.93s/it][A
 70%|███████   | 14/20 [00:26<00:11,  1.98s/it][A
 75%|███████▌  | 15/20 [00:27<00:09,  1.85s/it][A
 80%|████████  | 16/20 [00:28<00:06,  1.53s/it][A
 85%|████████▌ | 17/20 [00:29<00:03,  1.30s/it][A
 90%|█████████ | 18/20 [00:29<00:02,  1.07s/it][A
 95%|█████████▌| 19/20 [00:31<00:01,  1.39s/it]

## Compute the distance

In [27]:
def frechet_distance(mean_1, cov_1, mean_2, cov_2):
    return np.sum(np.square(mean_1 - mean_2)) + \
            np.trace(cov_1 + cov_2 - 2 * linalg.fractional_matrix_power(np.dot(cov_1, cov_2), 1/2))

In [28]:
frechet_distance(mean, cov, new_mean, new_cov)

5368.415027869489

In [31]:
frechet_distance(mean, cov, mean_1, cov_1)

1.0572058893706213

In [33]:
frechet_distance(new_mean, new_cov, new_mean_1, new_cov_1)

(273.62129750548485-3.0255722957704123e-07j)