In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from tabulate import tabulate
os.environ["KMP_DUPLICSCORE_LIB_OK"]="TRUE"

In [5]:
trials = 100
results = torch.load('e401k_cocycles_results_splits=5_boostrap=False_trials={0}.pt'.format(trials))

In [6]:
# Getting results names 
names = []
for result in results:
    if result['name'] not in names:
        names.append(result['name'])

In [7]:
# Getting val_losses and treatment effects
val_losses_PI = torch.zeros((len(names),trials))
val_losses_DR = torch.zeros((len(names),trials,5))

ATEs_PI_all = torch.zeros((len(names),trials))
Var_ATEs_PI_all = torch.zeros((len(names),trials))
ATEs_DR_all = torch.zeros((len(names),trials,5))
ATEs_sq_DR_all = torch.zeros((len(names),trials,5))
ATTs_PI_all = torch.zeros((len(names),trials))
Var_ATTs_PI_all = torch.zeros((len(names),trials))
ATTs_DR_all = torch.zeros((len(names),trials,5))
ATTs_sq_DR_all = torch.zeros((len(names),trials,5))

ETQs_all = torch.zeros((len(names),trials,len(results[0]['ETQ'])))
PPEQs_all = torch.zeros((len(names),trials,len(results[0]['PPEQ']),len(results[0]['PPEQ'].T)))

names_counter = torch.zeros(len(names)).int()
for result in results:
    name_ind = names.index(result['name'])
    val_losses_PI[name_ind,names_counter[name_ind]] = result['val_loss_PI']
    val_losses_DR[name_ind,names_counter[name_ind]] = torch.tensor(result['val_loss_DR'])
    ATEs_PI_all[name_ind,names_counter[name_ind]] = result['ATE_PI']
    Var_ATEs_PI_all[name_ind,names_counter[name_ind]] = result['Var_ATE_PI']
    ATEs_DR_all[name_ind,names_counter[name_ind]] = result['ATE_DR']
    ATEs_sq_DR_all[name_ind,names_counter[name_ind]] = result['ATE_sq_DR']
    ATTs_PI_all[name_ind,names_counter[name_ind]] = result['ATT_PI']
    Var_ATTs_PI_all[name_ind,names_counter[name_ind]] = result['Var_ATT_PI']
    ATTs_DR_all[name_ind,names_counter[name_ind]] = result['ATT_DR']
    ATTs_sq_DR_all[name_ind,names_counter[name_ind]] = result['ATT_sq_DR']
    ETQs_all[name_ind,names_counter[name_ind]] = result['ETQ'][:,0]
    PPEQs_all[name_ind,names_counter[name_ind]] = result['PPEQ']
    names_counter[name_ind]+=1

In [8]:
# Returning best performers
ATEs_PI = torch.zeros(trials)
ATEs_DR = torch.zeros(trials)
ATTs_PI = torch.zeros(trials)
ATTs_DR = torch.zeros(trials)

Var_ATEs_PI = torch.zeros(trials)
ATEs_sq_DR = torch.zeros(trials)
Var_ATTs_PI = torch.zeros(trials)
ATTs_sq_DR = torch.zeros(trials)

ETQs = torch.zeros((trials,len(results[0]['ETQ'])))
PPEQs = torch.zeros((trials,len(results[0]['PPEQ']),len(results[0]['PPEQ'].T)))

model_select = [0,1,2,3,4,5,6,7,8,9,10,11]

for i in range(trials):  
    best_PI = val_losses_PI[model_select,i].sort()[1][0]
    best_DR = val_losses_DR[model_select,i].sort(0)[1][0]
    ATEs_PI[i] = ATEs_PI_all[best_PI,i]
    ATEs_DR[i] = ATEs_DR_all[best_DR,i].mean()
    ATTs_PI[i] = ATTs_PI_all[best_PI,i]
    ATTs_DR[i] = ATTs_DR_all[best_DR,i].mean()
    Var_ATEs_PI[i] = Var_ATEs_PI_all[best_PI,i]
    ATEs_sq_DR[i] = ATEs_sq_DR_all[best_DR,i].mean()
    Var_ATTs_PI[i] = Var_ATTs_PI_all[best_PI,i]
    ATTs_sq_DR[i] = ATTs_sq_DR_all[best_DR,i].mean()
    ETQs[i] = ETQs_all[best_PI,i]
    PPEQs[i] = PPEQs_all[best_PI,i]

