In [1]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.vq import kmeans2
from dgp_aepmcm.gp_network import DGPNetwork
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
np.random.seed(5)



In [2]:
data = load_iris()

In [3]:
x = data["data"]
y = data["target"]

In [4]:
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.33)
X_train = X_train.astype(np.float64)
X_test = X_test.astype(np.float64)


In [5]:
# DGP model variables
# Number of inducing points
M = 50
D = X_train.shape[-1]
# Maximun of epochs for training
max_epochs = 500
learning_rate = 0.01
minibatch_size = 100
n_samples_each_point = 10
n_samples = 20
# Inducing points locations
Z = kmeans2(X_train, M, minit="points")[0]
noise_val = 1e-5

In [7]:
# Instantiate and train DGP-AEPMCM. with L=3 and 3 GP nodes for the first two layers
model_aepmcm = DGPNetwork(X_train, y_train, inducing_points=Z, show_debug_info=True, minibatch_size=minibatch_size, dtype=np.float64)

model_aepmcm.add_input_layer()
# This method always assume a mean function for the prior p(u) = N(u| m(x), Kzz)
# with m(x) = X W
# For this example we disable the mean function for the prior so we set W to 0.
model_aepmcm.add_gp_layer(M, 3, W=np.zeros((D, 3))) # W=np.zeros((D, 3)))
#model_aepmcm.add_noise_layer(noise_val)
model_aepmcm.add_gp_layer(M, 3, W=np.zeros((3, 3)))
#model_aepmcm.add_noise_layer(noise_val)
model_aepmcm.add_gp_layer(M, 3, W=np.zeros((3, 3)))
#model_aepmcm.add_noise_layer(noise_val)
model_aepmcm.add_output_layer_multiclass_classification(noise_in_labels=True, noise_in_labels_trainable=False)

model_aepmcm.train_via_adam(
    max_epochs=max_epochs,    
    learning_rate=learning_rate,
)

labels_aepmcm, probs_aepmcm = model_aepmcm.predict(X_test)


Creating DGP network for classification problem with 3 classes
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.
Compiling adam updates
Initializing network
Instructions for updating:
Use tf.cast instead.
Instructions for updating:
This op will be removed after the deprecation date. Please switch to tf.sets.difference().
Training for 500 epochs, 500 iterations
Epoch: 0   | Energy: 1.643345    | Time:  10.5248s | Memory: 0.74 GB | ETA: -
Epoch: 1   | Energy: 1.641309    | Time:   0.2831s | Memory: 0.74 GB | ETA: -
Epoch: 2   | Energy: 1.623860    | Time:   0.3496s | Memory

Epoch: 77  | Energy: 0.490352    | Time:   0.1969s | Memory: 0.73 GB | ETA: 1min, 35s
Epoch: 78  | Energy: 0.487322    | Time:   0.2064s | Memory: 0.70 GB | ETA: 1min, 34s
Epoch: 79  | Energy: 0.482358    | Time:   0.2042s | Memory: 0.72 GB | ETA: 1min, 33s
Epoch: 80  | Energy: 0.479636    | Time:   0.1918s | Memory: 0.72 GB | ETA: 1min, 30s
Epoch: 81  | Energy: 0.482465    | Time:   0.1785s | Memory: 0.70 GB | ETA: 1min, 28s
Epoch: 82  | Energy: 0.475339    | Time:   0.2184s | Memory: 0.72 GB | ETA: 1min, 27s
Epoch: 83  | Energy: 0.478015    | Time:   0.1922s | Memory: 0.72 GB | ETA: 1min, 25s
Epoch: 84  | Energy: 0.474951    | Time:   0.2227s | Memory: 0.75 GB | ETA: 1min, 25s
Epoch: 85  | Energy: 0.472796    | Time:   0.1968s | Memory: 0.76 GB | ETA: 1min, 25s
Epoch: 86  | Energy: 0.473896    | Time:   0.1860s | Memory: 0.71 GB | ETA: 1min, 25s
Epoch: 87  | Energy: 0.473271    | Time:   0.2236s | Memory: 0.72 GB | ETA: 1min, 25s
Epoch: 88  | Energy: 0.469618    | Time:   0.1940s | M

Epoch: 173 | Energy: 0.396507    | Time:   0.2647s | Memory: 0.70 GB | ETA: 1min, 13s
Epoch: 174 | Energy: 0.384670    | Time:   0.2534s | Memory: 0.73 GB | ETA: 1min, 14s
Epoch: 175 | Energy: 0.394060    | Time:   0.2228s | Memory: 0.72 GB | ETA: 1min, 13s
Epoch: 176 | Energy: 0.385912    | Time:   0.1947s | Memory: 0.73 GB | ETA: 1min, 12s
Epoch: 177 | Energy: 0.380762    | Time:   0.1830s | Memory: 0.71 GB | ETA: 1min, 11s
Epoch: 178 | Energy: 0.387810    | Time:   0.1841s | Memory: 0.76 GB | ETA: 1min, 10s
Epoch: 179 | Energy: 0.379639    | Time:   0.1902s | Memory: 0.74 GB | ETA: 1min, 10s
Epoch: 180 | Energy: 0.383892    | Time:   0.2616s | Memory: 0.72 GB | ETA: 1min, 10s
Epoch: 181 | Energy: 0.389098    | Time:   0.2957s | Memory: 0.73 GB | ETA: 1min, 11s
Epoch: 182 | Energy: 0.389105    | Time:   0.1886s | Memory: 0.75 GB | ETA: 1min, 10s
Epoch: 183 | Energy: 0.385606    | Time:   0.2141s | Memory: 0.75 GB | ETA: 1min, 10s
Epoch: 184 | Energy: 0.379968    | Time:   0.2065s | M

