# README: 

This notebook contains the private version of DAG-WGAN. Anyone can train the model using the GPUs from Google Colab VMs. 

It can be easily run by toggling args.differentialPrivacy to True or False in Private_Main.py. Feel free to play around with the hyper-parameters as there is plenty of tweeking still to be done.

edits: @calmac

last: 13/10/22

#=========================================

# Differentially private DAG training

## Utils.py

In [None]:

# -*- coding: utf-8 -*-
"""
Created on Thu Nov 12 15:13:21 2020
@author: Hristo Petkov
"""

"""
@inproceedings{yu2019dag,
  title={DAG-GNN: DAG Structure Learning with Graph Neural Networks},
  author={Yue Yu, Jie Chen, Tian Gao, and Mo Yu},
  booktitle={Proceedings of the 36th International Conference on Machine Learning},
  year={2019}
}
@inproceedings{xu2019modeling,
  title={Modeling Tabular data using Conditional GAN},
  author={Xu, Lei and Skoularidou, Maria and Cuesta-Infante, Alfredo and Veeramachaneni, Kalyan},
  booktitle={Advances in Neural Information Processing Systems},
  year={2019}
}
"""
import torch
import os
import math
import numpy as np
import torch.nn as nn
import networkx as nx
import scipy.linalg as slin
import torch.nn.functional as F
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.feature_selection import SelectFromModel
from networkx.drawing.nx_agraph import write_dot, graphviz_layout
from networkx.convert_matrix import from_numpy_matrix
from matplotlib import pyplot as plt
from torch.utils.data.dataset import TensorDataset
from torch.utils.data import DataLoader
from torch.autograd import Variable
# from FullDataPreProcessor import FullDataPreProcessor

# AAE utility functions

def num_categories(labels):
    return len(set(labels))

def my_softmax(input, axis=1):
    trans_input = input.transpose(axis, 0).contiguous()
    soft_max_1d = F.softmax(trans_input, dim=-1)
    return soft_max_1d.transpose(axis, 0)

def relu(x, derivative=False, alpha=0.1):
    rel = x * (x > 0)
    if derivative:
        return (x > 0)*1
    return rel

def preprocess_adj_new(adj, device):
    adj_normalized = (torch.eye(adj.shape[0]).double().to(device) - (adj.transpose(0,1)).to(device))
    return adj_normalized

def preprocess_adj_new1(adj, device):
    adj_normalized = torch.inverse(torch.eye(adj.shape[0]).double().to(device) - adj.transpose(0,1).to(device))
    return adj_normalized

def matrix_poly(matrix, d, device):
    x = torch.eye(d).double().to(device) + torch.div(matrix, d).to(device)
    return torch.matrix_power(x, d)

# compute constraint h(A) value
def _h_A(A, m, device):
    expm_A = matrix_poly(A*A, m, device)
    h_A = torch.trace(expm_A) - m
    return h_A

def build_phi(w, totalcount):
    phi = w[totalcount:].reshape(-1,1)
    return phi

def build_W(w, d,totalcount):
    # build w
    w1 = np.zeros([d, d])
    lower_index = np.tril_indices(d, -1)
    w1[lower_index] = w[:totalcount]

    return w1 + w1.T - np.diag(w1.diagonal())

def build_w_inv(A, phi, d,totalcount):
    # build w
    w1 = np.zeros([d, d])
    for i in range(d-1):
        for j in range(d-1):
            if ((relu(phi[j]-phi[i])>1e-8)):
                w1[i,j] = A[i,j]/relu(phi[j]-phi[i])
            else:
                w1[i,j] = 0
    w = (w1 + w1.T)/2.
    wnew = np.zeros(totalcount)
    lower_index = np.tril_indices(d, -1)
    wnew[:totalcount] = w[lower_index]

    return wnew

def to_categorical(y, num_columns):
    """Returns one-hot encoded Variable"""
    y_cat = np.zeros((y.shape[0], num_columns))
    y_cat[range(y.shape[0]), y] = 1.0

    return Variable(torch.FloatTensor(y_cat))

def pns_(model_adj, dataloader, num_neighbors, thresh):
    """Preliminary neighborhood selection"""
    #x_train, _ = train_data.sample(train_data.num_samples)
    #x_test, _ = test_data.sample(test_data.num_samples)
    #x = np.concatenate([x_train.detach().cpu().numpy(), x_test.detach().cpu().numpy()], 0)
    x = dataloader.dataset.tensors[0].squeeze()
    #print(x.shape)

    num_samples = x.shape[0]
    num_nodes = x.shape[1]
    print("PNS: num samples = {}, num nodes = {}".format(num_samples, num_nodes))
    for node in range(num_nodes):
        print("PNS: node " + str(node))
        x_other = np.copy(x)
        x_other[:, node] = 0
        reg = ExtraTreesRegressor(n_estimators=500)
        reg = reg.fit(x_other, x[:, node])
        selected_reg = SelectFromModel(reg, threshold="{}*mean".format(thresh), prefit=True,
                                       max_features=num_neighbors)
        mask_selected = selected_reg.get_support(indices=False).astype(np.float)

        model_adj[:, node] *= mask_selected

    return model_adj

def nll_catogrical(preds, target, add_const = False, eps=1e-16):
    '''compute the loglikelihood of discrete variables
    '''
    loss = nn.CrossEntropyLoss(reduction='sum')
    output = loss(preds, torch.argmax(target,1))
    return output   

def nll_gaussian(preds, target, variance, add_const=False):
    
    mean1 = preds
    mean2 = target
    
    neg_log_p = variance + torch.div(torch.pow(mean1 - mean2, 2), 2.*np.exp(2. * variance))
    
    if add_const:
        const = 0.5 * torch.log(2 * torch.from_numpy(np.pi) * variance)
        neg_log_p += const
            
    return neg_log_p.sum() / (target.size(0))
    
def kl_gaussian_sem(logits):
    mu = logits
    kl_div = mu * mu
    kl_sum = kl_div.sum()
    return (kl_sum / (logits.size(0)))*0.5

# def nll_catogrical(preds, target, add_const = False,  eps = 1e-20):
#     '''compute the loglikelihood of discrete variables
#     '''
#     #BCE = F.binary_cross_entropy(preds, target, size_average=False) / target.shape[0]
#     BCE = torch.sum(target * torch.log(preds + eps), dim=1).mean()
#     return BCE

def kl_categorical(preds, num_cats, eps=1e-16):
    # KL Divergence = entropy (logits) - cross_entropy(logits, uniform log-odds)
    q_y = F.softmax(preds, dim=-1) # convert logits values to probabilities
    kl1 = q_y * torch.log(q_y + eps) # entropy (self.latent)
    kl2 = q_y * np.log((1.0/num_cats) + eps) # cross_entropy(logits, uniform log-odds)
    KL_divergence = torch.sum(torch.sum(kl1 - kl2, 2),1).mean()
    return KL_divergence
     

def sample_gumbel(shape, eps=1e-10):
    """
    NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3
    Sample from Gumbel(0, 1)
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
    (MIT license)
    """
    U = torch.rand(shape).float()
    return - torch.log(eps - torch.log(U + eps))


def gumbel_softmax_sample(logits, tau=1, eps=1e-10):
    """
    NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3
    Draw a sample from the Gumbel-Softmax distribution
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
    (MIT license)
    """
    gumbel_noise = sample_gumbel(logits.size(), eps=eps)
    if logits.is_cuda:
        gumbel_noise = gumbel_noise.cuda()
    y = logits + Variable(gumbel_noise).double()
    return my_softmax(y / tau, axis=-1)

def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10):
    """
    NOTE: Stolen from https://github.com/pytorch/pytorch/pull/3341/commits/327fcfed4c44c62b208f750058d14d4dc1b9a9d3
    Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
      logits: [batch_size, n_class] unnormalized log-probs
      tau: non-negative scalar temperature
      hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
      [batch_size, n_class] sample from the Gumbel-Softmax distribution.
      If hard=True, then the returned sample will be one-hot, otherwise it will
      be a probability distribution that sums to 1 across classes
    Constraints:
    - this implementation only works on batch_size x num_features tensor for now
    based on
    https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
    (MIT license)
    """
    y_soft = gumbel_softmax_sample(logits, tau=tau, eps=eps)
    if hard:
        shape = logits.size()
        _, k = y_soft.data.max(-1)
        # this bit is based on
        # https://discuss.pytorch.org/t/stop-gradients-for-st-gumbel-softmax/530/5
        y_hard = torch.zeros(*shape)
        if y_soft.is_cuda:
            y_hard = y_hard.cuda()
        y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0)
        # this cool bit of code achieves two things:
        # - makes the output value exactly one-hot (since we add then
        #   subtract y_soft value)
        # - makes the gradient equal to y_soft gradient (since we strip
        #   all other gradients)
        y = Variable(y_hard - y_soft.data) + y_soft
    else:
        y = y_soft
    return y

#Plotting the DAG
#Borrowed from Causalnex Documentation
#https://causalnex.readthedocs.io/en/latest/03_tutorial/plotting_tutorial.html
def draw_dag(graph, data_type, columns = None):
    
    final_DAG = from_numpy_matrix(graph, create_using=nx.DiGraph)
    
    if data_type == 'real':
        final_DAG = nx.relabel_nodes(
            final_DAG, dict(zip(list(range(graph.shape[0])), columns)))
    final_DAG.remove_nodes_from(list(nx.isolates(final_DAG)))
    
    print('FINAL DAG')
    print(final_DAG.adj)
    
    write_dot(final_DAG,'test.dot')
    
    fig = plt.figure(figsize=(15, 8))  # set figsize
    ax = fig.add_subplot(1, 1, 1)
    ax.set_facecolor("#001521")  # set backgrount

    pos = graphviz_layout(final_DAG, prog="dot")

    # add nodes to figure
    nx.draw_networkx_nodes(
        final_DAG,
        pos,
        node_shape="H",
        node_size=1000,
        linewidths=3,
        edgecolors="#4a90e2d9",
    )
    
    # add labels
    nx.draw_networkx_labels(
        final_DAG,
        pos,
        font_color="#FFFFFFD9",
        font_weight="bold",
        font_family="Helvetica",
        font_size=10,
    )
    
    # add edges
    nx.draw_networkx_edges(
        final_DAG,
        pos,
        edge_color="white",
        node_shape="H",
        node_size=2000,
        width=[w + 0.1 for _, _, w, in final_DAG.edges(data="weight")],
    )

    plt.show()
    plt.close()
    
