In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import torchaudio
from torchaudio.transforms import *
import pytorch_lightning as pl

import musdb
import museval 

import numpy as np
import torch

from datetime import datetime
import time

import warnings
warnings.filterwarnings("ignore")

In [2]:
class DataHandler():
    def __init__(self,batch_size,shortest_duration, sampling_rate, longest_duration, segment_overlap, segment_chunks, segment_length, chunks_below_percentile, drop_percentile):
        self.sample_rate = sampling_rate
        self.shortest_duration = shortest_duration
        self.longest_duration = longest_duration
        self.segment_length = segment_length
        self.segment_chunks = segment_chunks
        self.chunks_below_percentile = chunks_below_percentile
        self.segment_overlap = segment_overlap
        self.drop_percentile = drop_percentile
        self.segment_samples = int(self.segment_length * self.sample_rate)
        self.chunk_samples = int(self.segment_samples / self.segment_chunks)
        self.batch_size = batch_size
        
    def batchize_training_item(self,stems):
        # we need to trim the stems to the shortest duration track, starting from a random location
        stems = self.trim_stems(stems, self.random_start(stems.shape[1]))
        # we split the stems from shape (num_stems, num_samples, num_channels) into a tensor with shape (num_stems, num_segments, num_samples_per_segment, num_channels)
        print(stems.shape)
        segments = self.split_track(stems)
        print(segments.shape)
        start = time.time()
        # now we need to drop out the low energy segments
        segments = self.high_energy_segments(segments)
        print(segments.shape)
        time_it(start, "HIGH ENERGY SEGMENTS CHECK")
        # then we choose a random and continuous yet constant number of segments from the track minus the dropped segments
        return segments[:self.batch_size]

    def batchize_testing_item(self,stems):
        # first we make sure the length of the song will produce a whole number of segments by padding with zeros to the end of the next segment.
        stems = self.add_zero_padding(stems)
        # now we extend the song with zeros to make sure we have an equal batch size for every output.
        to_pad = self.longest_duration_in_samples() - stems.shape[1]
        stems = self.add_N_zeros(stems, to_pad)
        # now we split the stems into equal size segments
        segments = self.split_track(stems)
        # we are now ready to operate on the song since we don't want to drop or modify our data as we conserve it to reconstruct our signal.
        return segments[:self.batch_size]
        
        
    def trim_stems(self, stems, start):
        # this function should trim the track to the shortest duration of the track from the start index, it allows looping back across the song.
        if start + self.shortest_duration > stems.shape[1]:
            first_half = stems[:, start:, :]
            remaining = self.shortest_duration - (stems.shape[1] - start)
            second_half = stems[:, :remaining, :]
            return torch.cat((first_half, second_half), axis=1)
        else:
            return stems[:, start:start+self.shortest_duration, :]

    def pad_stems(self, stems):
        length_in_samples = stems.shape[1]
        to_pad = self.longest_duration_in_samples() - length_in_samples

    def num_segments_in_track(self, duration_in_samples):
        # this function should return the number of segments in the track and consider the overlap factor, self.segment_overlap
        return int(torch.ceil(torch.Tensor(duration_in_samples / (self.segment_samples * (1 - self.segment_overlap)))))

    def random_start(self, duration_in_samples):
        # this function should return a random start index for the track
        return torch.randint(0, duration_in_samples, (1,))


    def longest_duration_in_samples(self, mus):
        # this function should return the longest duration of the stems in the track
        return max(min([track.stems.shape[1] for track in mus.tracks]))

    def is_high_energy_segment(self, segment, threshold):
        # this function decides based on the provided threshold whether a sufficient number of chunks in the segment have an energy above the threshold
        mix_chunk_energies = self.segment_chunk_energies(segment)[:]
        return len(torch.argwhere(mix_chunk_energies > threshold)) > (self.chunks_below_percentile * self.segment_chunks)

    def high_energy_segments(self, segments):
        # this function should take in a full track's stems, it will then split the track into segments. 
        # The segments should have an overlap factor of self.segment_overlap.
        # Then, it will split each segment into chunks, and it will calculate the energy of each segment, and store the energy of each segment in a list.
        # With this list, it will calculate the percentile of the energy of the chunks, and it will discard the segments where 25% of the chunks have an energy below the percentile.
        # It will then return the list of segments that have a high enough energy.
        high_energy_indices = []
        threshold = self.segment_energy_threshold(segments)
        for idx, segment in enumerate(segments):
            if self.is_high_energy_segment(segment, threshold):
                high_energy_indices.append(idx)
        high_energy_indices = torch.tensor(high_energy_indices, dtype= torch.int)
        return torch.index_select(segments, 0, high_energy_indices)

    def segment_energy_threshold(self, segments):
        # this function should split every segment into self.segment_chunks chunks, and it will calculate the energy of each chunk using the RMS energy function.
        # It will save the energy of each chunk in a list, and it will return the value self.percentile_dropped percentile of the list.
        chunk_energies = []
        for segment in segments:
            chunk_energies.extend(self.segment_chunk_energies(segment))
        chunk_energies = torch.stack(chunk_energies)
        percentile = torch.quantile(chunk_energies, self.drop_percentile, interpolation='midpoint')
        return percentile
            
    def segment_chunk_energies(self, segment):
        # this function should split the segment into self.segment_chunks chunks, and it will calculate the energy of each chunk using the RMS energy function.
        # It will save the energy of each chunk in a list, and it will return the list.
        chunk_energies = []
        segment_samples = segment.shape[1]
        chunk_samples = int(segment_samples / self.segment_chunks)
        for i in range(0, segment_samples, chunk_samples):
            chunk = segment[:, i:i+chunk_samples]
            mix_track = chunk[0,:,:]
            squared_tensor = torch.pow(chunk, 2)
            mean_power = torch.mean(squared_tensor)
            rms = torch.sqrt(mean_power)
            chunk_energies.append(rms)
        return torch.stack(chunk_energies).view(self.segment_chunks)

    def split_track(self, stems):
        # this function should take in a full track, and it will split the track into segments. 
        # The segments should have an overlap factor of self.segment_overlap.
        # We add zero padding to the track to make sure that the track is divisible by the segment length.
        # Then, it will split each segment into chunks, and it will return the list of chunks.
        # The input is a tensor with shape (num_stems, num_samples, num_channels)
        # The output is a tensor array with shape (num_stems, num_segments, num_samples_per_segment, num_channels)
        stems = self.add_zero_padding(stems)
        segments = []
        num_samples = stems.shape[1]
        step_in_samples = int(self.segment_samples * (1 - self.segment_overlap))
        for i in range(0, num_samples - step_in_samples, step_in_samples):
            segment = stems[:, i:i+self.segment_samples]
            segments.append(segment)
        segments = torch.stack(segments)
        return segments

    def add_zero_padding(self, stems):
        # this function should add zero padding to the track to make sure that the track is divisible by the segment length and the residue from the overlap.
        # the length of the array has to be segment_length + k * samples_in_steps for some nonnegative integer k.
        num_samples = stems.shape[1]
        step_in_samples = int(self.segment_samples * (1 - self.segment_overlap))
        samples_in_last_segment = num_samples % step_in_samples
        if samples_in_last_segment != 0:
            padding = torch.zeros((stems.shape[0], step_in_samples - samples_in_last_segment, stems.shape[2]))
            return torch.cat((stems, padding), axis=1)
        else:
            return stems

    def add_N_zeros(self, stems, N):
        zeros = torch.zeros((stems.shape[0], N, stems.shape[2]))
        return torch.cat((stems, zeros), axis = 1)
    

