In [None]:
import os
import sys
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import pickle
import pathlib
from tqdm.notebook import tqdm
from sklearn.datasets import fetch_openml
%matplotlib inline

In [None]:
# Load data from https://www.openml.org/d/554
X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)

In [None]:
img = Image.fromarray(255 - np.uint8(X[47].reshape((28, 28))), 'L')

In [None]:
img

In [None]:
images = []

for idx in range(len(y)):
    if y[idx] == '5':
        img = Image.fromarray(np.uint8(X[idx].reshape((28, 28))), 'L')
        images.append(img)

### Create point clouds of "5"s and flipped "5"s

In [None]:
def get_dataset(angle):
    fives = []

    for i, elem in enumerate(y):
        if elem == '5':
            A = np.zeros((40, 40))
            A[6:34, 6:34] = X[i].reshape((28, 28))
            img = Image.fromarray(np.uint8(A), 'L')
            
            if angle > 0:
                img = img.transpose(Image.FLIP_TOP_BOTTOM)
            
            fives.append(np.asarray(img).flatten())
    
    return np.array(fives)

In [None]:
clouds = []

for angle in [0, 1]:
    clouds.append(get_dataset(angle))

In [None]:
for cloud in clouds:
    print(cloud.shape)

### Compute barcodes 

In [None]:
import mtd

In [None]:
res1 = []
trials = 20

for i in range(1, len(clouds)):
    np.random.seed(7)
    barcs = [mtd.calc_cross_barcodes(clouds[i], clouds[0], batch_size1 = 100, batch_size2 = 1000) for _ in range(trials)]
    res1.append(barcs)

In [None]:
def get_scores(res, args_dict, trials = 10):

    scores = []

    for i in range(len(res)): 
        asum = []
        
        for exp_id, elem in enumerate(res[i]):
            asum.append(mtd.get_score(elem, **args_dict))

        scores.append(sum(asum) / len(res[i]))

    return scores

In [None]:
scores = get_scores(res1, {'h_idx' : 1, 'kind' : 'sum_length'})

In [None]:
for s in scores:
    print(s)

In [None]:
# Geometry Score

import gs

def get_rlts(X):
    N = 2500
    gamma = (1/128)/(N/5000)
    rlts = gs.rlts(X, gamma=gamma, n=N, n_threads = 40)
    
    return rlts

In [None]:
rlts = [None] * len(clouds)

In [None]:
%time
for i in range(len(clouds)):
    rlts[i] = get_rlts(clouds[i])

In [None]:
for i in range(1, len(clouds)):
    print(1e3 * gs.geom_score(rlts[0], rlts[i]))

In [None]:
# additional experiment with IMD
from msid import msid_score

res1 = []
for i in range(len(clouds)):
    v = msid_score(clouds[0], clouds[i])
    res1.append(v)

    print(i, trial, v)