In [1]:
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, ConcatDataset, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import KBinsDiscretizer
from spn.experiments.RandomSPNs_layerwise.rat_spn import RatSpn, RatSpnConfig
from spn.experiments.RandomSPNs_layerwise.distributions import RatNormal
from spn.algorithms.layerwise.distributions import Bernoulli
from utils.datasets import gen_dataset
from utils.config_utils import load_config_data
from utils.utils import visualize_3d
from utils.selectors import get_sim_dataloader
from constraint.constraints import GeneralizationConstraint
import time
import argparse
from pathlib import Path

In [3]:
def make_spn(S, I, R, D, F, C, device, leaf_base_class) -> RatSpn:
        """Construct the RatSpn"""

        # Setup RatSpnConfig
        config = RatSpnConfig()
        config.F = F
        config.R = R
        config.D = D
        config.I = I
        config.S = S
        config.C = C
        config.dropout = 0.0
        config.leaf_base_class = leaf_base_class 
        config.leaf_base_kwargs = {}

        # Construct RatSpn from config
        model = RatSpn(config)

        model = model.to(device)
        model.train()

        print("Using device:", device)
        return model


def get_adult_loaders(use_cuda, batch_size):
    df = pd.read_csv("..\\data\\Adult\\train_0.csv").astype(float)
    for col in df.columns:
        if df[col].nunique() > 2:
            d  = KBinsDiscretizer(n_bins=2, encode='ordinal',strategy='kmeans')
            df[col] = d.fit_transform(df[col].to_numpy().reshape(-1, 1)).flatten().astype(int)
    print (df.nunique())
    kwargs = {"num_workers": 8, "pin_memory": True} if use_cuda else {}

    test_batch_size = batch_size

    train, test = train_test_split(df, stratify=df.income, random_state=0)
    # Train data loader
    train_loader = torch.utils.data.DataLoader(
        TensorDataset(torch.Tensor(train.to_numpy())),
        batch_size=batch_size,
        shuffle=True,
        **kwargs,
    )

    # Test data loader
    test_loader = torch.utils.data.DataLoader(
        TensorDataset(torch.Tensor(test.to_numpy())),
        batch_size=test_batch_size,
        shuffle=False,
        **kwargs,
    )
    return train_loader, test_loader


In [47]:
train_loader, test_loader = get_adult_loaders(False, 32)
rat_S, rat_I, rat_D, rat_R, rat_C, leaves = 20, 20, 5, 5, 1, Bernoulli #RatNormal
n_features = train_loader.dataset[0][0].shape[0]
device="cpu"
dropout=0
model = make_spn(S=rat_S, I=rat_I, D=rat_D, R=rat_R, device=device, F=n_features, C=rat_C,leaf_base_class=leaves)

age                               2
education-num                     2
capital-gain                      2
capital-loss                      2
hours-per-week                    2
                                 ..
native-country_Trinadad&Tobago    2
native-country_United-States      2
native-country_Vietnam            2
native-country_Yugoslavia         2
income                            2
Length: 87, dtype: int64
Using device: cpu


In [53]:
from tqdm import tqdm
epsilon = 0.001

def penalty(model, data, epsilon):
    y = data[:, -1].clone()
    data0 = data.clone()
    data0[:, -1] = 0
    logp0 = model(data0)
    denom = model(data0, (86,))
    p0 = torch.exp(logp0 - denom).ravel()
    delta = p0 - 0.5 + epsilon
    return torch.square(torch.maximum(y*(delta), torch.zeros_like(y))).sum()

model.train()
prev_loss, total_loss = 0, 0
prev_penalty, total_penalty = 0, 0
optimizer = optim.Adam(model.parameters(), lr=0.01)
t = 1
from itertools import islice
for iteration in range(10):    
    total_loss, total_penalty = 0, 0
    for (data,) in tqdm(islice(train_loader,10)):
        data = data.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        zeta = penalty(model, data, epsilon)
        lambda_ = 10**t
        loss = -outputs.sum() / data.shape[0] + lambda_*zeta
        total_loss += loss
        total_penalty += zeta

        loss.backward()
        optimizer.step()
    if iteration > 0:
        rel_change_loss = (prev_loss - total_loss) / prev_loss
        rel_change_penalty = (prev_penalty - total_penalty) / prev_penalty
        print (f"{total_loss:.4f}, {rel_change_loss:.4f}, {total_penalty:.4f}, {rel_change_penalty:.4f}")
        if rel_change_loss < 1e-4 :
            if total_penalty != 0 and rel_change_penalty > 1e-4:
                t = t + 1
            else:
                break
    prev_loss = total_loss 
    prev_penalty = total_penalty
    

10it [00:09,  1.05it/s]
10it [00:09,  1.04it/s]


139.5193, 0.0101, 0.1640, 0.4375


10it [00:09,  1.04it/s]


142.4369, -0.0209, 0.0928, 0.4338


10it [00:09,  1.06it/s]


142.7810, -0.0024, 0.0523, 0.4364


10it [00:09,  1.05it/s]


182.8649, -0.2807, 0.0457, 0.1268


10it [00:09,  1.06it/s]


176.3221, 0.0358, 0.0038, 0.9175


10it [00:09,  1.06it/s]


156.2538, 0.1138, 0.0016, 0.5717


10it [00:09,  1.06it/s]


136.4110, 0.1270, 0.0000, 1.0000


10it [00:09,  1.07it/s]

143.5266, -0.0522, 0.0000, nan



