In [1]:
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F





##resample.py

In [3]:
class Resample1d(nn.Module):
    def __init__(self, channels, kernel_size, stride, transpose=False, padding="reflect", trainable=False):
        '''
        Creates a resampling layer for time series data (using 1D convolution) - (N, C, W) input format # (N batch_size, C channels, W lentgth of downsampled time_steps)
        :param channels: Number of features C at each time-step
        :param kernel_size: Width of sinc-based lowpass-filter (>= 15 recommended for good filtering performance)
        :param stride: Resampling factor (integer)
        :param transpose: False for down-, true for upsampling
        :param padding: Either "reflect" to pad or "valid" to not pad
        :param trainable: Optionally activate this to train the lowpass-filter, starting from the sinc initialisation
        '''
        super(Resample1d, self).__init__()

        self.padding = padding
        self.kernel_size = kernel_size
        self.stride = stride
        self.transpose = transpose
        self.channels = channels

        cutoff = 0.5 / stride

        assert(kernel_size > 2)
        assert ((kernel_size - 1) % 2 == 0)
        assert(padding == "reflect" or padding == "valid")

        filter = build_sinc_filter(kernel_size, cutoff)

        self.filter = torch.nn.Parameter(torch.from_numpy(np.repeat(np.reshape(filter, [1, 1, kernel_size]), channels, axis=0)), requires_grad=trainable)

    def forward(self, x):
        # Pad here if not using transposed conv
        input_size = x.shape[2]
        # print("self.padding = ",self.padding )
        if self.padding != "valid":
            num_pad = (self.kernel_size-1)//2
            out = F.pad(x, (num_pad, num_pad), mode=self.padding)
        else:
            out = x

        # Lowpass filter (+ 0 insertion if transposed)
        if self.transpose:
            expected_steps = ((input_size - 1) * self.stride + 1)
            if self.padding == "valid":
                expected_steps = expected_steps - self.kernel_size + 1

            out = F.conv_transpose1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels)
            diff_steps = out.shape[2] - expected_steps
            if diff_steps > 0:
                assert(diff_steps % 2 == 0)
                out = out[:,:,diff_steps//2:-diff_steps//2]
        else:
            assert(input_size % self.stride == 1)
            out = F.conv1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels)
        # print("out.size() = ", out.size())
        return out

    def get_output_size(self, input_size):
        '''
        Returns the output dimensionality (number of timesteps) for a given input size
        :param input_size: Number of input time steps (Scalar, each feature is one-dimensional)
        :return: Output size (scalar)
        '''
        assert(input_size > 1)
        if self.transpose:
            if self.padding == "valid":
                return ((input_size - 1) * self.stride + 1) - self.kernel_size + 1
            else:
                return ((input_size - 1) * self.stride + 1)
        else:
            assert(input_size % self.stride == 1) # Want to take first and last sample
            if self.padding == "valid":
                return input_size - self.kernel_size + 1
            else:
                return input_size

    def get_input_size(self, output_size):
        '''
        Returns the input dimensionality (number of timesteps) for a given output size
        :param input_size: Number of input time steps (Scalar, each feature is one-dimensional)
        :return: Output size (scalar)
        '''

        # Strided conv/decimation
        if not self.transpose:
            curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
        else:
            curr_size = output_size

        # Conv
        if self.padding == "valid":
            curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1

        # Transposed
        if self.transpose:
            assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end
            curr_size = ((curr_size - 1) // self.stride) + 1
        assert(curr_size > 0)
        return curr_size

def build_sinc_filter(kernel_size, cutoff):
    # FOLLOWING https://www.analog.com/media/en/technical-documentation/dsp-book/dsp_book_Ch16.pdf
    # Sinc lowpass filter
    # Build sinc kernel
    assert(kernel_size % 2 == 1)
    M = kernel_size - 1
    filter = np.zeros(kernel_size, dtype=np.float32)
    for i in range(kernel_size):
        if i == M//2:
            filter[i] = 2 * np.pi * cutoff
        else:
            filter[i] = (np.sin(2 * np.pi * cutoff * (i - M//2)) / (i - M//2)) * \
                    (0.42 - 0.5 * np.cos((2 * np.pi * i) / M) + 0.08 * np.cos(4 * np.pi * M))

    filter = filter / np.sum(filter)
    return filter

##crop.py 

In [4]:
def centre_crop(x, target):
    # combined = centre_crop(shortcut, upsampled)
    '''
    Center-crop 3-dim. input tensor along last axis so it fits the target tensor shape
    :param x: Input tensor
    :param target: Shape of this tensor will be used as target shape
    :return: Cropped input tensor
    '''
    if x is None:
        return None
    if target is None:
        return x

    target_shape = target.shape
    diff = x.shape[-1] - target_shape[-1]
    assert (diff % 2 == 0)
    crop = diff // 2

    if crop == 0:
        return x
    if crop < 0:
        raise ArithmeticError

    return x[:, :, crop:-crop].contiguous()

##conv.py

In [5]:
class ConvLayer(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, conv_type, transpose=False):
        super(ConvLayer, self).__init__()
        self.transpose = transpose
        self.stride = stride
        self.kernel_size = kernel_size
        self.conv_type = conv_type

        # How many channels should be normalised as one group if GroupNorm is activated
        # WARNING: Number of channels has to be divisible by this number!
        NORM_CHANNELS = 8

        if self.transpose:
            self.filter = nn.ConvTranspose1d(n_inputs, n_outputs, self.kernel_size, stride, padding=kernel_size-1)
        else:
            self.filter = nn.Conv1d(n_inputs, n_outputs, self.kernel_size, stride)

        if conv_type == "gn":
            assert(n_outputs % NORM_CHANNELS == 0)
            self.norm = nn.GroupNorm(n_outputs // NORM_CHANNELS, n_outputs)
        elif conv_type == "bn":
            self.norm = nn.BatchNorm1d(n_outputs, momentum=0.01)
        # Add you own types of variations here!

    def forward(self, x):
        # Apply the convolution
        if self.conv_type == "gn" or self.conv_type == "bn":
            out = F.relu(self.norm((self.filter(x))))
        else: # Add your own variations here with elifs conditioned on "conv_type" parameter!
            assert(self.conv_type == "normal")
            out = F.leaky_relu(self.filter(x))
        return out

    def get_input_size(self, output_size):
        # Strided conv/decimation
        if not self.transpose:
            curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
        else:
            curr_size = output_size

        # Conv
        curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1

        # Transposed
        if self.transpose:
            assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end
            curr_size = ((curr_size - 1) // self.stride) + 1
        assert(curr_size > 0)
        return curr_size

    def get_output_size(self, input_size):
        # Transposed
        if self.transpose:
            assert(input_size > 1)
            curr_size = (input_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
        else:
            curr_size = input_size

        # Conv
        curr_size = curr_size - self.kernel_size + 1 # o = i + p - k + 1
        assert (curr_size > 0)

        # Strided conv/decimation
        if not self.transpose:
            assert ((curr_size - 1) % self.stride == 0)  # We need to have a value at the beginning and end
            curr_size = ((curr_size - 1) // self.stride) + 1

        return curr_size

##UTILS.py

In [6]:
import os

def save_model(model, optimizer, state, path):
    if isinstance(model, torch.nn.DataParallel):
        model = model.module  # save state dict of wrapped module
    if len(os.path.dirname(path)) > 0 and not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'state': state,  # state of training loop (was 'step')
    }, path)


def load_model(model, optimizer, path, cuda):
    if isinstance(model, torch.nn.DataParallel):
        model = model.module  # load state dict of wrapped module
    if cuda:
        checkpoint = torch.load(path)
    else:
        checkpoint = torch.load(path, map_location='cpu')
    try:
        model.load_state_dict(checkpoint['model_state_dict'])
    except:
        # work-around for loading checkpoints where DataParallel was saved instead of inner module
        from collections import OrderedDict
        model_state_dict_fixed = OrderedDict()
        prefix = 'module.'
        for k, v in checkpoint['model_state_dict'].items():
            if k.startswith(prefix):
                k = k[len(prefix):]
            model_state_dict_fixed[k] = v
        model.load_state_dict(model_state_dict_fixed)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if 'state' in checkpoint:
        state = checkpoint['state']
    else:
        # older checkpoints only store step, rest of state won't be there
        state = {'step': checkpoint['step']}
    return state


def compute_loss(model, inputs, targets, criterion, compute_grad=False):
    '''
    Computes gradients of model with given inputs and targets and loss function.
    Optionally backpropagates to compute gradients for weights.
    Procedure depends on whether we have one model for each source or not
    :param model: Model to train with
    :param inputs: Input mixture
    :param targets: Target sources
    :param criterion: Loss function to use (L1, L2, ..)
    :param compute_grad: Whether to compute gradients
    :return: Model outputs, Average loss over batch
    '''
    all_outputs = {}

    if model.separate:
        avg_loss = 0.0
        num_sources = 0
        for inst in model.speakers:
            output = model(inputs, inst)
            loss = criterion(output[inst], targets[inst])

            if compute_grad:
                loss.backward()

            avg_loss += loss.item()
            num_sources += 1

            all_outputs[inst] = output[inst].detach().clone()

        avg_loss /= float(num_sources)
    else:
        loss = 0
        all_outputs = model(inputs)
        for inst in all_outputs.keys():
            loss += criterion(all_outputs[inst], targets[inst])

        if compute_grad:
            loss.backward()

        avg_loss = loss.item() / float(len(all_outputs))

    return all_outputs, avg_loss


class DataParallel(torch.nn.DataParallel):
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__(module, device_ids, output_device, dim)

    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

#waveunet.py

In [7]:
class UpsamplingBlock(nn.Module):
    def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res):
        super(UpsamplingBlock, self).__init__()
        assert(stride > 1)

        # CONV 1 for UPSAMPLING
        if res == "fixed":
            self.upconv = Resample1d(n_inputs, 15, stride, transpose=True)
        else:
            self.upconv = ConvLayer(n_inputs, n_inputs, kernel_size, stride, conv_type, transpose=True)

        self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_outputs, kernel_size, 1, conv_type)] +
                                                [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])

        # CONVS to combine high- with low-level information (from shortcut)
        self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_outputs + n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
                                                 [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])

    def forward(self, x, shortcut):
        # UPSAMPLE HIGH-LEVEL FEATURES
        upsampled = self.upconv(x)

        for conv in self.pre_shortcut_convs:
            upsampled = conv(upsampled)

        # Prepare shortcut connection
        combined = centre_crop(shortcut, upsampled) # shortcut has to have the same size as the shortcut

        # Combine high- and low-level features
        for conv in self.post_shortcut_convs:
            combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1)) #[0,1,2]
        return combined

    def get_output_size(self, input_size):
        curr_size = self.upconv.get_output_size(input_size)

        # Upsampling convs
        for conv in self.pre_shortcut_convs:
            curr_size = conv.get_output_size(curr_size)

        # Combine convolutions
        for conv in self.post_shortcut_convs:
            curr_size = conv.get_output_size(curr_size)

        return curr_size

