In [13]:
import matplotlib
import sys
import os
current_dir = os.path.dirname(os.path.abspath('./'))
if not current_dir in sys.path:
    sys.path.append(current_dir)
current_dir = os.path.dirname(os.path.abspath('../'))
if not current_dir in sys.path:
    sys.path.append(current_dir)

from utils.structures import Pipeline, Deploy
from utils.data_management import dict2str
from utils.machine_learning import one_hot_encoder, one_hot_decoder
from typing import *
import tensorflow as tf
from sklearn.datasets import make_classification
import mne
from combiners import EpochsCombiner
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
from utils.machine_learning.designer import ModelDesign, ParallelDesign, LayerDesign
from utils.machine_learning.analyzer import ModelAnalyzer, LFCNNAnalyzer
from mne.datasets import multimodal
import sklearn
import mneflow as mf
import tensorflow as tf
from mneflow.layers import DeMixing, LFTConv, TempPooling, Dense
from mneflow.models import BaseModel
import mneflow
import logging
from time import perf_counter

logger= logging.getLogger(__name__)
logging.root.handlers = []
logger.setLevel(logging.NOTSET)
logging.basicConfig(
    format='%(asctime)s, %(name)s %(levelname)s %(message)s',
    datefmt='%H:%M:%S',
    level=logging.DEBUG,
    handlers=[
        logging.FileHandler('./history.log'),
        logging.StreamHandler(sys.stdout)
    ]
)

# %matplotlib qt

In [14]:
class Deconw(tf.keras.layers.Layer):
    def __init__(
        self,
        units=32,
        kernel_size=(4, 10),
        strides=(1, 1),
        padding='valid',
        output_padding=None,
        data_format=None,
        dilation_rate=(1, 1),
        activation=None,
        use_bias=True,
        kernel_initializer='glorot_uniform',
        bias_initializer='zeros',
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        **kwargs
    ):
        super().__init__()
        self.units = units
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.output_padding = output_padding
        self.data_format = data_format
        self.dilation_rate = dilation_rate
        self.activation = activation
        self.use_bias = use_bias
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.kernel_regularizer = kernel_regularizer
        self.bias_regularizer = bias_regularizer
        self.activity_regularizer = activity_regularizer
        self.kernel_constraint = kernel_constraint
        self.bias_constraint = bias_constraint
        self.kwargs = kwargs

    def build(self, input_shape):
        self.n_channels = input_shape[-1]
        self.deconws = self.deconv_constructor(self.n_channels)

    def __call__(self, inputs):
        self.build(inputs.shape)
        outputs = []
        for i in range(self.n_channels):
            input_ = tf.expand_dims(inputs[:, :, :, i], axis=3)
            outputs.append(self.deconws[i](input_))
        return tf.transpose(tf.stack(outputs), (1, 0, 2, 3, 4))

    def deconv_constructor(self, n_channels):
        return [
            tf.keras.layers.Conv2DTranspose(
                filters=1,
                kernel_size=self.kernel_size,
                strides=self.strides,
                padding=self.padding,
                output_padding=self.output_padding,
                data_format=self.data_format,
                dilation_rate=self.dilation_rate,
                activation=self.activation,
                use_bias=self.use_bias,
                kernel_initializer=self.kernel_initializer,
                bias_initializer=self.bias_initializer,
                kernel_regularizer=self.kernel_regularizer,
                bias_regularizer=self.bias_regularizer,
                activity_regularizer=self.activity_regularizer,
                kernel_constraint=self.kernel_constraint,
                bias_constraint=self.bias_constraint,
                **self.kwargs
            )
        for _ in range(n_channels)
    ]

In [3]:
mne.set_log_level(verbose='CRITICAL')
fname_raw = os.path.join(multimodal.data_path(), 'multimodal_raw.fif')
raw = mne.io.read_raw_fif(fname_raw)
cond = raw.acqparser.get_condition(raw, None)
condition_names = [k for c in cond for k,v in c['event_id'].items()]
epochs_list = [mne.Epochs(raw, **c) for c in cond]
epochs = mne.concatenate_epochs(epochs_list)
epochs = epochs.pick_types(meg='grad')
X = np.array([])
Y = list()
for i, epochs in enumerate(epochs_list):
    data = epochs.get_data()
    if i == 0:
        X = data.copy()
    else:
        X = np.append(X, data, axis=0)
    Y += [i for _ in range(data.shape[0])]

