In [None]:
import os
import time
import glob
import dit
import time
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.optimize import minimize
os.chdir("../")
from helpers.group_helpers import loadsyms, classifylowerorders, classifyoversized
from helpers.compare_helpers import addbestsym
from jointpdfpython3.measures import synergistic_entropy_upper_bound, append_random_srv
from jointpdfpython3.params_matrix import params2matrix_incremental,matrix2params_incremental
from jointpdfpython3.JointProbabilityMatrix import JointProbabilityMatrix
from syndisc.syndisc import self_disclosure_channel
os.chdir("./plot_notebooks")
import seaborn as sns;sns.set()

## only load non-oversized constructed SRVs S, since |S|>|Xi| is not supported in jointpdf

In [None]:
states = 3
lenX = 2
subjects=np.arange(lenX)
symss = []
data = {'totmi':[],'indivmi':[],'symsort':[],'states':[]}

concsyms, syms = loadsyms(states)
syms = classifyoversized(syms,states)
if 'lower order' in syms.keys():
    syms = classifylowerorders(states,syms)

# get indexes of all non-oversized SRVs
listsyms = []
symids = {}
previd = 0
newsyms = {}
for k in syms.keys():
    if 'oversized' not in k:
        newsyms[k] = syms[k]
        for s in syms[k]:
            listsyms.append(s)
            symids[k] = np.arange(previd,previd+len(syms[k]))
        previd = previd+len(syms[k])
syms = newsyms

## Start with different types of initial guesses given X

In [None]:
def costfunc(srvparams,jXS,lenJXS,parX,subjects,upper):
    params2matrix_incremental(jXS,parX+list(srvparams))
    totmi = jXS.mutual_information(subjects,[lenJXS-1])
    indivmis = sum([jXS.mutual_information([i],[lenJXS-1]) for i in subjects])
    cost=abs((upper-(totmi-indivmis))/upper)
    if totmi != 0:
        return cost+((indivmis/totmi))
    else:
        return cost+(indivmis)

minimize_options = {'ftol': 1e-6}
def symsyninfo(states,lenX,parX,upper,jX,syms,initialtype='PSRV',costf=costfunc,verbose=None):
    subjects = list(range(lenX))
    jXS = append_random_srv(jX,parX,1)
    lenJXS = len(jXS)
    bestsymid=-1

    # find constructed srv with lowest cost
    if initialtype!='random':
        pXSym,bestsymid = addbestsym(lenX,jX,upper,syms)
        print("BESTSYMID type",initialtype," = ",bestsymid)
        jXS.joint_probabilities.joint_probabilities = pXSym

    # optimize initial guess
    freeparams = (states**(lenJXS))-(states**(len(jX)))
    symparams = matrix2params_incremental(jXS)[-freeparams:]
    optres_ix = minimize(costf,
                            symparams,
                            bounds=[(0.0, 1.0)]*freeparams,
                            # callback=(lambda xv: param_vectors_trace.append(list(xv))) if verbose else None,
                            args=(jXS,lenJXS,parX,subjects,upper),options=minimize_options)
    params2matrix_incremental(jXS,parX+list(optres_ix.x))
    return bestsymid,jXS

## Optimize different initial guesses (random, psrv, best of all constructed) for random X

In [None]:
data = {'systemID':[],'parX':[],'upper':[],'totmi':[],'indivmi':[],'runtime':[],'exp_sort':[]}
types = ['random','PSRV','bestofall','syndisc']
# types = ['PSRV','random'] # only select some types to compare
samples = 5
cursyms = []
for i in range(samples):
    print(i,time.strftime("%H:%M:%S", time.localtime()))
    jX = JointProbabilityMatrix(lenX,states)
    upper = synergistic_entropy_upper_bound(jX)
    pX = jX.joint_probabilities.joint_probabilities
    parX = matrix2params_incremental(jX)
    for t in types:
        if t !='syndisc':
            if t == 'PSRV':
                cursyms = syms['PSRVs']
            elif t == 'bestofall':
                cursyms = listsyms
            else:
                cursyms = []
            before = time.time()
            best, jXS = symsyninfo(states,lenX,parX,upper,jX,\
                   cursyms,initialtype=t,costf=costfunc)
            data['runtime'].append(time.time()-before)
            data['indivmi'].append(sum([jXS.mutual_information([s],[lenX]) for s in subjects]))
            data['totmi'].append(jXS.mutual_information(subjects,[lenX]))
            data['exp_sort'].append(t)
            data['systemID'].append(i)
            data['upper'].append(upper)
            data['parX'].append(parX)
        else:
            ditjX = dit.Distribution.from_ndarray(pX)
            before = time.time()
            syn, probs = self_disclosure_channel(ditjX)
            data['runtime'].append(time.time()-before)
            data['indivmi'].append(0)
            data['totmi'].append(syn)
            data['exp_sort'].append(t)
            data['systemID'].append(i)
            data['upper'].append(upper)
            data['parX'].append(parX)

initialdata = pd.DataFrame(data=data)
initialdata.to_pickle("../../results/test/finalinitialcomparison"+str(states)+".pkl") # note data is saved in 'test' folder

## load \& plot initial syms

In [None]:
# # load all initial guess data and concatenate from 'rq31' folder
folder = '../../results/rq31/'
os.chdir(folder)
files = glob.glob(folder+"*.pkl")
ds = []
if len(files) != 0:
    for f in files:
        if 'initial' in f:
            print(f,f[-5])
            cur = pd.read_pickle(f)
            cur['states'] = int(f[-5])
            cur=cur[(cur['systemID']<150)]
            ds.append(cur)

d = pd.concat(ds)
os.chdir('../../code/plot_notebooks')
# d=d[(d['exp_sort']!='syndisc')]
d['norm indivmi']=d['indivmi']/d['totmi']
xl = 'I(X;S)'
d[xl]=d['totmi']
d['runtime (seconds)'] = d['runtime']
d.keys()

In [None]:
sns.set_context("paper", font_scale = 1.6)

g = sns.jointplot(data=d,x='runtime (seconds)',y='norm indivmi',hue='exp_sort',
                  s=50,palette='tab10')
g.fig.suptitle("states = [2,3,4]",y=1.0,fontsize=13)
# g.fig.suptitle("states = "+str(states),y=1.0,fontsize=13)
legend_properties = {'size':14}
legendMain=g.ax_joint.legend(prop=legend_properties,loc='upper right')
fig = g.fig.get_figure()

g = sns.jointplot(data=d,x='I(X;S)',y='norm indivmi',hue='exp_sort',
                  s=50,palette='tab10')
g.fig.suptitle("states = "+str(states),y=1.0,fontsize=13)
legend_properties = {'size':14}
legendMain=g.ax_joint.legend(prop=legend_properties,loc='upper right')
fig = g.fig.get_figure()