# Redundancy bottleneck examples

This jupyter notebook can be used to generate the figures that appear in the manuscript:

  * A Kolchinsky, "Partial information decomposition as information bottleneck", 2024.

Specifically, we study the folowing three examples: unique gate (Example 1), four binary symmetric channels (Example 3), and 3-spin target (Example 4).

In [None]:
# First, define some useful functions
import numpy as np
import redundancy_bottleneck as rb

def get_rb_curve(pY, src_cond_dists, beta_vals, verbose=False):
    # Compute RB values using the functions in redundancy_bottleneck.py
    
    # For convenience, print mutual information between each source and the target
    for src, p in enumerate(src_cond_dists):
        print('MI(X_%d;Y)=%g'%(src+1, rb.mi(p*pY[None,:])))
    
    data = []
    
    baseline_rQgZSs = None

    # Go in descendending order for better because it seems to work better 
    # to try solutions from higher better at lower better than vice versa
    for beta in sorted(beta_vals, reverse=True):
        o = rb.get_rb_value(beta, pY, src_cond_dists, 
                            num_retries=10,
                            baseline_rQgZSs=baseline_rQgZSs)
        data.append(o)
        baseline_rQgZSs = [o.rQgZS,]  # Reuse previous solution
        
        if verbose:
            print(beta, o.prediction, o.compression, o.src_prediction, o.src_compression)
        
    if not len(data):
        raise Exception('No data created')

    return data
    


In [None]:
# Plotting functions

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
import os


def do_plot(data, figname=None):
    def subplotlbl(s):
        plt.text(-0.15, 1.10, '('+s+')', transform=plt.gca().transAxes, fontsize=15, fontweight='bold', fontname='sans-serif', va='top')
    
    def g(s,data):
        return np.array([getattr(o,s) for o in data])
    
    def myscatter(xx, yy, color_vals):
        markersize=16
        plt.scatter(xx, yy, s=markersize, c=color_vals, cmap='viridis')
        
    def plotxy(xvals, yvals, color_vals, upperbound=None):
        plt.plot(xvals, yvals, c='k',zorder=-10)
        myscatter(xvals, yvals, color_vals)
        plt.xlim([xvals.min(), xvals.max()])
        if upperbound is not None:
            ub_val, ub_lbl = upperbound
            plt.plot(xvals, xvals*0+ub_val, ls=':', c='k')
            plt.ylim([0,ub_val*1.1])
            plt.text(0.77, 0.9, ub_lbl, transform=plt.gca().transAxes, ha='right', va='top', fontsize=12)
    
    sns.set(style='white')
    matplotlib.rcParams['mathtext.fontset']   = 'stix'
    matplotlib.rcParams['font.family']        = 'STIXGeneral'
    matplotlib.rcParams['text.usetex']        = True
    matplotlib.rcParams['text.latex.preamble']="""
    \\usepackage{newtxtext}
    \\usepackage{bm}
    """
    

    fig = plt.figure(figsize=(16,3.25))
    sns.set(font_scale=1.2, style='white')
    gs  = matplotlib.gridspec.GridSpec(1,4,wspace=0.35, width_ratios=[1, 1, 1, 1], figure=fig)
    
    plot_beta_vals = g('beta', data)
    comp_vals      = g('compression', data)
    pred_vals      = g('prediction', data)
    
    miZS_Y = rb.mi(data[0].pY_ZS)
    miS_ZY = rb.mi(data[0].pS_ZY)
    print('I(Z,S;Y)=%g, I(Z,Y;S)=%g' % (miZS_Y, miS_ZY))
    
    
    fig.add_subplot(gs[0])
    plotxy(plot_beta_vals, pred_vals, color_vals=plot_beta_vals, upperbound=(miZS_Y, r'$I(S,Z;Y)$'))
    plt.ylabel(r'Prediction $I(Q;Y|S)$')
    plt.xlabel(r'$\beta$')
    plt.title(r'Prediction')
    subplotlbl('a')
    
    fig.add_subplot(gs[1])
    plotxy(plot_beta_vals, comp_vals, color_vals=plot_beta_vals, upperbound=(miS_ZY, r'$I(Z;S\vert Y)$'))
    plt.ylabel(r'Compression $I(Q,S|Y)$')
    plt.xlabel(r'$\beta$')
    plt.title(r'Compression')
    subplotlbl('b')
    
    fig.add_subplot(gs[2])
    plotxy(comp_vals, pred_vals, color_vals=plot_beta_vals, upperbound=(miZS_Y, r'$I(S,Z;Y)$'))
    plt.fill_between(comp_vals,pred_vals,color='grey',alpha=0.1)
    
    plt.xlabel(r'Compression $I(Q;S\vert Y)$')
    plt.ylabel(r'Prediction $I(Q;Y\vert S)$')
    plt.title(r'RB curve')
    subplotlbl('c')
    
    fig.add_subplot(gs[3])
    plt.xlabel(r'Compression $I(Q;S=s\vert Y)$')
    plt.ylabel(r'Prediction $I(Q;Y\vert S=s)$')
    srcmarkers=['v','x','+','v']
    srccolors=['royalblue','tab:red','tab:orange','Magenta']
    ls = ['-','--','-.',':']
    src_comp_vals = g('src_compression', data)
    src_pred_vals = g('src_prediction' , data)
    
    for src in range(len(src_cond_dists)):
        plt.plot(src_comp_vals[:,src], src_pred_vals[:,src],lw=2.5, ls=ls[src], label='$X_%d$'%(src+1))
    rng = np.array([-.03,1.03])
    plt.xlim(rng*src_comp_vals.max())
    plt.ylim(rng*src_pred_vals.max())
    
    plt.legend(frameon=False, handlelength=1.6, fontsize=13)
    plt.title(r'RB curve, by source')
    subplotlbl('d')
    
    plt.tight_layout()
    
    if figname is not None:
        fname = 'rbcurve-%s.pdf' % figname
        plt.savefig(fname, bbox_inches='tight')


