# Disentangling Manifolds with Categorical Factors

## Duet

In this case study the manifold takes the shape of two poles, i.e. dots which gives rise to a problem similar to classification. Points on the same pole are all maximally close while points on separate poles are maximally distant. As can be seen from the scatterplots, the model places all points from the right pole into the lower halfplane and all points from the left pole into the upper halfplane. The separation thus worked. When inspecting the layerwise contributions one can see that the linear and non-linear parts counteract each other to some extent. It is possible that a model with fewer layers would already suffice to solve the task. As she cross-validation results show, different initializations of the model lead to a working separation, thereby underlining the reliability of the model.

### Data Synthesis

In [None]:
# Set up the raw manifold function
manifold_name = 'f_6'
manifold_function = lambda S: (5*(2*(0 < S) - 1), 5*(2*(0 < S) - 1))


# Generate a dataset
reset_random_number_generators(seed=856) # Reproducability
M = 2**14
noise_standard_deviation = [1.0, 1.0]
target_correlations = [0, 0.9]
Z, Y = create_data_set(S=S, manifold_function=manifold_function, noise_standard_deviation=noise_standard_deviation) # Z.shape == [M, N], Y.shape = [M, F]

# Plot pairs of instances
batch_size = M//8
iterator = mdis.volatile_factorized_pair_iterator(X=Z, Y=Y, batch_size=batch_size, target_correlations=target_correlations)
Z_ab, Y_ab = next(iterator)

plot_instance_pairs(S=S, Z_ab=Z_ab, Y_ab=Y_ab, manifold_function=manifold_function, manifold_name=manifold_name)

### Model Creation

In [None]:
# Create network
reset_random_number_generators(seed=286) # Reproducability
network = create_network(stage_count=3, sigma = 1.0)

# Plot input and output
plot_input_output(network, S=S, manifold_function=manifold_function, noise_standard_deviation=noise_standard_deviation, manifold_name=manifold_name);

### Model Calibration

In [None]:
# Calibrate network
network.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01))
batch_size = 1024
epoch_loss_means, epoch_loss_standard_deviations = network.fit(iterator=iterator, epoch_count=10, batch_count=M//batch_size)
plot_loss_trajectory(epoch_loss_means=epoch_loss_means, epoch_loss_standard_deviations=epoch_loss_standard_deviations, manifold_name=manifold_name)

In [None]:
# Plot input and output
plot_input_output(network, S=S, manifold_function=manifold_function, noise_standard_deviation=noise_standard_deviation, manifold_name=manifold_name);

In [None]:
plot_instance_pairs_2(Z_ab=Z_ab)

In [None]:
Z_tilde_a = network(Z_ab[:,0,:])[:,np.newaxis,:]
Z_tilde_b = network(Z_ab[:,1,:])[:,np.newaxis,:]
Z_tilde_ab = tf.concat([Z_tilde_a, Z_tilde_b], axis=1)
plot_instance_pairs_2(Z_ab=Z_tilde_ab)

### Model Evaluation

In [None]:
# Plot input and output
plot_input_output(network, S_range = [np.min(S), np.max(S)], manifold_function=manifold_function, noise_standard_deviation=noise_standard_deviation, manifold_name=manifold_name);

In [None]:
# Plot stage-wise contribution
plot_contribution_per_layer(network=network, s_range=(np.min(S), np.max(S)), manifold_function=manifold_function, manifold_name=manifold_name, layer_steps=[7,14,21], step_titles=['Stage 1','Stage 2','Stage 3'])

In [None]:
# Plot interactive tool
interact(plot_inverse_point, position=(-1.5,1.5,0.1), residual=(-2,2,0.1), S=fixed(S_sample), network=fixed(network), manifold_function=fixed(manifold_function), manifold_name=fixed(manifold_name));

In [None]:
# Cross validate model behaviour
reset_random_number_generators(seed=958) # Reproducability
networks = [None] * 5 # Fold-count many networks
for i in range(len(networks)):
    networks[i] = create_model(Z_sample=Z[np.random.choice(M, size=128)], stage_count=3, sigma = 0.8)
    networks[i].compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.005))

Z_test, Y_test = cross_validate(Z=Z, Y=Y, networks=networks, batch_size=512, epoch_count=100, similarity_function=similarity_function_6, manifold_name=manifold_name)

In [None]:
# Plot cross validation
evaluate_and_plot_networks(Z_test=Z_test, Y_test=Y_test, networks=networks, manifold_name=manifold_name);

## Quartet

