In [None]:
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm.notebook import tqdm
import itertools
from sklearn.metrics import roc_auc_score
import pandas as pd
import pickle

In [None]:
EXP_ROOT = "/scratch/zeiberg.d/leveragingStructureFinalExperiments/experiments/"

In [None]:
REAL_DATASETS = ["income","employment","income_poverty_ratio","amazon_reviews"]

In [None]:
class Experiment:
    def __init__(self,pth):
        self.path = pth
        self.loadLabels()
        self.loadPreds()

    def loadLabels(self):
        try:
            self.yUnlabeledTest = np.load(os.path.join(self.path,
                                                      "yUnlabeledTest.npy"))
        except FileNotFoundError:
            self.yUnlabeledTest = None
            
    def loadLabels(self):
        try:
            self.yUnlabeledTest = np.load(os.path.join(self.path,
                                                      "yUnlabeledTest.npy"))
        except FileNotFoundError:
            self.yUnlabeledTest = None

    def getK(self):
        if not np.isnan(self.k):
            return self.k
        try:
            with open(os.path.join(self.path,"mm","clusterer.pkl"),"rb") as f:
                clusterer = pickle.load(f)
                self.clusterer = clusterer
            self.k = clusterer.n_clusters
        except FileNotFoundError:
            self.k = np.nan
        return self.k
        
    def loadPreds(self,skipsteps=[]):
        self.methodPreds = {}
        try:
            self.methodPreds["Group-Aware Global"] = np.load(os.path.join(self.path,
                                                                        "ag",
                                                                        "preds.npy"))
        except FileNotFoundError:
            return
        try:
            self.methodPreds["Cluster Global"] = np.load(os.path.join(self.path,
                                                                      "mm",
                                                                      "clusterGlobalPreds.npy"))

            self.methodPreds["Our Method"] = np.load(os.path.join(self.path,
                                                          "mm",
                                                          "preds.npy"))
        except FileNotFoundError:
            return
        
        try:
            self.methodPreds["Global Star"] = np.load(os.path.join(self.path,
                                                            "mmStar",
                                                            "clusterGlobalPreds.npy"))
            self.methodPreds["Star"] = np.load(os.path.join(self.path,
                                                            "mmStar",
                                                            "preds.npy"))
        except FileNotFoundError:
            return
        
        try:
            self.methodPreds["Global"] = np.load(os.path.join(self.path,
                                                              "mm2",
                                                              "clusterGlobalPreds.npy"))

            self.methodPreds["Label Shift"] = np.load(os.path.join(self.path,
                                                                   "mm2",
                                                                   "Preds.npy"))

        except FileNotFoundError:
            return
        try:
            self.methodPreds["coral"] = np.load(os.path.join(self.path,
                                                            "fe",
                                                            "preds.npy"))
        except FileNotFoundError:
            return
    def aucSeries(self):
        names, vals = list(zip(*[(m,roc_auc_score(self.yUnlabeledTest,preds)) for m,preds in self.methodPreds.items()]))
        return pd.Series(data=vals,index=names)

In [None]:
exp_sets = {}
for setting in range(1,3):
    for name in REAL_DATASETS:
        print(setting,name)
        exp_sets[(name,setting)] = [Experiment(pth) for pth in glob(f"/scratch/zeiberg.d/leveragingStructureFinalExperiments/experiments/{name}_setting_{setting}_*/")]

In [None]:
synth_exp_sets = {}
for setting in range(1,3):
    for dim in [1,4,16,64]:
        for nClusters in [1,2,4,8]:
            print(setting,dim,nClusters)
            synth_exp_sets[(setting,dim,nClusters)] = [Experiment(pth) for pth in \
                                                      glob(os.path.join(EXP_ROOT,f"synthetic_dim_{dim}_nClusters_{nClusters}_setting_{setting}_*/")) \
                                                      if "FAILED" not in pth]

In [None]:
tables = {}
for k,exps in exp_sets.items():
    print(k)
    series = []
    for e in exps:
        try:
            series.append(e.aucSeries())
        except ValueError:
            print(e.path)
    tables[k] = pd.DataFrame(series)

In [None]:
import pathlib

