In [1]:
import sys
from os import path
sys.path.append(path.join("..", "src"))

import numpy as np
import pandas as pd
import torch
import torchvision
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, ConcatDataset, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import KBinsDiscretizer
from packages.spn.experiments.RandomSPNs_layerwise.rat_spn import RatSpn, RatSpnConfig
from packages.spn.experiments.RandomSPNs_layerwise.distributions import RatNormal
from packages.spn.algorithms.layerwise.distributions import Bernoulli, Categorical
from utils.datasets import gen_dataset
from utils.config_utils import load_config_data
from utils.utils import visualize_3d
from utils.selectors import get_sim_dataloader
from constraint.constraints import GeneralizationConstraint, EqualityConstraint, AbstractConstraint, get_outputs
import time
import argparse
from tqdm import tqdm
from pathlib import Path

In [2]:
def make_spn(S, I, R, D, F, C, device, leaf_base_class, leaf_base_kwargs=None) -> RatSpn:
        """Construct the RatSpn"""

        # Setup RatSpnConfig
        config = RatSpnConfig()
        config.F = F
        config.R = R
        config.D = D
        config.I = I
        config.S = S
        config.C = C
        config.dropout = 0.0
        config.leaf_base_class = leaf_base_class 
        config.leaf_base_kwargs = {} if leaf_base_kwargs is None else leaf_base_kwargs

        # Construct RatSpn from config
        model = RatSpn(config)

        model = model.to(device)
        model.train()

        print("Using device:", device)
        return model

k = 3
def get_dataset(name: str):
    if name == "cleveland":
        names = ["age", "sex", "cp", "trestbps", "chol", "fbs", "restecg", "thalach", "exang", "oldpeak", "slope", "ca",
                 "thal", "num"]
        url = "https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data"
        frame = pd.read_csv(url, names=names, na_values="?").dropna()
        frame = frame[["age", "sex", "cp", "trestbps", "chol", "fbs", "restecg", "num", "thalach", "exang", "oldpeak", "slope"]]
        frame["unhealthy"] = (frame.num.astype(int) != 0).astype(int)
        frame.drop(['num'], axis=1, inplace=True)
        frame["chol"] = ((frame["chol"] < 200) | (frame["chol"] > 240)) # not normal
        frame["trestbps"] = pd.cut(frame["trestbps"], [0, 120, 140, np.inf], labels = np.arange(3))
        frame["restecg"] = (frame["restecg"] != 0) # not normal
        frame["cp"] = (frame["cp"] != 4) # chest pain present
        
        frame["age"] = pd.cut(frame["age"], [0, 40, 60, np.inf], labels = np.arange(3))
        
        for name in ["thalach", "oldpeak", "slope"]:
          frame[name] = KBinsDiscretizer(n_bins=k, encode='ordinal', strategy='kmeans') \
            .fit_transform(frame[name].to_numpy().reshape(-1, 1)) \
            .flatten().astype(int)
          
        for name in frame.columns:
          frame[name] = frame[name].astype(int)
        
        r = [(m + 1) for i, m in enumerate(frame.to_numpy().max(axis=0))]
        train, test = train_test_split(frame, test_size=0.5, random_state=0, stratify=frame.unhealthy)
        return train, test, r
        

def get_loaders(name: str, use_cuda, batch_size):

    train, test, r = get_dataset(name)
    
    kwargs = {"num_workers": 8, "pin_memory": True} if use_cuda else {}

    test_batch_size = batch_size

    
    train_dataset, test_dataset = TensorDataset(torch.Tensor(train.to_numpy())), TensorDataset(torch.Tensor(test.to_numpy()))
    # Train data loader
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        **kwargs,
    )

    # Test data loader
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=test_batch_size,
        shuffle=False,
        **kwargs,
    )
    return r, train.columns.tolist(), train_dataset, test_dataset, train_loader, test_loader




r, names, train_dataset, test_dataset, train_loader, test_loader = get_loaders("cleveland", False, 32)
rat_S, rat_I, rat_D, rat_R, rat_C, leaves = 20, 20, 2, 5, 1,Categorical #RatNormal
n_features = train_loader.dataset[0][0].shape[0]
device=torch.device("cuda")
dropout=0
model = make_spn(S=rat_S, I=rat_I, D=rat_D, R=rat_R, device=device, F=n_features, C=rat_C,leaf_base_class=leaves, leaf_base_kwargs=dict(num_bins=max(r)))



Using device: cuda