In [8]:
class newMus(torch.utils.data.Dataset):
    def __init__(self, musdb_root, split='train', subset='train', filtered_indices = None, batch_size = None, is_wav=False, sample_rate=44100, segment_length = 10, segment_chunks = 10, discard_low_energy = True, segment_overlap = 0.5, drop_percentile =  0.1, chunks_below_percentile = 0.5):
        assert(subset == 'train' or subset == 'test')
        self.mode = subset
        self.split = split 
        self.mus = musdb.DB(musdb_root, subsets=subset, split=split, is_wav=is_wav)
        self.sample_rate = sample_rate
        self.discard_low_energy = discard_low_energy
        self.segment_length = segment_length
        self.segment_chunks = segment_chunks
        self.chunks_below_percentile = chunks_below_percentile
        self.segment_overlap = segment_overlap
        self.drop_percentile = drop_percentile
        self.segment_samples = int(self.segment_length * self.sample_rate)
        self.chunk_samples = int(self.segment_samples / self.segment_chunks)
        if filtered_indices is None or batch_size is None:
            self.durations = dict()
            self.filtered_indices = dict()
            self.len = self.init_durations()
            self.shortest_duration = self.shortest_duration_in_samples(self.mus)
            self.batch_size = self.find_batch_size()
        else:
            self.filtered_indices = [int(x) for x in filtered_indices.split(',')]
            self.batch_size = batch_size
            self.len = len(self.filtered_indices)

    def __len__(self):
        return self.len


    def __getitem__(self, idx):
        # this function should return a batch of segment STFTs from the song as well as their stem STFTs.
        start = time.time()
        track = self.mus.tracks[self.filtered_indices[idx]]
        # stems is a list of the stems of the track, in the order of the stems in the track
        stems = torch.Tensor(track.stems)
        return stems

    def init_durations(self):
        pos = 0
        for idx, track in enumerate(self.mus.tracks):
            print(track.name)
            self.durations[idx] = track.stems.shape[1]
            if self.durations[idx] >= self.segment_samples * 9:
                self.filtered_indices[pos] = idx
                pos += 1
        return pos
        
    def find_batch_size(self):
        # this function should tell us how many STFTs we can fit into a batch based on finding the floor power of 2 of the number of STFTs we fit over the duration of the song
        # each STFT will represent a STFT over a fixed segment length.
        num_segments = self.num_segments_in_track(self.shortest_duration)
        # we anticipate a drop of up to twice the drop percentile (impossible, just to be safe) of the segments.
        num_segments = torch.Tensor([int(num_segments * (1 - 2 * self.drop_percentile))])
        # We return the closest power of two to that anticipated number of segments. Of course we use a floor because we want to fill every batch.
        return 2 ** int(torch.floor(torch.log2(num_segments)))

    def num_segments_in_track(self, duration_in_samples):
        # this function should return the number of segments in the track and consider the overlap factor, self.segment_overlap
        return int(torch.ceil(torch.Tensor([duration_in_samples / (self.segment_samples * (1 - self.segment_overlap))])))

    def shortest_duration_in_samples(self, mus):
        # this function should return the shortest duration of the stems in the track
        min = 100000000
        for i, dur in self.durations.items():
            if i in self.filtered_indices.values():
                if dur < min:
                    min = dur
        return min

    def longest_duration_in_samples(self, mus):
        # this function should return the longest duration of the stems in the track
        return max([duration for duration in self.durations.values()])


