# Plan for increasing conditioning effectiveness

## Unkowns

Currently many modes are either not being captured by the generator, or not being correlated with the conditioning input. It is currently unclear how well the initial text embedding is able to separate and cluster SFX search queries. It is also unclear if, and how much, the initial down sample from the high dimensional text embedding space (1024 dimensions) to the lower 128 dimensional conditioning input for the GAN reduces the separability of the the search queries. It is also possible that the search queries input by a person, and those provided to the GAN are too different, and produce embeddings that have a different distribution to the original data (When training we use a text string containing many keywords, while humans normal only enter a subset of these).

## Experiments

In order to address these difficulties we plan to begin by creating a visualization for the high dimensional text embeddings that are fed into the GAN when training. Along with the text that is used to create the embeddings for the GAN, we will also feed in embedded versions of more realistic human search queries to the visualizer in order to see if these queries correlate with appropriate training texts. This experiment will also be repeated by first feeding the text embedding through a pre-trained embedding reduction. Currently both the generator and discriminator train separate embedding reduction networks, this probably hinders and slows down learning. However, both embedding reduction networks will be visualized for now as the model I currently have trained uses two separate reduction networks.

## Proposed Solutions (Pre Processing)

Depending on the results of these experiments, extra preprocessing steps may be performed in order to modify and augment training data with more realistic input. This could include: Filtering out non-descriptive SFX collection names (e.g. "Gamemaster Audio Pro", these are words that users would not search for without prior knowledge of the training set), removing duplicate words and removing punctuation (not present in user searches). The data could also be augmented by randomly dropping words from the training text, so that they average around 2 words (with stddev 2 and min 1). Each audio sample would then be associated with a list of possible texts to use for training. This would serve to both augment training data, and increase correlation between user text embeddings and the training embeddings (user usually only use 1 - 4 search terms).

The text embedding stage should now be done as a preprocessing step. The features of each data sample then contain an audio sample + a list of possible text embeddings to use during training.

Existing audio should also be augmented using various pitch shifts. This is almost always done in game anyway, as repeated sound effects sound extremely bad.

# Architecture Plan

## Problem

