In [0]:
import numpy as np
import pickle
import torch
import torch.distributions as D
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import plotly.graph_objs as go

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cpu


In [0]:
class CouplingBlock(nn.Module):
    def __init__(self, channels_number):
      super(CouplingBlock, self).__init__()
      self.channels_number = channels_number
      self.model = nn.Sequential(
          nn.Conv2d(self.channels_number, self.channels_number, 1, stride=1),
          nn.ReLU(),
          nn.Conv2d(self.channels_number, self.channels_number, 3, stride=1, padding=1),
          nn.ReLU(),
          nn.Conv2d(self.channels_number, self.channels_number, 1, stride=1)
      )

    def forward(self, x):
      return x + self.model(x)

class AffineCoupling(nn.Module):
    def __init__(self, in_channels, n_filters=256, n_blocks=8):
      super(AffineCoupling, self).__init__()
      self.n_filters = n_filters
      self.in_channels = in_channels
      self.out_channels = in_channels * 2
      self.n_blocks = n_blocks

      self.first_conv = nn.Conv2d(self.in_channels, self.n_filters, 3, stride=1, padding=1)
      self.layers = nn.Sequential(*[CouplingBlock(self.n_filters) for _ in range(self.n_blocks)])
      self.relu = nn.ReLU()
      self.last_conv = nn.Conv2d(self.n_filters, self.out_channels, 3, stride=1, padding=1)

    def forward(self, data):
      (x1, x2), loss = data
      y1 = x1
      s, t = torch.chunk(self.simple_resnet(x1), 2, dim=1)
      y2 = torch.exp(s) * x2 + t
      loss = loss + torch.abs(s.reshape(x1.shape[0], -1).sum(dim=1))
      return ((y1, y2), loss)  

    def reverse(self, z):
      z1, z2 = z
      x1 = z1
      s, t = torch.chunk(self.simple_resnet(z1), 2, dim=1)
      x2 = z2 * torch.exp(-s) - t
      return (x1, x2)

    def simple_resnet(self, x):
      z = x
      z = self.first_conv(z)
      z = self.layers(z)
      z = self.relu(z)
      z = self.last_conv(z)
      return z

In [0]:
class FlipModel(nn.Module):
    def __init__(self):
      super(FlipModel, self).__init__()

    def forward(self, data):
      (x1, x2), loss = data
      return (x2, x1), loss

In [0]:
class AffineWithFlip(nn.Module):
  def __init__(self, in_channels):
    super(AffineWithFlip, self).__init__()
    self.in_channels = in_channels
    self.model = nn.Sequential(
        AffineCoupling(in_channels),
        FlipModel()
    )

  def forward(self, data):
    return self.model(data)

  def reverse(self, z):
    z1, z2 = z
    return self.model[0].reverse((z2, z1))