Y = np.array(Y)
X = np.array([X[i, epochs._channel_type_idx['grad'], :] for i, _ in enumerate(X)])
original_X = X.copy()
original_Y = Y.copy()

In [15]:
import_opt = dict(savepath='../tfr/',
    out_name='mne_sample_epochs',
    fs=600,
    input_type='trials',
    target_type='int',
    picks={'meg':'grad'},
    scale=True,  # apply baseline_scaling
    crop_baseline=True,  # remove baseline interval after scaling
    decimate=None,
    scale_interval=(0, 60),  # indices in time axis corresponding to baseline interval
    n_folds=5,
    overwrite=True,
    segment=False,
)

specs = dict()
specs.setdefault('filter_length', 7)
specs.setdefault('n_latent', 4)
specs.setdefault('pooling', 10)
specs.setdefault('stride', 2)
specs.setdefault('padding', 'SAME')
specs.setdefault('pool_type', 'max')
specs.setdefault('nonlin', tf.nn.relu)
specs.setdefault('l1', 3e-4)
specs.setdefault('l2', 0)
specs.setdefault('l1_scope', ['fc', 'dmx', 'tconv', 'fc'])
specs.setdefault('l2_scope', [])
specs.setdefault('maxnorm_scope', [])
specs.setdefault('dropout', .5)

specs['filter_length'] = 17
specs['pooling'] = 5
specs['stride'] = 5
specs['l1'] = 3e-3
# out_dim = len(np.unique(original_Y))
# Y = original_Y.copy()
# Y = one_hot_encoder(Y)
# X = original_X.copy()
# X = np.transpose(np.expand_dims(X, axis = 1), (0, 1, 3, 2))
# print(X.shape)
# n_samples, _, n_times, n_channels = X.shape
# X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(original_X, original_Y, train_size=.85)

# # write TFRecord files and metadata file to disk
# meta = mneflow.produce_tfrecords((original_X, original_Y), **import_opt)
# dataset = mneflow.Dataset(meta, train_batch=100)

In [16]:
out_dim=8

lfcnnd = ModelDesign(
    None,
    DeMixing(
        size=specs['n_latent'],
        nonlin=tf.identity,
        axis=3, specs=specs
    ),
    LFTConv(
        size=specs['n_latent'],
        nonlin=specs['nonlin'],
        filter_length=specs['filter_length'],
        padding=specs['padding'],
        specs=specs
    ),
    TempPooling(
        pooling=specs['pooling'],
        pool_type=specs['pool_type'],
        stride=specs['stride'],
        padding=specs['padding'],
    ),
    tf.keras.layers.Dropout(specs['dropout'], noise_shape=None),
    Dense(size=out_dim, nonlin=tf.identity, specs=specs)
)