In [48]:
def predict_proba(model, r, data, target_index, marg_indices=None, device='cpu'):
    log_p = torch.zeros((len(data), r[target_index]), device=device)
    log_denom = model(data, (target_index,)) if marg_indices is None else model(data, (target_index,*marg_indices))
    log_denom = log_denom.ravel()
    for i in range(r[target_index]):
        data_i = data.clone()
        data_i[:, target_index] = i
        log_numer = model(data_i) if marg_indices is None else model(data_i, marg_indices)
        log_p[:, i] = log_numer.ravel() - log_denom
    
    return torch.softmax(log_p, axis=1)

class ContextSpecificIndependence(EqualityConstraint):
    def __init__(self, X, Y, Z, z, r):
        # X \indep Y | Z = z
        self.X = X
        self.Y = Y
        self.Z = Z
        self.z = z
        self.r = r
        super().__init__()
    
    def violation(self, model, dataset, config_data, device="cpu", **kwargs):
        # P(X | Y, Z = z) = P(X | Z = z) 
        
        data = torch.zeros((self.r[self.Y], n_features), device=device)
        for i in range(self.r[self.Y]):
            data[i, self.Y] = i
            data[i, self.Z] = self.z
        
        marg_indices = [i for i in range(n_features) if i not in (self.X, self.Y, self.Z )]
        p1 = predict_proba(model, self.r, data, self.X, marg_indices, device)
        p2 = predict_proba(model, self.r, data, self.X, marg_indices + [self.Y], device)
        delta = self.delta(p1,p2)
        violation = self.degree_violation(delta)
        return violation / (self.r[self.X]*self.r[self.Y])
            
class InequalityConstraint(AbstractConstraint):
    def __init__(self, sign, epsilon):
        super().__init__()
        self.sign = sign
        self.epsilon = epsilon
        
    def delta(self, output_1, output_2):
        delta = torch.sub(output_1, output_2)*self.sign + self.epsilon
        return delta
    def degree_violation(self, delta):
        return torch.sum(torch.max(delta, torch.tensor(0.0, device=device))**2)

class MonotonicityConstraint(InequalityConstraint):
    def __init__(self, Xj, Xi, r, sign, epsilon):
        super().__init__(sign, epsilon)
        self.Xj = Xj
        self.Xi = Xi
        self.r = r
        
    def violation(self, model, dataset, config_data, device="cpu", **kwargs):
        n_features = len(r)
        marg_indices = [i for i in range(n_features) if i not in (self.Xi, self.Xj)]
        data = torch.zeros((self.r[self.Xj], n_features), device=device)
        for i in range(self.r[self.Xj]):
            data[i, self.Xj] = i
        cdf = torch.cumsum(predict_proba(model, self.r, data, self.Xi, marg_indices,device=device), axis=1)
        
        total = torch.tensor(0.0, device=device)
        count = 0
        for xi in range(self.r[self.Xi]):
            for xj_ in range(self.r[self.Xj]):
                for xj in range(xj_):
                    # xj_ > xj
                    delta = self.delta(cdf[xj_, xi], cdf[xj, xi])
                    total += self.degree_violation(delta)
                    count += 1
                    
        return torch.div(total, count)

class FalsePositiveConstraint(InequalityConstraint):
    def __init__(self, target, r, epsilon):
        super().__init__(+1, epsilon)
        self.target = target
        self.r = r
        assert self.r[self.target] == 2
    
    def violation(self, model, dataset, config_data, device="cpu", batch_size=64, **kwargs):
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            **kwargs,
        )
        total = torch.tensor(0.0, device=device)
        count = 0
        for (data,) in dataloader:
            
            data = data.to(device)
            y = data[:, self.target].clone()
            p = predict_proba(model, self.r, data[y == 1], target_index=self.target, device=device)
            
            p0 = p[:, 0]
            delta = self.delta(p0, 0.5)
            total += self.degree_violation(delta)
            count += 1
            
        return torch.div(total, count)


In [66]:


