In [36]:
from DagmaDCE import utils, nonlinear, nonlinear_dce
import torch, gpytorch
import time
import numpy as np
import matplotlib.pyplot as plt
from CausalDisco.analytics import r2_sortability, var_sortability
from CausalDisco.baselines import r2_sort_regress, var_sort_regress
from cdt.metrics import SID
from scipy.stats import kendalltau
import seaborn as sns
sns.set_context("paper")


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
torch.set_default_device(device)

torch.set_default_dtype(torch.double)
utils.set_random_seed(0)
torch.manual_seed(0)

reestimate_graph = False
RESULTS_DIR = ''

In [37]:
print('>>> Generating Data <<<')
n, d, s0, graph_type, sem_type = 1000, 10, 20, 'ER', 'gauss'
B_true = utils.simulate_dag(d, s0, graph_type)
W_true = utils.simulate_parameter(B_true)
X = utils.simulate_linear_sem(W_true, n, sem_type)

results_r2_sort_regress = r2_sort_regress(X)
acc_r2_sort_regress = utils.count_accuracy(
    B_true, results_r2_sort_regress != 0)
sid_r2_sort_regress = SID(B_true, results_r2_sort_regress != 0).item()
print('[Var Sort Regress Results] Var-Sortability of X:',
      r2_sortability(X, W_true))
print('[Var Sort Regress Results] SHD:',
      acc_r2_sort_regress['shd'], '| SID:', sid_r2_sort_regress, '| F1:', acc_r2_sort_regress['f1'])

results_var_sort_regress = var_sort_regress(X)
acc_var_sort_regress = utils.count_accuracy(
    B_true, results_var_sort_regress != 0)
sid_var_sort_regress = SID(B_true, results_var_sort_regress != 0).item()
print('[R^2 Sort Regress Results] R^2-Sortability of X:',
      var_sortability(X, W_true))
print('[R^2 Sort Regress Results] SHD:',
      acc_var_sort_regress['shd'], '| SID:', sid_var_sort_regress, '| F1:', acc_var_sort_regress['f1'])

X = torch.from_numpy(X).to(device)
print(f"X shape: {X.shape}")
print(f"B_true shape: {B_true.shape}")

>>> Generating Data <<<
[Var Sort Regress Results] Var-Sortability of X: 0.7142857142857143
[Var Sort Regress Results] SHD: 24 | SID: 43.0 | F1: 0.45614035087719296
[R^2 Sort Regress Results] R^2-Sortability of X: 1.0
[R^2 Sort Regress Results] SHD: 0 | SID: 0.0 | F1: 1.0
X shape: torch.Size([1000, 10])
B_true shape: (10, 10)


In [48]:
import copy
import torch
import torch
import torch.nn as nn
import numpy as np
from torch import optim
import copy
from tqdm.auto import tqdm
import abc
import typing
import gpytorch
from gpytorch.mlls import SumMarginalLogLikelihood

class Dagma_DCE_Module(nn.Module, abc.ABC):
    @abc.abstractmethod
    def get_graph(self, x: torch.Tensor) -> torch.Tensor:
        ...

    @abc.abstractmethod
    def h_func(self, W: torch.Tensor, s: float) -> torch.Tensor:
        ...

    @abc.abstractmethod
    def get_l1_reg(self, W: torch.Tensor) -> torch.Tensor:
        ...


