# Side Information Experiment Notebook

This notebook consolidates the original Python modules (`data_generation.py`, `models.py`, `experiment.py`, `visualization.py`) into sequential cells for easy execution (e.g. Google Colab).

Sections:
1. Imports & Utilities
2. Data Generation (Distributions)
3. Models (Vanilla & Strategic Classifiers)
4. Experiment Orchestration (Parameter Sweep)
5. Visualization (Correlation Space Plots)
6. Demo Run (Small sweep)

You can later expand the parameter sweep for full experiments.

In [None]:
# --- 1. Imports & Global Utilities ---
import numpy as np
import pandas as pd
import itertools
import matplotlib.pyplot as plt
from typing import Dict, Tuple, Any, List

# Reproducibility helper (optional)
def set_global_seed(seed: int = 0):
    np.random.seed(seed)

set_global_seed(0)
print('Imports ready.')

## 2. Data Generation
Defines `BaseDistribution` and `PerturbedDistribution` producing samples (X, Y, Z, U).
- (Y, Z, U) each in {-1, +1}.
- `p` is visibility probability of Z (otherwise masked as NaN).
- Correlations E[YZ], E[YU] can be perturbed by integer shift levels.

In [None]:
class BaseDistribution:
    COMBINATIONS = [
        (-1, -1, -1), (-1, -1, 1), (-1, 1, -1), (-1, 1, 1),
        (1, -1, -1), (1, -1, 1), (1, 1, -1), (1, 1, 1)
    ]
    def __init__(self, seed: int = 0):
        self.rng = np.random.RandomState(seed)
        raw = self.rng.rand(8)
        self.pmf = raw / raw.sum()
        self.mus = {}
        self.sigmas = {}
        for k in self.COMBINATIONS:
            mu = self.rng.randn(2)
            A = self.rng.randn(2, 2)
            Sigma = A @ A.T + np.eye(2) * 0.1
            self.mus[k] = mu
            self.sigmas[k] = Sigma
    def compute_corr(self):
        corr_YZ = 0.0; corr_YU = 0.0
        for prob, (y, z, u) in zip(self.pmf, self.COMBINATIONS):
            corr_YZ += prob * y * z
            corr_YU += prob * y * u
        return float(corr_YZ), float(corr_YU)

