In [1]:
import tensorflow as tf
import numpy as np

In [2]:
def identity_transpose(A):
    '''Calculate (I - A^T)'''
    return tf.eye(A.shape[0], A.shape[0]) - tf.transpose(A)

def identity_transpose_inverse(A):
    '''Calculate (I - A^T)^(-1)'''
    return tf.linalg.inv(identity_transpose(A))

In [5]:
class Encoder(tf.keras.Model):
    '''
    Encoder class for DAG-GNN method

    Inputs:
    adjA (tensor [d, d]) : current estimated adjascency matrix
    ind_dim (int) : dimension of input layer
    out_dim (int) : dimension of output layer
    hid_dim (int) : dimension of hidden layer

    Outputs:
    out (tensor [batch, d]) : output of neural network
    ligs (tensor [d, d]) : product of (I - A^T @ out)
    adjA (tensor [d, d]) : current estimated adjascency matrix

    '''
    def __init__(self, adjA, in_dim, hid_dim, out_dim):
        super(MLPEncoder, self).__init__()
        self.adjA = tf.Variable(initial_value = adjA, trainable = True)
        self.Wa = tf.variable(np.zeros(), trainable = True)

        self.fc1 = tf.keras.layers.Dense(hid_dim, activation= 'relu')
        self.fc2 = tf.keras.layers.Dense(out_dim)

    def call(self, inputs):
        '''Forward process of neural network'''
        #calculate I - A^T
        I_adjA = identity_transpose(self.adjA)
        hidden = self.fc1(inputs)
        out = self.fc2(hidden)
        logits = tf.matmul(I_adj, out)
        return out, logits, self.adjA


class Decoder(tf.keras.Model):
    '''
    Decoder class for DAG-GNN method

    Inputs:
    ind_dim (int) : dimension of input layer
    out_dim (int) : dimension of output layer
    hid_dim (int) : dimension of hidden layer

    Outputs:
    '''
    def __init__(self, in_dim, hid_dim, out_dim):
        super(Decoder, self).__init__()
        self.fc1 = tf.keras.layers.Dense(hid_dim, activation = 'relu')
        self.fc2 = tf.keras.layers.Dense(out_dim)

    def call(self, inputs,  adjA):

        #calculate (I - A^T)^(-1)
        I_adjA = identity_transpose(adjA)
        z = tf.matmul(I_adjA, inputs)

        hidden = self.fc1(z)
        out = self.fc2(hidden)
        return z, out

In [None]:
def dag_gnn(data, hid_dim = 20, out_dim = 4, max_iter = 10e8, rho_max = 10e20, epochs = 20):
    '''
    Function for inference of DAG with method DAG-GNN
    
    Inputs:
    
    Outputs:
    
    '''
    
    n_variables = data.shape[1]
    rho = 1
    alpha = 1
    
    def _h(A):
    '''Calculate the constraint of A ensure that it's a DAG'''
    #(Yu et al. 2019 DAG-GNN)
    # h(w) = tr[(I + kA*A)^n_variables] - n_variables
    M = tf.eye(n_variables, num_columns = n_variables) + A/n_variables
    E = M
    for _ in range(n_variables - 2):
        E = tf.linalg.matmul(E, M)
    h = tf.math.reduce_sum(tf.transpose(E) * M) - n_variables
    return h

    def _nll_loss(y_est, y):
    '''
    Compute negative likelihood loss for the adjacency matrix
    L = (y_est - y)^2 / (2 * n)
    '''
    nll = tf.pow(y_test - y, 2) / (2 * n_variables)
    return tf.sum(nll)

    def _kl_loss(y):
    '''Compute KL divergence loss'''
    return tf.sum(tf.pow(y, 2) / ( 2 * y.shape[0]))

    def _loss(A, A_est, logits):
    '''
    Function that evaluate the model loss
    loss = kl loss + nll loss + dag constraint + l1 reg + l2 reg
    '''
        h = _h(A)
        h_loss = 0.5 * rho * h * h + alpha * h
        kl_loss = _kl_loss(logits)
        nll_loss = _nll_loss(A_est, A)
        return loss
        
    
    new_adj = np.zeros((n_variables, n_variables))
    #setup of data loader
    train_loader, test_loader = setup_data_loader(data)
    
    #setup of neural networks
    encoder = Encoder(new_adj, n_variables, hid_dim, out_dim)
    decoder = Decoder(out_dim, hid_dim, n_variables)
    
    for _ in range(max_iter):
        while rho < rho_max:
            for epoch in range(epochs):
                W_est = train()
                
    for batch_id, batch_data in enumarete(train_loader):
        
        #passing through neural network
        encoder_out, logits, adjA = encoder(batch_data)
        z, decoder_out = decoder(logits, adjA)
        