class ZubarevBaseNet(BaseModel):
    def __init__(self, Dataset, specs=dict()):
        super().__init__(Dataset, specs)

    def train(
        self,
        n_epochs,
        eval_step=None,
        min_delta=1e-6,
        early_stopping=3,
        mode='single_fold',
        *,
        callbacks=None
    ):
        callbacks = [] if callbacks is None else callbacks

        stop_early = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            min_delta=min_delta,
            patience=early_stopping,
            restore_best_weights=True
        )
        if not eval_step:
            train_size = self.dataset.h_params['train_size']
            eval_step = train_size // self.dataset.h_params['train_batch'] + 1

        self.train_params = [n_epochs, eval_step, early_stopping, mode]

        if mode == 'single_fold':
            self.t_hist = self.km.fit(
                self.dataset.train,
                validation_data=self.dataset.val,
                epochs=n_epochs, steps_per_epoch=eval_step,
                shuffle=True,
                validation_steps=self.dataset.validation_steps,
                callbacks=[stop_early, *callbacks], verbose=2
            )
            self.v_loss, self.v_metric = self.evaluate(self.dataset.val)
            self.v_loss_sd = 0
            self.v_metric_sd = 0
            print("Training complete: loss: {}, Metric: {}".format(self.v_loss, self.v_metric))
            self.update_log()
        elif mode == 'cv':
            n_folds = len(self.dataset.h_params['folds'][0])
            print("Running cross-validation with {} folds".format(n_folds))
            metrics = []
            losses = []
            for jj in range(n_folds):
                print("fold:", jj)
                train, val = self.dataset._build_dataset(
                    self.dataset.h_params['train_paths'],
                    train_batch=self.dataset.training_batch,
                    test_batch=self.dataset.validation_batch,
                    split=True, val_fold_ind=jj
                )
                self.t_hist = self.km.fit(
                    train,
                    validation_data=val,
                    epochs=n_epochs, steps_per_epoch=eval_step,
                    shuffle=True,
                    validation_steps=self.dataset.validation_steps,
                    callbacks=[stop_early, *callbacks], verbose=2
                )
                loss, metric = self.evaluate(val)
                losses.append(loss)
                metrics.append(metric)

                if jj < n_folds -1:
                    self.shuffle_weights()
                else:
                    "Not shuffling the weights for the last fold"


                print("Fold: {} Loss: {:.4f}, Metric: {:.4f}".format(jj, loss, metric))
            self.cv_losses = losses
            self.cv_metrics = metrics
            self.v_loss = np.mean(losses)
            self.v_metric = np.mean(metrics)
            self.v_loss_sd = np.std(losses)
            self.v_metric_sd = np.std(metrics)
            print("{} with {} folds completed. Loss: {:.4f} +/- {:.4f}. Metric: {:.4f} +/- {:.4f}".format(mode, n_folds, np.mean(losses), np.std(losses), np.mean(metrics), np.std(metrics)))
            self.update_log()
            return self.cv_losses, self.cv_metrics

        elif mode == "loso":
            n_folds = len(self.dataset.h_params['test_paths'])
            print("Running leave-one-subject-out CV with {} subject".format(n_folds))
            metrics = []
            losses = []
            for jj in range(n_folds):
                print("fold:", jj)

                test_subj = self.dataset.h_params['test_paths'][jj]
                train_subjs = self.dataset.h_params['train_paths'].copy()
                train_subjs.pop(jj)

                train, val = self.dataset._build_dataset(
                    train_subjs,
                    train_batch=self.dataset.training_batch,
                    test_batch=self.dataset.validation_batch,
                    split=True, val_fold_ind=0
                )
                self.t_hist = self.km.fit(
                    train,
                    validation_data=val,
                    epochs=n_epochs, steps_per_epoch=eval_step,
                    shuffle=True,
                    validation_steps=self.dataset.validation_steps,
                    callbacks=[stop_early, *callbacks], verbose=2
                )
                test = self.dataset._build_dataset(
                    test_subj,
                    test_batch=None,
                    split=False
                )

                loss, metric = self.evaluate(test)
                losses.append(loss)
                metrics.append(metric)

                if jj < n_folds -1:
                    self.shuffle_weights()
                else:
                    "Not shuffling the weights for the last fold"

            self.cv_losses = losses
            self.cv_metrics = metrics
            self.v_loss = np.mean(losses)
            self.v_metric = np.mean(metrics)
            self.v_loss_sd = np.std(losses)
            self.v_metric_sd = np.std(metrics)
            self.update_log()
            print("{} with {} folds completed. Loss: {:.4f} +/- {:.4f}. Metric: {:.4f} +/- {:.4f}".format(mode, n_folds, np.mean(losses), np.std(losses), np.mean(metrics), np.std(metrics)))
            return self.cv_losses, self.cv_metrics


class ZubarevNet(ZubarevBaseNet):
    def __init__(self, Dataset, specs=dict(), design=lfcnnd, design_name='design'):
        self.scope = design_name
        self.design = design
        specs.setdefault('filter_length', 7)
        specs.setdefault('n_latent', 4)
        specs.setdefault('pooling', 4)
        specs.setdefault('stride', 4)
        specs.setdefault('padding', 'SAME')
        specs.setdefault('pool_type', 'max')
        specs.setdefault('nonlin', tf.nn.relu)
        specs.setdefault('l1', 3e-4)
        specs.setdefault('l2', 0)
        specs.setdefault('l1_scope', ['fc', 'demix', 'lf_conv'])
        specs.setdefault('l2_scope', [])
        specs.setdefault('maxnorm_scope', [])

        super().__init__(Dataset, specs)

    def build_graph(self):
        return self.design(self.inputs)

    def set_design(self, design: ModelDesign):
        self.design = design