In [9]:
# Printing ATEs

# For bootstrap (PI)
print("ATE_PI : ", ATEs_PI.mean(),
      "ATT_PI : ", ATTs_PI.mean())

# For cross-fitting (DR) 
print("ATE_DR : ", ATEs_DR.median(),
      "ATT_DR : ", ATTs_DR.median())

ATE_PI :  tensor(8132.8706) ATT_PI :  tensor(10671.0273)
ATE_DR :  tensor(9348.5361) ATT_DR :  tensor(9469.9170)


In [11]:
# Printing SDs

# For bootrstrap (PI)
print("SD_ATE_PI : ", ATEs_PI.var()**0.5,
      "SD_ATT_PI : ", ATTs_PI.var()**0.5)

# For cross-fitting (DR)
print("SD_ATE_DR : ", ((ATEs_sq_DR-ATEs_DR**2)+(ATEs_DR-ATEs_DR.median())).median()**0.5/9915**0.5,
      "SD_ATT_DR : ", ((ATTs_sq_DR-ATTs_DR**2)+(ATTs_DR-ATTs_DR.median())).median()**0.5/3682**0.5)

SD_ATE_PI :  tensor(2204.7017) SD_ATT_PI :  tensor(2885.4153)
SD_ATE_DR :  tensor(1152.5707) SD_ATT_DR :  tensor(1107.1593)


## Asymptotic results

In [13]:
# Getting reslts by model and trial
ATEs_PI_all = torch.zeros((trials,len(names)))
Var_ATEs_PI_all = torch.zeros((trials,len(names)))
ATEs_DR_all = torch.zeros((trials,len(names)))
ATEs_sq_DR_all = torch.zeros((trials,len(names)))
ATTs_PI_all = torch.zeros((trials,len(names)))
Var_ATTs_PI_all = torch.zeros((trials,len(names)))
ATTs_DR_all = torch.zeros((trials,len(names)))
ATTs_sq_DR_all = torch.zeros((trials,len(names)))
ETQs_all = torch.zeros((trials,len(results[0]['ETQ']),len(names)))
PPEQs_all = torch.zeros((trials,len(results[0]['PPEQ']),len(results[0]['PPEQ'].T),len(names)))
val_loss_PI_all = torch.zeros((trials,len(names)))
val_loss_DR_all = torch.zeros((trials,len(names)))


counter = torch.zeros(len(names)).int()
for result in results:
    for i in range(len(names)):
        if result['name'] == names[i]:
            ATEs_PI_all[counter[i],...,i] = result['ATE_PI']
            ATEs_DR_all[counter[i],...,i] = result['ATE_DR'].mean()
            ATTs_PI_all[counter[i],...,i] = result['ATT_PI']
            ATTs_DR_all[counter[i],...,i] = result['ATT_DR'].mean()
            Var_ATEs_PI_all[counter[i],...,i] = result['Var_ATE_PI']
            ATEs_sq_DR_all[counter[i],...,i] = result['ATE_sq_DR'].mean()
            Var_ATTs_PI_all[counter[i],...,i] = result['Var_ATT_PI']
            ATTs_sq_DR_all[counter[i],...,i] = result['ATT_sq_DR'].mean()
            ETQs_all[counter[i],...,i] = result['ETQ'].T
            PPEQs_all[counter[i],...,i] = result['PPEQ']
            val_loss_PI_all[counter[i],...,i] = result['val_loss_PI']
            val_loss_DR_all[counter[i],...,i] = torch.tensor(result['val_loss_DR']).mean()
            counter[i] += 1

In [14]:
# Getting reslts by transformation group and trial
ATEs_PI_group = torch.zeros((trials,4))
ATEs_DR_group = torch.zeros((trials,4))
ATTs_PI_group = torch.zeros((trials,4))
ATTs_DR_group = torch.zeros((trials,4))
Var_ATEs_PI_group = torch.zeros((trials,4))
ATEs_sq_DR_group = torch.zeros((trials,4))
Var_ATTs_PI_group = torch.zeros((trials,4))
ATTs_sq_DR_group = torch.zeros((trials,4))
ETQs_group = torch.zeros((trials,len(results[0]['ETQ']),4))
PPEQs_group = torch.zeros((trials,len(results[0]['PPEQ']),len(results[0]['PPEQ'].T),4))
val_loss_PI_group = torch.zeros((trials,4))
val_loss_DR_group = torch.zeros((trials,4))

for b in range(len(ATEs_PI_group)):
    for i in range(4):
        best_PI = torch.where(val_loss_PI_all[b,3*i:3*i+3] == val_loss_PI_all[b,3*i:3*i+3].min())[0][0]
        best_DR = torch.where(val_loss_DR_all[b,3*i:3*i+3] == val_loss_DR_all[b,3*i:3*i+3].min())[0][0]
        ATEs_PI_group[b,i] = ATEs_PI_all[b,3*i:3*i+3][best_PI]
        ATEs_DR_group[b,i] = ATEs_DR_all[b,3*i:3*i+3][best_DR]
        ATTs_PI_group[b,i] = ATTs_PI_all[b,3*i:3*i+3][best_PI]
        ATTs_DR_group[b,i] = ATTs_DR_all[b,3*i:3*i+3][best_DR]
        Var_ATEs_PI_group[b,i] = Var_ATEs_PI_all[b,3*i:3*i+3][best_PI]
        ATEs_sq_DR_group[b,i] = ATEs_sq_DR_all[b,3*i:3*i+3][best_DR]
        Var_ATTs_PI_group[b,i] = Var_ATTs_PI_all[b,3*i:3*i+3][best_PI]
        ATTs_sq_DR_group[b,i] = ATTs_sq_DR_all[b,3*i:3*i+3][best_DR]
        ETQs_group[b,...,i] = ETQs_all[b,...,3*i:3*i+3][...,best_PI]
        PPEQs_group[b,...,i] = PPEQs_all[b,...,3*i:3*i+3][...,best_PI]
        val_loss_PI_group[b,i] = val_loss_PI_all[b,3*i:3*i+3][best_PI]
        val_loss_DR_group[b,i] = val_loss_DR_all[b,3*i:3*i+3][best_DR]

In [15]:
# Getting results by model
Results_by_group_PI = torch.row_stack((ATEs_PI_all.mean(0).int(),
                                        ((Var_ATEs_PI_all**0.5).mean(0)).int(),
                                        ATTs_PI_all.mean(0).int(),
                                        ((Var_ATTs_PI_all**0.5).mean(0)).int(),
                                        val_loss_PI_all.mean(0)))
Results_by_model_DR = torch.row_stack((ATEs_DR_all.mean(0).int(),
                                        (ATEs_DR_all.var(0)**0.5).int(),
                                        ATTs_DR_all.mean(0).int(),
                                        (ATTs_DR_all.var(0)**0.5).int(),
                                        val_loss_DR_all.mean(0)))

In [16]:
# Getting results by transformation group
group_names = ["Linear", "Additive", "Affine", "TMI"]
Results_by_group = np.zeros((7,9)).astype(str)
for i in range(len(group_names)):
    Results_by_group[0,i+1]  = group_names[i]
    Results_by_group[0,i+5]  = group_names[i]
    Results_by_group[1,i+1] = int(ATEs_PI_group.mean(0)[i])
    Results_by_group[1,i+5] = int(ATEs_DR_group[:,i].median())
    Results_by_group[2,i+1] = "("+str(int((Var_ATEs_PI_group**0.5).mean(0)[i]))+")"
    Results_by_group[2,i+5] = "("+str(int(((ATEs_sq_DR_group-ATEs_DR_group**2)+(ATEs_DR_group-ATEs_DR_group[:,i].median()))[:,i].median()**0.5/9915**0.5))+")"
    Results_by_group[3,i+1] = int(ATTs_PI_group.mean(0)[i])
    Results_by_group[3,i+5] = int(ATTs_DR_group[:,i].median())
    Results_by_group[4,i+1] = "("+str(int((Var_ATTs_PI_group**0.5).mean(0)[i]))+")"
    Results_by_group[4,i+5] = "("+str(int(((ATTs_sq_DR_group-ATTs_DR_group**2)+(ATTs_DR_group-ATTs_DR_group[:,i].median()))[:,i].median()**0.5/3682**0.5))+")"
    Results_by_group[5,i+1] = float(val_loss_PI_group.mean(0)[i])
    Results_by_group[5,i+5] = float(val_loss_DR_group[:,i].mean())
    Results_by_group[6,i+1] = float(val_loss_PI_group.var(0)[i]**0.5)
    Results_by_group[6,i+5] = float(val_loss_DR_group.var(0)[i]**0.5)

In [17]:
print(tabulate(Results_by_group, tablefmt="latex"))

\begin{tabular}{rllllllll}
\hline
 0 & Linear               & Additive             & Affine               & TMI                  & Linear               & Additive             & Affine               & TMI                  \\
 0 & 2536                 & 3431                 & 8821                 & 8174                 & 12664                & 12058                & 9304                 & 9333                 \\
 0 & (0)                  & (25)                 & (217)                & (118)                & (1293)               & (1264)               & (1152)               & (1150)               \\
 0 & 2536                 & 4354                 & 11421                & 10727                & 11317                & 11132                & 9401                 & 9430                 \\
 0 & (0)                  & (42)                 & (328)                & (202)                & (1193)               & (1178)               & (1100)               & (1105)               \\
 0 & -0.43832740

## Bootstrap based results

In [12]:
# Getting reslts by model and trial
ATEs_PI_all = torch.zeros((trials,len(names)))
ATEs_DR_all = torch.zeros((trials,len(names)))
ATTs_PI_all = torch.zeros((trials,len(names)))
ATTs_DR_all = torch.zeros((trials,len(names)))
ETQs_all = torch.zeros((trials,len(results[0]['ETQ']),len(names)))
PPEQs_all = torch.zeros((trials,len(results[0]['PPEQ']),len(results[0]['PPEQ'].T),len(names)))
val_loss_PI_all = torch.zeros((trials,len(names)))
val_loss_DR_all = torch.zeros((trials,len(names)))

counter = torch.zeros(len(names)).int()
for result in results:
    for i in range(len(names)):
        if result['name'] == names[i]:
            ATEs_PI_all[counter[i],...,i] = result['ATE_PI']
            ATEs_DR_all[counter[i],...,i] = result['ATE_DR'].mean()
            ATTs_PI_all[counter[i],...,i] = result['ATT_PI']
            ATTs_DR_all[counter[i],...,i] = result['ATT_DR'].mean()
            ETQs_all[counter[i],...,i] = result['ETQ'].T
            PPEQs_all[counter[i],...,i] = result['PPEQ']
            val_loss_PI_all[counter[i],...,i] = result['val_loss_PI']
            val_loss_DR_all[counter[i],...,i] = torch.tensor(result['val_loss_DR']).mean()
            counter[i] += 1

In [None]:
# Getting reslts by transformation group and trial
ATEs_PI_group = torch.zeros((trials,4))
ATEs_DR_group = torch.zeros((trials,4))
ATTs_PI_group = torch.zeros((trials,4))
ATTs_DR_group = torch.zeros((trials,4))
ETQs_group = torch.zeros((trials,len(results[0]['ETQ']),4))
PPEQs_group = torch.zeros((trials,len(results[0]['PPEQ']),len(results[0]['PPEQ'].T),4))
val_loss_PI_group = torch.zeros((trials,4))
val_loss_DR_group = torch.zeros((trials,4))

