In [1]:
import sys
sys.path.append("../mypkg")

In [40]:
from constants import RES_ROOT, FIG_ROOT, DATA_ROOT
from utils.misc import load_pkl, save_pkl, merge_intervals
from utils.colors import qual_cmap
from utils.stats import weighted_quantile
from data_gen_utils.data_gen_my2 import get_simu_data
from utils.utils import MyDataSet, get_idx_sets
from demo_settings import simu_settings
from CQR import get_CQR_CIs, boosting_pred, boosting_logi, get_CF_CIs
from mlp.train_mlp import TrainMLP
from ddpm.train_ddpm_now import TrainDDPM
from weighted_conformal_inference import WeightedConformalInference
from local_weighted_conformal_inference import LocalWeightedConformalInference, get_opth
from naive_sample import NaiveSample

In [41]:
%load_ext autoreload
%autoreload 2
# 0,1, 2, 3, be careful about the space

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [42]:
import torch
import scipy.stats as ss
import numpy as np
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict as ddict
from tqdm import tqdm, trange
import random
from joblib import Parallel, delayed
import pandas as pd
from pprint import pprint
from copy import deepcopy
plt.style.use(FIG_ROOT/"base.mplstyle")

# Params

In [107]:
setting = "setting3"

'setting3'

In [108]:
params = edict()

params.simu_setting = edict()
params.simu_setting.rho = 0.9
params.simu_setting.is_homo = False
params.simu_setting.n = 3000
params.simu_setting.d = 10
params.simu_setting.ntest = 1000
params.simu_setting.cal_ratio = 0.45 # for conformal inference
params.simu_setting.val_ratio = 0.05 # for tuning network
params.simu_setting.update(simu_settings[setting])
params.simu_setting.n = 3000
params.simu_setting.d = 100 
pprint(params.simu_setting)


params.nrep = 50 # num of reptition for simulation
params.K = 40 # num of sps drawn from q(Y(1)|X)
params.save_snapshot = 500
params.df_dtype = torch.float32
params.device="cpu"
params.n_jobs = 20
params.verbose = True
params.inf_bs = 40 # the inference batch, fct x K

params.ddpm_training = edict()
# Batch size during training
params.ddpm_training.batch_size = 256 
# Number of training epochs
params.ddpm_training.n_epoch = 1000
params.ddpm_training.n_infeat = 128
# Learning rate for optimizers
params.ddpm_training.lr = 0.001
params.ddpm_training.lr_gamma = 0.5
params.ddpm_training.lr_step = 1000
params.ddpm_training.test_intv = 5
params.ddpm_training.n_T = 400 # 100
params.ddpm_training.n_upblk = 1
params.ddpm_training.n_downblk = 1
params.ddpm_training.weight_decay = 1e-2
params.ddpm_training.early_stop = False
params.ddpm_training.early_stop_dict = {"early_stop_len":50, "early_stop_eps": 5e-4}
#params.ddpm_training.betas = [0.001, 0.5]

params.wconformal = edict()
# remove too large and too small in ws/mean(ws)
params.wconformal.nwthigh = 20
params.wconformal.nwtlow = 0.05
params.wconformal.useinf = False


params.hypo_test = edict()
params.hypo_test.alpha = 0.05 # sig level

params.prefix = ""
params.save_dir = f"demo_ddpm_{setting}_test11"
if not (RES_ROOT/params.save_dir).exists():
    (RES_ROOT/params.save_dir).mkdir()

{'cal_ratio': 0.25,
 'd': 100,
 'err_type': 'norm',
 'is_homo': False,
 'n': 3000,
 'ntest': 1000,
 'rho': 0.0,
 'val_ratio': 0.05}


In [109]:
torch.set_default_dtype(params.df_dtype)

In [110]:
keys = ["lr", "n_infeat", "n_T", "weight_decay", "n_upblk", "n_downblk"]
def _get_name_postfix(keys, ddpm_training):
    lst = []
    for key in keys:
        if ddpm_training[key] >= 1:
            lst.append(f"{key}-{str(ddpm_training[key])}")
        else:
            lst.append(f"{key}--{str(ddpm_training[key]).split('.')[-1]}")
    return "_".join(lst)