class PerturbedDistribution(BaseDistribution):
    def __init__(self, perturb_level_YU=0, perturb_level_YZ=0, p: float = 1.0, seed: int = 10):
        super().__init__(seed=seed)
        self.p = float(p)
        self.S_YU = int(perturb_level_YU)
        self.S_YZ = int(perturb_level_YZ)
        self.pmf = self._apply_perturbations(self.pmf)
        self.corr_YZ, self.corr_YU = self.compute_corr()
    def _joint_to_table(self, pmf):
        return {k: float(p) for p, k in zip(pmf, BaseDistribution.COMBINATIONS)}
    def _table_to_pmf(self, table):
        arr = np.array([table[k] for k in BaseDistribution.COMBINATIONS], dtype=float)
        arr = np.clip(arr, 0.0, None)
        s = arr.sum()
        if s <= 0: arr = np.ones_like(arr)/arr.size
        else: arr = arr / s
        return arr
    def _apply_perturbations(self, pmf):
        table = self._joint_to_table(pmf)
        def _perturb_pair(table, A, B, S):
            states = list(BaseDistribution.COMBINATIONS)
            pAB = {}
            vars = ['Y','Z','U']
            for a in [-1,1]:
                for b in [-1,1]:
                    s_ = 0.0
                    for (y,z,u) in states:
                        vals={'Y':y,'Z':z,'U':u}
                        if vals[A]==a and vals[B]==b: s_ += table[(y,z,u)]
                    pAB[(a,b)] = s_
            S_steps = abs(S); direction = 1 if S>=0 else -1
            for _ in range(S_steps):
                S_plus=[k for k in pAB if k[0]*k[1]==1]; S_minus=[k for k in pAB if k[0]*k[1]==-1]
                P_plus=sum(pAB[k] for k in S_plus); P_minus=sum(pAB[k] for k in S_minus)
                if (direction==1 and P_minus==0) or (direction==-1 and P_plus==0): break
                t = (P_minus/3.0) if direction==1 else (P_plus/3.0)
                if t<=0: break
                src = S_minus if direction==1 else S_plus
                dst = S_plus if direction==1 else S_minus
                src_total=sum(pAB[k] for k in src)
                if src_total<=0: break
                for k in src:
                    frac = pAB[k]/src_total if src_total>0 else 0
                    delta = t*frac
                    pAB[k] = max(pAB[k]-delta,0.0)
                dst_total=sum(pAB[k] for k in dst)
                if dst_total<=0:
                    per = t/len(dst)
                    for k in dst: pAB[k]+=per
                else:
                    for k in dst:
                        frac=pAB[k]/dst_total
                        pAB[k]+=t*frac
            new_table={}
            for (y,z,u) in states:
                vals={'Y':y,'Z':z,'U':u}; a=vals[A]; b=vals[B]
                p_ab_orig=0.0
                for (yy,zz,uu) in states:
                    v={'Y':yy,'Z':zz,'U':uu}
                    if v[A]==a and v[B]==b: p_ab_orig += table[(yy,zz,uu)]
                if p_ab_orig>0: p_c_given_ab = table[(y,z,u)]/p_ab_orig
                else: p_c_given_ab = 0.5
                new_table[(y,z,u)] = pAB[(a,b)] * p_c_given_ab
            return new_table
        table1 = _perturb_pair(table,'Y','U',self.S_YU)
        table2 = _perturb_pair(table1,'Y','Z',self.S_YZ)
        return self._table_to_pmf(table2)
    def sample(self, n=100):
        idx = self.rng.choice(8, size=n, p=self.pmf)
        X = np.zeros((n,2)); Y=np.zeros(n,int); Z=np.zeros(n,float); U=np.zeros(n,int)
        keys=list(self.COMBINATIONS)
        for i,k in enumerate(keys):
            mask = idx==i; cnt=mask.sum();
            if cnt==0: continue
            mu=self.mus[k]; Sigma=self.sigmas[k]
            X[mask]=self.rng.multivariate_normal(mu,Sigma,size=cnt)
            Y[mask]=k[0]; Z[mask]=k[1]; U[mask]=k[2]
        vis = self.rng.rand(n) < self.p
        Z_masked = np.where(vis, Z, np.nan)
        return {"X":X, "Y":Y, "Z":Z_masked, "U":U}

_dist_test = PerturbedDistribution(-1,2,p=0.5,seed=42)
print('Sample corr (YZ,YU):', _dist_test.corr_YZ, _dist_test.corr_YU)

## 3. Models
Defines `VanillaClassifier` and `StrategicClassifier` with numerically stable training.

In [None]:
def logistic_loss(y, s):
    return np.logaddexp(0, -y * s)

def _stable_sigmoid(x):
    out = np.empty_like(x, dtype=float)
    pos = x >= 0; out[pos] = 1.0 / (1.0 + np.exp(-x[pos]))
    neg = ~pos; ex = np.exp(x[neg]); out[neg] = ex / (1.0 + ex)
    return out

def _clip_grad(grad, max_norm=5.0):
    norm = np.linalg.norm(grad)
    if norm > max_norm and norm > 0: return grad * (max_norm / norm)
    return grad

def sm_max(a,b,tau=5.0):
    m=np.maximum(a,b)
    return (np.log(np.exp(tau*(a-m))+np.exp(tau*(b-m))) + tau*m)/tau

def sm_min(a,b,tau=5.0):
    return -sm_max(-a,-b,tau)

class BaseClassifier:
    def __init__(self):
        self.w_f=None; self.w_g=None; self.history={}
    def _phi_f(self,X,Z):
        n=X.shape[0]; Zf = np.where(np.isnan(Z),0.0,Z)
        return np.column_stack([np.ones(n), X[:,0], X[:,1], Zf])
    def _phi_g(self,X):
        n=X.shape[0]; return np.column_stack([np.ones(n), X[:,0], X[:,1]])

