In [2]:
import numpy as np
import glob
import os
import soundfile as sf
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn.utils.rnn import pad_packed_sequence
import torch.nn.functional as F

In [None]:
def z3cprelu(x, name='z3cprelu'):
    x_R, x_I = tf.split(x, 2, -1)
    in_shape = x_R.get_shape().as_list()
    with tf.variable_scope(name):
        # make one alpha per feature
        alpha1_R = tf.get_variable('alpha1_R', in_shape[-1],
                                initializer=tf.constant_initializer(0.),
                                dtype=tf.float32)
        alpha1_I = tf.get_variable('alpha1_I', in_shape[-1],
                                initializer=tf.constant_initializer(0.),
                                dtype=tf.float32)
        alpha2_R = tf.get_variable('alpha2_R', in_shape[-1],
                                initializer=tf.constant_initializer(0.),
                                dtype=tf.float32)
        alpha2_I = tf.get_variable('alpha2_I', in_shape[-1],
                                initializer=tf.constant_initializer(0.),
                                dtype=tf.float32)
        alpha3_R = tf.get_variable('alpha3_R', in_shape[-1],
                                initializer=tf.constant_initializer(0.),
                                dtype=tf.float32)
        alpha3_I = tf.get_variable('alpha3_I', in_shape[-1],
                                initializer=tf.constant_initializer(0.),
                                dtype=tf.float32)
                                
        out1_R = (x_R*alpha1_R) - (x_I*alpha1_I)
        out1_I = (x_R*alpha1_I) + (x_I*alpha1_R)
        
        out2_R = (x_R*alpha2_R) - (x_I*alpha2_I)
        out2_I = (x_R*alpha2_I) + (x_I*alpha2_R)
        
        out3_R = (x_R*alpha3_R) - (x_I*alpha3_I)
        out3_I = (x_R*alpha3_I) + (x_I*alpha3_R)
        
        cond1 = tf.logical_and(tf.greater(x_R, 0.), tf.greater(x_I, 0.))
        cond2 = tf.logical_and(tf.less_equal(x_R, 0.), tf.greater(x_I, 0.))
        cond3 = tf.logical_and(tf.less_equal(x_R, 0.), tf.less_equal(x_I, 0.))
        cond4 = tf.logical_and(tf.greater(x_R, 0.), tf.less_equal(x_I, 0.))
        zeros = tf.zeros_like(x_R)
        
        ans_R = tf.where(cond1, x_R, zeros)
        ans_I = tf.where(cond1, x_I, zeros)
        
        ans_R += tf.where(cond2, out1_R, zeros)
        ans_I += tf.where(cond2, out1_I, zeros)
        
        ans_R += tf.where(cond3, out2_R, zeros)
        ans_I += tf.where(cond3, out2_I, zeros)
        
        ans_R += tf.where(cond4, out3_R, zeros)
        ans_I += tf.where(cond4, out3_I, zeros)   
    return tf.concat([ans_R, ans_I], -1)

In [25]:
class Z3PReLU(nn.Module):
    def __init__(self, num_features):
        super(Z3PReLU, self).__init__()
        self.num_features = num_features
        self.alpha1_R = nn.Parameter(0.5*torch.zeros(num_features//2), requires_grad=True)
        self.alpha1_I = nn.Parameter(torch.zeros(num_features//2), requires_grad=True)
        
        self.alpha2_R = nn.Parameter(0.5*torch.zeros(num_features//2), requires_grad=True)
        self.alpha2_I = nn.Parameter(torch.zeros(num_features//2), requires_grad=True)
        
        self.alpha3_R = nn.Parameter(torch.zeros(num_features//2), requires_grad=True)
        self.alpha3_I = nn.Parameter(torch.zeros(num_features//2), requires_grad=True)
    def forward(self, x):
        x = x.contiguous().permute(0, 2, 3, 1)
        x_R, x_I = torch.chunk(x, 2, -1)
        out1_R = (x_R*self.alpha1_R) - (x_I*self.alpha1_I)
        out1_I = (x_R*self.alpha1_I) + (x_I*self.alpha1_R)
        
        out2_R = (x_R*self.alpha2_R) - (x_I*self.alpha2_I)
        out2_I = (x_R*self.alpha2_I) + (x_I*self.alpha2_R)
        
        out3_R = (x_R*self.alpha3_R) - (x_I*self.alpha3_I)
        out3_I = (x_R*self.alpha3_I) + (x_I*self.alpha3_R)
        
        cond1 = torch.gt(x_R, 0.) & torch.gt(x_I, 0.)
        cond2 = torch.le(x_R, 0.) & torch.gt(x_I, 0.)
        cond3 = torch.le(x_R, 0.) & torch.le(x_I, 0.)
        cond4 = torch.gt(x_R, 0.) & torch.le(x_I, 0.)
        zeros = torch.zeros_like(x_R)
        
        ans_R = torch.where(cond1, x_R, zeros)
        ans_I = torch.where(cond1, x_I, zeros)
        
        ans_R += torch.where(cond2, out1_R, zeros)
        ans_I += torch.where(cond2, out1_I, zeros)
        
        ans_R += torch.where(cond3, out2_R, zeros)
        ans_I += torch.where(cond3, out2_I, zeros)
        
        ans_R += torch.where(cond4, out3_R, zeros)
        ans_I += torch.where(cond4, out3_I, zeros) 
        
        ans = torch.cat([ans_R, ans_I], -1)
        ans = ans.contiguous().permute(0, 3, 1, 2)
        return ans

In [26]:
x = torch.gt(torch.zeros(4), 1.)
y = torch.le(torch.ones(4), 2.)

In [36]:
x = torch.cat([torch.ones(16, 8, 8, 4), -1*torch.ones(16, 8, 8, 4)], dim=1)

In [41]:
nl = Z3PReLU(2)

In [49]:
x = np.random.normal(0., 1., [1, 2, 4, 4])

In [50]:
x = torch.from_numpy(x).float()

In [51]:
nl(x)

tensor([[[[0.0000, 0.0000, 0.0000, 1.9358],
          [0.0000, 0.0000, 1.2548, 0.0000],
          [0.0831, 0.0000, 0.0000, 0.0000],
          [2.0417, 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000, 1.7356],
          [0.0000, 0.0000, 0.0049, 0.0000],
          [1.0448, 0.0000, 0.0000, 0.0000],
          [0.8518, 0.0000, 0.0000, 0.0000]]]], grad_fn=<PermuteBackward>)

In [52]:
x

tensor([[[[-1.1122, -0.2530, -1.5160,  1.9358],
          [ 0.2022, -0.7168,  1.2548,  2.6191],
          [ 0.0831, -0.5869, -0.7560, -0.1729],
          [ 2.0417, -1.5679, -2.0188, -1.5051]],

         [[-1.0841,  0.9348, -0.7147,  1.7356],
          [-0.5666,  0.7544,  0.0049, -0.2964],
          [ 1.0448, -1.6356, -0.8429, -0.4969],
          [ 0.8518, -0.7840, -1.0005,  1.3122]]]])