In [111]:
rep_ix = 1
params = params
lr = 1e-2
n_infeat = 256
weight_decay = 1e-2
n_blk = 3
if True:
    manualSeed = rep_ix
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.use_deterministic_algorithms(True) # Needed for reproducible results
    params = edict(deepcopy(params))
    params.ddpm_training.n_upblk = n_blk
    params.ddpm_training.n_downblk = n_blk
    params.ddpm_training.weight_decay = weight_decay
    params.ddpm_training.n_T = n_T
    params.ddpm_training.lr = lr
    params.ddpm_training.n_infeat = n_infeat
    keys = ["lr", "n_infeat", "n_T", "weight_decay", "n_upblk", "n_downblk"]
    post_fix = _get_name_postfix(keys, params.ddpm_training)
    
    torch.set_default_dtype(params.df_dtype)
    torch.set_default_device(params.device)
    
    fil_name = (RES_ROOT/params.save_dir)/f"rep_{rep_ix}_{post_fix}_res.pkl"
    ofil_name = (RES_ROOT/params.save_dir)/f"rep_{rep_ix}_others_res.pkl"
        
        
    if True:
    #if not fil_name.exists():
        data_train = get_simu_data(n=params.simu_setting.n, 
                                   d=params.simu_setting.d, 
                                   rho=params.simu_setting.rho, 
                                   err_type=params.simu_setting.err_type);
        data_test = get_simu_data(n=params.simu_setting.ntest, 
                                   d=params.simu_setting.d, 
                                   rho=params.simu_setting.rho,
                                   err_type=params.simu_setting.err_type);
        
        
        cal_idxs, val_idxs, tr_idxs = get_idx_sets(all_idxs=np.where(data_train.T)[0], 
                                                       ratios = [params.simu_setting.cal_ratio, params.simu_setting.val_ratio])
                
        # get psfun
        psY = data_train.T.astype(int)
        psX = data_train.X
        fit_res = boosting_logi(psY, psX);
        def wsfun(X):
            eps=1e-10
            if isinstance(X, torch.Tensor):
                X = X.cpu().numpy()
            if X.ndim == 1:
                X = X.reshape(1, -1)
            est_ps = boosting_pred(X, fit_res)
            est_ws = 1/(est_ps+eps)
            return torch.tensor(est_ws, dtype=params.df_dtype).to(device=params.device)
            
        
        cal_X = torch.tensor(data_train.X[cal_idxs], dtype=params.df_dtype)
        cal_Y = torch.tensor(data_train.Y1[cal_idxs], dtype=params.df_dtype)
        val_X = torch.tensor(data_train.X[val_idxs], dtype=params.df_dtype)
        val_Y = torch.tensor(data_train.Y1[val_idxs], dtype=params.df_dtype)
        test_X = torch.tensor(data_test.X, dtype=params.df_dtype)
        test_Y = torch.tensor(data_test.Y1, dtype=params.df_dtype)
            
        # train q(Y(1)|X)
        data_train_ddpm = MyDataSet(Y=data_train.Y[tr_idxs], X=data_train.X[tr_idxs])
        data_val = edict()
        data_val.c = val_X
        data_val.x = val_Y
        input_params = edict(deepcopy(params.ddpm_training))
        input_params.pop("n_epoch")
        input_params.pop("early_stop")
        input_params.pop("early_stop_dict")

In [112]:
mlp_fit = TrainMLP(data_train_ddpm, 
                  save_dir=params.save_dir, 
                  verbose=params.verbose, prefix=f"rep{rep_ix}_{post_fix}", 
                  device=params.device,
                  **input_params)

2024-02-13 19:14:38,957 - mlp.train_mlp - INFO - The results are saved at /data/rajlab1/user_data/jin/MyResearch/DG-CITE_paper/notebooks/../mypkg/../results/demo_ddpm_setting3_test11.
2024-02-13 19:14:38,979 - mlp.train_mlp - INFO - The params is {'lr': 0.01, 'batch_size': 256, 'device': 'cpu', 'n_infeat': 256, 'n_downblk': 3, 'lr_gamma': 0.5, 'lr_step': 1000, 'test_intv': 5, 'weight_decay': 0.01}


The num of params is 1.57m. 
Adjusting learning rate of group 0 to 1.0000e-02.


<mlp.train_mlp.TrainMLP at 0x7f4b23345be0>

In [113]:
mlp_fit.train(n_epoch=params.ddpm_training.n_epoch, 
              data_val=data_val, 
              save_snapshot=params.save_snapshot, 
              early_stop=params.ddpm_training.early_stop, 
              early_stop_dict=params.ddpm_training.early_stop_dict
              )

loss: 0.3833:  50%|█████████████████████▉                      | 498/1000 [00:33<00:33, 15.19it/s, val loss=3.16]2024-02-13 19:15:13,458 - mlp.train_mlp - INFO - Save model rep1_lr--01_n_infeat-256_n_T-400_weight_decay--01_n_upblk-3_n_downblk-3_mlp_epoch500.pth.
loss: 0.2650: 100%|███████████████████████████████████████████▉| 998/1000 [01:07<00:00, 14.72it/s, val loss=6.59]2024-02-13 19:15:47,733 - mlp.train_mlp - INFO - Save model rep1_lr--01_n_infeat-256_n_T-400_weight_decay--01_n_upblk-3_n_downblk-3_mlp_epoch1000.pth.
loss: 0.2650: 100%|███████████████████████████████████████████| 1000/1000 [01:08<00:00, 14.71it/s, val loss=6.59]

Adjusting learning rate of group 0 to 5.0000e-03.





In [100]:
net = mlp_fit.get_model(500);
wcf  = WeightedConformalInference(cal_X, cal_Y, gen_fn=net, ws_fn=wsfun, gen_type="reg")
CIs = wcf(test_X);

2024-02-13 19:13:56,230 - mlp.train_mlp - INFO - We load model /data/rajlab1/user_data/jin/MyResearch/DG-CITE_paper/notebooks/../mypkg/../results/demo_ddpm_setting1_test11/rep1_lr--01_n_infeat-256_n_T-400_weight_decay--01_n_upblk-3_n_downblk-3_mlp_epoch500.pth.
2024-02-13 19:13:56,241 - weighted_conformal_inference - INFO - wcf params is {'K': 40, 'nwhigh': 20, 'nwlow': 0.05, 'useinf': False, 'cf_type': 'naive'}
2024-02-13 19:13:56,242 - weighted_conformal_inference - INFO - gen params is {'gen_type': 'reg'}


In [101]:
_get_intvs_len = lambda intvs: np.array([np.sum([np.diff(iv) for iv in intv]) for intv in intvs]);
def _get_inset(vs, intvs):
    in_set = []
    for v, intv in zip(vs, intvs):
        in_set.append(np.sum([np.bitwise_and(v>iv[0], v<iv[1]) for iv in intv]))
    in_set = np.array(in_set)
    return in_set
np.median(_get_intvs_len(CIs)),  _get_inset(test_Y, CIs).mean()

(6.924099, 0.948)