Skip to content

Commit

Permalink
Black formatting preparing for github action run
Browse files Browse the repository at this point in the history
  • Loading branch information
JesperDramsch committed Dec 6, 2022
1 parent ca4495c commit 24a4cfe
Show file tree
Hide file tree
Showing 12 changed files with 619 additions and 711 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ venv

.vscode

dist/*
dist/*

__pycache__/*
2 changes: 2 additions & 0 deletions complexnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#
# What this module includes by default:
from . import bn, conv, dense, init, norm, pool

# from . import fft

from .bn import ComplexBatchNormalization as ComplexBN
Expand All @@ -17,6 +18,7 @@
WeightNorm_Conv,
)
from .dense import ComplexDense

# from .fft import (fft, ifft, fft2, ifft2, FFT, IFFT, FFT2, IFFT2)
from .init import (
ComplexIndependentFilters,
Expand Down
292 changes: 150 additions & 142 deletions complexnn/bn.py

Large diffs are not rendered by default.

79 changes: 18 additions & 61 deletions complexnn/conv.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""conv.py"""
# pylint: disable=protected-access, consider-using-enumerate, too-many-lines

#
# Authors: Chiheb Trabelsi

import tensorflow as tf
from tensorflow.keras import backend as K
Expand All @@ -24,7 +19,7 @@

def conv1d_transpose(
inputs,
filter, # pylint: disable=redefined-builtin
filter,
kernel_size=None,
filters=None,
strides=(1,),
Expand Down Expand Up @@ -76,7 +71,7 @@ def conv1d_transpose(

def conv2d_transpose(
inputs,
filter, # pylint: disable=redefined-builtin
filter,
kernel_size=None,
filters=None,
strides=(1, 1),
Expand Down Expand Up @@ -121,9 +116,7 @@ def conv2d_transpose(
output_shape = (batch_size, out_height, out_width, filters)

filter = K.permute_dimensions(filter, (0, 1, 3, 2))
return K.conv2d_transpose(
inputs, filter, output_shape, strides, padding=padding, data_format=data_format
)
return K.conv2d_transpose(inputs, filter, output_shape, strides, padding=padding, data_format=data_format)


def ifft(f):
Expand All @@ -136,9 +129,7 @@ def ifft2(f):
raise NotImplementedError(str(f))


def conv_transpose_output_length(
input_length, filter_size, padding, stride, dilation=1, output_padding=None
):
def conv_transpose_output_length(input_length, filter_size, padding, stride, dilation=1, output_padding=None):
"""Rearrange arguments for compatibility with conv_output_length."""
if dilation != 1:
msg = f"Dilation must be 1 for transposed convolution. "
Expand Down Expand Up @@ -287,14 +278,8 @@ def __init__(
self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, "kernel_size")
self.strides = conv_utils.normalize_tuple(strides, rank, "strides")
self.padding = conv_utils.normalize_padding(padding)
self.data_format = (
"channels_last"
if rank == 1
else conv_utils.normalize_data_format(data_format)
)
self.dilation_rate = conv_utils.normalize_tuple(
dilation_rate, rank, "dilation_rate"
)
self.data_format = "channels_last" if rank == 1 else conv_utils.normalize_data_format(data_format)
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank, "dilation_rate")
self.activation = activations.get(activation)
self.use_bias = use_bias
self.normalize_weight = normalize_weight
Expand Down Expand Up @@ -336,10 +321,7 @@ def build(self, input_shape):
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError(
"The channel dimension of the inputs "
"should be defined. Found `None`."
)
raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.")
# Divide by 2 for real and complex input.
input_dim = input_shape[channel_axis] // 2
if False and self.transposed:
Expand Down Expand Up @@ -421,9 +403,7 @@ def build(self, input_shape):
self.bias = None

# Set input spec.
self.input_spec = InputSpec(
ndim=self.rank + 2, axes={channel_axis: input_dim * 2}
)
self.input_spec = InputSpec(ndim=self.rank + 2, axes={channel_axis: input_dim * 2})
self.built = True

def call(self, inputs, **kwargs):
Expand Down Expand Up @@ -457,9 +437,7 @@ def call(self, inputs, **kwargs):
"strides": self.strides[0] if self.rank == 1 else self.strides,
"padding": self.padding,
"data_format": self.data_format,
"dilation_rate": self.dilation_rate[0]
if self.rank == 1
else self.dilation_rate,
"dilation_rate": self.dilation_rate[0] if self.rank == 1 else self.dilation_rate,
}
if self.transposed:
convArgs.pop("dilation_rate", None)
Expand Down Expand Up @@ -518,12 +496,8 @@ def call(self, inputs, **kwargs):
broadcast_mu_imag = K.reshape(mu_imag, broadcast_mu_shape)
reshaped_f_real_centred = reshaped_f_real - broadcast_mu_real
reshaped_f_imag_centred = reshaped_f_imag - broadcast_mu_imag
Vrr = (
K.mean(reshaped_f_real_centred**2, axis=reduction_axes) + self.epsilon
)
Vii = (
K.mean(reshaped_f_imag_centred**2, axis=reduction_axes) + self.epsilon
)
Vrr = K.mean(reshaped_f_real_centred**2, axis=reduction_axes) + self.epsilon
Vii = K.mean(reshaped_f_imag_centred**2, axis=reduction_axes) + self.epsilon
Vri = (
K.mean(
reshaped_f_real_centred * reshaped_f_imag_centred,
Expand Down Expand Up @@ -558,9 +532,7 @@ def call(self, inputs, **kwargs):

cat_kernels_4_real = K.concatenate([f_real, -f_imag], axis=-2)
cat_kernels_4_imag = K.concatenate([f_imag, f_real], axis=-2)
cat_kernels_4_complex = K.concatenate(
[cat_kernels_4_real, cat_kernels_4_imag], axis=-1
)
cat_kernels_4_complex = K.concatenate([cat_kernels_4_real, cat_kernels_4_imag], axis=-1)
if False and self.transposed:
cat_kernels_4_complex._keras_shape = self.kernel_size + (
2 * self.filters,
Expand Down Expand Up @@ -632,9 +604,7 @@ def get_config(self):
"gamma_off_initializer": sanitizedInitSer(self.gamma_off_initializer),
"kernel_regularizer": regularizers.serialize(self.kernel_regularizer),
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
"gamma_diag_regularizer": regularizers.serialize(
self.gamma_diag_regularizer
),
"gamma_diag_regularizer": regularizers.serialize(self.gamma_diag_regularizer),
"gamma_off_regularizer": regularizers.serialize(self.gamma_off_regularizer),
"activity_regularizer": regularizers.serialize(self.activity_regularizer),
"kernel_constraint": constraints.serialize(self.kernel_constraint),
Expand Down Expand Up @@ -1113,10 +1083,7 @@ def build(self, input_shape):
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError(
"The channel dimension of the inputs "
"should be defined. Found `None`."
)
raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.")
input_dim = input_shape[channel_axis]
gamma_shape = (input_dim * self.filters,)
self.gamma = self.add_weight(
Expand All @@ -1134,32 +1101,22 @@ def call(self, inputs):
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError(
"The channel dimension of the inputs "
"should be defined. Found `None`."
)
raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.")
input_dim = input_shape[channel_axis]
ker_shape = self.kernel_size + (input_dim, self.filters)
nb_kernels = ker_shape[-2] * ker_shape[-1]
kernel_shape_4_norm = (np.prod(self.kernel_size), nb_kernels)
reshaped_kernel = K.reshape(self.kernel, kernel_shape_4_norm)
normalized_weight = K.l2_normalize(
reshaped_kernel, axis=0, epsilon=self.epsilon
)
normalized_weight = (
K.reshape(self.gamma, (1, ker_shape[-2] * ker_shape[-1]))
* normalized_weight
)
normalized_weight = K.l2_normalize(reshaped_kernel, axis=0, epsilon=self.epsilon)
normalized_weight = K.reshape(self.gamma, (1, ker_shape[-2] * ker_shape[-1])) * normalized_weight
shaped_kernel = K.reshape(normalized_weight, ker_shape)
shaped_kernel._keras_shape = ker_shape

convArgs = {
"strides": self.strides[0] if self.rank == 1 else self.strides,
"padding": self.padding,
"data_format": self.data_format,
"dilation_rate": self.dilation_rate[0]
if self.rank == 1
else self.dilation_rate,
"dilation_rate": self.dilation_rate[0] if self.rank == 1 else self.dilation_rate,
}
convFunc = {1: K.conv1d, 2: K.conv2d, 3: K.conv3d}[self.rank]
output = convFunc(inputs, shaped_kernel, **convArgs)
Expand Down

0 comments on commit 24a4cfe

Please sign in to comment.