class VanillaClassifier(BaseClassifier):
    def __init__(self, lr=1.0, epochs=200, weight_decay=1e-3):
        super().__init__(); self.lr=lr; self.epochs=epochs; self.weight_decay=weight_decay
    def fit(self,X,Y,Z,U=None):
        n=X.shape[0]; Phi_f=self._phi_f(X,Z); Phi_g=self._phi_g(X)
        self.w_f=np.zeros(Phi_f.shape[1]); self.w_g=np.zeros(Phi_g.shape[1]); self.history={'loss':None}
        last=None
        for ep in range(self.epochs):
            mask_obs=~np.isnan(Z); s=np.zeros(n)
            s[mask_obs]=Phi_f[mask_obs]@self.w_f; s[~mask_obs]=Phi_g[~mask_obs]@self.w_g
            losses=logistic_loss(Y,s); loss=losses.mean()+0.5*self.weight_decay*(np.sum(self.w_f**2)+np.sum(self.w_g**2))
            last=float(loss)
            coeff=-Y*_stable_sigmoid(-Y*s)
            grad_w_f=np.zeros_like(self.w_f); grad_w_g=np.zeros_like(self.w_g)
            if mask_obs.any(): grad_w_f=(Phi_f[mask_obs].T @ coeff[mask_obs])/n + self.weight_decay*self.w_f
            if (~mask_obs).any(): grad_w_g=(Phi_g[~mask_obs].T @ coeff[~mask_obs])/n + self.weight_decay*self.w_g
            grad_w_f=_clip_grad(grad_w_f); grad_w_g=_clip_grad(grad_w_g)
            self.w_f-=self.lr*grad_w_f; self.w_g-=self.lr*grad_w_g
        self.history['loss']=last; return self
    def eval(self,X,Y,Z,U=None):
        n=X.shape[0]; Phi_f=self._phi_f(X,Z); Phi_g=self._phi_g(X)
        mask_obs=~np.isnan(Z); s=np.zeros(n)
        s[mask_obs]=Phi_f[mask_obs]@self.w_f; s[~mask_obs]=Phi_g[~mask_obs]@self.w_g
        loss=logistic_loss(Y,s).mean(); preds=np.sign(s); acc=(preds==Y).mean()
        return {'loss':float(loss),'accuracy':float(acc)}

class StrategicClassifier(BaseClassifier):
    def __init__(self, lr=1.0, epochs=200, tau=5.0, lam=1.0, weight_decay=1e-3):
        super().__init__(); self.lr=lr; self.epochs=epochs; self.tau=tau; self.lam=lam; self.weight_decay=weight_decay
    def fit(self,X,Y,Z,U):
        n=X.shape[0]; Phi_f=self._phi_f(X,Z); Phi_g=self._phi_g(X)
        self.w_f=np.zeros(Phi_f.shape[1]); self.w_g=np.zeros(Phi_g.shape[1]); self.history={'loss':None}
        mask_obs=~np.isnan(Z); idx_vis=np.where(mask_obs)[0]; idx_hid=np.where(~mask_obs)[0]; last=None
        for ep in range(self.epochs):
            s_f=Phi_f@self.w_f; s_g=Phi_g@self.w_g; s=np.zeros(n); s[~mask_obs]=s_g[~mask_obs]
            if mask_obs.any():
                for i in idx_vis:
                    if U[i]==1: s[i]=sm_max(s_f[i],s_g[i],tau=self.tau)
                    else: s[i]=-sm_max(-s_f[i],-s_g[i],tau=self.tau)
            losses=logistic_loss(Y,s)
            reg=0.0
            if mask_obs.any(): diff=s_f[mask_obs]-s_g[mask_obs]; reg=(diff**2).mean()
            loss=losses.mean()+self.lam*reg+0.5*self.weight_decay*(np.sum(self.w_f**2)+np.sum(self.w_g**2))
            last=float(loss)
            coeff=-Y*_stable_sigmoid(-Y*s)
            grad_w_f=np.zeros_like(self.w_f); grad_w_g=np.zeros_like(self.w_g)
            if len(idx_hid)>0: grad_w_g+=(Phi_g[idx_hid].T @ coeff[idx_hid])/n
            for i in idx_vis:
                sf=s_f[i]; sg=s_g[i]
                if U[i]==1:
                    a=self.tau*sf; b=self.tau*sg; m=max(a,b); wa=np.exp(a-m); wb=np.exp(b-m); denom=wa+wb; dsf=wa/denom; dsg=wb/denom
                else:
                    a=-self.tau*sf; b=-self.tau*sg; m=max(a,b); wa=np.exp(a-m); wb=np.exp(b-m); denom=wa+wb; dsf=-(wa/denom); dsg=-(wb/denom)
                grad_w_f += (Phi_f[i]*(coeff[i]*dsf))/n + (2.0/n)*self.lam*(sf-sg)*Phi_f[i]
                grad_w_g += (Phi_g[i]*(coeff[i]*dsg))/n + (-2.0/n)*self.lam*(sf-sg)*Phi_g[i]
            grad_w_f=_clip_grad(grad_w_f); grad_w_g=_clip_grad(grad_w_g)
            self.w_f-=self.lr*grad_w_f; self.w_g-=self.lr*grad_w_g
        self.history['loss']=last; return self
    def eval(self,X,Y,Z,U):
        n=X.shape[0]; Phi_f=self._phi_f(X,Z); Phi_g=self._phi_g(X)
        s_f=Phi_f@self.w_f; s_g=Phi_g@self.w_g; s=np.zeros(n); mask_obs=~np.isnan(Z)
        idx_vis=np.where(mask_obs)[0]; idx_hid=np.where(~mask_obs)[0]
        for i in idx_vis:
            if U[i]==1: s[i]=max(s_f[i],s_g[i])
            else: s[i]=min(s_f[i],s_g[i])
        if len(idx_hid)>0: s[idx_hid]=s_g[idx_hid]
        loss=logistic_loss(Y,s).mean(); preds=np.sign(s); acc=(preds==Y).mean()
        return {'loss':float(loss),'accuracy':float(acc)}

