Skip to content

Commit

Permalink
Ensure that WGAN trainer doesn't move data to the GPU unless instruct…
Browse files Browse the repository at this point in the history
…ed to do so
  • Loading branch information
JohnVinyard committed Apr 12, 2018
1 parent baf0f84 commit 5f3f75b
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions zounds/learn/wgan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from trainer import Trainer
import numpy as np
import torch
from torch.autograd import Variable


class WassersteinGanTrainer(Trainer):
Expand Down Expand Up @@ -71,18 +73,26 @@ def _gradient_penalty(self, real_samples, fake_samples, kwargs):
real_samples = real_samples[:subset_size]
fake_samples = fake_samples[:subset_size]

alpha = torch.rand(subset_size).cuda()
alpha = torch.rand(subset_size)
if self.use_cuda:
alpha = alpha.cuda()
alpha = alpha.view((-1,) + ((1,) * (real_samples.dim() - 1)))

interpolates = alpha * real_samples + ((1 - alpha) * fake_samples)
interpolates = Variable(interpolates.cuda(), requires_grad=True)
interpolates = Variable(interpolates, requires_grad=True)
if self.use_cuda:
interpolates = interpolates.cuda()

d_output = self.critic(interpolates, **kwargs)

grad_ouputs = torch.ones(d_output.size())
if self.use_cuda:
grad_ouputs = grad_ouputs.cuda()

gradients = grad(
outputs=d_output,
inputs=interpolates,
grad_outputs=torch.ones(d_output.size()).cuda(),
grad_outputs=grad_ouputs,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
Expand Down Expand Up @@ -134,10 +144,11 @@ def _init_optimizers(self):
self.critic_optim = Adam(
trainable_critic_params, lr=0.0001, betas=(0, 0.9))

def train(self, data):
def _cuda(self, device=None):
self.generator = self.generator.cuda()
self.critic = self.critic.cuda()

import torch
from torch.autograd import Variable
def train(self, data):

self.network.train()
self.unfreeze_discriminator()
Expand All @@ -146,12 +157,7 @@ def train(self, data):
data = data.astype(np.float32)

noise_shape = (self.batch_size,) + self.latent_dimension
noise = torch.FloatTensor(*noise_shape)
fixed_noise = torch.FloatTensor(*noise_shape).normal_(0, 1)

self.generator.cuda()
self.critic.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
noise = self._tensor(noise_shape)

self._init_optimizers()

Expand All @@ -178,10 +184,7 @@ def train(self, data):

self.zero_discriminator_gradients()

minibatch = self._minibatch(data)
inp = torch.from_numpy(minibatch)
inp = inp.cuda()
input_v = Variable(inp)
input_v = self._variable(self._minibatch(data))

if self.preprocess:
input_v = self.preprocess(epoch, input_v)
Expand Down

0 comments on commit 5f3f75b

Please sign in to comment.