# Classification benchmark: PLN-Tree

## Synthetic data

### Define base model

In [23]:
from plntree.models import PLNTreeConditional
from plntree.utils.classifiers import DenseClassifier, RNNClassifier
from plntree.utils.jupyter_functions import *
from plntree.data.utils import numpy_dataset_to_torch_dataloader
import torch
import torch.optim as optim

In [2]:
seed = 42
prefix = f'synthetic_plntree_classifier_s{seed}'

seed_all(seed)

In [3]:
# Model parameters
selected_layers = [1, -1]

# Tree parameters
K = [1, 5, 14, 35]
mu_1 = torch.tensor([0., 1, -0.2, 2, 0]) + 0.1 * torch.randn(K[selected_layers[0]])
Omega_1 = torch.tensor(artificial_loader.generate_precision_matrix(
    artificial_loader.generate_adjacency_matrix(K[selected_layers[0]], seed=seed), 
    conditioning=0.3, 
    correlation=0.4
))

identifiable = True
diag_correction = 1e-3
positive_fun = 'softplus'

n_latent_layers = 3
    
offset_method = 'constant'
offset_constant = 8.

In [4]:
tree = generate_hierachical_tree(K, seed=seed)
tree.plot(legend=False, title='')
savefig(f'{prefix}_tree_graph')

In [5]:
seed_all(seed)
classifier_type = DenseClassifier
classifier_params = {
    'n_classes':2,
    'input_size':K[-1],
    'hidden_sizes':[64, 32],
    'selected_layer':-1,
}
classifier = classifier_type(seed=np.random.randint(1000), **classifier_params)

In [6]:
base = PLNTreeConditional(
        tree=tree,
        n_classes=2,
        classifier=classifier,
        identifiable=identifiable,
        diag_smoothing_factor=diag_correction,
        positive_fun=positive_fun,
        selected_layers=selected_layers,
        offset_method=offset_method,
        n_latent_layers=n_latent_layers,
        seed=seed,
    )

base.mu_fun[0].data = mu_1
base.omega_fun[1][0].weight.data *= 20
base.mu_fun[1][0].weight.data += 10
base.mu_fun[2][2].weight.data *= 20
base.mu_fun[2][2].weight.data += 5
base.offset_constant.data = torch.tensor([offset_constant])

In [7]:
n = 2_000
X_base, Y_base, Z_base, O_base = base.sample(batch_size=n, seed=seed)
dataloader = numpy_dataset_to_torch_dataloader(X_base, Y_base, batch_size=512, shuffle=False)
print('Proportion of label 1:', Y_base.sum() / len(Y_base))

In [8]:
vizualize_samples(dataloader, tree, base.selected_layers, autofill=True, seed=seed)
savefig('synthetic_plntree_samples')

In [9]:
from sklearn.decomposition import PCA
import matplotlib.lines as mlines

fig, axs = plt.subplots(1, len(base.K), figsize=(18, 7))
colors = ['C0', 'C1']
for layer, K_l in enumerate(base.K):
    X_l = torch.log(X_base[:, layer, :K_l] + 1e-10)
    X_l_pca = PCA(n_components=2, random_state=seed).fit_transform(X_l)
    for k, c in enumerate(np.unique(Y_base)):
        indexes = np.where(Y_base == c)
        axs[layer].plot(X_l_pca[indexes, 0], X_l_pca[indexes, 1], marker='.', linestyle='', color=colors[k], alpha=0.5)
legend_handles = [
        mlines.Line2D([], [], marker='o', linestyle='', color=color, alpha=0.9, label=group)
        for color, group in zip(colors, ['label 0', 'label 1'])
    ]
legend = plt.legend(handles=legend_handles, fontsize="12")

In [19]:
def get_train_test(train_size=0.8, seed=None, return_indexes=False, return_dataloader=False):
    seed_all(seed)
    indexes = np.arange(len(X_base))
    np.random.shuffle(indexes)
    floor_index = int(np.floor(len(indexes) * train_size))
    train_indexes = indexes[:floor_index]
    test_indexes = indexes[floor_index:]
    X_train, Y_train = X_base[train_indexes], Y_base[train_indexes]
    X_test, Y_test = X_base[test_indexes], Y_base[test_indexes]
    output = [X_train, Y_train, X_test, Y_test]
    if return_indexes:
        output.append(train_indexes)
        output.append(test_indexes)
    if return_dataloader:
        dataloader_train = numpy_dataset_to_torch_dataloader(X_train, Y_train, batch_size=512, shuffle=False)
        output.append(dataloader_train)
    return output

### Learn the model

#### Mean-field

In [15]:
def learn_plntree_mean_field(dataloader, seed=seed, n_epoch=15_000):
    estimator = PLNTreeConditional(
        tree=tree,
        n_classes=2,
        classifier=classifier_type(seed=seed, **classifier_params),
        identifiable=identifiable,
        diag_smoothing_factor=diag_correction,
        positive_fun=positive_fun,
        selected_layers=selected_layers,
        offset_method=offset_method,
        n_latent_layers=n_latent_layers,
        
        variational_approx='mean_field',
        variational_approx_params={'n_variational_layers': 1, 'preprocessing': ['log']},
        seed=seed
    )

    optimizer = optim.Adam(
        estimator.parameters(),
        lr=1e-3,
    )
    return estimator.fit(optimizer, dataloader, n_epoch=n_epoch, verbose=100, max_grad_norm=5.)

In [20]:
X_train, Y_train, X_test, Y_test, dataloader_train = get_train_test(train_size=0.8, seed=seed, return_dataloader=True)

In [24]:
try:
    meanfield, meanfield_losses = load_pkl(prefix, f'mean_field')
except:
    print('Learning PLN-Tree (MF)')
    meanfield, meanfield_losses = learn_plntree_mean_field(dataloader_train, seed=seed)
    save_pkl((meanfield, meanfield_losses), prefix, f'mean_field')

In [26]:
plt.plot(meanfield_losses)
plt.yscale('log')

In [27]:
meanfield.predict_proba(X_test)