## Imports

In [7]:
import os
import sys
import scipy
import datetime

In [8]:
sys.path.append(os.path.join(os.pardir, 'src'))

In [9]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt

In [10]:
mnist = tf.keras.datasets.mnist

In [11]:
from tqdm.notebook import tqdm
from collections import defaultdict

In [28]:
from CBO.distributions import Normal, NumpyNormal
from CBO.functions import rastrigin, rastrigin_c, square
from CBO.minimize import minimize
from CBO.visualizations import visualize_trajectory_convergence
from CBO.train import train, NeuralNetworkObjectiveFunction, TensorboardLogging, compute_model_dimensionality, UpdatableTfModel

## Data preparation

In [13]:
def load_mnist_data():
    (X_train, y_train),(X_test, y_test) = mnist.load_data()
    X_train, X_test = X_train / 255.0, X_test / 255.0
    return X_train, X_test, y_train, y_test

In [14]:
X_train, X_test, y_train, y_test = load_mnist_data()

## Model training

In [15]:
def build_default_model():
    return tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(28, 28)),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

In [16]:
def build_small_model():
    return tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(28, 28)),
      tf.keras.layers.Dense(10, activation='relu'),
      tf.keras.layers.BatchNormalization(center=False, scale=False, momentum=1),
    ])

In [19]:
class MeanCrossEntropy():
    def __init__(self):
        self._loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                                   reduction=tf.keras.losses.Reduction.NONE)
        
    def __call__(self, y_true, y_pred):
        loss_value = self._loss(y_true, y_pred)
        return tf.reduce_mean(loss_value) / 10
    
class CrossEntropy:
    def __init__(self, n_classes=10):
        self.n_classes = n_classes
    
    def __call__(self, y_true, y_pred):
        y_true_one_hot = tf.one_hot(y_true, self.n_classes)
        loss = tf.reduce_sum(-y_true_one_hot * tf.math.log(tf.nn.softmax(y_pred))) / y_true.shape[0]
        return loss

In [20]:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [21]:
objective = NeuralNetworkObjectiveFunction(build_small_model(), loss, X_train, y_train)

In [16]:
def update_model_parameters(model, parameters):
    current_position = 0
    for weight in model.trainable_weights:
        next_position = current_position + tf.size(weight)
        weight.assign(tf.reshape(parameters[current_position:next_position], weight.shape))
        current_position = next_position
    return model

In [17]:
# ! rm -rf logs/fit
# tensorboard_logging = TensorboardLogging('cbo_small', 'logs/fit')

In [18]:
acc = tf.keras.metrics.SparseCategoricalAccuracy()
a = lambda x, y: -acc(x, y)

In [20]:
cbo_small_model, trajectory = train(build_small_model(), loss, X_train, y_train, n_particles=100, 
                                    time_horizon=2000,
                                    optimizer_config = {'alpha': 50, 'sigma': 0.4**0.5, 'dt': 0.1}, 
                                    initial_distribution=NumpyNormal(),
                                    return_trajectory=True, verbose=True, particles_batches=10, 
                                    dataset_batches=1000, X_val=X_test, y_val=y_test, 
                                    tensorboard_logging=None, cooling=True,
                                    update_all_particles=True,
                                    evaluation_sample_size=1000,
                                    evaluation_rate=1)

0it [00:00, ?it/s]

Epoch 0, batch 3/1000, batch objective: 2.351, train accuracy: 0.08, val accuracy: 0.0661

KeyboardInterrupt: 

In [22]:
def conduct_tf_training(model, X_train, X_test, y_train, y_test, with_tensorboard=False,
                        model_description='', epoches=10):
    model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
    
    if with_tensorboard:
        log_dir = 'logs/fit/' + model_description
        if model_description is '':
            log_dir += datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
        model.fit(x=X_train, y=y_train, epochs=epoches, validation_data=(X_test, y_test),
                  callbacks=[tensorboard_callback])
    else:
        model.fit(x=X_train, y=y_train, epochs=epoches, validation_data=(X_test, y_test))
    
    return model

In [23]:
adam_small_model = conduct_tf_training(build_small_model(),
                                       X_train, X_test, y_train, y_test,
                                       with_tensorboard=True,
                                       model_description='adam_small', epoches=50)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50


Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


