In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import tensorflow as tf

from datetime import datetime

from GeodesicRelaxationDCCA.algorithms.losses_metrics import EpochWatchdog, EmptyWatchdog

from GeodesicRelaxationDCCA.data.mnist import MNISTData

from GeodesicRelaxationDCCA.experiments.mnist import MNISTDeepCCAExperiment, MNISTDeepCCASlackExperiment

In [None]:
root_dir = 'tmp'

if not os.path.exists(root_dir):
    os.mkdir(root_dir)

# Load synthetic data
mnist_dataprovider = MNISTData.generate(50000, num_boxes=2, max_width=10, num_samples=50000)

# Save dataset
mnist_dataprovider.save(root_dir)


# DCCA

In [None]:
for _ in range(5):
    exp = MNISTDeepCCAExperiment(
        log_dir=os.path.join(root_dir, 'slack_ref'), 
        encoder_config_v1=[(1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (15, None)],
        encoder_config_v2=[(1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (15, None)],
        dataprovider=mnist_dataprovider,
        shared_dim=15,
        lambda_rad=0,
        topk=1,
        max_perc=1,
        lambda_l1=0,
        lambda_l2=1e-5,
        cca_reg=1e-4,
        eval_epochs=10,
        val_default_value=0.0,
        convergence_threshold=0.000,
    )

    exp.train_multiple_epochs(2000)
    
    exp.save()

# Slack-DCCA

In [None]:
for res in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    for _ in range(5):
        exp = MNISTDeepCCASlackExperiment(
            log_dir=os.path.join(root_dir, 'slack'),
            encoder_config_v1=[(1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (15, None)],
            encoder_config_v2=[(1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (1536, 'sigmoid'), (15, None)],
            dataprovider=mnist_dataprovider,
            shared_dim=15,
            residual=res,
            corr_reg=1e-10,
            lambda_l1=0,
            lambda_l2=1e-5,
            eval_epochs=10,
            val_default_value=0.0,
            convergence_threshold=0.000
        )
        
        exp.train_multiple_epochs(num_epochs=2000, num_inner_epochs=100, epsilon_inner_epochs=1e-10)
        
        exp.save()