for b in range(len(ATEs_PI_group)):
    for i in range(4):
        best_PI = torch.where(val_loss_PI_all[b,3*i:3*i+3] == val_loss_PI_all[b,3*i:3*i+3].min())[0][0]
        best_DR = torch.where(val_loss_DR_all[b,3*i:3*i+3] == val_loss_DR_all[b,3*i:3*i+3].min())[0][0]
        ATEs_PI_group[b,i] = ATEs_PI_all[b,3*i:3*i+3][best_PI]
        ATEs_DR_group[b,i] = ATEs_DR_all[b,3*i:3*i+3][best_DR]
        ATTs_PI_group[b,i] = ATTs_PI_all[b,3*i:3*i+3][best_PI]
        ATTs_DR_group[b,i] = ATTs_DR_all[b,3*i:3*i+3][best_DR]
        ETQs_group[b,...,i] = ETQs_all[b,...,3*i:3*i+3][...,best_PI]
        PPEQs_group[b,...,i] = PPEQs_all[b,...,3*i:3*i+3][...,best_PI]
        val_loss_PI_group[b,i] = val_loss_PI_all[b,3*i:3*i+3][best_PI]
        val_loss_DR_group[b,i] = val_loss_DR_all[b,3*i:3*i+3][best_DR]

In [None]:
# Getting results by model
Results_by_group_PI = torch.row_stack((ATEs_PI_all.mean(0).int(),
                                        (ATEs_PI_all.var(0)**0.5).int(),
                                        ATTs_PI_all.mean(0).int(),
                                        (ATTs_PI_all.var(0)**0.5).int(),
                                        val_loss_PI_all.mean(0)))
Results_by_model_DR = torch.row_stack((ATEs_DR_all.mean(0).int(),
                                        (ATEs_DR_all.var(0)**0.5).int(),
                                        ATTs_DR_all.mean(0).int(),
                                        (ATTs_DR_all.var(0)**0.5).int(),
                                        val_loss_DR_all.mean(0)))

In [None]:
# Getting results by transformation group
group_names = ["Linear", "Additive", "Affine", "TMI"]
Results_by_group = np.zeros((7,9)).astype(str)
for i in range(len(group_names)):
    Results_by_group[0,i+1]  = group_names[i]
    Results_by_group[0,i+5]  = group_names[i]
    Results_by_group[1,i+1] = int(ATEs_PI_group.mean(0)[i])
    Results_by_group[1,i+5] = int(ATEs_DR_group.mean(0)[i])
    Results_by_group[2,i+1] = "("+str(int((ATEs_PI_group.var(0)**0.5)[i]))+")"
    Results_by_group[2,i+5] = "("+str(int((ATEs_DR_group.var(0)**0.5)[i]))+")"
    Results_by_group[3,i+1] = int(ATTs_PI_group.mean(0)[i])
    Results_by_group[3,i+5] = int(ATTs_DR_group.mean(0)[i])
    Results_by_group[4,i+1] = "("+str(int((ATTs_PI_group.var(0)**0.5)[i]))+")"
    Results_by_group[4,i+5] = "("+str(int((ATTs_DR_group.var(0)**0.5)[i]))+")"
    Results_by_group[5,i+1] = float(val_loss_PI_group.mean(0)[i])
    Results_by_group[5,i+5] = float(val_loss_DR_group.mean(0)[i])
    Results_by_group[6,i+1] = float(val_loss_PI_group.var(0)[i]**0.5)
    Results_by_group[6,i+5] = float(val_loss_DR_group.var(0)[i]**0.5)

In [None]:
print(tabulate(Results_by_group, tablefmt="latex"))

In [None]:
lower = 10
upper = 90
PPEQ_select = [0,2,9]
PPEQ_threshs = [trials0,2000,3000,4000,5000,6000,7000,8000,9000,trials00]
from matplotlib import rcParams, rc_file_defaults
from matplotlib.ticker import FormatStrFormatter
from labellines import labelLines
ylabelsize = 14
xlabelsize = 14
rc_file_defaults()
rcParams['xtick.labelsize'] = xlabelsize
rcParams['ytick.labelsize'] = ylabelsize 
#rcParams['text.usetex'] = False

