In [None]:
!pip3 install librosa pescador matplotlib numpy torch tqdm

In [None]:
# import shutil
# import os

# def copy_wav_files(source_folder, destination_folder):
#     # Ensure the destination folder exists
#     os.makedirs(destination_folder, exist_ok=True)

#     # Iterate over all files in the source folder
#     for filename in os.listdir(source_folder):
#         # Check if the file has a .wav extension
#         if filename.lower().endswith('.wav'):
#             source_file = os.path.join(source_folder, filename)
#             # Check if it's a file and not a folder
#             if os.path.isfile(source_file):
#                 destination_file = os.path.join(destination_folder, filename)
#                 shutil.copy2(source_file, destination_file)  # copy2 preserves metadata
#                 print(f"Copied {filename} to {destination_folder}")

# # Example usage
# source_folder = '/content/drive/MyDrive/snare_samples/16bit_mono_wavs_clipped/validation'
# destination_folder = '/content/drive/MyDrive/ECS271/snare_sample_valid_wavOnly'
# copy_wav_files(source_folder, destination_folder)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Parameter


In [None]:
import os
import random
import torch
import numpy as np
import logging

# Paths
target_signals_dir = '/content/drive/MyDrive/ECS271/snare_samples/16bit_mono_wavs_clipped_22050'  # Update this path
output_dir = '/content/drive/MyDrive/ECS271'  # Update this path

# Model Parameters
model_prefix = 'exp1'
n_iterations = 160000 #100000, 1000, 10000, 20000
use_batchnorm = False
lr_g = 1e-5
lr_d = 1e-5
beta1 = 0.5
beta2 = 0.9
decay_lr = False
generator_batch_size_factor = 1
n_critic = 1
validate = False
p_coeff = 10
batch_size = 30
noise_latent_dim = 100
model_capacity_size = 32
store_cost_every = 300 #300
progress_bar_step_iter_size = 400
window_length = 32768 #32768 65536
sampling_rate = 22050 #22050 16000
normalize_audio = True
num_channels = 1

take_backup = True
backup_every_n_iters = 1000
save_samples_every = 1000
output_dir = 'output'
if not(os.path.isdir(output_dir)):
    os.makedirs(output_dir)

#############################
# Logger init
#############################
LOGGER = logging.getLogger('wavegan')
LOGGER.setLevel(logging.DEBUG)
#############################
# Torch Init and seed setting
#############################
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
# update the seed
manual_seed = 826
random.seed(manual_seed)
torch.manual_seed(manual_seed)
np.random.seed(manual_seed)
if cuda:
    torch.cuda.manual_seed(manual_seed)
    torch.cuda.empty_cache()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import pescador
from tqdm import tqdm

# Utili


In [None]:
import soundfile as sf

def get_recursive_files(folderPath, ext):
    results = os.listdir(folderPath)
    outFiles = []
    for file in results:
        if os.path.isdir(os.path.join(folderPath, file)):
            outFiles += get_recursive_files(os.path.join(folderPath, file), ext)
        elif file.endswith(ext):
            outFiles.append(os.path.join(folderPath, file))

    return outFiles


def make_path(output_path):
    if not os.path.isdir(output_path):
        os.makedirs(output_path)
    return output_path


#############################
# Plotting utils
#############################
def visualize_audio(audio_tensor, is_monphonic=False):
    # takes a batch ,n channels , window length and plots the spectogram
    input_audios = audio_tensor.detach().cpu().numpy()
    plt.figure(figsize=(18, 50))
    for i, audio in enumerate(input_audios):
        plt.subplot(10, 2, i + 1)
        if is_monphonic:
            plt.title("Monophonic %i" % (i + 1))
            librosa.display.waveshow(audio[0], sr=sampling_rate)
        else:
            D = librosa.amplitude_to_db(np.abs(librosa.stft(audio[0])), ref=np.max)
            librosa.display.specshow(D, y_axis="linear")
            plt.colorbar(format="%+2.0f dB")
            plt.title("Linear-frequency power spectrogram %i" % (i + 1))
    if not (os.path.isdir("visualization")):
        os.makedirs("visualization")
    plt.savefig("visualization/interpolation.png")