def train(model, train_loader, constraints, iterations=100, t_max=0, tol=1e-4, device='cpu'):
    
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    prev_loss, total_loss = 0, 0
    prev_penalty, total_penalty = 0, 0
    t = -1
    config_data = {}
    
    for iteration in range(1000):    
        total_loss, total_penalty = 0, 0
        for (data,) in tqdm(train_loader, total = len(train_loader)):
            data = data.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = -outputs.sum() / data.shape[0]
            total_loss += loss
            penalty = torch.tensor(0.0, device=device)
            
            for constraint in constraints:
                penalty += constraint.violation(model, train_loader.dataset, config_data, device=device, batch_size=64)
            
            total_penalty += penalty
            if t >= 0:
                lambda_ = 10**t
                loss += lambda_*penalty
            
            loss.backward()
            optimizer.step()
        

        if iteration > 0: 
            rel_change_loss = (prev_loss - total_loss) / prev_loss
            rel_change_penalty = (prev_penalty - total_penalty) / prev_penalty
            if rel_change_loss < tol:
                if total_penalty < tol:
                    break
                else:
                    t = min(t + 1, t_max)
                    
                    if t == t_max:
                        break
                
            if iteration % 10 == 1:
                print (f"{t} {total_loss:.4f}, {rel_change_loss:.4f}, {total_penalty:.4f}, {rel_change_penalty:.4f}")
        prev_loss = total_loss 
        prev_penalty = total_penalty

    if iteration > 0:
        print (f"{t} {total_loss:.4f}, {rel_change_loss:.4f}, {total_penalty:.4f}, {rel_change_penalty:.4f}")
    return model
            

In [5]:
dict(zip(names, r))

{'age': 3,
 'sex': 2,
 'cp': 2,
 'trestbps': 3,
 'chol': 2,
 'fbs': 2,
 'restecg': 2,
 'thalach': 3,
 'exang': 2,
 'oldpeak': 3,
 'slope': 3,
 'unhealthy': 2}

In [6]:
names[11]

'unhealthy'

In [82]:
constraints = [
    # FalsePositiveConstraint(11, r, 0.01),
    # ContextSpecificIndependence(2, 11, 4, 1, r),
    MonotonicityConstraint(i, 11, r, +1, 0.001)

    for i in (0, 1,3,4,6)
]

model = make_spn(S=rat_S, I=rat_I, D=rat_D, R=rat_R, device=device, F=n_features, C=rat_C,leaf_base_class=leaves, leaf_base_kwargs=dict(num_bins=max(r)))
model = train(model, train_loader, constraints, t_max=0, device=device)

Using device: cuda


100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.72it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.31it/s]


-1 63.4389, 0.0281, 0.0012, 0.3452


100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.36it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.21it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.37it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.50it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.21it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.34it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.42it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.13it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.20it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.30it/s]


-1 51.1544, 0.0148, 0.0002, -0.1121


100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.10it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.22it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.28it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.06it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.14it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.20it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.03it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.10it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.14it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.99it/s]


-1 45.0097, 0.0137, 0.0000, -0.0000


100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.08it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.09it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.95it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.09it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.08it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.92it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.00it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.03it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.90it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.97it/s]


-1 42.0158, 0.0093, 0.0000, 0.0000


100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.98it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.83it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.91it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.94it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.76it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.90it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.93it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.79it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.86it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.85it/s]


-1 40.3398, 0.0063, 0.0000, 0.0000


100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.77it/s]

-1 40.4364, -0.0024, 0.0000, 0.0000





In [84]:
# model = make_spn(S=rat_S, I=rat_I, D=rat_D, R=rat_R, device=device, F=n_features, C=rat_C,leaf_base_class=leaves, leaf_base_kwargs=dict(num_bins=max(r)))
model = train(model, train_loader, constraints, t_max=10, device=device)

100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 28.81it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.39it/s]


-1 40.2182, 0.0033, 0.0000, 0.0000


100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.39it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.25it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.46it/s]
100%|█████████████████████████████████████████████| 5/5 [00:00<00:00, 29.44it/s]

-1 39.5966, -0.0022, 0.0000, 0.0000





In [83]:
def get_outputs(data_loader, model, device="cpu"):
    outputs = None
    for batch_idx, (batch,) in enumerate(data_loader):
        inputs = batch
        inputs = inputs.to(device)
        if outputs is None:
            outputs = model(inputs)
        else:
            outputs = torch.cat((outputs, model(inputs)))
    return outputs
    
def log_likelihood(data_loader, model, device="cpu"):
    model.eval()
    total = 0
    for (data,) in data_loader:
        data = data.to(device)
        total += model(data).to("cpu").detach().numpy().sum()
    return total

log_likelihood(test_loader, model, device)

-1372.8946380615234

In [2]:
list(range(10))[0:2]

[0, 1]

In [85]:
log_likelihood(test_loader, model, device)

-1366.8207397460938