In [6]:
%load_ext autoreload
%autoreload 2

import os, sys, argparse
import time
from functools import partial
import torch
import pickle
import numpy as np

import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from tqdm import tqdm

# add project root dir to sys.path so that all packages can be found by python.
root_dir = os.path.dirname(os.path.dirname(os.path.realpath("__file__")))
sys.path.append(root_dir)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
from model.model_params import ModelArgs
from model.waveunet_params import waveunet_params
import model.utils as model_utils
import utils
from data.dataset import SeparationDataset
from data.musdb import get_musdb_folds
from data.utils import crop_targets, random_amplify

In [130]:
import torch
import torch.nn as nn

from model.crop import centre_crop
from model.resample import Resample1d
from model.conv import ConvLayer

class UpsamplingBlock(nn.Module):
    def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res, num_convs):
        super(UpsamplingBlock, self).__init__()
        assert(stride > 1)
        self.num_convs = num_convs
        # CONV 1 for UPSAMPLING
        if self.num_convs == 1:
            if res == "fixed":
                self.upconv = Resample1d(n_shortcut, 15, stride, transpose=True)
            else:
                self.upconv = ConvLayer(n_shortcut, n_shortcut, kernel_size, stride, conv_type, transpose=True)
        else:
            if res == "fixed":
                self.upconv = Resample1d(n_inputs, 15, stride, transpose=True)
            else:
                self.upconv = ConvLayer(n_inputs, n_inputs, kernel_size, stride, conv_type, transpose=True)

        if self.num_convs == 2:    
            self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_outputs, kernel_size, 1, conv_type)] +
                                                    [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])

            # CONVS to combine high- with low-level information (from shortcut)
            self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_outputs + n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
                                                     [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])
        elif self.num_convs == 1:
            self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_outputs + n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
                                                    [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])
    def forward(self, x, shortcut):
        # UPSAMPLE HIGH-LEVEL FEATURES
        upsampled = self.upconv(x)
        
        if self.num_convs == 2:
            for conv in self.pre_shortcut_convs:
                upsampled = conv(upsampled)

        # Prepare shortcut connection
        combined = centre_crop(shortcut, upsampled)

        # Combine high- and low-level features
        for conv in self.post_shortcut_convs:
            combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1))
        return combined

    def get_output_size(self, input_size):
        curr_size = self.upconv.get_output_size(input_size)

        # Upsampling convs
        if self.num_convs == 2:
            for conv in self.pre_shortcut_convs:
                curr_size = conv.get_output_size(curr_size)

        # Combine convolutions
        for conv in self.post_shortcut_convs:
            curr_size = conv.get_output_size(curr_size)

        return curr_size

class DownsamplingBlock(nn.Module):
    def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res, num_convs):
        super(DownsamplingBlock, self).__init__()
        assert(stride > 1)

        self.kernel_size = kernel_size
        self.stride = stride
        self.num_convs = num_convs

        # CONV 1
        self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_shortcut, kernel_size, 1, conv_type)] +
                                                [ConvLayer(n_shortcut, n_shortcut, kernel_size, 1, conv_type) for _ in range(depth - 1)])

        
        
        if self.num_convs == 2:
            self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
                                                 [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in
                                                  range(depth - 1)])

        if self.num_convs == 1:
            n_outputs = n_shortcut
            
        # CONV 2 with decimation
        if res == "fixed":
            self.downconv = Resample1d(n_outputs, 15, stride) # Resampling with fixed-size sinc lowpass filter
        elif res == "naive":
            #todo: add decimation here
            self.downconv = self.naive_decimation
        else:
            self.downconv = ConvLayer(n_outputs, n_outputs, kernel_size, stride, conv_type)

    def forward(self, x):
        # PREPARING SHORTCUT FEATURES
        shortcut = x
        for conv in self.pre_shortcut_convs:
            shortcut = conv(shortcut)
        
        # PREPARING FOR DOWNSAMPLING
        out = shortcut
        if self.num_convs == 2:
            for conv in self.post_shortcut_convs:
                out = conv(out)

        # DOWNSAMPLING
        out = self.downconv(out)

        return out, shortcut

    def get_input_size(self, output_size):
        curr_size = self.downconv.get_input_size(output_size)
        if self.num_convs == 2:
            for conv in reversed(self.post_shortcut_convs):
                curr_size = conv.get_input_size(curr_size)

        for conv in reversed(self.pre_shortcut_convs):
            curr_size = conv.get_input_size(curr_size)
        return curr_size

    def naive_decimation(self, x):
        # a very naive decimation
        return x[:,::2,:] # Decimate by factor of 2 # out = (in-1)/2 + 1