def visualize_loss(loss_1, loss_2, first_legend, second_legend, y_label):
    plt.figure(figsize=(10, 5))
    plt.title("{} and {} Loss During Training".format(first_legend, second_legend))
    plt.plot(loss_1, label=first_legend)
    plt.plot(loss_2, label=second_legend)
    plt.xlabel("iterations")
    plt.ylabel(y_label)
    plt.grid(True)
    plt.tight_layout()
    plt.legend()
    plt.show()
    if not (os.path.isdir("visualization")):
        os.makedirs("visualization")
    plt.savefig("visualization/loss.png")


def latent_space_interpolation(model, n_samples=10):
    z_test = sample_noise(2)
    with torch.no_grad():
        interpolates = []
        for alpha in np.linspace(0, 1, n_samples):
            interpolate_vec = alpha * z_test[0] + ((1 - alpha) * z_test[1])
            interpolates.append(interpolate_vec)

        interpolates = torch.stack(interpolates)
        generated_audio = model(interpolates)
    visualize_audio(generated_audio, True)


#############################
# Wav files utils
#############################
# Fast loading used with wav files only of 8 bits
def load_wav(wav_file_path):
    try:
        audio_data, _ = librosa.load(wav_file_path, sr=sampling_rate)

        if normalize_audio:
            # Clip magnitude
            max_mag = np.max(np.abs(audio_data))
            if max_mag > 1:
                audio_data /= max_mag
    except Exception as e:
        LOGGER.error("Could not load {}: {}".format(wav_file_path, str(e)))
        raise e
    audio_len = len(audio_data)
    if audio_len < window_length:
        pad_length = window_length - audio_len
        left_pad = pad_length // 2
        right_pad = pad_length - left_pad
        audio_data = np.pad(audio_data, (left_pad, right_pad), mode="constant")

    return audio_data.astype("float32")