One of the biggest problems with training a GAN is mode collapse. Generally GANs perform very well if they are trained on a dataset that has limited modality (such as the LSUN bedroom dataset). This way the GAN has a specifiic data distribution for which they are training a generator to fit. Our initial experiments, using only a very small dataset of fighting sound effects (SFX), produced very promising results. Subjectively, this dataset only has about 2 modalities (human vocalizations, and hit/woosh SFX). A very basic GAN is able to quite quickly learn to generate decent audio samples on this dataset. However, when trained using a much larger dataset of SFX, with a variety of modalities, the quality of SFX from specific modalities (such as the fighting sounds) decreases dramatically. In our model, this is likely due to the gradient penalty term used in our loss function. Loss functions with gradient penalties can cause generated data points to spread themselves, and oscilate, between different modalities, as discussed in this paper [[On catastrophic forgetting and mode collapse in Generative Adversarial
Networks]](https://arxiv.org/pdf/1807.04015.pdf). The generator is also completely unable to reproduce SFX from certain modalities. SFX from the Animal Hyperrealism collection are a good example of this. The generator is completely unable to reproduce sounds from this collection, and usually produces either noise, or examples from easier modalities, such as hit and foley sounds. In generaal the generators default go-to modality seems to be folley/hit and swoosh SFX. It will often produce SFX from this modality, even when fed with conditioning text for completely unrelated modalities (like 'cat').

## Existing Research

Several recent papers such as [ProgressiveGAN](https://arxiv.org/abs/1710.10196) have found it beneficial to split the generative problem into smaller, easier to solve generative problems, that then build up to solve a harder generative problem. ProgressiveGAN starts by training a very shallow network to generate low resolution images, then gradually adds more stages to the existing model, to generate higher, and higher resolution images. The key here is training different parts of the network to solve different problems. All the stages of ProgressiveGAN are explicitly trained to generate images of a specific resolution, from images of a lower resolution. This is a much simpler problem than going directly from a low dimensional latent vector to a high dimensional output space. By training a network to solve small sub problems, that can be combined to solve a larger problem, the network is able to learn to generate extremely convincing results, even for a very high dimensional output space.

Splitting up the generative problem into multiple, easier to solve problems, has also been used to solve the mode collapse problem in both [MGAN](https://arxiv.org/pdf/1708.02556.pdf) and [MADGAN](https://arxiv.org/pdf/1704.02906.pdf). These network both use very similar architectures. The basic idea is to use multiple generators, instead of a single generator. Each generator is trained to produce a specific modality from the dataset. This is done by having the discriminator try to classify both whether the image is real or not, and what generator produced the image. In this way, each generator is trained to diversify itself from other generators, such that the discriminator is easily able to tell them apart. In order to reduce complexity these networks often use weight sharing between the different generators (usually all the weights are shared except in the final layers), although the authors on MADGAN note that not using weight sharing allows for generating from more diverse datasets.

[StackGAN++](https://arxiv.org/pdf/1710.10916.pdf) also splits the generative problem up into mutiple pipeline stages, similar to ProgressiveGAN. However StackGAN++ trains the entire pipeline at once, rather than progressively adding stages as one stage converges. StackGAN is of particular interest, as they are solving a text-to-image problem, which is similar to our text-to-sfx problem. StackGAN++ splits the generative problem into two different ones, generating realistic looking images, and generating images that match a specific condition. The discriminator in this network splits in two at the final layer, with one path mixing in the condition, and outputs a score for real images and matching conditions. In this way the discriminator is explicitly trained to tell if a generated image belongs to the category it was conditioned on. However, unlike MADGAN, it is unable to tell specifically what modality the generated image belongs to. If multiple generators are used, as in MADGAN, this is likely to cause the generators to all learn the same output distribution which maximally fools the discriminator, rather than diversifying into multiple modalities.

# Proposed Architecture

Initially a simple architecture will be implemented following a similar upscaling / downscaling structure as ProgressiveGAN, but without the progressive growing implemented. Apart from the progressive growing part, and regularization / hyperparameter optimization, their architecture is very simple. The number of parameters should be designed to match ProgressiveGAN up to the 128 x 128 image scale. This matches our 16384 generated samples (just over 1 second of audio at 16KHz). WGAN-GP will be used as the loss function to prevent vanishing gradients. This first simple architecture will then be tested against basic WaveGAN for validation.

Next text conditioning will be added using the same method as StackGAN++. The discriminator will now output whether the image matches the supplied text condition, as well as a score indicating how real the image looks. Training will also be augmented by showing the discriminator wrong audio clips for a given conditioning text. This will introduce an extra error term into the discriminator error (wrong_discriminator_err). This architecture will then be evaluated against the current conditional architecture for validation.

In order to cover diverse modalities, we will next modify the model to include multiple generators. A new generator selector stage will be added to try to select the best generator for a given conditioning text input. This generator selector will be based directly on the conditioning text embedding. The output will be a softmax layer corresponding to the probability that a particular generator generates data for the specified modality (specified via the conditioning text embedding). A generator is then selected using the probability distribution output by the softmax layer. The input latent 'z' vector and conditioning text embedding are then fed forward through this generator only. The discriminator will be modified to have a third ouput. This will be a softmax layer predicting which generator was used to produce the output. This output will be ignored in the real_discriminator_err and wrong_discriminator_err, but utilized in the fake_discriminator_err as an exta error term to check predicted generator vs actual generator used. Intuitively this extra softmax output corresponds to a prediction of which modality a sample is drawn from, irrespective of whether the sample is considered real or fake by the discriminator. This encourages the generators to diversify into easily separable modalities, even when the discriminator is having a hard time telling which samples are real or fake.

# Relavent StackGAN++ Code

## Dnet

In [None]:
netD, optD = self.netsD[idx], self.optimizersD[idx]
real_imgs = self.real_imgs[idx]
wrong_imgs = self.wrong_imgs[idx]
fake_imgs = self.fake_imgs[idx]
#
netD.zero_grad()
# Forward
real_labels = self.real_labels[:batch_size]
fake_labels = self.fake_labels[:batch_size]
# for real
real_logits = netD(real_imgs, mu.detach())
wrong_logits = netD(wrong_imgs, mu.detach())
fake_logits = netD(fake_imgs.detach(), mu.detach())
#
errD_real = criterion(real_logits[0], real_labels)
errD_wrong = criterion(wrong_logits[0], fake_labels)
errD_fake = criterion(fake_logits[0], fake_labels)
if len(real_logits) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
    errD_real_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
        criterion(real_logits[1], real_labels)
    errD_wrong_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
        criterion(wrong_logits[1], real_labels)
    errD_fake_uncond = cfg.TRAIN.COEFF.UNCOND_LOSS * \
        criterion(fake_logits[1], fake_labels)
    #
    errD_real = errD_real + errD_real_uncond
    errD_wrong = errD_wrong + errD_wrong_uncond
    errD_fake = errD_fake + errD_fake_uncond
    #
    errD = errD_real + errD_wrong + errD_fake
else:
    errD = errD_real + 0.5 * (errD_wrong + errD_fake)

## Gnet

In [None]:
batch_size = self.real_imgs[0].size(0)
criterion, mu, logvar = self.criterion, self.mu, self.logvar
real_labels = self.real_labels[:batch_size]
for i in range(self.num_Ds):
    outputs = self.netsD[i](self.fake_imgs[i], mu)
    errG = criterion(outputs[0], real_labels)
    if len(outputs) > 1 and cfg.TRAIN.COEFF.UNCOND_LOSS > 0:
        errG_patch = cfg.TRAIN.COEFF.UNCOND_LOSS *\
            criterion(outputs[1], real_labels)
        errG = errG + errG_patch
    errG_total = errG_total + errG

## Dnet Model End

In [None]:
x_code = self.img_code_s256_4(x_code)

if cfg.GAN.B_CONDITION and c_code is not None:
    c_code = c_code.view(-1, self.ef_dim, 1, 1)
    c_code = c_code.repeat(1, 1, 4, 4)
    # state size (ngf+egf) x 4 x 4
    h_c_code = torch.cat((c_code, x_code), 1)
    # state size ngf x in_size x in_size
    h_c_code = self.jointConv(h_c_code)
else:
    h_c_code = x_code

output = self.logits(h_c_code)
if cfg.GAN.B_CONDITION:
    out_uncond = self.uncond_logits(x_code)
    return [output.view(-1), out_uncond.view(-1)]