This case study separates its manifold into four poles and thus constitutes an advancement of the dipole case study. It is more complex because the model has more freedom is arranging four poles along its position factor than two poles. The calibration trajectory converges, indicating the stability of the model. One can see in the scatter plots that the model manages to enumerate the four poles along the position axis. The layer-wise analysis shows that all 3 stages use linear and non-linear operations and the later stages still have a large contribution to the output. As can be seen from the cross-validation plots, not all model initializations create a clustered assocation of model input and output. This confirms the increased complexity compared to the dipole case.  

### Data Synthesis

In [None]:
# Set up the raw manifold function
manifold_name = 'f_7'
def manifold_function(S):
    s_max = np.max(S)
    centroids = 5*np.array([[-1.0,1],[-1,-1],[1,1],[1,-1]])
    tmp = centroids[0] * (S[:,np.newaxis] < -s_max/2)
    tmp = tmp + centroids[1] * np.logical_and(-s_max/2 <= S, S < 0)[:,np.newaxis]
    tmp = tmp + centroids[2] * np.logical_and(0 <= S, S < s_max/2)[:,np.newaxis]
    tmp = tmp + centroids[3] * (s_max/2 <= S[:,np.newaxis])
    return (tmp[:,0], tmp[:,1])

# Generate a dataset
reset_random_number_generators(seed=395) # Reproducability
noise_standard_deviation = [0.5, 0.5]
M = 2**13
S = np.random.uniform(low=-1, high=1, size=[M])
Z, Y = create_data_set(S=S, manifold_function=manifold_function, noise_standard_deviation=noise_standard_deviation) # Z.shape == [M, N], Y.shape = [M, F]

# Plot pairs of instances
batch_size = M//8
target_correlations = [0.0, 0.9]
iterator = mdis.volatile_factorized_pair_iterator(X=Z, Y=Y, batch_size=batch_size, target_correlations=target_correlations)
Z_ab, Y_ab = next(iterator)

plot_instance_pairs(S=S, Z_ab=Z_ab, Y_ab=Y_ab, manifold_function=manifold_function, manifold_name=manifold_name)

### Model Creation

In [None]:
# Create network
reset_random_number_generators(seed=294) # Reproducability
network = create_network(stage_count=3, sigma =1.0)
network.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.005))


In [None]:

# Calibrate network
epoch_loss_means, epoch_loss_standard_deviations = network.fit(iterator=iterator, epoch_count=100, batch_count=M//batch_size)
plot_loss_trajectory(epoch_loss_means=epoch_loss_means, epoch_loss_standard_deviations=epoch_loss_standard_deviations, manifold_name=manifold_name)

### Model Training

In [None]:
# Create network
reset_random_number_generators(seed=294) # Reproducability
network = create_network(stage_count=3, sigma =1.0)
network.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.005))

# Calibrate network
epoch_loss_means, epoch_loss_standard_deviations = network.fit(iterator=iterator, epoch_count=100, batch_count=M//batch_size)
plot_loss_trajectory(epoch_loss_means=epoch_loss_means, epoch_loss_standard_deviations=epoch_loss_standard_deviations, manifold_name=manifold_name)

### Evaluation

In [None]:
# Plot input and output
plot_input_output(network, S=S, manifold_function=manifold_function, noise_standard_deviation=noise_standard_deviation, manifold_name=manifold_name);

In [None]:
# Plot stage-wise contribution
plot_contribution_per_layer(network=network, s_range=(np.min(S), np.max(S)), manifold_function=manifold_function, manifold_name=manifold_name, layer_steps=[7,14,21], step_titles=['Stage 1','Stage 2','Stage 3'])

In [None]:
# Plot interactive tool
interact(plot_inverse_point, position=(-2,2,0.1), residual=(-2,2,0.1), S=fixed(S_sample), network=fixed(network), manifold_function=fixed(manifold_function), manifold_name=fixed(manifold_name))

In [None]:
# Cross validate model behaviour
reset_random_number_generators(seed=958) # Reproducability
networks = [None] * 5 # Fold-count many networks
for i in range(len(networks)):
    networks[i] = create_model(Z_sample=Z[np.random.choice(M, size=128)], stage_count=3, sigma=0.8)
    networks[i].compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.005))
Z_test, Y_test = cross_validate(Z=Z, Y=Y, networks=networks, batch_size=512, epoch_count=200, similarity_function=similarity_function_7, manifold_name=manifold_name)

In [None]:
# Plot cross validation
evaluate_and_plot_networks(Z_test=Z_test, Y_test=Y_test, networks=networks, manifold_name=manifold_name)