Setting reg for fc, to l1


In [6]:
simplenetd = ModelDesign(
    tf.keras.Input(shape=(1, n_times, n_channels)),
    DeMixing(
        size=specs['n_latent'],
        nonlin=tf.identity,
        axis=3, specs=specs
    ),
    LFTConv(
        size=specs['n_latent'],
        nonlin=specs['nonlin'],
        filter_length=specs['filter_length'],
        padding=specs['padding'],
        specs=specs
    ),
    LFTConv(
        size=specs['n_latent'],
        nonlin=specs['nonlin'],
        filter_length=specs['filter_length'],
        padding=specs['padding'],
        specs=specs
    ),
    LayerDesign(
        lambda X: X[:, :, ::2, :]
    ),
    tf.keras.layers.Dropout(specs['dropout'], noise_shape=None),
    Dense(size=out_dim, nonlin=tf.identity, specs=specs)
)

lfrnnd = ModelDesign(
    None,
    LayerDesign(tf.squeeze, axis=1),
    tf.keras.layers.Bidirectional(
        tf.keras.layers.LSTM(
            specs['n_latent'],
            bias_regularizer='l1',
            return_sequences=True,
            kernel_regularizer=tf.keras.regularizers.L1(.01),
            recurrent_regularizer=tf.keras.regularizers.L1(.01),
            dropout=0.4,
            recurrent_dropout=0.4,
        ),
        merge_mode='sum'
    ),
    LayerDesign(tf.expand_dims, axis=1),
    LFTConv(
        size=specs['n_latent'],
        nonlin=specs['nonlin'],
        filter_length=specs['filter_length'],
        padding=specs['padding'],
        specs=specs
    ),
    TempPooling(
        pooling=specs['pooling'],
        pool_type=specs['pool_type'],
        stride=specs['stride'],
        padding=specs['padding'],
    ),
    tf.keras.layers.Dropout(specs['dropout'], noise_shape=None),
    Dense(size=out_dim, nonlin=tf.identity, specs=specs)
)


# newnetd = ModelDesign(
#     tf.keras.Input(shape=(1, n_times, n_channels)),
#     Deconw(kernel_size=(specs['n_latent'], specs['filter_length']), activation='relu', kernel_regularizer='l1'),
#     tf.keras.layers.Conv2D(1, (1, specs['filter_length']), activation='relu', kernel_regularizer='l2'),
#     LayerDesign(
#         lambda X: tf.transpose(tf.squeeze(X, axis=-1), (0, 2, 3, 1))
#     ),
#     tf.keras.layers.Conv2D(1, (1, 204), padding='same'),
#     LayerDesign(
#         lambda X: tf.transpose(X, (0, 3, 2, 1))
#     ),
#     LayerDesign(
#         lambda X: X[:, :, ::specs['pooling'], :]
#     ),
#     tf.keras.layers.Dropout(specs['dropout'], noise_shape=None),
#     Dense(size=out_dim, nonlin=tf.identity, specs=specs)
# )

# newnetd = ModelDesign(
#     tf.keras.Input(shape=(1, n_times, n_channels)),
#     Deconw(kernel_size=(specs['n_latent'], specs['filter_length']), activation='relu', kernel_regularizer='l1'),
#     tf.keras.layers.Conv2D(1, (1, specs['filter_length']), activation='relu'),
#     LayerDesign(
#         lambda X: tf.transpose(tf.squeeze(X, axis=-1), (0, 1, 3, 2))
#     ),
#     tf.keras.layers.DepthwiseConv2D((204, 1), kernel_regularizer='l1'),
#     LayerDesign(
#         lambda X: X[:, :, ::2, :]
#     ),
#     tf.keras.layers.Dropout(specs['dropout'], noise_shape=None),
#     Dense(size=out_dim, nonlin=tf.identity, specs=specs)
# )

