In [1]:
import sys
sys.path.append("../mypkg")

In [2]:
from constants import RES_ROOT, FIG_ROOT, DATA_ROOT
from utils.misc import load_pkl, save_pkl, merge_intervals
from utils.colors import qual_cmap
from utils.stats import weighted_quantile
from data_gen_utils.data_gen_my3 import get_simu_data
from utils.utils import MyDataSet, get_idx_sets
from demo_settings import simu_settings
from CQR import get_CQR_CIs, boosting_pred, boosting_logi, get_CF_CIs
from ddpm.train_ddpm_now import TrainDDPM
from weighted_conformal_inference import WeightedConformalInference
from local_weighted_conformal_inference import LocalWeightedConformalInference, get_opth
from naive_sample import NaiveSample

In [3]:
%load_ext autoreload
%autoreload 2
# 0,1, 2, 3, be careful about the space

In [4]:
import torch
import scipy.stats as ss
import numpy as np
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict as ddict
from tqdm import tqdm, trange
import random
from joblib import Parallel, delayed
import pandas as pd
from pprint import pprint
from copy import deepcopy
plt.style.use(FIG_ROOT/"base.mplstyle")

# Params

In [5]:
setting = "setting7"

'setting7'

In [6]:
params = edict()

params.simu_setting = edict()
params.simu_setting.rho = 0.9
params.simu_setting.is_homo = False
params.simu_setting.n = 3000
params.simu_setting.d = 10
params.simu_setting.ntest = 1000
params.simu_setting.cal_ratio = 0.25 # for conformal inference
params.simu_setting.val_ratio = 0.15 # for tuning network
params.simu_setting.update(simu_settings[setting])
params.simu_setting.n = 3000
params.simu_setting.d = 10 
pprint(params.simu_setting)


params.nrep = 50 # num of reptition for simulation
params.K = 40 # num of sps drawn from q(Y(1)|X)
params.save_snapshot = 500
params.df_dtype = torch.float32
params.device="cpu"
params.n_jobs = 20
params.verbose = True
params.inf_bs = 40 # the inference batch, fct x K

params.ddpm_training = edict()
# Batch size during training
params.ddpm_training.batch_size = 256 
# Number of training epochs
params.ddpm_training.n_epoch = 1000
params.ddpm_training.n_infeat = 128
# Learning rate for optimizers
params.ddpm_training.lr = 0.001
params.ddpm_training.lr_gamma = 0.5
params.ddpm_training.lr_step = 1000
params.ddpm_training.test_intv = 5
params.ddpm_training.n_T = 400 # 100
params.ddpm_training.n_upblk = 1
params.ddpm_training.n_downblk = 1
params.ddpm_training.weight_decay = 1e-2
params.ddpm_training.early_stop = False
params.ddpm_training.early_stop_dict = {"early_stop_len":50, "early_stop_eps": 5e-4}
#params.ddpm_training.betas = [0.001, 0.5]

params.wconformal = edict()
# remove too large and too small in ws/mean(ws)
params.wconformal.nwthigh = 20
params.wconformal.nwtlow = 0.05
params.wconformal.useinf = False


params.hypo_test = edict()
params.hypo_test.alpha = 0.05 # sig level

params.prefix = ""
params.save_dir = f"demo_ddpm_{setting}_test11"
if not (RES_ROOT/params.save_dir).exists():
    (RES_ROOT/params.save_dir).mkdir()

{'cal_ratio': 0.25,
 'd': 10,
 'err_type': 't',
 'is_homo': False,
 'n': 3000,
 'ntest': 1000,
 'rho': 0.0,
 'val_ratio': 0.15}


In [7]:
torch.set_default_dtype(params.df_dtype)

In [8]:
keys = ["lr", "n_infeat", "n_T", "weight_decay", "n_upblk", "n_downblk"]
def _get_name_postfix(keys, ddpm_training):
    lst = []
    for key in keys:
        if ddpm_training[key] >= 1:
            lst.append(f"{key}-{str(ddpm_training[key])}")
        else:
            lst.append(f"{key}--{str(ddpm_training[key]).split('.')[-1]}")
    return "_".join(lst)

# Some fns

