Skip to content

Commit

Permalink
Introduce convenience class for conducting GAN experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnVinyard committed Feb 17, 2018
1 parent 5de5f5f commit e7d7df2
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 10 deletions.
19 changes: 17 additions & 2 deletions zounds/learn/dct_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def __init__(self, use_cuda=False):
super(DctTransform, self).__init__()
self.use_cuda = use_cuda
self._basis_cache = dict()
self._window_cache = dict()

def _variable(self, x, *args, **kwargs):
v = Variable(x, *args, **kwargs)
Expand All @@ -20,9 +21,21 @@ def dct_basis(self, n):
return self._basis_cache[n]
except KeyError:
basis = torch.from_numpy(dct_basis(n)).float()
if self.use_cuda:
basis = basis.cuda()
self._basis_cache[n] = basis
return basis

def window(self, n, window):
try:
return self._window_cache[n]
except KeyError:
data = torch.from_numpy(window._wdata(n)).float()
if self.use_cuda:
data = data.cuda()
self._window_cache[n] = data
return data

def _base_dct_transform(self, x, basis, axis=-1):
n = torch.FloatTensor(1)
n[:] = 2. / x.shape[axis]
Expand All @@ -43,7 +56,7 @@ def dct_resample(self, x, factor, axis=-1):
# figure out how many samples our resampled signal will have
n_samples = int(factor * x.shape[axis])

coeffs = self.dct_basis(x)
coeffs = self.dct(x)

# create the shape of our new coefficients
new_coeffs_shape = list(coeffs.shape)
Expand All @@ -64,9 +77,11 @@ def dct_resample(self, x, factor, axis=-1):

return self.idct(new_coeffs)

def short_time_dct(self, x, size, step):
def short_time_dct(self, x, size, step, window):
original_shape = x.shape
x = x.unfold(-1, size, step)
window = self._variable(self.window(x.shape[-1], window))
x = x * window
x = self.dct(x, axis=-1)
x = x.view((original_shape[0], size, x.shape[2]))
return x
33 changes: 30 additions & 3 deletions zounds/learn/gan_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from util import from_var
from random import choice
from zounds.spectral import stft, rainbowgram
from zounds.learn import try_network
from zounds.timeseries import SR11025, SampleRate, Seconds, AudioSamples
from wgan import WassersteinGanTrainer
from pytorch_model import PyTorchGan
Expand All @@ -28,13 +29,15 @@ def __init__(
batch_size=32,
n_samples=int(5e5),
latent_dim=100,
real_sample_transformer=lambda x: x,
debug_gradients=False,
sample_size=8192,
sample_hop=1024,
samplerate=SR11025(),
app_port=8888):

super(GanExperiment, self).__init__()
self.real_sample_transformer = real_sample_transformer
self.debug_gradients = debug_gradients
self.n_samples = n_samples
self.batch_size = batch_size
Expand Down Expand Up @@ -94,23 +97,43 @@ def batch_complete(self, epoch, network, samples):

def fake_audio(self):
sample = choice(self.fake_samples)
return AudioSamples(sample, self.samplerate)\
return AudioSamples(sample, self.samplerate) \
.pad_with_silence(Seconds(1))

def fake_stft(self):
samples = self.fake_audio()
def _stft(self, samples):
samples = samples / np.abs(samples.max())
wscheme = SampleRate(
frequency=samples.samplerate.frequency * 128,
duration=samples.samplerate.frequency * 256)
coeffs = stft(samples, wscheme, HanningWindowingFunc())
return rainbowgram(coeffs)

def fake_stft(self):
samples = self.fake_audio()
return self._stft(samples)

def real_stft(self):
snd = self.sound_cls.random()
windowed = choice(snd.windowed)
return self._stft(windowed)

def test(self):
z = np.random.normal(
0, 1, (self.batch_size, self.latent_dim)).astype(np.float32)
samples = try_network(self.gan_pair.generator, z)
samples = from_var(samples)
print samples.shape
wasserstein_estimate = try_network(self.gan_pair.discriminator, samples)
print wasserstein_estimate.shape

def run(self):
ingest(self.dataset, self.sound_cls, multi_threaded=True)

experiment = self
fake_audio = self.fake_audio
fake_stft = self.fake_stft
real_stft = self.real_stft
Sound = self.sound_cls

self.app = ZoundsApp(
model=self.sound_cls,
Expand All @@ -135,6 +158,10 @@ def run(self):
on_batch_complete=self.batch_complete,
debug_gradient=self.debug_gradients)

def gen():
for snd in self.sound_cls:
yield self.real_sample_transformer(snd.windowed)

self.gan_pipeline.process(
samples=(snd.windowed for snd in self.sound_cls),
trainer=trainer,
Expand Down
7 changes: 6 additions & 1 deletion zounds/learn/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ def __init__(
batch_size,
holdout_percent=0.0,
data_preprocessor=lambda x: x,
label_preprocessor=lambda x: x):
label_preprocessor=lambda x: x,
on_batch_complete=None):

super(SupervisedTrainer, self).__init__(
epochs,
batch_size)

self.on_batch_complete = on_batch_complete
self.label_preprocessor = label_preprocessor
self.data_preprocessor = data_preprocessor
self.holdout_percent = holdout_percent
Expand Down Expand Up @@ -56,6 +58,9 @@ def batch(d, l, test=False):
error.backward()
self.optimizer.step()

if self.on_batch_complete:
self.on_batch_complete(inp_v, output)

return error.data[0]

for epoch in xrange(self.epochs):
Expand Down
9 changes: 5 additions & 4 deletions zounds/learn/wgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,12 @@ def _gradient_penalty(self, real_samples, fake_samples, kwargs):

# computing the norm of the gradients is very expensive, so I'm only
# taking a subset of the minibatch here
subset_size = min(10, real_samples.shape[0])
# subset_size = min(10, real_samples.shape[0])
subset_size = real_samples.shape[0]

real_samples = real_samples[:subset_size]
fake_samples = fake_samples[:subset_size]

# TODO: this should have the same number of dimensions as real and
# fake samples, and should not be hard-coded
alpha = torch.rand(subset_size).cuda()
alpha = alpha.view((-1,) + ((1,) * (real_samples.dim() - 1)))

Expand Down Expand Up @@ -190,6 +189,7 @@ def train(self, data):
d_loss.backward()
critic_optim.step()


self.zero_discriminator_gradients()
self.zero_generator_gradients()

Expand All @@ -213,13 +213,14 @@ def train(self, data):

gl = g_loss.data[0]
dl = d_loss.data[0]
rl = real_mean.data[0]

if self.on_batch_complete:
self.on_batch_complete(epoch, self.network, self.samples)

if i % 10 == 0:
print \
'Epoch {epoch}, batch {i}, generator {gl}, critic {dl}' \
'Epoch {epoch}, batch {i}, generator {gl}, real {rl}, critic {dl}' \
.format(**locals())

return self.network

0 comments on commit e7d7df2

Please sign in to comment.