Import packages.

In [None]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [None]:
import time
import numpy                   as np
import matplotlib.pyplot       as plt
from sklearn.cluster           import KMeans
from sklearn.metrics           import pairwise_distances
from sklearn.preprocessing     import LabelEncoder
from matplotlib                import cm
from ot                        import sinkhorn as wasserstein
from ot.gromov                 import entropic_gromov_wasserstein as gromov_wasserstein

In [None]:
from MREC         import *
from quantization import *
from matching     import *
from scores       import *

Dataset.

In [None]:
n_pts_per_clus1, n_pts_per_clus2 = 100, 100
X1 = np.concatenate([np.random.normal(loc=0.,  scale=1., size=n_pts_per_clus1)[:,np.newaxis],  
                     np.random.normal(loc=0.,  scale=1., size=n_pts_per_clus1)[:,np.newaxis]], axis=1)
Y1 = np.concatenate([np.random.normal(loc=0.,  scale=1., size=n_pts_per_clus1)[:,np.newaxis],  
                     np.random.normal(loc=10., scale=1., size=n_pts_per_clus1)[:,np.newaxis]], axis=1)
Z1 = np.concatenate([np.random.normal(loc=30., scale=1., size=n_pts_per_clus1)[:,np.newaxis],  
                     np.random.normal(loc=0.,  scale=1., size=n_pts_per_clus1)[:,np.newaxis]], axis=1)
X2 = np.concatenate([np.random.normal(loc=2.,  scale=1., size=n_pts_per_clus2)[:,np.newaxis],  
                     np.random.normal(loc=-1.,  scale=1., size=n_pts_per_clus2)[:,np.newaxis]], axis=1)
Y2 = np.concatenate([np.random.normal(loc=1.,  scale=1., size=n_pts_per_clus2)[:,np.newaxis],  
                     np.random.normal(loc=12., scale=1., size=n_pts_per_clus2)[:,np.newaxis]], axis=1)
Z2 = np.concatenate([np.random.normal(loc=28., scale=1., size=n_pts_per_clus2)[:,np.newaxis],  
                     np.random.normal(loc=-2.,  scale=1., size=n_pts_per_clus2)[:,np.newaxis]], axis=1)
X1, X2 = np.vstack([X1,Y1,Z1]), np.vstack([X2,Y2,Z2])
lab1 = np.array(  [1 for _ in range(n_pts_per_clus1)] 
                + [2 for _ in range(n_pts_per_clus1)] 
                + [3 for _ in range(n_pts_per_clus1)])
lab2 = np.array(  [1 for _ in range(n_pts_per_clus2)] 
                + [2 for _ in range(n_pts_per_clus2)] 
                + [3 for _ in range(n_pts_per_clus2)])

In [None]:
D1, D2 = pairwise_distances(X1), pairwise_distances(X2)
D12 = pairwise_distances(X1, X2)

In [None]:
print(X1.shape, X2.shape, D1.shape, D2.shape, D12.shape)

In [None]:
plt.figure()
labname1 = {1: 'Group 1 - Data set 1', 2: 'Group 2 - Data set 1', 3: 'Group 3 - Data set 1'}
labname2 = {1: 'Group 1 - Data set 2', 2: 'Group 2 - Data set 2', 3: 'Group 3 - Data set 2'}
clrmap_dicts = {1: cm.get_cmap('Greens'), 2: cm.get_cmap('Oranges')}
clrscale = {1: 0.25, 2:0.5, 3:0.75}
for i in [1,2,3]:
    I1, I2 = np.argwhere(lab1==i).ravel(), np.argwhere(lab2==i).ravel()
    plt.scatter(X1[I1,0], X1[I1,1], marker='.', s=30, alpha=.9, 
                label=labname1[i], color=clrmap_dicts[1](clrscale[i]))
    plt.scatter(X2[I2,0], X2[I2,1], marker='+', s=30, alpha=.9, 
                label=labname2[i], color=clrmap_dicts[2](clrscale[i]))
plt.legend()
plt.grid()
plt.show()

Compute matching.

In [None]:
start = time.time()
gamma = wasserstein(np.ones([len(X1)])/len(X1), np.ones([len(X2)])/len(X2), D12, np.quantile(D12, .5))
mappingFSW12 = np.argmax(gamma, axis=1)
mappingFSW21 = np.argmax(gamma, axis=0)
end = time.time()
timeFSW = end-start

In [None]:
start = time.time()
gamma = gromov_wasserstein(D1, D2, np.ones([len(D1)])/len(D1), np.ones([len(D2)])/len(D2), 'square_loss', 1e1)
mappingFSGW12 = np.argmax(gamma, axis=1)
mappingFSGW21 = np.argmax(gamma, axis=0)
end = time.time()
timeFSGW = end-start

In [None]:
matching         = SinkhornWasserstein
matching_params  = SinkhornWassersteinMedianParameters
quant            = KMeansQuantization
quant_params     = QuantizationSizeMinParameters(20)

start = time.time()
mappingSW12 = MREC(X1=X1, X2=X2, X12=None, D1=None, D2=None, D1quant=None, D2quant=None,
                   matching=matching, matching_params=matching_params,
                   quantization1=quant, quantization_params1=quant_params,
                   quantization2=quant, quantization_params2=quant_params,
                   threshold=10, last_matching='match', impose_equal=True)
mappingSW21 = MREC(X1=X2, X2=X1, X12=None, D1=None, D2=None, D1quant=None, D2quant=None,
                   matching=matching, matching_params=matching_params,
                   quantization1=quant, quantization_params1=quant_params,
                   quantization2=quant, quantization_params2=quant_params,
                   threshold=10, last_matching='match', impose_equal=True)