In [149]:
rep_ix = 2
params = params
lr = 1e-1
n_infeat = 128
n_T = 100
weight_decay = 1e-2
n_blk = 1
manualSeed = rep_ix
random.seed(manualSeed)
np.random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results
params = edict(deepcopy(params))
params.ddpm_training.n_upblk = n_blk
params.ddpm_training.n_downblk = n_blk
params.ddpm_training.weight_decay = weight_decay
params.ddpm_training.n_T = n_T
params.ddpm_training.lr = lr
params.ddpm_training.n_infeat = n_infeat
keys = ["lr", "n_infeat", "n_T", "weight_decay", "n_upblk", "n_downblk"]
post_fix = _get_name_postfix(keys, params.ddpm_training)

torch.set_default_dtype(params.df_dtype)
torch.set_default_device(params.device)
    
data_train = get_simu_data(n=params.simu_setting.n, 
                           d=params.simu_setting.d, 
                           #is_homo=params.simu_setting.is_homo, 
                           rho=params.simu_setting.rho, 
                           err_type=params.simu_setting.err_type);
data_test = get_simu_data(n=params.simu_setting.ntest, 
                           d=params.simu_setting.d, 
                           #is_homo=params.simu_setting.is_homo, 
                           rho=params.simu_setting.rho,
                           err_type=params.simu_setting.err_type);


cal_idxs, val_idxs, tr_idxs = get_idx_sets(all_idxs=np.where(data_train.T!=100)[0], 
                                                       ratios = [params.simu_setting.cal_ratio, params.simu_setting.val_ratio])
#cal_idxs, val_idxs, tr_idxs = get_idx_sets(all_idxs=np.where(data_train.T)[0], 
#                                                       ratios = [params.simu_setting.cal_ratio, params.simu_setting.val_ratio])
                
# get psfun
psY = data_train.T.astype(int)
psX = data_train.X
fit_res = boosting_logi(psY, psX);
def wsfun(X):
    eps=1e-10
    if isinstance(X, torch.Tensor):
        X = X.cpu().numpy()
    if X.ndim == 1:
        X = X.reshape(1, -1)
    est_ps = boosting_pred(X, fit_res)
    est_ws = 1/(est_ps+eps)
    return torch.tensor(est_ws, dtype=params.df_dtype).to(device=params.device)
            
        
cal_X = torch.tensor(data_train.X[cal_idxs], dtype=params.df_dtype)
cal_Y = torch.tensor(data_train.Y1[cal_idxs], dtype=params.df_dtype)
val_X = torch.tensor(data_train.X[val_idxs], dtype=params.df_dtype)
val_Y = torch.tensor(data_train.Y1[val_idxs], dtype=params.df_dtype)
test_X = torch.tensor(data_test.X, dtype=params.df_dtype)
test_Y = torch.tensor(data_test.Y1, dtype=params.df_dtype)

# get subset of all X
test_Xnorm = torch.norm(test_X, p=2, dim=1)
cutoff1 = torch.quantile(test_Xnorm, 0.1)
cutoff2 = torch.quantile(test_Xnorm, 0.5)
test_X1 = test_X[test_Xnorm<cutoff1].clone()
test_Y1 = test_Y[test_Xnorm<cutoff1].clone()
test_X1c = test_X[test_Xnorm>=cutoff1].clone()
test_Y1c = test_Y[test_Xnorm>=cutoff1].clone()
test_X2 = test_X[test_Xnorm<cutoff2].clone()
test_Y2 = test_Y[test_Xnorm<cutoff2].clone()
test_X2c = test_X[test_Xnorm>=cutoff2].clone()
test_Y2c = test_Y[test_Xnorm>=cutoff2].clone();


In [150]:
# train q(Y(1)|X)
data_train_ddpm = MyDataSet(Y=data_train.Y[tr_idxs], X=data_train.X[tr_idxs])
data_val = edict()
data_val.c = val_X
data_val.x = val_Y
input_params = edict(deepcopy(params.ddpm_training))
input_params.pop("n_epoch")
input_params.pop("early_stop")
input_params.pop("early_stop_dict")
myddpm = TrainDDPM(data_train_ddpm, save_dir=params.save_dir, verbose=params.verbose, prefix=f"rep{rep_ix}_{post_fix}", 
                   device=params.device,
                   **input_params);


myddpm.train(n_epoch=100,
                 data_val=data_val, save_snapshot=params.save_snapshot, 
                 early_stop=params.ddpm_training.early_stop, 
                 early_stop_dict=params.ddpm_training.early_stop_dict
                         )