# data generating functions below this point

def simulate_random_dag(d: int,
                        degree: float,
                        graph_type: str,
                        w_range: tuple = (0.5, 2.0)) -> nx.DiGraph:
    """Simulate random DAG with some expected degree.
    Args:
        d: number of nodes
        degree: expected node degree, in + out
        graph_type: {erdos-renyi, barabasi-albert, full}
        w_range: weight range +/- (low, high)
    Returns:
        G: weighted DAG
    """
    if graph_type == 'erdos-renyi':
        prob = float(degree) / (d - 1)
        B = np.tril((np.random.rand(d, d) < prob).astype(float), k=-1)
    elif graph_type == 'barabasi-albert':
        m = int(round(degree / 2))
        B = np.zeros([d, d])
        bag = [0]
        for ii in range(1, d):
            dest = np.random.choice(bag, size=m)
            for jj in dest:
                B[ii, jj] = 1
            bag.append(ii)
            bag.extend(dest)
    elif graph_type == 'full':  # ignore degree, only for experimental use
        B = np.tril(np.ones([d, d]), k=-1)
    else:
        raise ValueError('unknown graph type')
    # random permutation
    P = np.random.permutation(np.eye(d, d))  # permutes first axis only
    B_perm = P.T.dot(B).dot(P)
    U = np.random.uniform(low=w_range[0], high=w_range[1], size=[d, d])
    U[np.random.rand(d, d) < 0.5] *= -1
    W = (B_perm != 0).astype(float) * U
    G = nx.DiGraph(W)
    return G

def simulate_sem(G: nx.DiGraph,
                 n: int, x_dims: int,
                 sem_type: str,
                 linear_type: str,
                 noise_scale: float = 1.0) -> np.ndarray:
    """Simulate samples from SEM with specified type of noise.
    Args:
        G: weigthed DAG
        n: number of samples
        sem_type: {linear-gauss,linear-exp,linear-gumbel}
        noise_scale: scale parameter of noise distribution in linear SEM
    Returns:
        X: [n,d] sample matrix
    """
    
    W = nx.to_numpy_array(G)
    d = W.shape[0]
    X = np.zeros([n, d, x_dims])
    ordered_vertices = list(nx.topological_sort(G))
    assert len(ordered_vertices) == d
    for j in ordered_vertices:
        parents = list(G.predecessors(j))
        if linear_type == 'linear':
            eta = X[:, parents, 0].dot(W[parents, j])
        elif linear_type == 'nonlinear_1':
            eta = np.cos(X[:, parents, 0] + 1).dot(W[parents, j])
        elif linear_type == 'nonlinear_2':
            eta = (X[:, parents, 0]+0.5).dot(W[parents, j])
        else:
            raise ValueError('unknown linear data type')

        if sem_type == 'linear-gauss':
            if linear_type == 'linear':
                X[:, j, 0] = eta + np.random.normal(scale=noise_scale, size=n)
            elif linear_type == 'nonlinear_1':
                X[:, j, 0] = eta + np.random.normal(scale=noise_scale, size=n)
            elif linear_type == 'nonlinear_2':
                X[:, j, 0] = 2.*np.sin(eta) + eta + np.random.normal(scale=noise_scale, size=n)
        elif sem_type == 'linear-exp':
            X[:, j, 0] = eta + np.random.exponential(scale=noise_scale, size=n)
        elif sem_type == 'linear-gumbel':
            X[:, j, 0] = eta + np.random.gumbel(scale=noise_scale, size=n)
        else:
            raise ValueError('unknown sem type')
    if x_dims > 1 :
        for i in range(x_dims-1):
            X[:, :, i+1] = np.random.normal(scale=noise_scale, size=1)*X[:, :, 0] + np.random.normal(scale=noise_scale, size=1) + np.random.normal(scale=noise_scale, size=(n, d))
        X[:, :, 0] = np.random.normal(scale=noise_scale, size=1) * X[:, :, 0] + np.random.normal(scale=noise_scale, size=1) + np.random.normal(scale=noise_scale, size=(n, d))
    return X

def simulate_population_sample(W: np.ndarray,
                               Omega: np.ndarray) -> np.ndarray:
    """Simulate data matrix X that matches population least squares.
    Args:
        W: [d,d] adjacency matrix
        Omega: [d,d] noise covariance matrix
    Returns:
        X: [d,d] sample matrix
    """
    d = W.shape[0]
    X = np.sqrt(d) * slin.sqrtm(Omega).dot(np.linalg.pinv(np.eye(d) - W))
    return X

def count_accuracy(G_true: nx.DiGraph,
                   G: nx.DiGraph,
                   G_und: nx.DiGraph = None) -> tuple:
    """Compute FDR, TPR, and FPR for B, or optionally for CPDAG B + B_und.
    Args:
        G_true: ground truth graph
        G: predicted graph
        G_und: predicted undirected edges in CPDAG, asymmetric
    Returns:
        fdr: (reverse + false positive) / prediction positive
        tpr: (true positive) / condition positive
        fpr: (reverse + false positive) / condition negative
        shd: undirected extra + undirected missing + reverse
        nnz: prediction positive
    """
    B_true = nx.to_numpy_array(G_true) != 0
    B = nx.to_numpy_array(G) != 0
    B_und = None if G_und is None else nx.to_numpy_array(G_und)
    d = B.shape[0]
    # linear index of nonzeros
    if B_und is not None:
        pred_und = np.flatnonzero(B_und)
    pred = np.flatnonzero(B)
    cond = np.flatnonzero(B_true)
    cond_reversed = np.flatnonzero(B_true.T)
    cond_skeleton = np.concatenate([cond, cond_reversed])
    # true pos
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    if B_und is not None:
        # treat undirected edge favorably
        true_pos_und = np.intersect1d(pred_und, cond_skeleton, assume_unique=True)
        true_pos = np.concatenate([true_pos, true_pos_und])
    # false pos
    false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=True)
    if B_und is not None:
        false_pos_und = np.setdiff1d(pred_und, cond_skeleton, assume_unique=True)
        false_pos = np.concatenate([false_pos, false_pos_und])
    # reverse
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
    # compute ratio
    pred_size = len(pred)
    if B_und is not None:
        pred_size += len(pred_und)
    cond_neg_size = 0.5 * d * (d - 1) - len(cond)
    fdr = float(len(reverse) + len(false_pos)) / max(pred_size, 1)
    tpr = float(len(true_pos)) / max(len(cond), 1)
    fpr = float(len(reverse) + len(false_pos)) / max(cond_neg_size, 1)
    # structural hamming distance
    B_lower = np.tril(B + B.T)
    if B_und is not None:
        B_lower += np.tril(B_und + B_und.T)
    pred_lower = np.flatnonzero(B_lower)
    cond_lower = np.flatnonzero(np.tril(B_true + B_true.T))
    extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
    missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
    shd = len(extra_lower) + len(missing_lower) + len(reverse)
    return fdr, tpr, fpr, shd, pred_size

def count_accuracy_new(G_true: nx.DiGraph,
                   G: nx.DiGraph,
                   G_und: nx.DiGraph = None) -> tuple:
    """Compute FDR, TPR, and FPR for B, or optionally for CPDAG B + B_und.
    Args:
        G_true: ground truth graph
        G: predicted graph
        G_und: predicted undirected edges in CPDAG, asymmetric
    Returns:
        fdr: (reverse + false positive) / prediction positive
        tpr: (true positive) / condition positive
        fpr: (reverse + false positive) / condition negative
        shd: undirected extra + undirected missing + reverse
        nnz: prediction positive
    """
    B_true = nx.to_numpy_array(G_true) != 0
    B = nx.to_numpy_array(G) != 0
    B_und = None if G_und is None else nx.to_numpy_array(G_und)
    d = B.shape[0]
    # linear index of nonzeros
    if B_und is not None:
        pred_und = np.flatnonzero(B_und)
    pred = np.flatnonzero(B)
    cond = np.flatnonzero(B_true)
    cond_reversed = np.flatnonzero(B_true.T)
    cond_skeleton = np.concatenate([cond, cond_reversed])
    # true pos
    true_pos = np.intersect1d(pred, cond, assume_unique=True)
    if B_und is not None:
        # treat undirected edge favorably
        true_pos_und = np.intersect1d(pred_und, cond_skeleton, assume_unique=True)
        true_pos = np.concatenate([true_pos, true_pos_und])
    # false pos
    false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=True)
    if B_und is not None:
        false_pos_und = np.setdiff1d(pred_und, cond_skeleton, assume_unique=True)
        false_pos = np.concatenate([false_pos, false_pos_und])
    # reverse
    extra = np.setdiff1d(pred, cond, assume_unique=True)
    reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)
    # compute ratio
    pred_size = len(pred)
    if B_und is not None:
        pred_size += len(pred_und)
    cond_neg_size = 0.5 * d * (d - 1) - len(cond)
    fdr = float(len(reverse) + len(false_pos)) / max(pred_size, 1)
    tpr = float(len(true_pos)) / max(len(cond), 1)
    fpr = float(len(reverse) + len(false_pos)) / max(cond_neg_size, 1)
    # structural hamming distance
    B_lower = np.tril(B + B.T)
    if B_und is not None:
        B_lower += np.tril(B_und + B_und.T)
    pred_lower = np.flatnonzero(B_lower)
    cond_lower = np.flatnonzero(np.tril(B_true + B_true.T))
    extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True)
    missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True)
    shd = len(extra_lower) + len(missing_lower) + len(reverse)
    print('extra %f + missing %f + reverse %f' % ( len(extra_lower), len(missing_lower), len(reverse)))

    return fdr, tpr, fpr, shd, pred_size, len(extra_lower), len(missing_lower), len(reverse)

