In [13]:
# Imports
import torch
from torch.distributions import Normal,Uniform,Gamma,Laplace,OneHotCategorical
import os
import time
from functools import partial

from causal_cocycle.model import cocycle_model,flow_model,flow_outcome_model
from causal_cocycle.optimise import *
from causal_cocycle.loss_functions import Loss
from causal_cocycle.conditioners import Empty_Conditioner,Constant_Conditioner,Lin_Conditioner,NN_RELU_Conditioner
from causal_cocycle.transformers import Transformer,Shift_layer,Scale_layer,RQS_layer,Inverse_layer
from causal_cocycle.helper_functions import likelihood_loss,mmd,propensity_score
from causal_cocycle.kernels import *
from causal_cocycle.kde import *

#Shorthand function calls
def NN(i,o=2,width=128,layers=2):
    return NN_RELU_Conditioner(width = width,
                                     layers = layers, 
                                     input_dims =  i, 
                                     output_dims = o,
                                     bias = True)
def C(rows,cols,value):
    return Constant_Conditioner(init = torch.ones((rows,cols))*value)

T = partial(Transformer,logdet = True)

In [14]:
# DGP functions
def cdf(X,t):
    return ((X<= t.T)*1).float().mean(0)

class IG:
    
    def __init__(self,alpha,beta):
        self.alpha = alpha
        self.beta = beta
        
    def sample(self,size):
        return 1/Gamma(self.alpha,self.beta).sample(size)

class Mixture1D:
    
    def __init__(self,base_dists,probabilities,noints,scales):
        self.dists = base_dists
        self.probabilities = probabilities
        self.noints = noints
        self.scales = scales
        
    def sample(self,size):
        C = OneHotCategorical(probabilities).sample(size)[:,0]
        Z = torch.zeros((size[0],len(probabilities)))
        for i in range(len(self.dists)):
            Z[:,i] = self.noints[i]+self.scales[i]*self.dists[i].sample(size).T
        return (Z*C).sum(1)[:,None]          

def policy(V, flip_prob = 0.00):
    Z = (V.mean(1)*len(V.T)**0.5)[:,None]
    X_correct =  (Z<-1)*0+(Z>=-1)*(Z<1)*1 + (Z>=1)*2
    flips = (Uniform(0,1).sample((len(V),1))<flip_prob)*1
    return X_correct*(1-flips) + torch.randint(3, (len(V),1))*flips

def new_policy(V, flip_prob = 0.00):
    Z = (V.mean(1)*len(V.T)**0.5)[:,None]
    X_correct =  (Z<-1)*0+(Z>=-1)*1
    flips = (Uniform(0,1).sample((len(V),1))<flip_prob)*1
    return X_correct*(1-flips) + torch.randint(2, (len(V),1))*flips

def shift(V,policy,coeffs):
    t = policy(V)
    z = V @ coeffs
    return 1/(1+torch.exp(z)) + ((t==0)*torch.exp(-0.1*(z+3)**2) + 
                                 (t==1)*torch.exp(-0.1*(z-0)**2)*0.75 + 
                                 (t==2)*torch.exp(-0.1*(z-3)**2)*0.5)

def scale(V,coeffs):
    z = V @ coeffs
    return 0.1*(torch.exp(-1/10*(z+2)**2*(z-2)**2)+1)

def DGP(N,D,policy,covariate_corr = 0, 
        covariate_dist = Normal(0,1),
        noise_dist = Normal(0,1)):
    Sigma = (1-covariate_corr)*torch.eye(D)+covariate_corr*torch.ones((D,D))
    A = torch.linalg.cholesky(Sigma)
    Z = covariate_dist.sample((N,D)) @ A.T
    U = noise_dist.sample((N,1))
    Y = shift(Z,policy,coeffs) + scale(Z,coeffs)*U
    X = torch.column_stack((policy(Z),Z))
    return Z,X,Y

In [15]:
# DGP set up
N = 10**4
D = 10
Zcorr = 0.0
flip_prob = 0.05
coeffs = 1/torch.linspace(1,D,D)[:,None]**1
coeffs *= 1/coeffs.sum()
means = torch.tensor([[-2, 0]]).T # means for mixture U dist
scales = torch.tensor([[-1.0, 1.0]]).T  # variances for mixture U dist
probabilities = torch.tensor([1/2,1/2]) # mixture probs for mixture U dist
base_dists = [IG(10,10),IG(1,1)]
noise_dist = Mixture1D(base_dists,probabilities,means,scales)
Zdist = Normal(0,1.5)

In [16]:
# Method + opt set up
base_distribution = Normal(0,1)
batch_size = 64
validation_method = "fixed"
layers = 2
width = 128
train_val_split = 0.5
learn_rate = [1e-3]
scheduler = True
maxiter = 10000
miniter = 10000
weight_decay = 1e-3
RQS_bins = 8

