In [1]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Dense, GRU, Conv2D, ZeroPadding2D, Conv2DTranspose, BatchNormalization, SeparableConv2D
from tensorflow.keras.activations import elu, relu, sigmoid, tanh
from tensorflow.keras.layers import Layer
from keras.regularizers import l2
from keras.initializers import HeUniform

import numpy as np
import random
import import_ipynb

In [None]:
# reg = l2(1e-5)

In [1]:
def grouped_conv2d_transpose(inputs, groups, convkwargs):
    """Performs grouped transposed convolution.

    Args:
        inputs: A `Tensor` of shape `[batch_size, h, w, c]`.
        filters: The number of convolutional filters.
        kernel_size: The spatial size of the convolutional kernel.
        strides: The convolutional stride.
        groups: The number of groups to use in the grouped convolution step.
            The input channel count needs to be evenly divisible by `groups`.
    Returns:
        A `Tensor` of shape `[batch_size, new_h, new_w, filters]`.
    """
    splits = tf.split(inputs, groups, axis=-1)
    convolved_splits = [
        Conv2DTranspose(**convkwargs)(split) for split in splits
    ]
    return tf.concat(convolved_splits, -1)

def grouped_conv2d(inputs, groups, convkwargs):
    
    splits = tf.split(inputs, groups, axis=-1)
    convolved_splits = [
        Conv2D(**convkwargs)(split) for split in splits
    ]
    return tf.concat(convolved_splits, -1)

In [5]:
def RELU(x):
    return tf.where(x<0.0, 0.0, x)

In [6]:
def convkxf_old(inputs,
            out_ch: int,
            k: int = 1,
            f: int = 3,
            fstride: int = 2,
            lookahead: int = 0,
            batch_norm: bool = False,
            act = 'relu',
            mode = "normal",
            depthwise: bool = True,
            complex_in: bool = False,
            reshape: bool = False,
            name: str = 'conv',
            training=True
           ):
    in_ch = inputs.get_shape()[-1]
    bias = batch_norm is False
    stride = 1 if f == 1 else (1, fstride)
    fpad = (f - 1) // 2
#     convpad = (0, fpad)
        
    if depthwise: groups = min(in_ch, out_ch)
    else: groups = 1
        
    if in_ch % groups != 0 or out_ch % groups != 0: groups = 1    
    if complex_in and groups % 2 == 0: groups //= 2
        
    convkwargs = {
        "filters": out_ch//groups,
        "kernel_size": (k, f),
        "strides": stride,
        "use_bias": bias,
        "padding": 'same',
#         "name": name,
    }
    
    if mode == "normal":
        if not training:
            convkwargs = {"filters": out_ch//groups,
                          "kernel_size": (k, f),
                          "strides": stride,
                          "use_bias": bias,
                          "padding": 'valid'}
            if fpad>0:
                paddings = tf.constant([[0,0],[0,0],[fpad,fpad],[0,0]])
                inputs = tf.pad(inputs, paddings, "CONSTANT", constant_values=0)
        
        if groups>1: conv_out = grouped_conv2d(inputs, groups, convkwargs)
        else: conv_out = grouped_conv2d(inputs, groups, convkwargs)
        
    elif mode == "transposed":
        if groups>1: conv_out = grouped_conv2d_transpose(inputs, groups, convkwargs)
        else: conv_out = Conv2DTranspose(**convkwargs)(inputs)        
    else:
        raise NotImplementedError()
        
    print(convkwargs)
    
    if groups>1: 
        conv_out = Conv2D(out_ch, kernel_size=1, trainable = training,
                          use_bias=False, name=name + '_1x1')(conv_out)
    if batch_norm: 
        conv_out = BatchNormalization(momentum=0.1, epsilon=1e-5, name=name+'BN')(conv_out)
        
    if act == 'relu': 
        conv_out = RELU(conv_out)
    elif act == 'sigmoid': conv_out = sigmoid(conv_out)
    else: pass
    
    if reshape: 
        conv_out = tf.squeeze(conv_out,-1)
        
    return conv_out

In [None]:
def convkxf(inputs,
            out_ch: int,
            k: int = 1,
            f: int = 3,
            fstride: int = 2,
            lookahead: int = 0,
            batch_norm: bool = False,
            act = 'relu',
            mode = "normal",
            depthwise: bool = True,
            complex_in: bool = False,
            reshape: bool = False,
            name: str = 'conv',
            training=True,
            bias: bool = True,
           ):
    in_ch = inputs.get_shape()[-1]