class RealNVPModel(nn.Module):

  def __init__(self, in_channels):
    super(RealNVPModel, self).__init__()
    self.prior = D.Normal(torch.tensor(0.).to(DEVICE), torch.tensor(1.).to(DEVICE))
    self.in_channels = in_channels
    self.affineCoupling1 = nn.Sequential(*[AffineWithFlip(self.in_channels) for _ in range(4)])
    self.affineCoupling2 = nn.Sequential(*[AffineWithFlip(self.in_channels * 2) for _ in range(3)])
    self.affineCoupling3 = nn.Sequential(*[AffineWithFlip(self.in_channels * 4) for _ in range(3)])
    self.affineCoupling4 = nn.Sequential(*[AffineWithFlip(self.in_channels * 8) for _ in range(3)])
    self.affineCoupling5 = nn.Sequential(*[AffineWithFlip(self.in_channels * 16) for _ in range(3)])

  def forward(self, x):
    data, loss = self.preprocess(x), torch.zeros(x.shape[0]).to(DEVICE)
    data, loss = self.withCheckBoardSplit(data, loss, self.affineCoupling1)
    data = self.squeeze(data)
    data, loss = self.withChannelSplit(data, loss, self.affineCoupling2)
    data, loss = self.withCheckBoardSplit(data, loss, self.affineCoupling3)
    data = self.squeeze(data)
    data, loss = self.withChannelSplit(data, loss, self.affineCoupling4)
    data, loss = self.withCheckBoardSplit(data, loss, self.affineCoupling5)
    loss = loss + torch.sum(self.prior.log_prob(data), dim=(1, 2, 3))
    return data, loss

  def reverse(self, z):
    loss = 0
    z, loss = self.withCheckBoardSplit(z, loss, self.affineCoupling5, True)
    z, loss = self.withChannelSplit(z, loss, self.affineCoupling4, True)
    z = self.unsqueeze(z)
    z, loss = self.withCheckBoardSplit(z, loss, self.affineCoupling3, True)
    z, loss = self.withChannelSplit(z, loss, self.affineCoupling2, True)
    z = self.unsqueeze(z)
    z, loss = self.withCheckBoardSplit(z, loss, self.affineCoupling1, True)
    return z

  def squeeze(self, data):
    N, C, W, H = data.shape
    mask1 = torch.tensor([[True, False], [False, False]]).repeat(N, C, W//2, H//2).to(DEVICE)
    x1 = torch.masked_select(data, mask1).reshape(N, C, W//2, H//2)
    mask2 = torch.tensor([[False, True], [False, False]]).repeat(N, C, W//2, H//2).to(DEVICE)
    x2 = torch.masked_select(data, mask2).reshape(N, C, W//2, H//2)
    mask3 = torch.tensor([[False, False], [True, False]]).repeat(N, C, W//2, H//2).to(DEVICE)
    x3 = torch.masked_select(data, mask3).reshape(N, C, W//2, H//2)
    mask4 = torch.tensor([[False, False], [False, True]]).repeat(N, C, W//2, H//2).to(DEVICE)
    x4 = torch.masked_select(data, mask4).reshape(N, C, W//2, H//2)
    return torch.cat([x1, x2, x3, x4], dim=1)

  def unsqueeze(self, data):
    N, C, W, H = data.shape
    C, W, H = C // 4, W * 2, H * 2
    x1, x2, x3, x4 = torch.chunk(data, 4, dim=1)
    x = torch.zeros(N, C, W, H).to(DEVICE)
    mask1 = torch.tensor([[True, False], [False, False]]).repeat(N, C, W//2, H//2).to(DEVICE)
    mask2 = torch.tensor([[False, True], [False, False]]).repeat(N, C, W//2, H//2).to(DEVICE)
    mask3 = torch.tensor([[False, False], [True, False]]).repeat(N, C, W//2, H//2).to(DEVICE)
    mask4 = torch.tensor([[False, False], [False, True]]).repeat(N, C, W//2, H//2).to(DEVICE)
    x[mask1] = x1.reshape(-1)
    x[mask2] = x2.reshape(-1)
    x[mask3] = x3.reshape(-1)
    x[mask4] = x4.reshape(-1)
    return x

  def withCheckBoardSplit(self, data, loss, layer, reverse=False):
    data = self.checkBoardSplit(data)
    if reverse:
      for item in reversed(layer):        
        data, loss = item.reverse(data), loss
    else:
      data, loss = layer((data, loss))
    data = self.inverseCheckBoardSplit(data)
    return data, loss

  def withChannelSplit(self, data, loss, layer, reverse=False):
    data = self.channelSplit(data)
    if reverse:
      for item in reversed(layer):        
        data, loss = item.reverse(data), loss
    else:
      data, loss = layer((data, loss))
    data = self.inverseChannelSplit(data)
    return data, loss

  def checkBoardSplit(self, data):
    N, C, W, H = data.shape
    mask = torch.tensor([[True, False], [False, True]]).repeat(N, C, W//2, H//2).to(DEVICE)
    x1 = torch.masked_select(data, mask).reshape(N, C, -1)
    x2 = torch.masked_select(data, ~mask).reshape(N, C, -1)
    return (x1.reshape(N, C, W, H // 2), x2.reshape(N, C, W, H // 2))

  def inverseCheckBoardSplit(self, data):
    x1, x2 = data
    N, C, W, H = x1.shape
    H = H * 2
    data = torch.zeros(N, C, W, H).to(DEVICE)
    mask1 = torch.tensor([[True, False], [False, True]]).repeat(N, C, W//2, H//2).to(DEVICE)
    mask2 = torch.tensor([[False, True], [True, False]]).repeat(N, C, W//2, H//2).to(DEVICE)
    data[mask1] = x1.reshape(-1)
    data[mask2] = x2.reshape(-1)
    return data

  def channelSplit(self, data):
    return (torch.chunk(data, 2, dim=1))

  def inverseChannelSplit(self, data):
    x1, x2 = data
    return torch.cat((x1, x2), dim=1)

  def preprocess(self, x):
    return x

  def sample(self, N, C, W, H):
    z = self.prior.sample((size, C * 18, W // 4, H // 4))
    return self.reverse(z)

In [15]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
def load_data():
  def prepare_data(data):
    data = torch.from_numpy(data).float()
    data = data / 3
    data = data.transpose(1, 3)
    data = data.transpose(2, 3)
    return data

  dataset_path = '/content/drive/My Drive/dataset/hw2_q2.pkl'

  with open(dataset_path, 'rb') as f:
      dataset = pickle.load(f)

  return prepare_data(dataset['train']), prepare_data(dataset['test'])

In [0]:
train_data, test_data = load_data()

In [0]:
def train(model, optimizer, train_loader, test_loader, num_epochs):
  for epoch in range(num_epochs):
    loss, val_loss = 0, 0
    iteration = 0
    for batch in train_loader:
      batch = batch.to(DEVICE)
      iteration = iteration + 1
      optimizer.zero_grad()
      data, curr_loss = model(batch)
      curr_loss = curr_loss.mean()
      loss += curr_loss.item()
      curr_loss.backward()
      optimizer.step()
      print(f"iteration is {iteration}, curr_loss is ${curr_loss}, total loss is ${loss / iteration}")

    with torch.no_grad():
      for batch in test_loader:
        batch = batch.to(DEVICE)
        data, curr_loss = model(batch)
        val_loss += curr_loss.mean()

    print(f"After epoch {epoch} loss is {loss / len(train_loader)} and validation loss is {val_loss / len(test_loader)}")

In [0]:
def run_realNVP_model_train(train_data, test_data):
  batch_size = 8
  num_epochs = 10

  model = RealNVPModel(3)
  model = model.to(DEVICE)
  optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
  train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
  test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
  train(model, optimizer, train_loader, test_loader, num_epochs)