In [17]:
# Setting training optimiser args
opt_args = ["learn_rate",
            "scheduler",
            "batch_size",
            "maxiter",
            "miniter",
            "weight_decay",
            "print_"]
opt_argvals = [learn_rate,
              scheduler,
              batch_size,
             maxiter,
              miniter,
              weight_decay,
              True]

In [28]:
# Specifying models for cross-validation

# Specifying list of hypers to construct models from
hypers = ["weight_decay"]
hypers_list = [[0,1e-3,1e-2,1e-1,1]]
conditioners_list = [[NN(D+1,1,width,layers),C(1,1,1),C(1,3*RQS_bins + 2,3)],
                          [NN(D+1,1,width,layers),NN(D+1,1,width,layers),C(1,3*RQS_bins + 2,3)],
                          [NN(D+1,1,width,layers),NN(D+1,1,width,layers),NN(D+1,3*RQS_bins + 2,width,layers)]]
transformers_list = [Transformer([Shift_layer(),Scale_layer(),RQS_layer(RQS_bins)],logdet = True),
                           Transformer([Shift_layer(),Scale_layer(),RQS_layer(RQS_bins)],logdet = True),
                           Transformer([Shift_layer(),Scale_layer(),RQS_layer(RQS_bins)],logdet = True)]

# Constructing all model combinations
models_validation = []
hyper_argvals = []
for m in range(len(conditioners_list)):
    for hyper in hypers_list:
        for hyper_value in hyper:
            models_validation.append(flow_model(conditioners_list[m],transformers_list[m]))
            hyper_argvals.append([hyper_value])
hyper_args = [hypers]*len(hyper_argvals)

In [29]:
# DGP
ntrain = int(train_val_split*N)
Z,X,Y = DGP(N,D,partial(policy,flip_prob = flip_prob),Zcorr,Zdist,noise_dist)
Ztrain,Xtrain,Ytrain = Z[:ntrain],X[:ntrain],Y[:ntrain]
Ztest,Xtest,Ytest = Z[ntrain:],X[ntrain:],Y[ntrain:]

In [30]:
# Getting loss functon and training model
loss_fn =  likelihood_loss(base_distribution)
models_validation,val_losses = validate(models_validation,
                                         loss_fn,
                                         X,
                                         Y,
                                         validation_method,
                                         train_val_split,
                                         opt_args,
                                         opt_argvals,
                                         hyper_args,
                                         hyper_argvals)
best_ind = torch.where(val_losses ==val_losses.min())[0][0]
final_model = models_validation[best_ind]
final_model.transformer.logdet = False


Training loss last 10 avg is : tensor(2.5889)
99.9  % completion
Currently optimising model  14 , for fold  0


In [40]:
# Defining outcome model and feature of interest (i.e. cdf)
ntest = 5000
feature = lambda x,t: (torch.sigmoid(x)[None]<= t[...,None]).float()
Usample = base_distribution.sample((ntest,1))
conditional_mean_model = flow_outcome_model(final_model,Usample)

In [41]:
# Sampling from interventional distribution
nintsample = 10**5
Zintdist = Normal(1,0.5)
Zshift,Xshift,Yshift = DGP(nintsample,D,partial(policy),Zcorr,Zintdist,noise_dist)
Zint,Xint,Yint = DGP(nintsample,D,partial(new_policy),Zcorr,Zintdist,noise_dist)
Zshift_train,Xshift_train = Zshift[:ntest],Xshift[:ntest]
Xint_train = Xint[:ntest]

In [42]:
# cdf values
t = torch.linspace(0,1,1000)[:,None]

# Cocycle model cdf
batch = 100
nbatch = int(len(t)/batch)
SCM_cdf_int = torch.zeros(len(t))
SCM_cdf_shift = torch.zeros(len(t))
for i in range(nbatch):
    SCM_cdf_int[i*batch:(i+1)*batch] = conditional_mean_model(Xint_train,
                                                              partial(feature,t = t[i*batch:(i+1)*batch])).mean(1)
    SCM_cdf_shift[i*batch:(i+1)*batch] = conditional_mean_model(Xshift_train,
                                                              partial(feature,t = t[i*batch:(i+1)*batch])).mean(1)
    
    print("getting cdf value batch ",i+1,"/",nbatch)

# True cdf
true_cdf_int = feature(Yint,t).mean((1,2))
true_cdf_shift = feature(Yshift,t).mean((1,2))

getting cdf value batch  0 / 10
getting cdf value batch  1 / 10
getting cdf value batch  2 / 10
getting cdf value batch  3 / 10
getting cdf value batch  4 / 10
getting cdf value batch  5 / 10
getting cdf value batch  6 / 10
getting cdf value batch  7 / 10
getting cdf value batch  8 / 10
getting cdf value batch  9 / 10


