# Model implementation 01

This notebook is a first pass at reproducing the Wassertain GAN model used 
in


Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Residual Blocks

The model makes use of residual blocks in both predictor and discriminator

In [2]:
# residual blocks used in generator and discriminator
class ResidualBlock(nn.Module):
    def __init__(self, d_in, d_out, relu): #pick nn.ReLU for generator, nn.LeakyRelu for discriminator
        super().__init__()
        self.relu_1 = relu()
        self.conv_1 = nn.Conv1d(in_channels=d_in, out_channels=d_out, kernel_size=5, padding = 2)
        self.relu_2 = relu()
        self.conv_2 = nn.Conv1d(in_channels=d_in, out_channels=d_out, kernel_size=5, padding = 2)
        # bias not needed, already present by default in Conv1d()
        
    def forward(self, x):
        y = self.relu_1(x)
        y = self.conv_1(y)
        y = self.relu_2(y)
        y = self.conv_2(y)
        return x + (0.3 * y)

Test this class with a simple example

In [3]:
resblock = ResidualBlock(100, 100, nn.ReLU)
print(resblock)

ResidualBlock(
  (relu_1): ReLU()
  (conv_1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu_2): ReLU()
  (conv_2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,))
)


In [4]:
x = torch.randn(1, 100, 100)
resblock(x)

tensor([[[-1.2654,  0.9001,  0.9767,  ...,  0.7684, -0.6504, -0.2017],
         [ 0.4391, -0.5020, -1.0044,  ..., -0.7929,  1.2517,  0.0854],
         [-0.8011, -1.2852, -0.9045,  ...,  0.2945,  0.3540,  0.1076],
         ...,
         [ 0.9124,  2.4432, -0.6559,  ..., -1.5265,  0.4473, -0.2616],
         [-0.2717,  0.5713,  0.6560,  ...,  2.0048, -1.6967, -1.1009],
         [ 0.2134,  0.3010,  0.1197,  ...,  0.9366,  0.5990, -0.9680]]],
       grad_fn=<AddBackward0>)

### Verify blocks work on intended input shape

Test input linear layer layout, with an input from latent space $Z$, then feed into resblock

In [5]:
z = torch.randn(100)
lin = nn.Linear(100, 100*50) # for sequence of 50 nucleotides

x = lin(z) 
x = torch.reshape(x, (1,50,100))


In [6]:
x.shape

torch.Size([1, 50, 100])

In [7]:
lin(z).shape

torch.Size([5000])

In [8]:
z = torch.randn(100)
lin = nn.Linear(100, 100*50) # for sequence of 50 nucleotides

x = lin(z) 

In [9]:
x.shape

torch.Size([5000])

## Generator

In [10]:
class Generator(nn.Module):
    def __init__(self, seq_len, batch_size):
        super(Generator, self).__init__()
        self.linear = nn.Linear(100, 100*seq_len)
        self.res = []
        for i in range(5):
            self.res.append(ResidualBlock(100, 100, relu=nn.ReLU))
        self.conv_f = nn.Conv1d(in_channels=100, out_channels=4, kernel_size=1)
        
        self.batch_size = batch_size
        self.seq_len = seq_len
    
    def forward(self,z):
        y = self.linear(z)
        y = torch.reshape(y, (self.batch_size, 100, self.seq_len))
        for i in range(5):
            y = self.res[i](y)
        y = self.conv_f(y)
        y = F.softmax(y, dim=1)
        return y

Test generator on a simple example

In [11]:
z = torch.randn(100)
generator = Generator(50, 1)
x = generator(z)
x.shape

torch.Size([1, 4, 50])

## Discriminator

In [12]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Cond1d(in_channels=4, out_channels=100)
    def forward(self, z):
        pass

Gradient penalty used in WGAN-WP architecture [GitHub source](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py)

In [13]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty