In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import tensorflow as tf

from datetime import datetime

from GeodesicRelaxationDCCA.algorithms.correlation import CCA
from GeodesicRelaxationDCCA.algorithms.correlation_residual import canonical_correlations, chordal_distance
from GeodesicRelaxationDCCA.algorithms.losses_metrics import EpochWatchdog, EmptyWatchdog

from GeodesicRelaxationDCCA.data.synthetic import SyntheticData
from GeodesicRelaxationDCCA.experiments.synthetic import SynthDeepCCAExperiment, SynthDeepCCASlackExperiment

In [None]:
def eval_network(dataprov, network):
    for data in dataprov.training_data:
        netw_output = network(data)

    gt_signal_0 = dataprov.z_0
    gt_signal_1 = dataprov.z_1

    if 'rrcca_view_0' in netw_output.keys():
        latent_view_0 = tf.transpose(netw_output['rrcca_view_0'])
        latent_view_1 = tf.transpose(netw_output['rrcca_view_1'])
    elif 'cca_view_0' in netw_output.keys():
        latent_view_0 = netw_output['cca_view_0']
        latent_view_1 = netw_output['cca_view_1']

    _, _, _, _, ccor_0, _, _ = CCA(gt_signal_0, latent_view_0, 2)
    _, _, _, _, ccor_1, _, _ = CCA(gt_signal_1, latent_view_1, 2)

    dist_0 = 1 - tf.math.reduce_mean(ccor_0)
    dist_1 = 1 - tf.math.reduce_mean(ccor_1)
    dist_avg = (dist_0+dist_1)/2

    correlations = canonical_correlations(latent_view_0, latent_view_1, 2, 0)
    corr_avg = tf.math.reduce_mean(correlations)
    
    return {
        'dist_0': dist_0.numpy(),
        'dist_1': dist_1.numpy(),
        'dist_avg': dist_avg.numpy(),
        'corr': correlations.numpy(),
        'corr_avg': corr_avg.numpy(),
    }

In [None]:
root_dir = 'tmp'

if not os.path.exists(root_dir):
    os.mkdir(root_dir)

# Load synthetic data
syn_dataprovider = SyntheticData.generate(
    num_samples=200,
    batch_size=200,
    correlations=[0.6, 0.6],
    num_channels=2,
    non_lin_type='channel_wise'
)

# Save dataset
syn_dataprovider.save(root_dir)

In [None]:
_, _, _, _, ccor, _, _ = CCA(syn_dataprovider.z_0, syn_dataprovider.z_1, 2)
gt_angles = tf.math.acos(ccor)
gt_distance = tf.linalg.norm(tf.math.sin(gt_angles)) / tf.sqrt(2.0)

In [None]:
gt_distance

## DCCA

In [None]:
dcca_results = list()

for _ in range(5):
    opt = tf.keras.optimizers.Adam(learning_rate=0.001)
    
    exp = SynthDeepCCAExperiment(
        log_dir=os.path.join(root_dir, 'slack_ref'), 
        encoder_config_v1=[(256, 'sigmoid'), (256, 'sigmoid'), (2, None)],
        encoder_config_v2=[(256, 'sigmoid'), (256, 'sigmoid'), (2, None)],
        dataprovider=syn_dataprovider,
        shared_dim=2,
        lambda_rad=0,
        topk=1,
        max_perc=1,
        lambda_l1=0,
        lambda_l2=1e-4,
        cca_reg=1e-4,
        eval_epochs=5, 
        val_default_value=1.0,
        convergence_threshold=0.001,
        optimizer=opt
    )

    exp.train_multiple_epochs(2000)
    
    exp.save()
    
    dcca_results.append(eval_network(syn_dataprovider, exp.architecture))

## RDCCA

In [None]:
rrcca_results = dict()

for residual in [0.0, 0.1, 0.5, 0.6, 0.7, 0.8, 0.9]:
    rrcca_results[residual] = list()
    for _ in range(5):
        opt = tf.keras.optimizers.Adam(learning_rate=0.001)
        
        exp = SynthDeepCCASlackExperiment(
            log_dir=os.path.join(root_dir, 'slack'),
            encoder_config_v1=[(256, 'sigmoid'), (256, 'sigmoid'), (2, None)],
            encoder_config_v2=[(256, 'sigmoid'), (256, 'sigmoid'), (2, None)],
            dataprovider=syn_dataprovider,
            shared_dim=2,
            residual=residual,
            corr_reg=1e-6,
            lambda_l1=0,
            lambda_l2=1e-6,
            eval_epochs=5,
            val_default_value=1.0,
            convergence_threshold=0.001,
            optimizer=opt
        )

        exp.train_multiple_epochs(num_epochs=2000, num_inner_epochs=100, epsilon_inner_epochs=1e-10)
        
        exp.save()
        
        rrcca_results[residual].append(eval_network(syn_dataprovider, exp.architecture))

## CCA

In [None]:
for data in syn_dataprovider.training_data:
    pass

Ax, Ay, epsilon, omega, _, _, _ = CCA(data['nn_input_0'], data['nn_input_1'], 2)

gt_signal_0 = syn_dataprovider.z_0
gt_signal_1 = syn_dataprovider.z_1

_, _, _, _, ccor_0, _, _ = CCA(gt_signal_0, tf.transpose(epsilon), 2)
_, _, _, _, ccor_1, _, _ = CCA(gt_signal_1, tf.transpose(omega), 2)

dist_0 = 1 - tf.math.reduce_mean(ccor_0)
dist_1 = 1 - tf.math.reduce_mean(ccor_1)
dist_avg = (dist_0+dist_1)/2

correlations = canonical_correlations(tf.transpose(epsilon), tf.transpose(omega), 2, 0)
corr_avg = tf.math.reduce_mean(correlations)