In [1]:
import copy
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
import autograd.numpy as np
from sklearn import datasets, preprocessing
import pandas as pd
from pymanopt.solvers import TrustRegions
from manopt_dr.core import gen_ldr
from manopt_dr.predefined_func_generator import *
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os

In [2]:
# Generate data

dataset = datasets.load_wine()

X = dataset.data
y = dataset.target
n_samples, n_features = X.shape
n_components = 2
X = preprocessing.scale(X)


In [3]:
# generalized cPCA

GCPCA = gen_ldr(gen_cost_gcpca, gen_default_proj)
label_set = [0,1,2]


def return_partition(my_list):
    filtered=[]
    for l in range(1,len(my_list)):
        for c in itertools.combinations(my_list,l):
            filtered.append(c)
    return filtered

partitions = return_partition(label_set)

def print_groups(groups):
    s = ""
    for g in range(len(groups)):
        if g == len(groups) - 1:
            s += str(groups[g])
        else:
            s += str(groups[g]) + ","
    return s

def print_groups_index(index):
    groups =  partitions[index]
    s = ""
    for g in range(len(groups)):
        if g == len(groups) - 1:
            s += str(groups[g])
        else:
            s += str(groups[g]) + ","
    return s

for p in partitions:
    print(print_groups(p))
    
def Trials(g1,g2):    

#    print(par[0])
#    print(par[1])
    y_tg = copy.deepcopy(y)
    y_tg = np.asarray(y_tg)

    for i in g1:
        y_tg[y_tg == i] = -1

    y_bg = copy.deepcopy(y)
    y_bg = np.asarray(y_bg)

    for i in g2:
        y_bg[y_bg == i] = -1

    gcpca = GCPCA(n_components=n_components).fit(X, y_tg, y_bg)
    Z = gcpca.transform(X)
    cost = gcpca.get_final_cost()
    contrastiveness = 1 / cost

    # Plot
    
    #print(Z)
    #if index1 == index2:
    #    return contrastiveness
    #print(print_groups(index1))
    #print(print_groups(index2))
    plt.figure(figsize=(10, 10))

    tg_x = [Z[i][0] for i in range(len(y)) if y[i] in g1]
    tg_y = [Z[i][1] for i in range(len(y)) if y[i] in g1]
    bg_x = [Z[i][0] for i in range(len(y)) if y[i] in g2]
    bg_y = [Z[i][1] for i in range(len(y)) if y[i] in g2]

    #for i in range(len(Z)):
    #    if y[i] in g1:
    #        plt.plot(Z[i][0], Z[i][1],  'ro')
    #    else:
    #        plt.plot(Z[i][0], Z[i][1],  'ko')
    
    plt.scatter(tg_x, tg_y, marker='s',label = "target",facecolors='none', edgecolors='r')
    plt.scatter(bg_x, bg_y, marker='o',label = "background",facecolors='none', edgecolors='black')

    plt.legend(loc='best', shadow=False, scatterpoints=1)
    plt.title(
        f'Generalized cPCA of Wine dataset, target = ({print_groups(g1)}), background = ({print_groups(g2)})  (cost: {cost:.3f}, contrastiveness: {contrastiveness:.3f})' ,
        fontsize=8)
    #plt.show()
    plt.savefig(f'{print_groups(g1)} - {print_groups(g2)}.png')
    plt.clf()
    
    return contrastiveness


0
1
2
0,1
0,2
1,2


In [4]:
map_dataset = []
for i in range(len(partitions)):
    row = []
    rec = [l for l in label_set if l not in partitions[i]]
    c = Trials(partitions[i],rec)
    row = [ list(partitions[i]), rec ,c]
    map_dataset.append(row)

#map_dataset = pd.DataFrame(map_dataset)


<Figure size 720x720 with 0 Axes>

<Figure size 720x720 with 0 Axes>

<Figure size 720x720 with 0 Axes>

<Figure size 720x720 with 0 Axes>

<Figure size 720x720 with 0 Axes>

<Figure size 720x720 with 0 Axes>

In [6]:
#for x in map_dataset:
#    print(x)
columns = [("T : {" + print_groups(x[0]) +  "}",  "B : {" + print_groups(x[1]) + "}") for x in map_dataset]
values = [round(x[2],6) for x in map_dataset]
#print(map_dataset)
#print(columns)
columns.insert(0," ")
values.insert(0,"Contrast")
#print(rows)

#values = pd.DataFrame(values,columns=['Contrastiveness'])
#print(values)

fig = go.Figure(data=[go.Table(header=dict(values = columns, align = "left",   font=dict(color='white', size=12)),
                 cells=dict(values = values))
                              ])
fig.write_image("Table.png")
#map_dataset.rename(columns=lambda s: print_groups_index(s), index=lambda s: print_groups_index(s), inplace = True )
#sns.heatmap(map_dataset,cmap="YlGnBu", linewidths=.5)
#plt.xlabel("Target")
#plt.ylabel("Background")
#plt.gca().invert_yaxis()
#plt.show()
#plt.savefig("heatmap_wine.png")