In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
# import torchvision.datasets as dset
# import torchvision.transforms as transforms
# import torchvision.utils as vutils
from torch.autograd import Variable


class PhaseShuffle(nn.Module):
    """
    Performs phase shuffling, i.e. shifting feature axis of a 3D tensor
    by a random integer in {-n, n} and performing reflection padding where
    necessary
    """
    def __init__(self, n):
        super(PhaseShuffle, self).__init__()
        self.n = n
        
    def forward(self, x):
        # Make sure to use PyTorch to generate number RNG state is all shared
        k = int(torch.Tensor(1).random_(0,self.n + 1)) - 5
        
        # Return if no phase shift
        if k == 0:
            return x
        
        # Slice feature dimension
        if k > 0:
            x_trunc = x[:, :, :-k]
            pad = (0, k)
        else:
            x_trunc = x[:, :, -k:]
            pad = (-k, 0)
        
        # Reflection padding
        x_shuffle = F.pad(x_trunc, pad, mode='reflect')
        assert x_shuffle.shape == x.shape, "{}, {}".format(x_shuffle.shape, x.shape)
        return x_shuffle
        


class WaveGANGenerator(nn.Module):
    def __init__(self, d, ngpu, c=1, latent_dim=100, verbose=False):
        super(WaveGANGenerator, self).__init__()
        self.ngpu = ngpu
        self.d = d
        self.c = c
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(100, 256*d)
        self.verbose = verbose
        # NOTE: Added padding to make dimensions match. We probably want to check this
        self.tconv1 = nn.ConvTranspose1d(16*d, 8*d, 25, stride=4, padding=11, output_padding=1) 
        self.tconv2 = nn.ConvTranspose1d(8*d, 4*d, 25, stride=4, padding=11, output_padding=1) 
        self.tconv3 = nn.ConvTranspose1d(4*d, 2*d, 25, stride=4, padding=11, output_padding=1) 
        self.tconv4 = nn.ConvTranspose1d(2*d, d, 25, stride=4, padding=11, output_padding=1) 
        self.tconv5 = nn.ConvTranspose1d(d, c, 25, stride=4, padding=11, output_padding=1) 

    def forward(self, x):
        x = F.relu(self.fc1(x))

        x = x.view(-1, 16*self.d, 16)
        if self.verbose:
            print(x.shape)
            
        x = F.relu(self.tconv1(x))
        if self.verbose:
            print(x.shape)
            
        x = F.relu(self.tconv2(x))
        if self.verbose:
            print(x.shape)
            
        x = F.relu(self.tconv3(x))
        if self.verbose:
            print(x.shape)
            
        x = F.relu(self.tconv4(x))
        if self.verbose:
            print(x.shape)
            
        output = F.tanh(self.tconv5(x))
        if self.verbose:
            print(output.shape)

        return output

    
class WaveGANDiscriminator(nn.Module):
    def __init__(self, d, c=1, n=2, verbose=False):
        super(WaveGANDiscriminator, self).__init__()
        self.d = d
        self.c = c
        self.n = n
        self.verbose = verbose
        # Conv2d(in_channels, out_channels, kernel_size, stride=1, etc.)
        self.conv1 = nn.Conv1d(c, d, 25, stride=4, padding=11)
        self.conv2 = nn.Conv1d(d, 2*d, 25, stride=4, padding=11)
        self.conv3 = nn.Conv1d(2*d, 4*d, 25, stride=4, padding=11)
        self.conv4 = nn.Conv1d(4*d, 8*d, 25, stride=4, padding=11)
        self.conv5 = nn.Conv1d(8*d, 16*d, 25, stride=4, padding=11)
        self.ps1 = PhaseShuffle(n)
        self.ps2 = PhaseShuffle(n)
        self.ps3 = PhaseShuffle(n)
        self.ps4 = PhaseShuffle(n)
        self.fc1 = nn.Linear(256*d, 1)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        if self.verbose:
            print(x.shape)
        x = self.ps1(x)
        
        x = F.leaky_relu(self.conv2(x))
        if self.verbose:
            print(x.shape)            
        x = self.ps2(x)
        
        x = F.leaky_relu(self.conv3(x))
        if self.verbose:
            print(x.shape)
        x = self.ps3(x)
        
        x = F.leaky_relu(self.conv4(x))
        if self.verbose:
            print(x.shape)            
        x = self.ps4(x)
        
        x = self.conv5(x)
        if self.verbose:
            print(x.shape)
            
        x = x.view(-1, 256*self.d)
        if self.verbose:
            print(x.shape)
            
        return F.sigmoid(self.fc1(x))

In [2]:
# Create generator
latent_dim = 100
gen = WaveGANGenerator(d=64, ngpu=0, c=1, latent_dim=latent_dim).cuda()

In [3]:
# Create discriminator
disc = WaveGANDiscriminator(d=64).cuda()

In [4]:
# Sample from noise distribution p(z)
z = torch.Tensor(5, latent_dim).uniform_(0, 1)
z = Variable(z).cuda()

In [5]:
# Run the sample through the generator to generate a sample
# from the model distribution
out = gen(z)

In [6]:
# Evaluate the given waveforms with the discriminator
disc(out)

Variable containing:
 0.5006
 0.5006
 0.5006
 0.5006
 0.5006
[torch.cuda.FloatTensor of size 5x1 (GPU 0)]

In [None]:
# TODO: Write training code