def sample_audio(audio_data, start_idx=None, end_idx=None):
    audio_len = len(audio_data)
    if audio_len == window_length:
        # If we only have a single 1*window_length audio, just yield.
        sample = audio_data
    else:
        # Sample a random window from the audio
        if start_idx is None or end_idx is None:
            start_idx = np.random.randint(0, (audio_len - window_length) // 2)
            end_idx = start_idx + window_length
        sample = audio_data[start_idx:end_idx]
    sample = sample.astype("float32")
    assert not np.any(np.isnan(sample))
    return sample, start_idx, end_idx


def sample_buffer(buffer_data, start_idx=None, end_idx=None):
    audio_len = len(buffer_data) // 4
    if audio_len == window_length:
        # If we only have a single 1*window_length audio, just yield.
        sample = buffer_data
    else:
        # Sample a random window from the audio
        if start_idx is None or end_idx is None:
            start_idx = np.random.randint(0, (audio_len - window_length) // 2)
            end_idx = start_idx + window_length
        sample = buffer_data[start_idx * 4 : end_idx * 4]
    return sample, start_idx, end_idx


def wav_generator(file_path):
    audio_data = load_wav(file_path)
    while True:
        sample, _, _ = sample_audio(audio_data)
        yield {"single": sample}


def create_stream_reader(single_signal_file_list):
    data_streams = []
    for audio_path in single_signal_file_list:
        stream = pescador.Streamer(wav_generator, audio_path)
        data_streams.append(stream)
    mux = pescador.ShuffledMux(data_streams)
    batch_gen = pescador.buffer_stream(mux, batch_size)
    return batch_gen


def save_samples(epoch_samples, epoch):
    """
    Save output samples.
    """
    sample_dir = make_path(os.path.join(output_dir, str(epoch)))

    for idx, sample in enumerate(epoch_samples):
        output_path = os.path.join(sample_dir, "{}.wav".format(idx + 1))
        sample = sample[0]
        sf.write(output_path, sample, sampling_rate)


#############################
# Sampling from model
#############################
def sample_noise(size):
    z = torch.FloatTensor(size, noise_latent_dim).to(device)
    z.data.normal_()  # generating latent space based on normal distribution
    return z


#############################
# Model Utils
#############################


def update_optimizer_lr(optimizer, lr, decay):
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr * decay


def gradients_status(model, flag):
    for p in model.parameters():
        p.requires_grad = flag


def weights_init(m):
    if isinstance(m, nn.Conv1d):
        m.weight.data.normal_(0.0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0)
        m.bias.data.fill_(0)
    elif isinstance(m, nn.Linear):
        m.bias.data.fill_(0)


#############################
# Creating Data Loader and Sampler
#############################
class WavDataLoader:
    def __init__(self, folder_path, audio_extension="wav"):
        self.signal_paths = get_recursive_files(folder_path, audio_extension)
        self.data_iter = None
        self.initialize_iterator()

    def initialize_iterator(self):
        data_iter = create_stream_reader(self.signal_paths)
        self.data_iter = iter(data_iter)

    def __len__(self):
        return len(self.signal_paths)

    def numpy_to_tensor(self, numpy_array):
        numpy_array = numpy_array[:, np.newaxis, :]
        return torch.Tensor(numpy_array).to(device)

    def __iter__(self):
        return self

    def __next__(self):
        x = next(self.data_iter)
        return self.numpy_to_tensor(x["single"])

# Model

In [None]:
import torch.nn.functional as F

class Transpose1dLayer(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding=11,
        upsample=None,
        output_padding=1,
        use_batch_norm=False,
    ):
        super(Transpose1dLayer, self).__init__()
        self.upsample = upsample
        reflection_pad = nn.ConstantPad1d(kernel_size // 2, value=0)
        conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, stride)
        conv1d.weight.data.normal_(0.0, 0.02)
        Conv1dTrans = nn.ConvTranspose1d(
            in_channels, out_channels, kernel_size, stride, padding, output_padding
        )
        batch_norm = nn.BatchNorm1d(out_channels)
        if self.upsample:
            operation_list = [reflection_pad, conv1d]
        else:
            operation_list = [Conv1dTrans]

        if use_batch_norm:
            operation_list.append(batch_norm)
        self.transpose_ops = nn.Sequential(*operation_list)

    def forward(self, x):
        if self.upsample:
            # recommended by wavgan paper to use nearest upsampling
            x = nn.functional.interpolate(x, scale_factor=self.upsample, mode="nearest")
        return self.transpose_ops(x)


class Conv1D(nn.Module):
    def __init__(
        self,
        input_channels,
        output_channels,
        kernel_size,
        alpha=0.2,
        shift_factor=2,
        stride=4,
        padding=11,
        use_batch_norm=False,
        drop_prob=0,
    ):
        super(Conv1D, self).__init__()
        self.conv1d = nn.Conv1d(
            input_channels, output_channels, kernel_size, stride=stride, padding=padding
        )
        self.batch_norm = nn.BatchNorm1d(output_channels)
        self.phase_shuffle = PhaseShuffle(shift_factor)
        self.alpha = alpha
        self.use_batch_norm = use_batch_norm
        self.use_phase_shuffle = shift_factor == 0
        self.use_drop = drop_prob > 0
        self.dropout = nn.Dropout2d(drop_prob)

    def forward(self, x):
        x = self.conv1d(x)
        if self.use_batch_norm:
            x = self.batch_norm(x)
        x = F.leaky_relu(x, negative_slope=self.alpha)
        if self.use_phase_shuffle:
            x = self.phase_shuffle(x)
        if self.use_drop:
            x = self.dropout(x)
        return x


class PhaseShuffle(nn.Module):
    """
    Performs phase shuffling, i.e. shifting feature axis of a 3D tensor
    by a random integer in {-n, n} and performing reflection padding where
    necessary.
    """

    # Copied from https://github.com/jtcramer/wavegan/blob/master/wavegan.py#L8
    def __init__(self, shift_factor):
        super(PhaseShuffle, self).__init__()
        self.shift_factor = shift_factor

    def forward(self, x):
        if self.shift_factor == 0:
            return x
        # uniform in (L, R)
        k_list = (
            torch.Tensor(x.shape[0]).random_(0, 2 * self.shift_factor + 1)
            - self.shift_factor
        )
        k_list = k_list.numpy().astype(int)

        # Combine sample indices into lists so that less shuffle operations
        # need to be performed
        k_map = {}
        for idx, k in enumerate(k_list):
            k = int(k)
            if k not in k_map:
                k_map[k] = []
            k_map[k].append(idx)

        # Make a copy of x for our output
        x_shuffle = x.clone()

        # Apply shuffle to each sample
        for k, idxs in k_map.items():
            if k > 0:
                x_shuffle[idxs] = F.pad(x[idxs][..., :-k], (k, 0), mode="reflect")
            else:
                x_shuffle[idxs] = F.pad(x[idxs][..., -k:], (0, -k), mode="reflect")

        assert x_shuffle.shape == x.shape, "{}, {}".format(x_shuffle.shape, x.shape)
        return x_shuffle

class WaveGANGenerator(nn.Module):
    def __init__(
        self,
        model_size=64,
        ngpus=1,
        num_channels=1,
        verbose=False,
        upsample=True,
        slice_len=16384,
        use_batch_norm=False,
    ):
        super(WaveGANGenerator, self).__init__()
        assert slice_len in [16384, 32768, 65536]  # used to predict longer utterances

        self.ngpus = ngpus
        self.model_size = model_size  # d
        self.num_channels = num_channels  # c
        latent_dim = noise_latent_dim
        self.verbose = verbose
        self.use_batch_norm = use_batch_norm

        self.dim_mul = 16 if slice_len == 16384 else 32

        self.fc1 = nn.Linear(latent_dim, 4 * 4 * model_size * self.dim_mul)
        self.bn1 = nn.BatchNorm1d(num_features=model_size * self.dim_mul)

        stride = 4
        if upsample:
            stride = 1
            upsample = 4

        deconv_layers = [
            Transpose1dLayer(
                self.dim_mul * model_size,
                (self.dim_mul * model_size) // 2,
                25,
                stride,
                upsample=upsample,
                use_batch_norm=use_batch_norm,
            ),
            Transpose1dLayer(
                (self.dim_mul * model_size) // 2,
                (self.dim_mul * model_size) // 4,
                25,
                stride,
                upsample=upsample,
                use_batch_norm=use_batch_norm,
            ),
            Transpose1dLayer(
                (self.dim_mul * model_size) // 4,
                (self.dim_mul * model_size) // 8,
                25,
                stride,
                upsample=upsample,
                use_batch_norm=use_batch_norm,
            ),
            Transpose1dLayer(
                (self.dim_mul * model_size) // 8,
                (self.dim_mul * model_size) // 16,
                25,
                stride,
                upsample=upsample,
                use_batch_norm=use_batch_norm,
            ),
        ]

        if slice_len == 16384:
            deconv_layers.append(
                Transpose1dLayer(
                    (self.dim_mul * model_size) // 16,
                    num_channels,
                    25,
                    stride,
                    upsample=upsample,
                )
            )
        elif slice_len == 32768:
            deconv_layers += [
                Transpose1dLayer(
                    (self.dim_mul * model_size) // 16,
                    model_size,
                    25,
                    stride,
                    upsample=upsample,
                    use_batch_norm=use_batch_norm,
                ),
                Transpose1dLayer(model_size, num_channels, 25, 2, upsample=upsample),
            ]
        elif slice_len == 65536:
            deconv_layers += [
                Transpose1dLayer(
                    (self.dim_mul * model_size) // 16,
                    model_size,
                    25,
                    stride,
                    upsample=upsample,
                    use_batch_norm=use_batch_norm,
                ),
                Transpose1dLayer(
                    model_size, num_channels, 25, stride, upsample=upsample
                ),
            ]
        else:
            raise ValueError("slice_len {} value is not supported".format(slice_len))

        self.deconv_list = nn.ModuleList(deconv_layers)
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)

    def forward(self, x):
        x = self.fc1(x).view(-1, self.dim_mul * self.model_size, 16)
        if self.use_batch_norm:
            x = self.bn1(x)
        x = F.relu(x)
        if self.verbose:
            print(x.shape)

        for deconv in self.deconv_list[:-1]:
            x = F.relu(deconv(x))
            if self.verbose:
                print(x.shape)
        output = torch.tanh(self.deconv_list[-1](x))
        return output



class WaveGANDiscriminator(nn.Module):
    def __init__(
        self,
        model_size=64,
        ngpus=1,
        num_channels=1,
        shift_factor=2,
        alpha=0.2,
        verbose=False,
        slice_len=16384,
        use_batch_norm=False,
    ):
        super(WaveGANDiscriminator, self).__init__()
        assert slice_len in [16384, 32768, 65536]  # used to predict longer utterances

        self.model_size = model_size  # d
        self.ngpus = ngpus
        self.use_batch_norm = use_batch_norm
        self.num_channels = num_channels  # c
        self.shift_factor = shift_factor  # n
        self.alpha = alpha
        self.verbose = verbose

        conv_layers = [
            Conv1D(
                num_channels,
                model_size,
                25,
                stride=4,
                padding=11,
                use_batch_norm=use_batch_norm,
                alpha=alpha,
                shift_factor=shift_factor,
            ),
            Conv1D(
                model_size,
                2 * model_size,
                25,
                stride=4,
                padding=11,
                use_batch_norm=use_batch_norm,
                alpha=alpha,
                shift_factor=shift_factor,
            ),
            Conv1D(
                2 * model_size,
                4 * model_size,
                25,
                stride=4,
                padding=11,
                use_batch_norm=use_batch_norm,
                alpha=alpha,
                shift_factor=shift_factor,
            ),
            Conv1D(
                4 * model_size,
                8 * model_size,
                25,
                stride=4,
                padding=11,
                use_batch_norm=use_batch_norm,
                alpha=alpha,
                shift_factor=shift_factor,
            ),
            Conv1D(
                8 * model_size,
                16 * model_size,
                25,
                stride=4,
                padding=11,
                use_batch_norm=use_batch_norm,
                alpha=alpha,
                shift_factor=0 if slice_len == 16384 else shift_factor,
            ),
        ]
        self.fc_input_size = 256 * model_size
        if slice_len == 32768:
            conv_layers.append(
                Conv1D(
                    16 * model_size,
                    32 * model_size,
                    25,
                    stride=2,
                    padding=11,
                    use_batch_norm=use_batch_norm,
                    alpha=alpha,
                    shift_factor=0,
                )
            )
            self.fc_input_size = 480 * model_size
        elif slice_len == 65536:
            conv_layers.append(
                Conv1D(
                    16 * model_size,
                    32 * model_size,
                    25,
                    stride=4,
                    padding=11,
                    use_batch_norm=use_batch_norm,
                    alpha=alpha,
                    shift_factor=0,
                )
            )
            self.fc_input_size = 512 * model_size

        self.conv_layers = nn.ModuleList(conv_layers)

        self.fc1 = nn.Linear(self.fc_input_size, 1)

        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)

    def forward(self, x):
        for conv in self.conv_layers:
            x = conv(x)
            if self.verbose:
                print(x.shape)
        x = x.view(-1, self.fc_input_size)
        if self.verbose:
            print(x.shape)

        return self.fc1(x)

# Train


In [None]:
from torch.autograd import grad, Variable

class WaveGan_GP(object):
    def __init__(self, train_loader, val_loader):
        super(WaveGan_GP, self).__init__()
        self.g_cost = []
        self.train_d_cost = []
        self.train_w_distance = []
        self.valid_g_cost = [-1]
        self.valid_reconstruction = []

        self.discriminator = WaveGANDiscriminator(
            slice_len=window_length,
            model_size=model_capacity_size,
            use_batch_norm=use_batchnorm,
            num_channels=num_channels,
        ).to(device)
        self.discriminator.apply(weights_init)

        self.generator = WaveGANGenerator(
            slice_len=window_length,
            model_size=model_capacity_size,
            use_batch_norm=use_batchnorm,
            num_channels=num_channels,
        ).to(device)
        self.generator.apply(weights_init)

        self.optimizer_g = optim.Adam(
            self.generator.parameters(), lr=lr_g, betas=(beta1, beta2)
        )  # Setup Adam optimizers for both G and D
        self.optimizer_d = optim.Adam(
            self.discriminator.parameters(), lr=lr_d, betas=(beta1, beta2)
        )

        self.train_loader = train_loader
        self.val_loader = val_loader

        self.validate = validate
        self.n_samples_per_batch = len(train_loader)

    def calculate_discriminator_loss(self, real, generated):
        disc_out_gen = self.discriminator(generated)
        disc_out_real = self.discriminator(real)

        alpha = torch.FloatTensor(batch_size, 1, 1).uniform_(0, 1).to(device)
        alpha = alpha.expand(batch_size, real.size(1), real.size(2))

        interpolated = (1 - alpha) * real.data + (alpha) * generated.data[:batch_size]
        interpolated = Variable(interpolated, requires_grad=True)

        # calculate probability of interpolated examples
        prob_interpolated = self.discriminator(interpolated)
        grad_inputs = interpolated
        ones = torch.ones(prob_interpolated.size()).to(device)
        gradients = grad(
            outputs=prob_interpolated,
            inputs=grad_inputs,
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        # calculate gradient penalty
        grad_penalty = (
            p_coeff
            * ((gradients.view(gradients.size(0), -1).norm(2, dim=1) - 1) ** 2).mean()
        )
        assert not (torch.isnan(grad_penalty))
        assert not (torch.isnan(disc_out_gen.mean()))
        assert not (torch.isnan(disc_out_real.mean()))
        cost_wd = disc_out_gen.mean() - disc_out_real.mean()
        cost = cost_wd + grad_penalty
        return cost, cost_wd

    def apply_zero_grad(self):
        self.generator.zero_grad()
        self.optimizer_g.zero_grad()

        self.discriminator.zero_grad()
        self.optimizer_d.zero_grad()

    def enable_disc_disable_gen(self):
        gradients_status(self.discriminator, True)
        gradients_status(self.generator, False)

    def enable_gen_disable_disc(self):
        gradients_status(self.discriminator, False)
        gradients_status(self.generator, True)

    def disable_all(self):
        gradients_status(self.discriminator, False)
        gradients_status(self.generator, False)

    def train(self):
        progress_bar = tqdm(total=n_iterations // progress_bar_step_iter_size)
        fixed_noise = sample_noise(batch_size).to(
            device
        )  # used to save samples every few epochs

        gan_model_name = "gan_{}.tar".format(model_prefix)

        first_iter = 0
        if take_backup and os.path.isfile(gan_model_name):
            if cuda:
                checkpoint = torch.load(gan_model_name)
            else:
                checkpoint = torch.load(gan_model_name, map_location="cpu")
            self.generator.load_state_dict(checkpoint["generator"])
            self.discriminator.load_state_dict(checkpoint["discriminator"])
            self.optimizer_d.load_state_dict(checkpoint["optimizer_d"])
            self.optimizer_g.load_state_dict(checkpoint["optimizer_g"])
            self.train_d_cost = checkpoint["train_d_cost"]
            self.train_w_distance = checkpoint["train_w_distance"]
            self.valid_g_cost = checkpoint["valid_g_cost"]
            self.g_cost = checkpoint["g_cost"]

            first_iter = checkpoint["n_iterations"]
            for i in range(0, first_iter, progress_bar_step_iter_size):
                progress_bar.update()
            self.generator.eval()
            with torch.no_grad():
                fake = self.generator(fixed_noise).detach().cpu().numpy()
            save_samples(fake, first_iter)
        self.generator.train()
        self.discriminator.train()
        for iter_indx in range(first_iter, n_iterations):
            self.enable_disc_disable_gen()
            for _ in range(n_critic):
                real_signal = next(self.train_loader)

                # need to add mixed signal and flag
                noise = sample_noise(batch_size * generator_batch_size_factor)
                generated = self.generator(noise)
                #############################
                # Calculating discriminator loss and updating discriminator
                #############################
                self.apply_zero_grad()
                disc_cost, disc_wd = self.calculate_discriminator_loss(
                    real_signal.data, generated.data
                )
                assert not (torch.isnan(disc_cost))
                disc_cost.backward()
                self.optimizer_d.step()

            if self.validate and iter_indx % store_cost_every == 0:
                self.disable_all()
                val_data = next(self.val_loader)
                val_real = val_data
                with torch.no_grad():
                    val_discriminator_output = self.discriminator(val_real)
                    val_generator_cost = val_discriminator_output.mean()
                    self.valid_g_cost.append(val_generator_cost.item())

            #############################
            # (2) Update G network every n_critic steps
            #############################
            self.apply_zero_grad()
            self.enable_gen_disable_disc()
            noise = sample_noise(batch_size * generator_batch_size_factor)
            generated = self.generator(noise)
            discriminator_output_fake = self.discriminator(generated)
            generator_cost = -discriminator_output_fake.mean()
            generator_cost.backward()
            self.optimizer_g.step()
            self.disable_all()

            if iter_indx % store_cost_every == 0:
                self.g_cost.append(generator_cost.item() * -1)
                self.train_d_cost.append(disc_cost.item())
                self.train_w_distance.append(disc_wd.item() * -1)

                progress_updates = {
                    "Loss_D WD": str(self.train_w_distance[-1]),
                    "Loss_G": str(self.g_cost[-1]),
                    "Val_G": str(self.valid_g_cost[-1]),
                }
                progress_bar.set_postfix(progress_updates)

            if iter_indx % progress_bar_step_iter_size == 0:
                progress_bar.update()
            # lr decay
            if decay_lr:
                decay = max(0.0, 1.0 - (iter_indx * 1.0 / n_iterations))
                # update the learning rate
                update_optimizer_lr(self.optimizer_d, lr_d, decay)
                update_optimizer_lr(self.optimizer_g, lr_g, decay)

            if iter_indx % save_samples_every == 0:
                with torch.no_grad():
                    latent_space_interpolation(self.generator, n_samples=2)
                    fake = self.generator(fixed_noise).detach().cpu().numpy()
                save_samples(fake, iter_indx)

            if take_backup and iter_indx % backup_every_n_iters == 0:
                saving_dict = {
                    "generator": self.generator.state_dict(),
                    "discriminator": self.discriminator.state_dict(),
                    "n_iterations": iter_indx,
                    "optimizer_d": self.optimizer_d.state_dict(),
                    "optimizer_g": self.optimizer_g.state_dict(),
                    "train_d_cost": self.train_d_cost,
                    "train_w_distance": self.train_w_distance,
                    "valid_g_cost": self.valid_g_cost,
                    "g_cost": self.g_cost,
                }
                torch.save(saving_dict, gan_model_name)

        self.generator.eval()

# Run

In [None]:
train_loader = WavDataLoader(os.path.join(target_signals_dir, "training"))
val_loader = WavDataLoader(os.path.join(target_signals_dir, "validation"))

wave_gan = WaveGan_GP(train_loader, val_loader)
wave_gan.train()
visualize_loss(
    wave_gan.g_cost, wave_gan.valid_g_cost, "Train", "Val", "Negative Critic Loss"
)
latent_space_interpolation(wave_gan.generator, n_samples=5)

In [None]:
!cp -r output /content/drive/MyDrive/

In [None]:
!cp -r visualization /content/drive/MyDrive/