Skip to content

Commit

Permalink
GanExperiment now uses the training monitor app
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnVinyard committed Feb 27, 2018
1 parent 60dd7f6 commit 01bd4c8
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 90 deletions.
3 changes: 2 additions & 1 deletion zounds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
infinite_streaming_learning_pipeline

from ui import \
ZoundsApp, ZoundsSearch, TrainingMonitorApp, RangeUnitUnsupportedException
ZoundsApp, ZoundsSearch, TrainingMonitorApp, SupervisedTrainingMonitoApp, \
GanTrainingMonitorApp, RangeUnitUnsupportedException

from index import \
SearchResults, HammingDb, HammingIndex, BruteForceSearch, \
Expand Down
118 changes: 58 additions & 60 deletions zounds/learn/gan_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@
from util import from_var
from random import choice
from zounds.spectral import stft, rainbowgram
from zounds.learn import try_network
from zounds.learn import \
try_network, infinite_streaming_learning_pipeline, \
object_store_pipeline_settings
from zounds.timeseries import \
SR11025, SampleRate, Seconds, AudioSamples, audio_sample_rate
from wgan import WassersteinGanTrainer
from pytorch_model import PyTorchGan
from graph import learning_pipeline
from util import simple_settings
from preprocess import PreprocessingPipeline, InstanceScaling
from zounds.ui import ZoundsApp
from preprocess import InstanceScaling
from zounds.ui import GanTrainingMonitorApp
from zounds.util import simple_lmdb_settings
from zounds.basic import resampled
from zounds.spectral import HanningWindowingFunc, SlidingWindow
from zounds.basic import windowed
from zounds.spectral import HanningWindowingFunc
from zounds.datasets import ingest
from zounds.persistence import ArrayWithUnitsFeature
import numpy as np


Expand All @@ -25,6 +24,8 @@ def __init__(
experiment_name,
dataset,
gan_pair,
object_storage_username,
object_storage_api_key,
epochs=500,
n_critic_iterations=10,
batch_size=32,
Expand All @@ -35,7 +36,9 @@ def __init__(
sample_size=8192,
sample_hop=1024,
samplerate=SR11025(),
app_port=8888):
app_port=8888,
object_storage_region='DFW',
app_secret=None):

super(GanExperiment, self).__init__()
self.real_sample_transformer = real_sample_transformer
Expand All @@ -53,47 +56,41 @@ def __init__(
self.sample_size = sample_size
self.latent_dim = latent_dim
self.experiment_name = experiment_name
self.app_secret = app_secret

base_model = resampled(
resample_to=self.samplerate, store_resampled=True)

window_sample_rate = SampleRate(
frequency=self.samplerate.frequency * sample_hop,
duration=samplerate.frequency * sample_size)
base_model = windowed(
resample_to=self.samplerate,
store_resampled=True,
wscheme=self.samplerate * (sample_hop, sample_size))

@simple_lmdb_settings(
experiment_name, map_size=1e11, user_supplied_id=True)
class Sound(base_model):
windowed = ArrayWithUnitsFeature(
SlidingWindow,
wscheme=window_sample_rate,
needs=base_model.resampled)
pass

self.sound_cls = Sound

base_pipeline = learning_pipeline()

@simple_settings
class Gan(base_pipeline):
@object_store_pipeline_settings(
'Gan-{experiment_name}'.format(**locals()),
object_storage_region,
object_storage_username,
object_storage_api_key)
@infinite_streaming_learning_pipeline
class Gan(ff.BaseModel):
scaled = ff.PickleFeature(
InstanceScaling,
needs=base_pipeline.shuffled)
InstanceScaling)

wgan = ff.PickleFeature(
PyTorchGan,
trainer=ff.Var('trainer'),
needs=scaled)

pipeline = ff.PickleFeature(
PreprocessingPipeline,
needs=(scaled, wgan),
store=True)

self.gan_pipeline = Gan()
self.fake_samples = None
self.app = None

def batch_complete(self, epoch, network, samples):
def batch_complete(self, *args, **kwargs):
samples = kwargs['samples']
self.fake_samples = from_var(samples).squeeze()

def fake_audio(self):
Expand Down Expand Up @@ -139,37 +136,38 @@ def run(self):
real_stft = self.real_stft
Sound = self.sound_cls

self.app = ZoundsApp(
model=self.sound_cls,
audio_feature=self.sound_cls.ogg,
visualization_feature=self.sound_cls.windowed,
try:
network = self.gan_pipeline.load_network()
print 'initialized weights'
except RuntimeError as e:
print 'Error', e
network = self.gan_pair
for p in network.parameters():
p.data.normal_(0, 0.02)

trainer = WassersteinGanTrainer(
network,
latent_dimension=(self.latent_dim,),
n_critic_iterations=self.n_critic_iterations,
epochs=self.epochs,
batch_size=self.batch_size,
debug_gradient=self.debug_gradients)
trainer.register_batch_complete_callback(self.batch_complete)

self.app = GanTrainingMonitorApp(
trainer=trainer,
model=Sound,
visualization_feature=Sound.windowed,
audio_feature=Sound.ogg,
globals=globals(),
locals=locals())
locals=locals(),
secret=self.app_secret)

with self.app.start_in_thread(self.app_port):
if not self.gan_pipeline.exists():
network = self.gan_pair

for p in network.parameters():
p.data.normal_(0, 0.02)

trainer = WassersteinGanTrainer(
network,
latent_dimension=(self.latent_dim,),
n_critic_iterations=self.n_critic_iterations,
epochs=self.epochs,
batch_size=self.batch_size,
on_batch_complete=self.batch_complete,
debug_gradient=self.debug_gradients)

def gen():
for snd in self.sound_cls:
yield self.real_sample_transformer(snd.windowed)

