In [1]:
import jax.numpy as jnp
import numpy as np
import haiku as hk
import jax
from haiku._src.utils import get_channel_index

In [143]:
inputs_channel_first.shape

(64, 3, 200, 100)

In [147]:
channel_index = 1
[1 if i!=channel_index else inputs_channel_first.shape[channel_index] for i in range(inputs_channel_first.ndim)]

[1, 3, 1, 1]

In [210]:
get_channel_index('CNHW')

ValueError: Unable to extract channel information from 'CNHW'. Valid data formats are spatial (e.g.`NCHW`), sequential (e.g. `BTHWD`), `channels_first` and `channels_last`).

In [224]:
def testing_batchnorm(inputs, data_format='NCHW'):
    
    channel_index = get_channel_index(data_format)
    if channel_index == 1: channel_first = True
    if channel_index < 0:
        channel_index += inputs.ndim
    axis = [i for i in range(inputs.ndim) if i != channel_index]

    inputs = jnp.array([inputs.real, inputs.imag])
    channel_index += 1
    axis = [i+1 for i in axis] 
    print('channel_index: ', channel_index)
    print('axis: ', axis)
    print('input_shape: ', inputs.shape)
    
    mean = jnp.mean(inputs, axis, keepdims=True)
    print('Mean: ', mean)
    
    centered_inputs = inputs - mean
    
    variances = (centered_inputs * centered_inputs).mean(axis) + 1e-5
    Var_Rez = variances[0]
    Var_Imz = variances[1]
    print('V_rr: ', Var_Rez)
    print('V_ii: ', Var_Imz)
    
    Cov_ReIm = Cov_ImRe = (centered_inputs[0] * centered_inputs[1]).mean([a-1 for a in axis])
    print('V_ri = V_ir: ', Cov_ReIm.shape)
    
    covariance_matrix = jnp.array( [[Var_Rez, Cov_ReIm], [Cov_ImRe, Var_Imz]] )
    print('Covariance matrix: ', covariance_matrix.reshape(2,2,-1).shape)

    print('\n')
    sqrt_det = jnp.sqrt(Var_Rez * Var_Imz - Cov_ReIm * Cov_ImRe)
    sqrt_tr  = jnp.sqrt(Var_Rez + Var_Imz + 2*sqrt_det)
    print(sqrt_det, sqrt_tr)
    
    denom = sqrt_det * jnp.sqrt(Var_Rez + 2 * sqrt_det + Var_Imz)
    print('denom: ', denom)
    #print('my denom: ', 1 / jnp.sqrt(sqrt_tr / sqrt_det))
    
    inverse_root_covmat = jnp.array([[Var_Imz + sqrt_det, - Cov_ReIm], 
                                     [- Cov_ImRe, Var_Rez + sqrt_det]]).reshape(2,2,-1)
    inverse_root_covmat /= denom
    print(inverse_root_covmat.shape, centered_inputs.shape)
    if channel_first:
        einstein_formula = 'ijk,jlk...->ilk...'
    else:
        einstein_formula = 'ij...,j...->i...'
    normalized_input = jnp.einsum(einstein_formula, inverse_root_covmat, centered_inputs)
    
    print([1 if i in axis else inputs.shape[i] for i in range(inputs.ndim)])
    
    return normalized_input

In [176]:
inputs_nc = jax.random.normal(jax.random.PRNGKey(0), (64,200), jnp.csingle)

In [203]:
inputs_channel_first = jax.random.normal(jax.random.PRNGKey(0), (64,3,200,100), jnp.csingle) * (24-0.6j) -3

In [164]:
inputs_real = jax.random.uniform(jax.random.PRNGKey(1), (64,200,100,3))
inputs_imag = jax.random.uniform(jax.random.PRNGKey(0), (64,200,100,3))
inputs_channel_last = jax.lax.complex(inputs_real, inputs_imag)

In [225]:
normalized_input = testing_batchnorm(inputs_channel_first, data_format='NCHW')

channel_index:  2
axis:  [1, 3, 4]
input_shape:  (2, 64, 3, 200, 100)
Mean:  [[[[[-2.9938722 ]]

   [[-3.015481  ]]

   [[-3.0342317 ]]]]



 [[[[-0.01056103]]

   [[ 0.01909198]]

   [[-0.00530588]]]]]
V_rr:  [288.5507  288.13956 287.8861 ]
V_ii:  [287.93304 288.0765  288.1567 ]
V_ri = V_ir:  (3,)
Covariance matrix:  (2, 2, 3)


[288.24164 288.10785 288.0212 ] [33.955368 33.947487 33.94238 ]
denom:  [9787.351 9780.537 9776.125]
(2, 2, 3) (2, 64, 3, 200, 100)
[2, 1, 3, 1, 1]


In [220]:
print(normalized_input)
a, b, c, d = normalized_input.reshape(4,-1)

[[[ 2.8855069e+02  2.8813956e+02  2.8788611e+02]
  [ 1.8422553e-01 -3.0165416e-01 -3.0266324e-01]]

 [[ 1.8422553e-01 -3.0165416e-01 -3.0266324e-01]
  [ 2.8793304e+02  2.8807651e+02  2.8815671e+02]]]