end = time.time()
timeSW = end-start

In [None]:
matching        = SinkhornGromovWasserstein
matching_params = {'epsilon': 5*1e0, 'tol': 1e-9, 'max_iter': 1000, 'loss_fun': 'square_loss'}
quant           = KMeansQuantization
quant_params    = QuantizationSizeMinParameters(20)

start = time.time()
mappingSGW12 = MREC(X1=X1, X2=X2, X12=None, D1=D1, D2=D2, D1quant=D1, D2quant=D2,
                    matching=matching, matching_params=matching_params,
                    quantization1=quant, quantization_params1=quant_params,
                    quantization2=quant, quantization_params2=quant_params,
                    threshold=10, last_matching='match', impose_equal=True, mode='5')
mappingSGW21 = MREC(X1=X2, X2=X1, X12=None, D1=D2, D2=D1, D1quant=D2, D2quant=D1,
                    matching=matching, matching_params=matching_params,
                    quantization1=quant, quantization_params1=quant_params,
                    quantization2=quant, quantization_params2=quant_params,
                    threshold=10, last_matching='match', impose_equal=True, mode='5')
end = time.time()
timeSGW = end-start

In [None]:
matching        = NonConvexGromovWasserstein
matching_params = {'num_iter':50, 'sigma_m_0':5., 'mu':10., 'method':'L-BFGS-B', 'map_init':None, 'verbose':False}
quant           = KMeansQuantization
quant_params    = QuantizationSizeMinParameters(20)

start = time.time()
mappingNCGW12 = MREC(X1=X1, X2=X2, X12=None, D1=D1, D2=D2, D1quant=D1, D2quant=D2,
                     matching=matching, matching_params=matching_params,
                     quantization1=quant, quantization_params1=quant_params,
                     quantization2=quant, quantization_params2=quant_params,
                     threshold=10, last_matching='constant', impose_equal=True, mode='5')
mappingNCGW21 = MREC(X1=X2, X2=X1, X12=None, D1=D2, D2=D1, D1quant=D2, D2quant=D1,
                     matching=matching, matching_params=matching_params,
                     quantization1=quant, quantization_params1=quant_params,
                     quantization2=quant, quantization_params2=quant_params,
                     threshold=10, last_matching='constant', impose_equal=True, mode='5')
end = time.time()
timeNCGW = end-start

In [None]:
matching        = CPLEXConvexGromovWasserstein
eng             = matlab.engine.start_matlab()
matching_params = {'eng': eng, 'maxtime': 1000}
quant           = KMeansQuantization
quant_params    = {'n_clusters': 10}

start = time.time()
mappingCPLEXGW12 = MREC(X1=None, X2=None, X12=None, D1=D1, D2=D2, D1quant=D1, D2quant=D2,
                        matching=matching, matching_params=matching_params,
                        quantization1=quant, quantization_params1=quant_params,
                        quantization2=quant, quantization_params2=quant_params,
                        threshold=10, last_matching='constant', impose_equal=True)
mappingCPLEXGW21 = MREC(X1=None, X2=None, X12=None, D1=D2, D2=D1, D1quant=D2, D2quant=D1,
                        matching=matching, matching_params=matching_params,
                        quantization1=quant, quantization_params1=quant_params,
                        quantization2=quant, quantization_params2=quant_params,
                        threshold=10, last_matching='constant', impose_equal=True)
end = time.time()
eng.quit()
timeCPLEXGW = end-start

In [None]:
for (mapping12, mapping21, t) in [
                                  (mappingFSW12,     mappingFSW21,     timeFSW), 
                                  (mappingFSGW12,    mappingFSGW21,    timeFSGW), 
                                  (mappingSW12,      mappingSW21,      timeSW), 
                                  (mappingSGW12,     mappingSGW21,     timeSGW), 
                                  (mappingNCGW12,    mappingNCGW21,    timeNCGW), 
                                  #(mappingCPLEXGW12, mappingCPLEXGW21, timeCPLEXGW),
                                 ]:
    
    XX1, XX2 = X2[mapping12,:], X1[mapping21,:]
    integrated = np.vstack([0.5 * (X1 + XX1), 0.5 * (X2 + XX2)])
    integrated_labs = np.concatenate([lab1,lab2])
    
    distortion12 = distortion_score(X1=X1, X2=X2, mapping=mapping12, computation_mode='3', metric='euclidean')
    distortion21 = distortion_score(X1=X2, X2=X1, mapping=mapping21, computation_mode='3', metric='euclidean')
    
    quality12 = len(np.argwhere(np.array([lab1[i] == lab2[mapping12[i]] for i in range(len(X1))])))/len(X1)
    quality21 = len(np.argwhere(np.array([lab2[i] == lab1[mapping21[i]] for i in range(len(X2))])))/len(X2)
    print(t, quality12, quality21, distortion12, distortion21)
    
    plt.figure(figsize=(7,7))
    plt.scatter(integrated[:len(X1),0], integrated[:len(X1),1])
    plt.scatter(integrated[len(X1):,0], integrated[len(X1):,1])
    plt.show()
    
    colors = LabelEncoder().fit_transform(integrated_labs)
    plt.figure(figsize=(7,7))
    plt.scatter(integrated[:,0], integrated[:,1], c=colors)
    plt.show()
    
    list_kth = [5,10,15,20,25,30,35,40,45,50,55,60,65,70,75,80,85,90,95,150,200,250,300]
    mixings = compute_mixings(integrated, len(X1), list_kth)
    plt.figure(figsize=(7,7))
    plt.boxplot(mixings, labels=[str(k) for k in list_kth], showfliers=True)
    plt.show()
