Packages

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import datetime
import matplotlib.pyplot as plt

Data

In [None]:
batch = 128

transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Lambda(lambda x: x.round())])

train_set = torchvision.datasets.MNIST(root="./data",
                                       train=True,
                                       download=True,
                                       transform=transform)

trainloader = torch.utils.data.DataLoader(train_set,
                                          batch_size=batch,
                                          shuffle=True,
                                          num_workers=2)

Network

In [None]:
def append_location(x, device):
  """
    Append location of pixels to the images.
  """
  idxs = torch.arange(x.shape[-1])/27
  grid = torch.meshgrid(idxs, idxs)
  locations = torch.stack(grid, dim=0) # 2, x.shape[1], x.shape[1] (Stack: concatenate along a new dimension)
  locations = locations.repeat(x.shape[0], 1, 1, 1) # batch_size, 2, x.shape[1], x.shape[1]
  x = torch.cat((x, locations.to(device)), dim = 1) # batch_size, x.shape[0] + 2, x.shape[1], x.shape[2]
  return x

In [None]:
class DilatedCausalConv1d(nn.Module):
  def __init__(self, in_channels, out_channels, dilation=1):
    super().__init__()
    self.dilation = dilation
    self.conv = nn.Conv1d(in_channels,
                          out_channels,
                          kernel_size=2,
                          dilation=dilation,
                          padding=0)

  def forward(self, x):
    """
      Input: batch, in_channels, sequence_length
    """
    x = self.conv(F.pad(x, [self.dilation, 0]))
    return x # batch, out_channels, sequence_length

class CausalConv1d(nn.Module):
  def __init__(self, in_channels, out_channels, dilation=1):
    super().__init__()
    self.conv = nn.Conv1d(in_channels,
                          out_channels,
                          kernel_size=2,
                          dilation=dilation,
                          padding=0)

  def forward(self, x):
    """
      Input: batch, in_channels, sequence_length
    """
    return self.conv(F.pad(x, [2, 0]))[:,:, :-1] # batch, out_channels, sequence_length

In [None]:
class ResidualBlock(nn.Module):
  def __init__(self, residual_channels, dilation=1):
    super().__init__()
    self.dilate = DilatedCausalConv1d(residual_channels, 2*residual_channels, dilation)
    self.conv1d = nn.Conv1d(residual_channels, residual_channels, kernel_size=1)

  def forward(self, x):
    """
      Input: x # batch, residual_channels, sequence_length
    """
    o = self.dilate(x)
    o1, o2 = o.chunk(2, dim=1)
    # print("o1 ",o1.shape)
    # print('o2 ', o2.shape)
    o = torch.tanh(o1) * torch.sigmoid(o2)
    x = x + self.conv1d(o)
    return x # batch, residual_channels, sequence_length

In [None]:
class WaveNet(nn.Module):
  def __init__(self, input_size, residual_channels, device, append_loc=True):
    super().__init__()
    self.input_size = input_size
    self.layers = 9
    self.residual_channels = residual_channels
    self.device = device
    self.append_loc = append_loc

    if self.append_loc:
      self.causal_conv = CausalConv1d(self.input_size[0]+2, self.residual_channels)
    else:
      self.causal_conv = CausalConv1d(self.input_size[0], self.residual_channels)

    res_blocks = []
    for i in range(self.layers):
      res_blocks.append(ResidualBlock(self.residual_channels, dilation=2**i))
    self.stacked_res_blocks = nn.Sequential(*res_blocks)

    self.out_conv = nn.Sequential(nn.Conv1d(self.residual_channels, self.input_size[0], kernel_size=1))
                                  # nn.ReLU(),
                                  # nn.Conv1d(self.input_size[0], self.input_size[0], kernel_size=1))

  def forward(self, x):
    """
      Input: x    # batch, channels, height, width
    """
    batch = x.shape[0]
    x = append_location(x, self.device) if self.append_loc else x
    x = x.view(batch, -1, self.input_size[1]*self.input_size[2]) # batch, in_channels, sequence_length

    x = self.causal_conv(x) # batch, residual_channels, sequence_length
    x = self.stacked_res_blocks(x) # batch, residual_channels, sequence_length
    x = self.out_conv(x) # batch, out_channels, sequence_length
    return x.view(batch, self.input_size[0], self.input_size[1], self.input_size[2]) # batch, channels, height, width

  def sample(self, n):
    with torch.no_grad():
      x = torch.zeros(n, self.input_size[0], self.input_size[1], self.input_size[2]).to(self.device)
      for i in range(self.input_size[1]):
        for j in range(self.input_size[2]):
          logits = self.forward(x)[:, :, i, j]
          probs = torch.sigmoid(logits)
          x[:, :, i, j] = torch.bernoulli(probs)
    return x.cpu() # n, channels, height, width

init model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
model = WaveNet((1, 28, 28), 64, device).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

Training loop

In [None]:
epochs = 20
loss_values = []

for epoch in range(1, epochs+1):
  for i, (imgs, _) in enumerate(trainloader):
    x = imgs.to(device)
    targets = imgs.to(device)
    probs = torch.sigmoid(model(x))

    optimizer.zero_grad()
    loss = F.binary_cross_entropy(probs, targets)
    # loss = model.nll(x)
    loss_values.append(loss.item())
    loss.backward()
    optimizer.step()

    if (i+1)%100==0:
      print(f'Epoch [{epoch}/{epochs}], Step: [{i+1 }/{len(trainloader)}], Time {datetime.datetime.now()}, Loss {loss.item()}')

  if epoch in [1, 4, 10, 14, 18, 20]:
    torch.save(model.state_dict(), f'/content/wavenet_epoch{epoch}.pth')
    samples = model.sample(16)
    for j in range(16):
      plt.subplot(4, 4, j+1)
      plt.imshow(samples[j, :, :, :].view(1, 28, 28).permute(1, 2, 0).numpy())
      plt.axis('off')
    plt.savefig(f'/content/sampled_imgs_epoch{epoch}.png')

print('Finish training')
plt.plot(loss_values)
plt.savefig('/content/loss_function.png')