In [1]:
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import mean_squared_error
from scipy.special import expit  # Sigmoid function for propensity score
from kernel_regression import KernelRegression
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri

In [2]:
from sklearn.neural_network import MLPClassifier, MLPRegressor

In [3]:
from pathlib import Path
import os
import glob
from joblib import dump, load
import pandas as pd
import scipy
import scipy.stats
import scipy.special
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from utils.riesznet import RieszNet
from utils.moments import ate_moment_fn
from utils.ihdp_data import *

In [4]:
moment_fn = ate_moment_fn

In [5]:
drop_prob = 0.0  # dropout prob of dropout layers throughout notebook
n_hidden = 100  # width of hidden layers throughout notebook

# Training params
learner_lr = 1e-5
learner_l2 = 1e-3
learner_l1 = 0.0
n_epochs = 600
earlystop_rounds = 40 # how many epochs to wait for an out-of-sample improvement
earlystop_delta = 1e-4
target_reg = 1.0
riesz_weight = 0.1

bs = 64
device = torch.cuda.current_device() if torch.cuda.is_available() else None
print("GPU:", torch.cuda.is_available())

from itertools import chain, combinations
from itertools import combinations_with_replacement as combinations_w_r

def _combinations(n_features, degree, interaction_only):
        comb = (combinations if interaction_only else combinations_w_r)
        return chain.from_iterable(comb(range(n_features), i)
                                   for i in range(0, degree + 1))

class Learner(nn.Module):

    def __init__(self, n_t, n_hidden, p, degree, interaction_only=False):
        super().__init__()
        n_common = 200
        self.monomials = list(_combinations(n_t, degree, interaction_only))
        self.common = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())
        self.riesz_nn = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))
        self.riesz_poly = nn.Sequential(nn.Linear(len(self.monomials), 1))
        self.reg_nn0 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_poly = nn.Sequential(nn.Linear(len(self.monomials), 1))


    def forward(self, x):
        poly = torch.cat([torch.prod(x[:, t], dim=1, keepdim=True)
                          for t in self.monomials], dim=1)
        feats = self.common(x)
        riesz = self.riesz_nn(feats) + self.riesz_poly(poly)
        reg = self.reg_nn0(feats) * (1 - x[:, [0]]) + self.reg_nn1(feats) * x[:, [0]] + self.reg_poly(poly)
        return torch.cat([reg, riesz], dim=1)

GPU: False


In [6]:
drop_prob = 0.0  # dropout prob of dropout layers throughout notebook
n_hidden = 100  # width of hidden layers throughout notebook

# Training params
learner_lr = 1e-5
learner_l2 = 1e-3
learner_l1 = 0.0
n_epochs = 600
earlystop_rounds = 40 # how many epochs to wait for an out-of-sample improvement
earlystop_delta = 1e-4
target_reg = 1.0
riesz_weight = 0.1

bs = 64
device = torch.cuda.current_device() if torch.cuda.is_available() else None
print("GPU:", torch.cuda.is_available())

from itertools import chain, combinations
from itertools import combinations_with_replacement as combinations_w_r

def _combinations(n_features, degree, interaction_only):
        comb = (combinations if interaction_only else combinations_w_r)
        return chain.from_iterable(comb(range(n_features), i)
                                   for i in range(0, degree + 1))

class Learner(nn.Module):

    def __init__(self, n_t, n_hidden, p, degree, interaction_only=False):
        super().__init__()
        n_common = 200
        self.monomials = list(_combinations(n_t, degree, interaction_only))
        self.common = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())
        self.riesz_nn = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))
        self.riesz_poly = nn.Sequential(nn.Linear(len(self.monomials), 1))
        self.reg_nn0 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_poly = nn.Sequential(nn.Linear(len(self.monomials), 1))


    def forward(self, x):
        poly = torch.cat([torch.prod(x[:, t], dim=1, keepdim=True)
                          for t in self.monomials], dim=1)
        feats = self.common(x)
        riesz = self.riesz_nn(feats) + self.riesz_poly(poly)
        reg = self.reg_nn0(feats) * (1 - x[:, [0]]) + self.reg_nn1(feats) * x[:, [0]] + self.reg_poly(poly)
        return torch.cat([reg, riesz], dim=1)

GPU: False


In [12]:
# Simulation parameters
n = 3000 # Number of samples
p = 3   # Number of covariates
treatment_effect = 5.0  # True treatment effect

# Generate covariates
np.random.seed(0)

In [13]:
true_ATE = treatment_effect

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class NeuralNetBiasCorrection:
    def __init__(self, input_dim, hidden_dim=100, max_iter=600, tol=1e-4, lbd=0.01, loss="Logit", lr=0.01):
        """
        ニューラルネットワークを用いたバイアス補正モデル
        :param input_dim: 入力特徴量の次元
        :param hidden_dim: 隠れ層のユニット数
        :param max_iter: 最大反復回数
        :param tol: 収束許容誤差
        :param lbd: 正則化項の重み
        :param loss: 損失関数の種類（"Logit", "DBCLS", "DBCLogit"）
        :param lr: 学習率
        """
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.max_iter = max_iter
        self.tol = tol
        self.lbd = lbd
        self.lr = lr
        
        # ニューラルネットワークの構築
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ELU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        if loss == "Logit":
            self.criterion = self._logistic_loss_function
        elif loss == "CBPS":
            self.criterion = self._cbps_loss_function
        elif loss == "DBCLS":
            self.criterion = self._least_squares_loss_function
        elif loss == "DBCUKL":
            self.criterion = self._ukl_loss_function
        else:
            raise ValueError("Invalid loss function specified")

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        
    
    def _least_squares_loss_function(self, X_tensor, outputs, targets):
        """最小二乗誤差損失関数"""
        outputs = torch.clamp(outputs, min=0.05, max=0.95)
        loss = torch.mean(-2 * (1 / outputs + 1 / (1 - outputs)) + (targets / outputs - (1 - targets) / (1 - outputs))**2)
        return loss + self.lbd * self._l2_regularization()
    
    def _ukl_loss_function(self, X_tensor, outputs, targets):
        """制約付きロジスティック損失関数"""
        outputs = torch.clamp(outputs, min=0.05, max=0.95)
        # Correcting the formula with the proper variable names
        loss = - torch.log(1/outputs) - torch.log(1/(1 - outputs)) + targets / outputs + (1 - targets) / (1 - outputs)
        # Mean loss calculation
        loss = torch.mean(loss)
        return loss + self.lbd * self._l2_regularization()
    
    def _logistic_loss_function(self, X_tensor, outputs, targets):
        """ロジスティック損失関数"""
        outputs = torch.clamp(outputs, min=0.05, max=0.95)
        loss = nn.BCELoss()(outputs, targets)  # バイナリクロスエントロピー
        return loss + self.lbd * self._l2_regularization()
    
    def _cbps_loss_function(self, X_tensor, outputs, targets):
        """ロジスティック損失関数"""
        outputs = torch.clamp(outputs, min=0.05, max=0.95)
        loss =  torch.mean((targets * X_tensor / outputs - (1 - targets) * X_tensor / (1 - outputs))**2, axis=0)
        loss = (loss**2).mean()
        return loss + self.lbd * self._l2_regularization()

    def _l2_regularization(self):
        """L2正則化項の計算"""
        reg_loss = sum(torch.sum(param**2) for param in self.model.parameters())
        return reg_loss
    
    def fit(self, X, T):
        """
        モデルの学習
        :param X: 説明変数（N×d の配列）
        :param T: 目的変数（N×1 のバイナリ配列）
        """
        X_tensor = torch.tensor(X, dtype=torch.float32)
        T_tensor = torch.tensor(T, dtype=torch.float32).view(-1, 1)

        prev_loss = float('inf')
        for epoch in range(self.max_iter):
            self.optimizer.zero_grad()
            outputs = self.model(X_tensor)
            loss = self.criterion(X_tensor, outputs, T_tensor)
            loss.backward()
            self.optimizer.step()

            # 収束判定
            if abs(prev_loss - loss.item()) < self.tol:
                break
            prev_loss = loss.item()

    def predict_proba(self, X):
        """
        予測確率の計算
        :param X: 説明変数（N×d の配列）
        :return: クラス1の確率
        """
        X_tensor = torch.tensor(X, dtype=torch.float32)
        with torch.no_grad():
            probas = torch.clamp(self.model(X_tensor), min=0.05, max=0.95).numpy()
        return np.hstack([1 - probas, probas])

    def predict(self, X):
        """
        クラスラベルの予測
        :param X: 説明変数（N×d の配列）
        :return: 0または1の予測クラス
        """
        probas = self.predict_proba(X)
        return (probas[:, 1] >= 0.5).astype(int)

    def get_params(self):
        """
        学習済みパラメータの取得
        :return: モデルの重みとバイアス
        """
        params = {name: param.detach().numpy() for name, param in self.model.named_parameters()}
        return params

### Data generation

In [15]:
num_trial = 1000

In [16]:
nsims = 1000
np.random.seed(123)
#sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)
methods = ['dr', 'direct', 'ips']
srr = {'dr' : True, 'direct' : False, 'ips' : True}


In [18]:
result_list = []
result_list2 = []

