In [1]:
import sys
sys.path.append("../mypkg")

In [2]:
from constants import RES_ROOT, FIG_ROOT, DATA_ROOT, MID_ROOT
from utils.misc import load_pkl, save_pkl, merge_intervals
from utils.colors import qual_cmap
from utils.stats import weighted_quantile

In [3]:
%load_ext autoreload
%autoreload 2
# 0,1, 2, 3, be careful about the space

In [4]:
import torch
import scipy.stats as ss
import numpy as np
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict as ddict
from tqdm import tqdm, trange
import random
from joblib import Parallel, delayed
import pandas as pd
from pprint import pprint
plt.style.use(FIG_ROOT/"base.mplstyle")

In [5]:
pd.set_option('display.float_format', '{:.3f}'.format)

In [6]:
from collections import defaultdict as ddict
def fil_name2paras(fil_name):
    """plz be careful about 0.1 and 1
    """
    tmp = fil_name.split("_")
    paras = {}
    for itm in tmp:
        if '--' in itm:
            k, v = itm.split("--")
            v = f"0.{v}"
        elif '-' in itm:
            k, v = itm.split("-")
            # to be compatible with old simu
            if v.startswith("0"):
                v = f"0.{v}"
        else:
            continue
        
        if k.startswith("T"):
            k = "n_T"
        elif k.startswith("decay"):
            k = "weight_decay"
        elif k.startswith("infeat"):
            k = "n_infeat"
        paras[k] = float(v)
    paras["rep"] = int(tmp[1])
    return paras


In [7]:
cs = [0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.2, 0.4, 0.8, 1.2]

[0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.2, 0.4, 0.8, 1.2]

In [10]:
# get results of DDIM and naive
def _add_res(res, key, dict_res):
    if key not in res.keys():
        key = "L" + key
        if key not in res.keys():
            return None
    dict_res["Len"].append(res[key][-1])
    dict_res["ITE"].append(res[key][0])
    if key.startswith("DDIM"):
        dict_res["method"].append("L"+key)
    else:
        dict_res["method"].append(key)
    for ky, v in cur_paras.items():
        dict_res[ky].append(v)
    dict_res["c"].append(c)
        
our_res = ddict(list)
for c in cs:
    res_dir = RES_ROOT/f"realdata_n10000_c{c*100:.0f}"
    all_fils = list(res_dir.glob("*n_T-*.pkl"));
    for fil in all_fils:
        cur_paras = fil_name2paras(fil.stem);
        res = load_pkl(fil, verbose=False);
        
        for ky in res.keys():
            _add_res(res, key=ky, dict_res=our_res)
        
our_res_df = pd.DataFrame(our_res);

In [11]:
def _get_dataset(name):
    name = name.split("_ep")[0].split("_val")[0]
    if name.endswith("1"):
        dat_set = "set1"
    elif name.endswith("2"):
        dat_set = "set2"
    elif name.endswith("2c"):
        dat_set = "set2c"
    else:
        dat_set = "all_data"
    return dat_set

def _get_nep(x):
    vs = x.split("_ep")
    if len(vs) == 1:
        return 2000
    elif len(vs) == 2:
        return int(vs[1].split("_")[0])
_raw_method=lambda x: x.split("_")[0].split("1")[0].split("2")[0]
our_res_df["data_set"] = our_res_df["method"].map(_get_dataset);
our_res_df["is_val"] = our_res_df["method"].map(lambda x: x.endswith("val"));
our_res_df["nep"] = our_res_df["method"].map(_get_nep);
our_res_df["method_raw"] = our_res_df["method"].map(_raw_method)

In [12]:
# results of CQR and CF
def _add_ores(res, key, dict_res):
    if key not in res.keys():
        return None
    dict_res["Len"].append(res[key][-1])
    dict_res["ITE"].append(res[key][0])
    dict_res["method"].append(key)
    dict_res["rep"].append(_get_rep(ofil))
_get_rep = lambda p: int(p.stem.split("_")[1])
ores_dir = RES_ROOT/f"realdata_n10000_c40"
ofils = list(ores_dir.glob("*other*.pkl"));

other_res = ddict(list);
for ofil in ofils:
    res = load_pkl(ofil, verbose=False);
    for ky in res.keys():
        _add_ores(res, key=ky, dict_res=other_res)
other_res_df = pd.DataFrame(other_res);
other_res_df["data_set"] = other_res_df["method"].map(_get_dataset)
other_res_df["method_raw"] = other_res_df["method"].map(_raw_method)