class DagmaDCE:
    def __init__(self, model: Dagma_DCE_Module, use_mse_loss=True):
        """Initializes a DAGMA DCE model. Requires a `DAGMA_DCE_Module`

        Args:
            model (Dagma_DCE_Module): module implementing adjacency matrix,
                h_func constraint, and L1 regularization
            use_mse_loss (bool, optional): to use MSE loss instead of log MSE loss.
                Defaults to True.
        """
        self.model = model
        self.loss = self.mse_loss if use_mse_loss else self.log_mse_loss


    def mse_loss(self, output: torch.Tensor, target: torch.Tensor):
        """Computes the MSE loss sum (output - target)^2 / (2N)"""
        n, d = target.shape
        print(type(output))
        if isinstance(output, torch.distributions.MultivariateNormal):
            output_mean = output.mean
        else:
            output_mean = output
        print(f"output's type is: {type(output)}")
        print(f"target's type is: {type(target)}")
        return 0.5 / n * torch.sum((output_mean - target) ** 2)


    def log_mse_loss(self, output: torch.Tensor, target: torch.Tensor):
        """Computes the MSE loss d / 2 * log [sum (output - target)^2 / N ]"""
        n, d = target.shape
        loss = 0.5 * d * torch.log(1 / n * torch.sum((output - target) ** 2))
        return loss

    def minimize(
        self,
        max_iter: int,
        lr: float,
        lambda1: float,
        lambda2: float,
        mu: float,
        s: float,
        pbar: tqdm,
        lr_decay: bool = False,
        checkpoint: int = 1000,
        tol: float = 1e-3,
    ):
        """Perform minimization using the barrier method optimization

        Args:
            max_iter (int): maximum number of iterations to optimize
            lr (float): learning rate for adam
            lambda1 (float): regularization parameter
            lambda2 (float): weight decay
            mu (float): regularization parameter for barrier method
            s (float): DAMGA constraint hyperparameter
            pbar (tqdm): progress bar to use
            lr_decay (bool, optional): whether or not to use learning rate decay.
                Defaults to False.
            checkpoint (int, optional): how often to checkpoint. Defaults to 1000.
            tol (float, optional): tolerance to terminate learning. Defaults to 1e-3.
        """
        print(f"model parameters are:")
        for param in self.model.parameters():
            print(param)
        optimizer = optim.Adam(
            self.model.parameters(),
            lr=lr,
            betas=(0.99, 0.999),
            weight_decay=mu * lambda2,
        )

        trainable_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad
        )

        print(f"total params is: {trainable_params}")
        obj_prev = 1e16

        scheduler = optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=0.8 if lr_decay else 1.0
        )

        for i in range(max_iter):
            optimizer.zero_grad()

            if i == 0:
                # X_hat = self.model(self.X)
                print(self.model(self.X))
                print(self.model(self.X).shape)
                # print(f"shape of self.model(self.X) is: {self.model(self.X)}")
                X_hat = self.model(self.X).mean  # Extract mean from MultivariateNormal
                score = self.loss(X_hat, self.X)
                obj = score

            else:
                W_current, observed_derivs = self.model.get_graph(self.X)
                h_val = self.model.h_func(W_current, s)

                if h_val.item() < 0:
                    return False
                # print(self.model(self.X))
                # X_hat = self.model(self.X)
                X_hat = self.model(self.X).mean  # Extract mean from MultivariateNormal
                score = self.mse_loss(X_hat, self.X)

                l1_reg = lambda1 * self.model.get_l1_reg(observed_derivs)

                obj = mu * (score + l1_reg) + h_val

            obj.backward()
            optimizer.step()

            if lr_decay and (i + 1) % 1000 == 0:
                scheduler.step()

            if i % checkpoint == 0 or i == max_iter - 1:
                obj_new = obj.item()

                if np.abs((obj_prev - obj_new) / (obj_prev)) <= tol:
                    pbar.update(max_iter - i)
                    break
                obj_prev = obj_new

            pbar.update(1)

        return True

    def fit(
        self,
        X: torch.Tensor,
        lambda1: float = 0.02,
        lambda2: float = 0.005,
        T: int = 4,
        mu_init: float = 1.0,
        mu_factor: float = 0.1,
        s: float = 1.0,
        warm_iter: int = 5e3,
        max_iter: int = 8e3,
        lr: float = 1e-3,
        disable_pbar: bool = False,
    ) -> torch.Tensor:
        """Fits the DAGMA-DCE model

        Args:
            X (torch.Tensor): inputs
            lambda1 (float, optional): regularization parameter. Defaults to 0.02.
            lambda2 (float, optional): weight decay. Defaults to 0.005.
            T (int, optional): number of barrier loops. Defaults to 4.
            mu_init (float, optional): barrier path coefficient. Defaults to 1.0.
            mu_factor (float, optional): decay parameter for mu. Defaults to 0.1.
            s (float, optional): DAGMA constraint hyperparameter. Defaults to 1.0.
            warm_iter (int, optional): number of warmup models. Defaults to 5e3.
            max_iter (int, optional): maximum number of iterations for learning. Defaults to 8e3.
            lr (float, optional): learning rate. Defaults to 1e-3.
            disable_pbar (bool, optional): whether or not to use the progress bar. Defaults to False.

        Returns:
            torch.Tensor: graph returned by the model
        """
        mu = mu_init
        self.X = X

        with tqdm(total=(T - 1) * warm_iter + max_iter, disable=disable_pbar) as pbar:
            for i in range(int(T)):
                success, s_cur = False, s
                lr_decay = False

                inner_iter = int(max_iter) if i == T - 1 else int(warm_iter)
                model_copy = copy.deepcopy(self.model)

                while success is False:
                    success = self.minimize(
                        inner_iter,
                        lr,
                        lambda1,
                        lambda2,
                        mu,
                        s_cur,
                        lr_decay=lr_decay,
                        pbar=pbar,
                    )

                    if success is False:
                        self.model.load_state_dict(model_copy.state_dict().copy())
                        lr *= 0.5
                        lr_decay = True
                        if lr < 1e-10:
                            print(":(")
                            break  # lr is too small

                    mu *= mu_factor
                    print(f"Success is {success}")

        return self.model.get_graph(self.X)[0]