missing = []
for k,exps in exp_sets.items():
    rt = pathlib.Path(exps[0].path).name
    rt = rt[:rt.rfind("_")]
    exp_names = set([pathlib.Path(e.path).name for e in exps])
    for num in range(25):
        pth = rt+f"_{num}"
        if pth not in exp_names:
            missing.append(pth)

In [None]:
missing

In [None]:
synth_tables = {}
synth_missing = []
for k,exps in synth_exp_sets.items():
    series = []
    for e in exps:
        try:
            series.append(e.aucSeries())
        except ValueError:
            p = pathlib.Path(e.path).name
            print("cannot process ",p)
            synth_missing.append(p)
            continue
    synth_tables[k] = pd.DataFrame(series)

In [None]:
len(synth_missing)

In [None]:
synth_missing

In [None]:
missing_dict = {}
for setting in range(1,3):
    for dim in [1,4,16,64]:
        for nClusters in [1,2,4,8]:
            template = f"synthetic_dim_{dim}_nClusters_{nClusters}_setting_{setting}_"
            names = [e for e in synth_missing if template in e]
            if len(names):
                missing_dict[(dim,nClusters,setting)] = ",".join([n[n.rfind("_")+1:] for n in names])

### Real-Data Experiment Iteration Counts `with` CORAL

In [None]:
[(n,t.dropna().shape[0]) for n,t in tables.items()]

In [None]:
name,vals = zip(*[(n,t.dropna().mean(axis=0).sort_values()) for n,t in tables.items()])
summaryTable = pd.DataFrame(vals,index=name)

In [None]:
summaryTable.style.highlight_max(subset=[c for c in summaryTable.columns if "Star" not in c],axis=1)

In [None]:
table1 = summaryTable.loc(axis=0)[:,1].T.loc[["coral","Global", "Group-Aware Global", "Cluster Global", "Label Shift", "Our Method","Star"]]
table1.columns = table1.columns.droplevel(1)
amazon1 = table1["amazon_reviews"]
table1 = table1.drop("amazon_reviews",axis=1)

In [None]:
amazon1

In [None]:
table1

In [None]:
amazon1.to_latex("figures/amazon_1.latex",
               header=["Amazon"],
               float_format="%.3f",
               index_names=["CORAL","Global","Group-Aware Global","Cluster Global", "Label Shift", "Our Method","True Clustering"],
               caption="Average AUC calculated on the held-out test set for Amazon datasets in setting 1.",
                     label="tab:amazonSetting1",)

In [None]:
table1.to_latex("figures/table_1.latex",
               header=["Income","Employment","IPR"],
               float_format="%.3f",
               index_names=["CORAL","Global","Group-Aware Global","Cluster Global", "Label Shift", "Our Method","True Clustering"],
               caption="Average AUC calculated on the held-out test set for real-world datasets in setting 1.",
                     label="tab:realSetting1",)

In [None]:
table2 = summaryTable.loc(axis=0)[:,2].T.loc[["coral","Global", "Group-Aware Global", "Cluster Global", "Label Shift", "Our Method","Star"]]
table2.columns = table2.columns.droplevel(1)
amazon2 = table2["amazon_reviews"]
table2 = table2.drop("amazon_reviews",axis=1)

In [None]:
amazon2

In [None]:
table2

In [None]:
amazon2.to_latex("figures/amazon_2.latex",
               header=["Amazon"],
               float_format="%.3f",
               index_names=["CORAL","Global","Group-Aware Global","Cluster Global", "Label Shift", "Our Method","True Clustering"],
               caption="Average AUC calculated on the held-out test set for Amazon datasets in setting 2.",
                     label="tab:amazonSetting2",)

In [None]:
amazonTable = pd.DataFrame({"Setting 1":amazon1,
                            "Setting 2":amazon2})

In [None]:
amazonTable

In [None]:
amazonTable.to_latex("figures/amazonTable.latex",
               float_format="%.3f",
               index_names=["CORAL","Global","Group-Aware Global","Cluster Global", "Label Shift", "Our Method","True Clustering"],
               caption="Average AUC calculated on the held-out test set for Amazon datasets in settings 1 and 2.",
                     label="tab:amazon",)

