In [1]:
import torch
import torch.nn as nn
import numpy as np
from collections import defaultdict
import itertools
import torch.nn.utils.spectral_norm as sn
import torch.nn.functional as F

In [2]:
#  mlp model
# num_layers will be a number and num_nodes will be a list

class MLP(nn.Module):
    def __init__(self, num_layer, num_nodes, relu_final=False):
        super(MLP, self).__init__()
        main = nn.Sequential()
        for l in np.arange(num_layer - 1):
            main.add_module('linear{0}'.format(l), nn.Linear(num_nodes[l], num_nodes[l + 1]))
            if relu_final:
                main.add_module('relu{0}'.format(l), nn.ReLU())
            else:
                if num_layer > 2 and l < num_layer - 2: # 2 layers = linear network, >2 layers, relu net
                    main.add_module('relu{0}'.format(l), nn.ReLU())
        self.main = main

    def forward(self, input):
        output = self.main(input)
        return output


In [4]:
# a MLP generator
class MLP_Generator(nn.Module):
    def __init__(self, z_dim, pax_dim, plabels_dim, noise_dim, num_layer=1, num_nodes=64):
        super(MLP_Generator, self).__init__()

        self.decoder = MLP(num_layer + 2, [noise_dim + pax_dim + plabels_dim] + [num_nodes]*num_layer + [z_dim])

    # pax is parents of x , plabels are the pseudo labels of the domain
    def forward(self, noise, pax, plabels, noise_d=None):

        input_gen = torch.cat((pax, noise, plabels), axis=1)
        output = self.decoder(input_gen)
        return output

In [None]:
# discriminator + domain predictor

class MLP_AuxClassifier(nn.Module):
    def __init__(self, z_dim, plabels_dim, do_num, num_layer=1, num_nodes=64):
        super(MLP_AuxClassifier, self).__init__()
        # classifying zi's into different domains
        self.cls = MLP(num_layer + 2, [z_dim] + [num_nodes]*num_layer +[plabels_dim])
        
        
        self.common_net = MLP(num_layer + 1, [z_dim] + [num_nodes]*num_layer, relu_final=True)
        
        self.aux_c = nn.Linear(num_nodes, cl_num)
        self.aux_c_tw = nn.Linear(num_nodes, cl_num)
        
        self.aux_d = nn.Linear(num_nodes, do_num)
        self.aux_d_tw = nn.Linear(num_nodes, do_num)

    def forward(self, input0):
        input = self.common_net(input0)
        output_c = self.aux_c(input)
        output_c_tw = self.aux_c_tw(input)
        output_d = self.aux_d(input)
        output_d_tw = self.aux_d_tw(input)
        output_cls = self.cls(input0)
        return output_c, output_c_tw, output_d, output_d_tw, output_cls


In [5]:
# an MLP classifier to predict y
class MLP_Classifier(nn.Module):
    def __init__(self, i_dim, cl_num, num_layer=1, num_nodes=64):
        super(MLP_Classifier, self).__init__()
        self.net = MLP(num_layer + 2, [i_dim] + [num_nodes]*num_layer +[cl_num])

    def forward(self, input):
        output_c = self.net(input)
        return output_c


In [6]:
# The graph nodes.
class Data(object):
    def __init__(self, name):
        self.__name = name
        self.__links = set()

    @property
    def name(self):
        return self.__name

    @property
    def links(self):
        return set(self.__links)

    def add_link(self, other):
        self.__links.add(other)
        other.__links.add(self)

In [None]:
# Class to represent a graph, for topological sort of DAG
class Graph:
    def __init__(self, vertices):
        self.graph = defaultdict(list)  # dictionary containing adjacency List
        self.V = vertices  # No. of vertices

    # function to add an edge to graph
    def addEdge(self, u, v):
        self.graph[u].append(v)

    # A recursive function used by topologicalSort
    def topologicalSortUtil(self, v, visited, stack):

        # Mark the current node as visited.
        visited[v] = True

        # Recur for all the vertices adjacent to this vertex
        for i in self.graph[v]:
            if visited[i] is False:
                self.topologicalSortUtil(i, visited, stack)

        # Push current vertex to stack which stores result
        stack.insert(0, v)

    # The function to do Topological Sort. It uses recursive
    # topologicalSortUtil()
    def topologicalSort(self):
        # Mark all the vertices as not visited
        visited = [False] * self.V
        stack = []

        # Call the recursive helper function to store Topological
        # Sort starting from all vertices one by one
        for i in range(self.V):
            if visited[i] is False:
                self.topologicalSortUtil(i, visited, stack)

        # Return contents of stack
        return stack