In [43]:
# Training propensity score models
Propensity_score_model_est = []
Propensity_score_model_policy = []
Propensity_score_model_new_policy = []

# Estimating mistae probabilities
Xtrue = policy(Ztrain)
states = torch.unique(X[:,0]).int()
nstate = len(states)
P = torch.zeros((nstate,nstate))
for i in range(nstate):
    for j in range(nstate):
        P[i,j] = ((Xtrain[:,0]==states[i])*(Xtrue[:,0]==states[j])).float().sum()
P *= 1/P.sum(0)

propensity_model_est = propensity_score(P,policy)
propensity_model_new_policy = propensity_score(torch.eye(len(P)),new_policy)  
propensity_model_policy = propensity_score(torch.eye(len(P)),policy)  

In [44]:
# Training density models
kde_learn_rate = 0.1
kde_miniter = 100
kde_maxiter = 100
kde_tol = 1e-2
kde_nfold = 3
kde_reg = 1e-6

Densities_Z = []
Densities_Z_shift = []
    
    
kernel = inverse_gaussian_kernel(lengthscale = torch.ones(D),scale = 1.0)
density_z = KDE(kernel)
losses = density_z.optimise(Ztrain,kde_learn_rate,kde_miniter,kde_maxiter,kde_tol,kde_nfold,kde_reg)

kernel_shift = inverse_gaussian_kernel(lengthscale = torch.ones(D),scale = 1.0)
density_zshift = KDE(kernel_shift)
losses_shift = density_zshift.optimise(Zshift_train,kde_learn_rate,kde_miniter,kde_maxiter,kde_tol,kde_nfold,kde_reg)

iter 0 , loss =  tensor(68976.0859)
iter 10 , loss =  tensor(68891.7734)
iter 20 , loss =  tensor(68879.0703)
iter 30 , loss =  tensor(68875.4609)
iter 40 , loss =  tensor(68874.4062)
iter 50 , loss =  tensor(68874.)
iter 60 , loss =  tensor(68873.8047)
iter 70 , loss =  tensor(68873.7031)
iter 80 , loss =  tensor(68873.6875)
iter 90 , loss =  tensor(68873.6875)
iter 0 , loss =  tensor(56175.3359)
iter 10 , loss =  tensor(39026.2188)
iter 20 , loss =  tensor(39731.8789)
iter 30 , loss =  tensor(38435.4258)
iter 40 , loss =  tensor(38606.7148)
iter 50 , loss =  tensor(38439.5352)
iter 60 , loss =  tensor(38442.9180)
iter 70 , loss =  tensor(38438.7500)
iter 80 , loss =  tensor(38431.1484)
iter 90 , loss =  tensor(38431.6094)


In [45]:
# Getting IPW estimator
weights_shift = ((propensity_model_policy(Xtest,Ztest)*
                density_zshift.forward(Ztest,Zshift_train))/
                (propensity_model_est(Xtest,Ztest)*
                density_z.forward(Ztest,Ztrain))).detach()

weights_int = ((propensity_model_new_policy(Xtest,Ztest)*
                density_zshift.forward(Ztest,Zshift_train))/
                (propensity_model_est(Xtest,Ztest)*
                density_z.forward(Ztest,Ztrain))).detach()

IPW_cdf_shift = (weights_shift[None,:,None]*feature(Ytest,t)).mean((1,2))
IPW_cdf_int = (weights_int[None,:,None]*feature(Ytest,t)).mean((1,2))

In [76]:
# Getting DR estimator (start by adding on IPW term to outcome model
cocycle_DR_cdf_shift = cocycle_cdf_shift + IPW_cdf_shift
cocycle_DR_cdf_int = cocycle_cdf_int + IPW_cdf_int

for i in range(nbatch):
    # Getting batch of conditional means and propensity weights
    conditional_mean_batch = conditional_mean_model(Xtest,partial(feature,t = t[i*batch:(i+1)*batch]))

    # Updating DR estimator
    SCM_DR_cdf_shift[i*batch:(i+1)*batch] -= (weights_shift*conditional_mean_batch).mean(1)
    SCM_DR_cdf_int[i*batch:(i+1)*batch] -= (weights_int*conditional_mean_batch).mean(1)

    print("getting cdf value batch ",i+1,"/",nbatch)   

getting cdf value batch  0 / 10
getting cdf value batch  1 / 10
getting cdf value batch  2 / 10
getting cdf value batch  3 / 10
getting cdf value batch  4 / 10
getting cdf value batch  5 / 10
getting cdf value batch  6 / 10
getting cdf value batch  7 / 10
getting cdf value batch  8 / 10
getting cdf value batch  9 / 10