print('Models ready.')

## 4. Experiment Orchestration
`Experiment` class runs parameter sweeps and aggregates results.

In [None]:
class Experiment:
    def __init__(self, param_space: Dict[str, list], seed=0):
        self.param_space = param_space; self.seed=seed; self.results: List[Dict[str, Any]] = []
    def single_run(self, params: Dict[str, Any], model_hparams=None):
        dist = PerturbedDistribution(
            perturb_level_YU=params['perturb_level_YU'],
            perturb_level_YZ=params['perturb_level_YZ'],
            p=params['p'], seed=self.seed)
        train_n = params.get('train_n', 1000); test_n = params.get('test_n', 1000)
        train = dist.sample(train_n); test = dist.sample(test_n)
        if model_hparams is None: model_hparams = {}
        van = VanillaClassifier(**model_hparams.get('vanilla', {}))
        strat = StrategicClassifier(**model_hparams.get('strategic', {}))
        van.fit(train['X'], train['Y'], train['Z'], train['U']); strat.fit(train['X'], train['Y'], train['Z'], train['U'])
        van_train = van.eval(train['X'], train['Y'], train['Z'], train['U']); van_test = van.eval(test['X'], test['Y'], test['Z'], test['U'])
        strat_train = strat.eval(train['X'], train['Y'], train['Z'], train['U']); strat_test = strat.eval(test['X'], test['Y'], test['Z'], test['U'])
        res = {**params, 'corr_YU': dist.corr_YU, 'corr_YZ': dist.corr_YZ,
               'van_w_f': van.w_f.tolist(), 'van_w_g': van.w_g.tolist(),
               'strat_w_f': strat.w_f.tolist(), 'strat_w_g': strat.w_g.tolist(),
               'van_train_loss': van_train['loss'], 'van_test_loss': van_test['loss'], 'van_test_acc': van_test['accuracy'],
               'strat_train_loss': strat_train['loss'], 'strat_test_loss': strat_test['loss'], 'strat_test_acc': strat_test['accuracy']}
        self.results.append(res); return res
    def run_sweep(self, model_hparams=None):
        keys=list(self.param_space.keys()); values=[self.param_space[k] for k in keys]
        combos=list(itertools.product(*values)); total=len(combos); results=[]
        for idx, combo in enumerate(combos,1):
            print(f'Sweep {idx}/{total} ({idx/total:.1%})', end='\r')
            params=dict(zip(keys, combo)); res=self.single_run(params, model_hparams=model_hparams); results.append(res)
        print(); return results
    def get_results_df(self):
        return pd.DataFrame(self.results)
    def save_csv(self, path='experiment_results.csv'):
        df=self.get_results_df(); col_order=[]
        if 'p' in df.columns: col_order.append('p')
        for c in ['corr_YZ','corr_YU']:
            if c in df.columns: col_order.append(c)
        param_cols=[c for c in ['perturb_level_YU','perturb_level_YZ','train_n','test_n'] if c in df.columns]
        col_order.extend(param_cols)
        rest=[c for c in df.columns if c not in col_order]
        col_order.extend(rest); df=df[col_order]
        sort_keys=[k for k in ['p','corr_YZ','corr_YU','perturb_level_YU','perturb_level_YZ','train_n','test_n'] if k in df.columns]
        df=df.sort_values(by=sort_keys).reset_index(drop=True)
        df.to_csv(path, index=False); return path