class DownsamplingBlock(nn.Module):
    def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res):
        super(DownsamplingBlock, self).__init__()
        assert(stride > 1)

        self.kernel_size = kernel_size
        self.stride = stride

        # CONV 1
        self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_shortcut, kernel_size, 1, conv_type)] +
                                                [ConvLayer(n_shortcut, n_shortcut, kernel_size, 1, conv_type) for _ in range(depth - 1)])

        self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
                                                 [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in
                                                  range(depth - 1)])

        # CONV 2 with decimation
        if res == "fixed":
            self.downconv = Resample1d(n_outputs, 15, stride) # Resampling with fixed-size sinc lowpass filter
        else:
            self.downconv = ConvLayer(n_outputs, n_outputs, kernel_size, stride, conv_type)

    def forward(self, x):
        # PREPARING SHORTCUT FEATURES
        shortcut = x
        for conv in self.pre_shortcut_convs:
            shortcut = conv(shortcut)

        # PREPARING FOR DOWNSAMPLING
        out = shortcut
        for conv in self.post_shortcut_convs:
            out = conv(out)

        # DOWNSAMPLING
        out = self.downconv(out)

        return out, shortcut

    def get_input_size(self, output_size):
        curr_size = self.downconv.get_input_size(output_size)

        for conv in reversed(self.post_shortcut_convs):
            curr_size = conv.get_input_size(curr_size)

        for conv in reversed(self.pre_shortcut_convs):
            curr_size = conv.get_input_size(curr_size)
        return curr_size