In [13]:
# DDIM and MLP and naive results (no LCP)
our_res_df0 = our_res_df[our_res_df["c"]==0].copy();
our_res_df0["method_raw"] = our_res_df0["method_raw"].apply(lambda x: x[1:] if x.startswith("L") else x)

# Select com for each rep

In [14]:
# get opt obs based on val set
def _get_opt_obs_given_repix(rep_ix, data_set, all_res_df, method_key="DDPM", cutoff=None):
    cols = ["method_raw", "lr", "n_infeat", "n_T", 
            "weight_decay", "upblk", "downblk", 
            "rep", "c", "nep", "data_set"]
    kpidx = np.bitwise_and(all_res_df["rep"] == rep_ix, all_res_df["is_val"])
    kpidx = np.bitwise_and(kpidx, all_res_df["method_raw"]==method_key)
    kpidx = np.bitwise_and(kpidx, all_res_df["data_set"]==data_set)
    
    if cutoff is not None:
        kpidx1 = np.bitwise_and(kpidx, all_res_df["ITE"]>cutoff)
    if kpidx1.sum() > 0:
        kpidx = kpidx1
        best_val = all_res_df[kpidx].sort_values(by="Len").iloc[0]
    else:
        # if not ITE > target, use the one with largest ITE
        best_val = all_res_df[kpidx].sort_values(by="ITE").iloc[-1]
        
    mask = np.ones(all_res_df.shape[0], dtype=bool)
    mask = np.bitwise_and(mask, all_res_df["is_val"] == False)
    for col in cols:
        v = best_val[col]
        mask = np.bitwise_and(mask, all_res_df[col] == v)
    return all_res_df[mask]

In [23]:
data_sets = ["set1"]
best_res = []
cutoff = 0.95
for cur_rep in list(set(our_res_df["rep"])):
    for data_set in data_sets:
        best_res.append(_get_opt_obs_given_repix(cur_rep, data_set, our_res_df0, "DDIM", cutoff))
        best_res.append(_get_opt_obs_given_repix(cur_rep, data_set, our_res_df0, "naive", cutoff))
        best_res.append(_get_opt_obs_given_repix(cur_rep, data_set, our_res_df0, "MLP", cutoff))
        best_res.append(_get_opt_obs_given_repix(cur_rep, data_set, our_res_df, "LDDIM", cutoff))
    
best_res = pd.concat(best_res);

In [29]:
kp_cols = ["Len", "ITE", "method_raw", "rep", "data_set"]
# put the results together
all_res_df = pd.concat([best_res[kp_cols], other_res_df[kp_cols]])

res = all_res_df.groupby(["method_raw", "data_set"])[["Len", "ITE"]].agg(["mean", "std", "count"])

ky = "ITE"
res[("Coverage", "Low")] = res[(ky, "mean")] - 1.96*res[(ky, "std")]/np.sqrt(res[(ky, "count")])
res[("Coverage", "Est")] = res[(ky, "mean")] 
res[("Coverage", "High")] = res[(ky, "mean")] + 1.96*res[(ky, "std")]/np.sqrt(res[(ky, "count")])

ky = "Len"
res[("Len of Interval", "Low")] = res[(ky, "mean")] - 1.96*res[(ky, "std")]/np.sqrt(res[(ky, "count")])
res[("Len of Interval", "Est")] = res[(ky, "mean")] 
res[("Len of Interval", "High")] = res[(ky, "mean")] + 1.96*res[(ky, "std")]/np.sqrt(res[(ky, "count")])

In [30]:
kycols = [v for v in res.columns if v[0] in ["Coverage", "Len of Interval"]]
kyrows = [v for v in res.index if v[-1]=="set1"];

In [31]:
res[kycols].loc[kyrows]

Unnamed: 0_level_0,Unnamed: 1_level_0,Coverage,Coverage,Coverage,Len of Interval,Len of Interval,Len of Interval
Unnamed: 0_level_1,Unnamed: 1_level_1,Low,Est,High,Low,Est,High
method_raw,data_set,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2
CF,set1,0.11,0.119,0.127,0.237,0.242,0.247
CQR,set1,0.92,0.927,0.935,3.162,3.191,3.22
DDIM,set1,0.847,0.858,0.869,3.044,3.092,3.14
LDDIM,set1,0.938,0.948,0.957,4.79,5.066,5.341
MLP,set1,0.836,0.847,0.859,2.866,2.896,2.926
naive,set1,0.596,0.615,0.635,1.656,1.705,1.754
