In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import tensorflow as tf
import os
import numpy as np

from model_fc import SeizureModel
from model_cnn import ConvSeizureModel
from model_gru import GRUSeizureModel
from dataprovider import dataprovider
from training import experiment

In [None]:
data = dataprovider()

In [None]:
dcca_type = 'gru'

assert dcca_type in ['fc', 'cnn', 'gru']

if dcca_type == 'fc':
    seiz_model = SeizureModel(
        encoder_config=[(512, 'relu'), (512, 'relu'), (5, None)],
        decoder_config=[(512, 'relu'), (512, 'relu'), (300, None)]
    )
    l2 = 1e-6
    l_rec = 1e-3
    dim = 5

elif dcca_type == 'cnn':
    seiz_model = ConvSeizureModel(
        encoder_config=[
            dict(l_type='conv', n_filters=5, k_size=10),
            dict(l_type='maxpool', pool_size=2),
            dict(l_type='conv', n_filters=5, k_size=10),
            dict(l_type='maxpool', pool_size=5),
            dict(l_type='conv', n_filters=5, k_size=5),
            dict(l_type='maxpool', pool_size=5),
            dict(l_type='conv', n_filters=1, k_size=1),
        ],
        decoder_config=[
            dict(l_type='conv', n_filters=5, k_size=1),
            dict(l_type='conv_transp', n_filters=5, k_size=2, strides=5),
            dict(l_type='conv', n_filters=5, k_size=5),
            dict(l_type='conv_transp', n_filters=5, k_size=2, strides=5),
            dict(l_type='conv', n_filters=5, k_size=10),
            dict(l_type='conv_transp', n_filters=5, k_size=2, strides=2),
            dict(l_type='conv', n_filters=1, k_size=10),
        ]
    )
    l2 = 1e-5
    l_rec = 1e-10
    dim = 6
    
elif dcca_type == 'gru':
    seiz_model = GRUSeizureModel(
        encoder_config=[(7, False)],
        decoder_config=[300, (1, True)]
    )
    l2 = 1e-8
    l_rec = 1e-3
    dim = 7

In [None]:
exp = experiment(
    'tmp', 
    data, 
    seiz_model, 
    dim=dim, 
    cca_reg=1e-4,
    lambda_rec=l_rec,
    lambda_l2=l2,
    eval_epochs=1)

In [None]:
exp.train(num_epochs=100)

In [None]:
exp.load_best()

In [None]:
exp.analyse_subspace(views=[1], method='DCCAE', latent_dim=1)