# Spatial pattern

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops

import numpy as np
import scipy.signal as sg

from envelope_detector import EnvelopeDetector, create_importance_indices, create_spatial_patterns, create_temporal_patterns



In [2]:
def diff(x, y):
    return torch.linalg.norm(x - y) / torch.sqrt(torch.linalg.norm(x) * torch.linalg.norm(y))

In [3]:
np.random.seed(0)
npoints = 300
nchannels = 30
nfeatures = 5
spatial_bias = False
np.random.seed(0)
data = np.random.normal(size=(1000, nchannels, npoints))
x = einops.rearrange(data, 'b c t -> c (b t)')
data_tf = einops.rearrange(data, 'b c t -> b t c')
fs = 1000
temporal_filter_size = 15
downsample_coef = 10
output_layer = 3
nyquist = fs // 2

In [5]:
envelope_detector = EnvelopeDetector(
    nchannels=nchannels, 
    nfeatures=nfeatures, 
    temporal_filter_size=temporal_filter_size, 
    downsample_coef=downsample_coef, 
    fs_in=fs,
    spatial_bias=spatial_bias,
    activation='demodulation', 
    downsample_method='avepool',
)

In [6]:
spatial_filters = envelope_detector.get_spatial_filter()

temporal_filters = envelope_detector.get_temporal_filter()

In [31]:
def covariance(x, unbiased=True, mean=None):
    if mean is None:
        mean = torch.mean(x, dim=-1, keepdims=True)
    x = x - mean
    covariance_matrix = 1 / (x.shape[-1] - unbiased) * torch.einsum('...ct, ...Ct -> ...cC', x, x)
    return covariance_matrix

In [32]:
spatial_patterns = torch.zeros((nfeatures, nchannels), dtype=torch.float32)

for feature in range(nfeatures):
    temporal_filter = envelope_detector.temporal_filter.weight.data[feature:feature+1]
    spatial_filter = envelope_detector.spatial_filter.weight.data[feature]
    temporal_filter_ = einops.repeat(temporal_filter, '1 1 t -> c 1 t', c=nchannels)
    
    x_filtered = F.conv1d(torch.tensor(x, dtype=torch.float32), temporal_filter_, bias=None, padding='same', groups=nchannels)
    x_cov = covariance(x_filtered, unbiased=True)

    pattern = torch.einsum('...cC, Ck -> c', x_cov, spatial_filter)
    spatial_patterns[feature] = pattern

In [33]:
spatial_results_x = create_spatial_patterns(x, spatial_filters, temporal_filters)
spatial_results_data = create_spatial_patterns(data, spatial_filters, temporal_filters, nbatch=10)

print(diff(spatial_patterns, spatial_results_x['spatial_patterns']))
print(diff(spatial_patterns, spatial_results_data['spatial_patterns']))

tensor(0.)
tensor(0.0014)


# Temporal patterns

In [34]:
def spectrum(signal, fs, nfreq):
    amplitudes = torch.abs(torch.fft.fft(signal, nfreq, dim=-1))
    frequencies = torch.fft.fftfreq(nfreq, d=1/fs)
    assert amplitudes.shape[-1] == frequencies.shape[-1], f"{amplitudes.shape[-1]}!={frequencies.shape[-1]}"
    positive_freq = nfreq // 2
    return frequencies[:positive_freq], amplitudes[...,:positive_freq]

In [35]:
x_unmixed = envelope_detector.spatial_filter(torch.tensor(x, dtype=torch.float32))
temporal_filter = envelope_detector.temporal_filter.weight.data

x_unmixed_numpy = x_unmixed.cpu().detach().numpy()
input_frequencies, input_spectrum = sg.welch(x_unmixed_numpy, fs=fs, nperseg=nyquist, detrend='constant', axis=-1)
input_frequencies = torch.tensor(input_frequencies[:-1], dtype=x_unmixed.dtype, device=x_unmixed.device)
input_spectrum = torch.tensor(input_spectrum[...,:-1], dtype=x_unmixed.dtype, device=x_unmixed.device)

frequencies_filter, temporal_filters_spectrum = spectrum(temporal_filter, fs, nfreq=nyquist)
temporal_filters_spectrum = einops.rearrange(temporal_filters_spectrum, 'c 1 t -> c t')

temporal_filters_spectrum = temporal_filters_spectrum
temporal_patterns_spectrum = temporal_filters_spectrum * input_spectrum
output_spectrum = torch.pow(temporal_filters_spectrum, 2) * input_spectrum