In [9]:
class LightningModel(pl.LightningModule): 
    
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters()
        self.mus_path = hparams['mus_path']
        self.bandwidths = [int(bandwidth) for bandwidth in hparams['bandwidths'].split(',')]
        self.step = 0
        self.n_mels = hparams['n_mels']
        self.N = hparams['bandwidth_freq_out_size']
        self.K = len(self.bandwidths)
        self.time_steps = hparams['time_steps']
        self.transforms = Transforms().double()
        self.kernel1 = hparams['conv_1_kernel_size']
        self.stride1 = hparams['conv_1_stride']
        self.kernel2 = hparams['conv_2_kernel_size']
        self.stride2 = hparams['conv_2_stride']
        self.training_dataloader = None
        self.testing_dataloader = None
        self.validation_dataloader = None
        self.bandsplit = BandSplit(self.bandwidths, self.N).double()
        self.conv1 = ConvolutionLayer(self.K, self.K, self.kernel1, self.stride1).double()
        self.conv2 = ConvolutionLayer(self.K, self.K, self.kernel2, self.stride2).double()
        self.conv3 = ConvolutionLayer(1,8,kernel_size=(1,11),stride=(1,3)).double()
        self.conv4 = ConvolutionLayer(8,22,kernel_size=(1,3),stride=(1,2)).double()
        self.conv5 = ConvolutionLayer(1,8,kernel_size=(1,11),stride=(1,3)).double()
        self.conv6 = ConvolutionLayer(8,22,kernel_size=(1,3),stride=(1,2)).double()
        self.blstms1 = AlternatingBLSTMs(self.K, 70, 63, 64).double()
        self.blstms2 = AlternatingBLSTMs(self.K, 70, 96, 64).double()
        self.blstms3 = AlternatingBLSTMs(self.K, 70, 76, 63 ).double()
        self.deconv1 = TransposeConvolutionLayer(self.K, self.K, self.kernel2, self.stride2).double()
        self.deconv2 = TransposeConvolutionLayer(self.K, self.K, self.kernel1, self.stride1).double()
        self.masks = MaskEstimation(self.bandwidths, 128,32).double()
        self.data_handler = DataHandler(hparams['training_batch_size'], hparams['shortest_duration'], hparams['sampling_rate'], 
            hparams['longest_duration'], hparams['segment_overlap'], hparams['segment_chunks'], 
            hparams['segment_length'], hparams['chunks_below_percentile'], hparams['drop_percentile'])
    
    def forward(self, x):
        return x
    
    def training_step(self, batch, batch_idx):
        if self.step % 5 == 0:
            torch.cuda.empty_cache()
        self.step += 1
        
        data = self.training_dataloader.dataset.batchize_training_item(batch)
        stfts, chromas, mfccs = self.transforms(data)
        predicted_sources = self.forward_pass(stfts, chromas, mfccs)
        loss = self.loss(predicted_sources, real_sources[:,3])
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        if self.step % 5 == 0:
            torch.cuda.empty_cache()
        self.step += 1
        
        data = self.validation_dataloader.dataset.batchize_training_item(batch)
        stfts, chromas, mfccs = self.transforms(data)
        predicted_sources = self.forward_pass(stfts, chromas, mfccs)
        loss = self.loss(predicted_sources, real_sources[:,3])
        self.log("valid_loss", loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        pass
    
    def forward_pass(self, X, chromas, mfccs):
        # takes in STFTs, chromas, mfccs
        X1 = self.bandsplit(X)
        batch_size = X1.shape[0]
        #Shape: torch.Size([32, 22, 431, 128]) (batch_size, num_bands, time_steps, freq_N)
        X2 = self.conv1(X1)
        X3 = self.conv2(X2)
        mfccs = mfccs.reshape(batch_size,1,self.n_mels,-1)
        mfccs = self.conv3(mfccs)
        mfccs = self.conv4(mfccs)
        chromas = chromas.reshape(batch_size,1,12,-1)
        chromas = self.conv5(chromas)
        chromas = self.conv6(chromas)
        X, _ = self.blstms1(X3)
        xmfccs = torch.cat((mfccs,X), 2)
        X, _ = self.blstms2(xmfccs)
        xchromas = torch.cat((chromas,X),2)
        X, _ = self.blstms3(xchromas)      
        X = self.deconv1(X + X3)
        X = self.deconv2(X + X2)
        X = self.masks(torch.cat((X, X[:,:,:,-1].unsqueeze(3)), 3))
        return X
    
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr = 0.001)
        return optimizer
    

        