In [18]:
n_params = tf.concat([tf.reshape(w, -1) for w in build_small_model().get_weights()], axis=-1).shape[0]
minimizer, trajectory = minimize(objective, dimensionality=n_params, n_particles=100, time_horizon=2, 
                                  return_trajectory=True, optimizer_config = {
                                      'alpha': 50,
                                      'sigma': 0.4**0.5,
                                      'dt': 0.1,
                                  },
                                  initial_distribution=Normal(0, 1))
result_model = update_model_parameters(build_small_model(), tf.reshape(minimizer, -1))
y_pred = result_model.predict(X_train)
accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
accuracy.update_state(y_train, y_pred)
accuracy.result().numpy()

  0%|          | 0/21 [00:00<?, ?it/s]

b 4.166293
a 3.8489268
3.8489268


  5%|▍         | 1/21 [01:51<37:03, 111.17s/it]

3.8489268
b 3.8489268
a 3.5423727


KeyboardInterrupt: 

## Results analysis

In [None]:
def calculate_cbo_weights(particles, objective, alpha=50):
    objective_values = np.array([objective(particle) for particle in particles])
    weights = np.exp(-alpha * (objective_values - objective_values.min())).reshape(-1, 1)
    return weights / weights.sum()

In [None]:
nn_loss = NeuralNetworkObjectiveFunction(build_small_model(), loss, X_train, y_train)

In [None]:
def update_model_parameters(model, parameters):
    current_position = 0
    for weight in model.trainable_weights:
        next_position = current_position + tf.size(weight)
        weight.assign(tf.reshape(parameters[current_position:next_position], weight.shape))
        current_position = next_position
    return model

In [None]:
small_model_dim = tf.reduce_sum([len(tf.reshape(weight, -1).numpy()) 
                                 for weight in build_small_model().trainable_weights])

In [None]:
adam_small_model_weights = tf.concat([tf.reshape(weight, -1) 
                                      for weight in adam_small_model.trainable_weights], 0)

In [None]:
model = build_small_model()

for i in range(10):
    weights = adam_small_model_weights + Normal(0, 0.1).sample(small_model_dim)
    model = update_model_parameters(model, weights)
    acc = tf.keras.metrics.SparseCategoricalAccuracy()
    acc.update_state(y_train, model.predict(X_train))
    print(f'Model {i} train accuracy: {acc.result().numpy()}')

In [None]:
def plot_particle_loss(trajectory, particle_ind, nn_loss):
    timestamps = list(sorted(trajectory.keys()))
    losses = []
    for ts in timestamps:
        losses.append(nn_loss(trajectory[ts]['particles'][particle_ind]))
    plt.clf()
    plt.plot(timestamps, losses)
    plt.show()

In [None]:
plot_particle_loss(trajectory, 1, nn_loss)

In [None]:
accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

In [None]:
weights = calculate_cbo_weights(trajectory[0]['particles'], nn_loss)

In [None]:
best_particle_ind = np.equal(weights.max(), weights).nonzero()[0][0]
best_particle = trajectory[0]['particles'][best_particle_ind]
best_model = update_model_parameters(build_small_model(), trajectory[0]['particles'][best_particle_ind])
plot_model_predictive_distribution(build_small_model(), trajectory[0]['particles'][best_particle_ind], X_train)
accuracy.update_state(y_train, best_model.predict(X_train))
print(f'Train accuracy: {accuracy.result().numpy()}')

In [None]:
tf.norm(trajectory[0.1]['particles'] - trajectory[0]['particles'], ord=2, axis=1)

In [None]:
plt.clf()
timestamps = list(sorted(trajectory.keys()))
distances = [tf.reduce_sum((best_particle - trajectory[ts]['particles'][best_particle_ind]) ** 2)
             for ts in timestamps]
plt.plot(timestamps, distances)

In [None]:
consensus_postions = [trajectory[ts]['consensus'] for ts in sorted(trajectory.keys())]

In [None]:
plt.clf()
plt.plot(timestamps, [consensus_postions[i][0][1] for i in range(len(consensus_postions))])
plt.show()

In [None]:
all_weights = np.concatenate([tf.reshape(w, -1).numpy() for w in adam_small_model.trainable_weights])

In [None]:
plt.clf()
plt.hist(all_weights, bins=100)
plt.show()

In [None]:
def get_model_predictions(model, parameters, X):
    model = update_model_parameters(model, tf.reshape(parameters, -1))
    predictions = model(X).numpy()
    return np.argmax(predictions, axis=1)