class Waveunet(nn.Module):
    def __init__(self, num_inputs, num_channels, num_outputs, speakers, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2):
        super(Waveunet, self).__init__()

        self.num_levels = len(num_channels)
        self.strides = strides
        self.kernel_size = kernel_size
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.depth = depth
        self.speakers = speakers
        self.separate = separate

        # Only odd filter kernels allowed
        assert(kernel_size % 2 == 1)

        self.waveunets = nn.ModuleDict()

        model_list = speakers if separate else ["ALL"]
        # Create a model for each source if we separate sources separately, otherwise only one (model_list=["ALL"])
        for speaker in model_list:
            module = nn.Module()

            module.downsampling_blocks = nn.ModuleList()
            module.upsampling_blocks = nn.ModuleList()

            for i in range(self.num_levels - 1):
                in_ch = num_inputs if i == 0 else num_channels[i]

                module.downsampling_blocks.append(
                    DownsamplingBlock(in_ch, num_channels[i], num_channels[i+1], kernel_size, strides, depth, conv_type, res))

            for i in range(0, self.num_levels - 1):
                module.upsampling_blocks.append(
                    UpsamplingBlock(num_channels[-1-i], num_channels[-2-i], num_channels[-2-i], kernel_size, strides, depth, conv_type, res))

            module.bottlenecks = nn.ModuleList(
                [ConvLayer(num_channels[-1], num_channels[-1], kernel_size, 1, conv_type) for _ in range(depth)])

            # Output conv
            outputs = num_outputs if separate else num_outputs * len(speakers)
            module.output_conv = nn.Conv1d(num_channels[0], outputs, 1)

            self.waveunets[speaker] = module

        self.set_output_size(target_output_size)

    def set_output_size(self, target_output_size):
        self.target_output_size = target_output_size

        self.input_size, self.output_size = self.check_padding(target_output_size)
        print("Using valid convolutions with " + str(self.input_size) + " inputs and " + str(self.output_size) + " outputs")

        assert((self.input_size - self.output_size) % 2 == 0)
        self.shapes = {"output_start_frame" : (self.input_size - self.output_size) // 2,
                       "output_end_frame" : (self.input_size - self.output_size) // 2 + self.output_size,
                       "output_frames" : self.output_size,
                       "input_frames" : self.input_size}

    def check_padding(self, target_output_size):
        # Ensure number of outputs covers a whole number of cycles so each output in the cycle is weighted equally during training
        bottleneck = 1

        while True:
            out = self.check_padding_for_bottleneck(bottleneck, target_output_size)
            if out is not False:
                return out
            bottleneck += 1

    def check_padding_for_bottleneck(self, bottleneck, target_output_size):
        module = self.waveunets[[k for k in self.waveunets.keys()][0]]
        try:
            curr_size = bottleneck
            for idx, block in enumerate(module.upsampling_blocks):
                curr_size = block.get_output_size(curr_size)
            output_size = curr_size

            # Bottleneck-Conv
            curr_size = bottleneck
            for block in reversed(module.bottlenecks):
                curr_size = block.get_input_size(curr_size)
            for idx, block in enumerate(reversed(module.downsampling_blocks)):
                curr_size = block.get_input_size(curr_size)

            assert(output_size >= target_output_size)
            return curr_size, output_size
        except AssertionError as e:
            return False

    def forward_module(self, x, module):
        '''
        A forward pass through a single Wave-U-Net (multiple Wave-U-Nets might be used, one for each source)
        :param x: Input mix
        :param module: Network module to be used for prediction
        :return: Source estimates
        '''
        shortcuts = []
        out = x

        # DOWNSAMPLING BLOCKS
        for block in module.downsampling_blocks:
            out, short = block(out)
            shortcuts.append(short)
        print("out.size() pre bottlenet",out.size())
        # BOTTLENECK CONVOLUTION
        for conv in module.bottlenecks:
            out = conv(out)
        print("out.size() after = ", out.size())
        # UPSAMPLING BLOCKS
        for idx, block in enumerate(module.upsampling_blocks):
            out = block(out, shortcuts[-1 - idx])

        # OUTPUT CONV
        out = module.output_conv(out)
        if not self.training:  # At test time clip predictions to valid amplitude range
            out = out.clamp(min=-1.0, max=1.0)
        return out

    def forward(self, x, inst=None):
        curr_input_size = x.shape[-1]
        # print('self.waveunets ',self.waveunets.keys())
        # print('self.input_size ',self.input_size)
        assert(curr_input_size == self.input_size) # User promises to feed the proper input himself, to get the pre-calculated (NOT the originally desired) output size

        if self.separate:
            return {inst : self.forward_module(x, self.waveunets[inst])}
        else:
            assert(len(self.waveunets) == 1)
            out = self.forward_module(x, self.waveunets["ALL"])

            out_dict = {}
            for idx, inst in enumerate(self.speakers):
                out_dict[inst] = out[:, idx * self.num_outputs:(idx + 1) * self.num_outputs]
            return out_dict