class Chroma(nn.Module):
    def __init__(self, n_fft, sampling_rate):
        self.n_fft = n_fft
        self.sampling_rate = sampling_rate
        
    def forward(self, x):
        pass
    
        # x is a spectrogram with shape(... , T, F)
class Transforms(nn.Module):
    def __init__(self, input_freq = 44100, resample_freq = 16000, n_fft = 2048, hop_length = 1024, win_length=2048, n_mels = 32):
        super().__init__()
        self.resample = Resample(input_freq, resample_freq)
        self.stft = Spectrogram(n_fft = n_fft, hop_length = hop_length, win_length = win_length)
        self.mel = MelScale(sample_rate = resample_freq, n_mels = n_mels, n_stft = n_fft // 2 + 1)

    def forward(self, X):
        X = self.resample(X)
        
        stft = self.stft(X)
    
        mfccs= self.mel(stft)
        
        chromas = self.mel(stft)
        return stft, chromas, mfccs
    
def post_conv_dimensions(self,N,time_steps,in_channels):
        x = torch.randn(1,in_channels,N,time_steps).to(self.device).double()
        return self.conv2(self.conv1(x)).shape
        
class ConvolutionLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0, dtype='double'):
        super(ConvolutionLayer, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        x = self.conv(x)
        return x
    
class TransposeConvolutionLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, dtype='double'):
        super(TransposeConvolutionLayer, self).__init__()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        x = self.conv(x)
        return x
    

# This class defines a module that runs t
class AlternatingBLSTMs(nn.Module):
    def __init__(self, num_bands, time_steps, N, out_size, axis=1):
        super(AlternatingBLSTMs, self).__init__()
        self.band_blstm = BandBiLSTM(num_bands, time_steps, N)
        self.temporal_blstm = TemporalBiLSTM(num_bands, time_steps, N, out_size)
        self.num_bands = num_bands
        self.time_steps = time_steps
        self.N = N
        # hidden size = freq_steps_per_band * time_steps 

    def forward(self, x):
        # Input shape: (batch_size, num_bands, N, time_steps)
        # Prepare for Band BLSTM: shape = (batch_size, num_bands, N * time_steps)
        batch_size = x.shape[0]
        x = x.reshape(batch_size, self.num_bands, -1)
        x = self.band_blstm(x)
        x = x.reshape(batch_size, self.time_steps, -1)
        x = self.temporal_blstm(x)
        #x += residual
        # Return the output of the module
        return x    
    