for tr in range(num_trial):
    result_list_temp = []

    X = np.random.normal(0, 1, (n, p))

    # Define a propensity score model
    # Assume treatment probability is a sigmoid function of a subset of covariates
    X_temp = np.concatenate([X, X**2, np.array([X[:, 0]*X[:, 1], X[:, 1]*X[:, 2], X[:, 0]*X[:, 2]]).T], axis=1)
    propensity_coef = np.random.normal(0, 0.5, X_temp.shape[1])
    propensity_scores = expit(X_temp @ propensity_coef)  # Calculate propensity scores

    # Generate treatment assignment based on propensity scores
    T = np.random.binomial(1, propensity_scores)

    # Generate outcome with treatment effect
    # Assume a simple linear model for demonstration
    beta = np.random.normal(0, 1, p)
    Y = X@ beta + (X @ beta)**2 + 3*np.sin(X[:, 0]) + 1.1 + treatment_effect * T + np.random.normal(0, 1, n)

    X_treatment = X[T == 1]
    X_control = X[T == 0]
    
    y_scaler = StandardScaler(with_mean=True).fit(np.array([Y]).T)
    y = y_scaler.transform(np.array([Y]).T)
    XT = np.c_[T, X]
    
    X_train, X_test, y_train, y_test = train_test_split(XT, y, test_size = 0.2)

    torch.cuda.empty_cache()
    learner = Learner(X_train.shape[1], n_hidden, drop_prob, 0, interaction_only=True)
    agmm = RieszNet(learner, moment_fn)
    # Fast training
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=2, earlystop_delta=earlystop_delta,
             learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=100, bs=bs, target_reg=target_reg,
             riesz_weight=riesz_weight, optimizer='adam',
             model_dir=str(Path.home()), device=device, verbose=0)
    # Fine tune
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,
             learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=600, bs=bs, target_reg=target_reg,
             riesz_weight=riesz_weight, optimizer='adam', warm_start=True,
             model_dir=str(Path.home()), device=device, verbose=0)
    
    params = tuple(x * y_scaler.scale_[0] for method in methods
                   for x in agmm.predict_avg_moment(XT, y,  model='earlystop', method = method, srr = srr[method])) + (true_ATE, )
                        
    result_list2.append(params)
    
    Y = y.T[0]
    
    Y_treatment = Y[T == 1]
    Y_control = Y[T == 0]
    
    for method in ["DBCLS", "DBCUKL", "CBPS", "Logit"]:
        prop_model = NeuralNetBiasCorrection(input_dim=p, lbd = 0.01, loss=method)
        prop_model.fit(X, T)
        est_prop_score = prop_model.predict_proba(X)[:, 1]
        est_prop_score_dbc = est_prop_score

        #treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
        #control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
        
        treatment_outcome_model = MLPRegressor(random_state=1, max_iter=600)
        control_outcome_model = MLPRegressor(random_state=1, max_iter=600)

        treatment_outcome_model.fit(X_treatment, Y_treatment)
        control_outcome_model.fit(X_control, Y_control)

        est_treatment_outcome = treatment_outcome_model.predict(X)
        est_control_outcome = control_outcome_model.predict(X)

        IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

        # Evaluate performance
        IPW_bias = IPW_est - true_ATE

        result_list_temp.append(IPW_est)

        DM_est = np.mean(est_treatment_outcome - est_control_outcome)

        # Evaluate performance
        DM_bias = DM_est - true_ATE
        
        result_list_temp.append(DM_est)

        DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

        # Evaluate performance
        DR_bias = DR_est - true_ATE

        result_list_temp.append(DR_est)
    
    ##### Linear models
    
    # Fit a linear model to estimate the treatment effect
    model = LinearRegression()
    model.fit(np.hstack([X, T.reshape(-1, 1)]), Y)
    estimated_treatment_effect = model.coef_[-1]

    # Evaluate performance
    true_ATE = treatment_effect
    bias = estimated_treatment_effect - true_ATE
    mse = mean_squared_error(Y, model.predict(np.hstack([X, T.reshape(-1, 1)])))

    result_list_temp.append(estimated_treatment_effect)
    
    #### CBPS
    
    # Enable automatic conversion of Pandas DataFrame to R DataFrame
    pandas2ri.activate()

    # Simulate data in Python

    # Create a pandas DataFrame
    column_names = [f'X{i+1}' for i in range(p)]
    df = pd.DataFrame(X, columns=column_names)
    df['T'] = T
    df['Y'] = Y


    # Convert pandas DataFrame to R DataFrame
    r_df = pandas2ri.py2rpy(df)

    ro.r.assign("p", p)

    # Load the CBPS package in R and fit the model for ATE estimation
    ro.r('''
        library(CBPS)
        estimate_cbps_ate <- function(df) {
            formula_str <- paste("T ~", paste(names(df)[1:{p}], collapse=" + "))

            # CBPSの適用 (ATEの推定、ATT=0)
            model <- CBPS(as.formula(formula_str), data = df, ATT = 0, method = "exact")

            # 推定された傾向スコアの取得
            df$propensity_score <- fitted(model)

            # IPW (Inverse Probability Weighting) を適用
            df$weight <- ifelse(df$T == 1, 1 / df$propensity_score, 1 / (1 - df$propensity_score))

            # 重み付き回帰によるATEの推定
            result <- lm(Y ~ T, data = df, weights = df$weight)

            return(df$propensity_score)
        }
    ''')

    # R関数を呼び出してATEと傾向スコアを取得
    est_prop_score = ro.r['estimate_cbps_ate'](r_df)
    
    est_prop_score_cbps = est_prop_score
    
    #print(er)

    treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
    control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))

    treatment_outcome_model.fit(X_treatment, Y_treatment)
    control_outcome_model.fit(X_control, Y_control)

    est_treatment_outcome = treatment_outcome_model.predict(X)
    est_control_outcome = control_outcome_model.predict(X)

    IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

    # Evaluate performance
    IPW_bias = IPW_est - true_ATE
    
    result_list_temp.append(IPW_est)
    
    DM_est = np.mean(est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DM_bias = DM_est - true_ATE
    
    result_list_temp.append(DM_est)

    DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DR_bias = DR_est - true_ATE
    
    result_list_temp.append(DR_est)
    
    result_list_temp = np.array(result_list_temp)*y_scaler.scale_[0]
    
    result_list.append(result_list_temp)
    
    res = tuple(np.array(x) for x in zip(*result_list2))
    truth = res[-1:]
    res_dict = {}

    res_list_temp = []
    for it, method in enumerate(methods):
        point, lb, ub = res[it * 3: (it + 1)*3]
        res_list_temp.append(point)
        
    result_list_final = np.concatenate([np.array(result_list), np.array(res_list_temp).T], axis=1)
    
    print(np.round(np.sqrt(np.mean((result_list_final - true_ATE)**2, axis=0)), 3))

  return torch.load(os.path.join(self.model_dir,


[0.312 0.124 0.168 0.481 0.124 0.17  0.918 0.124 0.182 0.495 0.124 0.178
 0.003 0.36  0.164 0.175 0.118 0.068 0.13 ]


  return torch.load(os.path.join(self.model_dir,


[0.393 0.088 0.119 0.469 0.088 0.121 1.118 0.088 0.129 0.47  0.088 0.127
 0.021 0.255 0.121 0.134 0.084 0.079 0.093]


  return torch.load(os.path.join(self.model_dir,


[0.323 0.072 0.101 0.387 0.072 0.106 0.93  0.072 0.106 0.384 0.072 0.109
 0.25  0.599 0.143 0.156 0.087 0.081 0.286]


  return torch.load(os.path.join(self.model_dir,


[0.822 0.207 0.168 0.889 0.207 0.177 1.619 0.207 0.198 2.023 0.207 0.199
 2.045 2.05  0.523 0.536 0.078 0.133 0.251]


  return torch.load(os.path.join(self.model_dir,


[0.788 0.195 0.164 0.891 0.195 0.171 1.663 0.195 0.193 1.961 0.195 0.192
 1.948 1.943 0.521 0.53  0.079 0.125 0.226]


  return torch.load(os.path.join(self.model_dir,


[0.723 0.178 0.15  0.845 0.178 0.156 1.564 0.178 0.176 1.798 0.178 0.176
 1.784 1.78  0.477 0.485 0.089 0.127 0.212]


  return torch.load(os.path.join(self.model_dir,


[0.719 0.177 0.155 0.871 0.177 0.159 1.629 0.177 0.18  1.778 0.177 0.179
 1.722 1.731 0.483 0.487 0.09  0.121 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.683 0.174 0.152 0.83  0.174 0.157 1.561 0.174 0.176 1.672 0.174 0.175
 1.621 1.634 0.453 0.46  0.094 0.115 0.212]


  return torch.load(os.path.join(self.model_dir,


[0.71  0.204 0.17  0.956 0.204 0.178 1.713 0.204 0.2   2.007 0.204 0.2
 1.98  2.008 0.533 0.539 0.142 0.157 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.803 0.194 0.163 1.009 0.194 0.17  1.775 0.194 0.193 1.984 0.194 0.191
 1.897 1.942 0.531 0.533 0.136 0.149 0.26 ]


  return torch.load(os.path.join(self.model_dir,


[0.766 0.188 0.158 0.962 0.188 0.165 1.693 0.188 0.186 1.904 0.188 0.185
 1.821 1.865 0.509 0.51  0.132 0.145 0.249]


  return torch.load(os.path.join(self.model_dir,


[0.733 0.18  0.151 0.922 0.18  0.158 1.634 0.18  0.178 1.823 0.18  0.177
 1.744 1.785 0.491 0.489 0.126 0.138 0.238]


  return torch.load(os.path.join(self.model_dir,


[0.705 0.184 0.153 0.886 0.184 0.16  1.575 0.184 0.178 1.759 0.184 0.177
 1.683 1.722 0.48  0.475 0.123 0.134 0.232]


  return torch.load(os.path.join(self.model_dir,


[0.68  0.178 0.148 0.854 0.178 0.156 1.519 0.178 0.173 1.7   0.178 0.172
 1.628 1.666 0.465 0.459 0.119 0.13  0.224]


  return torch.load(os.path.join(self.model_dir,


[0.657 0.173 0.145 0.827 0.173 0.152 1.474 0.173 0.169 1.643 0.173 0.168
 1.573 1.61  0.449 0.444 0.117 0.126 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.637 0.17  0.143 0.801 0.17  0.149 1.427 0.17  0.165 1.591 0.17  0.164
 1.527 1.561 0.437 0.431 0.113 0.122 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.62  0.173 0.145 0.782 0.173 0.149 1.392 0.173 0.167 1.546 0.173 0.165
 1.486 1.519 0.443 0.44  0.111 0.119 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.604 0.169 0.141 0.766 0.169 0.145 1.357 0.169 0.163 1.517 0.169 0.161
 1.465 1.494 0.431 0.428 0.111 0.119 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.588 0.164 0.137 0.753 0.164 0.141 1.348 0.164 0.159 1.484 0.164 0.156
 1.429 1.458 0.421 0.417 0.108 0.116 0.208]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.161 0.134 0.738 0.161 0.138 1.317 0.161 0.156 1.449 0.161 0.153
 1.396 1.424 0.41  0.407 0.107 0.115 0.203]


  return torch.load(os.path.join(self.model_dir,


[0.564 0.157 0.132 0.75  0.157 0.135 1.298 0.157 0.152 1.421 0.157 0.15
 1.363 1.411 0.401 0.399 0.105 0.113 0.199]


  return torch.load(os.path.join(self.model_dir,


[0.555 0.153 0.129 0.735 0.153 0.133 1.275 0.153 0.149 1.391 0.153 0.147
 1.332 1.379 0.392 0.39  0.104 0.111 0.195]


  return torch.load(os.path.join(self.model_dir,


[0.605 0.15  0.13  0.777 0.15  0.133 1.295 0.15  0.148 1.372 0.15  0.147
 1.31  1.35  0.392 0.39  0.274 0.306 0.195]


  return torch.load(os.path.join(self.model_dir,


[0.641 0.166 0.143 0.824 0.166 0.146 1.351 0.166 0.16  1.418 0.166 0.16
 1.355 1.396 0.422 0.421 0.269 0.299 0.191]


  return torch.load(os.path.join(self.model_dir,


[0.629 0.163 0.141 0.812 0.163 0.144 1.329 0.163 0.158 1.393 0.163 0.157
 1.328 1.368 0.413 0.413 0.264 0.294 0.187]


  return torch.load(os.path.join(self.model_dir,


[0.618 0.16  0.138 0.801 0.16  0.141 1.313 0.16  0.155 1.368 0.16  0.154
 1.302 1.342 0.408 0.406 0.259 0.288 0.184]


  return torch.load(os.path.join(self.model_dir,


[0.62  0.157 0.138 0.818 0.157 0.141 1.322 0.157 0.154 1.363 0.157 0.153
 1.3   1.342 0.412 0.41  0.255 0.283 0.184]


  return torch.load(os.path.join(self.model_dir,


[0.617 0.313 0.244 0.851 0.313 0.261 1.307 0.313 0.281 1.494 0.313 0.281
 1.517 1.655 0.534 0.538 0.286 0.327 0.203]


  return torch.load(os.path.join(self.model_dir,


[0.618 0.308 0.24  0.849 0.308 0.256 1.331 0.308 0.276 1.495 0.308 0.276
 1.497 1.641 0.532 0.535 0.281 0.323 0.201]


  return torch.load(os.path.join(self.model_dir,


[0.612 0.315 0.242 0.834 0.315 0.259 1.34  0.315 0.28  1.472 0.315 0.28
 1.477 1.622 0.529 0.532 0.284 0.328 0.205]


  return torch.load(os.path.join(self.model_dir,


[0.603 0.313 0.239 0.821 0.313 0.256 1.32  0.313 0.277 1.449 0.313 0.277
 1.455 1.597 0.522 0.525 0.28  0.323 0.202]


  return torch.load(os.path.join(self.model_dir,


[0.601 0.308 0.236 0.809 0.308 0.252 1.303 0.308 0.272 1.426 0.308 0.273
 1.435 1.574 0.514 0.517 0.276 0.318 0.202]


  return torch.load(os.path.join(self.model_dir,


[0.592 0.303 0.233 0.797 0.303 0.249 1.284 0.303 0.268 1.405 0.303 0.269
 1.413 1.55  0.506 0.509 0.272 0.313 0.199]


  return torch.load(os.path.join(self.model_dir,


[0.689 0.299 0.23  0.869 0.299 0.246 1.336 0.299 0.265 1.498 0.299 0.265
 1.5   1.631 0.499 0.502 0.273 0.311 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.68  0.295 0.227 0.857 0.295 0.242 1.32  0.295 0.261 1.477 0.295 0.262
 1.478 1.608 0.492 0.495 0.269 0.307 0.211]


  return torch.load(os.path.join(self.model_dir,


[0.671 0.291 0.224 0.845 0.291 0.239 1.303 0.291 0.258 1.457 0.291 0.258
 1.458 1.585 0.485 0.488 0.265 0.303 0.208]


  return torch.load(os.path.join(self.model_dir,


[0.663 0.287 0.221 0.836 0.287 0.236 1.311 0.287 0.254 1.439 0.287 0.255
 1.438 1.564 0.479 0.481 0.262 0.299 0.205]


  return torch.load(os.path.join(self.model_dir,


[0.655 0.283 0.218 0.83  0.283 0.233 1.295 0.283 0.251 1.426 0.283 0.252
 1.429 1.551 0.473 0.475 0.259 0.295 0.205]


  return torch.load(os.path.join(self.model_dir,


[0.652 0.281 0.217 0.823 0.281 0.232 1.299 0.281 0.25  1.434 0.281 0.251
 1.432 1.551 0.474 0.476 0.256 0.292 0.206]


  return torch.load(os.path.join(self.model_dir,


[0.647 0.278 0.215 0.815 0.278 0.229 1.292 0.278 0.247 1.418 0.278 0.247
 1.415 1.532 0.468 0.471 0.253 0.288 0.205]


  return torch.load(os.path.join(self.model_dir,


[0.64  0.275 0.212 0.813 0.275 0.226 1.306 0.275 0.244 1.423 0.275 0.244
 1.414 1.532 0.465 0.467 0.25  0.285 0.206]


  return torch.load(os.path.join(self.model_dir,


[0.632 0.271 0.21  0.803 0.271 0.224 1.291 0.271 0.241 1.406 0.271 0.242
 1.397 1.514 0.459 0.461 0.247 0.282 0.205]


  return torch.load(os.path.join(self.model_dir,


[0.625 0.268 0.207 0.794 0.268 0.221 1.278 0.268 0.238 1.39  0.268 0.239
 1.381 1.497 0.454 0.456 0.244 0.279 0.203]


  return torch.load(os.path.join(self.model_dir,


[0.627 0.267 0.207 0.798 0.267 0.22  1.272 0.267 0.238 1.393 0.267 0.238
 1.384 1.499 0.456 0.458 0.242 0.276 0.201]


  return torch.load(os.path.join(self.model_dir,


[0.62  0.264 0.205 0.79  0.264 0.218 1.269 0.264 0.236 1.387 0.264 0.236
 1.374 1.49  0.453 0.455 0.239 0.273 0.199]


  return torch.load(os.path.join(self.model_dir,


[0.613 0.263 0.204 0.782 0.263 0.218 1.256 0.263 0.235 1.376 0.263 0.235
 1.365 1.484 0.452 0.454 0.241 0.273 0.197]


  return torch.load(os.path.join(self.model_dir,


[0.608 0.26  0.202 0.776 0.26  0.215 1.249 0.26  0.232 1.361 0.26  0.232
 1.35  1.468 0.447 0.45  0.238 0.27  0.196]


  return torch.load(os.path.join(self.model_dir,


[0.602 0.258 0.2   0.772 0.258 0.214 1.242 0.258 0.23  1.372 0.258 0.231
 1.362 1.478 0.445 0.447 0.236 0.268 0.194]


  return torch.load(os.path.join(self.model_dir,


[0.597 0.256 0.198 0.766 0.256 0.211 1.237 0.256 0.228 1.359 0.256 0.228
 1.349 1.463 0.441 0.442 0.234 0.265 0.192]


  return torch.load(os.path.join(self.model_dir,


[0.593 0.254 0.198 0.763 0.254 0.211 1.228 0.254 0.227 1.381 0.254 0.227
 1.367 1.478 0.442 0.444 0.231 0.262 0.203]


  return torch.load(os.path.join(self.model_dir,


[0.589 0.252 0.196 0.758 0.252 0.209 1.22  0.252 0.225 1.371 0.252 0.225
 1.354 1.464 0.438 0.44  0.229 0.26  0.201]


  return torch.load(os.path.join(self.model_dir,


[0.586 0.251 0.196 0.75  0.251 0.208 1.212 0.251 0.224 1.359 0.251 0.224
 1.345 1.455 0.436 0.437 0.228 0.258 0.199]


  return torch.load(os.path.join(self.model_dir,


[0.58  0.248 0.195 0.746 0.248 0.207 1.203 0.248 0.222 1.349 0.248 0.222
 1.334 1.444 0.432 0.434 0.226 0.256 0.198]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.25  0.195 0.742 0.25  0.207 1.192 0.25  0.222 1.345 0.25  0.223
 1.332 1.442 0.438 0.439 0.224 0.254 0.205]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.248 0.193 0.738 0.248 0.205 1.186 0.248 0.22  1.334 0.248 0.221
 1.321 1.43  0.434 0.435 0.222 0.251 0.203]


  return torch.load(os.path.join(self.model_dir,


[0.565 0.245 0.191 0.733 0.245 0.203 1.187 0.245 0.218 1.329 0.245 0.219
 1.313 1.422 0.431 0.432 0.22  0.249 0.202]


  return torch.load(os.path.join(self.model_dir,


[0.56  0.244 0.19  0.731 0.244 0.202 1.18  0.244 0.217 1.319 0.244 0.217
 1.31  1.419 0.428 0.429 0.218 0.247 0.2  ]


  return torch.load(os.path.join(self.model_dir,


[0.556 0.241 0.188 0.725 0.241 0.2   1.183 0.241 0.215 1.318 0.241 0.215
 1.306 1.414 0.424 0.426 0.216 0.246 0.198]


  return torch.load(os.path.join(self.model_dir,


[0.551 0.239 0.187 0.719 0.239 0.198 1.175 0.239 0.213 1.307 0.239 0.213
 1.295 1.402 0.421 0.423 0.215 0.244 0.197]


  return torch.load(os.path.join(self.model_dir,


[0.547 0.238 0.185 0.713 0.238 0.197 1.166 0.238 0.211 1.296 0.238 0.212
 1.285 1.39  0.418 0.419 0.213 0.242 0.195]


  return torch.load(os.path.join(self.model_dir,


[0.542 0.239 0.187 0.707 0.239 0.199 1.162 0.239 0.214 1.289 0.239 0.214
 1.282 1.385 0.416 0.42  0.211 0.24  0.193]


  return torch.load(os.path.join(self.model_dir,


[0.538 0.237 0.186 0.709 0.237 0.197 1.156 0.237 0.213 1.292 0.237 0.213
 1.28  1.382 0.414 0.417 0.21  0.238 0.192]


  return torch.load(os.path.join(self.model_dir,


[0.534 0.236 0.185 0.704 0.236 0.196 1.147 0.236 0.211 1.282 0.236 0.211
 1.269 1.371 0.41  0.414 0.208 0.236 0.191]


  return torch.load(os.path.join(self.model_dir,


[0.53  0.238 0.184 0.713 0.238 0.195 1.152 0.238 0.21  1.314 0.238 0.211
 1.301 1.409 0.418 0.422 0.206 0.234 0.194]


  return torch.load(os.path.join(self.model_dir,


[0.527 0.238 0.185 0.709 0.238 0.196 1.143 0.238 0.211 1.313 0.238 0.212
 1.3   1.407 0.422 0.424 0.205 0.233 0.193]


  return torch.load(os.path.join(self.model_dir,


[0.554 0.244 0.191 0.745 0.244 0.202 1.258 0.244 0.218 1.429 0.244 0.219
 1.415 1.511 0.468 0.469 0.204 0.231 0.193]


  return torch.load(os.path.join(self.model_dir,


[0.55  0.242 0.189 0.742 0.242 0.201 1.256 0.242 0.216 1.42  0.242 0.218
 1.405 1.502 0.465 0.466 0.202 0.23  0.191]


  return torch.load(os.path.join(self.model_dir,


[0.547 0.24  0.188 0.738 0.24  0.199 1.248 0.24  0.215 1.409 0.24  0.216
 1.396 1.492 0.462 0.463 0.201 0.228 0.19 ]


  return torch.load(os.path.join(self.model_dir,


[0.543 0.239 0.187 0.732 0.239 0.198 1.239 0.239 0.213 1.399 0.239 0.215
 1.386 1.481 0.459 0.46  0.2   0.227 0.189]


  return torch.load(os.path.join(self.model_dir,


[0.539 0.237 0.186 0.727 0.237 0.197 1.23  0.237 0.212 1.389 0.237 0.213
 1.376 1.471 0.456 0.457 0.199 0.225 0.188]


  return torch.load(os.path.join(self.model_dir,


[0.537 0.236 0.185 0.723 0.236 0.196 1.225 0.236 0.211 1.379 0.236 0.212
 1.366 1.461 0.453 0.454 0.197 0.223 0.186]


  return torch.load(os.path.join(self.model_dir,


[0.54  0.234 0.184 0.724 0.234 0.194 1.218 0.234 0.209 1.373 0.234 0.211
 1.357 1.455 0.45  0.451 0.196 0.222 0.185]


  return torch.load(os.path.join(self.model_dir,


[0.537 0.233 0.182 0.72  0.233 0.193 1.21  0.233 0.208 1.363 0.233 0.209
 1.347 1.445 0.447 0.448 0.195 0.221 0.184]


  return torch.load(os.path.join(self.model_dir,


[0.534 0.231 0.181 0.717 0.231 0.192 1.203 0.231 0.207 1.356 0.231 0.208
 1.338 1.435 0.444 0.445 0.193 0.219 0.183]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.24  0.189 0.78  0.24  0.2   1.265 0.24  0.215 1.466 0.24  0.217
 1.397 1.56  0.481 0.481 0.201 0.22  0.19 ]


  return torch.load(os.path.join(self.model_dir,


[0.568 0.239 0.188 0.788 0.239 0.2   1.263 0.239 0.215 1.499 0.239 0.217
 1.432 1.591 0.484 0.484 0.2   0.218 0.192]


  return torch.load(os.path.join(self.model_dir,


[0.565 0.238 0.187 0.785 0.238 0.198 1.256 0.238 0.213 1.491 0.238 0.215
 1.424 1.582 0.481 0.481 0.198 0.217 0.191]


  return torch.load(os.path.join(self.model_dir,


[0.591 0.252 0.198 0.833 0.252 0.21  1.327 0.252 0.227 1.654 0.252 0.23
 1.597 1.742 0.528 0.531 0.197 0.215 0.204]


  return torch.load(os.path.join(self.model_dir,


[0.588 0.252 0.199 0.827 0.252 0.21  1.323 0.252 0.228 1.644 0.252 0.23
 1.595 1.74  0.531 0.533 0.197 0.215 0.207]


  return torch.load(os.path.join(self.model_dir,


[0.585 0.25  0.198 0.823 0.25  0.209 1.315 0.25  0.226 1.634 0.25  0.229
 1.588 1.731 0.528 0.529 0.195 0.214 0.206]


  return torch.load(os.path.join(self.model_dir,


[0.582 0.249 0.196 0.821 0.249 0.208 1.311 0.249 0.225 1.644 0.249 0.227
 1.601 1.742 0.525 0.527 0.194 0.213 0.205]


  return torch.load(os.path.join(self.model_dir,


[0.579 0.247 0.195 0.817 0.247 0.206 1.306 0.247 0.223 1.634 0.247 0.226
 1.591 1.732 0.522 0.524 0.193 0.212 0.205]


  return torch.load(os.path.join(self.model_dir,


[0.584 0.246 0.194 0.821 0.246 0.205 1.325 0.246 0.222 1.636 0.246 0.225
 1.589 1.727 0.524 0.524 0.192 0.21  0.205]


  return torch.load(os.path.join(self.model_dir,


[0.581 0.245 0.193 0.816 0.245 0.204 1.317 0.245 0.221 1.629 0.245 0.224
 1.581 1.719 0.521 0.521 0.191 0.209 0.204]


  return torch.load(os.path.join(self.model_dir,


[0.579 0.243 0.192 0.816 0.243 0.203 1.319 0.243 0.22  1.623 0.243 0.222
 1.574 1.71  0.519 0.518 0.192 0.211 0.203]


  return torch.load(os.path.join(self.model_dir,


[0.58  0.242 0.191 0.815 0.242 0.202 1.329 0.242 0.219 1.616 0.242 0.221
 1.565 1.701 0.516 0.515 0.191 0.21  0.202]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.24  0.19  0.816 0.24  0.201 1.323 0.24  0.218 1.617 0.24  0.22
 1.567 1.703 0.514 0.513 0.19  0.209 0.202]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.246 0.194 0.813 0.246 0.205 1.322 0.246 0.222 1.615 0.246 0.225
 1.569 1.709 0.519 0.518 0.189 0.209 0.202]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.245 0.193 0.811 0.245 0.204 1.318 0.245 0.221 1.607 0.245 0.224
 1.56  1.7   0.516 0.515 0.188 0.208 0.201]


  return torch.load(os.path.join(self.model_dir,


[0.579 0.246 0.197 0.829 0.246 0.208 1.338 0.246 0.224 1.636 0.246 0.227
 1.569 1.729 0.528 0.526 0.187 0.207 0.201]


  return torch.load(os.path.join(self.model_dir,


[0.576 0.245 0.195 0.825 0.245 0.206 1.331 0.245 0.223 1.627 0.245 0.226
 1.561 1.719 0.525 0.523 0.186 0.205 0.2  ]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.244 0.195 0.82  0.244 0.206 1.327 0.244 0.222 1.619 0.244 0.225
 1.553 1.71  0.523 0.521 0.185 0.204 0.199]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.242 0.194 0.816 0.242 0.205 1.321 0.242 0.221 1.61  0.242 0.224
 1.544 1.701 0.52  0.518 0.184 0.203 0.198]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.241 0.193 0.812 0.241 0.204 1.317 0.241 0.22  1.602 0.241 0.223
 1.537 1.692 0.517 0.515 0.184 0.202 0.197]


  return torch.load(os.path.join(self.model_dir,


[0.568 0.24  0.192 0.809 0.24  0.203 1.327 0.24  0.219 1.594 0.24  0.222
 1.529 1.684 0.516 0.513 0.183 0.202 0.197]


  return torch.load(os.path.join(self.model_dir,


[0.566 0.239 0.192 0.806 0.239 0.202 1.325 0.239 0.218 1.595 0.239 0.221
 1.527 1.682 0.516 0.513 0.182 0.201 0.196]


  return torch.load(os.path.join(self.model_dir,


[0.565 0.238 0.191 0.803 0.238 0.201 1.324 0.238 0.217 1.587 0.238 0.22
 1.52  1.675 0.513 0.511 0.181 0.2   0.197]


  return torch.load(os.path.join(self.model_dir,


[0.562 0.237 0.19  0.8   0.237 0.2   1.317 0.237 0.216 1.583 0.237 0.219
 1.516 1.671 0.511 0.509 0.181 0.199 0.196]


  return torch.load(os.path.join(self.model_dir,


[0.56  0.236 0.189 0.796 0.236 0.199 1.311 0.236 0.215 1.575 0.236 0.218
 1.509 1.664 0.509 0.506 0.181 0.2   0.198]


  return torch.load(os.path.join(self.model_dir,


[0.56  0.235 0.188 0.8   0.235 0.198 1.308 0.235 0.214 1.578 0.235 0.217
 1.51  1.667 0.507 0.505 0.18  0.199 0.199]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.237 0.19  0.814 0.237 0.2   1.314 0.237 0.215 1.603 0.237 0.218
 1.533 1.681 0.517 0.515 0.18  0.198 0.198]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.236 0.189 0.814 0.236 0.199 1.315 0.236 0.214 1.6   0.236 0.217
 1.53  1.676 0.516 0.514 0.179 0.197 0.198]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.235 0.189 0.815 0.235 0.199 1.316 0.235 0.214 1.595 0.235 0.217
 1.523 1.669 0.514 0.512 0.181 0.2   0.206]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.234 0.188 0.818 0.234 0.198 1.32  0.234 0.213 1.595 0.234 0.216
 1.52  1.665 0.512 0.51  0.18  0.2   0.216]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.233 0.187 0.814 0.233 0.197 1.313 0.233 0.212 1.587 0.233 0.215
 1.512 1.657 0.509 0.508 0.179 0.199 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.232 0.186 0.81  0.232 0.196 1.307 0.232 0.211 1.58  0.232 0.214
 1.505 1.649 0.507 0.505 0.179 0.198 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.231 0.186 0.808 0.231 0.195 1.301 0.231 0.21  1.578 0.231 0.213
 1.505 1.647 0.505 0.504 0.178 0.197 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.566 0.231 0.185 0.806 0.231 0.195 1.297 0.231 0.209 1.575 0.231 0.212
 1.502 1.644 0.504 0.503 0.177 0.196 0.212]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.23  0.184 0.808 0.23  0.194 1.309 0.23  0.209 1.59  0.23  0.212
 1.51  1.651 0.503 0.502 0.176 0.196 0.212]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.231 0.186 0.813 0.231 0.195 1.308 0.231 0.21  1.588 0.231 0.213
 1.512 1.653 0.505 0.504 0.176 0.195 0.211]


  return torch.load(os.path.join(self.model_dir,


[0.568 0.23  0.185 0.811 0.23  0.194 1.311 0.23  0.209 1.59  0.23  0.212
 1.515 1.654 0.504 0.502 0.176 0.195 0.212]


  return torch.load(os.path.join(self.model_dir,


[0.566 0.229 0.184 0.808 0.229 0.193 1.306 0.229 0.208 1.583 0.229 0.211
 1.508 1.647 0.502 0.5   0.175 0.194 0.211]


  return torch.load(os.path.join(self.model_dir,


[0.567 0.228 0.183 0.809 0.228 0.193 1.309 0.228 0.207 1.578 0.228 0.21
 1.502 1.64  0.5   0.499 0.174 0.193 0.211]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.23  0.186 0.81  0.23  0.195 1.317 0.23  0.211 1.571 0.23  0.213
 1.497 1.636 0.498 0.497 0.174 0.195 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.23  0.186 0.808 0.23  0.195 1.311 0.23  0.21  1.572 0.23  0.213
 1.499 1.639 0.499 0.498 0.173 0.194 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.229 0.185 0.813 0.229 0.194 1.32  0.229 0.21  1.578 0.229 0.212
 1.504 1.643 0.499 0.498 0.173 0.194 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.228 0.184 0.81  0.228 0.194 1.315 0.228 0.209 1.574 0.228 0.211
 1.497 1.637 0.497 0.496 0.172 0.193 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.567 0.227 0.184 0.807 0.227 0.193 1.311 0.227 0.208 1.568 0.227 0.211
 1.491 1.63  0.495 0.494 0.172 0.192 0.208]


  return torch.load(os.path.join(self.model_dir,


[0.566 0.227 0.183 0.812 0.227 0.193 1.306 0.227 0.208 1.569 0.227 0.211
 1.494 1.63  0.495 0.494 0.171 0.192 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.564 0.226 0.183 0.809 0.226 0.192 1.305 0.226 0.207 1.562 0.226 0.21
 1.488 1.624 0.493 0.492 0.171 0.191 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.562 0.225 0.182 0.807 0.225 0.191 1.302 0.225 0.206 1.556 0.225 0.209
 1.482 1.617 0.491 0.49  0.17  0.19  0.209]


  return torch.load(os.path.join(self.model_dir,


[0.59  0.239 0.196 0.833 0.239 0.206 1.321 0.239 0.221 1.571 0.239 0.223
 1.492 1.625 0.505 0.503 0.174 0.191 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.589 0.243 0.2   0.83  0.243 0.209 1.317 0.243 0.226 1.572 0.243 0.228
 1.505 1.638 0.511 0.51  0.174 0.19  0.216]


  return torch.load(os.path.join(self.model_dir,


[0.589 0.242 0.2   0.833 0.242 0.209 1.325 0.242 0.225 1.583 0.242 0.228
 1.52  1.651 0.511 0.511 0.173 0.189 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.587 0.241 0.199 0.831 0.241 0.208 1.323 0.241 0.225 1.577 0.241 0.227
 1.514 1.644 0.51  0.509 0.172 0.189 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.585 0.241 0.199 0.83  0.241 0.208 1.319 0.241 0.224 1.582 0.241 0.227
 1.524 1.655 0.511 0.51  0.172 0.188 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.584 0.247 0.207 0.849 0.247 0.216 1.315 0.247 0.232 1.619 0.247 0.235
 1.541 1.706 0.526 0.527 0.172 0.187 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.582 0.246 0.206 0.846 0.246 0.215 1.31  0.246 0.232 1.613 0.246 0.234
 1.535 1.699 0.524 0.525 0.172 0.187 0.223]


  return torch.load(os.path.join(self.model_dir,


[0.58  0.246 0.206 0.848 0.246 0.215 1.306 0.246 0.232 1.634 0.246 0.235
 1.548 1.726 0.523 0.525 0.172 0.187 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.578 0.245 0.205 0.844 0.245 0.215 1.301 0.245 0.231 1.627 0.245 0.234
 1.542 1.72  0.521 0.523 0.172 0.187 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.579 0.246 0.206 0.848 0.246 0.216 1.306 0.246 0.232 1.662 0.246 0.235
 1.583 1.758 0.523 0.525 0.171 0.186 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.579 0.245 0.206 0.845 0.245 0.215 1.301 0.245 0.231 1.658 0.245 0.235
 1.58  1.754 0.522 0.523 0.171 0.186 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.578 0.245 0.205 0.842 0.245 0.214 1.298 0.245 0.231 1.652 0.245 0.234
 1.574 1.747 0.52  0.522 0.17  0.185 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.576 0.244 0.205 0.84  0.244 0.214 1.294 0.244 0.23  1.646 0.244 0.233
 1.568 1.741 0.518 0.52  0.17  0.184 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.243 0.204 0.837 0.243 0.213 1.289 0.243 0.229 1.64  0.243 0.232
 1.563 1.734 0.517 0.518 0.169 0.184 0.223]


  return torch.load(os.path.join(self.model_dir,


[0.593 0.243 0.204 0.848 0.243 0.213 1.322 0.243 0.229 1.678 0.243 0.233
 1.602 1.769 0.523 0.524 0.168 0.183 0.226]


  return torch.load(os.path.join(self.model_dir,


[0.591 0.242 0.203 0.845 0.242 0.212 1.321 0.242 0.228 1.675 0.242 0.232
 1.599 1.764 0.522 0.523 0.168 0.183 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.589 0.242 0.202 0.843 0.242 0.211 1.318 0.242 0.228 1.669 0.242 0.231
 1.593 1.758 0.52  0.521 0.167 0.182 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.587 0.241 0.202 0.84  0.241 0.211 1.313 0.241 0.227 1.663 0.241 0.23
 1.587 1.752 0.518 0.519 0.167 0.181 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.587 0.24  0.201 0.838 0.24  0.21  1.311 0.24  0.226 1.657 0.24  0.229
 1.583 1.746 0.517 0.518 0.166 0.181 0.223]


  return torch.load(os.path.join(self.model_dir,


[0.585 0.239 0.201 0.835 0.239 0.21  1.307 0.239 0.226 1.652 0.239 0.229
 1.578 1.742 0.515 0.516 0.166 0.181 0.222]


  return torch.load(os.path.join(self.model_dir,


[0.583 0.239 0.2   0.833 0.239 0.209 1.307 0.239 0.225 1.648 0.239 0.228
 1.575 1.737 0.514 0.515 0.166 0.181 0.222]


  return torch.load(os.path.join(self.model_dir,


[0.581 0.238 0.2   0.83  0.238 0.208 1.303 0.238 0.224 1.645 0.238 0.228
 1.572 1.734 0.512 0.514 0.165 0.18  0.221]


  return torch.load(os.path.join(self.model_dir,


[0.579 0.239 0.2   0.83  0.239 0.209 1.299 0.239 0.225 1.644 0.239 0.228
 1.571 1.735 0.514 0.515 0.165 0.18  0.221]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.238 0.199 0.828 0.238 0.208 1.297 0.238 0.224 1.639 0.238 0.227
 1.566 1.73  0.512 0.513 0.164 0.179 0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.576 0.237 0.199 0.825 0.237 0.208 1.293 0.237 0.223 1.634 0.237 0.227
 1.562 1.724 0.511 0.512 0.164 0.179 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.236 0.198 0.822 0.236 0.207 1.289 0.236 0.223 1.629 0.236 0.226
 1.557 1.719 0.509 0.511 0.163 0.178 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.58  0.236 0.198 0.82  0.236 0.206 1.298 0.236 0.222 1.627 0.236 0.226
 1.552 1.714 0.51  0.512 0.164 0.179 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.579 0.235 0.197 0.818 0.235 0.206 1.295 0.235 0.222 1.623 0.235 0.225
 1.548 1.71  0.509 0.51  0.164 0.179 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.235 0.197 0.823 0.235 0.205 1.301 0.235 0.221 1.627 0.235 0.225
 1.55  1.713 0.509 0.51  0.164 0.178 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.234 0.196 0.821 0.234 0.205 1.297 0.234 0.221 1.621 0.234 0.224
 1.546 1.707 0.507 0.508 0.163 0.178 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.233 0.196 0.818 0.233 0.204 1.292 0.233 0.22  1.62  0.233 0.223
 1.542 1.703 0.506 0.506 0.162 0.177 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.232 0.195 0.816 0.232 0.204 1.289 0.232 0.219 1.615 0.232 0.223
 1.537 1.698 0.504 0.505 0.162 0.177 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.232 0.195 0.813 0.232 0.203 1.287 0.232 0.219 1.61  0.232 0.222
 1.532 1.692 0.503 0.503 0.162 0.176 0.217]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.231 0.194 0.812 0.231 0.202 1.284 0.231 0.218 1.605 0.231 0.221
 1.527 1.687 0.501 0.502 0.161 0.176 0.217]


  return torch.load(os.path.join(self.model_dir,


[0.567 0.23  0.193 0.811 0.23  0.202 1.282 0.23  0.217 1.601 0.23  0.221
 1.524 1.683 0.5   0.5   0.161 0.175 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.566 0.23  0.193 0.809 0.23  0.201 1.282 0.23  0.217 1.597 0.23  0.22
 1.519 1.678 0.499 0.499 0.16  0.175 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.564 0.229 0.192 0.807 0.229 0.201 1.283 0.229 0.216 1.597 0.229 0.219
 1.518 1.677 0.497 0.498 0.16  0.174 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.563 0.228 0.192 0.81  0.228 0.2   1.284 0.228 0.216 1.607 0.228 0.219
 1.525 1.689 0.498 0.498 0.16  0.174 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.566 0.228 0.192 0.807 0.228 0.2   1.282 0.228 0.216 1.602 0.228 0.219
 1.521 1.684 0.496 0.497 0.159 0.173 0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.565 0.228 0.191 0.805 0.228 0.2   1.281 0.228 0.215 1.597 0.228 0.218
 1.516 1.68  0.495 0.496 0.159 0.173 0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.563 0.227 0.191 0.803 0.227 0.199 1.277 0.227 0.214 1.592 0.227 0.218
 1.512 1.674 0.494 0.494 0.159 0.172 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.563 0.227 0.19  0.802 0.227 0.199 1.279 0.227 0.214 1.591 0.227 0.217
 1.509 1.673 0.493 0.493 0.158 0.172 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.564 0.226 0.19  0.804 0.226 0.198 1.28  0.226 0.214 1.599 0.226 0.217
 1.514 1.681 0.494 0.494 0.158 0.172 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.563 0.226 0.19  0.802 0.226 0.198 1.279 0.226 0.213 1.594 0.226 0.216
 1.509 1.676 0.493 0.493 0.158 0.171 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.563 0.225 0.189 0.802 0.225 0.197 1.28  0.225 0.212 1.592 0.225 0.216
 1.508 1.673 0.492 0.492 0.157 0.171 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.561 0.225 0.189 0.801 0.225 0.197 1.279 0.225 0.212 1.59  0.225 0.215
 1.505 1.67  0.491 0.492 0.157 0.171 0.217]


  return torch.load(os.path.join(self.model_dir,


[0.56  0.224 0.188 0.799 0.224 0.196 1.276 0.224 0.211 1.586 0.224 0.215
 1.502 1.666 0.49  0.49  0.157 0.17  0.216]


  return torch.load(os.path.join(self.model_dir,


[0.558 0.225 0.188 0.812 0.225 0.196 1.274 0.225 0.211 1.595 0.225 0.215
 1.513 1.677 0.489 0.49  0.156 0.17  0.216]


  return torch.load(os.path.join(self.model_dir,


[0.557 0.224 0.187 0.811 0.224 0.196 1.273 0.224 0.211 1.6   0.224 0.214
 1.519 1.681 0.49  0.49  0.156 0.169 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.557 0.224 0.187 0.809 0.224 0.195 1.271 0.224 0.21  1.596 0.224 0.213
 1.515 1.676 0.488 0.489 0.156 0.169 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.555 0.223 0.186 0.807 0.223 0.195 1.267 0.223 0.21  1.591 0.223 0.213
 1.511 1.672 0.487 0.487 0.155 0.169 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.554 0.222 0.186 0.807 0.222 0.194 1.263 0.222 0.209 1.587 0.222 0.212
 1.507 1.668 0.486 0.486 0.155 0.168 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.554 0.226 0.188 0.816 0.226 0.197 1.303 0.226 0.213 1.671 0.226 0.217
 1.598 1.752 0.501 0.501 0.157 0.171 0.217]


  return torch.load(os.path.join(self.model_dir,


[0.553 0.225 0.187 0.814 0.225 0.196 1.301 0.225 0.213 1.667 0.225 0.216
 1.594 1.747 0.5   0.5   0.157 0.17  0.216]


  return torch.load(os.path.join(self.model_dir,


[0.551 0.225 0.187 0.811 0.225 0.196 1.297 0.225 0.212 1.663 0.225 0.216
 1.59  1.743 0.499 0.499 0.156 0.17  0.216]


  return torch.load(os.path.join(self.model_dir,


[0.551 0.225 0.187 0.811 0.225 0.196 1.297 0.225 0.212 1.66  0.225 0.215
 1.587 1.74  0.498 0.498 0.156 0.17  0.215]


  return torch.load(os.path.join(self.model_dir,


[0.55  0.224 0.186 0.809 0.224 0.195 1.293 0.224 0.211 1.655 0.224 0.215
 1.583 1.735 0.497 0.497 0.156 0.169 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.548 0.224 0.186 0.806 0.224 0.195 1.29  0.224 0.211 1.651 0.224 0.214
 1.579 1.73  0.496 0.496 0.155 0.169 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.548 0.224 0.186 0.808 0.224 0.195 1.292 0.224 0.211 1.654 0.224 0.215
 1.582 1.733 0.498 0.497 0.155 0.168 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.547 0.223 0.186 0.807 0.223 0.195 1.29  0.223 0.211 1.649 0.223 0.214
 1.579 1.728 0.497 0.496 0.155 0.168 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.546 0.223 0.185 0.805 0.223 0.194 1.287 0.223 0.21  1.645 0.223 0.214
 1.574 1.724 0.495 0.495 0.154 0.168 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.548 0.223 0.186 0.808 0.223 0.195 1.287 0.223 0.211 1.644 0.223 0.214
 1.573 1.723 0.496 0.497 0.154 0.167 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.547 0.223 0.185 0.806 0.223 0.194 1.29  0.223 0.21  1.64  0.223 0.213
 1.569 1.718 0.495 0.495 0.153 0.167 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.546 0.222 0.185 0.803 0.222 0.194 1.289 0.222 0.21  1.635 0.222 0.213
 1.565 1.714 0.494 0.495 0.153 0.167 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.547 0.222 0.185 0.805 0.222 0.194 1.292 0.222 0.21  1.636 0.222 0.213
 1.565 1.714 0.494 0.495 0.153 0.166 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.549 0.221 0.184 0.804 0.221 0.193 1.298 0.221 0.209 1.634 0.221 0.212
 1.561 1.709 0.494 0.494 0.153 0.166 0.212]


  return torch.load(os.path.join(self.model_dir,


[0.547 0.221 0.184 0.803 0.221 0.193 1.294 0.221 0.209 1.629 0.221 0.212
 1.557 1.705 0.492 0.493 0.152 0.166 0.212]


  return torch.load(os.path.join(self.model_dir,


[0.546 0.221 0.185 0.805 0.221 0.193 1.297 0.221 0.209 1.64  0.221 0.213
 1.57  1.717 0.495 0.496 0.153 0.166 0.211]


  return torch.load(os.path.join(self.model_dir,


[0.546 0.221 0.184 0.805 0.221 0.193 1.295 0.221 0.209 1.641 0.221 0.212
 1.569 1.719 0.495 0.495 0.153 0.166 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.545 0.22  0.184 0.803 0.22  0.192 1.291 0.22  0.209 1.636 0.22  0.212
 1.565 1.714 0.493 0.494 0.152 0.166 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.543 0.22  0.183 0.801 0.22  0.192 1.288 0.22  0.208 1.632 0.22  0.211
 1.561 1.71  0.492 0.493 0.152 0.165 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.542 0.219 0.183 0.799 0.219 0.191 1.285 0.219 0.207 1.628 0.219 0.211
 1.557 1.705 0.491 0.491 0.152 0.165 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.543 0.22  0.184 0.802 0.22  0.193 1.29  0.22  0.209 1.635 0.22  0.212
 1.563 1.71  0.493 0.494 0.152 0.164 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.542 0.22  0.184 0.802 0.22  0.192 1.29  0.22  0.209 1.634 0.22  0.212
 1.563 1.709 0.493 0.493 0.151 0.164 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.544 0.219 0.183 0.8   0.219 0.192 1.287 0.219 0.208 1.631 0.219 0.211
 1.56  1.706 0.492 0.492 0.151 0.164 0.208]


  return torch.load(os.path.join(self.model_dir,


[0.543 0.219 0.183 0.799 0.219 0.191 1.289 0.219 0.208 1.63  0.219 0.211
 1.558 1.704 0.491 0.491 0.151 0.164 0.208]


  return torch.load(os.path.join(self.model_dir,


[0.542 0.218 0.183 0.798 0.218 0.191 1.287 0.218 0.207 1.633 0.218 0.211
 1.561 1.706 0.492 0.493 0.152 0.164 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.548 0.219 0.184 0.808 0.219 0.192 1.289 0.219 0.208 1.665 0.219 0.212
 1.595 1.737 0.506 0.507 0.152 0.164 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.547 0.218 0.183 0.807 0.218 0.192 1.286 0.218 0.208 1.661 0.218 0.211
 1.591 1.733 0.505 0.505 0.153 0.165 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.545 0.218 0.183 0.806 0.218 0.191 1.284 0.218 0.207 1.658 0.218 0.211
 1.587 1.729 0.504 0.504 0.153 0.165 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.549 0.217 0.183 0.805 0.217 0.191 1.281 0.217 0.207 1.656 0.217 0.21
 1.588 1.728 0.503 0.504 0.159 0.171 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.548 0.217 0.182 0.803 0.217 0.191 1.278 0.217 0.207 1.652 0.217 0.21
 1.584 1.725 0.502 0.502 0.159 0.171 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.549 0.218 0.182 0.802 0.218 0.191 1.286 0.218 0.207 1.648 0.218 0.21
 1.581 1.723 0.501 0.501 0.158 0.171 0.217]


  return torch.load(os.path.join(self.model_dir,


[0.548 0.217 0.182 0.8   0.217 0.191 1.284 0.217 0.206 1.648 0.217 0.21
 1.583 1.724 0.501 0.501 0.158 0.17  0.217]


  return torch.load(os.path.join(self.model_dir,


[0.547 0.217 0.181 0.798 0.217 0.19  1.282 0.217 0.206 1.646 0.217 0.209
 1.582 1.722 0.5   0.5   0.158 0.17  0.216]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.227 0.189 0.838 0.227 0.198 1.309 0.227 0.215 1.699 0.227 0.219
 1.636 1.774 0.531 0.532 0.157 0.17  0.216]


  return torch.load(os.path.join(self.model_dir,


[0.576 0.226 0.188 0.836 0.226 0.198 1.308 0.226 0.214 1.695 0.226 0.218
 1.632 1.77  0.53  0.531 0.157 0.17  0.216]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.226 0.188 0.837 0.226 0.197 1.315 0.226 0.214 1.7   0.226 0.218
 1.636 1.773 0.531 0.531 0.157 0.17  0.216]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.225 0.187 0.835 0.225 0.197 1.312 0.225 0.213 1.696 0.225 0.217
 1.632 1.769 0.53  0.53  0.156 0.169 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.225 0.187 0.833 0.225 0.197 1.309 0.225 0.213 1.692 0.225 0.217
 1.628 1.765 0.528 0.529 0.156 0.169 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.224 0.187 0.831 0.224 0.196 1.307 0.224 0.213 1.688 0.224 0.216
 1.624 1.761 0.527 0.528 0.156 0.168 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.224 0.187 0.831 0.224 0.196 1.306 0.224 0.212 1.686 0.224 0.216
 1.622 1.758 0.527 0.527 0.156 0.168 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.223 0.186 0.829 0.223 0.196 1.303 0.223 0.212 1.682 0.223 0.216
 1.619 1.754 0.526 0.526 0.155 0.168 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.224 0.188 0.833 0.224 0.197 1.307 0.224 0.214 1.684 0.224 0.218
 1.622 1.758 0.528 0.528 0.155 0.168 0.215]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.224 0.187 0.831 0.224 0.197 1.306 0.224 0.214 1.68  0.224 0.217
 1.618 1.754 0.527 0.527 0.155 0.168 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.224 0.187 0.833 0.224 0.197 1.311 0.224 0.214 1.678 0.224 0.218
 1.616 1.751 0.526 0.527 0.155 0.169 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.224 0.187 0.831 0.224 0.196 1.308 0.224 0.213 1.675 0.224 0.217
 1.612 1.747 0.525 0.525 0.155 0.168 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.223 0.187 0.83  0.223 0.196 1.308 0.223 0.213 1.671 0.223 0.217
 1.609 1.744 0.524 0.525 0.155 0.169 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.223 0.186 0.828 0.223 0.196 1.305 0.223 0.213 1.667 0.223 0.216
 1.605 1.74  0.523 0.523 0.154 0.168 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.223 0.186 0.826 0.223 0.195 1.302 0.223 0.212 1.663 0.223 0.216
 1.602 1.736 0.522 0.522 0.154 0.168 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.586 0.222 0.186 0.835 0.222 0.195 1.33  0.222 0.212 1.671 0.222 0.216
 1.627 1.764 0.53  0.529 0.154 0.168 0.214]


  return torch.load(os.path.join(self.model_dir,


[0.586 0.222 0.186 0.836 0.222 0.195 1.331 0.222 0.212 1.669 0.222 0.215
 1.626 1.762 0.529 0.529 0.153 0.167 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.585 0.221 0.185 0.834 0.221 0.195 1.329 0.221 0.211 1.666 0.221 0.215
 1.623 1.759 0.528 0.528 0.153 0.167 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.584 0.221 0.185 0.834 0.221 0.194 1.326 0.221 0.211 1.663 0.221 0.214
 1.62  1.756 0.527 0.527 0.153 0.167 0.213]


  return torch.load(os.path.join(self.model_dir,


[0.583 0.221 0.184 0.833 0.221 0.194 1.325 0.221 0.21  1.66  0.221 0.214
 1.617 1.753 0.526 0.526 0.153 0.167 0.212]


  return torch.load(os.path.join(self.model_dir,


[0.584 0.221 0.185 0.833 0.221 0.194 1.327 0.221 0.211 1.658 0.221 0.215
 1.615 1.751 0.525 0.526 0.152 0.166 0.212]


  return torch.load(os.path.join(self.model_dir,


[0.582 0.221 0.185 0.831 0.221 0.194 1.324 0.221 0.211 1.655 0.221 0.214
 1.612 1.747 0.524 0.525 0.152 0.166 0.211]


  return torch.load(os.path.join(self.model_dir,


[0.581 0.22  0.184 0.83  0.22  0.194 1.321 0.22  0.21  1.652 0.22  0.214
 1.609 1.744 0.523 0.523 0.152 0.166 0.212]


  return torch.load(os.path.join(self.model_dir,


[0.581 0.22  0.184 0.834 0.22  0.193 1.331 0.22  0.21  1.678 0.22  0.214
 1.637 1.77  0.524 0.524 0.152 0.166 0.212]


  return torch.load(os.path.join(self.model_dir,


[0.58  0.22  0.184 0.835 0.22  0.193 1.33  0.22  0.21  1.682 0.22  0.214
 1.638 1.772 0.523 0.523 0.152 0.165 0.211]


  return torch.load(os.path.join(self.model_dir,


[0.578 0.219 0.184 0.836 0.219 0.193 1.327 0.219 0.21  1.678 0.219 0.213
 1.634 1.768 0.522 0.522 0.151 0.165 0.211]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.219 0.183 0.835 0.219 0.193 1.325 0.219 0.209 1.674 0.219 0.213
 1.631 1.765 0.521 0.521 0.151 0.164 0.211]


  return torch.load(os.path.join(self.model_dir,


[0.576 0.219 0.183 0.836 0.219 0.193 1.322 0.219 0.209 1.674 0.219 0.213
 1.629 1.765 0.52  0.52  0.151 0.165 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.219 0.183 0.834 0.219 0.192 1.321 0.219 0.209 1.67  0.219 0.213
 1.625 1.762 0.519 0.519 0.151 0.164 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.219 0.183 0.836 0.219 0.192 1.322 0.219 0.209 1.682 0.219 0.212
 1.639 1.774 0.521 0.521 0.151 0.164 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.219 0.182 0.838 0.219 0.192 1.331 0.219 0.209 1.687 0.219 0.212
 1.644 1.779 0.522 0.522 0.151 0.164 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.576 0.219 0.183 0.837 0.219 0.192 1.329 0.219 0.209 1.686 0.219 0.212
 1.645 1.781 0.523 0.523 0.152 0.166 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.218 0.182 0.842 0.218 0.192 1.327 0.218 0.208 1.685 0.218 0.212
 1.644 1.78  0.522 0.522 0.151 0.166 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.218 0.182 0.841 0.218 0.192 1.324 0.218 0.208 1.684 0.218 0.212
 1.644 1.779 0.521 0.522 0.151 0.166 0.21 ]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.219 0.182 0.842 0.219 0.192 1.327 0.219 0.208 1.685 0.219 0.212
 1.643 1.779 0.522 0.522 0.151 0.165 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.218 0.182 0.84  0.218 0.191 1.324 0.218 0.208 1.681 0.218 0.211
 1.64  1.775 0.521 0.521 0.151 0.165 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.22  0.183 0.845 0.22  0.193 1.33  0.22  0.21  1.691 0.22  0.213
 1.65  1.784 0.524 0.524 0.15  0.165 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.578 0.221 0.183 0.849 0.221 0.193 1.331 0.221 0.21  1.691 0.221 0.214
 1.649 1.784 0.524 0.524 0.15  0.165 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.579 0.221 0.183 0.85  0.221 0.193 1.332 0.221 0.21  1.691 0.221 0.213
 1.655 1.789 0.524 0.525 0.15  0.165 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.586 0.222 0.184 0.861 0.222 0.194 1.349 0.222 0.211 1.706 0.222 0.215
 1.683 1.816 0.531 0.531 0.15  0.165 0.209]


  return torch.load(os.path.join(self.model_dir,


[0.585 0.221 0.184 0.859 0.221 0.194 1.347 0.221 0.211 1.703 0.221 0.214
 1.68  1.813 0.53  0.53  0.15  0.165 0.208]


  return torch.load(os.path.join(self.model_dir,


[0.585 0.221 0.184 0.861 0.221 0.194 1.347 0.221 0.211 1.707 0.221 0.214
 1.686 1.817 0.53  0.53  0.15  0.165 0.208]


  return torch.load(os.path.join(self.model_dir,


[0.584 0.221 0.184 0.86  0.221 0.193 1.346 0.221 0.21  1.706 0.221 0.214
 1.683 1.814 0.529 0.529 0.15  0.164 0.207]


  return torch.load(os.path.join(self.model_dir,


[0.583 0.22  0.183 0.858 0.22  0.193 1.344 0.22  0.21  1.702 0.22  0.214
 1.68  1.811 0.528 0.528 0.149 0.164 0.207]


  return torch.load(os.path.join(self.model_dir,


[0.582 0.22  0.183 0.858 0.22  0.193 1.342 0.22  0.21  1.705 0.22  0.214
 1.684 1.814 0.528 0.528 0.149 0.164 0.207]


  return torch.load(os.path.join(self.model_dir,


[0.581 0.22  0.183 0.857 0.22  0.193 1.339 0.22  0.21  1.702 0.22  0.213
 1.681 1.81  0.527 0.527 0.149 0.163 0.206]


  return torch.load(os.path.join(self.model_dir,


[0.584 0.219 0.183 0.86  0.219 0.192 1.342 0.219 0.209 1.702 0.219 0.213
 1.679 1.81  0.526 0.527 0.149 0.163 0.207]


  return torch.load(os.path.join(self.model_dir,


[0.584 0.22  0.183 0.86  0.22  0.193 1.348 0.22  0.209 1.703 0.22  0.213
 1.68  1.813 0.527 0.527 0.149 0.164 0.207]


  return torch.load(os.path.join(self.model_dir,


[0.583 0.219 0.182 0.859 0.219 0.192 1.346 0.219 0.209 1.701 0.219 0.212
 1.677 1.81  0.526 0.526 0.149 0.163 0.207]


  return torch.load(os.path.join(self.model_dir,


[0.582 0.219 0.182 0.858 0.219 0.192 1.345 0.219 0.209 1.702 0.219 0.212
 1.679 1.811 0.525 0.525 0.149 0.163 0.206]


  return torch.load(os.path.join(self.model_dir,


[0.582 0.22  0.182 0.856 0.22  0.192 1.344 0.22  0.209 1.698 0.22  0.212
 1.675 1.807 0.524 0.524 0.149 0.164 0.206]


  return torch.load(os.path.join(self.model_dir,


[0.582 0.219 0.182 0.855 0.219 0.192 1.345 0.219 0.209 1.695 0.219 0.212
 1.673 1.805 0.524 0.525 0.179 0.191 0.221]


  return torch.load(os.path.join(self.model_dir,


[0.581 0.219 0.182 0.854 0.219 0.192 1.344 0.219 0.209 1.698 0.219 0.212
 1.676 1.807 0.525 0.525 0.179 0.191 0.221]


  return torch.load(os.path.join(self.model_dir,


[0.579 0.218 0.182 0.853 0.218 0.192 1.341 0.218 0.208 1.695 0.218 0.212
 1.673 1.804 0.524 0.524 0.178 0.19  0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.579 0.218 0.182 0.851 0.218 0.191 1.339 0.218 0.208 1.692 0.218 0.211
 1.67  1.801 0.523 0.523 0.178 0.19  0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.578 0.218 0.181 0.851 0.218 0.191 1.342 0.218 0.208 1.693 0.218 0.211
 1.671 1.802 0.523 0.523 0.178 0.19  0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.217 0.181 0.85  0.217 0.191 1.341 0.217 0.207 1.691 0.217 0.211
 1.669 1.8   0.522 0.522 0.178 0.189 0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.576 0.218 0.181 0.849 0.218 0.191 1.339 0.218 0.208 1.689 0.218 0.211
 1.669 1.8   0.522 0.522 0.178 0.191 0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.218 0.181 0.848 0.218 0.191 1.344 0.218 0.207 1.693 0.218 0.211
 1.673 1.802 0.522 0.523 0.178 0.19  0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.218 0.181 0.847 0.218 0.191 1.343 0.218 0.207 1.691 0.218 0.211
 1.671 1.8   0.522 0.522 0.178 0.19  0.225]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.218 0.181 0.845 0.218 0.191 1.342 0.218 0.207 1.689 0.218 0.211
 1.671 1.799 0.522 0.522 0.178 0.191 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.217 0.181 0.845 0.217 0.19  1.341 0.217 0.207 1.686 0.217 0.21
 1.668 1.796 0.521 0.521 0.178 0.191 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.217 0.18  0.844 0.217 0.19  1.34  0.217 0.207 1.684 0.217 0.21
 1.665 1.793 0.52  0.52  0.178 0.19  0.226]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.216 0.18  0.843 0.216 0.19  1.337 0.216 0.206 1.68  0.216 0.21
 1.662 1.789 0.519 0.519 0.177 0.19  0.226]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.216 0.18  0.841 0.216 0.189 1.336 0.216 0.206 1.677 0.216 0.209
 1.659 1.786 0.518 0.518 0.177 0.19  0.225]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.216 0.179 0.84  0.216 0.189 1.334 0.216 0.205 1.677 0.216 0.209
 1.659 1.786 0.518 0.518 0.177 0.189 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.215 0.179 0.838 0.215 0.189 1.332 0.215 0.205 1.674 0.215 0.209
 1.656 1.782 0.517 0.517 0.177 0.189 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.215 0.179 0.84  0.215 0.188 1.333 0.215 0.205 1.673 0.215 0.208
 1.653 1.78  0.517 0.517 0.176 0.189 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.568 0.214 0.178 0.839 0.214 0.188 1.338 0.214 0.204 1.671 0.214 0.208
 1.65  1.777 0.516 0.516 0.176 0.189 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.216 0.18  0.841 0.216 0.19  1.34  0.216 0.206 1.674 0.216 0.209
 1.665 1.791 0.519 0.519 0.177 0.189 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.216 0.18  0.84  0.216 0.19  1.339 0.216 0.206 1.673 0.216 0.209
 1.666 1.791 0.519 0.519 0.177 0.189 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.216 0.18  0.841 0.216 0.189 1.34  0.216 0.206 1.672 0.216 0.209
 1.664 1.79  0.519 0.519 0.176 0.189 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.215 0.179 0.839 0.215 0.189 1.338 0.215 0.205 1.669 0.215 0.209
 1.662 1.787 0.518 0.518 0.176 0.188 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.215 0.18  0.842 0.215 0.189 1.343 0.215 0.206 1.671 0.215 0.209
 1.663 1.787 0.519 0.519 0.176 0.188 0.226]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.215 0.179 0.84  0.215 0.189 1.341 0.215 0.205 1.668 0.215 0.209
 1.66  1.784 0.518 0.518 0.176 0.188 0.226]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.215 0.179 0.841 0.215 0.189 1.344 0.215 0.205 1.669 0.215 0.209
 1.661 1.784 0.518 0.518 0.175 0.187 0.226]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.215 0.179 0.842 0.215 0.189 1.345 0.215 0.205 1.669 0.215 0.209
 1.661 1.784 0.518 0.518 0.175 0.187 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.215 0.179 0.84  0.215 0.188 1.343 0.215 0.205 1.666 0.215 0.208
 1.658 1.781 0.517 0.517 0.175 0.187 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.214 0.179 0.84  0.214 0.188 1.345 0.214 0.205 1.664 0.214 0.208
 1.655 1.778 0.516 0.516 0.174 0.187 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.214 0.179 0.839 0.214 0.188 1.343 0.214 0.205 1.662 0.214 0.208
 1.653 1.775 0.516 0.515 0.174 0.186 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.214 0.178 0.838 0.214 0.188 1.341 0.214 0.204 1.659 0.214 0.208
 1.65  1.772 0.515 0.514 0.174 0.186 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.214 0.178 0.836 0.214 0.188 1.339 0.214 0.204 1.656 0.214 0.207
 1.647 1.769 0.514 0.514 0.174 0.186 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.214 0.178 0.839 0.214 0.188 1.344 0.214 0.204 1.656 0.214 0.207
 1.645 1.768 0.514 0.514 0.174 0.187 0.223]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.214 0.179 0.88  0.214 0.188 1.341 0.214 0.205 1.661 0.214 0.208
 1.651 1.774 0.516 0.516 0.174 0.186 0.223]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.214 0.179 0.879 0.214 0.188 1.339 0.214 0.204 1.658 0.214 0.207
 1.648 1.771 0.515 0.515 0.174 0.186 0.223]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.214 0.178 0.877 0.214 0.188 1.337 0.214 0.204 1.656 0.214 0.207
 1.646 1.769 0.514 0.514 0.174 0.186 0.222]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.213 0.178 0.876 0.213 0.187 1.335 0.213 0.204 1.653 0.213 0.207
 1.644 1.766 0.514 0.513 0.173 0.185 0.222]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.213 0.178 0.876 0.213 0.187 1.334 0.213 0.203 1.652 0.213 0.206
 1.642 1.765 0.513 0.513 0.173 0.185 0.222]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.213 0.178 0.875 0.213 0.187 1.332 0.213 0.203 1.65  0.213 0.206
 1.64  1.762 0.512 0.512 0.173 0.185 0.222]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.213 0.177 0.876 0.213 0.187 1.332 0.213 0.203 1.649 0.213 0.206
 1.639 1.761 0.512 0.512 0.173 0.185 0.221]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.215 0.179 0.887 0.215 0.188 1.345 0.215 0.204 1.661 0.215 0.207
 1.671 1.789 0.52  0.52  0.172 0.185 0.221]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.214 0.178 0.888 0.214 0.188 1.346 0.214 0.204 1.661 0.214 0.207
 1.669 1.788 0.52  0.52  0.172 0.185 0.221]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.214 0.178 0.886 0.214 0.187 1.345 0.214 0.204 1.659 0.214 0.207
 1.667 1.785 0.519 0.519 0.172 0.184 0.221]


  return torch.load(os.path.join(self.model_dir,


[0.586 0.219 0.182 0.905 0.219 0.191 1.36  0.219 0.209 1.684 0.219 0.212
 1.693 1.808 0.532 0.532 0.173 0.185 0.226]


  return torch.load(os.path.join(self.model_dir,


[0.594 0.232 0.196 0.912 0.232 0.206 1.358 0.232 0.224 1.694 0.232 0.226
 1.718 1.841 0.55  0.551 0.178 0.189 0.226]


  return torch.load(os.path.join(self.model_dir,


[0.596 0.234 0.197 0.915 0.234 0.207 1.359 0.234 0.225 1.695 0.234 0.227
 1.719 1.841 0.552 0.552 0.186 0.199 0.226]


  return torch.load(os.path.join(self.model_dir,


[0.596 0.234 0.197 0.914 0.234 0.206 1.357 0.234 0.225 1.693 0.234 0.227
 1.72  1.842 0.552 0.553 0.186 0.199 0.226]


  return torch.load(os.path.join(self.model_dir,


[0.596 0.234 0.197 0.913 0.234 0.206 1.357 0.234 0.224 1.69  0.234 0.227
 1.718 1.839 0.551 0.552 0.186 0.198 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.595 0.233 0.197 0.912 0.233 0.206 1.355 0.233 0.224 1.689 0.233 0.227
 1.717 1.837 0.55  0.551 0.185 0.198 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.594 0.233 0.196 0.912 0.233 0.206 1.353 0.233 0.224 1.687 0.233 0.226
 1.715 1.835 0.55  0.55  0.185 0.198 0.225]


  return torch.load(os.path.join(self.model_dir,


[0.593 0.233 0.196 0.911 0.233 0.205 1.352 0.233 0.224 1.684 0.233 0.226
 1.712 1.833 0.549 0.55  0.185 0.197 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.592 0.233 0.196 0.909 0.233 0.205 1.35  0.233 0.223 1.681 0.233 0.226
 1.71  1.83  0.548 0.549 0.185 0.197 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.591 0.232 0.196 0.908 0.232 0.205 1.349 0.232 0.223 1.679 0.232 0.225
 1.707 1.827 0.547 0.548 0.184 0.197 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.59  0.232 0.195 0.906 0.232 0.205 1.347 0.232 0.223 1.677 0.232 0.225
 1.705 1.825 0.546 0.547 0.184 0.197 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.59  0.232 0.195 0.905 0.232 0.205 1.347 0.232 0.223 1.677 0.232 0.225
 1.706 1.825 0.546 0.547 0.184 0.197 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.589 0.232 0.195 0.904 0.232 0.204 1.346 0.232 0.222 1.675 0.232 0.225
 1.704 1.823 0.546 0.547 0.184 0.196 0.224]


  return torch.load(os.path.join(self.model_dir,


[0.589 0.232 0.195 0.904 0.232 0.204 1.345 0.232 0.222 1.674 0.232 0.225
 1.706 1.824 0.546 0.547 0.184 0.196 0.223]


  return torch.load(os.path.join(self.model_dir,


[0.588 0.231 0.195 0.903 0.231 0.204 1.343 0.231 0.222 1.672 0.231 0.224
 1.703 1.822 0.546 0.546 0.183 0.196 0.223]


  return torch.load(os.path.join(self.model_dir,


[0.587 0.231 0.195 0.902 0.231 0.204 1.346 0.231 0.222 1.68  0.231 0.224
 1.712 1.831 0.547 0.548 0.183 0.196 0.223]


  return torch.load(os.path.join(self.model_dir,


[0.587 0.232 0.195 0.902 0.232 0.204 1.345 0.232 0.222 1.685 0.232 0.225
 1.718 1.836 0.55  0.551 0.183 0.196 0.223]


  return torch.load(os.path.join(self.model_dir,


[0.588 0.232 0.195 0.902 0.232 0.204 1.347 0.232 0.222 1.685 0.232 0.224
 1.717 1.835 0.55  0.55  0.183 0.196 0.223]


  return torch.load(os.path.join(self.model_dir,


[0.587 0.231 0.195 0.901 0.231 0.204 1.346 0.231 0.222 1.682 0.231 0.224
 1.714 1.832 0.549 0.549 0.183 0.195 0.222]


  return torch.load(os.path.join(self.model_dir,


[0.587 0.231 0.194 0.899 0.231 0.203 1.344 0.231 0.221 1.68  0.231 0.224
 1.712 1.829 0.548 0.549 0.183 0.195 0.222]


  return torch.load(os.path.join(self.model_dir,


[0.586 0.231 0.194 0.898 0.231 0.203 1.342 0.231 0.221 1.677 0.231 0.223
 1.709 1.826 0.547 0.548 0.182 0.195 0.222]


  return torch.load(os.path.join(self.model_dir,


[0.585 0.23  0.194 0.897 0.23  0.203 1.34  0.23  0.221 1.675 0.23  0.223
 1.706 1.823 0.546 0.547 0.182 0.194 0.222]


  return torch.load(os.path.join(self.model_dir,


[0.584 0.23  0.193 0.896 0.23  0.202 1.338 0.23  0.22  1.673 0.23  0.223
 1.704 1.821 0.545 0.546 0.182 0.194 0.221]


  return torch.load(os.path.join(self.model_dir,


[0.583 0.23  0.193 0.895 0.23  0.202 1.336 0.23  0.22  1.67  0.23  0.222
 1.701 1.818 0.545 0.545 0.181 0.194 0.221]


  return torch.load(os.path.join(self.model_dir,


[0.583 0.229 0.193 0.894 0.229 0.202 1.339 0.229 0.22  1.676 0.229 0.222
 1.706 1.823 0.545 0.545 0.181 0.194 0.221]


  return torch.load(os.path.join(self.model_dir,


[0.582 0.229 0.193 0.893 0.229 0.201 1.337 0.229 0.219 1.675 0.229 0.222
 1.705 1.822 0.544 0.544 0.181 0.193 0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.581 0.229 0.192 0.892 0.229 0.201 1.335 0.229 0.219 1.672 0.229 0.221
 1.703 1.819 0.543 0.544 0.181 0.193 0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.58  0.228 0.192 0.89  0.228 0.201 1.333 0.228 0.219 1.67  0.228 0.221
 1.702 1.818 0.542 0.543 0.18  0.193 0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.58  0.228 0.192 0.889 0.228 0.201 1.336 0.228 0.219 1.668 0.228 0.221
 1.699 1.815 0.542 0.542 0.18  0.193 0.22 ]


  return torch.load(os.path.join(self.model_dir,


[0.579 0.228 0.191 0.888 0.228 0.2   1.334 0.228 0.218 1.669 0.228 0.221
 1.705 1.826 0.541 0.542 0.18  0.192 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.578 0.229 0.192 0.887 0.229 0.201 1.333 0.229 0.219 1.67  0.229 0.221
 1.707 1.827 0.542 0.542 0.18  0.192 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.578 0.228 0.192 0.886 0.228 0.201 1.331 0.228 0.219 1.667 0.228 0.221
 1.705 1.825 0.541 0.542 0.18  0.192 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.228 0.192 0.884 0.228 0.201 1.33  0.228 0.218 1.665 0.228 0.221
 1.703 1.822 0.54  0.541 0.179 0.192 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.228 0.192 0.884 0.228 0.201 1.328 0.228 0.219 1.668 0.228 0.221
 1.706 1.824 0.542 0.543 0.179 0.191 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.576 0.228 0.192 0.882 0.228 0.201 1.326 0.228 0.219 1.666 0.228 0.221
 1.704 1.821 0.541 0.542 0.179 0.191 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.228 0.191 0.881 0.228 0.2   1.325 0.228 0.218 1.663 0.228 0.221
 1.701 1.819 0.541 0.541 0.179 0.191 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.578 0.227 0.191 0.882 0.227 0.2   1.326 0.227 0.218 1.661 0.227 0.22
 1.699 1.816 0.541 0.541 0.179 0.191 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.577 0.227 0.191 0.881 0.227 0.2   1.324 0.227 0.218 1.659 0.227 0.22
 1.697 1.814 0.54  0.541 0.178 0.191 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.576 0.227 0.191 0.88  0.227 0.2   1.324 0.227 0.217 1.657 0.227 0.22
 1.694 1.811 0.539 0.54  0.178 0.191 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.576 0.227 0.19  0.879 0.227 0.199 1.322 0.227 0.217 1.655 0.227 0.219
 1.692 1.809 0.538 0.539 0.178 0.19  0.218]


  return torch.load(os.path.join(self.model_dir,


[0.575 0.226 0.19  0.877 0.226 0.199 1.321 0.226 0.217 1.653 0.226 0.219
 1.69  1.807 0.538 0.538 0.178 0.19  0.217]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.226 0.19  0.876 0.226 0.199 1.319 0.226 0.216 1.65  0.226 0.219
 1.688 1.804 0.537 0.537 0.177 0.19  0.217]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.226 0.19  0.875 0.226 0.199 1.32  0.226 0.216 1.648 0.226 0.219
 1.685 1.802 0.536 0.537 0.177 0.19  0.217]


  return torch.load(os.path.join(self.model_dir,


[0.574 0.226 0.19  0.875 0.226 0.198 1.332 0.226 0.216 1.647 0.226 0.218
 1.708 1.819 0.536 0.537 0.177 0.19  0.219]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.225 0.189 0.874 0.225 0.198 1.332 0.225 0.216 1.645 0.225 0.218
 1.706 1.816 0.536 0.536 0.177 0.189 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.573 0.225 0.189 0.873 0.225 0.198 1.33  0.225 0.215 1.643 0.225 0.218
 1.704 1.814 0.535 0.536 0.177 0.189 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.225 0.189 0.873 0.225 0.198 1.328 0.225 0.215 1.642 0.225 0.217
 1.702 1.813 0.535 0.535 0.177 0.189 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.225 0.189 0.872 0.225 0.197 1.327 0.225 0.215 1.64  0.225 0.217
 1.7   1.81  0.534 0.535 0.176 0.189 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.224 0.188 0.871 0.224 0.197 1.326 0.224 0.215 1.638 0.224 0.217
 1.698 1.808 0.533 0.534 0.176 0.189 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.228 0.192 0.873 0.228 0.2   1.325 0.228 0.219 1.651 0.228 0.221
 1.714 1.822 0.538 0.539 0.176 0.188 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.572 0.229 0.192 0.872 0.229 0.201 1.324 0.229 0.219 1.653 0.229 0.221
 1.716 1.824 0.54  0.541 0.176 0.188 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.571 0.229 0.192 0.872 0.229 0.201 1.322 0.229 0.219 1.653 0.229 0.221
 1.716 1.823 0.539 0.54  0.176 0.188 0.219]


  return torch.load(os.path.join(self.model_dir,


[0.57  0.228 0.192 0.871 0.228 0.201 1.321 0.228 0.218 1.65  0.228 0.221
 1.714 1.821 0.539 0.539 0.175 0.188 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.228 0.192 0.87  0.228 0.2   1.319 0.228 0.218 1.648 0.228 0.221
 1.711 1.818 0.538 0.539 0.175 0.187 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.569 0.228 0.191 0.868 0.228 0.2   1.318 0.228 0.218 1.646 0.228 0.22
 1.709 1.816 0.537 0.538 0.175 0.187 0.218]


  return torch.load(os.path.join(self.model_dir,


[0.568 0.228 0.191 0.868 0.228 0.2   1.317 0.228 0.218 1.644 0.228 0.22
 1.707 1.814 0.536 0.537 0.175 0.187 0.217]


  return torch.load(os.path.join(self.model_dir,


[0.568 0.227 0.191 0.867 0.227 0.2   1.315 0.227 0.218 1.642 0.227 0.22
 1.704 1.811 0.536 0.536 0.175 0.187 0.217]


  return torch.load(os.path.join(self.model_dir,


[0.568 0.227 0.191 0.866 0.227 0.2   1.314 0.227 0.217 1.639 0.227 0.22
 1.702 1.809 0.535 0.536 0.174 0.187 0.217]


  return torch.load(os.path.join(self.model_dir,


[0.568 0.227 0.191 0.865 0.227 0.199 1.316 0.227 0.217 1.637 0.227 0.22
 1.7   1.806 0.535 0.535 0.174 0.186 0.217]


  return torch.load(os.path.join(self.model_dir,


[0.567 0.227 0.19  0.864 0.227 0.199 1.314 0.227 0.217 1.635 0.227 0.219
 1.697 1.804 0.534 0.535 0.174 0.186 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.566 0.226 0.19  0.863 0.226 0.199 1.313 0.226 0.217 1.633 0.226 0.219
 1.695 1.801 0.533 0.534 0.174 0.186 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.565 0.226 0.19  0.862 0.226 0.199 1.311 0.226 0.216 1.631 0.226 0.219
 1.693 1.799 0.533 0.533 0.174 0.187 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.565 0.226 0.19  0.862 0.226 0.198 1.31  0.226 0.216 1.629 0.226 0.218
 1.691 1.797 0.532 0.532 0.174 0.186 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.565 0.226 0.19  0.861 0.226 0.198 1.311 0.226 0.216 1.629 0.226 0.218
 1.698 1.804 0.533 0.534 0.174 0.186 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.564 0.226 0.189 0.86  0.226 0.198 1.31  0.226 0.216 1.627 0.226 0.218
 1.696 1.802 0.533 0.533 0.174 0.186 0.216]


  return torch.load(os.path.join(self.model_dir,


[0.563 0.225 0.189 0.859 0.225 0.198 1.308 0.225 0.216 1.625 0.225 0.218
 1.697 1.802 0.532 0.533 0.173 0.186 0.217]


  return torch.load(os.path.join(self.model_dir,


[0.562 0.225 0.189 0.859 0.225 0.198 1.312 0.225 0.215 1.624 0.225 0.218
 1.698 1.803 0.532 0.533 0.173 0.186 0.217]


  return torch.load(os.path.join(self.model_dir,


[0.562 0.225 0.189 0.858 0.225 0.197 1.313 0.225 0.215 1.622 0.225 0.217
 1.696 1.801 0.532 0.532 0.173 0.185 0.216]



KeyboardInterrupt



In [19]:
res = tuple(np.array(x) for x in zip(*result_list2))
truth = res[-1:]
res_dict = {}

res_list_temp = []
for it, method in enumerate(methods):
    point, lb, ub = res[it * 3: (it + 1)*3]
    res_list_temp.append(point)

In [20]:
result_list_final = np.concatenate([np.array(result_list), np.array(res_list_temp).T], axis=1)

In [21]:
np.savetxt("ATE_result3.csv", result_list_final)

In [22]:
np.round(np.sqrt(np.mean((result_list_final - true_ATE)**2, axis=0)), 3)

array([0.562, 0.225, 0.189, 0.858, 0.225, 0.197, 1.313, 0.225, 0.215,
       1.622, 0.225, 0.217, 1.696, 1.801, 0.532, 0.532, 0.173, 0.185,
       0.216])

In [23]:
len(result_list_final)

367

In [23]:
np.array(result_list)

array([[2.33420245, 2.69248963, 2.78836369, 2.05049704, 2.69248963,
        2.78292243, 1.17453   , 2.69248963, 2.72538244, 1.14573749,
        2.69248963, 2.73356029, 0.1674897 , 0.17318048, 2.18604934,
        2.17192884],
       [2.90344812, 3.01980062, 3.00719726, 2.72691919, 3.01980062,
        3.00100208, 2.71324345, 3.01980062, 3.00840853, 2.78409739,
        3.01980062, 2.99738415, 2.88539763, 2.76368037, 2.96821651,
        2.95634584]])

In [49]:
res = tuple(np.array(x) for x in zip(*result_list2))
truth = res[-1:]
res_dict = {}

res_list_temp = []
for it, method in enumerate(methods):
    point, lb, ub = res[it * 3: (it + 1)*3]
    res_list_temp.append(point)

In [65]:
np.concatenate([np.array(result_list), np.array(res_list_temp).T], axis=1)

array([[3.13824887, 2.98757769, 3.00656468, 2.93689259, 2.98757769,
        2.99441305, 3.21834824, 2.98757769, 2.99590128, 3.08447467,
        2.98757769, 2.99751422, 2.96991877, 2.37510366, 2.98757769,
        2.98152865, 2.98207014, 2.99127682, 2.94388743],
       [3.51725208, 3.55584463, 3.4887395 , 3.46519865, 3.55584463,
        3.50424109, 5.49242025, 3.55584463, 3.54794203, 4.19640302,
        3.55584463, 3.53464643, 4.90452631, 4.88247822, 3.55584463,
        3.56359317, 3.10494999, 3.1386278 , 3.01535168]])

In [60]:
np.array(res_list_temp).T.shape

(2, 3)

In [46]:
res_dict

{'dr': {'point': array([2.98207014, 3.10494999]),
  'lb': array([2.89669808, 3.02416755]),
  'ub': array([3.0674422 , 3.18573244]),
  'cov': 0.5,
  'bias': 0.043510066762107336},
 'direct': {'point': array([2.99127682, 3.1386278 ]),
  'lb': array([2.96516755, 3.11051796]),
  'ub': array([3.01738608, 3.16673764]),
  'cov': 0.5,
  'bias': 0.06495230559765441},
 'ips': {'point': array([2.94388743, 3.01535168]),
  'lb': array([2.29009494, 2.70514945]),
  'ub': array([3.59767992, 3.3255539 ]),
  'cov': 1.0,
  'bias': -0.020380447898789456}}

In [58]:
np.array(result_list).shape

(2, 16)

In [186]:
result_list = []

for tr in range(num_trial):
    result_list_temp = []
    
    
    X = np.random.normal(0, 1, (n, p))

    # Define a propensity score model
    # Assume treatment probability is a sigmoid function of a subset of covariates
    X_temp = np.concatenate([X, X**2, np.array([X[:, 0]*X[:, 1], X[:, 1]*X[:, 2], X[:, 0]*X[:, 2]]).T], axis=1)
    propensity_coef = np.random.normal(0, 0.5, X_temp.shape[1])
    propensity_scores = expit(X_temp @ propensity_coef)  # Calculate propensity scores

    # Generate treatment assignment based on propensity scores
    T = np.random.binomial(1, propensity_scores)

    # Generate outcome with treatment effect
    # Assume a simple linear model for demonstration
    beta = np.random.normal(0, 1, p)
    Y = (X @ beta)**2 + 1.1 + treatment_effect * T + np.random.normal(0, 1, n)

    X_treatment = X[T == 1]
    X_control = X[T == 0]

    Y_treatment = Y[T == 1]
    Y_control = Y[T == 0]
    
    #### Direct bias correction
    prop_model = NeuralNetBiasCorrection(input_dim=p, lbd = 0.01, loss="DBCLS")
    prop_model.fit(X, T)
    est_prop_score = prop_model.predict_proba(X)[:, 1]
    est_prop_score_dbc = est_prop_score

    treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
    control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))

    treatment_outcome_model.fit(X_treatment, Y_treatment)
    control_outcome_model.fit(X_control, Y_control)

    est_treatment_outcome = treatment_outcome_model.predict(X)
    est_control_outcome = control_outcome_model.predict(X)

    IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

    # Evaluate performance
    IPW_bias = IPW_est - true_ATE

    print(f"Estimated ATE: {IPW_est}")
    print(f"Bias: {IPW_bias}")
    
    result_list_temp.append(IPW_est)
    
    DM_est = np.mean(est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DM_bias = DM_est - true_ATE

    print(f"Estimated ATE: {DM_est}")
    print(f"Bias: {DM_bias}")
    
    result_list_temp.append(DM_est)

    DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DR_bias = DR_est - true_ATE

    print(f"Estimated ATE: {DR_est}")
    print(f"Bias: {DR_bias}")
    
    result_list_temp.append(DR_est)
    
    #### Direct bias correction
    prop_model = NeuralNetBiasCorrection(input_dim=p, lbd = 0.01, loss="CBPS")
    prop_model.fit(X, T)
    est_prop_score = prop_model.predict_proba(X)[:, 1]
    est_prop_score_dbc = est_prop_score

    treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
    control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))

    treatment_outcome_model.fit(X_treatment, Y_treatment)
    control_outcome_model.fit(X_control, Y_control)

    est_treatment_outcome = treatment_outcome_model.predict(X)
    est_control_outcome = control_outcome_model.predict(X)

    IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

    # Evaluate performance
    IPW_bias = IPW_est - true_ATE

    print(f"Estimated ATE: {IPW_est}")
    print(f"Bias: {IPW_bias}")
    
    result_list_temp.append(IPW_est)
    
    DM_est = np.mean(est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DM_bias = DM_est - true_ATE

    print(f"Estimated ATE: {DM_est}")
    print(f"Bias: {DM_bias}")
    
    result_list_temp.append(DM_est)

    DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DR_bias = DR_est - true_ATE

    print(f"Estimated ATE: {DR_est}")
    print(f"Bias: {DR_bias}")
    
    result_list_temp.append(DR_est)
    
    ##### Linear models
    
    # Fit a linear model to estimate the treatment effect
    model = LinearRegression()
    model.fit(np.hstack([X, T.reshape(-1, 1)]), Y)
    estimated_treatment_effect = model.coef_[-1]

    # Evaluate performance
    true_ATE = treatment_effect
    bias = estimated_treatment_effect - true_ATE
    mse = mean_squared_error(Y, model.predict(np.hstack([X, T.reshape(-1, 1)])))

    print(f"Estimated ATE: {estimated_treatment_effect}")
    print(f"Bias: {bias}")
    print(f"Mean Squared Error: {mse}")

    result_list_temp.append(estimated_treatment_effect)
    
    #### Logistc regression 
    
    prop_model = NeuralNetBiasCorrection(input_dim=p, lbd = 0., loss="Logit")
    prop_model.fit(X, T)
    est_prop_score = prop_model.predict_proba(X)[:, 1]

    treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
    control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))

    treatment_outcome_model.fit(X_treatment, Y_treatment)
    control_outcome_model.fit(X_control, Y_control)

    est_treatment_outcome = treatment_outcome_model.predict(X)
    est_control_outcome = control_outcome_model.predict(X)

    IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

    # Evaluate performance
    IPW_bias = IPW_est - true_ATE

    print(f"Estimated ATE: {IPW_est}")
    print(f"Bias: {IPW_bias}")
    
    result_list_temp.append(IPW_est)
    
    DM_est = np.mean(est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DM_bias = DM_est - true_ATE

    print(f"Estimated ATE: {DM_est}")
    print(f"Bias: {DM_bias}")
    
    result_list_temp.append(DM_est)

    DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DR_bias = DR_est - true_ATE

    print(f"Estimated ATE: {DR_est}")
    print(f"Bias: {DR_bias}")
    
    result_list_temp.append(DR_est)
    
    #### CBPS
    
    # Enable automatic conversion of Pandas DataFrame to R DataFrame
    pandas2ri.activate()

    # Simulate data in Python

    # Create a pandas DataFrame
    column_names = [f'X{i+1}' for i in range(p)]
    df = pd.DataFrame(X, columns=column_names)
    df['T'] = T
    df['Y'] = Y


    # Convert pandas DataFrame to R DataFrame
    r_df = pandas2ri.py2rpy(df)

    ro.r.assign("p", p)

    # Load the CBPS package in R and fit the model for ATE estimation
    ro.r('''
        library(CBPS)
        estimate_cbps_ate <- function(df) {
            formula_str <- paste("T ~", paste(names(df)[1:{p}], collapse=" + "))

            # CBPSの適用 (ATEの推定、ATT=0)
            model <- CBPS(as.formula(formula_str), data = df, ATT = 0, method = "exact")

            # 推定された傾向スコアの取得
            df$propensity_score <- fitted(model)

            # IPW (Inverse Probability Weighting) を適用
            df$weight <- ifelse(df$T == 1, 1 / df$propensity_score, 1 / (1 - df$propensity_score))

            # 重み付き回帰によるATEの推定
            result <- lm(Y ~ T, data = df, weights = df$weight)

            return(df$propensity_score)
        }
    ''')

    # R関数を呼び出してATEと傾向スコアを取得
    est_prop_score = ro.r['estimate_cbps_ate'](r_df)
    
    est_prop_score_cbps = est_prop_score
    
    #print(er)

    treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
    control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))

    treatment_outcome_model.fit(X_treatment, Y_treatment)
    control_outcome_model.fit(X_control, Y_control)

    est_treatment_outcome = treatment_outcome_model.predict(X)
    est_control_outcome = control_outcome_model.predict(X)

    IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

    # Evaluate performance
    IPW_bias = IPW_est - true_ATE

    print(f"Estimated ATE: {IPW_est}")
    print(f"Bias: {IPW_bias}")
    
    result_list_temp.append(IPW_est)
    
    DM_est = np.mean(est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DM_bias = DM_est - true_ATE

    print(f"Estimated ATE: {DM_est}")
    print(f"Bias: {DM_bias}")
    
    result_list_temp.append(DM_est)

    DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DR_bias = DR_est - true_ATE

    print(f"Estimated ATE: {DR_est}")
    print(f"Bias: {DR_bias}")
    
    result_list_temp.append(DR_est)
    
    
    result_list.append(result_list_temp)

Estimated ATE: 3.666468677704666
Bias: 0.6664686777046658
Estimated ATE: 3.197314788090588
Bias: 0.1973147880905879
Estimated ATE: 3.1961531208078506
Bias: 0.19615312080785063
Estimated ATE: 4.601117985662866
Bias: 1.6011179856628663
Estimated ATE: 3.197314788090588
Bias: 0.1973147880905879
Estimated ATE: 3.21044721347009
Bias: 0.21044721347008988
Estimated ATE: 3.3286003541844162
Bias: 0.32860035418441624
Mean Squared Error: 17.199976092090044
Estimated ATE: 3.5808537722716682
Bias: 0.5808537722716682
Estimated ATE: 3.197314788090588
Bias: 0.1973147880905879
Estimated ATE: 3.194816434072869
Bias: 0.19481643407286908
Estimated ATE: 3.3077861256136964
Bias: 0.30778612561369645
Estimated ATE: 3.197314788090588
Bias: 0.1973147880905879
Estimated ATE: 3.2033339822009474
Bias: 0.20333398220094745
Estimated ATE: 3.1291984672441298
Bias: 0.12919846724412976
Estimated ATE: 2.9329038786413597
Bias: -0.06709612135864029
Estimated ATE: 2.957972282635173
Bias: -0.04202771736482713
Estimated ATE: 3

Estimated ATE: 4.480063581477346
Bias: 1.4800635814773457
Estimated ATE: 3.211012872255794
Bias: 0.21101287225579402
Estimated ATE: 3.1660743637380966
Bias: 0.1660743637380966
Estimated ATE: 2.4853513152555715
Bias: -0.5146486847444285
Mean Squared Error: 15.831362221934224
Estimated ATE: 3.3761699274878327
Bias: 0.37616992748783273
Estimated ATE: 3.211012872255794
Bias: 0.21101287225579402
Estimated ATE: 3.174281498886304
Bias: 0.17428149888630395
Estimated ATE: 2.578085457490623
Bias: -0.4219145425093771
Estimated ATE: 3.211012872255794
Bias: 0.21101287225579402
Estimated ATE: 3.113174035127961
Bias: 0.11317403512796087
Estimated ATE: 2.9283066741281996
Bias: -0.07169332587180044
Estimated ATE: 3.042952440901091
Bias: 0.04295244090109085
Estimated ATE: 3.0299444571977516
Bias: 0.02994445719775163
Estimated ATE: 2.549836212862722
Bias: -0.4501637871372779
Estimated ATE: 3.042952440901091
Bias: 0.04295244090109085
Estimated ATE: 3.037379304998199
Bias: 0.03737930499819919
Estimated ATE

Estimated ATE: 3.361435236237726
Bias: 0.361435236237726
Estimated ATE: 3.3697215452767826
Bias: 0.3697215452767826
Estimated ATE: 3.319099506579724
Bias: 0.31909950657972397
Estimated ATE: 4.534702122759686
Bias: 1.5347021227596862
Estimated ATE: 3.3697215452767826
Bias: 0.3697215452767826
Estimated ATE: 3.3735313947034604
Bias: 0.3735313947034604
Estimated ATE: 3.231489395949986
Bias: 0.23148939594998597
Estimated ATE: 3.8223027174408695
Bias: 0.8223027174408695
Estimated ATE: 3.7102635797078722
Bias: 0.7102635797078722
Estimated ATE: 7.016509392524352
Bias: 4.016509392524352
Estimated ATE: 3.8223027174408695
Bias: 0.8223027174408695
Estimated ATE: 3.8036748897198662
Bias: 0.8036748897198662
Estimated ATE: 7.088392945419314
Bias: 4.088392945419314
Mean Squared Error: 59.38350448675715
Estimated ATE: 4.262228979372108
Bias: 1.2622289793721082
Estimated ATE: 3.8223027174408695
Bias: 0.8223027174408695
Estimated ATE: 3.7282271290617235
Bias: 0.7282271290617235
Estimated ATE: 7.032412773

Estimated ATE: 8.741429384382585
Bias: 5.7414293843825845
Estimated ATE: 4.017896885419382
Bias: 1.0178968854193817
Estimated ATE: 4.020466890611351
Bias: 1.0204668906113508
Estimated ATE: 2.3983881022073796
Bias: -0.6016118977926204
Estimated ATE: 2.5349538900006725
Bias: -0.4650461099993275
Estimated ATE: 2.6917158841836373
Bias: -0.3082841158163627
Estimated ATE: 1.911896267480981
Bias: -1.088103732519019
Estimated ATE: 2.5349538900006725
Bias: -0.4650461099993275
Estimated ATE: 2.6186504261357295
Bias: -0.3813495738642705
Estimated ATE: 1.8105533280209705
Bias: -1.1894466719790295
Mean Squared Error: 6.351877588816076
Estimated ATE: 2.4779278066888395
Bias: -0.5220721933111605
Estimated ATE: 2.5349538900006725
Bias: -0.4650461099993275
Estimated ATE: 2.696285958602931
Bias: -0.30371404139706915
Estimated ATE: 1.7560116381390043
Bias: -1.2439883618609957
Estimated ATE: 2.5349538900006725
Bias: -0.4650461099993275
Estimated ATE: 2.5931690073033575
Bias: -0.40683099269664247
Estimated

Estimated ATE: 4.320964480125358
Bias: 1.3209644801253582
Estimated ATE: 2.9509287823447377
Bias: -0.049071217655262345
Estimated ATE: 3.01881200736428
Bias: 0.018812007364279903
Estimated ATE: 4.100405569002077
Bias: 1.1004055690020769
Estimated ATE: 2.9509287823447377
Bias: -0.049071217655262345
Estimated ATE: 2.9947901591878034
Bias: -0.005209840812196553
Estimated ATE: 2.1863522079864794
Bias: -0.8136477920135206
Mean Squared Error: 89.10015774022669
Estimated ATE: 3.4903545864797114
Bias: 0.4903545864797114
Estimated ATE: 2.9509287823447377
Bias: -0.049071217655262345
Estimated ATE: 2.9996107426101135
Bias: -0.00038925738988648817
Estimated ATE: 2.38228835952334
Bias: -0.61771164047666
Estimated ATE: 2.9509287823447377
Bias: -0.049071217655262345
Estimated ATE: 2.9732482794736774
Bias: -0.026751720526322575
Estimated ATE: 3.354539764877999
Bias: 0.3545397648779991
Estimated ATE: 2.733476919144341
Bias: -0.2665230808556589
Estimated ATE: 2.7709241344686806
Bias: -0.2290758655313194

Estimated ATE: 3.803234638640029
Bias: 0.8032346386400291
Estimated ATE: 3.0236413163437925
Bias: 0.023641316343792518
Estimated ATE: 2.9631953568719727
Bias: -0.03680464312802734
Estimated ATE: 2.854627946016926
Bias: -0.14537205398307407
Mean Squared Error: 7.055476810045092
Estimated ATE: 3.059604597465503
Bias: 0.05960459746550306
Estimated ATE: 3.0236413163437925
Bias: 0.023641316343792518
Estimated ATE: 2.9534848687660142
Bias: -0.04651513123398576
Estimated ATE: 2.868203055891288
Bias: -0.13179694410871212
Estimated ATE: 3.0236413163437925
Bias: 0.023641316343792518
Estimated ATE: 2.9438196787544353
Bias: -0.05618032124556471
Estimated ATE: 3.0342841562223923
Bias: 0.03428415622239234
Estimated ATE: 3.1711169219366946
Bias: 0.1711169219366946
Estimated ATE: 3.1228324165274914
Bias: 0.12283241652749144
Estimated ATE: 3.344726412093466
Bias: 0.3447264120934661
Estimated ATE: 3.1711169219366946
Bias: 0.1711169219366946
Estimated ATE: 3.155076565312471
Bias: 0.15507656531247083
Esti

Estimated ATE: 2.7555058982652807
Bias: -0.2444941017347193
Estimated ATE: 2.782118738278893
Bias: -0.21788126172110678
Estimated ATE: 2.839058984717843
Bias: -0.1609410152821571
Estimated ATE: 2.196029496876571
Bias: -0.8039705031234292
Estimated ATE: 2.782118738278893
Bias: -0.21788126172110678
Estimated ATE: 2.7635034041581306
Bias: -0.23649659584186944
Estimated ATE: 2.7016345063492233
Bias: -0.2983654936507767
Estimated ATE: 2.8629015075726243
Bias: -0.1370984924273757
Estimated ATE: 2.9360734029668585
Bias: -0.06392659703314152
Estimated ATE: 1.2658518190115575
Bias: -1.7341481809884425
Estimated ATE: 2.8629015075726243
Bias: -0.1370984924273757
Estimated ATE: 2.8823815331070146
Bias: -0.11761846689298538
Estimated ATE: 2.827482605718764
Bias: -0.17251739428123614
Mean Squared Error: 1.3058491802496193
Estimated ATE: 3.2082817997199164
Bias: 0.20828179971991645
Estimated ATE: 2.8629015075726243
Bias: -0.1370984924273757
Estimated ATE: 2.930958519731527
Bias: -0.0690414802684729
E

Estimated ATE: -7.626616994626293
Bias: -10.626616994626293
Estimated ATE: 0.5493879170123054
Bias: -2.4506120829876945
Estimated ATE: 0.5214340880225721
Bias: -2.478565911977428
Estimated ATE: 3.174375735416997
Bias: 0.17437573541699702
Estimated ATE: 3.171624034669089
Bias: 0.17162403466908893
Estimated ATE: 3.138903961513802
Bias: 0.1389039615138019
Estimated ATE: 3.3970353804063835
Bias: 0.39703538040638353
Estimated ATE: 3.171624034669089
Bias: 0.17162403466908893
Estimated ATE: 3.1694581143442555
Bias: 0.1694581143442555
Estimated ATE: 3.312909791460142
Bias: 0.3129097914601422
Mean Squared Error: 3.1748817727896337
Estimated ATE: 3.2690252617101723
Bias: 0.26902526171017227
Estimated ATE: 3.171624034669089
Bias: 0.17162403466908893
Estimated ATE: 3.1462490493555046
Bias: 0.14624904935550465
Estimated ATE: 3.3674646443363865
Bias: 0.3674646443363865
Estimated ATE: 3.171624034669089
Bias: 0.17162403466908893
Estimated ATE: 3.1736035353975205
Bias: 0.1736035353975205
Estimated ATE:

Estimated ATE: 3.1958239372274857
Bias: 0.1958239372274857
Estimated ATE: 3.3683085878248966
Bias: 0.3683085878248966
Estimated ATE: 3.3244554932132058
Bias: 0.32445549321320577
Estimated ATE: 3.880025474338841
Bias: 0.8800254743388409
Estimated ATE: 3.3683085878248966
Bias: 0.3683085878248966
Estimated ATE: 3.3560480929849485
Bias: 0.35604809298494855
Estimated ATE: 4.469033708084538
Bias: 1.4690337080845381
Mean Squared Error: 17.06598118618572
Estimated ATE: 3.429115238798844
Bias: 0.42911523879884417
Estimated ATE: 3.3683085878248966
Bias: 0.3683085878248966
Estimated ATE: 3.3300998188626822
Bias: 0.33009981886268225
Estimated ATE: 4.893334277708078
Bias: 1.8933342777080782
Estimated ATE: 3.3683085878248966
Bias: 0.3683085878248966
Estimated ATE: 3.3801812939730254
Bias: 0.38018129397302536
Estimated ATE: 2.6432255099688815
Bias: -0.3567744900311185
Estimated ATE: 2.9536208665436923
Bias: -0.04637913345630773
Estimated ATE: 2.9445478989458653
Bias: -0.05545210105413467
Estimated AT

Estimated ATE: 2.3965091235997074
Bias: -0.6034908764002926
Estimated ATE: 2.9192381457567183
Bias: -0.08076185424328175
Estimated ATE: 2.90939445182553
Bias: -0.09060554817446986
Estimated ATE: 2.893111902677576
Bias: -0.10688809732242399
Mean Squared Error: 1.0366134744464026
Estimated ATE: 2.705623064284074
Bias: -0.2943769357159258
Estimated ATE: 2.9192381457567183
Bias: -0.08076185424328175
Estimated ATE: 2.920454204148047
Bias: -0.0795457958519532
Estimated ATE: 2.8721781237812647
Bias: -0.1278218762187353
Estimated ATE: 2.9192381457567183
Bias: -0.08076185424328175
Estimated ATE: 2.8956861608843663
Bias: -0.10431383911563374
Estimated ATE: 1.0695797376900622
Bias: -1.9304202623099378
Estimated ATE: 2.371528751131591
Bias: -0.6284712488684092
Estimated ATE: 2.386074794911927
Bias: -0.6139252050880728
Estimated ATE: -0.4281587349993001
Bias: -3.4281587349993
Estimated ATE: 2.371528751131591
Bias: -0.6284712488684092
Estimated ATE: 2.355808554015459
Bias: -0.644191445984541
Estimat

In [187]:
np.round(np.sqrt(np.mean((np.array(result_list) - true_ATE)**2, axis=0)), 3)

array([0.813, 0.549, 0.485, 2.61 , 0.549, 0.541, 2.058, 1.126, 0.549,
       0.499, 2.168, 0.549, 0.555])

In [28]:
res = tuple(np.array(x) for x in zip(*result_list2))
truth = res[-1:]
res_dict = {}
for it, method in enumerate(methods):
    point, lb, ub = res[it * 3: (it + 1)*3]
    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub,
                        'cov': np.mean(np.logical_and(truth >= lb, truth <= ub)),
                        'bias': np.mean(point - truth),
                        'rmse': rmse_fn(point, truth)
                        }
    print("{} : bias = {:.3f}, rmse = {:.3f}, cov = {:.3f}".format(method, res_dict[method]['bias'], res_dict[method]['rmse'], res_dict[method]['cov']))

NameError: name 'rmse_fn' is not defined

In [178]:
np.round(np.sqrt(np.mean((np.array(result_list) - true_ATE)**2, axis=0)), 3)

array([0.979, 0.556, 0.503, 2.769, 0.556, 0.555, 1.681, 0.813, 0.556,
       0.513, 1.841, 0.556, 0.561])

In [137]:
np.concatenate([X, X**2], axis=1).shape

(3000, 6)

In [138]:
X_temp

array([[ 9.25924616e-01, -2.03031180e+00,  1.27520423e+00,
         8.57336395e-01,  4.12216601e+00,  1.62614582e+00],
       [-5.66775343e-01,  1.07887815e+00, -9.96268834e-01,
         3.21234289e-01,  1.16397806e+00,  9.92551589e-01],
       [-1.14120031e+00, -1.03892180e-01,  2.86252848e-01,
         1.30233815e+00,  1.07935851e-02,  8.19406930e-02],
       ...,
       [-3.20993243e-01,  1.19378072e+00, -9.92688808e-01,
         1.03036662e-01,  1.42511242e+00,  9.85431070e-01],
       [ 3.32821056e-01,  5.35663125e-02, -1.35952125e+00,
         1.10769855e-01,  2.86934983e-03,  1.84829802e+00],
       [-3.07757600e-01, -1.46619960e+00, -2.25171990e+00,
         9.47147402e-02,  2.14974127e+00,  5.07024251e+00]])

In [139]:
np.max(propensity_scores)

0.9999342940986149

In [140]:
np.min(propensity_scores)

0.47356072751750583

In [141]:
np.max(est_prop_score_dbc)

0.7247125

In [142]:
np.min(est_prop_score_dbc)

0.5067347

In [143]:
np.max(est_prop_score_cbps)

0.86754426622205

In [144]:
np.min(est_prop_score_cbps)

0.3840582792049322

In [145]:
np.round(np.sqrt(np.mean((np.array(result_list) - true_ATE)**2, axis=0)), 3)

array([0.81 , 0.584, 0.531, 4.821, 0.584, 0.607, 1.801, 0.939, 0.584,
       0.55 , 2.265, 0.584, 0.596])

In [175]:
print(f"True Average Treatment Effect (ATE): {true_ATE}")

True Average Treatment Effect (ATE): 3.0


In [176]:
# Fit a linear model to estimate the treatment effect
model = LinearRegression()
model.fit(np.hstack([X, T.reshape(-1, 1)]), Y)
estimated_treatment_effect = model.coef_[-1]

# Evaluate performance
true_ATE = treatment_effect
bias = estimated_treatment_effect - true_ATE
mse = mean_squared_error(Y, model.predict(np.hstack([X, T.reshape(-1, 1)])))

print(f"Estimated ATE: {estimated_treatment_effect}")
print(f"Bias: {bias}")
print(f"Mean Squared Error: {mse}")





Estimated ATE: 2.972625870114331
Bias: -0.027374129885668896
Mean Squared Error: 1.7797254450419215


In [177]:
# Fit a linear model to estimate the treatment effect
prop_model = LogisticRegression()
prop_model.fit(X, T)
est_prop_score = prop_model.predict_proba(X)[:, 1]

treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))

treatment_outcome_model.fit(X_treatment, Y_treatment)
control_outcome_model.fit(X_control, Y_control)

est_treatment_outcome = treatment_outcome_model.predict(X)
est_control_outcome = control_outcome_model.predict(X)

IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

# Evaluate performance
IPW_bias = IPW_est - true_ATE

print(f"Estimated ATE: {IPW_est}")
print(f"Bias: {IPW_bias}")

DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

# Evaluate performance
DR_bias = DR_est - true_ATE

print(f"Estimated ATE: {DR_est}")
print(f"Bias: {DR_bias}")

DM_est = np.mean(est_treatment_outcome - est_control_outcome)

# Evaluate performance
DM_bias = DM_est - true_ATE

print(f"Estimated ATE: {DM_est}")
print(f"Bias: {DM_bias}")

Estimated ATE: 2.887012893946858
Bias: -0.11298710605314222
Estimated ATE: 2.9749627896566637
Bias: -0.02503721034333628
Estimated ATE: 2.97911273489089
Bias: -0.020887265109109876


In [178]:
                                                              result_list = []

for tr in range(num_trial):
    result_list_temp = []
    
    
    X = np.random.normal(0, 1, (n, p))

    # Define a propensity score model
    # Assume treatment probability is a sigmoid function of a subset of covariates
    propensity_coef = np.random.normal(0, 0.5, p)
    propensity_scores = expit(X @ propensity_coef)  # Calculate propensity scores

    # Generate treatment assignment based on propensity scores
    T = np.random.binomial(1, propensity_scores)

    # Generate outcome with treatment effect
    # Assume a simple linear model for demonstration
    beta = np.random.normal(0, 1, p)
    Y = (X @ beta)**2 + 1.1 + treatment_effect * T + np.random.normal(0, 1, n)

    X_treatment = X[T == 1]
    X_control = X[T == 0]

    Y_treatment = Y[T == 1]
    Y_control = Y[T == 0]
    
        #### Direct bias correction
    prop_model = DirectBiasCorrection(lbd = 0.01)
    
    prop_model.fit(X, T, Y)
    est_prop_score = prop_model.predict_proba(X)[:, 1]
    est_prop_score_dbc = est_prop_score

    treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
    control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))

    treatment_outcome_model.fit(X_treatment, Y_treatment)
    control_outcome_model.fit(X_control, Y_control)

    est_treatment_outcome = treatment_outcome_model.predict(X)
    est_control_outcome = control_outcome_model.predict(X)

    IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

    # Evaluate performance
    IPW_bias = IPW_est - true_ATE

    print(f"Estimated ATE: {IPW_est}")
    print(f"Bias: {IPW_bias}")
    
    result_list_temp.append(IPW_est)
    
    DM_est = np.mean(est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DM_bias = DM_est - true_ATE

    print(f"Estimated ATE: {DM_est}")
    print(f"Bias: {DM_bias}")
    
    result_list_temp.append(DM_est)

    DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DR_bias = DR_est - true_ATE

    print(f"Estimated ATE: {DR_est}")
    print(f"Bias: {DR_bias}")
    
    result_list_temp.append(DR_est)
    
    ##### Linear models
    
    # Fit a linear model to estimate the treatment effect
    model = LinearRegression()
    model.fit(np.hstack([X, T.reshape(-1, 1)]), Y)
    estimated_treatment_effect = model.coef_[-1]

    # Evaluate performance
    true_ATE = treatment_effect
    bias = estimated_treatment_effect - true_ATE
    mse = mean_squared_error(Y, model.predict(np.hstack([X, T.reshape(-1, 1)])))

    print(f"Estimated ATE: {estimated_treatment_effect}")
    print(f"Bias: {bias}")
    print(f"Mean Squared Error: {mse}")

    result_list_temp.append(estimated_treatment_effect)
    
    
    
    #### Logistc regression 
    
    prop_model = LogisticRegression()
    prop_model.fit(X, T)
    est_prop_score = prop_model.predict_proba(X)[:, 1]

    treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
    control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))

    treatment_outcome_model.fit(X_treatment, Y_treatment)
    control_outcome_model.fit(X_control, Y_control)

    est_treatment_outcome = treatment_outcome_model.predict(X)
    est_control_outcome = control_outcome_model.predict(X)

    IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

    # Evaluate performance
    IPW_bias = IPW_est - true_ATE

    print(f"Estimated ATE: {IPW_est}")
    print(f"Bias: {IPW_bias}")
    
    result_list_temp.append(IPW_est)
    
    DM_est = np.mean(est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DM_bias = DM_est - true_ATE

    print(f"Estimated ATE: {DM_est}")
    print(f"Bias: {DM_bias}")
    
    result_list_temp.append(DM_est)

    DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DR_bias = DR_est - true_ATE

    print(f"Estimated ATE: {DR_est}")
    print(f"Bias: {DR_bias}")
    
    result_list_temp.append(DR_est)
    

    
    
    #### CBPS
    
    # Enable automatic conversion of Pandas DataFrame to R DataFrame
    pandas2ri.activate()

    # Simulate data in Python

    # Create a pandas DataFrame
    column_names = [f'X{i+1}' for i in range(p)]
    df = pd.DataFrame(X, columns=column_names)
    df['T'] = T
    df['Y'] = Y


    # Convert pandas DataFrame to R DataFrame
    r_df = pandas2ri.py2rpy(df)

    ro.r.assign("p", p)

    # Load the CBPS package in R and fit the model for ATE estimation
    ro.r('''
        library(CBPS)
        estimate_cbps_ate <- function(df) {
            formula_str <- paste("T ~", paste(names(df)[1:{p}], collapse=" + "))

            # CBPSの適用 (ATEの推定、ATT=0)
            model <- CBPS(as.formula(formula_str), data = df, ATT = 0, method = "exact")

            # 推定された傾向スコアの取得
            df$propensity_score <- fitted(model)

            # IPW (Inverse Probability Weighting) を適用
            df$weight <- ifelse(df$T == 1, 1 / df$propensity_score, 1 / (1 - df$propensity_score))

            # 重み付き回帰によるATEの推定
            result <- lm(Y ~ T, data = df, weights = df$weight)

            return(df$propensity_score)
        }
    ''')

    # R関数を呼び出してATEと傾向スコアを取得
    est_prop_score = ro.r['estimate_cbps_ate'](r_df)
    
    est_prop_score_cbps = est_prop_score
    
    #print(er)

    treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
    control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))

    treatment_outcome_model.fit(X_treatment, Y_treatment)
    control_outcome_model.fit(X_control, Y_control)

    est_treatment_outcome = treatment_outcome_model.predict(X)
    est_control_outcome = control_outcome_model.predict(X)

    IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

    # Evaluate performance
    IPW_bias = IPW_est - true_ATE

    print(f"Estimated ATE: {IPW_est}")
    print(f"Bias: {IPW_bias}")
    
    result_list_temp.append(IPW_est)
    
    DM_est = np.mean(est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DM_bias = DM_est - true_ATE

    print(f"Estimated ATE: {DM_est}")
    print(f"Bias: {DM_bias}")
    
    result_list_temp.append(DM_est)

    DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

    # Evaluate performance
    DR_bias = DR_est - true_ATE

    print(f"Estimated ATE: {DR_est}")
    print(f"Bias: {DR_bias}")
    
    result_list_temp.append(DR_est)
    
    
    result_list.append(result_list_temp)

In [179]:
# Fit a linear model to estimate the treatment effect
prop_model = DirectBiasCorrection()
prop_model.fit(X, T)
est_prop_score = prop_model.predict_proba(X)[:, 1]

treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))

treatment_outcome_model.fit(X_treatment, Y_treatment)
control_outcome_model.fit(X_control, Y_control)

est_treatment_outcome = treatment_outcome_model.predict(X)
est_control_outcome = control_outcome_model.predict(X)

IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

# Evaluate performance
IPW_bias = IPW_est - true_ATE

print(f"Estimated ATE: {IPW_est}")
print(f"Bias: {IPW_bias}")

DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

# Evaluate performance
DR_bias = DR_est - true_ATE

print(f"Estimated ATE: {DR_est}")
print(f"Bias: {DR_bias}")

DM_est = np.mean(est_treatment_outcome - est_control_outcome)

# Evaluate performance
DM_bias = DM_est - true_ATE

print(f"Estimated ATE: {DM_est}")
print(f"Bias: {DM_bias}")

Estimated ATE: 2.9685072990904686
Bias: -0.03149270090953138
Estimated ATE: 2.972925625090657
Bias: -0.027074374909342946
Estimated ATE: 2.97911273489089
Bias: -0.020887265109109876


In [14]:
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri
import pandas as pd
import numpy as np

# Enable automatic conversion of Pandas DataFrame to R DataFrame

# Simulate data in Python

# Create a pandas DataFrame
column_names = [f'X{i+1}' for i in range(p)]
df = pd.DataFrame(X, columns=column_names)
df['T'] = T
df['Y'] = Y


# Convert pandas DataFrame to R DataFrame
r_df = pandas2ri.py2rpy(df)

ro.r.assign("p", p)

# Load the CBPS package in R and fit the model for ATE estimation
ro.r('''
    library(CBPS)
    estimate_cbps_ate <- function(df) {
        formula_str <- paste("T ~", paste(names(df)[1:{p}], collapse=" + "))
        
        # CBPSの適用 (ATEの推定、ATT=0)
        model <- CBPS(as.formula(formula_str), data = df, ATT = 0, method = "exact")
        
        # 推定された傾向スコアの取得
        df$propensity_score <- fitted(model)
        
        # IPW (Inverse Probability Weighting) を適用
        df$weight <- ifelse(df$T == 1, 1 / df$propensity_score, 1 / (1 - df$propensity_score))
        
        # 重み付き回帰によるATEの推定
        result <- lm(Y ~ T, data = df, weights = df$weight)
        
        return(df$propensity_score)
    }
''')

# R関数を呼び出してATEと傾向スコアを取得
est_prop_score = ro.r['estimate_cbps_ate'](r_df)

treatment_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))
control_outcome_model = KernelRegression(kernel="rbf", gamma=np.logspace(-2, 2, 10))

treatment_outcome_model.fit(X_treatment, Y_treatment)
control_outcome_model.fit(X_control, Y_control)

est_treatment_outcome = treatment_outcome_model.predict(X)
est_control_outcome = control_outcome_model.predict(X)

IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

# Evaluate performance
IPW_bias = IPW_est - true_ATE

print(f"Estimated ATE: {IPW_est}")
print(f"Bias: {IPW_bias}")

DR_est = np.mean(T*(Y - est_treatment_outcome) / est_prop_score - (1 - T)*(Y - est_control_outcome)  / (1 - est_prop_score) + est_treatment_outcome - est_control_outcome)

# Evaluate performance
DR_bias = DR_est - true_ATE

print(f"Estimated ATE: {DR_est}")
print(f"Bias: {DR_bias}")

DM_est = np.mean(est_treatment_outcome - est_control_outcome)

# Evaluate performance
DM_bias = DM_est - true_ATE

print(f"Estimated ATE: {DM_est}")
print(f"Bias: {DM_bias}")

ModuleNotFoundError: No module named 'rpy2'

In [144]:
est_prop_score

<rpy2.robjects.functions.SignatureTranslatedFunction object at 0x326c98250> [3]
R classes: ('function',)

In [113]:
DM_est = np.mean(est_treatment_outcome - est_control_outcome)

# Evaluate performance
DM_bias = DM_est - true_ATE

print(f"Estimated ATE: {DM_est}")
print(f"Bias: {DM_bias}")

Estimated ATE: 2.9181776676318285
Bias: -0.08182233236817149


In [114]:


IPW_est = np.mean(T*Y / est_prop_score - (1 - T)*Y / (1 - est_prop_score))

# Evaluate performance
IPW_bias = IPW_est - true_ATE

print(f"Estimated ATE: {IPW_est}")
print(f"Bias: {IPW_bias}")

Estimated ATE: 3.067399653974786
Bias: 0.06739965397478587


In [115]:
T

array([1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0,
       0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0,
       1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1,
       1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1,
       1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1,
       1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0,
       0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1,
       0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1,
       1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1,
       0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0,
       0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0,
       1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0,
       1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0,

In [116]:
import torch
import torch.nn as nn
import torch.optim as optim


KeyboardInterrupt



In [None]:
# Convert data to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
T_tensor = torch.tensor(T, dtype=torch.float32).view(-1, 1)

# Define a simple neural network model for propensity score estimation
class PropensityScoreNN(nn.Module):
    def __init__(self, input_dim):
        super(PropensityScoreNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x

# Initialize model, loss function, and optimizer
model = PropensityScoreNN(p)
criterion = nn.BCELoss()  # Binary Cross Entropy for binary classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the neural network
num_epochs = 500
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    outputs = model(X_tensor)
    loss = criterion(outputs, T_tensor)
    
    # Backward pass and optimization
    loss.backward()
    optimizer.step()
    
    # Print loss for every 100 epochs
    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Estimated propensity scores for all data
with torch.no_grad():
    estimated_propensity_scores = model(X_tensor).numpy()

In [187]:
estimated_propensity_scores[estimated_propensity_scores < 0.01] = 0.01

In [188]:
estimated_propensity_scores[estimated_propensity_scores > 0.99] = 0.99

In [181]:
np.mean((T/estimated_propensity_scores.T[0] - (1-T)/(1 - estimated_propensity_scores.T[0]))*Y)

2.011218015922772

In [194]:
# Convert data to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
T_tensor = torch.tensor(T, dtype=torch.float32).view(-1, 1)

# Define a simple neural network model for propensity score estimation
class PropensityScoreNN(nn.Module):
    def __init__(self, input_dim):
        super(PropensityScoreNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x

# Initialize model, loss function, and optimizer
model = PropensityScoreNN(p)
criterion = nn.BCELoss()  # Binary Cross Entropy for binary classification
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the neural network
num_epochs = 500
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    outputs = model(X_tensor)
    outputs = torch.clamp(outputs, min=0.01, max=0.99)
    loss = -2*(1/outputs + 1/(1 - outputs)) + (T_tensor / outputs - (1-T_tensor) / (1-outputs))**2
    loss = loss.mean()
    
    # Backward pass and optimization
    loss.backward()
    optimizer.step()
    
    # Print loss for every 100 epochs
    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Estimated propensity scores for all data
with torch.no_grad():
    estimated_propensity_scores = model(X_tensor).numpy()

Epoch [100/500], Loss: -89.3757
Epoch [200/500], Loss: -77.3542
Epoch [300/500], Loss: -126.3577
Epoch [400/500], Loss: -137.7103
Epoch [500/500], Loss: -146.9600


In [189]:
Z_tensor = torch.cat([T_tensor, X_tensor], axis=1)
Y_tensor = torch.tensor(Y, dtype=torch.float32).view(-1, 1)
dim = Z_tensor.shape[1]

class CodOutcomeNN(nn.Module):
    def __init__(self, input_dim):
        super(CodOutcomeNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3 = nn.Linear(32, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    
# Initialize model, loss function, and optimizer
model = CodOutcomeNN(dim)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Train the neural network
num_epochs = 10000
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    outputs = model(Z_tensor)
    loss = ((Y_tensor - outputs)**2).mean()
    loss = loss.mean()
    
    # Backward pass and optimization
    loss.backward()
    optimizer.step()
    
    # Print loss for every 100 epochs
    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Estimated propensity scores for all data
with torch.no_grad():
    estimated_conditional_outcomes = model(Z_tensor).numpy()
    estimated_conditional_outcomes_1 = model(Z_tensor_1).numpy()
    estimated_conditional_outcomes_0 = model(Z_tensor_0).numpy()

Epoch [100/10000], Loss: 66.5176
Epoch [200/10000], Loss: 55.3388
Epoch [300/10000], Loss: 43.7221
Epoch [400/10000], Loss: 34.2139
Epoch [500/10000], Loss: 28.5578
Epoch [600/10000], Loss: 25.6987
Epoch [700/10000], Loss: 23.8395
Epoch [800/10000], Loss: 22.1396
Epoch [900/10000], Loss: 20.4078
Epoch [1000/10000], Loss: 18.6166
Epoch [1100/10000], Loss: 16.7452
Epoch [1200/10000], Loss: 14.7996
Epoch [1300/10000], Loss: 12.8643
Epoch [1400/10000], Loss: 10.9603
Epoch [1500/10000], Loss: 9.1128
Epoch [1600/10000], Loss: 7.3873
Epoch [1700/10000], Loss: 5.9171
Epoch [1800/10000], Loss: 4.7640
Epoch [1900/10000], Loss: 3.9100
Epoch [2000/10000], Loss: 3.3287
Epoch [2100/10000], Loss: 2.9629
Epoch [2200/10000], Loss: 2.7293
Epoch [2300/10000], Loss: 2.5687
Epoch [2400/10000], Loss: 2.4482
Epoch [2500/10000], Loss: 2.3496
Epoch [2600/10000], Loss: 2.2630
Epoch [2700/10000], Loss: 2.1825
Epoch [2800/10000], Loss: 2.1069
Epoch [2900/10000], Loss: 2.0346
Epoch [3000/10000], Loss: 1.9642
Epoch

In [190]:
with torch.no_grad():
    estimated_conditional_outcomes = model(Z_tensor).numpy()
    estimated_conditional_outcomes_1 = model(Z_tensor_1).numpy()
    estimated_conditional_outcomes_0 = model(Z_tensor_0).numpy()

In [191]:
aaa = np.mean(estimated_conditional_outcomes_1 - estimated_conditional_outcomes_0)

In [192]:
aaa

2.98163

In [193]:
aaa + np.mean((T/estimated_propensity_scores.T[0] - (1-T)/(1 - estimated_propensity_scores.T[0]))*(Y - estimated_conditional_outcomes.T[0]))

2.979416648378537

In [160]:
Y - estimated_conditional_outcomes.T[0]

array([ 5.63514307e-01,  4.77596318e-01,  6.29739734e-02, -1.27126940e-01,
       -1.77228238e-01,  4.64099619e-01, -3.51511405e-01, -2.99983419e-01,
       -2.69382789e-01,  1.26012672e-01, -3.40216497e-03,  2.07889849e-01,
       -6.60575762e-01, -4.97077542e-01, -8.41113774e-01, -1.26968687e-01,
       -4.25912768e-01, -9.33943451e-02, -4.43000561e-01,  2.44456645e-01,
        7.93688016e-02, -7.86699426e-01,  6.66016525e-01,  5.07777333e-01,
       -6.60204782e-01,  1.76161963e-01,  3.64499704e-01, -2.11489539e-01,
       -3.55552370e-01, -6.61527685e-01, -7.72481761e-02, -2.88697698e-01,
       -6.51645266e-01, -1.14941385e+00, -5.81560010e-01,  9.66436085e-01,
       -7.85939945e-01, -4.66742301e-01,  7.49621729e-01,  2.88066177e-01,
       -5.62568961e-01, -1.75046911e-01,  2.87174176e-01,  5.13794861e-01,
       -3.57863637e-01, -6.20851911e-01, -4.15798049e-01, -4.12468909e-01,
        2.46813750e-01, -5.61748112e-01,  7.15916546e-01, -6.73407462e-01,
       -1.76732254e-01, -