Epoch: 273 | Energy: 0.335492    | Time:   0.1898s | Memory: 0.71 GB | ETA: 44s
Epoch: 274 | Energy: 0.346141    | Time:   0.1868s | Memory: 0.73 GB | ETA: 44s
Epoch: 275 | Energy: 0.338468    | Time:   0.1873s | Memory: 0.72 GB | ETA: 44s
Epoch: 276 | Energy: 0.340100    | Time:   0.2141s | Memory: 0.74 GB | ETA: 44s
Epoch: 277 | Energy: 0.338005    | Time:   0.1901s | Memory: 0.70 GB | ETA: 43s
Epoch: 278 | Energy: 0.338368    | Time:   0.1806s | Memory: 0.74 GB | ETA: 43s
Epoch: 279 | Energy: 0.338778    | Time:   0.1774s | Memory: 0.72 GB | ETA: 42s
Epoch: 280 | Energy: 0.331265    | Time:   0.1833s | Memory: 0.74 GB | ETA: 42s
Epoch: 281 | Energy: 0.338987    | Time:   0.1843s | Memory: 0.74 GB | ETA: 42s
Epoch: 282 | Energy: 0.339633    | Time:   0.1844s | Memory: 0.74 GB | ETA: 41s
Epoch: 283 | Energy: 0.336062    | Time:   0.2249s | Memory: 0.74 GB | ETA: 41s
Epoch: 284 | Energy: 0.336457    | Time:   0.2247s | Memory: 0.76 GB | ETA: 42s
Epoch: 285 | Energy: 0.332165    | Time:

Epoch: 376 | Energy: 0.309678    | Time:   0.1955s | Memory: 0.73 GB | ETA: 24s
Epoch: 377 | Energy: 0.306977    | Time:   0.2058s | Memory: 0.76 GB | ETA: 24s
Epoch: 378 | Energy: 0.308320    | Time:   0.2011s | Memory: 0.72 GB | ETA: 24s
Epoch: 379 | Energy: 0.306247    | Time:   0.1909s | Memory: 0.75 GB | ETA: 24s
Epoch: 380 | Energy: 0.328535    | Time:   0.1880s | Memory: 0.71 GB | ETA: 24s
Epoch: 381 | Energy: 0.315930    | Time:   0.1751s | Memory: 0.72 GB | ETA: 23s
Epoch: 382 | Energy: 0.315468    | Time:   0.1849s | Memory: 0.74 GB | ETA: 23s
Epoch: 383 | Energy: 0.311088    | Time:   0.1907s | Memory: 0.76 GB | ETA: 23s
Epoch: 384 | Energy: 0.317584    | Time:   0.1791s | Memory: 0.73 GB | ETA: 23s
Epoch: 385 | Energy: 0.314600    | Time:   0.1927s | Memory: 0.74 GB | ETA: 22s
Epoch: 386 | Energy: 0.317792    | Time:   0.1841s | Memory: 0.74 GB | ETA: 22s
Epoch: 387 | Energy: 0.312282    | Time:   0.1873s | Memory: 0.75 GB | ETA: 22s
Epoch: 388 | Energy: 0.315003    | Time:

Epoch: 479 | Energy: 0.302627    | Time:   0.1942s | Memory: 0.72 GB | ETA: 4s
Epoch: 480 | Energy: 0.300365    | Time:   0.1911s | Memory: 0.72 GB | ETA: 4s
Epoch: 481 | Energy: 0.304321    | Time:   0.1770s | Memory: 0.72 GB | ETA: 4s
Epoch: 482 | Energy: 0.301438    | Time:   0.1935s | Memory: 0.70 GB | ETA: 3s
Epoch: 483 | Energy: 0.311208    | Time:   0.1790s | Memory: 0.71 GB | ETA: 3s
Epoch: 484 | Energy: 0.295479    | Time:   0.2141s | Memory: 0.71 GB | ETA: 3s
Epoch: 485 | Energy: 0.306982    | Time:   0.1831s | Memory: 0.73 GB | ETA: 3s
Epoch: 486 | Energy: 0.311384    | Time:   0.1911s | Memory: 0.70 GB | ETA: 3s
Epoch: 487 | Energy: 0.294561    | Time:   0.1800s | Memory: 0.74 GB | ETA: 2s
Epoch: 488 | Energy: 0.302070    | Time:   0.2107s | Memory: 0.71 GB | ETA: 2s
Epoch: 489 | Energy: 0.304829    | Time:   0.2571s | Memory: 0.72 GB | ETA: 2s
Epoch: 490 | Energy: 0.309453    | Time:   0.3189s | Memory: 0.73 GB | ETA: 2s
Epoch: 491 | Energy: 0.307576    | Time:   0.2914s |

In [10]:
acc_dgp = accuracy_score(y_test, labels_aepmcm)
ll = model_aepmcm.calculate_log_likelihood(X_test, y_test[:, None])
print(f"Accuracy: {acc_dgp}")
print(f"Test log-likelihood: {ll}")


Accuracy: 0.94
Test log-likelihood: -0.1307495151667313