def compute_BiCScore(G, D):
    '''compute the bic score'''
    # score = gm.estimators.BicScore(self.data).score(self.model)
    origin_score = []
    num_var = G.shape[0]
    for i in range(num_var):
        parents = np.where(G[:,i] !=0)
        score_one = compute_local_BiCScore(D, i, parents)
        origin_score.append(score_one)

    score = sum(origin_score)

    return score


def compute_local_BiCScore(np_data, target, parents):
    # use dictionary
    sample_size = np_data.shape[0]
    var_size = np_data.shape[1]

    # build dictionary and populate
    count_d = dict()
    if len(parents) < 1:
        a = 1

    # unique_rows = np.unique(self.np_data, axis=0)
    # for data_ind in range(unique_rows.shape[0]):
    #     parent_combination = tuple(unique_rows[data_ind,:].reshape(1,-1)[0])
    #     count_d[parent_combination] = dict()
    #
    #     # build children
    #     self_value = tuple(self.np_data[data_ind, target].reshape(1,-1)[0])
    #     if parent_combination in count_d:
    #         if self_value in count_d[parent_combination]:
    #             count_d[parent_combination][self_value] += 1.0
    #         else:
    #             count_d[parent_combination][self_value] = 1.0
    #     else:
    #         count_d[parent_combination] = dict()
    #         count_d

    # slower implementation
    for data_ind in range(sample_size):
        parent_combination = tuple(np_data[data_ind, parents].reshape(1, -1)[0])
        self_value = tuple(np_data[data_ind, target].reshape(1, -1)[0])
        if parent_combination in count_d:
            if self_value in count_d[parent_combination]:
                count_d[parent_combination][self_value] += 1.0
            else:
                count_d[parent_combination][self_value] = 1.0
        else:
            count_d[parent_combination] = dict()
            count_d[parent_combination][self_value] = 1.0

    # compute likelihood
    loglik = 0.0
    # for data_ind in range(sample_size):
    # if len(parents) > 0:
    num_parent_state = np.prod(np.amax(np_data[:, parents], axis=0) + 1)
    # else:
    #    num_parent_state = 0
    num_self_state = np.amax(np_data[:, target], axis=0) + 1

    for parents_state in count_d:
        local_count = sum(count_d[parents_state].values())
        for self_state in count_d[parents_state]:
            loglik += count_d[parents_state][self_state] * (
                        math.log(count_d[parents_state][self_state] + 0.1) - math.log(local_count))

    # penality
    num_param = num_parent_state * (
                num_self_state - 1)  # count_faster(count_d) - len(count_d) - 1 # minus top level and minus one
    bic = loglik - 0.5 * math.log(sample_size) * num_param

    return bic    

def data_to_tensor_dataset(X, batch_size, G=None):
        
    feat_train = torch.FloatTensor(X)
    train_data = TensorDataset(feat_train, feat_train)
    train_data_loader = DataLoader(train_data, batch_size=batch_size)
    
    return train_data_loader, G

def load_data(args, batch_size=1000, suffix='', debug = False):
    #  # configurations
    n, d = args.data_sample_size, args.data_variable_size
    graph_type, degree, sem_type, linear_type = args.graph_type, args.graph_degree, args.graph_sem_type, args.graph_linear_type
    x_dims = args.x_dims

    if args.data_type == 'synthetic':
        # generate data
        G = simulate_random_dag(d, degree, graph_type)
        X = simulate_sem(G, n, x_dims, sem_type, linear_type)
        
        train_data_loader, G = data_to_tensor_dataset(X, batch_size, G)
        return train_data_loader, G

    elif args.data_type == 'real':
        #this where you can use your own dataset
        assert args.path != '', 'Data path must be specified'
        fdpp = FullDataPreProcessor(args.path, args.column_names_list, args.initial_identifier, args.num_of_rows, args.seed)
        preprocessed_dataframe = fdpp.get_dataframe()
        columns = fdpp.sample_dataframe(preprocessed_dataframe[0]).columns
        X = fdpp.sample_dataframe(preprocessed_dataframe[0]).values
                
        train_data_loader, G = data_to_tensor_dataset(X, batch_size)
        return train_data_loader, X.shape[1], columns
    
    elif args.data_type == 'benchmark':
        # create your own version of benchmark discrete data
        assert args.path != '', 'Data path must be specified'
        file_path_dataset = os.path.join(args.path, 'pathfinder_5000.txt') #e.g for pathfinder benchmark dataset it should be something like pathfinder_5000.txt
        
        # read file
        data = np.loadtxt(file_path_dataset, skiprows =0, dtype=np.int32)
        
        #find how many categories there are
        num_cats = num_categories(data.flatten())
            
        # read ground truth graph
        file_path = os.path.join(args.path, 'pathfinder_graph.txt') #e.g for pathfinder benchmark dataset it should be somethiing like pathfinder_graph.txt
        
        graph = np.loadtxt(file_path, skiprows =0, dtype=np.int32)
            
        G = nx.DiGraph(graph)
        X = data[:args.num_of_rows]
        
        train_data_loader, G = data_to_tensor_dataset(X, batch_size, G)

        return train_data_loader, X.shape[1], G, num_cats
    
  

## FullDataPreProcessor.py

In [None]:
# -*- coding: utf-8 -*-
"""
Created on Mon Nov  2 20:13:45 2020
@inproceedings{xu2019modeling,
  title={Modeling Tabular data using Conditional GAN},
  author={Xu, Lei and Skoularidou, Maria and Cuesta-Infante, Alfredo and Veeramachaneni, Kalyan},
  booktitle={Advances in Neural Information Processing Systems},
  year={2019}
}
@article{torfi2020cor,
title={COR-GAN: Correlation-Capturing Convolutional Neural Networks for Generating Synthetic Healthcare Records},
author={Torfi, Amirsina and Fox, Edward A},
journal={arXiv preprint arXiv:2001.09346},
year={2020}
}
"""

#Importing libraries and frameworks
import os
import numpy as np
import torch
import pandas as pd
# from ctgan.data import read_csv
from pandas import read_csv

class FullDataPreProcessor:
    
    def __init__(self, path, column_names, initial_identifier, num_of_rows, seed):
        self.path = path
        self.column_names = column_names
        self.initial_identifier = initial_identifier
        self.num_of_rows = num_of_rows
        self.seed = seed
    
    def get_dataframe(self):
        
        df = read_csv(self.path)
        
        #Getting all of the columns with regards to their dtype
        non_numeric_columns = list(df[0].select_dtypes(exclude=['int64','float64']).columns)
        numeric_int_columns = list(df[0].select_dtypes(include=['int64']).columns)
        numeric_float_columns = list(df[0].select_dtypes(include=['float64']).columns)
            
        #Filling in all of the missing data of type string
        for j in range(len(non_numeric_columns)):
            df[0][non_numeric_columns[j]].fillna('emptyblock', inplace = True)
            
        #Filling in all of the missing data of type int
        for k in range(len(numeric_int_columns)):
            df[0][numeric_int_columns[k]].fillna(-123456789, inplace = True)
             
        #Filling in all of the missing data of type float    
        for l in range(len(numeric_float_columns)):
            df[0][numeric_float_columns[l]].fillna(-1234.56789, inplace = True)
        
        return df
    
    def sample_dataframe(self, dataframe): 
        if self.column_names != []:
            dataframes = []
            #assert self.initial_identifier != '', 'Initial Identifier not specified! Choose one of the following: ' + str(list(dataframe.columns))
            #initial_df = pd.DataFrame({self.initial_identifier: dataframe[self.initial_identifier]})
            #dataframes.append(initial_df)
            for column in self.column_names:
                tmpdf = pd.DataFrame({column: dataframe[column]})
                dataframes.append(tmpdf)
            new_df = pd.concat(dataframes, axis=1)
            if self.num_of_rows != -1:
                assert self.num_of_rows > 0, 'Number of rows must be greater than zero'
                assert self.num_of_rows <= dataframe.shape[0], 'Number of rows must less or equal to the total number of rows'
                sampled_df = new_df.sample(self.num_of_rows, random_state=self.seed)
                sampled_df.sort_index(inplace=True)
                return sampled_df
            else:
                return new_df
        else:
            return dataframe
    
class Dataset:
    def __init__(self, data, transform=None):

        # Transform
        self.transform = transform

        # load data here
        self.data = data
        self.sampleSize = data.shape[0]
        self.featureSize = data.shape[1]

    def return_data(self):
        return self.data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.data[idx]
        sample = np.clip(sample, 0, 1)

        if self.transform:
           pass

        return torch.from_numpy(sample)


## RDP_accountant.py

In [None]:
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RDP analysis of the Sampled Gaussian Mechanism.
Functionality for computing Renyi differential privacy (RDP) of an additive
Sampled Gaussian Mechanism (SGM). Its public interface consists of two methods:
  compute_rdp(q, noise_multiplier, T, orders) computes RDP for SGM iterated
                                   T times.
  get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta
                                   (or eps) given RDP at multiple orders and
                                   a target value for eps (or delta).
Example use:
Suppose that we have run an SGM applied to a function with l2-sensitivity 1.
Its parameters are given as a list of tuples (q1, sigma1, T1), ...,
(qk, sigma_k, Tk), and we wish to compute eps for a given delta.
The example code would be:
  max_order = 32
  orders = range(2, max_order + 1)
  rdp = np.zeros_like(orders, dtype=float)
  for q, sigma, T in parameters:
   rdp += rdp_accountant.compute_rdp(q, sigma, T, orders)
  eps, _, opt_order = rdp_accountant.get_privacy_spent(rdp, target_delta=delta)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import sys