#     bias = batch_norm is False
    stride = 1 if f == 1 else (1, fstride)
    fpad = (f - 1) // 2 # freq
        
    if depthwise: groups = min(in_ch, out_ch)
    else: groups = 1
        
    if in_ch % groups != 0 or out_ch % groups != 0: groups = 1    
    if complex_in and groups % 2 == 0: groups //= 2
    
    if mode == "normal":
        convkwargs = {
#         "filters": out_ch//groups,
        "filters": out_ch,
        "groups": groups,
        "kernel_size": (k, f),
        "strides": stride,
        "use_bias": bias,
        "padding": 'valid',
        "kernel_initializer": 'he_normal',
        "name": name
        }
        if training:
            paddings = tf.constant([[0,0],[k-1-lookahead,lookahead],[0,0],[0,0]])
            pad_flag = (0, 0, k - 1 - lookahead, lookahead)
            if any(p > 0 for p in pad_flag):
                inputs = tf.pad(inputs, paddings, "CONSTANT", constant_values=0)
            
        if fpad>0:
            paddings = tf.constant([[0,0],[0,0],[fpad,fpad],[0,0]])
            inputs = tf.pad(inputs, paddings, "CONSTANT", constant_values=0)
#         if groups>1: 
#             conv_out = grouped_conv2d(inputs, groups, convkwargs)
#         else: conv_out = Conv2D(**convkwargs)(inputs)
        conv_out = Conv2D(**convkwargs)(inputs)
        
    elif mode == "transposed":
        convkwargs = {
        "filters": out_ch//groups,
        "kernel_size": (k, f),
        "strides": stride,
        "use_bias": bias,
        "padding": 'same',
        "kernel_initializer": 'he_normal',
        }
        if groups>1: conv_out = grouped_conv2d_transpose(inputs, groups, convkwargs)
        else: conv_out = Conv2DTranspose(**convkwargs)(inputs) 

    else:
        raise NotImplementedError()
            
    if groups>1: 
        conv_out = Conv2D(out_ch, kernel_size=1, trainable = training,
                          use_bias=False, name=name + '_1x1')(conv_out)
    if batch_norm: 
        conv_out = BatchNormalization(momentum=0.1, epsilon=1e-5)(conv_out)
        
    if act == 'relu': 
        conv_out = RELU(conv_out)
    elif act == 'sigmoid': conv_out = sigmoid(conv_out)
    else: pass
    
    if reshape: 
        conv_out = tf.squeeze(conv_out,-1)
        
    return conv_out

In [7]:
def GroupGRU(inputs, hidden, groups, return_sequences=True, name='GGRU', reshape=True, count=0, training=True):
    groups_inputs = tf.split(inputs, groups, axis=-1)
    gru_list = []
    grukwargs = {
        "units": hidden//groups,
        "return_sequences": True,
        "kernel_initializer": 'he_normal',
    }
    for group_inputs in groups_inputs: # group number
        if not training: gru_tmp = GRU(**grukwargs, stateful=True, name=name + '_' +str(count))(group_inputs)
        else: gru_tmp = GRU(**grukwargs, name=name + '_' +str(count))(group_inputs)
        gru_list.append(gru_tmp)
        count+=1
    gru_output = tf.stack(gru_list, axis=-1, name=name + '_stack')
    if reshape:
        gru_output = tf.transpose(gru_output, [0, 1, 3, 2], name=name + '_transpose')
    
    shape = [tf.shape(gru_output)[l] for l in range(4)]
#     shape = gru_output.get_shape().as_list()
    gru_layer_output = tf.reshape(gru_output, shape = [shape[0], shape[1], shape[-1]*shape[-2]], name=name + '_reshape')
    
    if return_sequences:
        return gru_layer_output
    else:
        return gru_layer_output[:,-1]

In [8]:
def GroupGRULayer(inputs, hidden, groups, num_layer, name='GGRU', add_output=True, training=True):
    for i in range(num_layer):
        if i < num_layer-1: reshape = True
        else: reshape = False
        
        inputs = GroupGRU(inputs, hidden, groups, return_sequences=True, name=name + str(i), \
                          reshape=reshape, training=training)
        if add_output:
            if i == 0 : output = inputs
            else: output += inputs
                
        else: output = inputs
        
    return output

In [9]:
def GroupFC(inputs, hidden, groups=8, activation=None, name='GFC', count=0):
    groups_inputs = tf.split(inputs, groups, axis=-1)
    FC_list = []
    fckwargs = {
    "units": hidden//groups,
    "kernel_initializer": 'he_normal',
    }
    for group_inputs in groups_inputs: # group number
        FC_tmp = Dense(**fckwargs, name=name + '_' +str(count))(group_inputs)
        FC_list.append(FC_tmp)
        count+=1
    FC_output = tf.stack(FC_list, axis=-1, name=name + '_stack')
    rearange = tf.transpose(FC_output, [0, 1, 3, 2], name=name + '_transpose')
    
    shape = [tf.shape(rearange)[l] for l in range(4)]
#     shape = rearange.get_shape().as_list()
    FC_layer_output = tf.reshape(rearange, shape = [shape[0], shape[1], shape[-1]*shape[-2]], name=name + '_reshape')
    
    if activation == 'relu': 
        FC_layer_output = RELU(FC_layer_output)
#         FC_layer_output = tf.where(FC_layer_output < 0, 0.0, FC_layer_output)
    elif activation == 'sigmoid': FC_layer_output = sigmoid(FC_layer_output)
    else: FC_layer_output = FC_layer_output
        
    return FC_layer_output