# Example 1
The UNIQUE gate

In [None]:
beta_vals      = np.linspace(1,2.5,50)
pY             = np.array([.5,.5])
src_cond_dists = (np.eye(2),
                  np.array([[0.5,0.5],[0.5,0.5]]))

data = get_rb_curve(pY, src_cond_dists, beta_vals)
do_plot(data, 'unique')


# Example 3
Binary target with four binary symmetric channels

In [None]:
beta_vals      = np.linspace(1,5,50)
pY             = np.array([.5,.5])
src_cond_dists = []
for src in range(4):
    if src in [0,1]: p_flip = 0.1
    elif src == 2  : p_flip = 0.2
    else           : p_flip = 0.5
    src_cond_dists.append(np.array([[1-p_flip,p_flip],[p_flip,1-p_flip]]))


data = get_rb_curve(pY, tuple(src_cond_dists), beta_vals)
do_plot(data, '4sources')

# Example 4

Target contains 3 binary spins, and three sources consisting of 2 binary spins each

In [None]:
beta_vals = np.linspace(1e-2,5,50)
bitsY = 3
bitsX = 2
pY = np.ones(2**bitsY)/(2**bitsY)
src_cond_dists = []
for i in range(3):
    c = np.zeros( (2**bitsX,2**bitsY) )
    for yval in range(2**bitsY):
        yvalb = f"{yval:0{bitsY}b}"
        if i in [0,1,2]:
            if   i == 0: xixs = [0,1]
            elif i == 1: xixs = [0,1]
            elif i == 2: xixs = [0,2]
            xvalb = "".join([yvalb[j] for j in xixs])
            c[int(xvalb,2),yval] = 1.0
        else:
            c[:2,:] = .5
    src_cond_dists.append(c)
    
data = get_rb_curve(pY, tuple(src_cond_dists), beta_vals)
do_plot(data, '3spins')