In [2]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import os
import tensorflow as tf
import numpy as np
import scipy
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import nmca_model
from TwoChannelModel import TwoChannelModel
from correlation_analysis import CCA, PCC_Matrix
from plot import plot_eval
from tf_summary import write_image_summary, write_metric_summary

ModuleNotFoundError: No module named 'tensorflow'

# Generate Data

In [None]:
data_model = TwoChannelModel(num_samples=1000)
y_1, y_2, Az_1, Az_2, z_1, z_2 = data_model()

In [None]:
TwoChannelModel.plot_shared_components(z_1, z_2)

In [None]:
TwoChannelModel.plot_non_linearities(y_1, y_2, Az_1, Az_2)

# Build model

In [None]:
model = nmca_model.build_nmca_model()

In [None]:
model.summary()

# Train one network

In [None]:
writer = nmca_model.create_writer("/var/tmp/mkuschel/tf_logs")
optimizer = tf.keras.optimizers.Adam()

for epoch in tqdm(range(15000), desc='Epochs'):
    with tf.GradientTape() as tape:
        # Watch the input to be able to compute the gradient later
        tape.watch([y_1,y_2])
        # Forward path
        [fy_1, fy_2], [yhat_1, yhat_2] = model([tf.transpose(y_1), tf.transpose(y_2)])
        # Loss computation
        loss, cca_loss, rec_loss, ccor = nmca_model.compute_loss(y_1, y_2, fy_1, fy_2, yhat_1, yhat_2, lambda_reg=1e-10)
        
        if epoch%5 == 0:
            # Compute dist metric
            B1, B2, epsilon, omega, ccor = CCA(fy_1, fy_2, 2)
            dist = nmca_model.compute_distance_metric(S=z_1[:2], U=0.5*(omega+epsilon))
            
            write_metric_summary(writer, epoch, loss, cca_loss, rec_loss, ccor, dist)
        
        if epoch%500 == 0:
            write_image_summary(writer, epoch, Az_1, Az_2, y_1, y_2, fy_1, fy_2)
        
    # Compute gradients
    gradients = tape.gradient(loss, model.trainable_variables)
    # Backpropagate through network
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

# Train many networks

In [None]:
num_models = 3
lambda_reg = 1e-10

for _ in range(num_models):
    writer = nmca_model.create_writer("/var/tmp/mkuschel/tf_logs")
    optimizer = tf.keras.optimizers.Adam()
    model = nmca_model.build_nmca_model()

    for epoch in tqdm(range(30000), desc='Epochs'):
        with tf.GradientTape() as tape:
            # Watch the input to be able to compute the gradient later
            tape.watch([y_1,y_2])
            # Forward path
            [fy_1, fy_2], [yhat_1, yhat_2] = model([tf.transpose(y_1), tf.transpose(y_2)])
            # Loss computation
            loss, cca_loss, rec_loss, ccor = nmca_model.compute_loss(y_1, y_2, fy_1, fy_2, yhat_1, yhat_2, lambda_reg=lambda_reg)

            if epoch%5 == 0:
                # Compute dist metric
                B1, B2, epsilon, omega, ccor = CCA(fy_1, fy_2, 2)
                dist = nmca_model.compute_distance_metric(S=z_1[:2], U=0.5*(omega+epsilon))

                write_metric_summary(writer, epoch, loss, cca_loss, rec_loss, ccor, dist)

            if epoch%500 == 0:
                write_image_summary(writer, epoch, Az_1, Az_2, y_1, y_2, fy_1, fy_2)

        # Compute gradients
        gradients = tape.gradient(loss, model.trainable_variables)
        # Backpropagate through network
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

# Eval network

In [None]:
# Forward path
[fy_1, fy_2], [yhat_1, yhat_2] = model([tf.transpose(y_1), tf.transpose(y_2)])

# Compute CCA
B1, B2, epsilon, omega, ccor = CCA(fy_1, fy_2, 2)

fy_1, fy_2 = tf.transpose(fy_1), tf.transpose(fy_2)
yhat_1, yhat_2 = tf.transpose(yhat_1), tf.transpose(yhat_2)

In [None]:
S = z_1[:2]
Ps = np.eye(1000) - tf.transpose(S)@np.linalg.inv(S@tf.transpose(S))@S
U = 0.5*(omega+epsilon)
Q = scipy.linalg.orth(tf.transpose(U))
dist = np.linalg.norm(Ps@Q, ord=2)
print("Dist: "+str(dist))

In [None]:
plot_eval(z_1, z_2, Az_1, Az_2, y_1, y_2, fy_1, fy_2, yhat_1, yhat_2, epsilon, omega)