In [36]:
temporal_results_x = create_temporal_patterns(x, spatial_filters, temporal_filters)
temporal_results_data = create_temporal_patterns(data, spatial_filters, temporal_filters, fs=fs, nyquist=nyquist, nbatch=10)

print(diff(input_spectrum, temporal_results_x['input_spectrum']))
print(diff(temporal_filters_spectrum, temporal_results_x['temporal_filters_spectrum']))
print(diff(temporal_patterns_spectrum, temporal_results_x['temporal_patterns_spectrum']))
print(diff(output_spectrum, temporal_results_x['output_spectrum']))

print(diff(input_spectrum, temporal_results_data['input_spectrum']))
print(diff(temporal_filters_spectrum, temporal_results_data['temporal_filters_spectrum']))
print(diff(temporal_patterns_spectrum, temporal_results_data['temporal_patterns_spectrum']))
print(diff(output_spectrum, temporal_results_data['output_spectrum']))

tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.0086)
tensor(0.)
tensor(0.0085)
tensor(0.0085)


# Feature importance

In [37]:
from torch.utils.data import Dataset, DataLoader
from torch.autograd import grad
from envelope_detector.utils import SimpleDataset

In [38]:
import tensorflow as tf
from tensorflow.keras import layers, Input
from tensorflow import keras
from functools import reduce

In [39]:
dataset = SimpleDataset(data)
dataloader = DataLoader(dataset, batch_size=100, shuffle=False)

In [40]:
def create_model_tf(npoints, nchannels, nfeatures, temporal_filter_size=7, downsample_coef=1, output_layer=3):
    inputs = keras.Input(shape=(npoints, nchannels))
    a = layers.Conv1D(nfeatures, 1, padding="same", use_bias=False)(inputs)
    b = layers.BatchNormalization(center=False, scale=False, epsilon=1e-5, momentum=0.9)(a, training=False)
    c = layers.Conv1D(nfeatures, temporal_filter_size, padding="same", groups=nfeatures, use_bias=False)(b)
    d = layers.BatchNormalization(center=False, scale=False, epsilon=1e-5, momentum=0.9)(c, training=False)
    e = layers.LeakyReLU(-1)(d)
    z = layers.AveragePooling1D(pool_size=downsample_coef, strides=downsample_coef)(e)
    y = layers.Conv1D(1, 1, padding="same", use_bias=False)(z)
    model = keras.Model(inputs, (a, b, c, d, e, z, y), name=f'tf_model')
    model.compile()
    return model

In [41]:
def create_model_tf_grad(npoints, nchannels, nfeatures, temporal_filter_size=7, downsample_coef=1, output_layer=3):
    inputs = keras.Input(shape=(npoints, nchannels))
    a = layers.Conv1D(nfeatures, 1, padding="same", use_bias=False)(inputs)
    b = layers.BatchNormalization(center=False, scale=False, epsilon=1e-5, momentum=0.9)(a, training=False)
    c = layers.Conv1D(nfeatures, temporal_filter_size, padding="same", groups=nfeatures, use_bias=False)(b)
    d = layers.BatchNormalization(center=False, scale=False, epsilon=1e-5, momentum=0.9)(c, training=False)
    e = layers.LeakyReLU(-1)(d)
    z = layers.AveragePooling1D(pool_size=downsample_coef, strides=downsample_coef)(e)
    y = layers.Conv1D(1, 1, padding="same", use_bias=False)(z)
    model = keras.Model(inputs, (z, y), name=f'tf_model')
    model.compile()
    return model

In [42]:
class SimpleNet(nn.Module):
    def __init__(self, **kwargs):
        super(self.__class__,self).__init__()
        self.envelope_detector = EnvelopeDetector(**kwargs)
        self.regressor = nn.Conv1d(kwargs['nfeatures'], 1, kernel_size=1, bias=False)
        
    def forward(self, x):
        z = self.envelope_detector(x)
        y = self.regressor(z)
        return z, y

In [43]:
model_tf = create_model_tf(
    npoints=npoints, 
    nchannels=nchannels, 
    nfeatures=nfeatures, 
    temporal_filter_size=temporal_filter_size, 
    downsample_coef=downsample_coef, 
    output_layer=output_layer
)

In [58]:
torch.manual_seed(0)
model_pt = SimpleNet(
    nchannels=nchannels, 
    nfeatures=nfeatures, 
    temporal_filter_size=temporal_filter_size,
    downsample_coef=downsample_coef,
    spatial_bias=spatial_bias,
    activation_method='demodulation', 
    downsample_method='avepool',
)
model_pt.eval();