fig,axs = plt.subplots(figsize = (6,5))
quantiles = torch.linspace(0.01,0.99,len(ETQs[0]))
axs.plot(quantiles[lower:upper],ETQs.median(0)[0][lower:upper], color = "black", lw = 1.75);
axs.plot(quantiles[lower:upper],ETQs.quantile(0.95,0)[lower:upper], color = "grey", ls = (0,(1,2)), lw = 1.5)
axs.plot(quantiles[lower:upper],ETQs.quantile(0.75,0)[lower:upper], color = "grey", ls = "dashed", lw = 1.5)
axs.plot(quantiles[lower:upper],ETQs.quantile(0.25,0)[lower:upper], color = "grey", ls = "dashed", lw = 1.5)
axs.plot(quantiles[lower:upper],ETQs.quantile(0.05,0)[lower:upper], color = "grey", ls = (0,(1,2)), lw = 1.5)
#axs.plot(quantiles[lower:upper],ETQs.quantile(0.75,0)[lower:upper], color = "grey", ls = "dashed", lw = 2)
#axs.plot(quantiles[lower:upper],ETQs.quantile(0.25,0)[lower:upper], color = "grey", ls = "dashed", lw = 2)
axs.set_xlabel(r"$Y^{(0)}$ quantile $(\tau)$", fontsize = 14)
axs.set_ylabel(r"ETQ$(\tau)$", fontsize = 14)
fig.tight_layout() 
fig.savefig("e401k_ETQ_plot",bbox_inches = "tight")

fig,axs = plt.subplots(figsize = (6,5))
quantiles = torch.linspace(0.01,0.99,len(ETQs[0]))
for i in range(len(PPEQ_select)):
    axs.plot(quantiles[lower:upper],PPEQs.median(0)[0][lower:upper,...,PPEQ_select[i]], lw = 1.75, color = "black", label = "${0}".format(PPEQ_threshs[PPEQ_select[i]]));
#axs.plot(quantiles[lower:upper],PPEQs.quantile(0.95,0)[lower:upper,PPEQ_select], color = "grey", ls = (0,(1,2)), lw = 1.5)
#axs.plot(quantiles[lower:upper],PPEQs.quantile(0.05,0)[lower:upper,PPEQ_select], color = "grey", ls = (0,(1,2)), lw = 1.5)
labelLines(plt.gca().get_lines(), zorder=2.5)
axs.plot(quantiles[lower:upper],PPEQs.quantile(0.75,0)[lower:upper,...,PPEQ_select], color = "grey", ls = "dashed", lw = 1.5)
axs.plot(quantiles[lower:upper],PPEQs.quantile(0.25,0)[lower:upper,...,PPEQ_select], color = "grey", ls = "dashed", lw = 1.5)
axs.set_xlabel(r"$Y^{(0)}$ quantile $(\tau)$", fontsize = 14)
axs.set_ylabel(r"PTEQ$(\tau)$",fontsize =  14)
fig.tight_layout() 
fig.savefig("e401k_PPEQ_plot",bbox_inches = "tight")

fig,axs = plt.subplots(figsize = (6,5))
quantiles = torch.linspace(0.01,0.99,len(ETQs[0]))
axs.plot(quantiles[lower:upper],PPEQs.mean(0)[lower:upper,...,:], lw = 1.5,color = "black");

axs.set_xlabel(r"$Y^{(0)}$ quantile $(\tau)$", fontsize = 14)
axs.set_ylabel(r"PTEQ$(\tau)$",fontsize =  14)
fig.tight_layout() 
fig.savefig("e401k_PPEQ_mean_plot",bbox_inches = "tight")

In [None]:
# PTEQ by threshold 
lower = 10
upper = 90
PPEQ_threshs = [trials0,2000,3000,4000,5000,6000,7000,8000,9000,trials00]
from matplotlib import rcParams, rc_file_defaults
from matplotlib.ticker import FormatStrFormatter
from labellines import labelLines
ylabelsize = 14
xlabelsize = 14
rc_file_defaults()
rcParams['xtick.labelsize'] = xlabelsize
rcParams['ytick.labelsize'] = ylabelsize 
quantiles = torch.linspace(0.01,0.99,len(ETQs[0]))