self.gan_pipeline.process(
samples=(snd.windowed for snd in self.sound_cls),
trainer=trainer,
nsamples=self.n_samples,
dtype=np.float32)
self.gan_pipeline.process(
dataset=(Sound, Sound.windowed),
trainer=trainer,
nsamples=self.n_samples,
dtype=np.float32)

self.app.start(self.app_port)
76 changes: 48 additions & 28 deletions zounds/learn/wgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class WassersteinGanTrainer(Trainer):
epoch, and a minibatch, and mutates the minibatch
kwargs_factory (callable): function that takes the current epoch and
outputs args to pass to the generator and discriminator
on_batch_complete (callable): callable invoked after each epoch,
accepting epoch and network being trained as arguments
"""

def __init__(
Expand All @@ -29,12 +27,12 @@ def __init__(
batch_size,
preprocess_minibatch=None,
kwargs_factory=None,
on_batch_complete=None,
debug_gradient=False):
debug_gradient=False,
checkpoint_epochs=1):

super(WassersteinGanTrainer, self).__init__(epochs, batch_size)
self.checkpoint_epochs = checkpoint_epochs
self.debug_gradient = debug_gradient
self.on_batch_complete = on_batch_complete
self.arg_maker = kwargs_factory
self.preprocess = preprocess_minibatch
self.n_critic_iterations = n_critic_iterations
Expand All @@ -43,6 +41,16 @@ def __init__(
self.critic = network.discriminator
self.generator = network.generator
self.samples = None
self.register_batch_complete_callback(self._log)
self.generator_optim = None
self.critic_optim = None

def _log(self, *args, **kwargs):
if kwargs['batch'] % 10:
return
msg = 'Epoch {epoch}, batch {batch}, generator {generator_score}, ' \
'real {real_score}, critic {critic_loss}'
print msg.format(**kwargs)

def _minibatch(self, data):
indices = np.random.randint(0, len(data), self.batch_size)
Expand All @@ -58,9 +66,6 @@ def _gradient_penalty(self, real_samples, fake_samples, kwargs):

real_samples = real_samples.view(fake_samples.shape)

# computing the norm of the gradients is very expensive, so I'm only
# taking a subset of the minibatch here
# subset_size = min(10, real_samples.shape[0])
subset_size = real_samples.shape[0]

real_samples = real_samples[:subset_size]
Expand Down Expand Up @@ -116,15 +121,29 @@ def zero_discriminator_gradients(self):
self._debug_network_gradient(self.critic)
self.critic.zero_grad()

def _init_optimizers(self):
if self.generator_optim is None or self.critic_optim is None:
from torch.optim import Adam
trainable_generator_params = (
p for p in self.generator.parameters() if p.requires_grad)
trainable_critic_params = (
p for p in self.critic.parameters() if p.requires_grad)

self.generator_optim = Adam(
trainable_generator_params, lr=0.0001, betas=(0, 0.9))
self.critic_optim = Adam(
trainable_critic_params, lr=0.0001, betas=(0, 0.9))

def train(self, data):

import torch
from torch.optim import Adam
from torch.autograd import Variable

data = data.astype(np.float32)
self.network.train()
self.unfreeze_discriminator()
self.unfreeze_generator()

zdim = self.latent_dimension
data = data.astype(np.float32)

noise_shape = (self.batch_size,) + self.latent_dimension
noise = torch.FloatTensor(*noise_shape)
Expand All @@ -134,17 +153,15 @@ def train(self, data):
self.critic.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

trainable_generator_params = (
p for p in self.generator.parameters() if p.requires_grad)
trainable_critic_params = (
p for p in self.critic.parameters() if p.requires_grad)
self._init_optimizers()

start = self._current_epoch
stop = self._current_epoch + self.checkpoint_epochs

generator_optim = Adam(
trainable_generator_params, lr=0.0001, betas=(0, 0.9))
critic_optim = Adam(
trainable_critic_params, lr=0.0001, betas=(0, 0.9))
for epoch in xrange(start, stop):
if epoch >= self.epochs:
break

for epoch in xrange(self.epochs):
if self.arg_maker:
kwargs = self.arg_maker(epoch)
else:
Expand Down Expand Up @@ -187,7 +204,7 @@ def train(self, data):
gp = self._gradient_penalty(input_v.data, fake.data, kwargs)
d_loss = (fake_mean - real_mean) + gp
d_loss.backward()
critic_optim.step()
self.critic_optim.step()

self.zero_discriminator_gradients()
self.zero_generator_gradients()
Expand All @@ -208,18 +225,21 @@ def train(self, data):
d_fake = self.critic.forward(fake, **kwargs)
g_loss = -torch.mean(d_fake)
g_loss.backward()
generator_optim.step()
self.generator_optim.step()

gl = g_loss.data[0]
dl = d_loss.data[0]
rl = real_mean.data[0]

if self.on_batch_complete:
self.on_batch_complete(epoch, self.network, self.samples)
self.on_batch_complete(
epoch=epoch,
batch=i,
generator_score=gl,
real_score=rl,
critic_loss=dl,
samples=self.samples,
network=self.network)

if i % 10 == 0:
print \
'Epoch {epoch}, batch {i}, generator {gl}, real {rl}, critic {dl}' \
.format(**locals())
self._current_epoch += 1

return self.network
3 changes: 2 additions & 1 deletion zounds/ui/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from contentrange import RangeUnitUnsupportedException
from api import ZoundsApp
from search import ZoundsSearch
from training_monitor import TrainingMonitorApp
from training_monitor import \
TrainingMonitorApp, SupervisedTrainingMonitoApp, GanTrainingMonitorApp

0 comments on commit 01bd4c8

Please sign in to comment.