In [1]:
from layers import *
from lasagne.layers import *
import lasagne

from lasagne.nonlinearities import LeakyRectify

from collections import OrderedDict

import theano
import theano.tensor as T

from nets import ResGenerator, Discriminator

from theano.tensor.shared_randomstreams import RandomStreams
srng = RandomStreams(seed=234)

## Define da Nets!

In [2]:
# a & b
realA = T.tensor4("Real A input")
realB = T.tensor4("Real B input")

# G(a) -> b_hat & F(b) -> a_hat
genA2B = ResGenerator(filter_size=(64, 64), input_var=realA)
genB2A = ResGenerator(filter_size=(64, 64), input_var=realB)

# D_a(F(b)) & D_b(G(a))
discA = Discriminator(genB2A, wasserstein=True, real_inp_var=realA)
discB = Discriminator(genA2B, wasserstein=True, real_inp_var=realB)

# F(G(a)) -> a_cyc & G(F(b)) -> b_cyc
cycleA = genB2A.get_output(genA2B.output_var)
cycleB = genA2B.get_output(genB2A.output_var)

## Losses

In [3]:
# TODO: turn into function

class WGANLosses:
    def __init__(self, gen, disc, lambd=10):
        self.epsilon = srng.uniform([1]) # low=0, high=1
        
        x_hat = self.epsilon * disc.real_inp_var + (1 - self.epsilon) * gen.output_var
        
        x_hat_score = get_output(gen.layers["out"], x_hat).mean()
        x_hat_grad = T.grad(x_hat_score, x_hat)[0]
        
        wgan_loss = (disc.real_out.mean() - disc.fake_out.mean())
        gradient_penalty = ((T.sqrt(T.sum(x_hat_grad**2)) - 0.01)**2).mean()

        self.disc_loss = -wgan_loss + lambd * gradient_penalty
        self.gen_loss = -disc.fake_out.mean()

In [4]:
# Hyperparams
adam = {"learning_rate" : 1e-4, "beta1" : 0.5, "beta2" : 0.9}
lambd = 10

In [28]:
cycle_loss = (T.abs_(realA - cycleA) + T.abs_(realB - cycleB)).mean()

a_loss = WGANLosses(genA2B, discB)
b_loss = WGANLosses(genB2A, discA)

gen_comb_loss = a_loss.gen_loss + b_loss.gen_loss + cycle_loss
disc_comb_loss = a_loss.disc_loss + b_loss.disc_loss

gen_updates = lasagne.updates.adam(gen_comb_loss, genA2B.params + genB2A.params)
disc_updates = lasagne.updates.adam(disc_comb_loss, discA.params + discB.params)

trainGen = theano.function([realA, realB], gen_comb_loss, updates=gen_updates)
trainDisc = theano.function()

{12: 9,
 17: 9,
 18: 9,
 20: 9,
 22: u'cycle_loss = (T.abs_(realA - cycleA) + T.abs_(realB - cycleB)).mean()\n\na_loss = WGANLosses(genA2B, discB)\nb_loss = WGANLosses(genB2A, discA)\n\ngen_comb_loss = a_loss.gen_loss + b_loss.gen_loss + cycle_loss\ndisc_comb_loss = a_loss.disc_loss + b_loss.disc_loss\n\ngen_updates = lasagne.updates.adam(gen_comb_loss, genA2B.params + genB2A.params)\ndisc_updates = lasagne.updates.adam(disc_comb_loss, discA.params + discB.params)\n\ntrainGen = theano.function([realA, realB], gen_comb_loss, updates=gen_updates)'}