In [None]:
import numpy as np
from scipy import optimize


class RKHS_RieszLearner:
    def __init__(self, loss, separate_or_share, link_name):
        if loss == "LS":
            self.loss_func = self.ls_loss
        if loss == "KL":
            self.loss_func = self.kl_loss
        if loss == "TL":
            self.loss_func = self.tailored_loss
        
        self.separate_or_share = separate_or_share
        self.link_name = link_name

    def _model_construction(self, param, X1, X0, treatment):
        if self.link_name == "Linear":
            if (self.separate_or_share == "Separate"):
                param1 = param[:int(len(param)/2)]
                param0 = param[int(len(param)/2):]
                fx1 = X1 @ param1
                fx0 = X0 @ param0
            else:
                fx1 = X1 @ param
                fx0 = X0 @ param
            alpha1 = fx1
            alpha0 = fx0
            alpha = alpha0
            alpha[treatment] = alpha1[treatment]
            
        if self.link_name == "Logit":
            fx1 = X1 @ param
            ex1 = 1/(1 + np.exp(-fx1))
            fx0 = X0 @ param
            ex0 = 1/(1 + np.exp(-fx0))
            alpha = treatment / ex1 - (1 - treatment) / (1 - ex0)

        return alpha
        
    def ls_loss(self, param, X1, X0, treatment, regularizer):
        treatment0 = treatment*0
        treatment1 = treatment0 + 1
        alpha1 = self._model_construction(param, X1, X0, treatment1)
        alpha0 = self._model_construction(param, X1, X0, treatment0)
        loss = - 2*(alpha1 - alpha0) + treatment*alpha1**2 + (1 - treatment)*alpha0**2
        loss = np.mean(loss) + regularizer * np.sum(param**2)
        return loss
    
    def kl_loss(self, param, X1, X0, treatment, regularizer):
        treatment0 = treatment*0
        treatment1 = treatment0 + 1
        alpha1 = self._model_construction(param, X1, X0, treatment1)
        alpha0 = self._model_construction(param, X1, X0, treatment0)
        loss = - np.log(alpha1) - np.log(-alpha0) + treatment*alpha1 - (1 - treatment)*alpha0
        loss = np.mean(loss) + regularizer * np.sum(param**2)
        return loss
    
    def tailored_loss(self, param, X1, X0, treatment, regularizer):
        treatment0 = treatment*0
        treatment1 = treatment0 + 1
        alpha1 = self._model_construction(param, X1, X0, treatment1)
        alpha0 = self._model_construction(param, X1, X0, treatment0)
        loss = - (1 - treatment)*np.log(alpha1 - 1) - treatment*np.log(- alpha0 - 1) + treatment*alpha1 - (1 - treatment)*alpha0
        loss = np.mean(loss) + regularizer * np.sum(param**2)
        return loss
        
    def optimize(self, covariate, treatment, x_test):
        result = self.minimize(covariate, treatment, 0.01)
        self.x_test = x_test
        
    def obj_func_gen(self, X1, X0, treatment, regularizer):                
        obj_func = lambda param: self.loss_func(param, X1, X0, treatment, regularizer)
        return obj_func
    
    def fit(self, covariate, treatment, outcome, folds=2, num_basis=50):
        self.treatment = treatment
        self.X1_train, self.X0_train, self.X_test, self.lda_chosen = model.kernel_cv(covariate, treatment, covariate, folds=2, num_basis=50)
        self.train(self.X1_train, self.X0_train, self.treatment, self.lda_chosen)
            
    def train(self, X1, X0, treatment, lda_chosen, folds=2, num_basis=50):
        obj_func = self.obj_func_gen(X1, X0, treatment, lda_chosen)
        
        if (self.separate_or_share == "Share") & (self.link_name == "Linear"):
            init_param = np.zeros(X1.shape[1]*2)
        else:
            init_param = np.zeros(X1.shape[1])

        self.result = optimize.minimize(obj_func, init_param, method="BFGS")
        self.params = self.result.x
                    
    def predict(self):
        if self.link_name == "Linear":
            if (self.separate_or_share == "Share"):
                param1 = self.params[:int(len(self.params)/2)]
                param0 = self.params[int(len(self.params)/2):]
                fx1 = self.X_test @ param1
                fx0 = self.X_test @ param0
                
                fx = fx0
                fx[self.treatment == 1] = fx1[self.treatment == 1]
                
                alpha = fx
            else:
                fx = self.X_test @ self.params
                alpha = fx
            
        if self.link_name == "Logit":
            fx = self.X_test @ self.params
            ex = 1/(1 + np.exp(-fx))
            alpha = treatment / fx - (1 - treatment) / (1 - fx)


        self.riesz = alpha
                    
        return self.riesz
            

    def dist(self, X, X1, X0, treatment=None, num_basis=False):
        (d,n) = X.shape
        
        if num_basis is False:
            num_basis = 1000

        idx = np.random.permutation(n)[0:num_basis]
        C = X[:, idx]

        # calculate the squared distances
        X1C_dist = CalcDistanceSquared(X1, C)
        X0C_dist = CalcDistanceSquared(X0, C)
        DC_dist = CalcDistanceSquared(treatment, C)
        CC_dist = CalcDistanceSquared(C, C)
        return X1C_dist, X0C_dist, DC_dist, CC_dist, n, num_basis


    def kernel_cv(self, covariate_train, treatment, covariate_test, folds=5, num_basis=False, sigma_list=None, lda_list=None):
        if self.separate_or_share == "DX":
            treatment0 = treatment*0
            treatment1 = treatment0 + 1
            X_train1 = np.concatenate([np.array([treatment1]).T, covariate_train], axis=1)
            X_train0 = np.concatenate([np.array([treatment0]).T, covariate_train], axis=1)
            X_train = np.concatenate([np.array([treatment]).T, covariate_train], axis=1)
        elif self.separate_or_share == "OnlyX":
            X_train1 = covariate_train
            X_train0 = covariate_train
            X_train = covariate_train
                    
        if self.separate_or_share == "DX":
            X_test = np.concatenate([np.array([treatment]).T, covariate_test], axis=1)
        elif self.separate_or_share == "OnlyX":
            X_test = covariate_test
            
        X_train, X_train1, X_train0, X_test = X_train.T, X_train1.T, X_train0.T, X_test.T
        X1C_dist, X0C_dist, DC_dist, CC_dist, n, num_basis = self.dist(X_train, X_train1, X_train0, X_test, num_basis)
        # setup the cross validation
        cv_fold = np.arange(folds) # normal range behaves strange with == sign
        cv_split0 = np.floor(np.arange(n)*folds/n)
        cv_index = cv_split0[np.random.permutation(n)]
        # set the sigma list and lambda list
        if sigma_list==None:
            sigma_list = np.array([0.01, 0.05, 0.1, 0.5, 1])
        if lda_list==None:
            lda_list = np.array([0.01, 0.05, 0.1, 0.5, 1])
        score_cv = np.zeros((len(sigma_list), len(lda_list)))
        
        for sigma_idx, sigma in enumerate(sigma_list):
            # pre-sum to speed up calculation
            h1_cv = []
            h0_cv = []
            d_cv = []
            for k in cv_fold:
                h1_cv.append(np.exp(-X1C_dist[:, cv_index==k]/(2*sigma**2)))
                h0_cv.append(np.exp(-X0C_dist[:, cv_index==k]/(2*sigma**2)))
                d_cv.append(treatment[cv_index==k])

            for k in range(folds):
                # calculate the h vectors for training and test
                count = 0
                for j in range(folds):
                    if j == k:
                        h1te = h1_cv[j].T
                        h0te = h0_cv[j].T
                        dte = d_cv[j]
                    else:
                        if count == 0:
                            h1tr = h1_cv[j].T
                            h0tr = h0_cv[j].T
                            dtr = d_cv[j]
                            count += 1
                        else:
                            h1tr = np.append(h1tr, h1_cv[j].T, axis=0)
                            h0tr = np.append(h0tr, h0_cv[j].T, axis=0)
                            dtr = np.append(dtr, d_cv[j], axis=0)

                one = np.ones((len(h1tr),1))
                h1tr = np.concatenate([h1tr, one], axis=1)
                h0tr = np.concatenate([h0tr, one], axis=1)
                one = np.ones((len(h1te),1))
                h1te = np.concatenate([h1te, one], axis=1)
                h0te = np.concatenate([h0te, one], axis=1)
                for lda_idx, lda in enumerate(lda_list):
                    res_param = self.train(h1tr, h0tr, dtr, lda)
                    # calculate the solution and cross-validation value
                    obj_func = self.obj_func_gen(h1te, h0te, dte, 0)
                    score = obj_func(self.params)       
                    score_cv[sigma_idx, lda_idx] = score_cv[sigma_idx, lda_idx] + score


        # get the minimum
        (sigma_idx_chosen, lda_idx_chosen) = np.unravel_index(np.argmin(score_cv), score_cv.shape)
        sigma_chosen = sigma_list[sigma_idx_chosen]
        lda_chosen = lda_list[lda_idx_chosen]

        x1_train = np.exp(-X1C_dist/(2*sigma_chosen**2)).T
        x0_train = np.exp(-X0C_dist/(2*sigma_chosen**2)).T
        x_test = np.exp(-DC_dist/(2*sigma_chosen**2)).T

        one = np.ones((len(x1_train),1))
        X1_train = np.concatenate([x1_train, one], axis=1)
        X0_train = np.concatenate([x0_train, one], axis=1)
        one = np.ones((len(x_test),1))
        X_test = np.concatenate([x_test, one], axis=1)
        
        return X1_train, X0_train, X_test, lda_chosen



def CalcDistanceSquared(X, C):
    '''
    Calculates the squared distance between X and C.
    XC_dist2 = CalcDistSquared(X, C)
    [XC_dist2]_{ij} = ||X[:, j] - C[:, i]||2
    :param X: dxn: First set of vectors
    :param C: d:nc Second set of vectors
    :return: XC_dist2: The squared distance nc x n
    '''
    X**2
    Xsum = np.sum(X**2, axis=0).T
    Csum = np.sum(C**2, axis=0)
    XC_dist = Xsum[np.newaxis, :] + Csum[:, np.newaxis] - 2*np.dot(C.T, X)
    return XC_dist