In [49]:
class DagmaGP_DCE(Dagma_DCE_Module):
    def __init__(self, train_x, num_tasks, lr=0.1, training_iterations=50):
        super(DagmaGP_DCE, self).__init__()
        self.num_tasks = num_tasks
        self.lr = lr
        self.training_iterations = training_iterations

        self.models = []
        self.likelihoods = []
        for i in range(self.num_tasks):
            likelihood = gpytorch.likelihoods.GaussianLikelihood()
            model = ExactGPModel(train_x, train_x[:, i], likelihood)
            self.models.append(model)
            self.likelihoods.append(likelihood)

        self.model_list = gpytorch.models.IndependentModelList(*self.models)
        self.likelihood_list = gpytorch.likelihoods.LikelihoodList(*self.likelihoods)
        self.mll = SumMarginalLogLikelihood(self.likelihood_list, self.model_list)

        print("Initial Lengthscales:")
        for i, model in enumerate(self.model_list.models):
            lengthscale = model.covar_module.base_kernel.lengthscale
            print(f"Model {i} initial lengthscale: {lengthscale}")

    def forward(self, x):

        for model in self.model_list.models:
            model.eval()

        predictive_means = []
        predictive_variances = []
        for model, likelihood in zip(self.model_list.models, self.likelihood_list.likelihoods):
            with torch.no_grad():
                observed_pred = likelihood(model(x))
                predictive_means.append(observed_pred.mean)
                predictive_variances.append(observed_pred.variance)

        return torch.stack(predictive_means)

    def train(self, train_x):

        self.model_list.train()
        self.likelihood_list.train()
        optimizer = torch.optim.Adam(self.model_list.parameters(), lr=self.lr)

        for i in range(self.training_iterations):
            optimizer.zero_grad()
            output = self.model_list(*[train_x for _ in range(self.num_tasks)])
            targets = [train_x[:, j] for j in range(self.num_tasks)]
            loss = -self.mll(output, targets)
            loss.backward()
            optimizer.step()
            print('Iter %d/%d - Loss: %.3f' % (i + 1, self.training_iterations, loss.item()))
        
        print("Final Lengthscales:")
        for i, model in enumerate(self.model_list.models):
            lengthscale = model.covar_module.base_kernel.lengthscale
            print(f"Model {i} initial lengthscale: {lengthscale}")

    def get_graph(self, x):
        for model in self.model_list.models:
            model.eval()

        derivative = torch.zeros(len(x), self.num_tasks, 10)
        for i in range(len(x)):
            print(f"Computing Jacobian for data point {i}")
            input_vector = x[i].unsqueeze(0).detach().requires_grad_(True)
            for j, model in enumerate(self.model_list.models):
                model.zero_grad()
                observed_pred = self.likelihoods[j](model(input_vector))
                mean = observed_pred.mean
                mean.backward()
                derivative[i, j] = input_vector.grad.clone()
                input_vector.grad.zero_()

        mean_squared_jacobians = torch.mean(derivative ** 2, dim=0)
        W = torch.sqrt(mean_squared_jacobians)
        return W, derivative

    def h_func(self, W: torch.Tensor, s: float = 1.0) -> torch.Tensor:
        """Calculate the DAGMA constraint function

        Args:
            W (torch.Tensor): adjacency matrix
            s (float, optional): hyperparameter for the DAGMA constraint,
                can be any positive number. Defaults to 1.0.

        Returns:
            torch.Tensor: constraint
        """
        h = -torch.slogdet(s * self.I - W * W)[1] + self.d * np.log(s)
        return h

    def get_l1_reg(self, observed_derivs: torch.Tensor) -> torch.Tensor:
        """Gets the L1 regularization

        Args:
            observed_derivs (torch.Tensor): the batched Jacobian matrix

        Returns:
            torch.Tensor: _description_
        """
        return torch.sum(torch.abs(torch.mean(observed_derivs, axis=0)))
    
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [50]:
eq_model = DagmaGP_DCE(train_x=X, num_tasks = 10, lr=0.1, training_iterations=50)
model = DagmaDCE(eq_model)