2024-02-22 15:57:14,493 - ddpm.train_ddpm_now - INFO - The results are saved at /data/rajlab1/user_data/jin/MyResearch/DG-CITE_paper/notebooks/../mypkg/../results/demo_ddpm_setting7_test11.
2024-02-22 15:57:14,503 - ddpm.train_ddpm_now - INFO - The params is {'lr': 0.1, 'batch_size': 256, 'device': 'cpu', 'n_T': 100, 'n_infeat': 128, 'n_upblk': 1, 'n_downblk': 1, 'betas': [0.0001, 0.02], 'lr_gamma': 0.5, 'lr_step': 1000, 'test_intv': 5, 'weight_decay': 0.01}


The num of params is 0.30m. 
Adjusting learning rate of group 0 to 1.0000e-01.


loss: 0.8551: 100%|█████████████████████████████████████████████| 100/100 [00:08<00:00, 11.38it/s, val loss=1.12]


In [151]:
def _inner_fn(X, Y, ddpm, LCP=False, h=0.3, verbose=1):
    # get the len of CI based on intvs, there intvs is a list, each ele is another list contains CIs ele=[CI1, CI2]
    _get_intvs_len = lambda intvs: np.array([sum([np.diff(iv) for iv in intv])[0] for intv in intvs]);
    # get weather vaule in vs is in CI in intvs or not 
    def _get_inset(vs, intvs):
        in_set = []
        for v, intv in zip(vs, intvs):
            in_set.append(np.sum([np.bitwise_and(v>iv[0], v<iv[1]) for iv in intv]))
        in_set = np.array(in_set)
        return in_set
    wcf = LocalWeightedConformalInference(cal_X, 
                         cal_Y,
                         ddpm, 
                         ws_fn=None, 
                         #ws_fn=wsfun, 
                         verbose=verbose, 
                         gen_type="ddim",
                         seed=manualSeed,
                         n_jobs=params.n_jobs,
                         inf_bs=params.inf_bs,
                         device=params.device,
                         gen_params={"ddim_eta": 1},
                         wcf_params={
                            "K": params.K, # num of sps for each X
                            "nwhigh" : params.wconformal.nwthigh,
                            "nwlow" : params.wconformal.nwtlow,
                            "useinf": params.wconformal.useinf,
                         })
    wcf.add_data(X)
    intvs = wcf(local_method=LCP, 
                alpha=params.hypo_test.alpha, 
                lm_params={"h":h});
    prbs = np.mean(_get_inset(Y, intvs))
    mlen = np.median(_get_intvs_len(intvs))
    return prbs, mlen
ddpm = myddpm.ddpm;

In [152]:
_inner_fn(test_X, test_Y, ddpm, LCP=None, h=0.3, verbose=1)



(0.949, 15.601501)

In [153]:
_inner_fn(test_X1, test_Y1, ddpm, LCP=None, h=0.3, verbose=1)



(0.88, 15.661835)

In [154]:
_inner_fn(test_X1c, test_Y1c, ddpm, LCP=None, h=0.3, verbose=1)



(0.96, 15.586548)

In [132]:
_inner_fn(test_X2c, test_Y2c, ddpm, LCP=None, h=0.3, verbose=1)



(0.95, 15.028359)

In [131]:
_inner_fn(test_X2, test_Y2, ddpm, LCP=None, h=0.3, verbose=1)



(0.92, 15.025183)

In [119]:
_inner_fn(test_X2c, test_Y2c, ddpm, LCP="b-rlcp", 
          h=0.8*np.sqrt(params.simu_setting.d), 
          verbose=1)

