In [None]:
from PIL import Image
import numpy as np
from numpy.random import default_rng
import os
import matplotlib.pyplot as plt
import networkx as nx

from skimage.metrics import structural_similarity as ssim
import iblofunmatch.inter as ibfm

output_dir="output"
os.makedirs(output_dir, exist_ok=True)

In [None]:
CLASS_LIST = list(range(4))
NUM_SAMPLES = 72
NUM_EXP = len(CLASS_LIST)
DATA_PERCENT = 0.4
SUBSET_SIZE = int(DATA_PERCENT * NUM_SAMPLES)
print(f"SUBSET_SIZE:{SUBSET_SIZE}")

In [None]:
# Create list of labels
y=[]
for c in CLASS_LIST:
    y += [c]*NUM_SAMPLES
y = np.array(y)
# Read data
data = []
for c in CLASS_LIST:
    for i in range(NUM_SAMPLES):
        im_frame = Image.open(f"data_COIL20/coil-20-proc/obj{c+1}__{i}.png")
        np_frame = np.array(im_frame)
        data.append(np_frame)
    # samples per class
# going through classes
# Store data into variable
data = np.array(data)
print("All data shape")
print(data.shape)

In [None]:
Dist_X = np.zeros((data.shape[0],data.shape[0]))
for i in range(data.shape[0]):
    for j in range(data.shape[0]):
        if i < j:
            Dist_X[i][j] = 1 - ssim(data[i], data[j])
        elif i > j:
            Dist_X[i][j] = Dist_X[j][i]
    # end for over columns 
# for over rows

### Compute matching from all data to itself

In [None]:
id_S = list(range(data.shape[0]))
ibfm_out = ibfm.get_IBloFunMatch_subset(Dist_X, Dist_X, id_S, output_dir, max_rad=-1, num_it=1, store_0_pm=False, points=False)

In [None]:
os.makedirs(f"plots/ssim", exist_ok=True)
X_barcode_1 = ibfm_out["X_barcode_1"]
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(5,3))
ibfm.plot_barcode(X_barcode_1, "navy", ax)
cidx = CLASS_LIST[0]
plt.savefig(f"plots/ssim/class_{cidx}_barcode.png")

In [None]:
def draw_repr_cycle(repr_cycle, figsize, data, Dist):
    fig, ax = plt.subplots(figsize=figsize)
    G = nx.Graph()
    for v in np.unique(repr_cycle):
        G.add_node(v)
    weighted_edges = []
    for edge in np.array(repr_cycle).reshape((-1,2)).tolist():
        weighted_edges.append((edge[0], edge[1], Dist[edge[0], edge[1]]))
    # G.add_edges_from(weighted_edges)
    G.add_weighted_edges_from(weighted_edges)
    pos = nx.spectral_layout(G)
    nx.draw_networkx(G, ax=ax, pos=pos, width=figsize[0])
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    for node in pos.keys():
        im_array = data[node].reshape((128,128))
        im_frame = Image.fromarray(im_array)
        center = origin=pos[node]
        extent = (center[0]-0.1, center[0]+0.1, center[1]-0.1, center[1]+0.1)
        ax.imshow(im_array, cmap="gray", extent=extent, zorder=4)
    
    ax.set_ylim(ylim)
    ax.set_xlim(xlim)
    return ax, pos

In [None]:
long_X_bars = ((ibfm_out["X_barcode_1"][:,1]-ibfm_out["X_barcode_1"][:,0])>0.2).tolist()
long_X_bars = np.nonzero(long_X_bars)[0].tolist()
print(ibfm_out["X_barcode_1"][long_X_bars])

In [None]:
# %%capture
for id, repr_cycle in enumerate(ibfm_out["X_reps_1"]):
    if id not in long_X_bars:
        continue
    print(f"Cycle {id}, number of elements: {len(np.unique(repr_cycle))}")
    print(np.sort(np.unique(repr_cycle)))
    figsize = (len(repr_cycle), len(repr_cycle))
    figsize = (8,12)
    ax = draw_repr_cycle(repr_cycle, figsize,data, Dist_X)
    os.makedirs(f"plots/ssim", exist_ok=True)
    plt.savefig(f"plots/ssim/cycle_rep_codomain_{id}.png")