import numpy as np
from scipy import special
import six

########################
# LOG-SPACE ARITHMETIC #
########################


def _log_add(logx, logy):
  """Add two numbers in the log space."""
  a, b = min(logx, logy), max(logx, logy)
  if a == -np.inf:  # adding 0
    return b
  # Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b)
  return math.log1p(math.exp(a - b)) + b  # log1p(x) = log(x + 1)


def _log_sub(logx, logy):
  """Subtract two numbers in the log space. Answer must be non-negative."""
  if logx < logy:
    raise ValueError("The result of subtraction must be non-negative.")
  if logy == -np.inf:  # subtracting 0
    return logx
  if logx == logy:
    return -np.inf  # 0 is represented as -np.inf in the log space.

  try:
    # Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y).
    return math.log(math.expm1(logx - logy)) + logy  # expm1(x) = exp(x) - 1
  except OverflowError:
    return logx


def _log_print(logx):
  """Pretty print."""
  if logx < math.log(sys.float_info.max):
    return "{}".format(math.exp(logx))
  else:
    return "exp({})".format(logx)


def _compute_log_a_int(q, sigma, alpha):
  """Compute log(A_alpha) for integer alpha. 0 < q < 1."""
  assert isinstance(alpha, six.integer_types)

  # Initialize with 0 in the log space.
  log_a = -np.inf

  for i in range(alpha + 1):
    log_coef_i = (
        math.log(special.binom(alpha, i)) + i * math.log(q) +
        (alpha - i) * math.log(1 - q))

    s = log_coef_i + (i * i - i) / (2 * (sigma**2))
    log_a = _log_add(log_a, s)

  return float(log_a)


def _compute_log_a_frac(q, sigma, alpha):
  """Compute log(A_alpha) for fractional alpha. 0 < q < 1."""
  # The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are
  # initialized to 0 in the log space:
  log_a0, log_a1 = -np.inf, -np.inf
  i = 0

  z0 = sigma**2 * math.log(1 / q - 1) + .5

  while True:  # do ... until loop
    coef = special.binom(alpha, i)
    log_coef = math.log(abs(coef))
    j = alpha - i

    log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q)
    log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q)

    log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma))
    log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma))

    log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0
    log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1

    if coef > 0:
      log_a0 = _log_add(log_a0, log_s0)
      log_a1 = _log_add(log_a1, log_s1)
    else:
      log_a0 = _log_sub(log_a0, log_s0)
      log_a1 = _log_sub(log_a1, log_s1)

    i += 1
    if max(log_s0, log_s1) < -30:
      break

  return _log_add(log_a0, log_a1)


def _compute_log_a(q, sigma, alpha):
  """Compute log(A_alpha) for any positive finite alpha."""
  if float(alpha).is_integer():
    return _compute_log_a_int(q, sigma, int(alpha))
  else:
    return _compute_log_a_frac(q, sigma, alpha)


def _log_erfc(x):
  """Compute log(erfc(x)) with high accuracy for large x."""
  try:
    return math.log(2) + special.log_ndtr(-x * 2**.5)
  except NameError:
    # If log_ndtr is not available, approximate as follows:
    r = special.erfc(x)
    if r == 0.0:
      # Using the Laurent series at infinity for the tail of the erfc function:
      #     erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5)
      # To verify in Mathematica:
      #     Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}]
      return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 +
              .625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8)
    else:
      return math.log(r)

def _compute_delta(orders, rdp, eps):
  """Compute delta given a list of RDP values and target epsilon.
  Args:
    orders: An array (or a scalar) of orders.
    rdp: A list (or a scalar) of RDP guarantees.
    eps: The target epsilon.
  Returns:
    Pair of (delta, optimal_order).
  Raises:
    ValueError: If input is malformed.
  """
  orders_vec = np.atleast_1d(orders)
  rdp_vec = np.atleast_1d(rdp)

  if len(orders_vec) != len(rdp_vec):
    raise ValueError("Input lists must have the same length.")

  deltas = np.exp((rdp_vec - eps) * (orders_vec - 1))
  idx_opt = np.argmin(deltas)
  return min(deltas[idx_opt], 1.), orders_vec[idx_opt]


def _compute_eps(orders, rdp, delta):
  """Compute epsilon given a list of RDP values and target delta.
  Args:
    orders: An array (or a scalar) of orders.
    rdp: A list (or a scalar) of RDP guarantees.
    delta: The target delta.
  Returns:
    Pair of (eps, optimal_order).
  Raises:
    ValueError: If input is malformed.
  """
  orders_vec = np.atleast_1d(orders)
  rdp_vec = np.atleast_1d(rdp)

  if len(orders_vec) != len(rdp_vec):
    raise ValueError("Input lists must have the same length.")

  eps = rdp_vec - math.log(delta) / (orders_vec - 1)

  idx_opt = np.nanargmin(eps)  # Ignore NaNs
  return eps[idx_opt], orders_vec[idx_opt]


def _compute_rdp(q, sigma, alpha):
  """Compute RDP of the Sampled Gaussian mechanism at order alpha.
  Args:
    q: The sampling rate.
    sigma: The std of the additive Gaussian noise.
    alpha: The order at which RDP is computed.
  Returns:
    RDP at alpha, can be np.inf.
  """
  if q == 0:
    return 0

  if q == 1.:
    return alpha / (2 * sigma**2)

  if np.isinf(alpha):
    return np.inf

  return _compute_log_a(q, sigma, alpha) / (alpha - 1)


def compute_rdp(q, noise_multiplier, steps, orders):
  """Compute RDP of the Sampled Gaussian Mechanism.
  Args:
    q: The sampling rate.
    noise_multiplier: The ratio of the standard deviation of the Gaussian noise
        to the l2-sensitivity of the function to which it is added.
    steps: The number of steps.
    orders: An array (or a scalar) of RDP orders.
  Returns:
    The RDPs at all orders, can be np.inf.
  """
  if np.isscalar(orders):
    rdp = _compute_rdp(q, noise_multiplier, orders)
  else:
    rdp = np.array([_compute_rdp(q, noise_multiplier, order)
                    for order in orders])

  return rdp * steps


def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None):
  """Compute delta (or eps) for given eps (or delta) from RDP values.
  Args:
    orders: An array (or a scalar) of RDP orders.
    rdp: An array of RDP values. Must be of the same length as the orders list.
    target_eps: If not None, the epsilon for which we compute the corresponding
              delta.
    target_delta: If not None, the delta for which we compute the corresponding
              epsilon. Exactly one of target_eps and target_delta must be None.
  Returns:
    eps, delta, opt_order.
  Raises:
    ValueError: If target_eps and target_delta are messed up.
  """
  if target_eps is None and target_delta is None:
    raise ValueError(
        "Exactly one out of eps and delta must be None. (Both are).")

  if target_eps is not None and target_delta is not None:
    raise ValueError(
        "Exactly one out of eps and delta must be None. (None is).")

  if target_eps is not None:
    delta, opt_order = _compute_delta(orders, rdp, target_eps)
    return target_eps, delta, opt_order
  else:
    eps, opt_order = _compute_eps(orders, rdp, target_delta)
    return eps, target_delta, opt_order


def compute_rdp_from_ledger(ledger, orders):
  """Compute RDP of Sampled Gaussian Mechanism from ledger.
  Args:
    ledger: A formatted privacy ledger.
    orders: An array (or a scalar) of RDP orders.
  Returns:
    RDP at all orders, can be np.inf.
  """
  total_rdp = np.zeros_like(orders, dtype=float)
  for sample in ledger:
    # Compute equivalent z from l2_clip_bounds and noise stddevs in sample.
    # See https://arxiv.org/pdf/1812.06210.pdf for derivation of this formula.
    effective_z = sum([
        (q.noise_stddev / q.l2_norm_bound)**-2 for q in sample.queries])**-0.5
    total_rdp += compute_rdp(
        sample.selection_probability, effective_z, 1, orders)
  return total_rdp


## Private_AAE_WGAN_GP.py

In [None]:
# -*- coding: utf-8 -*-
"""
Created on Mon May 10 18:37:53 2021
@author: Hristo Petkov
"""

"""
@inproceedings{yu2019dag,
  title={DAG-GNN: DAG Structure Learning with Graph Neural Networks},
  author={Yue Yu, Jie Chen, Tian Gao, and Mo Yu},
  booktitle={Proceedings of the 36th International Conference on Machine Learning},
  year={2019}
}
@inproceedings{xu2019modeling,
  title={Modeling Tabular data using Conditional GAN},
  author={Xu, Lei and Skoularidou, Maria and Cuesta-Infante, Alfredo and Veeramachaneni, Kalyan},
  booktitle={Advances in Neural Information Processing Systems},
  year={2019}
}
"""

import time
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import networkx as nx
import scipy.linalg as slin
import os

from torch.autograd import Variable
from torch import optim
from torch.optim import lr_scheduler
from torch.nn import Linear, Sequential, LeakyReLU, Dropout, BatchNorm1d
# from Utils import preprocess_adj_new, preprocess_adj_new1
# from Utils import nll_gaussian, kl_gaussian_sem,  nll_catogrical
# from Utils import _h_A
# from Utils import count_accuracy  
    