In [8]:
# model = Waveunet(num_inputs =2, num_channels = [1], num_outputs = 2,speakers =["bass", "drums", "other", "vocals"],kernel_size= 5, target_output_size = 2.0,conv_type = "gn", res = "fixed"  )
# class Waveunet(nn.Module):
#     def __init__(self, num_inputs, num_channels, num_outputs, speakers, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2):
#         super(Waveunet, self).__init__()

In [9]:
len([1])

1

In [10]:
# import argparse
# parser = argparse.ArgumentParser(description='Optional app description')

# parser.add_argument('--speakers', type=str, nargs='+', default=["bass", "drums", "other", "vocals"],
#                     help="List of speakers to separate (default: \"bass drums other vocals\")")
# parser.add_argument('--cuda', action='store_true',
#                     help='Use CUDA (default: False)')
# parser.add_argument('--num_workers', type=int, default=1,
#                     help='Number of data loader worker threads (default: 1)')
# parser.add_argument('--features', type=int, default=32,
#                     help='Number of feature channels per layer')
# parser.add_argument('--log_dir', type=str, default='logs/waveunet',
#                     help='Folder to write logs into')
# parser.add_argument('--dataset_dir', type=str, default="/mnt/windaten/Datasets/MUSDB18HQ",
#                     help='Dataset path')
# parser.add_argument('--hdf_dir', type=str, default="hdf",
#                     help='Dataset path')
# parser.add_argument('--checkpoint_dir', type=str, default='checkpoints/waveunet',
#                     help='Folder to write checkpoints into')
# parser.add_argument('--load_model', type=str, default=None,
#                     help='Reload a previously trained model (whole task model)')
# parser.add_argument('--lr', type=float, default=1e-3,
#                     help='Initial learning rate in LR cycle (default: 1e-3)')
# parser.add_argument('--min_lr', type=float, default=5e-5,
#                     help='Minimum learning rate in LR cycle (default: 5e-5)')
# parser.add_argument('--cycles', type=int, default=2,
#                     help='Number of LR cycles per epoch')
# parser.add_argument('--batch_size', type=int, default=4,
#                     help="Batch size")
# parser.add_argument('--levels', type=int, default=6,
#                     help="Number of DS/US blocks")
# parser.add_argument('--depth', type=int, default=1,
#                     help="Number of convs per block")
# parser.add_argument('--sr', type=int, default=44100,
#                     help="Sampling rate")
# parser.add_argument('--channels', type=int, default=2,
#                     help="Number of input audio channels")
# parser.add_argument('--kernel_size', type=int, default=5,
#                     help="Filter width of kernels. Has to be an odd number")
# parser.add_argument('--output_size', type=float, default=2.0,
#                     help="Output duration")
# parser.add_argument('--strides', type=int, default=4,
#                     help="Strides in Waveunet")
# parser.add_argument('--patience', type=int, default=20,
#                     help="Patience for early stopping on validation set")
# parser.add_argument('--example_freq', type=int, default=200,
#                     help="Write an audio summary into Tensorboard logs every X training iterations")
# parser.add_argument('--loss', type=str, default="L1",
#                     help="L1 or L2")
# parser.add_argument('--conv_type', type=str, default="gn",
#                     help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn")
# parser.add_argument('--res', type=str, default="fixed",
#                     help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned")
# parser.add_argument('--separate', type=int, default=1,
#                     help="Train separate model for each source (1) or only one (0)")
# parser.add_argument('--feature_growth', type=str, default="double",
#                     help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)")