# This class defines a module that runs the input, with shape (num_bands, num_timesteps, N), through a normalization layer, then a temporal biLSTM, then a fully connected layer.
# Then, the output of that layer is of the same shape as the input to the module, which will be fed into a similar structure, but this time with a band biLSTM, following the same normalization, biLSTM, FC structure.
class BandBiLSTM(nn.Module):
    def __init__(self, num_bands, time_steps, N, axis=1):
        super(BandBiLSTM, self).__init__()
        self.norm = nn.GroupNorm(num_bands, num_bands)
        self.input_size = time_steps * N
        self.hidden_size = self.input_size // 2
        self.bilstm = nn.LSTM(self.input_size, self.hidden_size, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(N, N)
        self.axis = axis
        self.N = N
        self.num_bands = num_bands
        self.time_steps = time_steps
        # hidden size = freq_steps_per_band * time_steps 

    def forward(self, x):
        batch_size = x.shape[0]
        # (batch_size,time_steps, num_bands, N)
        x = self.norm(x)
        residual = x.clone().detach()
        x, lstm_vars = self.bilstm(x)
        # (batch_size, num_bands, 2 * hidden_size)
        x = x.reshape(batch_size, self.num_bands, self.time_steps, self.N)
        # (batch_size, num_bands, time_steps, N)
        x = self.fc(x)
        # (batch_size, num_bands, time_steps, N)
        #x += residual
        # Return the output of the module
        return x
    
class TemporalBiLSTM(nn.Module):
    def __init__(self, num_bands, time_steps, N, out, axis=1):
        super(TemporalBiLSTM, self).__init__()
        self.norm = nn.GroupNorm(time_steps, time_steps)
        self.input_size = num_bands * N
        self.hidden_size = num_bands * out // 2
        self.out = out
        self.bilstm = nn.LSTM(self.input_size, self.hidden_size, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(out, out)
        self.axis = axis
        self.N = N
        self.time_steps = time_steps
        self.num_bands = num_bands
        # hidden size = freq_steps_per_band * time_steps 

    def forward(self, x):
        batch_size = x.shape[0]
        # (batch_size,time_steps, num_bands, N)
        x = self.norm(x)
        residual = x.clone().detach()
        x, lstm_vars = self.bilstm(x)
        # (batch_size, num_bands, 2 * hidden_size)
        x = x.reshape(batch_size, self.num_bands, self.time_steps, self.out)
        # (batch_size, num_bands, time_steps, N)
        x = self.fc(x)
        x = x.permute(0,1,3,2)
        # (batch_size, num_bands, time_steps, N)
        #x += residual
        # Return the output of the module
        return x, lstm_vars

class BandSplit(torch.nn.Module):
    # Input shape: torch.Size([16, 2, 1025, 431, 2])
    def __init__(self, bandwidths, N):
        # bandwidth
        super(BandSplit, self).__init__()
        self.bandwidths = bandwidths
        self.norm_layers = torch.nn.ModuleList([torch.nn.LayerNorm(2 * bandwidth) for bandwidth in self.bandwidths])
        self.fc_layers = torch.nn.ModuleList([torch.nn.Linear(2 * bandwidth, N) for bandwidth in self.bandwidths])

    def forward(self, X):
        subband_spectrograms = []
        K = len(self.bandwidths)
        for i in range(K):
            start_index = sum(self.bandwidths[:i])
            end_index = start_index + self.bandwidths[i]
            subband_spectrogram = X[:, :,start_index:end_index, :]
            subband_spectrogram = subband_spectrogram.permute(0,1,4,2,3)
            subband_spectrogram = subband_spectrogram.reshape(2 * X.shape[0], X.shape[3], 2 * self.bandwidths[i])
            subband_spectrograms.append(subband_spectrogram)

        subband_features = []
        for i in range(K):
            norm_output = self.norm_layers[i](subband_spectrograms[i])
            fc_output = self.fc_layers[i](norm_output)
            subband_features.append(fc_output)

        Z = torch.stack(subband_features, dim=1)
        Z = Z.permute(0,1,3,2)
        return Z
    
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super(MLP, self).__init__()
        self.MLP = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
        
    def forward(self, x):
        x = self.MLP(x)
        return x
    
    
class MaskEstimation(nn.Module):
    def __init__(self, bandwidths, N, batch_size):
        super(MaskEstimation, self).__init__()
        self.num_bands = len(bandwidths)
        self.bandwidths = bandwidths
        self.batch_size = batch_size
        self.norm_layers = torch.nn.ModuleList([torch.nn.LayerNorm(N) for bandwidth in self.bandwidths])
        self.MLP_layers = torch.nn.ModuleList([MLP(N, bandwidth * 2, N * 2) for bandwidth in self.bandwidths])
    def forward(self, x):
        # Input shape: (batch_size, num_bands, N, T)
        time_steps = x.shape[3]
        x = x.permute(1, 0 , 3, 2)
        out = []
        # shape: (num_bands, batch_size, T, N)
        for i in range(self.num_bands):
            y = self.norm_layers[i](x[i])
            y = self.MLP_layers[i](y)
            out.append(y)
        out = torch.cat(out, 2)
        out = out.reshape(self.batch_size // 2, 2, sum(self.bandwidths), time_steps, 2)
        return out

In [10]:
hparams = {
        "mus_path": "musdb/",
        "num_bandwidths": 23,
        "bandwidths": "20,20,20,30,30,30,30,30,30,30,30,30,30,50,50,50,50,70,70,100,100,125",
        "bandwidth_freq_out_size": 128,
        "n_fft": 2048,
        "hop_length": 1024,
        "win_length": 2048,
        "conv_1_kernel_size":(1,7),
        "conv_1_stride":(1,3),
        "conv_2_kernel_size":(4,4),
        "conv_2_stride":(2,2),
        "conv_3_kernel_size":(1,7),
        "conv_3_stride":(1,3),
        "conv_3_ch_out_1":8,
        "time_steps": 431,
        "freq_bands": 1025,
        "n_mels": 32,
        "input_sampling_rate": 44100,
        "resampling_rate": 16000,
        "shortest_duration" : 5019648,
        "longest_duration" : 20000000,
        "segment_length" : 10,
        "sampling_rate" : 44100,
        "resampling_rate" : 16000,
        "discard_low_energy" : True,
        "drop_percentile" : 0.1,
        'chunks_below_percentile' : 0.5,
        'segment_overlap' : 0.5,
        'segment_chunks' : 10,
        'training_batch_size' : 16,
        'testing_batch_size' : 64,
        'filtered_training_indices' : '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,50,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85',
        'filtered_validation_indices' : '0,1,2,3,4,5,6,7,8,9,10,11,12,13'
    }

In [11]:
musValidation = newMus('musdb/', 'valid', 'train', batch_size = 16, filtered_indices = hparams['filtered_validation_indices'])

In [12]:
musTraining = newMus('musdb/', batch_size = 16, filtered_indices = hparams['filtered_training_indices'])

In [13]:
lightning = LightningModel(hparams)

In [14]:
trainer = pl.Trainer(max_epochs=1, accelerator='gpu', devices = 2)
trainer.fit(lightning, musTraining, musValidation)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name       | Type                      | Params
----------------------------------------------------------
0  | transforms | Transforms                | 0     
1  | bandsplit  | BandSplit                 | 269 K 
2  | conv1      | ConvolutionLayer          | 3.5 K 
3  | conv2      | ConvolutionLayer          | 7.8 K 
4  | conv3      | ConvolutionLayer 

Sanity Checking: 0it [00:00, ?it/s]

ProcessExitedException: process 1 terminated with signal SIGSEGV

In [None]:
def plot_waveform(waveform, sr, title="Waveform"):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sr

    figure, axes = plt.subplots(num_channels, 1)
    axes.plot(time_axis, waveform[0], linewidth=1)
    axes.grid(True)
    figure.suptitle(title)
    plt.show(block=False)


def plot_spectrogram(specgram, title=None, ylabel="freq_bin"):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or "Spectrogram (db)")
    axs.set_ylabel(ylabel)
    axs.set_xlabel("frame")
    im = axs.imshow(librosa.power_to_db(specgram), origin="lower", aspect="auto")
    fig.colorbar(im, ax=axs)
    plt.show(block=False)
    
def check_memory():
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    f = r-a  # free inside reserved
    GB = 1024 ** 3
    print("Total: ", t / GB)
    print("Reserved ", r / GB)
    print("Allocated: ", a / GB)
    print("free: ", f / GB)
    
def free_memory(var):
    del var
    torch.cuda.empty_cache()