class Discriminator(nn.Module):
    """Discriminator module."""
    def __init__(self, input_dim, discriminator_dim, negative_slope, dropout_rate, pac=10):
        super(Discriminator, self).__init__()
        dim = input_dim * pac
        self.pac = pac
        self.pacdim = dim
        
        seq = []
        for item in list(discriminator_dim):
            seq += [Linear(dim, item), LeakyReLU(negative_slope), Dropout(dropout_rate)]
            dim = item

        seq += [Linear(dim, 1)]
        self.seq = Sequential(*seq)
        self.init_weights()
        
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.0)
            elif isinstance(m, BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def calc_gradient_penalty(self, real_data, fake_data, data_type, device='cpu', pac=10, lambda_=10):
        
        # reshape data
        real_data = real_data.squeeze()
        fake_data = fake_data.squeeze()
        
        alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
        alpha = alpha.repeat(1, pac, real_data.size(1))
        alpha = alpha.view(-1, real_data.size(1))
        
        interpolates = alpha * real_data + ((1 - alpha) * fake_data)

        disc_interpolates = self(interpolates)

        gradients = torch.autograd.grad(
            outputs=disc_interpolates, inputs=interpolates,
            grad_outputs=torch.ones(disc_interpolates.size(), device=device),
            create_graph=True, retain_graph=True, only_inputs=True
        )[0]

        gradient_penalty = ((
            gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
        ) ** 2).mean() * lambda_

        return gradient_penalty

    def forward(self, input):
        assert input.size()[0] % self.pac == 0
        return self.seq(input.view(-1, self.pacdim))
    
class Generator(nn.Module):
    """Generator module (based on DAG-NotearsMLP)"""
    def __init__(self, m, dims, bias=True):
        super(Generator, self).__init__()
        assert len(dims) >= 2
        assert dims[-1] == 1
        d = dims[0]
        self.dims = dims
        # fc1: variable splitting for l1
        self.fc1_pos = nn.Linear(d, d * dims[1], bias=bias)
        self.fc1_neg = nn.Linear(d, d * dims[1], bias=bias)
        self.fc1_pos.weight.bounds = self._bounds()
        self.fc1_neg.weight.bounds = self._bounds()
        # fc2: local linear layers
        layers = []
        for l in range(len(dims) - 2):
            layers.append(LocallyConnected(d, dims[l + 1]+m, dims[l + 2], bias=bias))
        self.fc2 = nn.ModuleList(layers)
        self.init_weights()
        
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.0)
            elif isinstance(m, BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _bounds(self):
        d = self.dims[0]
        bounds = []
        for j in range(d):
            for m in range(self.dims[1]):
                for i in range(d):
                    if i == j:
                        bound = (0, 0)
                    else:
                        bound = (0, None)
                    bounds.append(bound)
        return bounds

    def forward(self, x, n, d, m):  # [n, d] -> [n, d]
        x = self.fc1_pos(x) - self.fc1_neg(x)  # [n, d * m1]
        x = x.view(-1, self.dims[0], self.dims[1])  # [n, d, m1]
        for fc in self.fc2:
            # x = F.elu(x)  # [n, d, m1]
            x = torch.sigmoid(x)  # [n, d, m1]
            z = Variable(torch.FloatTensor(np.random.normal(0, 1, (x.size(0), d, m)))).double().cuda()
            x = torch.cat((x,z), dim=2)
            x = fc(x)  # [n, d, m2]
        x = x.squeeze(dim=2)  # [n, d]
        return x

    def h_func(self):
        """Constrain 2-norm-squared of fc1 weights along m1 dim to be a DAG"""
        d = self.dims[0]
        fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight  # [j * m1, i]
        fc1_weight = fc1_weight.view(d, -1, d)  # [j, m1, i]
        A = torch.sum(fc1_weight * fc1_weight, dim=1).t()  # [i, j]
        h = trace_expm(A) - d  # (Zheng et al. 2018)
        # A different formulation, slightly faster at the cost of numerical stability
        # M = torch.eye(d) + A / d  # (Yu et al. 2019)
        # E = torch.matrix_power(M, d - 1)
        # h = (E.t() * M).sum() - d
        return h

    def l2_reg(self):
        """Take 2-norm-squared of all parameters"""
        reg = 0.
        fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight  # [j * m1, i]
        reg += torch.sum(fc1_weight ** 2)
        for fc in self.fc2:
            reg += torch.sum(fc.weight ** 2)
        return reg

    def fc1_l1_reg(self):
        """Take l1 norm of fc1 weight"""
        reg = torch.sum(self.fc1_pos.weight + self.fc1_neg.weight)
        return reg

    @torch.no_grad()
    def fc1_to_adj(self) -> np.ndarray:  # [j * m1, i] -> [i, j]
        """Get W from fc1 weights, take 2-norm over m1 dim"""
        d = self.dims[0]
        fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight  # [j * m1, i]
        fc1_weight = fc1_weight.view(d, -1, d)  # [j, m1, i]
        A = torch.sum(fc1_weight * fc1_weight, dim=1).t()  # [i, j]
        W = torch.sqrt(A)  # [i, j]
        W = W.cpu().detach().numpy()  # [i, j]
        return W
    
class LocallyConnected(nn.Module):
    """Local linear layer, i.e. Conv1dLocal() with filter size 1.
    Args:
        num_linear: num of local linear layers, i.e.
        in_features: m1
        out_features: m2
        bias: whether to include bias or not
    Shape:
        - Input: [n, d, m1]
        - Output: [n, d, m2]
    Attributes:
        weight: [d, m1, m2]
        bias: [d, m2]
    """

    def __init__(self, num_linear, input_features, output_features, bias=True):
        super(LocallyConnected, self).__init__()
        self.num_linear = num_linear
        self.input_features = input_features
        self.output_features = output_features

        self.weight = nn.Parameter(torch.Tensor(num_linear,
                                                input_features,
                                                output_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(num_linear, output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)

        self.reset_parameters()

    @torch.no_grad()
    def reset_parameters(self):
        k = 1.0 / self.input_features
        bound = math.sqrt(k)
        nn.init.uniform_(self.weight, -bound, bound)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input: torch.Tensor):
        # [n, d, 1, m2] = [n, d, 1, m1] @ [1, d, m1, m2]
        out = torch.matmul(input.unsqueeze(dim=2), self.weight.unsqueeze(dim=0))
        out = out.squeeze(dim=2)
        if self.bias is not None:
            # [n, d, m2] += [d, m2]
            out += self.bias
        return out

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'num_linear={}, in_features={}, out_features={}, bias={}'.format(
            self.num_linear, self.in_features, self.out_features,
            self.bias is not None
        )

class TraceExpm(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # detach so we can cast to NumPy
        E = slin.expm(input.detach().cpu().numpy())
        f = np.trace(E)
        E = torch.from_numpy(E)
        ctx.save_for_backward(E)
        return torch.as_tensor(f, dtype=input.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        E, = ctx.saved_tensors
        grad_input = grad_output * E.t()
        return grad_input.cuda()

trace_expm = TraceExpm.apply
    
class AAE_WGAN_GP(nn.Module):
    """DAG-AAE model/framework."""
    def __init__(self, args, adj_A):
        super(AAE_WGAN_GP, self).__init__()
        
        self.data_type = args.data_type
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.batch_size = args.batch_size
        
        self.discriminator_steps = args.discriminator_steps
        self.epochs = args.epochs
        self.lr = args.lr
        
        self.c_A = args.c_A
        self.lambda_A = args.lambda_A
        self.tau_A = args.tau_A
        self.graph_threshold = args.graph_threshold
        
        self.x_dims = args.x_dims
        self.z_dims = args.z_dims
        self.encoder_hidden = args.encoder_hidden
        self.decoder_hidden = args.decoder_hidden
        self.adj_A = adj_A
        
        self.k_max_iter = int(args.k_max_iter)
        self.h_tol = args.h_tol
        
        self.h_A_new = torch.tensor(1.)
        self.h_A_old = np.inf
        
        self.discrete_columns = args.discrete_column_names_list
        self.data_variable_size = self.adj_A.shape[1]
        
        self.mul1 = args.mul1
        self.mul2 = args.mul2
        
        self.lr_decay = args.lr_decay
        self.gamma = args.gamma
        self.negative_slope = args.negative_slope
        self.dropout_rate = args.dropout_rate
        
        self.differentialPrivacy = args.differentialPrivacy
        self.EPSILON = args.EPSILON
        self.DELTA = args.DELTA
        self.MAX_GRAD_NORM = args.MAX_GRAD_NORM
        self.SIGMA = args.SIGMA
        # usually, we should calculate SIGMA. But, to get things moving, lets assume values from literature: DP-GAN uses SIGMA=2
        self.MICRO_BATCH_SIZE = args.MICRO_BATCH_SIZE
        # any batch size that we use needs to be compatible with Hristo's model: he needs batch size to be divisible by number of variables.
        # micro batch size should thus be set accordingly. For now, assume = 10 (since his batch size is 100)
        self.priv_steps = 0
        self.eps = 0

    def forward(self, inputs):
        fake_data = self.generator(inputs.squeeze(), self.batch_size, self.data_variable_size, self.z_dims)
        return fake_data
                
    def update_optimizer(self, optimizer, original_lr, c_A):
        '''related LR to c_A, whenever c_A gets big, reduce LR proportionally'''
        MAX_LR = 1e-2
        MIN_LR = 1e-4

        estimated_lr = original_lr / (math.log10(c_A) + 1e-10)
        if estimated_lr > MAX_LR:
            lr = MAX_LR
        elif estimated_lr < MIN_LR:
            lr = MIN_LR
        else:
            lr = estimated_lr

        # set LR
        for parame_group in optimizer.param_groups:
            parame_group['lr'] = lr

        return optimizer, lr
    
    def train(self, train_loader, epoch, best_val_loss, ground_truth_G, lambda_A, c_A, optimizerG, optimizerD):
        '''training algorithm for a single epoch'''
        t = time.time()
        nll_train = []
        kl_train = []
        mse_train = []
        shd_trian = []

        # self.schedulerG.step()
        # self.schedulerD.step()

        # update optimizer
        optimizerG, lr = self.update_optimizer(optimizerG, self.lr, c_A)
        optimizerD, lr = self.update_optimizer(optimizerD, self.lr, c_A)

        for batch_idx, (data, relations) in enumerate(train_loader):
            for n in range(self.discriminator_steps):
                ###################################################################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###################################################################
                
                data, relations = Variable(data.to(self.device)).double(), Variable(relations.to(self.device)).double()
                
                if self.data_type != 'synthetic':
                    data = data.unsqueeze(2)
                
                optimizerD.zero_grad()
                
                if self.differentialPrivacy:
                      
                    clipped_grads = {
                        name: torch.zeros_like(param) for name, param in self.discriminator.named_parameters()}
                    
                    for k in range(int(data.size(0) / self.MICRO_BATCH_SIZE)): # for each micro-batch, 
                        # truncate data into MICROBATCH
                        data_micro = data[k * self.MICRO_BATCH_SIZE: (k + 1) * self.MICRO_BATCH_SIZE] # average the fake output
                        # pass MICROBATCH through generator: making sure data.size(0) evenly divisible by PAC (=10)
                        fake_data = self(data_micro)
                        # pass MICROBATCH through discriminator 
                        y_fake = self.discriminator(fake_data) 
                        y_real = self.discriminator(data_micro)     
                        if self.x_dims > 1:
                            #vector case
                            pen = self.discriminator.calc_gradient_penalty(
                                data_micro.view(-1, data_micro.size(1) * data_micro.size(2)), fake_data.view(-1, fake_data.size(1) * fake_data.size(2)), self.data_type, self.device) 
                            loss_d = -(torch.mean(F.softplus(y_real)) - torch.mean(F.softplus(y_fake)))
                        else:
                            #normal continious and discrete data case
                            pen = self.discriminator.calc_gradient_penalty(
                                    data_micro, fake_data, self.data_type, self.device) 
                            loss_d = -(torch.mean(F.softplus(y_real)) - torch.mean(F.softplus(y_fake)))
                  
                        # accumulate gradients 
                        pen.backward(retain_graph=True)
                        loss_d.backward()     
                        # now clip them
                        torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.MAX_GRAD_NORM)
                        for name, param in self.discriminator.named_parameters():
                            clipped_grads[name] += param.grad
                        # grads saved in clipped_grads, so remove from params
                        self.discriminator.zero_grad()

                    # last step: now attach the clipped gradients and add noise to them
                    for name, param in self.discriminator.named_parameters():
                        param.grad = (clipped_grads[name] + torch.FloatTensor(
                            clipped_grads[name].size()).normal_(0, self.SIGMA * self.MAX_GRAD_NORM).to(self.device)) / (
                                                  data.size(0) / self.MICRO_BATCH_SIZE)
                   
                    # update parameters with privatised (bounded) gradients 
                    optimizerD.step() 

                    # increment counter for the #times we update based on privatised grads
                    self.priv_steps += 1

                else:

                    fake_data = self(data)
                    y_fake = self.discriminator(fake_data)
                    y_real = self.discriminator(data)
                
                    if self.x_dims > 1:
                        #vector case
                        pen = self.discriminator.calc_gradient_penalty(
                            data.view(-1, data.size(1) * data.size(2)), fake_data.view(-1, fake_data.size(1) * fake_data.size(2)), self.data_type, self.device) 
                        loss_d = -(torch.mean(F.softplus(y_real)) - torch.mean(F.softplus(y_fake)))
                    else:
                        #normal continious and discrete data case
                        pen = self.discriminator.calc_gradient_penalty(
                                data, fake_data, self.data_type, self.device) 
                        loss_d = -(torch.mean(F.softplus(y_real)) - torch.mean(F.softplus(y_fake)))
                        
                    pen.backward(retain_graph=True)
                    loss_d.backward()
                    loss_d = optimizerD.step() 

            
            ###############################################
            # (2) Update G network: maximize log(D(G(z)))
            ###############################################

            optimizerG.zero_grad()

            if self.differentialPrivacy:
                      
                clipped_gen_grads = {
                    name: torch.zeros_like(param) for name, param in self.generator.named_parameters()}
                  
                for k in range(int(data.size(0) / self.MICRO_BATCH_SIZE)): # for each micro-batch,
                    # truncate data into MICROBATCH
                    data_micro = data[k * self.MICRO_BATCH_SIZE: (k + 1) * self.MICRO_BATCH_SIZE] # average the fake output
                    # pass MICROBATCH through generator: making sure data.size(0) evenly divisible by PAC (=10)
                    fake_data = self(data_micro)
                    # pass MICROBATCH through discriminator 
                    y_fake = self.discriminator(fake_data) 

                    # compute MICROBATCH loss and attach grads                     
                    loss_g = -torch.mean(F.softplus(y_fake))
                    h_A = self.generator.h_func()
                    
                    l2_reg = 0.5 * self.mul2 * self.generator.l2_reg()
                    l1_reg = self.mul1 * self.generator.fc1_l1_reg()
                    
                    loss_g += lambda_A * h_A + 0.5 * c_A * h_A * h_A 
                    loss_g += l2_reg + l1_reg

                    loss_g.backward(retain_graph=True)

                    # clip them
                    torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.MAX_GRAD_NORM)
                    for name, param in self.generator.named_parameters():
                        clipped_gen_grads[name] += param.grad
                    # grads saved in clipped_grads, so remove from params
                    self.generator.zero_grad()

                # last step: now attach the clipped gradients and add noise to them
                for name, param in self.generator.named_parameters():
                    param.grad = (clipped_gen_grads[name] + torch.FloatTensor(
                        clipped_gen_grads[name].size()).normal_(0, self.SIGMA * self.MAX_GRAD_NORM).cuda()) / (
                                              data.size(0) / self.MICRO_BATCH_SIZE)
               
                # update parameters with privatised (bounded) gradients 
                optimizerG.step() 

                # increment counter for the #times we update based on privatised grads
                self.priv_steps += 1

            else:
                        
                fake_data = self(data)
                
                y_fake = self.discriminator(fake_data) 
                
                loss_g = -torch.mean(F.softplus(y_fake))
                
                h_A = self.generator.h_func()
                
                l2_reg = 0.5 * self.mul2 * self.generator.l2_reg()
                l1_reg = self.mul1 * self.generator.fc1_l1_reg()
                
                loss_g += lambda_A * h_A + 0.5 * c_A * h_A * h_A 
                
                loss_g += l2_reg + l1_reg
                
                loss_g.backward()
                loss_g = optimizerG.step() 
            
            # compute metrics
            graph = self.generator.fc1_to_adj()
            graph[np.abs(graph) < self.graph_threshold] = 0
                 
            if ground_truth_G != None:
                fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(graph))
                shd_trian.append(shd)
                
            # mse_train.append(F.mse_loss(fake_data, data_micro.squeeze()).item())
            #nll_train.append(loss_g.item())
            #kl_train.append(loss_d.item())

        ''' update learning rates 
        '''
        self.schedulerG.step()
        self.schedulerD.step()

        ''' track metrics for differential privacy
        '''
        print('\ttime to update models: ',time.time() - t)
        t_priv = time.time()
        if self.differentialPrivacy:
            # Calculate the current privacy cost using the accountant: 
            max_lmbd = 1023
            lmbds = range(2, max_lmbd + 1)          
            print('\tcumulative # param updates: ', self.priv_steps)
            # Moments accountant: TensorFlow implementation (see RDP accountant block):
            rdp = compute_rdp(self.batch_size / len(train_loader.dataset), self.SIGMA, self.priv_steps, lmbds)
            self.eps, _, _ = get_privacy_spent(lmbds, rdp, target_delta=self.DELTA)
            print('\ttime to compute privacy budget: ',time.time() - t_priv)

        if ground_truth_G != None:
            
            print('Epoch: {:04d}'.format(epoch),
                  'nll_train: {:.10f}'.format(np.mean(nll_train)),
                  'kl_train: {:.10f}'.format(np.mean(kl_train)),
                  'ELBO_loss: {:.10f}'.format(np.mean(kl_train)  + np.mean(nll_train)),
                  'mse_train: {:.10f}'.format(np.mean(mse_train)),
                  'shd_trian: {:.10f}'.format(np.mean(shd_trian)),
                  'time: {:.4f}s'.format(time.time() - t),
                  # 'epsilon: {:.4f}/{}'.format(self.eps, self.EPSILON)
                  )
            return self.eps, np.mean(np.mean(kl_train)  + np.mean(nll_train)), np.mean(nll_train), np.mean(mse_train), graph#, origin_A
        else:
            print('Epoch: {:04d}'.format(epoch),
                  'nll_train: {:.10f}'.format(np.mean(nll_train)),
                  'kl_train: {:.10f}'.format(np.mean(kl_train)),
                  'ELBO_loss: {:.10f}'.format(np.mean(kl_train)  + np.mean(nll_train)),
                  'mse_train: {:.10f}'.format(np.mean(mse_train)),
                  'time: {:.4f}s'.format(time.time() - t),
                  # 'epsilon: {:.4f}/{}'.format(self.eps, self.EPSILON)
                  )
            return self.eps, np.mean(np.mean(kl_train)  + np.mean(nll_train)), np.mean(nll_train), np.mean(mse_train), graph#, origin_A 
    
    def fit(self, train_loader, ground_truth_G = None):
        
        if not hasattr(self, "discriminator"):
            self.discriminator = Discriminator(self.data_variable_size, (256, 256), self.negative_slope, self.dropout_rate).double().to(self.device)
            
        if not hasattr(self, "generator"):
            self.generator = Generator(self.z_dims, dims=[self.data_variable_size, 10, 1], bias=True).double().to(self.device)
            
        if not hasattr(self, "optimizerD"):
            self.optimizerD = optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(0.5, 0.9), weight_decay=1e-6)
            
        if not hasattr(self, "optimizerG"):
            self.optimizerG = optim.Adam(self.generator.parameters(), lr=self.lr)
            
        if not hasattr(self, "schedulerG"):
            self.schedulerG = lr_scheduler.StepLR(self.optimizerG, step_size=self.lr_decay, gamma=self.gamma)
            
        if not hasattr(self, "schedulerD"):
            self.schedulerD = lr_scheduler.StepLR(self.optimizerD, step_size=self.lr_decay, gamma=self.gamma)

        best_ELBO_loss = np.inf
        best_NLL_loss = np.inf
        best_MSE_loss = np.inf
        best_epoch = 0
        best_ELBO_graph = []
        best_NLL_graph = []
        best_MSE_graph = []

        try:
            if args.differentialPrivacy:
                epoch = 0
                while self.eps < self.EPSILON:
                    epoch += 1
                    self.eps, ELBO_loss, NLL_loss, MSE_loss, graph = self.train(train_loader,
                    epoch, best_ELBO_loss, ground_truth_G, 
                    self.lambda_A, self.c_A, self.optimizerG, self.optimizerD)
                    
                    # increment counter
                    print('\t*******')
                    print('\tprivacy budget expended so far: {:.4f}/{}'.format(self.eps, self.EPSILON))
                    print('\t*******')
                    # if epoch%100==0:
                    #   exit()
                
                    if ELBO_loss < best_ELBO_loss:
                        best_ELBO_loss = ELBO_loss
                        best_epoch = epoch
                        best_ELBO_graph = graph

                    if NLL_loss < best_NLL_loss:
                        best_NLL_loss = NLL_loss
                        best_epoch = epoch
                        best_NLL_graph = graph

                    if MSE_loss < best_MSE_loss:
                        best_MSE_loss = MSE_loss
                        best_epoch = epoch
                        best_MSE_graph = graph
                

                    # # update parameters
                    # # h_A, adj_A are computed in loss anyway, so no need to store
                    # self.h_A_old = self.h_A_new
                    # self.lambda_A += self.c_A * self.h_A_new
                print('\t*******')
                print('\tprivacy budget reached')
                print('\t*******')
  
            else:
                for step_k in range(self.k_max_iter):
                    while self.c_A < 1e+20:
                        for epoch in range(self.epochs):
                            self.eps, ELBO_loss, NLL_loss, MSE_loss, graph = self.train(train_loader,
                            epoch, best_ELBO_loss, ground_truth_G, 
                            self.lambda_A, self.c_A, self.optimizerG, self.optimizerD)
                            if ELBO_loss < best_ELBO_loss:
                                best_ELBO_loss = ELBO_loss
                                best_epoch = epoch
                                best_ELBO_graph = graph

                            if NLL_loss < best_NLL_loss:
                                best_NLL_loss = NLL_loss
                                best_epoch = epoch
                                best_NLL_graph = graph

                            if MSE_loss < best_MSE_loss:
                                best_MSE_loss = MSE_loss
                                best_epoch = epoch
                                best_MSE_graph = graph

                        print("Optimization Finished!")
                        print("Best Epoch: {:04d}".format(best_epoch))
                    
                        if ELBO_loss > 2 * best_ELBO_loss:
                            break

                        # update parameters
                        with torch.no_grad():
                            self.h_A_new = self.generator.h_func().item()
                        if self.h_A_new > 0.25 * self.h_A_old:
                            self.c_A*=10
                        else:
                            break

                    # update parameters
                    # h_A, adj_A are computed in loss anyway, so no need to store
                    self.h_A_old = self.h_A_new
                    self.lambda_A += self.c_A * self.h_A_new
                    
                    if self.h_A_new <= self.h_tol:
                        break
                
            if ground_truth_G != None:
                # test()
                #print (best_ELBO_graph)
                #print(nx.to_numpy_array(ground_truth_G))
                fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(best_ELBO_graph))
                print('Best ELBO Graph Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

                #print(best_NLL_graph)
                #print(nx.to_numpy_array(ground_truth_G))
                fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(best_NLL_graph))
                print('Best NLL Graph Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

                #print (best_MSE_graph)
                #print(nx.to_numpy_array(ground_truth_G))
                fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(best_MSE_graph))
                print('Best MSE Graph Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

                graph = self.generator.fc1_to_adj()
                graph[np.abs(graph) < 0.1] = 0
                # print(graph)
                fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(graph))
                print('threshold 0.1, Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

                graph[np.abs(graph) < 0.2] = 0
                # print(graph)
                fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(graph))
                print('threshold 0.2, Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

                graph[np.abs(graph) < 0.3] = 0
                # print(graph)
                fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(graph))
                print('threshold 0.3, Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)
                
                return graph
            else:
                graph = self.generator.fc1_to_adj()
                graph[np.abs(graph) < self.graph_threshold] = 0
                return graph

        except KeyboardInterrupt:
            # print the best anway
            #print(best_ELBO_graph)
            #print(nx.to_numpy_array(ground_truth_G))
            fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(best_ELBO_graph))
            print('Best ELBO Graph Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

            #print(best_NLL_graph)
            #print(nx.to_numpy_array(ground_truth_G))
            fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(best_NLL_graph))
            print('Best NLL Graph Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

            #print(best_MSE_graph)
            #print(nx.to_numpy_array(ground_truth_G))
            fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(best_MSE_graph))
            print('Best MSE Graph Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

            graph = self.generator.fc1_to_adj()
            graph[np.abs(graph) < 0.1] = 0
            # print(graph)
            fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(graph))
            print('threshold 0.1, Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

            graph[np.abs(graph) < 0.2] = 0
            # print(graph)
            fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(graph))
            print('threshold 0.2, Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)

            graph[np.abs(graph) < 0.3] = 0
            # print(graph)
            fdr, tpr, fpr, shd, nnz = count_accuracy(ground_truth_G, nx.DiGraph(graph))
            print('threshold 0.3, Accuracy: fdr', fdr, ' tpr ', tpr, ' fpr ', fpr, 'shd', shd, 'nnz', nnz)
    
    def save_model(self):
        assert self.save_directory != '', 'Saving directory not specified! Please specify a saving directory!'
        torch.save(self.generator.state_dict(), os.path.join(self.save_directory,'generator.pth'))
        torch.save(self.discriminator.state_dict(), os.path.join(self.save_directory,'discriminator.pth'))
        
    def load_model(self):
        assert self.load_directory != '', 'Loading directory not specified! Please specify a loading directory!'
        
        generator = Generator(self.z_dims, dims=[self.data_variable_size, 10, 1], bias=True).double().to(self.device)
        discriminator = Discriminator(self.data_variable_size, (256, 256), self.negative_slope, self.dropout_rate).double().to(self.device)
             
        generator.load_state_dict(torch.load(os.path.join(self.load_directory,'generator.pth')))
        discriminator.load_state_dict(torch.load(os.path.join(self.load_directory,'discriminator.pth')))
             
        return generator, discriminator



## Private_Main.py

In [None]:

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 24 22:49:31 2020
@author: Hristo Petkov
"""

"""
@inproceedings{yu2019dag,
  title={DAG-GNN: DAG Structure Learning with Graph Neural Networks},
  author={Yue Yu, Jie Chen, Tian Gao, and Mo Yu},
  booktitle={Proceedings of the 36th International Conference on Machine Learning},
  year={2019}
}
@inproceedings{xu2019modeling,
  title={Modeling Tabular data using Conditional GAN},
  author={Xu, Lei and Skoularidou, Maria and Cuesta-Infante, Alfredo and Veeramachaneni, Kalyan},
  booktitle={Advances in Neural Information Processing Systems},
  year={2019}
}
"""

#Importing libraries and frameworks
import time
import os
import torch
import pickle
import numpy as np
import networkx as nx
import pandas as pd 
# from FullDataPreProcessor import FullDataPreProcessor
# from AAE_WGAN_GP import AAE_WGAN_GP
# from Utils import load_data
# from Utils import draw_dag
# from Utils import compute_BiCScore
# from Utils import pns_
from argparse import ArgumentParser

#Adding a lot of args here
parser = ArgumentParser()
parser.add_argument('--path', type=str, default='',
                    help='choosing a path for the input.')
parser.add_argument('--column_names_list', type=str, nargs='+', default=[],
                    help = 'choosing the column names for samping of original dataframe')
parser.add_argument('--discrete_column_names_list', type=str, nargs='+', default=[],
                    help = 'choosing the discrete column names in the dataframe')
parser.add_argument('--discriminator_steps', type=int, default=1,
                    help='Number of steps for the discriminator')
parser.add_argument('--initial_identifier', type=str, default='',
                    help='Initial Identifier for the sample dataframe')
parser.add_argument('--num_of_rows', type=int, default=-1,
                    help='Number of rows in the sampled dataframe')
parser.add_argument('--save_model', default='', type=str,
                        help='A directory to save a trained model to.')
parser.add_argument('--load_model', default='', type=str,
                    help='A directory to load a trained model from.')
parser.add_argument('--export_directory', type=str, default='',
                    help='choosing a directory for the output.')
parser.add_argument('--verbose', type=int, default=1,
                    help='used to control the print statements per epoch.')

# -----------data parameters ------
# configurations
parser.add_argument('--synthesize', type=int, default=0,
                    help='Flag for synthesiing synthetic data')
parser.add_argument('--pns', type=int, default=1,
                    help='Flag for primary neighbour selection')
parser.add_argument('--data_type', type=str, default='synthetic',
                    choices=['synthetic', 'benchmark', 'real'],
                    help='choosing which experiment to do.')
parser.add_argument('--data_sample_size', type=int, default=5000,
                    help='the number of samples of data')
parser.add_argument('--data_variable_size', type=int, default=10,
                    help='the number of variables in synthetic generated data')
parser.add_argument('--graph_type', type=str, default='erdos-renyi',
                    help='the type of DAG graph by generation method')
parser.add_argument('--graph_degree', type=int, default=3,
                    help='the number of degree in generated DAG graph')
parser.add_argument('--graph_sem_type', type=str, default='linear-gauss',
                    help='the structure equation model (SEM) parameter type')
parser.add_argument('--graph_linear_type', type=str, default='nonlinear_2',
                    help='the synthetic data type: linear -> linear SEM, nonlinear_1 -> x=Acos(x+1)+z, nonlinear_2 -> x=2sin(A(x+0.5))+A(x+0.5)+z')
parser.add_argument('--edge-types', type=int, default=2,
                    help='The number of edge types to infer.')
parser.add_argument('--x_dims', type=int, default=1, #vector case: need to be equal to the last dimension of vector data to work
                    help='The number of input dimensions: default 1.')
parser.add_argument('--z_dims', type=int, default=1,
                    help='The number of latent variable dimensions: default the same as variable size.')

# -----------training hyperparameters
parser.add_argument('--graph_threshold', type=  float, default = 0.3,  # 0.3 is good, 0.2 is error prune
                    help = 'threshold for learned adjacency matrix binarization')
parser.add_argument('--tau_A', type = float, default=0.0,
                    help='coefficient for L-1 norm of A.')
parser.add_argument('--lambda_A',  type = float, default= 0.,
                    help='coefficient for DAG constraint h(A).')
parser.add_argument('--c_A',  type = float, default= 1,
                    help='coefficient for absolute value h(A).')
parser.add_argument('--negative_slope', type=float, default=0.2,
                    help='negative_slope for leaky_relu')
parser.add_argument('--dropout_rate', type=float, default=0.0,
                    help='rate for discriminator dropout')
parser.add_argument('--noise', type=float, default=0.5,
                    help='amount of noise for the ANM')


parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default= 300,
                    help='Number of epochs for step 1 to train.')
parser.add_argument('--epochs2', type=int, default= 600,
                    help='Number of epochs for step 2 to train.')
parser.add_argument('--batch_size', type=int, default = 100, # note: should be divisible by sample size, otherwise throw an error
                    help='Number of samples per batch.')
parser.add_argument('--lr', type=float, default=3e-3,  # basline rate = 1e-3
                    help='Initial learning rate.')
parser.add_argument('--encoder-hidden', type=int, default=64,
                    help='Number of hidden units.')
parser.add_argument('--decoder-hidden', type=int, default=64,
                    help='Number of hidden units.')
parser.add_argument('--k_max_iter', type = int, default = 1e2,
                    help ='the max iteration number for searching lambda and c')
parser.add_argument('--mul1', default=0.01, type=float,
                    help='multiplier for the L1_Loss')
parser.add_argument('--mul2', default=0.01, type=float,
                    help='multiplier for the L2_Loss')

parser.add_argument('--suffix', type=str, default='_springs5',
                    help='Suffix for training data (e.g. "_charged".')
parser.add_argument('--h_tol', type=float, default = 1e-8,
                    help='the tolerance of error of h(A) to zero')
parser.add_argument('--lr-decay', type=int, default=200,
                    help='After how epochs to decay LR by a factor of gamma.')
parser.add_argument('--gamma', type=float, default= 1.0,
                    help='LR decay factor.')
parser.add_argument('--temp', type=float, default=1.0,
                    help='Temperature for Gumbel softmax.')
parser.add_argument('--hard', action='store_true', default=False,
                    help='Uses discrete samples in training forward pass.')


#******************************************************************
# hyper-parameters for differential privacy. 
parser.add_argument('--SIGMA', type=float, default = 0.5,
                    help='Amount of noise to add to CLIPPED gradients during DP-SGD. Larger means more private, so it will take longer to train until reaching EPSILON.')
parser.add_argument('--DELTA', type=float, default = 1e-5,
                    help='Additional term for widening privacy bound set by EPSILON. Smaller means tighter bound, so more private. Typically set to 1/N, where N is the dataset size.')
parser.add_argument('--EPSILON', type=float, default=50.0,
                    help='Privacy budget, which sets an upper bound on how much model performance will change with and without new data. Lower means more private, but poorer performance.')
parser.add_argument('--MAX_GRAD_NORM', type=float, default= 0.1,
                    help='How much we clip the gradients by during each optim.step(). Lower means greater privacy, since we allow parameters to learn less info from the data.')
parser.add_argument('--MICRO_BATCH_SIZE', type=int, default=10,
                    help='Number of samples to average gradients over during optim.step(). Keep as small as possible, but also divisible by BATCH_SIZE.')
parser.add_argument('--differentialPrivacy', type=bool, default=True,
                    help='Choose whether or not to privatise SGD.')
# ******************************************************************

# Line below NEEDED in Colab. Otherwise, argparse doesnt work...
parser.add_argument("-f", "--file", required=False)
args = parser.parse_args()
print(args)

#controlls randomness of the entire program
torch.manual_seed(args.seed)

def main():
    
    t = time.time()
    
    if args.data_type == 'real':
        
        train_loader, data_variable_size, columns = load_data(args, args.batch_size, args.suffix)
        
        # add adjacency matrix A
        num_nodes = data_variable_size
        adj_A = np.zeros((num_nodes, num_nodes))
    
        aae_wgan_gp = AAE_WGAN_GP(args, adj_A)
        
        causal_graph = aae_wgan_gp.fit(train_loader)
        
        draw_dag(causal_graph, args.data_type, columns)
        
    elif args.data_type == 'benchmark':
        
        train_loader, data_variable_size, ground_truth_G, num_cats  = load_data(args, args.batch_size, args.suffix)
        
         # add adjacency matrix A
        num_nodes = data_variable_size
        adj_A = np.zeros((num_nodes, num_nodes))
    
        aae_wgan_gp = AAE_WGAN_GP(args, adj_A)
        
        causal_graph = aae_wgan_gp.fit(train_loader, ground_truth_G)
        
        BIC_score = compute_BiCScore(np.asarray(nx.to_numpy_matrix(ground_truth_G)), causal_graph)
        print('BIC_score: ' + str(BIC_score))
        
        draw_dag(causal_graph, args.data_type)
        
    else:
        if args.synthesize:
            #create and store synthetic data
            train_loader, ground_truth_G = load_data(args, args.batch_size, args.suffix)
            
            with open(r"train_loader.pkl", "wb") as output_file:
                pickle.dump(train_loader, output_file)
        
            with open(r"ground_truth_G.pkl", "wb") as output_file_G:
                pickle.dump(ground_truth_G, output_file_G)
                
        #load synthetic data
        with open(r"train_loader.pkl", "rb") as input_file:
            train_data = pickle.load(input_file)
        
        with open(r"ground_truth_G.pkl", "rb") as input_file_G:
            ground_truth = pickle.load(input_file_G)
        
        # add adjacency matrix A
        num_nodes = args.data_variable_size
        adj_A = np.zeros((num_nodes, num_nodes))
        
        #test = pns_(adj_A, train_data, num_nodes, 0.75)
        #print(test)
        #print(test.shape)
    
        aae_wgan_gp = AAE_WGAN_GP(args, adj_A)
        causal_graph = aae_wgan_gp.fit(train_data, ground_truth)
        
        #draw_dag(causal_graph, args.data_type)
    
    #causal_graph.to_csv(os.path.join(args.export_directory, 'adjacency_matrix.csv'), index=False)
                
    print('Programm finished in: ' + str(time.strftime("%H:%M:%S", time.gmtime(time.time() - t))))

if __name__ == "__main__":
    main()



#=========================================

# Thoughts on DAG-WGAN implementation

Usually, we would only need to privatise updates when the parameters have been involved with the real data. Intuitively, we don't need to worry about the privacy of fake patients, but rather the real ones! 
For this reason, when a set of parameters are learning information from the real data, we need to constrain their gradients to limit how much they learn from the real data.

This would be the case if $z\thicksim\mathcal{N}$, as in GANs (i.e., generate X, given random noise).

**However, with DAG-WGAN, because the fake data is produced as a function of the real features from the encoder $E$ i.e., $z\thicksim P_E(z|X)$, then this means we probably need to bound all gradients in the model**

In other words, since every part of this model will see the patient data, their parameters need to be constrained. 

# Components
- $E: X→z$. Maps the patient data to their corresponding latent representation. The encoder sees the patient's data directly, and therefore must be privatised.
- $G: z→\hat{X}$. Generator for reconstructing $X$ from its latent features (and $A$: parametric case). Since $z$ contains information about $X$, we will also need to privatise updates made to $G$.
- $D: \mathbb{R}^{|X|}\rightarrow\mathbb{R}$. Maps given data to real-value (i.e. probability of real/fake). Both sets of data ($X, \hat{X}$) contain patient info, so all updates to $D$ need to be privatised too. 

## Template
This isn't a working version, but a simple pseudocode-type illustration for how differential privacy would work with plain SGD (DP-SGD). 

In [None]:
from torch.nn.utils import clip_grad_norm_

for batch in Dataloader(train_dataset, batch_size=100):
    for param in model.parameters():
        param.accumulated_grads = []
    
    # Run the microbatches: these are samples from the mini-batch (e.g. = 10).
    # For example: given a MINI-batch of 100 samples, a MICRO-batch could be 10 of these samples. 
    #              we would then compute the avg gradient over these 10 samples (instead of the 100 as we usually do).
    for microbatch in batch:
        x, y = microbatch
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
    
        # Clip each parameter's per-sample gradient
        for param in model.parameters():
            per_sample_grad = p.grad.detach().clone()
            clip_grad_norm_(per_sample_grad, max_norm=args.max_grad_norm)  
            param.accumulated_grads.append(per_sample_grad)  
        
    # Aggregate back
    for param in model.parameters():
        param.grad = torch.stack(param.accumulated_grads, dim=0)

    # Now we are ready to update and add noise!
    for param in model.parameters():
        param = param - args.lr * param.grad    # plain SGD, without momentum. 
        param += torch.normal(mean=0, std=args.noise_multiplier * args.max_grad_norm)
        
    model.zero_grads()

# Practice stuff
Place to play around with code snippets before merging with main