In [None]:
table2.to_latex("figures/table_2.latex",
               header=["Income","Employment","IPR"],
               float_format="%.3f",
               index_names=["CORAL","Global","Group-Aware Global","Cluster Global", "Label Shift", "Our Method","True Clustering"],
               caption="Average AUC calculated on the held-out test set for real-world datasets in setting 2.",
                     label="tab:realSetting2",)

In [None]:
name,vals = zip(*[(n,t.dropna().mean(axis=0).sort_values()) for n,t in synth_tables.items()])
synth_summaryTable = pd.DataFrame(vals,index=name)

In [None]:
synth_summaryTable.style.highlight_max(subset=[c for c in synth_summaryTable.columns if "Star" not in c],axis=1)

In [None]:
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 34})
def plotRelPerf(tbl,savepath=None,ax=None,ticks=True,axLBL=True,title=None,
               cols=lambda tbl:tbl.columns,
               ticklabels = None,
               rel_to="Global",
                ytick_locs=np.arange(-20,25,5)):
    g = tbl[rel_to]
    cols = cols(tbl)
    rel = tbl.loc[:,cols].apply(lambda col: (col-g)/g) * 100
    if ax is None:
        ax = plt.subplot()
    ax.boxplot(rel)
    ax.violinplot(rel)
    if ticks:
        if ticklabels is None:
            ticklabels = cols
        labels = ax.set_xticks(np.arange(1,len(ticklabels)+1),
                        ticklabels,
                        rotation=45,ha="right")
    else:
        ax.tick_params(axis='x',          # changes apply to the x-axis
                           which='both',      # both major and minor ticks are affected
                           bottom=False,      # ticks along the bottom edge are off
                           top=False,         # ticks along the top edge are off
                           labelbottom=False) # labels along the bottom edge are off
    if axLBL:
        ax.set_ylabel("Relative AUC (%)")
        ax.set_yticks(ytick_locs,
                 [str(i) for i in ytick_locs])
    if title is not None:
        ax.set_title(title)
    if savepath is not None:
        plt.savefig(savepath,format="pdf")
    return ax

In [None]:
synthDFig2,synthDax2 = plt.subplots(4,4,figsize=(24,24),sharey=True,)
for i,d in enumerate([1,4,16,64]):
    for j,k in enumerate([1,2,4,8]):
        plotRelPerf(synth_tables[(2,d,k)].dropna(),
                    ax=synthDax2[i,j],
                   ticks=i==3,
                   axLBL=j==0,
                   title=f"d={d} K={k}",
                   cols=lambda tbl:["coral","Group-Aware Global",
                                    "Cluster Global",
                                    "Label Shift","Our Method","Star"],
                   ticklabels=["CORAL","Group-Aware Global",
                                    "Cluster Global",
                                    "Label Shift","Our Method","True Clustering"],
                   rel_to="Global",ytick_locs=np.arange(-20,30,10))
synthDFig2.subplots_adjust(hspace = .15,wspace=.05)

In [None]:
synthDFig2.savefig("figures/synthetic_breakdown.pdf",format="pdf", bbox_inches='tight')

In [None]:
synthfig,synthax= plt.subplots(1,2,figsize=(12,6),sharey=True)

synth1 = plotRelPerf(synth_summaryTable.loc[1],ax=synthax[0],
                     cols=lambda tbl:["coral","Group-Aware Global",
                                      "Cluster Global",
                                      "Label Shift","Our Method"],
                     ytick_locs=np.arange(-20,30,10),
                    ticklabels=["CORAL","Group-Aware Global",
                                    "Cluster Global",
                                    "Label Shift","Our Method"],)

synth2 = plotRelPerf(synth_summaryTable.loc[2],ax=synthax[1],axLBL=False,
                    cols=lambda tbl:["coral","Group-Aware Global",
                                      "Cluster Global",
                                      "Label Shift","Our Method"],
                    ticklabels=["CORAL","Group-Aware Global",
                                    "Cluster Global",
                                    "Label Shift","Our Method"],)

In [None]:
synthfig.savefig("figures/synthetic.pdf",format="pdf",bbox_inches='tight')