newnetd = ModelDesign(
    tf.keras.Input(shape=(1, n_times, n_channels)),
    Deconw(kernel_size=(specs['n_latent'], specs['filter_length']), activation='relu', kernel_regularizer='l1'),
    LayerDesign(
        lambda X: tf.transpose(tf.squeeze(X, axis=-1), (0, 1, 3, 2))
    ),
    tf.keras.layers.DepthwiseConv2D((1, specs['filter_length']), activation='relu', depthwise_regularizer='l1'),
    tf.keras.layers.DepthwiseConv2D((n_channels, 1), name='demixing', depthwise_regularizer='l1'),
    LayerDesign(
        lambda X: X[:, :, ::2, :]
    ),
    tf.keras.layers.Dropout(specs['dropout'], noise_shape=None),
    Dense(size=out_dim, nonlin=tf.identity, specs=specs)
)

Setting reg for fc, to l1
Setting reg for fc, to l1
Setting reg for fc, to l1


In [7]:
import wandb

class WanbCallback(tf.keras.callbacks.Callback):
    def __init__(
        self,
        model,
        meta,
        *args, **kwargs):
        self.model = model
        self.meta = meta
        self.start_time = perf_counter()
        wandb.init(*args, **kwargs)
    def on_epoch_end(self, epoch, logs=None):
        wandb.log(logs)
    def on_train_end(self, logs=None):
        train_runtime = perf_counter() - self.start_time
        wandb.log(dict(
            train_runtime=train_runtime
        ))

18:34:15, git.cmd DEBUG Popen(['git', 'version'], cwd=/home/user/Projects/FingerMovementDecoder/dirty_field/net_dev, universal_newlines=False, shell=None, istream=None)
18:34:15, git.cmd DEBUG Popen(['git', 'version'], cwd=/home/user/Projects/FingerMovementDecoder/dirty_field/net_dev, universal_newlines=False, shell=None, istream=None)


In [8]:


# model = ZubarevNet(dataset, specs, lfcnnd, 'lfcnn')
model_name='simplenet'
model = ZubarevNet(dataset, specs, simplenetd, model_name)
model.build()
t1 = perf_counter()
model.train(n_epochs=25, eval_step=100, early_stopping=5,
            callbacks=[
                WanbCallback(
                    model, meta,
                    project='fmdec',
                    config=specs,
                    name=model_name
                )
            ]
        )
y_true_train, y_pred_train = model.predict(meta['train_paths'])
t1 = perf_counter()
y_true_test, y_pred_test = model.predict(meta['test_paths'])
runtime=perf_counter()-t1
logging.info(
    f'{model.scope} performance:\n'
    f'\truntime: {runtime : .4f}\n'
    f'\ttrain-set: {sklearn.metrics.accuracy_score(one_hot_decoder(y_true_train), one_hot_decoder(y_pred_train))}\n'
    f'\ttest-set: {sklearn.metrics.accuracy_score(one_hot_decoder(y_true_test), one_hot_decoder(y_pred_test))}'
)

wandb.log(dict(
    test_runtime=runtime,
    train_acc=sklearn.metrics.accuracy_score(one_hot_decoder(y_true_train), one_hot_decoder(y_pred_train)),
    test_acc=sklearn.metrics.accuracy_score(one_hot_decoder(y_true_test), one_hot_decoder(y_pred_test))
))

Setting reg for dmx, to l1
Built: dmx input: (None, 1, 361, 204)
Setting reg for tconv, to l1
Built: tconv input: (None, 1, 361, 4)
Setting reg for tconv, to l1
Built: tconv input: (None, 1, 361, 4)
Built: fc input: (None, 1, 181, 4)
Input shape: (1, 361, 204)
y_pred: (None, 8)
Initialization complete!
18:34:16, wandb.jupyter ERROR Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
18:34:16, urllib3.connectionpool DEBUG Starting new HTTPS connection (1): api.wandb.ai:443
18:34:17, urllib3.connectionpool DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 None
18:34:17, urllib3.connectionpool DEBUG Starting new HTTPS connection (1): api.wandb.ai:443
18:34:17, urllib3.connectionpool DEBUG https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 None