class Waveunet(nn.Module):
    def __init__(self, num_inputs, num_channels, num_outputs, instruments, upsampling_kernel_size, downsampling_kernel_size, bottleneck_kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2, num_convs=2):
        super(Waveunet, self).__init__()

        self.num_levels = len(num_channels)
        self.strides = strides
        self.upsampling_kernel_size = upsampling_kernel_size
        self.downsampling_kernel_size = downsampling_kernel_size
        self.bottleneck_kernel_size = bottleneck_kernel_size
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.depth = depth
        self.instruments = instruments
        self.separate = separate
        self.num_convs=num_convs

         # Only odd filter kernels allowed
        assert(downsampling_kernel_size % 2 == 1)
        assert(bottleneck_kernel_size % 2 == 1)
        assert(upsampling_kernel_size % 2 == 1)

        self.waveunets = nn.ModuleDict()

        model_list = instruments if separate else ["ALL"]
        # Create a model for each source if we separate sources separately, otherwise only one (model_list=["ALL"])
        for instrument in model_list:
            module = nn.Module()

            module.downsampling_blocks = nn.ModuleList()
            module.upsampling_blocks = nn.ModuleList()

            for i in range(self.num_levels - 1):
                in_ch = num_inputs if i == 0 else num_channels[i]

                module.downsampling_blocks.append(
                    DownsamplingBlock(in_ch, num_channels[i], num_channels[i+1], self.downsampling_kernel_size, strides, depth, conv_type, res, num_convs))

            for i in range(0, self.num_levels - 1):
                module.upsampling_blocks.append(
                    UpsamplingBlock(num_channels[-1-i], num_channels[-2-i], num_channels[-2-i], self.upsampling_kernel_size, strides, depth, conv_type, res, num_convs))

            if self.num_convs == 2:
                module.bottlenecks = nn.ModuleList(
                    [ConvLayer(num_channels[-1], num_channels[-1], self.bottleneck_kernel_size, 1, conv_type) for _ in range(depth)])
            elif self.num_convs == 1:
                module.bottlenecks = nn.ModuleList(
                    [ConvLayer(num_channels[i], num_channels[i], self.bottleneck_kernel_size, 1, conv_type) for _ in range(depth)])
        
            # Output conv
            outputs = num_outputs if separate else num_outputs * len(instruments)
            module.output_conv = nn.Conv1d(num_channels[0], outputs, 1)

            self.waveunets[instrument] = module

        self.set_output_size(target_output_size)

    def set_output_size(self, target_output_size):
        self.target_output_size = target_output_size

        self.input_size, self.output_size = self.check_padding(target_output_size)
        print("Using valid convolutions with " + str(self.input_size) + " inputs and " + str(self.output_size) + " outputs")

        assert((self.input_size - self.output_size) % 2 == 0)
        self.shapes = {"output_start_frame" : (self.input_size - self.output_size) // 2,
                       "output_end_frame" : (self.input_size - self.output_size) // 2 + self.output_size,
                       "output_frames" : self.output_size,
                       "input_frames" : self.input_size}

    def check_padding(self, target_output_size):
        # Ensure number of outputs covers a whole number of cycles so each output in the cycle is weighted equally during training
        bottleneck = 1

        while True:
            out = self.check_padding_for_bottleneck(bottleneck, target_output_size)
            if out is not False:
                return out
            bottleneck += 1

    def check_padding_for_bottleneck(self, bottleneck, target_output_size):
        module = self.waveunets[[k for k in self.waveunets.keys()][0]]
        try:
            curr_size = bottleneck
            for idx, block in enumerate(module.upsampling_blocks):
                curr_size = block.get_output_size(curr_size)
            output_size = curr_size

            # Bottleneck-Conv
            curr_size = bottleneck
            for block in reversed(module.bottlenecks):
                curr_size = block.get_input_size(curr_size)
            for idx, block in enumerate(reversed(module.downsampling_blocks)):
                curr_size = block.get_input_size(curr_size)

            assert(output_size >= target_output_size)
            return curr_size, output_size
        except AssertionError as e:
            return False

    def forward_module(self, x, module):
        '''
        A forward pass through a single Wave-U-Net (multiple Wave-U-Nets might be used, one for each source)
        :param x: Input mix
        :param module: Network module to be used for prediction
        :return: Source estimates
        '''
        shortcuts = []
        out = x

        # DOWNSAMPLING BLOCKS
        for block in module.downsampling_blocks:
            out, short = block(out)
            shortcuts.append(short)

        # BOTTLENECK CONVOLUTION
        for conv in module.bottlenecks:
            out = conv(out)

        # UPSAMPLING BLOCKS
        for idx, block in enumerate(module.upsampling_blocks):
            out = block(out, shortcuts[-1 - idx])

        # OUTPUT CONV
        out = module.output_conv(out)
        if not self.training:  # At test time clip predictions to valid amplitude range
            out = out.clamp(min=-1.0, max=1.0)
        return out

    def forward(self, x, inst=None):
        curr_input_size = x.shape[-1]
        assert(curr_input_size == self.input_size) # User promises to feed the proper input himself, to get the pre-calculated (NOT the originally desired) output size

        if self.separate:
            return {inst : self.forward_module(x, self.waveunets[inst])}
        else:
            assert(len(self.waveunets) == 1)
            out = self.forward_module(x, self.waveunets["ALL"])

            out_dict = {}
            for idx, inst in enumerate(self.instruments):
                out_dict[inst] = out[:, idx * self.num_outputs:(idx + 1) * self.num_outputs]
            return out_dict