In [222]:
normalized_input[:,:,0]

DeviceArray([[2.8855069e+02, 1.8422553e-01],
             [1.8422553e-01, 2.8793304e+02]], dtype=float32)

In [223]:
print(a,b,c,d)

[288.5507  288.13956 287.8861 ] [ 0.18422553 -0.30165416 -0.30266324] [ 0.18422553 -0.30165416 -0.30266324] [287.93304 288.0765  288.1567 ]


In [205]:
normalized_input.shape

(2, 64, 3, 200, 100)

In [207]:
axis = [0,2,3]
b = jax.lax.complex(normalized_input[0], normalized_input[1])
Vrr = jnp.var(b.real, axis)
print(Vrr)
Vii = jnp.var(b.imag, axis)
print(Vii)
Vri = jnp.mean(b.real*b.imag, axis) - jnp.mean(b.real, axis)*jnp.mean(b.imag, axis)
print(Vri)
Vir = jnp.mean(b.imag*b.real, axis) - jnp.mean(b.imag, axis)*jnp.mean(b.real, axis)
print(Vir)

[0.9999999 1.0000006 1.0000006]
[1.0000011 0.9999999 1.0000005]
[-2.8327107e-09  1.8060190e-09  2.9355292e-09]
[-2.8327107e-09  1.8060190e-09  2.9355292e-09]


In [71]:


inverse_root_covmat = jnp.sqrt(sqrt_tr / sqrt_det) * jnp.array([[Var_Imz + sqrt_det, - Cov_ReIm], 
                                                        [- Cov_ImRe, Var_Rez + sqrt_det]]).reshape(2,2,-1)

In [74]:
jnp.var(reconstructed_input, [0,2,3])

DeviceArray([312149.4 , 311786.7 , 311552.75], dtype=float32)

In [None]:
(centered_inputs * centered_inputs).mean(axis, keepdims=True).shape

In [None]:
Var_Rez = jnp.var(jnp.expand_dims(inputs[0], 0), axis)
print(Var_Rez.shape)
Var_Imz = jnp.var(jnp.expand_dims(inputs[1], 0), axis)
print(Var_Imz.shape)

In [None]:
Cov_ReIm = jnp.mean(jnp.expand_dims(inputs[0]*inputs[1], 0), axis) - (mean[0]*mean[1]).reshape((Var_Rez.shape))
print(Cov_ReIm.shape)
Cov_ReIm

In [None]:
covariance_matrix = jnp.array( [[Var_Rez[0], Cov_ReIm[0]], [Cov_ReIm[0], Var_Imz[0]]] )
print(covariance_matrix.shape)
covariance_matrix[:,:,1]

In [None]:
eps = jax.lax.convert_element_type(1e-3, covariance_matrix.dtype)
regularized_covmat = covariance_matrix + jnp.broadcast_to(jnp.identity(2)[...,None], covariance_matrix.shape)*eps

In [None]:
regularized_covmat

In [None]:
np.linalg.cholesky(regularized_covmat.T)

In [None]:
jnp.sqrt(regularized_covmat.T)

In [None]:
inverse_convmat = jnp.linalg.inv(jnp.sqrt(regularized_covmat.T))
print(inverse_convmat.shape)

In [None]:
inverse_convmat = jnp.sqrt(jnp.linalg.inv(regularized_covmat.T))
print(inverse_convmat.shape)

In [None]:
inputs.shape

In [None]:
inverse_convmat

In [None]:
normalized_input = jnp.einsum('ijk,kli...->jli...', inverse_convmat, inputs)
print(normalized_input.shape)

In [None]:
normalized_input

In [59]:
b = jax.lax.complex(normalized_input[0], normalized_input[1])
print(b.shape)

(64, 3, 200, 100)


In [60]:
jnp.mean(b, [0,2,3])

DeviceArray([-1.6033649e-08+6.5779687e-08j,  9.0599057e-09+5.9366226e-08j,
             -1.2063980e-08-1.0442734e-08j], dtype=complex64)

In [61]:
jnp.mean((b - jnp.mean(b, [0,2,3], keepdims=True)) * jnp.conjugate(b - jnp.mean(b, [0,2,3], keepdims=True)), [0,2,3])

DeviceArray([2.0000024-1.4729266e-11j, 2.0000036+2.4532746e-11j,
             2.0000024+5.2360902e-12j], dtype=complex64)

In [62]:
jnp.mean((b - jnp.mean(b, [0,2,3], keepdims=True)) * (b - jnp.mean(b, [0,2,3], keepdims=True)), [0,2,3])

DeviceArray([2.0442009e-07+5.47170620e-09j, 9.3636510e-07-1.39236445e-08j,
             1.3365745e-07-6.77108769e-09j], dtype=complex64)

[0.9999997 1.0000012 0.9999997]
[0.99999934 1.0000004  0.9999991 ]
[ 8.8810997e-10 -9.2536256e-10  1.0728834e-09]
[ 8.8810997e-10 -9.2536256e-10  1.0728834e-09]
