# Model implementation 01

This notebook is a first pass at implementing in PyTorch the Wassertain GAN model used in 
[Generating and designing DNA with deep generative models](https://arxiv.org/abs/1712.06148) (Killoran and al, 2017)



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

## Residual Blocks

The model makes use of residual blocks in both predictor and discriminator

In [2]:
# residual blocks used in generator and discriminator
class ResidualBlock(nn.Module):
    def __init__(self, d_in, d_out, relu): #pick nn.ReLU for generator, nn.LeakyRelu for discriminator
        super().__init__()
        self.relu_1 = relu()
        self.conv_1 = nn.Conv1d(in_channels=d_in, out_channels=d_out, kernel_size=5, padding = 2)
        self.relu_2 = relu()
        self.conv_2 = nn.Conv1d(in_channels=d_in, out_channels=d_out, kernel_size=5, padding = 2)
        # bias not needed, already present by default in Conv1d()
        
    def forward(self, x):
        y = self.relu_1(x)
        y = self.conv_1(y)
        y = self.relu_2(y)
        y = self.conv_2(y)
        return x + (0.3 * y)

Test this class with a simple example

In [3]:
resblock = ResidualBlock(100, 100, nn.ReLU)
print(resblock)

ResidualBlock(
  (relu_1): ReLU()
  (conv_1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu_2): ReLU()
  (conv_2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,))
)


In [4]:
x = torch.randn(1, 100, 100)
resblock(x)

tensor([[[ 2.1803,  1.4953,  1.8549,  ..., -1.3184,  0.8193, -0.3785],
         [-0.8168, -0.5969, -0.1555,  ..., -0.1371,  1.0691, -2.5855],
         [-1.0546,  1.9645,  0.5384,  ...,  0.3895,  1.5450, -0.4745],
         ...,
         [ 0.3187, -0.6233,  1.5000,  ...,  0.4980, -1.0958,  1.4441],
         [ 0.4771, -0.4912,  0.3552,  ...,  0.5405,  1.1066,  0.0172],
         [ 0.3780,  1.8369, -0.2172,  ..., -0.9122,  0.2721,  0.5916]]],
       grad_fn=<AddBackward0>)

### Verify blocks work on intended input shape

Test input linear layer layout, with an input from latent space $Z$, then feed into resblock

In [5]:
z = torch.randn(100)
lin = nn.Linear(100, 100*50) # for sequence of 50 nucleotides

x = lin(z) 
x = torch.reshape(x, (1,50,100))


In [6]:
x.shape

torch.Size([1, 50, 100])

In [7]:
lin(z).shape

torch.Size([5000])

In [8]:
z = torch.randn(100)
lin = nn.Linear(100, 100*50) # for sequence of 50 nucleotides

x = lin(z) 

In [9]:
x.shape

torch.Size([5000])

## Generator

In [10]:
class Generator(nn.Module):
    def __init__(self, seq_len, batch_size):
        super(Generator, self).__init__()
        self.linear_in = nn.Linear(in_features=100, out_features=100*seq_len)
        self.resblocks = []
        for _ in range(5):
            self.resblocks.append(ResidualBlock(100, 100, relu=nn.ReLU))
        self.conv_out = nn.Conv1d(in_channels=100, out_channels=4, kernel_size=1)
        
        self.batch_size = batch_size
        self.seq_len = seq_len
    
    def forward(self,z):
        y = self.linear_in(z)
        y = torch.reshape(y, (self.batch_size, 100, self.seq_len))
        for i in range(5):
            y = self.resblocks[i](y)
        y = self.conv_out(y)
        y = F.softmax(y, dim=1)
        return y

Test generator on a simple example

In [11]:
z = torch.randn(100)
generator = Generator(50, 1)
x = generator(z)
x.shape

torch.Size([1, 4, 50])

In [12]:
x

tensor([[[0.2510, 0.2497, 0.1958, 0.1733, 0.1547, 0.2807, 0.1279, 0.5038,
          0.4120, 0.1492, 0.2069, 0.3457, 0.1523, 0.2672, 0.2112, 0.3845,
          0.2338, 0.2684, 0.2222, 0.2089, 0.2416, 0.2203, 0.1738, 0.2467,
          0.2477, 0.2930, 0.2330, 0.1932, 0.1619, 0.2039, 0.3279, 0.1746,
          0.3710, 0.3533, 0.2747, 0.2756, 0.2343, 0.2874, 0.2969, 0.3586,
          0.2194, 0.1976, 0.1705, 0.2653, 0.3341, 0.2387, 0.2466, 0.2969,
          0.3488, 0.2826],
         [0.1645, 0.2566, 0.2439, 0.3421, 0.2916, 0.2685, 0.1815, 0.1640,
          0.2284, 0.1609, 0.4385, 0.1738, 0.5587, 0.2149, 0.2189, 0.1339,
          0.2119, 0.1712, 0.3360, 0.2256, 0.2136, 0.1941, 0.3016, 0.1888,
          0.2081, 0.2794, 0.1971, 0.3620, 0.1995, 0.1984, 0.1607, 0.2298,
          0.1078, 0.1560, 0.2047, 0.2279, 0.1532, 0.1333, 0.1380, 0.2796,
          0.1931, 0.3692, 0.2212, 0.2475, 0.1293, 0.1975, 0.1934, 0.2848,
          0.1748, 0.1442],
         [0.3368, 0.2439, 0.1729, 0.3146, 0.3277, 0.2440, 

## Discriminator

In [13]:
class Discriminator(nn.Module):
    def __init__(self, seq_len, batch_size):
        super().__init__()
        self.conv_in = nn.Conv1d(in_channels=4, out_channels=100, kernel_size=1)
        self.resblocks = []
        for _ in range(5):
            self.resblocks.append(ResidualBlock(100, 100, relu=nn.ReLU))
        self.linear_out = nn.Linear(in_features=100*seq_len, out_features=1)
        
        self.seq_len = seq_len
        self.batch_size = batch_size
        
    def forward(self, x):
        y = self.conv_in(x)
        for i in range(5):
            y = self.resblocks[i](y)
        y = torch.reshape(y, (self.batch_size, 100*self.seq_len))
        y = self.linear_out(y)
        return y

Simple test on the previous output of the generator

In [14]:
discriminator = Discriminator(50, 1)

In [15]:
y = discriminator(x)

## Gradient Penalty

Gradient penalty used in WGAN-WP architecture [GitHub source](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py)

In [16]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty