In [1]:
import torch
import torch.nn as nn
import numpy as np
from collections import defaultdict
import itertools
import torch.nn.utils.spectral_norm as sn
import torch.nn.functional as F

In [2]:
#  mlp model
# num_layers will be a number and num_nodes will be a list

class MLP(nn.Module):
    def __init__(self, num_layer, num_nodes, relu_final=False):
        super(MLP, self).__init__()
        main = nn.Sequential()
        for l in np.arange(num_layer - 1):
            main.add_module('linear{0}'.format(l), nn.Linear(num_nodes[l], num_nodes[l + 1]))
            if relu_final:
                main.add_module('relu{0}'.format(l), nn.ReLU())
            else:
                if num_layer > 2 and l < num_layer - 2: # 2 layers = linear network, >2 layers, relu net
                    main.add_module('relu{0}'.format(l), nn.ReLU())
        self.main = main

    def forward(self, input):
        output = self.main(input)
        return output


In [4]:
# a MLP generator to generate the z_i's
class MLP_Generator(nn.Module):
    def __init__(self, z_dim, pax_dim, plabels_dim, noise_dim, num_layer=1, num_nodes=64):
        super(MLP_Generator, self).__init__()

        self.decoder = MLP(num_layer + 2, [noise_dim + pax_dim + plabels_dim] + [num_nodes]*num_layer + [z_dim])

    # pax is parents of x , plabels are the pseudo labels of the domain
    def forward(self, noise, pax, plabels, noise_d=None):

        input_gen = torch.cat((pax, noise, plabels), axis=1)
        output = self.decoder(input_gen)
        return output

In [None]:
# discriminator + domain predictor (see if they have to be together)

class MLP_AuxClassifier(nn.Module):
    def __init__(self, z_dim, plabels_dim, do_num, num_layer=1, num_nodes=64):
        super(MLP_AuxClassifier, self).__init__()
        # classifying zi's into different domains
        self.cls = MLP(num_layer + 2, [z_dim] + [num_nodes]*num_layer +[plabels_dim])
        
        
        self.common_net = MLP(num_layer + 1, [z_dim] + [num_nodes]*num_layer, relu_final=True)
        
        self.aux_c = nn.Linear(num_nodes, cl_num)
        self.aux_c_tw = nn.Linear(num_nodes, cl_num)
        
        self.aux_d = nn.Linear(num_nodes, do_num)
        self.aux_d_tw = nn.Linear(num_nodes, do_num)

    def forward(self, input0):
        input = self.common_net(input0)
        output_c = self.aux_c(input)
        output_c_tw = self.aux_c_tw(input)
        output_d = self.aux_d(input)
        output_d_tw = self.aux_d_tw(input)
        output_cls = self.cls(input0)
        return output_c, output_c_tw, output_d, output_d_tw, output_cls


In [5]:
# generic classifier
class MLP_Classifier(nn.Module):
    def __init__(self, i_dim, cl_num, num_layer=1, num_nodes=64):
        super(MLP_Classifier, self).__init__()
        self.net = MLP(num_layer + 2, [i_dim] + [num_nodes]*num_layer +[cl_num])

    def forward(self, input):
        output_c = self.net(input)
        return output_c


In [None]:
# an MLP regressor to predict x from z
# ask whether the p_labels have to be added too or what?

class MLP_Regressor(nn.Module):
    def __init__(self, z_dim, num_layer=1, num_nodes=64):
        super(MLP_Regressor, self).__init__()
        self.net = MLP(num_layer + 2, [z_dim] + [num_nodes]*num_layer +[1])

    def forward(self, input):
        output_c = self.net(input)
        return output_c

In [6]:
# The graph nodes.
class Data(object):
    def __init__(self, name):
        self.__name = name
        self.__links = set()

    @property
    def name(self):
        return self.__name

    @property
    def links(self):
        return set(self.__links)

    def add_link(self, other):
        self.__links.add(other)
        other.__links.add(self)

In [None]:
# Class to represent a graph, for topological sort of DAG
class Graph:
    def __init__(self, vertices):
        self.graph = defaultdict(list)  # dictionary containing adjacency List
        self.V = vertices  # No. of vertices

    # function to add an edge to graph
    def addEdge(self, u, v):
        self.graph[u].append(v)

    # A recursive function used by topologicalSort
    def topologicalSortUtil(self, v, visited, stack):

        # Mark the current node as visited.
        visited[v] = True

        # Recur for all the vertices adjacent to this vertex
        for i in self.graph[v]:
            if visited[i] is False:
                self.topologicalSortUtil(i, visited, stack)

        # Push current vertex to stack which stores result
        stack.insert(0, v)

    # The function to do Topological Sort. It uses recursive
    # topologicalSortUtil()
    def topologicalSort(self):
        # Mark all the vertices as not visited
        visited = [False] * self.V
        stack = []

        # Call the recursive helper function to store Topological
        # Sort starting from all vertices one by one
        for i in range(self.V):
            if visited[i] is False:
                self.topologicalSortUtil(i, visited, stack)

        # Return contents of stack
        return stack


In [None]:
# a decoder according to a DAG
class DAG_Generator(nn.Module):
    def __init__(self, num_vertices, plabels_dim, pax_dim, z_dim, num_layer=1, num_nodes=64, dagMat=None):
        super(DAG_Generator, self).__init__()
        
        # create a dag with the correct number of vertices
        dag = Graph(num_vertices)

        # from the adjacency matrix add edges to the dag
        for i in range(num_vertices):
            for j in range(num_vertices):
                if dagMat[j, i]:
                    dag.addEdge(i, j)

        # extract y and d signs
        
        ###########check this#################
        self.yd_sign = dagMat[:, -2:]
        dagMat = dagMat[:, :-2]

        # topological sort to get all the vertices in order
        nodeSort = dag.topologicalSort()
        # dim of parents of the node
        numInput = dagMat.sum(1)
        
        # why do we need this? for every vertex we get a different latent
        self.dnet = nn.Linear(do_num, do_dim * num_vertices, bias=False)
        
        self.cnet = nn.Linear(cl_num, cl_dim * num_vertices, bias=False)

        # construct generative network according to the dag
        nets = nn.ModuleList()
        for i in range(num_vertices):
            num_nodesIn = int(numInput[i]) + cl_dim + do_dim + z_dim
            num_nodes_i = [num_nodesIn] + [num_nodes]*num_layer + [1]
            netMB = MLP(num_layer + 2, num_nodes_i)
            nets.append(netMB)

        # prediction network
        self.nets = nets
        self.nodeSort = nodeSort
        self.nodesA = np.array(range(i_dim)).reshape(i_dim, 1).tolist()
        self.i_dim = i_dim
        self.i_dimNew = i_dim
        self.do_num = do_num
        self.cl_num = cl_num
        self.cl_dim = cl_dim
        self.do_dim = do_dim
        self.z_dim = z_dim
        self.dagMat = dagMat
        self.numInput = numInput
        self.is_reg = is_reg
        self.ischain = False

        # inputs: class indicator, domain indicator, noise, features
        # separate forward for each factor
    def forward_indep(self, noise, input_c, input_d, input_x, noise_d=None, device='cpu'):
        # class parameter network
        batch_size = input_c.size(0)
        if self.is_reg:
            inputs_c = input_c.view(batch_size, 1)
        else:
            inputs_c = self.cnet(input_c)
        if self.prob:
            theta = self.mu + torch.mul(torch.log(1+torch.exp(self.sigma)), noise_d)
            inputs_d = torch.matmul(input_d, theta)
        else:
            inputs_d = self.dnet(input_d)

        inputs_n = noise
        inputs_f = input_x

        # create output array
        output = torch.zeros((batch_size, len(self.nodeSort)))
        output = output.to(device)

        # create a network for each module
        for i in self.nodeSort:
            inputs_pDim = self.numInput[i]
            if inputs_pDim > 0:
                index = np.argwhere(self.dagMat[i, :])
                index = index.flatten()
                index = [int(j) for j in index]
                inputs_p = inputs_f[:, index] # get the parent data from real data, not fake data!!!
            if not self.is_reg:
                inputs_ci = inputs_c[:, i*self.cl_dim:(i+1)*self.cl_dim]
            else:
                inputs_ci = inputs_c
            inputs_di = inputs_d[:, i*self.do_dim:(i+1)*self.do_dim]
            inputs_ni = inputs_n[:, i*self.z_dim:(i+1)*self.z_dim]
            if inputs_pDim > 0:
                inputs_i = torch.cat((inputs_ci, inputs_di, inputs_ni, inputs_p), 1)
            else:
                inputs_i = torch.cat((inputs_ci, inputs_di, inputs_ni), 1)

            output[:, i] = self.nets[i](inputs_i).squeeze()

        if self.prob:
            KL_reg = 1 + torch.log(torch.log(1+torch.exp(self.sigma))**2) - self.mu**2 - torch.log(1+torch.exp(self.sigma))**2
            if KL_reg.shape[1] > 1:
                KL_reg = KL_reg.sum(axis=1)
            return output, -KL_reg
        else:
            return output

    # inputs: class indicator, domain indicator, noise
    # forward for all factors in a graph
    def forward(self, noise, input_c, input_d, device='cpu', noise_d=None):
        # class parameter network
        batch_size = input_c.size(0)
        if self.is_reg:
            inputs_c = input_c.view(batch_size, 1)
        else:
            inputs_c = self.cnet(input_c)
        if self.prob:
            theta = self.mu + torch.mul(torch.log(1+torch.exp(self.sigma)), noise_d)
            inputs_d = torch.matmul(input_d, theta)
        else:
            inputs_d = self.dnet(input_d)

        inputs_n = noise

        output = torch.zeros((batch_size, len(self.nodeSort)))
        output = output.to(device)

        # create a network for each module
        for i in self.nodeSort:
            inputs_pDim = self.numInput[i]
            if inputs_pDim > 0:
                index = np.argwhere(self.dagMat[i, :])
                index = index.flatten()
                index = [int(j) for j in index]
                inputs_p = output[:, index]

            if not self.is_reg:
                inputs_ci = inputs_c[:, i * self.cl_dim:(i + 1) * self.cl_dim]
            else:
                inputs_ci = inputs_c
            inputs_di = inputs_d[:, i * self.do_dim:(i + 1) * self.do_dim]
            inputs_ni = inputs_n[:, i * self.z_dim:(i + 1) * self.z_dim]
            if inputs_pDim > 0:
                inputs_i = torch.cat((inputs_ci, inputs_di, inputs_ni, inputs_p), 1)
            else:
                inputs_i = torch.cat((inputs_ci, inputs_di, inputs_ni), 1)

            output[:, i] = self.nets[i](inputs_i).squeeze()

        return output