print('Experiment class ready.')

## 5. Visualization
Plots which model wins (Vanilla vs Strategic) across correlation space, per `p`.

In [None]:
def plot_correlation_space(csv_path='experiment_results.csv'):
    results = pd.read_csv(csv_path)
    p_values = sorted(results['p'].unique())
    color_map = {0:'tab:blue', 1:'tab:orange', 2:'gray'}
    label_map = {0:'Vanilla', 1:'Strategic', 2:'Tie'}
    for metric, label, vcol, scol in [
        ('test_loss','Model with Lower Test Loss','van_test_loss','strat_test_loss'),
        ('test_acc','Model with Higher Test Accuracy','van_test_acc','strat_test_acc')]:
        for p in p_values:
            dfp = results[results['p']==p]
            if len(dfp)==0: continue
            if metric=='test_loss':
                margin = (dfp[scol]-dfp[vcol]) / np.maximum(dfp[vcol], dfp[scol])
                winner = np.where(margin>0.05,0,np.where(margin<-0.05,1,2))
            else:
                margin = (dfp[vcol]-dfp[scol]) / np.maximum(dfp[vcol], dfp[scol])
                winner = np.where(margin>0.05,0,np.where(margin<-0.05,1,2))
            colors=[color_map[w] for w in winner]
            plt.figure(figsize=(6.5,5.5))
            plt.scatter(dfp['corr_YZ'], dfp['corr_YU'], c=colors, alpha=0.85, edgecolor='k')
            for i in [0,1,2]: plt.scatter([],[],c=color_map[i],label=label_map[i])
            plt.xlabel('Correlation E[YZ]'); plt.ylabel('Correlation E[YU]')
            plt.title(f'{label} (p={p})')
            plt.legend(title='Winner'); plt.grid(True,linestyle='--',alpha=0.5)
            plt.tight_layout(); plt.savefig(f'correlation_space_{metric}_p{p}.png'); plt.close()
print('Visualization helper ready.')

## 6. Demo Run
Runs a SMALL sweep for speed. You can expand later (see commented example).

In [None]:
# Small demo parameter space (adjust as needed).
demo_param_space = {
    'perturb_level_YU': [-1,0,1],
    'perturb_level_YZ': [-1,0,1],
    'p': [0.0, 0.5, 1.0],
    'train_n': [800],
    'test_n': [300]
}
exp = Experiment(demo_param_space, seed=50)
_ = exp.run_sweep()
csv_path = exp.save_csv('experiment_results_demo.csv')
print('Demo sweep completed. Rows:', len(exp.results))
try:
    from IPython.display import display
    display(exp.get_results_df().head())
except Exception:
    print(exp.get_results_df().head())
plot_correlation_space(csv_path)
print('Plots saved for demo (test_loss & test_acc variants).')

# For FULL sweep (longer runtime) uncomment and adjust:
# full_param_space = {
#     'perturb_level_YU': [-3,-2,-1,0,1,2,3],
#     'perturb_level_YZ': [-3,-2,-1,0,1,2,3],
#     'p': [0,0.25,0.5,0.75,1],
#     'train_n': [2500],
#     'test_n': [500]
# }
# exp_full = Experiment(full_param_space, seed=50)
# exp_full.run_sweep()
# exp_full.save_csv('experiment_results_full.csv')
# plot_correlation_space('experiment_results_full.csv')
# print('Full sweep complete.')