In [131]:
from data.musdb_loader import setup_hq_musdb

In [132]:
def _create_waveunet(args):
    num_features = [args.features * i for i in range(1, args.levels + 1)] if args.feature_growth == "add" else \
        [args.features * 2 ** i for i in range(0, args.levels)]
    target_outputs = int(args.output_size * args.sr)
    model = Waveunet(args.channels, num_features, args.channels, args.instruments, downsampling_kernel_size=args.downsampling_kernel_size,
                     upsampling_kernel_size=args.upsampling_kernel_size, bottleneck_kernel_size=args.bottleneck_kernel_size,
                     target_output_size=target_outputs, depth=args.depth, strides=args.strides,
                     conv_type=args.conv_type, res=args.res, separate=args.separate, num_convs=args.num_convs)

    if args.cuda:
        model = model_utils.DataParallel(model)
        print("move model to gpu")
        model.cuda()

    print('model: ', model)
    print('parameter count: ', str(sum(p.numel() for p in model.parameters())))
    return model

def _load_musdb(args, data_shapes):
    # musdb = get_musdb_folds(args.dataset_dir, version="toy")
    musdb = None
    # If not data augmentation, at least crop targets to fit model output shape
    crop_func = partial(crop_targets, shapes=data_shapes)
    # Data augmentation function for training
    augment_func = partial(random_amplify, shapes=data_shapes, min=0.7, max=1.0)

    test_data = SeparationDataset(musdb, "test", args.instruments, args.sr, args.channels, data_shapes, False,
                                  args.hdf_dir, audio_transform=crop_func)
    return test_data, musdb