Initial Lengthscales:
Model 0 initial lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
Model 1 initial lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
Model 2 initial lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
Model 3 initial lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
Model 4 initial lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
Model 5 initial lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
Model 6 initial lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
Model 7 initial lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
Model 8 initial lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)
Model 9 initial lengthscale: tensor([[0.6931]], grad_fn=<SoftplusBackward0>)


In [52]:
eq_model.train(train_x=X)

# W_est_dce = model.fit(X, lambda1=0, lambda2=5e-3,
#                       lr=2e-4, mu_factor=0.1, mu_init=1, T=4, warm_iter=1*5000, max_iter=1*7000)

Iter 1/50 - Loss: 3.394
Iter 2/50 - Loss: 3.248
Iter 3/50 - Loss: 3.109
Iter 4/50 - Loss: 2.976
Iter 5/50 - Loss: 2.850
Iter 6/50 - Loss: 2.727
Iter 7/50 - Loss: 2.607
Iter 8/50 - Loss: 2.491
Iter 9/50 - Loss: 2.377
Iter 10/50 - Loss: 2.269
Iter 11/50 - Loss: 2.164
Iter 12/50 - Loss: 2.069
Iter 13/50 - Loss: 1.975
Iter 14/50 - Loss: 1.889
Iter 15/50 - Loss: 1.810
Iter 16/50 - Loss: 1.736
Iter 17/50 - Loss: 1.669
Iter 18/50 - Loss: 1.605
Iter 19/50 - Loss: 1.544
Iter 20/50 - Loss: 1.489
Iter 21/50 - Loss: 1.435
Iter 22/50 - Loss: 1.381
Iter 23/50 - Loss: 1.335
Iter 24/50 - Loss: 1.286
Iter 25/50 - Loss: 1.242
Iter 26/50 - Loss: 1.197
Iter 27/50 - Loss: 1.155
Iter 28/50 - Loss: 1.116
Iter 29/50 - Loss: 1.077
Iter 30/50 - Loss: 1.029
Iter 31/50 - Loss: 0.992
Iter 32/50 - Loss: 0.952
Iter 33/50 - Loss: 0.917
Iter 34/50 - Loss: 0.879
Iter 35/50 - Loss: 0.836
Iter 36/50 - Loss: 0.801
Iter 37/50 - Loss: 0.761
Iter 38/50 - Loss: 0.724
Iter 39/50 - Loss: 0.693
Iter 40/50 - Loss: 0.654
Iter 41/5

In [None]:
rms_jacobian, J = eq_model.get_graph(X)
print(f"rms_jacobian is: {rms_jacobian}")
print(f"J is: {J}")