fig,axs = plt.subplots(nrows = 5, ncols = 2,figsize = (10,15))
plot_num = 0
for row in range(5):
    for col in range(2):
        axs[row,col].plot(quantiles[lower:upper],PPEQs.median(0)[0][lower:upper,...,plot_num], lw = 1.75, color = "black", label = "${0}".format(PPEQ_threshs[plot_num]));
        axs[row,col].plot(quantiles[lower:upper],PPEQs.quantile(0.75,0)[lower:upper,...,plot_num], color = "grey", ls = "dashed", lw = 1.5)
        axs[row,col].plot(quantiles[lower:upper],PPEQs.quantile(0.25,0)[lower:upper,...,plot_num], color = "grey", ls = "dashed", lw = 1.5)
        axs[row,col].plot(quantiles[lower:upper],PPEQs.quantile(0.95,0)[lower:upper,...,plot_num], color = "grey", ls = (0,(1,2)), lw = 1.5)
        axs[row,col].plot(quantiles[lower:upper],PPEQs.quantile(0.05,0)[lower:upper,...,plot_num], color = "grey", ls = (0,(1,2)), lw = 1.5)
        axs[row,col].set_ylim(-0.05,1.05)
        if col == 0:
            axs[row,col].set_ylabel(r"PTEQ$(\tau)$",fontsize =  14)
        if row == 4:
            axs[row,col].set_xlabel(r"$Y^{(0)}$ quantile $(\tau)$", fontsize = 14)
        lines = axs[row,col].plot(quantiles[lower:upper],PPEQs.median(0)[0][lower:upper,...,plot_num], lw = 1.75, color = "black", label = "${0}".format(PPEQ_threshs[plot_num]));
        labelLines(lines, zorder=2.5)
        plot_num += 1
fig.tight_layout() 
fig.savefig("e401k_PPEQ_plot_all",bbox_inches = "tight")


In [None]:
lower = 10
upper = 90
PPEQ_select = [0,4,9]
from matplotlib import rcParams, rc_file_defaults
from matplotlib.ticker import FormatStrFormatter

ylabelsize = 14
xlabelsize = 14
rc_file_defaults()
rcParams['xtick.labelsize'] = xlabelsize
rcParams['ytick.labelsize'] = ylabelsize 

fig,axs = plt.subplots(figsize = (6,5))
quantiles = torch.linspace(0.01,0.99,len(ETQs[0]))
axs.plot(quantiles[lower:upper],ETQs.mean(0)[lower:upper], color = "black", lw = 2);
axs.fill_between(quantiles[lower:upper],ETQs.quantile(0.95,0)[lower:upper],
                 ETQs.quantile(0.05,0)[lower:upper], color = "grey", alpha = 0.25)

axs.set_xlabel(r"$Y^{(0)}$ quantile $(\tau)$", fontsize = 14)
axs.set_ylabel(r"ETQ$(\tau)$", fontsize = 14)
fig.tight_layout() 
fig.savefig("e401k_ETQ_plot",bbox_inches = "tight")

fig,axs = plt.subplots(figsize = (6,5))
quantiles = torch.linspace(0.01,0.99,len(ETQs[0]))
axs.plot(quantiles[lower:upper],PPEQs.mean(0)[lower:upper,...,PPEQ_select], lw = 2, color = "black");
for i in PPEQ_select:
    axs.fill_between(quantiles[lower:upper],PPEQs.quantile(0.95,0)[lower:upper,...,i],
                 PPEQs.quantile(0.05,0)[lower:upper,...,i], color = "grey", alpha = 0.25)

axs.set_xlabel(r"$Y^{(0)}$ quantile $(\tau)$", fontsize = 14)
axs.set_ylabel(r"PTEQ$(\tau)$",fontsize =  14)
fig.tight_layout() 
fig.savefig("e401k_PPEQ_plot",bbox_inches = "tight")