In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from TwoChannelModel import TwoChannelModel
import nmca_model
from correlation_analysis import CCA, PCC_Matrix
from plot import plot_eval
from tqdm.notebook import tqdm

# Generate Data

In [None]:
data_model = TwoChannelModel(num_samples=1000)
y_1, y_2, Az_1, Az_2, z_1, z_2 = data_model()

In [None]:
TwoChannelModel.plot_shared_components(z_1, z_2)

In [None]:
TwoChannelModel.plot_non_linearities(y_1, y_2, Az_1, Az_2)

# Build model

In [None]:
model = nmca_model.build_nmca_model()

In [None]:
model.summary()

# Train

In [None]:
def update_U(B_views, Fy_data):
    num_samples = Fy_data[0].shape[1]
    I_t = tf.cast(num_samples, dtype=tf.float32)
    W = tf.eye(num_samples, num_samples) - tf.matmul(tf.ones([num_samples, 1]), tf.transpose(tf.ones([num_samples, 1])))/I_t

    assert len(B_views) == len(Fy_data)
    Z_tmp = [tf.matmul(B_views[i], Fy_data[i]) for i in range(len(B_views))]
    Z = tf.add_n(Z_tmp)
    U_tmp = tf.matmul(Z, W)

    # singular values - left singular vectors - right singular vectors
    D, P, Q = tf.linalg.svd(U_tmp, full_matrices=False, compute_uv=True)

    return tf.sqrt(I_t)*tf.matmul(P, tf.transpose(Q))

In [None]:
NCA_Class = NonlinearComponentAnalysis(num_views=num_views,
                                 num_channels=z_dim+c_dim,
                                 encoder_layers=autoencoder_dims,
                                 decoder_layers=autoencoder_dims,
                                 batch_size=batch_size)
NCA_Model = NCA_Class.NCA

B_1, B_2 = tf.Variable(tf.eye(5), dtype=tf.float32), tf.Variable(tf.eye(5), dtype=tf.float32)

output_of_encoders, output_of_decoders = NCA_Model(data_chunk)

fy_1, fy_2 = output_of_encoders
fy_1 = tf.transpose(fy_1[0,:])
fy_2 = tf.transpose(fy_2[0,:])

U = update_U(
    B_views = [B_1, B_2],
    Fy_data = [fy_1, fy_2]
)

y_1 = tf.cast(tf.squeeze(tf.stack(data_chunk[:5], axis=0)), dtype=tf.float32)
y_2 = tf.cast(tf.squeeze(tf.stack(data_chunk[5:], axis=0)), dtype=tf.float32)

loss_old = None
cca_arr = list()
autoenc_arr = list()
loss_arr = list()
i = 0

In [None]:
# Initialization
NCA_Class.optimizer.learning_rate = 0.001
lambda_reg = 0.01

for _ in tqdm(range(100)):
    for _ in range(100):
        with tf.GradientTape(persistent=True) as tape:
            # Variables to optimize
            tape.watch(data_chunk)

            # Feed input forward through networks
            output_of_encoders, output_of_decoders = NCA_Model(data_chunk)

            # Encoder output
            fy_1, fy_2 = output_of_encoders
            fy_1 = tf.transpose(fy_1[0,:])
            fy_2 = tf.transpose(fy_2[0,:])
            
            y_1_recon = tf.transpose(tf.squeeze(output_of_decoders[0]))
            y_2_recon = tf.transpose(tf.squeeze(output_of_decoders[1]))
            
            tmp_1 = tf.square(tf.norm(U - tf.matmul(B_1, fy_1), axis=0))
            tmp_2 = tf.square(tf.norm(U - tf.matmul(B_2, fy_2), axis=0))
            cca_loss = tf.reduce_mean(tf.add(tmp_1,tmp_2))
            
            tmp_3 = tf.square(tf.norm(y_1 - y_1_recon, axis=0))
            tmp_4 = tf.square(tf.norm(y_2 - y_2_recon, axis=0))
            autoenc_loss = tf.reduce_mean(tf.add(tmp_3,tmp_4))
            
            theta_loss = cca_loss + lambda_reg * autoenc_loss
            autoenc_arr.append(autoenc_loss)
            cca_arr.append(cca_loss)
            loss_arr.append(theta_loss)

            #if loss_old is not None and theta_loss > loss_old:
            #    print("Break")
            #    break

            loss_old = theta_loss
            i = i + 1

        gradients = tape.gradient(theta_loss, [NCA_Model.trainable_variables, B_1, B_2])

        
        NCA_Class.optimizer.apply_gradients(grads_and_vars=zip(gradients[:-2][0], NCA_Model.trainable_variables))
        NCA_Class.optimizer.apply_gradients(grads_and_vars=zip([gradients[-2]], [B_1]))
        NCA_Class.optimizer.apply_gradients(grads_and_vars=zip([gradients[-1]], [B_2]))

    U = update_U(
        B_views = [B_1, B_2],
        Fy_data = [fy_1, fy_2]
    )
    
    print(theta_loss)
    loss_old = None

In [None]:
fig, axs = plt.subplots(1,1,figsize=(10,5))

axs.plot(np.squeeze([np.linspace(0, len(loss_arr)-1, len(loss_arr))]), loss_arr)
axs.set_ylim([0,50])
plt.show()

# Eval network

In [None]:
# Forward path
[fy_1, fy_2], [yhat_1, yhat_2] = model([tf.transpose(y_1), tf.transpose(y_2)])

# Compute CCA
B1, B2, epsilon, omega, ccor = CCA(fy_1, fy_2, 2)

fy_1, fy_2 = tf.transpose(fy_1), tf.transpose(fy_2)
yhat_1, yhat_2 = tf.transpose(yhat_1), tf.transpose(yhat_2)

In [None]:
plot_eval(z_1, z_2, Az_1, Az_2, y_1, y_2, fy_1, fy_2, yhat_1, yhat_2, epsilon, omega)