# args = parser.parse_args()
feature_growth = "double"
levels = 6
num_features = [32*i for i in range(1, levels +1)] if feature_growth == "add" else \
                [32*2**i for i in range(0, 6)]
# num_features = [args.features*i for i in range(1, args.levels+1)] if args.feature_growth == "add" else \
#                 [args.features*2**i for i in range(0, args.levels)]
output_size = 2.0
sr  = 16e3#44100
target_outputs = int(output_size * sr)
num_ch = 1
# target_outputs = int(args.output_size * args.sr)

# model = Waveunet(num_inputs =2, num_channels =num_features , num_outputs = 2,speakers =["bass", "drums", "other", "vocals"],kernel_size= 5, target_output_size = 2.0,depth=1,strides =4,conv_type = "gn",res = "fixed",separate = 1)
# class Waveunet(nn.Module):
# model = Waveunet(num_inputs =2, num_channels =num_features , num_outputs = 2,speakers =["ALL"],kernel_size= 5, target_output_size = 2.0,depth=1,strides =4,conv_type = "gn",res = "fixed",separate = False)
# class Waveunet(nn.Module):
# model = Waveunet(args.channels, num_features, args.channels, args.speakers, kernel_size=args.kernel_size,
#                   target_output_size=target_outputs, depth=args.depth, strides=args.strides,
#                   conv_type=args.conv_type, res=args.res, separate=args.separate)