In [59]:
def convert_conv_weights(x):
    return einops.rearrange(x, 'k i o -> o i k')

spatial_filter = model_pt.envelope_detector.spatial_filter.weight.data.numpy()
model_tf.layers[1].set_weights([convert_conv_weights(spatial_filter)])
temporal_filter = model_pt.envelope_detector.temporal_filter.weight.data.numpy()
model_tf.layers[3].set_weights([convert_conv_weights(temporal_filter)])
regressor_filter = model_pt.regressor.weight.data.numpy()
model_tf.layers[7].set_weights([convert_conv_weights(regressor_filter)])

In [60]:
with torch.no_grad():
    a_pt = model_pt.envelope_detector.spatial_filter(torch.tensor(data, dtype=torch.float32))
    b_pt = model_pt.envelope_detector.spatial_filter_batchnorm(a_pt)
    c_pt = model_pt.envelope_detector.temporal_filter(b_pt)
    d_pt = model_pt.envelope_detector.temporal_filter_batchnorm(c_pt)
    e_pt = model_pt.envelope_detector.activation(d_pt)
    z_pt = model_pt.envelope_detector.downsampler(e_pt)
    y_pt = model_pt.regressor(z_pt)
    values_pt = [x.numpy() for x in [a_pt, b_pt, c_pt, d_pt, e_pt, z_pt, y_pt]]

In [61]:
def tensor_tf2pt(x):
    return einops.rearrange(x, 'b t c -> b c t')

In [62]:
values_tf = [tensor_tf2pt(x.numpy()) for x in model_tf(data_tf)]

In [63]:
for (value_pt, value_tf) in zip(values_pt, values_tf):
    print(np.linalg.norm(value_pt - value_tf) / np.sqrt(reduce(lambda x, y: x*y, value_pt.shape)))

6.671769476697886e-08
7.664036621352898e-08
5.291321907243889e-08
6.268088330850478e-08
6.268088330850478e-08
4.170954531752131e-08
2.0895071145950377e-08


In [64]:
with torch.no_grad():
    y_pt_hat = model_pt(torch.tensor(data, dtype=torch.float32))[1].numpy()
print(np.linalg.norm(y_pt_hat - list(values_tf)[-1]) / np.sqrt(reduce(lambda x, y: x*y, y_pt_hat.shape)))

2.0895071145950377e-08


In [65]:
torch.manual_seed(0)
model_pt = SimpleNet(
    nchannels=nchannels, 
    nfeatures=nfeatures, 
    temporal_filter_size=temporal_filter_size,
    downsample_coef=downsample_coef,
    spatial_bias=spatial_bias,
    activation_method='demodulation', 
    downsample_method='avepool',
)
model_pt.eval();

In [66]:
importance_indices_pt, gradients_pt = create_importance_indices(model_pt, data, order=1, nbatch=100, device='cpu')

In [67]:
gradients_pt

tensor([6646.1860, 9556.2588, 3808.4402, 4501.4277, 1986.8008])

In [68]:
model_tf_grad = create_model_tf_grad(
    npoints=npoints, 
    nchannels=nchannels, 
    nfeatures=nfeatures, 
    temporal_filter_size=temporal_filter_size, 
    downsample_coef=downsample_coef, 
    output_layer=output_layer
)

In [69]:
spatial_filter = model_pt.envelope_detector.spatial_filter.weight.data.numpy()
model_tf_grad.layers[1].set_weights([convert_conv_weights(spatial_filter)])
temporal_filter = model_pt.envelope_detector.temporal_filter.weight.data.numpy()
model_tf_grad.layers[3].set_weights([convert_conv_weights(temporal_filter)])
regressor_filter = model_pt.regressor.weight.data.numpy()
model_tf_grad.layers[7].set_weights([convert_conv_weights(regressor_filter)])

In [70]:
z_tf, y_tf = model_tf_grad(data_tf)

In [71]:
with tf.GradientTape() as tape:
    z_tf, y_tf = model_tf_grad(data_tf)
gradients_tf_ = tape.gradient(y_tf, z_tf)
gradients_tf = np.sum(np.abs(gradients_tf_), axis=(0,1))

In [72]:
gradients_tf

array([6646.95  , 9553.332 , 3808.5874, 4502.1646, 1987.5853],
      dtype=float32)