> [0;32m/data/rajlab1/user_data/jin/MyResearch/DG-CITE_paper/mypkg/local_weighted_conformal_inference.py[0m(174)[0;36m__call__[0;34m()[0m
[0;32m    172 [0;31m        [0mintvs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mget_intvs[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mtest_Y_hat[0m[0;34m,[0m [0mqvs[0m[0;34m)[0m[0;34m;[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    173 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 174 [0;31m        [0;32mreturn[0m [0mintvs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    175 [0;31m[0;34m[0m[0m
[0m[0;32m    176 [0;31m[0;34m[0m[0m
[0m


ipdb>  qts


tensor([0.9585, 1.0870, 0.9557, 0.9552, 1.0161, 0.9567, 0.9534, 0.9706, 1.1381,
        0.9559, 0.9558, 0.9532, 1.0587, 0.9536, 0.9686, 1.2443, 1.4227, 1.0090,
        0.9560, 0.9532, 0.9578, 0.9588, 0.9578, 0.9539, 0.9699, 0.9704, 0.9527,
        0.9782, 0.9553, 0.9682,    inf, 0.9705, 1.0594, 0.9571, 0.9570, 0.9607,
        1.0018, 0.9688,    inf, 0.9535, 1.0886, 0.9539, 0.9528, 0.9540, 0.9584,
        0.9603, 0.9544, 0.9533, 0.9535, 0.9561, 0.9535, 0.9694, 1.1208, 0.9539,
        0.9530, 0.9555, 0.9554, 0.9644, 0.9545, 0.9551, 1.0512, 0.9586, 0.9646,
        0.9574, 1.0108, 0.9530, 0.9630, 0.9850, 0.9532, 0.9564, 0.9579, 0.9612,
        0.9566, 0.9574, 0.9565, 0.9551, 0.9567, 0.9608, 0.9577, 0.9701, 0.9566,
        0.9562, 0.9544, 0.9534, 0.9653, 1.0558, 1.0667, 0.9568, 0.9613, 0.9543,
        0.9536, 0.9563, 0.9553, 0.9532, 0.9550, 1.1238, 0.9595, 0.9560, 0.9628,
        0.9541, 0.9616, 0.9582, 0.9529, 0.9643, 0.9741, 0.9777, 0.9824, 0.9711,
        0.9546, 0.9543, 0.9977, 0.9570, 

ipdb>  tws_wtest[:, :-1].sum(axis=1)


tensor([307.2538, 270.9361, 308.1488, 308.3193, 289.8217, 307.8317, 308.8899,
        303.4350, 258.7679, 308.0736, 308.1057, 308.9638, 278.1680, 308.8303,
        304.0376, 236.6757, 207.0061, 291.8625, 308.0588, 308.9657, 307.4775,
        307.1674, 307.4829, 308.7367, 303.6412, 303.4731, 309.1338, 301.0640,
        308.2862, 304.1703,   0.0000, 303.4594, 277.9783, 307.6947, 307.7398,
        306.5439, 293.9598, 303.9774,   0.0000, 308.8686, 270.5422, 308.7429,
        309.1043, 308.7087, 307.2969, 306.6667, 308.5762, 308.9409, 308.8586,
        308.0253, 308.8696, 303.7866, 262.7523, 308.7470, 309.0164, 308.2072,
        308.2528, 305.3634, 308.5510, 308.3479, 280.1596, 307.2328, 305.3148,
        307.6194, 291.3672, 309.0399, 305.8209, 298.9729, 308.9702, 307.9235,
        307.4407, 306.3968, 307.8768, 307.6113, 307.8894, 308.3417, 307.8198,
        306.5191, 307.5146, 303.5745, 307.8755, 308.0058, 308.5685, 308.8933,
        305.1009, 278.9429, 276.0909, 307.7844, 306.3582, 308.60

ipdb>  nadd_ws


tensor([[  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   3.1959],
        [  0.0000,   0.0000,  38.7500,  ...,   0.0000,   0.0000,  38.7500],
        [  1.6316,   0.0000,   1.6316,  ...,   1.6316,   0.0000,   1.6316],
        ...,
        [  2.7679,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   2.7679],
        [  0.0000,   2.0946,   0.0000,  ...,   0.0000,   0.0000,   2.0946],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000, 310.0000]])


ipdb>  add_ws


tensor([[0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 1.,  ..., 0., 0., 1.],
        [1., 0., 1.,  ..., 1., 0., 1.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 1.],
        [0., 1., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.]])


ipdb>  add_ws.sum(axis=1)


tensor([ 97.,   8., 190., 287.,  11., 124., 286.,  42.,   9., 188., 135., 238.,
          9., 259.,  65.,   5.,   3.,  31., 129., 256., 213.,  92., 105., 308.,
         62.,  44., 306.,  38., 173.,  47.,   1.,  38.,  10., 143., 136.,  74.,
         20.,  81.,   1., 303.,   7., 197., 283., 243.,  79.,  86., 193., 306.,
        241., 273., 236.,  44.,   6., 310., 283., 165., 161.,  57., 234., 221.,
         12., 163.,  63., 235.,  14., 274.,  64.,  26., 310., 168., 109.,  86.,
        167., 112., 135., 259., 111.,  88., 134.,  59., 271., 144., 184., 223.,
         76.,   9.,  10., 188., 108., 256., 216., 117., 226., 263., 197.,   7.,
        112., 141.,  66., 205.,  73., 217., 296., 138.,  37.,  34.,  26.,  50.,
        189., 203.,  19., 106.,  14., 296.,  26.,  76.,  97.,  72., 213.,  36.,
        210.,  41., 308.,  35.,  31.,  51., 287.,  23., 297.,  71.,  14.,  95.,
        214.,  52.,  22., 200., 253., 119., 221.,  10., 204.,  21., 258., 118.,
        227.,  76., 130., 277., 124., 30

ipdb>  torch.where(add_ws.sum(axis=1)==1)


(tensor([ 30,  38, 244, 464, 499]),)


ipdb>  add_ws[30]


tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 

ipdb>  q


BdbQuit: 

In [117]:
_inner_fn(test_X2, test_Y2, ddpm, LCP="b-rlcp", 
          h=0.8*np.sqrt(params.simu_setting.d), 
          verbose=1)

(0.948, 15.063553)

In [113]:
_inner_fn(test_X2, test_Y2, ddpm, LCP="g-rlcp", 
          h=0.6,
          verbose=1)

(0.952, 15.178436)

In [None]:
        
            
            
        
        tmpX=torch.tensor(data_train.X[tr_idxs], dtype=params.df_dtype);
        opt_h = get_opth(X=tmpX, hmin=0.1, local_method="RLCP", target_eff_size=None)
        # get the results
        res_all = edict()
        
        # results under the final model
        ddpm = myddpm.ddpm
        ddpm.eval()
        res_all.DDIM = _inner_fn(test_X, test_Y, ddpm, gen_type="ddim", LCP=False)
        res_all.DDIM1 = _inner_fn(test_X1, test_Y1, ddpm, gen_type="ddim", LCP=False)
        res_all.DDIM2 = _inner_fn(test_X2, test_Y2, ddpm, gen_type="ddim", LCP=False)
        res_all.DDIM2c = _inner_fn(test_X2c, test_Y2c, ddpm, gen_type="ddim", LCP=False)
        res_all.DDIM_val = _inner_fn(val_X, val_Y, ddpm, gen_type="ddim", LCP=False)
        
        res_all.LDDIM = _inner_fn(test_X, test_Y, ddpm, gen_type="ddim", LCP=True)
        res_all.LDDIM1 = _inner_fn(test_X1, test_Y1, ddpm, gen_type="ddim", LCP=True)
        res_all.LDDIM2 = _inner_fn(test_X2, test_Y2, ddpm, gen_type="ddim", LCP=True)
        res_all.LDDIM2c = _inner_fn(test_X2c, test_Y2c, ddpm, gen_type="ddim", LCP=True)
        
   
        # results from CQR
        def _CQR_fn(test_X, test_Y):
            if isinstance(test_X, torch.Tensor):
                test_X = test_X.cpu().numpy()
            if isinstance(test_Y, torch.Tensor):
                test_Y = test_Y.cpu().numpy()
            CQR_CIs = get_CQR_CIs(X=data_train.X, Y=data_train.Y, 
                                  T=data_train.T, Xtest=test_X, 
                                  nav=0, 
                                  alpha=params.hypo_test.alpha, 
                                  estimand="unconditional",
                                  fyx_est="quantBoosting", seed=manualSeed)
            mlen_cqr = np.median(CQR_CIs[:, 1] -  CQR_CIs[:, 0])
            prb_Y1_cqr = np.bitwise_and(test_Y>CQR_CIs[:, 0],test_Y<CQR_CIs[:, 1]).mean()
            return prb_Y1_cqr, mlen_cqr
        res_all.CQR = _CQR_fn(test_X, test_Y)
        res_all.CQR1 = _CQR_fn(test_X1, test_Y1)
        res_all.CQR2 = _CQR_fn(test_X2, test_Y2)
        res_all.CQR2c = _CQR_fn(test_X2c, test_Y2c)
        
     
        
        
        save_pkl((RES_ROOT/params.save_dir)/f"rep_{rep_ix}_{post_fix}_res.pkl", res_all, is_force=True)
        all_models = list(myddpm.save_dir.glob(f"{myddpm.prefix}ddpm_epoch*.pth"))
        [m.unlink() for m in all_models]
    else:
        res_all = edict()
        print(f"As {fil_name} exists, we do not do anything")

In [15]:
res_all

{'DDIM': [0.964, 14.017761],
 'DDIM1': [0.96, 13.571974],
 'DDIM2': [0.962, 13.748566],
 'DDIM2c': [0.97, 14.220329],
 'DDIM_val': [0.9783783783783784, 14.089476],
 'LDDIM': [0.966, 14.068567],
 'LDDIM1': [0.96, 13.611632],
 'LDDIM2': [0.962, 13.841884],
 'LDDIM2c': [0.972, 14.267017],
 'CQR': [0.959, 6.2184535708958135],
 'CQR1': [0.968, 7.071869951586454],
 'CQR2': [0.954, 6.647186392210061],
 'CQR2c': [0.964, 5.935847625521185]}

In [15]:
res_all

{'DDIM': [0.964, 14.017761],
 'DDIM1': [0.96, 13.571974],
 'DDIM2': [0.962, 13.748566],
 'DDIM2c': [0.97, 14.220329],
 'DDIM_val': [0.9783783783783784, 14.089476],
 'LDDIM': [0.966, 14.068567],
 'LDDIM1': [0.96, 13.611632],
 'LDDIM2': [0.962, 13.841884],
 'LDDIM2c': [0.972, 14.267017],
 'CQR': [0.959, 6.2184535708958135],
 'CQR1': [0.968, 7.071869951586454],
 'CQR2': [0.954, 6.647186392210061],
 'CQR2c': [0.964, 5.935847625521185]}

In [16]:
def _inner_fn1(X, Y, ddpm, gen_type="ddpm", gen_params={"ddim_eta": 1}, LCP=False):
    # get the len of CI based on intvs, there intvs is a list, each ele is another list contains CIs ele=[CI1, CI2]
    _get_intvs_len = lambda intvs: np.array([sum([np.diff(iv) for iv in intv])[0] for intv in intvs]);
    # get weather vaule in vs is in CI in intvs or not 
    def _get_inset(vs, intvs):
        in_set = []
        for v, intv in zip(vs, intvs):
            in_set.append(np.sum([np.bitwise_and(v>iv[0], v<iv[1]) for iv in intv]))
        in_set = np.array(in_set)
        return in_s
    wcf = LocalWeightedConformalInference(cal_X, 
                         cal_Y,
                         ddpm, ws_fn=wsfun, verbose=2, 
                         gen_type=gen_type,
                         seed=manualSeed,
                         n_jobs=params.n_jobs,
                         inf_bs=params.inf_bs,
                         device=params.device,
                         gen_params=gen_params,
                         wcf_params={
                            "K": params.K, # num of sps for each X
                            "nwhigh" : params.wconformal.nwthigh,
                            "nwlow" : params.wconformal.nwtlow,
                            "useinf": params.wconformal.useinf,
                         })
    wcf.add_data(X)
    if not LCP:
        intvs = wcf(local_method=None, alpha=params.hypo_test.alpha, lm_params={"h":opt_h});
    else: 
        intvs = wcf(local_method="RLCP", alpha=params.hypo_test.alpha, lm_params={"h":opt_h});
    prbs = np.mean(_get_inset(Y, intvs))
    mlen = np.median(_get_intvs_len(intvs))
    return prbs, mlen
        
            
            
        


In [18]:
ddpm = myddpm.ddpm
ddpm.eval()
_inner_fn(test_X1, test_Y1, ddpm, gen_type="ddim", LCP=False)

2024-02-21 20:34:18,998 - weighted_conformal_inference - INFO - wcf params is {'K': 40, 'nwhigh': 20, 'nwlow': 0.05, 'useinf': False, 'cf_type': 'PCP'}
INFO:weighted_conformal_inference:wcf params is {'K': 40, 'nwhigh': 20, 'nwlow': 0.05, 'useinf': False, 'cf_type': 'PCP'}
2024-02-21 20:34:19,000 - weighted_conformal_inference - INFO - gen params is {'ddim_timesteps': 50, 'ddim_eta': 1, 'gen_type': 'ddim'}
INFO:weighted_conformal_inference:gen params is {'ddim_timesteps': 50, 'ddim_eta': 1, 'gen_type': 'ddim'}
100%|████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 3802.63it/s]
100%|████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 3962.23it/s]


(0.96, 13.571974)