In [None]:
# a decoder according to a DAG
class DAG_Generator(nn.Module):
    def __init__(self, num_vertices, plabels_dim, pax_dim, z_dim, num_layer=1, num_nodes=64, dagMat=None):
        super(DAG_Generator, self).__init__()
        
        # create a dag
        dag = Graph(num_vertices)

        for i in range(num_vertices):
            for j in range(num_vertices):
                if dagMat[j, i]:
                    dag.addEdge(i, j)

        # extract y and d signs
        
        ###########check this#################
        self.yd_sign = dagMat[:, -2:]
        dagMat = dagMat[:, :-2]

        # topological sort
        nodeSort = dag.topologicalSort()
        numInput = dagMat.sum(1)

        self.dnet = nn.Linear(do_num, do_dim * i_dim, bias=False)
        
        self.cnet = nn.Linear(cl_num, cl_dim * i_dim, bias=False)

        # construct generative network according to the dag
        nets = nn.ModuleList()
        for i in range(i_dim):
            num_nodesIn = int(numInput[i]) + cl_dim + do_dim + z_dim
            num_nodes_i = [num_nodesIn] + [num_nodes]*num_layer + [1]
            netMB = MLP(num_layer + 2, num_nodes_i)
            nets.append(netMB)

        # prediction network
        self.nets = nets
        self.nodeSort = nodeSort
        self.nodesA = np.array(range(i_dim)).reshape(i_dim, 1).tolist()
        self.i_dim = i_dim
        self.i_dimNew = i_dim
        self.do_num = do_num
        self.cl_num = cl_num
        self.cl_dim = cl_dim
        self.do_dim = do_dim
        self.z_dim = z_dim
        self.dagMat = dagMat
        self.numInput = numInput
        self.is_reg = is_reg
        self.ischain = False

        # inputs: class indicator, domain indicator, noise, features
        # separate forward for each factor
    def forward_indep(self, noise, input_c, input_d, input_x, noise_d=None, device='cpu'):
        # class parameter network
        batch_size = input_c.size(0)
        if self.is_reg:
            inputs_c = input_c.view(batch_size, 1)
        else:
            inputs_c = self.cnet(input_c)
        if self.prob:
            theta = self.mu + torch.mul(torch.log(1+torch.exp(self.sigma)), noise_d)
            inputs_d = torch.matmul(input_d, theta)
        else:
            inputs_d = self.dnet(input_d)

        inputs_n = noise
        inputs_f = input_x

        # create output array
        output = torch.zeros((batch_size, len(self.nodeSort)))
        output = output.to(device)

        # create a network for each module
        for i in self.nodeSort:
            inputs_pDim = self.numInput[i]
            if inputs_pDim > 0:
                index = np.argwhere(self.dagMat[i, :])
                index = index.flatten()
                index = [int(j) for j in index]
                inputs_p = inputs_f[:, index] # get the parent data from real data, not fake data!!!
            if not self.is_reg:
                inputs_ci = inputs_c[:, i*self.cl_dim:(i+1)*self.cl_dim]
            else:
                inputs_ci = inputs_c
            inputs_di = inputs_d[:, i*self.do_dim:(i+1)*self.do_dim]
            inputs_ni = inputs_n[:, i*self.z_dim:(i+1)*self.z_dim]
            if inputs_pDim > 0:
                inputs_i = torch.cat((inputs_ci, inputs_di, inputs_ni, inputs_p), 1)
            else:
                inputs_i = torch.cat((inputs_ci, inputs_di, inputs_ni), 1)

            output[:, i] = self.nets[i](inputs_i).squeeze()

        if self.prob:
            KL_reg = 1 + torch.log(torch.log(1+torch.exp(self.sigma))**2) - self.mu**2 - torch.log(1+torch.exp(self.sigma))**2
            if KL_reg.shape[1] > 1:
                KL_reg = KL_reg.sum(axis=1)
            return output, -KL_reg
        else:
            return output

    # inputs: class indicator, domain indicator, noise
    # forward for all factors in a graph
    def forward(self, noise, input_c, input_d, device='cpu', noise_d=None):
        # class parameter network
        batch_size = input_c.size(0)
        if self.is_reg:
            inputs_c = input_c.view(batch_size, 1)
        else:
            inputs_c = self.cnet(input_c)
        if self.prob:
            theta = self.mu + torch.mul(torch.log(1+torch.exp(self.sigma)), noise_d)
            inputs_d = torch.matmul(input_d, theta)
        else:
            inputs_d = self.dnet(input_d)

        inputs_n = noise

        output = torch.zeros((batch_size, len(self.nodeSort)))
        output = output.to(device)

        # create a network for each module
        for i in self.nodeSort:
            inputs_pDim = self.numInput[i]
            if inputs_pDim > 0:
                index = np.argwhere(self.dagMat[i, :])
                index = index.flatten()
                index = [int(j) for j in index]
                inputs_p = output[:, index]

            if not self.is_reg:
                inputs_ci = inputs_c[:, i * self.cl_dim:(i + 1) * self.cl_dim]
            else:
                inputs_ci = inputs_c
            inputs_di = inputs_d[:, i * self.do_dim:(i + 1) * self.do_dim]
            inputs_ni = inputs_n[:, i * self.z_dim:(i + 1) * self.z_dim]
            if inputs_pDim > 0:
                inputs_i = torch.cat((inputs_ci, inputs_di, inputs_ni, inputs_p), 1)
            else:
                inputs_i = torch.cat((inputs_ci, inputs_di, inputs_ni), 1)

            output[:, i] = self.nets[i](inputs_i).squeeze()

        return output