def get_predictions_distribution(values):
    _, counts = np.unique(values, return_counts=True)
    return 1. * counts / counts.sum()

def plot_model_predictive_distribution(model, parameters, X, title=''):
    plt.clf()
    plt.xticks(range(10))
    plt.xlim((0, 10))
    plt.title(title)
    predictions, counts = np.unique(get_model_predictions(model, parameters, X),
                                    return_counts=True)
    plt.bar(x=predictions + 0.5, height=1.*counts/counts.sum(), width=1)
    plt.xlabel('prediction')
    plt.ylabel('density')
    plt.show()

In [None]:
last_timestamp = list(sorted(trajectory.keys()))[-1]
plot_model_predictive_distribution(build_small_model(),
                                   trajectory[last_timestamp]['consensus'],
                                   X_test, 'Result distribution')

In [None]:
plt.rcParams['figure.figsize'] = (10, 10)

In [None]:
def plot_particles_entropy(model, trajectoty, X, n_particles=10, logarithmic=False):
    overall_particles = len(trajectory[0]['particles'])
    particles = list(sorted(np.random.choice(overall_particles, n_particles, replace=False)))
    entropy = defaultdict(lambda: [])
    timestamps = list(sorted(trajectoty.keys()))
    for timestamp in tqdm(timestamps):
        for particle in particles:
            predictions = get_model_predictions(model, trajectoty[timestamp]['particles'][particle], X)
            entropy[particle].append(scipy.stats.entropy(get_predictions_distribution(predictions)))
    plt.clf()
    plt.rcParams['figure.figsize'] = (15, 15)
    for particle, particle_entropy in entropy.items():
        plt.plot(timestamps, np.log(particle_entropy) if logarithmic else particle_entropy, 
                 label=f'particle {particle}')
    plt.xlabel('Timestamp')
    plt.ylabel('Predictive distribution entropy')
    plt.legend()
    plt.show()

In [None]:
plot_particles_entropy(build_small_model(), trajectory, X_train)

In [None]:
def visualize_particle_path_1d(trajectory, particle_ind, projection_dimenssion):
    timestamps = list(sorted(list(trajectory.keys())))
    particle_positions = [trajectory[ts]['particles'][particle_ind][projection_dimenssion]
                          for ts in timestamps]
    plt.clf()
    plt.plot(timestamps, particle_positions)
    plt.xlabel('Timestamp')
    plt.ylabel('Particle position')
    plt.show()
    
def visalize_particles_shift(trajectory):
    timestamps = list(sorted(list(trajectory.keys())))
    shifts = []
    for ind, ts in enumerate(timestamps[:-1]):
        shifts.append(tf.norm(trajectory[ts]['particles'] - trajectory[timestamps[ind + 1]]['particles'],
                              ord='euclidean'))
    plt.clf()
    plt.plot(timestamps[1:], shifts)
    plt.xlabel('Timestamp')
    plt.ylabel('Shift')
    plt.show()
    
def visalize_particles_std(trajectory, logarithmic=False):
    timestamps = list(sorted(list(trajectory.keys())))
    stds = []
    for ts in timestamps:
        stds.append(tf.math.reduce_std(trajectory[ts]['particles']))
    if logarithmic:
        stds = np.log(stds)
    plt.clf()
    plt.plot(timestamps, stds)
    plt.xlabel('Timestamp')
    plt.ylabel('Std')
    plt.show()

# contrains accuracies from different batches!
def visualize_cbo_accuracy(trajectory):
    timestamps = list(sorted(list(trajectory.keys())))
    accuracies = [trajectory[ts]['accuracy'] for ts in timestamps]
    plt.clf()
    plt.plot(timestamps, accuracies)
    plt.xlabel('Timestamp')
    plt.ylabel('Accuracy')
    plt.show()

In [None]:
visualize_particle_path_1d(trajectory, best_particle_ind, 4)

In [None]:
visalize_particles_shift(trajectory)

In [None]:
visalize_particles_std(trajectory)

In [None]:
visualize_cbo_accuracy(trajectory)

## Random initialization leads to unbalanced predictions of the initial model!

In [None]:
model = build_small_model()
initial_weights = tf.Variable(Normal(0, 1).sample(compute_model_dimensionality(model)))
model = update_model_parameters(model, initial_weights)
y_pred = np.argmax(tf.nn.softmax(model(X_train)), axis=1)
plt.hist(y_pred)
plt.show()