[34m[1mwandb[0m: Currently logged in as: [33malexzab[0m. Use [1m`wandb login --relogin`[0m to force relogin


18:34:17, git.cmd DEBUG Popen(['git', 'cat-file', '--batch-check'], cwd=/home/user/Projects/FingerMovementDecoder, universal_newlines=False, shell=None, istream=<valid stream>)


Epoch 1/25
100/100 - 4s - loss: 3.2023 - cat_ACC: 0.1462 - val_loss: 3.1243 - val_cat_ACC: 0.1330 - 4s/epoch - 43ms/step
Epoch 2/25
100/100 - 3s - loss: 3.0497 - cat_ACC: 0.1801 - val_loss: 3.0539 - val_cat_ACC: 0.1436 - 3s/epoch - 28ms/step
Epoch 3/25
100/100 - 3s - loss: 2.9476 - cat_ACC: 0.2106 - val_loss: 2.9896 - val_cat_ACC: 0.1596 - 3s/epoch - 28ms/step
Epoch 4/25
100/100 - 3s - loss: 2.8510 - cat_ACC: 0.2413 - val_loss: 2.9109 - val_cat_ACC: 0.1915 - 3s/epoch - 26ms/step
Epoch 5/25
100/100 - 3s - loss: 2.7248 - cat_ACC: 0.2976 - val_loss: 2.7746 - val_cat_ACC: 0.2128 - 3s/epoch - 28ms/step
Epoch 6/25
100/100 - 3s - loss: 2.5037 - cat_ACC: 0.3720 - val_loss: 2.5968 - val_cat_ACC: 0.2872 - 3s/epoch - 27ms/step
Epoch 7/25
100/100 - 3s - loss: 2.3333 - cat_ACC: 0.4248 - val_loss: 2.4594 - val_cat_ACC: 0.3564 - 3s/epoch - 29ms/step
Epoch 8/25
100/100 - 3s - loss: 2.1269 - cat_ACC: 0.5025 - val_loss: 2.3434 - val_cat_ACC: 0.3989 - 3s/epoch - 29ms/step
Epoch 9/25
100/100 - 3s - loss: 

2022-08-28 18:35:32.665895: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 276901440 exceeds 10% of free system memory.
2022-08-28 18:35:33.037427: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 276901440 exceeds 10% of free system memory.


No dataset specified using validation dataset (Default)
18:35:35, root INFO simplenet performance:
	runtime:  1.6491
	train-set: 0.948936170212766
	test-set: 0.7712765957446809


: 

In [17]:
n_samples, _, n_times, n_channels = (940, 1, 361, 204)
out_dim=8

# newnetd = ModelDesign(
#     tf.keras.Input(shape=(1, n_times, n_channels)),
#     Deconw(kernel_size=(specs['n_latent'], specs['filter_length']), activation='relu', kernel_regularizer='l1'),
#     tf.keras.layers.Conv2D(1, (1, specs['filter_length']), activation='relu', kernel_regularizer='l2'),
#     LayerDesign(
#         lambda X: tf.transpose(tf.squeeze(X, axis=-1), (0, 1, 3, 2))
#     ),
#     # TensorShape([None, 4, 361, 204])
#     tf.keras.layers.Conv2D(1, (204, 1), padding='valid', name='demixing'),
#     # LayerDesign(
#     #     lambda X: tf.transpose(X, (0, 3, 2, 1))
#     # ),
#     # LayerDesign(
#     #     lambda X: X[:, :, ::specs['pooling'], :]
#     # ),
#     # tf.keras.layers.Dropout(specs['dropout'], noise_shape=None),
#     # Dense(size=out_dim, nonlin=tf.identity, specs=specs)
# )

# newnetd = ModelDesign(
#     tf.keras.Input(shape=(1, n_times, n_channels)),
#     Deconw(kernel_size=(4, 10)),
#     tf.keras.layers.Conv2D(1, (1, 10)),
#     LayerDesign(
#         lambda X: tf.transpose(tf.squeeze(X, axis=-1), (0, 1, 3, 2))
#     ),
#     tf.keras.layers.DepthwiseConv2D((204, 1), name='demixing'),
#     LayerDesign(
#         lambda X: X[:, :, ::2, :]
#     ),
#     tf.keras.layers.Dropout(specs['dropout'], noise_shape=None),
#     Dense(size=out_dim, nonlin=tf.identity, specs=specs)
# )


newnetd = ModelDesign(
    tf.keras.Input(shape=(1, n_times, n_channels)),
    Deconw(kernel_size=(4, 10)),
    LayerDesign(
        lambda X: tf.transpose(tf.squeeze(X, axis=-1), (0, 1, 3, 2))
    ),
    tf.keras.layers.DepthwiseConv2D((1, 10)),
    tf.keras.layers.DepthwiseConv2D((204, 1), name='demixing'),
    LayerDesign(
        lambda X: X[:, :, ::2, :]
    ),
    tf.keras.layers.Dropout(specs['dropout'], noise_shape=None),
    Dense(size=out_dim, nonlin=tf.identity, specs=specs)
)


print('input_shape: ', (1, n_times, n_channels))
newnetd().shape

Setting reg for fc, to l1
input_shape:  (1, 361, 204)
Built: fc input: (None, 1, 181, 4)


TensorShape([None, 8])

In [None]:

def initialize_specs(specs: dict):
    specs.setdefault('filter_length', 7)
    specs.setdefault('n_latent', 4)
    specs.setdefault('pooling', 4)
    specs.setdefault('stride', 4)
    specs.setdefault('padding', 'SAME')
    specs.setdefault('pool_type', 'max')
    specs.setdefault('nonlin', tf.nn.relu)
    specs.setdefault('l1', 3e-4)
    specs.setdefault('l2', 0)
    specs.setdefault('l1_scope', ['fc', 'demix', 'lf_conv'])
    specs.setdefault('l2_scope', [])
    specs.setdefault('maxnorm_scope', [])
    return specs


LayerLike = Optional[list[Callable[[tf.Tensor], tf.Tensor]]]

class NetworkBuilder(object):
    def __init__(
        self,
        design_parts: LayerLike = None,
        demixing_layer: LayerLike = None,
        temporal_filtering_layer: LayerLike = None,
        pooling_layer: LayerLike = None,
        specs: Optional[dict] = None
    ):
        self.specs = initialize_specs(specs)
        self.parts = design_parts if design_parts else list()
        self.demixing = demixing_layer if demixing_layer else DeMixing(
            size=self.specs['n_latent'],
            nonlin=tf.identity,
            axis=3, specs=self.specs
        )
        self.pooling = pooling_layer if pooling_layer else TempPooling(
            pooling=self.specs['pooling'],
            pool_type=self.specs['pool_type'],
            stride=self.specs['stride'],
            padding=self.specs['padding'],
        )
        self.temporal_filter = temporal_filtering_layer if temporal_filtering_layer else LFTConv(
            size=self.specs['n_latent'],
            nonlin=self.specs['nonlin'],
            filter_length=self.specs['filter_length'],
            padding=self.specs['padding'],
            specs=self.specs
        )

    def add_temporal_filter(self):
        self.parts.append(self.temporal_filter)

    def add_demixing(self):
        self.parts.append(self.demixing)

    def add_pooling(self):
        self.parts.append(self.pooling)

    def add_layer(self, layer: LayerLike):
        self.parts.append(layer)

    def design(self):
        return ModelDesign(
            None, *self.parts
        )


class NetworkDirector(object):
    def __init__(self, builder: NetworkBuilder, layers_dict: dict[str: LayerLike]):
        self.builder = builder
        self.layers_dict = layers_dict

    def construct(self, config: str):
        config = config.split('')
        for part_config in config:
            {
                'f': self.builder.add_temporal_filter,
                'd': self.builder.add_demixing,
                'p': self.builder.add_pooling
            }[part_config]()

In [20]:

newnetd = ModelDesign(
    tf.keras.Input(shape=(1, n_times, n_channels)),
    Deconw(kernel_size=(4, 10)),
    LayerDesign(
        lambda X: tf.transpose(tf.squeeze(X, axis=-1), (0, 1, 3, 2))
    ),
    tf.keras.layers.DepthwiseConv2D((204, 1), name='demixing'),
    tf.keras.layers.DepthwiseConv2D((1, 10)),
    # tf.keras.layers.DepthwiseConv2D((204, 1), name='demixing'),
    # LayerDesign(
    #     lambda X: X[:, :, ::2, :]
    # ),
    # tf.keras.layers.Dropout(specs['dropout'], noise_shape=None),
    # Dense(size=out_dim, nonlin=tf.identity, specs=specs)
)



print('input_shape: ', (1, n_times, n_channels))
newnetd().shape

input_shape:  (1, 361, 204)


TensorShape([None, 1, 361, 4])

In [43]:
inp = tf.keras.Input(shape=(1, n_times, n_channels))
model = tf.keras.Model(inp, newnetd(inp))