In [139]:
args = waveunet_params.get_defaults()
args.instruments = ["accompaniment", "vocals"]
args.sr=22050 
args.channels= 1 
args.output_size= 2 
args.patience= 20 
args.separate= 0 
args.features= 24 
args.lr= 1e-4 
args.min_lr= 1e-4
args.batch_size= 16 
args.levels= 2  
args.depth=1 
args.kernel_size= 5 
args.strides= 2 
args.loss= "L2" 
args.conv_type= "normal" 
args.dataset_dir = '../data/musdb'
args.hdf_dir = '../data/hdf/'
args.num_convs=2

In [140]:
model = _create_waveunet(args)
test_data, musdb = _load_musdb(args, model.shapes)
sample_track = test_data[0]
dataloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True)
for example_num, (x, targets) in enumerate(dataloader):
    sample = (x, targets)
    break

Using valid convolutions with 44125 inputs and 44101 outputs
model:  Waveunet(
  (waveunets): ModuleDict(
    (ALL): Module(
      (downsampling_blocks): ModuleList(
        (0): DownsamplingBlock(
          (pre_shortcut_convs): ModuleList(
            (0): ConvLayer(
              (filter): Conv1d(1, 24, kernel_size=(5,), stride=(1,))
            )
          )
          (post_shortcut_convs): ModuleList(
            (0): ConvLayer(
              (filter): Conv1d(24, 48, kernel_size=(5,), stride=(1,))
            )
          )
          (downconv): Resample1d()
        )
      )
      (upsampling_blocks): ModuleList(
        (0): UpsamplingBlock(
          (upconv): Resample1d()
          (pre_shortcut_convs): ModuleList(
            (0): ConvLayer(
              (filter): Conv1d(48, 24, kernel_size=(5,), stride=(1,))
            )
          )
          (post_shortcut_convs): ModuleList(
            (0): ConvLayer(
              (filter): Conv1d(48, 24, kernel_size=(5,), stride=(1,))


In [141]:
model(sample[0], inst="vocals")

{'accompaniment': tensor([[[-0.0637, -0.0651, -0.0651,  ..., -0.0665, -0.0663, -0.0643]],
 
         [[-0.0603, -0.0599, -0.0596,  ..., -0.0645, -0.0646, -0.0640]],
 
         [[-0.0604, -0.0602, -0.0601,  ..., -0.0638, -0.0633, -0.0628]],
 
         ...,
 
         [[-0.0608, -0.0610, -0.0610,  ..., -0.0627, -0.0626, -0.0625]],
 
         [[-0.0640, -0.0641, -0.0641,  ..., -0.0612, -0.0613, -0.0607]],
 
         [[-0.0601, -0.0608, -0.0605,  ..., -0.0611, -0.0610, -0.0611]]],
        grad_fn=<SliceBackward0>),
 'vocals': tensor([[[0.0381, 0.0381, 0.0384,  ..., 0.0411, 0.0418, 0.0419]],
 
         [[0.0361, 0.0360, 0.0359,  ..., 0.0451, 0.0451, 0.0455]],
 
         [[0.0376, 0.0361, 0.0339,  ..., 0.0425, 0.0422, 0.0417]],
 
         ...,
 
         [[0.0394, 0.0394, 0.0394,  ..., 0.0448, 0.0449, 0.0450]],
 
         [[0.0448, 0.0448, 0.0449,  ..., 0.0345, 0.0345, 0.0348]],
 
         [[0.0387, 0.0393, 0.0394,  ..., 0.0401, 0.0397, 0.0392]]],
        grad_fn=<SliceBackward0>)}

In [136]:
args = waveunet_params.get_defaults()
args.instruments = ["accompaniment", "vocals"]
args.sr=22050 
args.channels= 1 
args.output_size= 2 
args.patience= 20 
args.separate= 0 
args.features= 24 
args.lr= 1e-4 
args.min_lr= 1e-4
args.batch_size= 16 
args.levels= 2  
args.depth= 1 
args.kernel_size= 5 
args.strides= 2 
args.loss= "L2" 
args.conv_type= "normal" 
args.dataset_dir = '../data/musdb'
args.hdf_dir = '../data/hdf/'
args.num_convs=1

In [137]:
new_model = _create_waveunet(args)
test_data, musdb = _load_musdb(args, new_model.shapes)
sample_track = test_data[0]
dataloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True)
for example_num, (x, targets) in enumerate(dataloader):
    sample = (x, targets)
    break

Using valid convolutions with 44117 inputs and 44101 outputs
model:  Waveunet(
  (waveunets): ModuleDict(
    (ALL): Module(
      (downsampling_blocks): ModuleList(
        (0): DownsamplingBlock(
          (pre_shortcut_convs): ModuleList(
            (0): ConvLayer(
              (filter): Conv1d(1, 24, kernel_size=(5,), stride=(1,))
            )
          )
          (downconv): Resample1d()
        )
      )
      (upsampling_blocks): ModuleList(
        (0): UpsamplingBlock(
          (upconv): Resample1d()
          (post_shortcut_convs): ModuleList(
            (0): ConvLayer(
              (filter): Conv1d(48, 24, kernel_size=(5,), stride=(1,))
            )
          )
        )
      )
      (bottlenecks): ModuleList(
        (0): ConvLayer(
          (filter): Conv1d(24, 24, kernel_size=(5,), stride=(1,))
        )
      )
      (output_conv): Conv1d(24, 2, kernel_size=(1,), stride=(1,))
    )
  )
)
parameter count:  9602


In [138]:
new_model(sample[0], inst="vocals")

{'accompaniment': tensor([[[-0.1760, -0.1798, -0.1865,  ..., -0.1842, -0.1814, -0.1796]],
 
         [[-0.1813, -0.1821, -0.1828,  ..., -0.1800, -0.1803, -0.1837]],
 
         [[-0.1843, -0.1833, -0.1831,  ..., -0.1828, -0.1772, -0.1764]],
 
         ...,
 
         [[-0.1829, -0.1825, -0.1831,  ..., -0.1830, -0.1829, -0.1831]],
 
         [[-0.1822, -0.1829, -0.1828,  ..., -0.1825, -0.1824, -0.1827]],
 
         [[-0.1830, -0.1830, -0.1830,  ..., -0.1830, -0.1830, -0.1830]]],
        grad_fn=<SliceBackward0>),
 'vocals': tensor([[[-0.1176, -0.1195, -0.1193,  ..., -0.1249, -0.1224, -0.1215]],
 
         [[-0.1126, -0.1148, -0.1164,  ..., -0.1148, -0.1165, -0.1199]],
 
         [[-0.1127, -0.1126, -0.1123,  ..., -0.1154, -0.1108, -0.1099]],
 
         ...,
 
         [[-0.1170, -0.1162, -0.1168,  ..., -0.1131, -0.1132, -0.1131]],
 
         [[-0.1126, -0.1131, -0.1145,  ..., -0.1182, -0.1173, -0.1172]],
 
         [[-0.1155, -0.1156, -0.1155,  ..., -0.1155, -0.1156, -0.1155]]],
        

{'accompaniment': tensor([[[-0.0406, -0.0441, -0.0457,  ..., -0.0434, -0.0455, -0.0496]],
 
         [[-0.0439, -0.0450, -0.0451,  ..., -0.0401, -0.0407, -0.0386]],
 
         [[-0.0356, -0.0370, -0.0389,  ..., -0.0396, -0.0401, -0.0402]],
 
         ...,
 
         [[-0.0388, -0.0388, -0.0388,  ..., -0.0388, -0.0388, -0.0388]],
 
         [[-0.0471, -0.0463, -0.0454,  ..., -0.0449, -0.0467, -0.0459]],
 
         [[-0.0428, -0.0467, -0.0470,  ..., -0.0353, -0.0357, -0.0350]]],
        grad_fn=<SliceBackward0>),
 'vocals': tensor([[[0.1289, 0.1318, 0.1339,  ..., 0.1296, 0.1298, 0.1253]],
 
         [[0.1293, 0.1255, 0.1246,  ..., 0.1303, 0.1252, 0.1268]],
 
         [[0.1300, 0.1306, 0.1301,  ..., 0.1282, 0.1281, 0.1273]],
 
         ...,
 
         [[0.1285, 0.1285, 0.1285,  ..., 0.1285, 0.1285, 0.1285]],
 
         [[0.1242, 0.1253, 0.1269,  ..., 0.1270, 0.1232, 0.1233]],
 
         [[0.1288, 0.1270, 0.1249,  ..., 0.1315, 0.1315, 0.1292]]],
        grad_fn=<SliceBackward0>)}

In [17]:
np_tensor = sample[0].numpy()
tf_tensor = tf.convert_to_tensor(np_tensor)

2021-12-26 00:01:13.842063: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [18]:
import tensorflow as tf

import tf_baseline
from tf_baseline import LeakyReLU
import numpy as np

class UnetAudioSeparator:
    '''
    U-Net separator network for singing voice separation.
    Uses valid convolutions, so it predicts for the centre part of the input - only certain input and output shapes are therefore possible (see getpadding function)
    '''

    def __init__(self, args):
        '''
        Initialize U-net
        :param num_layers: Number of down- and upscaling layers in the network 
        '''
        self.num_layers = (args.channels)
        self.num_initial_filters = 44125
        self.filter_size = args.kernel_size
        self.merge_filter_size = args.kernel_size
        self.input_filter_size = args.kernel_size
        self.output_filter_size = args.kernel_size
        self.upsampling = None
        self.output_type = "direct"
        self.context = None
        self.padding = "valid" if self.context else "same"
        self.source_names = args.instruments
        self.num_channels = args.channels
        self.output_activation = "tanh"

    def get_padding(self, shape):
        '''
        Calculates the required amounts of padding along each axis of the input and output, so that the Unet works and has the given shape as output shape
        :param shape: Desired output shape 
        :return: Input_shape, output_shape, where each is a list [batch_size, time_steps, channels]
        '''

        if self.context:
            # Check if desired shape is possible as output shape - go from output shape towards lowest-res feature map
            rem = float(shape[1]) # Cut off batch size number and channel

            # Output filter size
            rem = rem - self.output_filter_size + 1

            # Upsampling blocks
            for i in range(self.num_layers):
                rem = rem + self.merge_filter_size - 1
                rem = (rem + 1.) / 2.# out = in + in - 1 <=> in = (out+1)/

            # Round resulting feature map dimensions up to nearest integer
            x = np.asarray(np.ceil(rem),dtype=np.int64)
            assert(x >= 2)

            # Compute input and output shapes based on lowest-res feature map
            output_shape = x
            input_shape = x

            # Extra conv
            input_shape = input_shape + self.filter_size - 1

            # Go from centre feature map through up- and downsampling blocks
            for i in range(self.num_layers):
                output_shape = 2*output_shape - 1 #Upsampling
                output_shape = output_shape - self.merge_filter_size + 1 # Conv

                input_shape = 2*input_shape - 1 # Decimation
                if i < self.num_layers - 1:
                    input_shape = input_shape + self.filter_size - 1 # Conv
                else:
                    input_shape = input_shape + self.input_filter_size - 1

            # Output filters
            output_shape = output_shape - self.output_filter_size + 1

            input_shape = np.concatenate([[shape[0]], [input_shape], [self.num_channels]])
            output_shape = np.concatenate([[shape[0]], [output_shape], [self.num_channels]])

            return input_shape, output_shape
        else:
            return [shape[0], shape[1], self.num_channels], [shape[0], shape[1], self.num_channels]

    def get_output(self, input, training, return_spectrogram=False, reuse=True):
        '''
        Creates symbolic computation graph of the U-Net for a given input batch
        :param input: Input batch of mixtures, 3D tensor [batch_size, num_samples, num_channels]
        :param reuse: Whether to create new parameter variables or reuse existing ones
        :return: U-Net output: List of source estimates. Each item is a 3D tensor [batch_size, num_out_samples, num_channels]
        '''
        enc_outputs = list()
        current_layer = input

        # Down-convolution: Repeat strided conv
        for i in range(self.num_layers):
            current_layer = tf.keras.layers.Conv1D(self.num_initial_filters + (self.num_initial_filters * i), self.filter_size, strides=1, activation=LeakyReLU, padding=self.padding)(current_layer) # out = in - filter + 1
            enc_outputs.append(current_layer)
            current_layer = current_layer[:,::2,:] # Decimate by factor of 2 # out = (in-1)/2 + 1

        current_layer = tf.keras.layers.Conv1D(self.num_initial_filters + (self.num_initial_filters * self.num_layers),self.filter_size,activation=LeakyReLU,padding=self.padding)(current_layer) # One more conv here since we need to compute features after last decimation

        # Feature map here shall be X along one dimension

        # Upconvolution
        for i in range(self.num_layers):
            #UPSAMPLING
            current_layer = tf.expand_dims(current_layer, axis=1)
            if self.upsampling == 'learned':
                # Learned interpolation between two neighbouring time positions by using a convolution filter of width 2, and inserting the responses in the middle of the two respective inputs
                current_layer = tf_baseline.learned_interpolation_layer(current_layer, self.padding, i)
            else:
                if self.context:
                    current_layer = tf.image.resize_bilinear(current_layer, [1, current_layer.get_shape().as_list()[2] * 2 - 1], align_corners=True)
                else:
                    current_layer = tf.image.resize_bilinear(current_layer, [1, current_layer.get_shape().as_list()[2]*2]) # out = in + in - 1
            current_layer = tf.squeeze(current_layer, axis=1)
            # UPSAMPLING FINISHED

            assert(enc_outputs[-i-1].get_shape().as_list()[1] == current_layer.get_shape().as_list()[1] or self.context) #No cropping should be necessary unless we are using context
            current_layer = tf_baseline.crop_and_concat(enc_outputs[-i-1], current_layer, match_feature_dim=False)
            current_layer = tf.layers.conv1d(current_layer, self.num_initial_filters + (self.num_initial_filters * (self.num_layers - i - 1)), self.merge_filter_size,
                                             activation=LeakyReLU,
                                             padding=self.padding)  # out = in - filter + 1

        current_layer = tf_baseline.crop_and_concat(input, current_layer, match_feature_dim=False)

        # Output layer
        # Determine output activation function
        if self.output_activation == "tanh":
            out_activation = tf.tanh
        elif self.output_activation == "linear":
            out_activation = lambda x: tf_baseline.AudioClip(x, training)
        else:
            raise NotImplementedError

        if self.output_type == "direct":
            return tf_baseline.independent_outputs(current_layer, self.source_names, self.num_channels, self.output_filter_size, self.padding, out_activation)
        elif self.output_type == "difference":
            cropped_input = tf_baseline.crop(input,current_layer.get_shape().as_list(), match_feature_dim=False)
            return tf_baseline.difference_output(cropped_input, current_layer, self.source_names, self.num_channels, self.output_filter_size, self.padding, out_activation, training)
        else:
            raise NotImplementedError

In [None]:
separator_class = UnetAudioSeparator(args)
separator_func = separator_class.get_output
separator_func(tf_tensor, True)

AttributeError: module 'tensorflow' has no attribute 'layers'