# model = Waveunet(args.channels, num_features, args.channels, args.speakers, kernel_size=args.kernel_size,
#                   target_output_size=target_outputs, depth=args.depth, strides=args.strides,
#                   conv_type=args.conv_type, res=args.res, separate=args.separate)

In [11]:
model = Waveunet(num_inputs =num_ch, num_channels =num_features , num_outputs = num_ch,speakers =["1", "2"],kernel_size= 3, target_output_size = target_outputs,depth=1,strides =4,conv_type = "gn",res = "fixed",separate = False)

num_features

Using valid convolutions with 37205 inputs and 32429 outputs


[32, 64, 128, 256, 512, 1024]

In [12]:
N_batch=5
num_ch = 1
L_timeseries = 37205                
input = torch.rand([N_batch, num_ch,L_timeseries])



In [13]:
with torch.no_grad():
  out = model(input)

out.size() pre bottlenet torch.Size([5, 1024, 36])
out.size() after =  torch.Size([5, 1024, 34])


In [14]:
out['1'].size()

torch.Size([5, 1, 32429])

In [None]:
# Waveunet(
#   (waveunets): ModuleDict(
#     (ALL): Module(
#       (downsampling_blocks): ModuleList(
#         (0): DownsamplingBlock(
#           (pre_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(2, 32, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(4, 32, eps=1e-05, affine=True)
#             )
#           )
#           (post_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(32, 64, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
#             )
#           )
#           (downconv): Resample1d()
#         )
#         (1): DownsamplingBlock(
#           (pre_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(64, 64, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
#             )
#           )
#           (post_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(64, 128, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(16, 128, eps=1e-05, affine=True)
#             )
#           )
#           (downconv): Resample1d()
#         )
#         (2): DownsamplingBlock(
#           (pre_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(128, 128, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(16, 128, eps=1e-05, affine=True)
#             )
#           )
#           (post_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(128, 256, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(32, 256, eps=1e-05, affine=True)
#             )
#           )
#           (downconv): Resample1d()
#         )
#         (3): DownsamplingBlock(
#           (pre_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(256, 256, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(32, 256, eps=1e-05, affine=True)
#             )
#           )
#           (post_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(256, 512, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(64, 512, eps=1e-05, affine=True)
#             )
#           )
#           (downconv): Resample1d()
#         )
#         (4): DownsamplingBlock(
#           (pre_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(512, 512, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(64, 512, eps=1e-05, affine=True)
#             )
#           )
#           (post_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(512, 1024, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(128, 1024, eps=1e-05, affine=True)
#             )
#           )
#           (downconv): Resample1d()
#         )
#       )
#       (upsampling_blocks): ModuleList(
#         (0): UpsamplingBlock(
#           (upconv): Resample1d()
#           (pre_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(1024, 512, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(64, 512, eps=1e-05, affine=True)
#             )
#           )
#           (post_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(1024, 512, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(64, 512, eps=1e-05, affine=True)
#             )
#           )
#         )
#         (1): UpsamplingBlock(
#           (upconv): Resample1d()
#           (pre_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(512, 256, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(32, 256, eps=1e-05, affine=True)
#             )
#           )
#           (post_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(512, 256, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(32, 256, eps=1e-05, affine=True)
#             )
#           )
#         )
#         (2): UpsamplingBlock(
#           (upconv): Resample1d()
#           (pre_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(256, 128, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(16, 128, eps=1e-05, affine=True)
#             )
#           )
#           (post_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(256, 128, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(16, 128, eps=1e-05, affine=True)
#             )
#           )
#         )
#         (3): UpsamplingBlock(
#           (upconv): Resample1d()
#           (pre_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(128, 64, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
#             )
#           )
#           (post_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(128, 64, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
#             )
#           )
#         )
#         (4): UpsamplingBlock(
#           (upconv): Resample1d()
#           (pre_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(64, 32, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(4, 32, eps=1e-05, affine=True)
#             )
#           )
#           (post_shortcut_convs): ModuleList(
#             (0): ConvLayer(
#               (filter): Conv1d(64, 32, kernel_size=(5,), stride=(1,))
#               (norm): GroupNorm(4, 32, eps=1e-05, affine=True)
#             )
#           )
#         )
#       )
#       (bottlenecks): ModuleList(
#         (0): ConvLayer(
#           (filter): Conv1d(1024, 1024, kernel_size=(5,), stride=(1,))
#           (norm): GroupNorm(128, 1024, eps=1e-05, affine=True)
#         )
#       )
#       (output_conv): Conv1d(32, 8, kernel_size=(1,), stride=(1,))
#     )
#   )
# )



In [15]:
D = 5
I = 1
N =1000
V = 12

a = torch.rand([D,N])
b = torch.rand([V,N])
ab = torch.cat([a,b]).unsqueeze(dim = 1)
print(ab.size())

torch.Size([17, 1, 1000])