In [None]:
# trainer 

# DA_Infer trainer on a graph, deterministic/probabilistic theta encoder, implemented by joint MMD
class DA_Infer_JMMD_DAG(object):
    def __init__(self, config):
        #num_vertices
        input_dim = config['idim']
        num_class = config['num_class']
        num_domain = config['num_domain']
        dim_class = config['dim_y']
        dim_domain = config['dim_d']
        dim_hidden = config['dim_z']
        G_num_layer = config['G_mlp_layers']
        G_num_nodes = config['G_mlp_nodes']
        D_num_layer = config['D_mlp_layers']
        D_num_nodes = config['D_mlp_nodes']
        Dec_num_layer = config['Dec_mlp_layers']
        Dec_num_nodes = config['Dec_mlp_nodes']
        is_reg = config['is_reg']
        dag_mat_file = join(config['data_root'], config['dataset'], config['dag_mat_file'])
        npzfile = np.load(dag_mat_file)
        dag_mat = npzfile['mat']
        
        isProb = config['estimate'] == 'Bayesian'

        # is our model a dag generator or a pdag one?
        if config['G_model'] == 'DAG_Generator':
            self.gen = DAG_Generator(input_dim, num_class, num_domain, dim_class, dim_domain, dim_hidden, G_num_layer,
                                     G_num_nodes, is_reg, dag_mat, prob=isProb)
        if config['G_model'] == 'PDAG_Generator':
            self.gen = PDAG_Generator(input_dim, num_class, num_domain, dim_class, dim_domain, dim_hidden, G_num_layer,
                                     G_num_nodes, is_reg, dag_mat, prob=isProb)
            

        utils.seed_rng(config['seed'])
        
        # discriminator (decide what is to be done)
        if config['D_model'] == 'MLP_Classifier':
            self.dis = MLP_Classifier(input_dim, num_class, D_num_layer, D_num_nodes)
            
        # decoder to reconstruct x from the latent z_i's
        self.dec = MLP_Regressor(z_dim, D_num_layer, D_num_nodes)

        # set optimizer
        self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=config['G_lr'],
                                        betas=(config['G_B1'], config['G_B2']))
        self.dis_opt = torch.optim.Adam(self.dis.parameters(), lr=config['D_lr'],
                                        betas=(config['D_B1'], config['D_B2']))

        self.aux_loss_func = nn.CrossEntropyLoss()
        self.mmd_loss = 0
        self.mmd_loss_s = 0
        self.mmd_loss_t = 0
        self.aux_loss_c = 0

    def to(self, device):
        self.gen.to(device)
        self.dis.to(device)
        
    
    # separate training of each module, trained on source + target domain together
    # what is a module?
    def gen_update(self, x_a, y_a, config, state, device='cpu'):
 
        self.gen.zero_grad()
        self.dis.zero_grad()
        input_dim = config['idim']
        dim_domain = config['dim_d']
        batch_size = config['batch_size']
        noise_dim = config['dim_noise']
        z_dim = config['dim_z']
        num_domain = config['num_domain']
        num_class = config['num_class']
        is_reg = config['is_reg']
        do_ss = config['do_ss']
        
        # in our case dim_domain and num_domain are the same
        # create one-hot labels as the pseudo lables 
        
        # y first col is the class info, second col is the domain label

        # generate random Gaussian noise (input_dim is the number of vertices)
        if dim_hidden != 0:
            noise = torch.randn((batch_size, dim_hidden * input_dim), device=device)

        # create domain labels (is this the class domain?)
        if not is_reg:
            y_a_onehot = torch.nn.functional.one_hot(y_a[:, 0], num_class).float()
        else:
            y_a_onehot = y_a[:, 0].view(batch_size, 1)

        # create our pseudo_labels
        d_onehot = torch.nn.functional.one_hot(y_a[:, 1], num_domain).float()

        # do we need y_a_onehot as the input?
        if config['estimate'] == 'ML':
            fake_x_a = self.gen.forward_indep(noise, y_a_onehot, d_onehot, x_a, device=device)
            fake_x_a_cls = self.gen(noise, y_a_onehot, d_onehot, device=device)
            
            # not doing this
        elif config['estimate'] == 'Bayesian':
            noise_d = torch.randn(num_domain, dim_domain * input_dim).to(device)
            fake_x_a, KL_reg = self.gen.forward_indep(noise, y_a_onehot, d_onehot, x_a, device=device, noise_d=noise_d)
            fake_x_a_cls = self.gen(noise, y_a_onehot, d_onehot, device=device, noise_d=noise_d)

        # sigma for MMD
        base_x = config['base_x']
        sigma_list = [0.125, 0.25, 0.5, 1]
        # sigma_list = [0.25, 0.5, 1]
        sigma_listx = [sigma * base_x for sigma in sigma_list]

        ids_s = y_a[:, 1] != num_domain - 1
        ids_t = y_a[:, 1] == num_domain - 1

        # Train mode 0: only use MMD for G
        
        # create our own train_mode
        if config['train_mode'] == 'm0':  # no bp from classifier C to G
            output_cf = self.dis(fake_x_a_cls.detach())
            aux_loss_c = self.aux_loss_func(output_cf[ids_t], y_a[ids_t, 0])

        if config['train_mode'] == 'm1':  # bp from classifier C to G
            output_cr = self.dis(x_a[ids_s])
            output_cf = self.dis(fake_x_a_cls)
            lambda_src = config['SRC_weight']
            if state['epoch'] < config['warmup']:
                lambda_tar = 0
            else:
                lambda_tar = config['TAR_weight']
            aux_loss_c_src = lambda_src * self.aux_loss_func(output_cr, y_a[ids_s, 0])
            aux_loss_c_tar = lambda_tar * self.aux_loss_func(output_cf[ids_t], y_a[ids_t, 0])
            aux_loss_c = aux_loss_c_src + aux_loss_c_tar

        # this does not have to be changed
        # MMD matching for each factor
        batch_size_s = len(y_a[ids_s, :])
        # batch_size_t = len(y_a[ids_t, :])
        errG_s = torch.zeros(len(self.gen.nodeSort), device=device)
        # errG_t = torch.zeros(len(self.gen.nodeSort), device=device)

        for i in self.gen.nodeSort:
            input_pDim = self.gen.numInput[i]
            if input_pDim > 0:
                if not self.gen.ischain:
                    output_dim = 1
                    index = np.argwhere(self.gen.dagMat[i, :])
                    index = index.flatten()
                    index = [int(j) for j in index]
                else:
                    output_dim = len(self.gen.nodesA[i])
                    if output_dim == 1:
                        index = np.argwhere(self.gen.dagMat[self.gen.nodesA[i][0], :])
                        index = index.flatten()
                        index = [int(j) for j in index]
                    else:
                        index = np.argwhere(self.gen.dagMatNew[i, :])
                        index = index.flatten()
                        index = [self.gen.nodesA[j] for j in index]
                        index = list(itertools.chain.from_iterable(index))
                        index = [int(j) for j in index]
                input_p = x_a[:, index].view(batch_size, len(index))
                errG_s[i] = mix_rbf_mmd2_joint(fake_x_a[ids_s, self.gen.nodesA[i]].view(batch_size_s, output_dim),
                                             x_a[ids_s, self.gen.nodesA[i]].view(batch_size_s, output_dim),
                                             y_a_onehot[ids_s], y_a_onehot[ids_s], d_onehot[ids_s],
                                             d_onehot[ids_s], input_p[ids_s], input_p[ids_s], sigma_list=sigma_listx)

                # errG_t[i] = mix_rbf_mmd2_joint_regress(fake_x_a[ids_t, self.gen.nodesA[i]].view(batch_size_t, output_dim),
                #                                      x_a[ids_t, self.gen.nodesA[i]].view(batch_size_t, output_dim),
                #                                      input_p[ids_t], input_p[ids_t], sigma_list=sigma_listx, sigma_list1=sigma_listx)
            else:
                if not self.gen.ischain:
                    output_dim = 1
                else:
                    output_dim = len(self.gen.nodesA[i])
                errG_s[i] = mix_rbf_mmd2_joint(fake_x_a[ids_s][:, self.gen.nodesA[i]].view(batch_size_s, output_dim),
                                             x_a[ids_s][:, self.gen.nodesA[i]].view(batch_size_s, output_dim),
                                             y_a_onehot[ids_s], y_a_onehot[ids_s], d_onehot[ids_s],
                                             d_onehot[ids_s], sigma_list=sigma_listx)
                # errG_t[i] = mix_rbf_mmd2(fake_x_a[ids_t][:, self.gen.nodesA[i]].view(batch_size_t, output_dim),
                #                        x_a[ids_t][:, self.gen.nodesA[i]].view(batch_size_t, output_dim), sigma_list=sigma_listx)

        errG_t = mix_rbf_mmd2(fake_x_a_cls[ids_t], x_a[ids_t], sigma_list=sigma_listx)

        errG_s = errG_s.mean()
        # errG_t = errG_t.mean()

        lambda_c = config['AC_weight']
        if config['estimate'] == 'ML':
            errG = errG_s + errG_t + lambda_c * aux_loss_c
        elif config['estimate'] == 'Bayesian':
            errG = errG_s + errG_t + lambda_c * aux_loss_c + torch.dot(1.0 / do_ss.to(device).squeeze(), KL_reg.squeeze())

        errG.backward()
        self.gen_opt.step()
        self.dis_opt.step()
        self.mmd_loss = errG
        self.mmd_loss_s = errG_s
        self.mmd_loss_t = errG_t
        self.aux_loss_c = aux_loss_c
        
    # have a decoder update and training
    def dec_update(self, z, x_a, config, state, device = 'cpu'):

    # discriminator update and training
    def dis_update(self, x_a, y_a, config, state, device='cpu'):
        for p in self.dis.parameters():
            p.requires_grad_(True)
        self.dis.zero_grad()
        input_dim = config['idim']
        batch_size = config['batch_size']
        z_dim = config['dim_z']
        noise_dim = config['dim_noise']
        dim_domain = config['dim_d']
        num_domain = config['num_domain']
        num_class = config['num_class']
        is_reg = config['is_reg']
        do_ss = config['do_ss']

        # generate random Gaussian noise
        if noise_dim != 0:
            noise = torch.randn((batch_size, dim_hidden * input_dim), device=device)

        # create domain labels
        if not is_reg:
            y_a_onehot = torch.nn.functional.one_hot(y_a[:, 0], num_class).float()
        else:
            y_a_onehot = y_a[:, 0].view(batch_size, 1)

        d_onehot = torch.nn.functional.one_hot(y_a[:, 1], num_domain).float()

        if config['estimate'] == 'ML':
            fake_x_a_cls = self.gen(noise, y_a_onehot, d_onehot, device=device)
        elif config['estimate'] == 'Bayesian':
            noise_d = torch.randn(num_domain, dim_domain * input_dim).to(device)
            fake_x_a_cls = self.gen(noise, y_a_onehot, d_onehot, device=device, noise_d=noise_d)

        ids_s = y_a[:, 1] != num_domain - 1
        ids_t = y_a[:, 1] == num_domain - 1
        output_cr = self.dis(x_a[ids_s])
        output_cf = self.dis(fake_x_a_cls.detach())
        lambda_src = config['SRC_weight']
        lambda_tar = config['TAR_weight']
        aux_loss_c_src = lambda_src * self.aux_loss_func(output_cr, y_a[ids_s, 0])
        aux_loss_c_tar = lambda_tar * self.aux_loss_func(output_cf[ids_t], y_a[ids_t, 0])
        aux_loss_c = aux_loss_c_src + aux_loss_c_tar
        aux_loss_c.backward()
        self.dis_opt.step()
        self.aux_loss_c = aux_loss_c

    # resume training from the saved weights
    def resume(self, snapshot_prefix):
        gen_filename = snapshot_prefix + '_gen.pkl'
        dis_filename = snapshot_prefix + '_dis.pkl'
        state_filename = snapshot_prefix + '_state.pkl'
        self.gen.load_state_dict(torch.load(gen_filename))
        self.dis.load_state_dict(torch.load(dis_filename))
        state_dict = torch.load(state_filename)
        print('Resume the model')
        return state_dict

    # save the model parameters
    def save(self, snapshot_prefix, state_dict):
        gen_filename = snapshot_prefix + '_gen.pkl'
        dis_filename = snapshot_prefix + '_dis.pkl'
        state_filename = snapshot_prefix + '_state.pkl'
        torch.save(self.gen.state_dict(), gen_filename)
        torch.save(self.dis.state_dict(), dis_filename)
        torch.save(state_dict, state_filename)