### Compute matching from classes to all subsets 

In [None]:
subsets_lists = []
for i in range(NUM_EXP):
    subsets_lists.append(list(range(i*72, (i+1)*72)))

In [None]:
base_class_idx = 1
base_class = subsets_lists[base_class_idx]
other_pts = []
for i, other_class in enumerate(subsets_lists):
    if i != base_class_idx:
        other_pts += other_class
# end for 
subsets_lists = [subsets_lists[base_class_idx], other_pts]

In [None]:
len(subsets_lists)

In [None]:
ibfm_class = []
for subset_indices in subsets_lists:
    print(len(subset_indices))
    Dist_S = Dist_X[:,subset_indices][subset_indices]
    ibfm_class.append(ibfm.get_IBloFunMatch_subset(Dist_S, Dist_X, subset_indices, output_dir, max_rad=-1, num_it=1, store_0_pm=False, points=False))

In [None]:
unmatched_list = []
double_matched_list = []
for dim in range(2):
    block_0 = ibfm_class[0][f"block_function_{dim:d}"]
    block_1 = ibfm_class[1][f"block_function_{dim:d}"]
    common = [i for i in block_0 if i in block_1]
    unmatched = [i for i in range(ibfm_class[0][f"X_barcode_{dim:d}"].shape[0]) if ((i not in block_0) and (i not in block_1))]
    double_matched_list.append(common)
    unmatched_list.append(unmatched)

In [None]:
print("Repeated dim 1")
print(double_matched_list[1])
print("Unmatched dim 1")
print(unmatched_list[1])

In [None]:
for idxc, ibfm_out in enumerate(ibfm_class):
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,6))
    ibfm.plot_matching(ibfm_out, ax, fig, block_function=True, codomain_int=unmatched_list[1], repeated_codomain=double_matched_list[1])

In [None]:
print("Repeated dim 0")
print(double_matched_list[0])
print("Unmatched dim 0")
print(unmatched_list[0])

In [None]:
for ibfm_out in ibfm_class:
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,8))
    ibfm.plot_matching(ibfm_out, ax, fig, block_function=True, dim=0, codomain_int=unmatched_list[0], repeated_codomain=double_matched_list[0])

### Let us take a subset in a given class

In [None]:
subsets_lists = []
for i in range(NUM_EXP):
    subsets_lists.append(list(range(i*72, (i+1)*72)))

In [None]:
NUM_EXP

In [None]:
data_class = subsets_lists[0]
rng = default_rng(10)
data_subset_indices = rng.choice(range(len(data_class)), SUBSET_SIZE+5, replace=False)
data_subset =[data_class[i] for i in data_subset_indices]
data_complement = [i for i in range(len(data_class)) if i not in data_subset_indices]

In [None]:
ibfm_subset = []
for subset_indices in [data_subset_indices, data_complement]:
    subset_indices_global = [data_class[i] for i in subset_indices]
    Dist_C = Dist_X[:,data_class][data_class]
    Dist_S = Dist_X[:,subset_indices_global][subset_indices_global]
    print(Dist_C.shape)
    print(subset_indices)
    assert(len(subset_indices)==len(np.unique(subset_indices)))
    ibfm_subset.append(ibfm.get_IBloFunMatch_subset(Dist_S, Dist_C, subset_indices, output_dir, max_rad=-1, num_it=1, store_0_pm=False, points=False))

In [None]:
unmatched_list = []
double_matched_list = []
for dim in range(2):
    block_0 = ibfm_subset[0][f"block_function_{dim:d}"]
    block_1 = ibfm_subset[1][f"block_function_{dim:d}"]
    common = [i for i in block_0 if i in block_1]
    unmatched = [i for i in range(ibfm_subset[0][f"X_barcode_{dim:d}"].shape[0]) if ((i not in block_0) and (i not in block_1))]
    double_matched_list.append(common)
    unmatched_list.append(unmatched)

In [None]:
for ibfm_out in ibfm_subset:
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,8))
    ibfm.plot_matching(ibfm_out, ax, fig, block_function=True, dim=0, codomain_int=unmatched_list[0], repeated_codomain=double_matched_list[0])