In [0]:
import random
import os
import torch
import numpy as np
import pickle
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.distributions import Normal, MultivariateNormal, Uniform
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F

In [8]:
torch.manual_seed(0)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cpu


In [0]:
with open('hw2_q2.pkl', 'rb') as file:
    DATA = pickle.load(file)
DATA['train'] = torch.FloatTensor(DATA['train']).permute(0, 3, 1, 2)
DATA['test'] = torch.FloatTensor(DATA['test']).permute(0, 3, 1, 2)

In [0]:
class Resnet(nn.Module):
    def __init__(self, ch_in, n_filters=256, n_blocks=8):
        super(Resnet,self).__init__()
        ch_out = ch_in * 2
        self.n_blocks = n_blocks
        self.conv1 = nn.Conv2d(ch_in, n_filters, kernel_size=(3, 3), stride=(1, 1), padding=2)
        self.batch_norm1 = nn.BatchNorm2d(n_filters)
        self.conv2 = nn.Conv2d(n_filters, ch_out, kernel_size=(3, 3), stride=(1, 1))
        self.batch_norm2 = nn.BatchNorm2d(ch_out)

        def get_ht_model():
          return nn.Sequential(
              nn.Conv2d(n_filters, n_filters, kernel_size=(1, 1), stride=(1, 1), padding=0),
              nn.BatchNorm2d(n_filters),
              nn.ReLU(),
              nn.Conv2d(n_filters, n_filters, kernel_size=(3, 3), stride=(1, 1), padding=1),
              nn.BatchNorm2d(n_filters))
        def get_h_model():
          return nn.Sequential(
              nn.ReLU(),
              nn.Conv2d(n_filters, n_filters, kernel_size=(1,1), stride=(1, 1), padding=0),
              nn.BatchNorm2d(n_filters))
        self._h_model = torch.nn.ModuleList([get_ht_model() for _ in range(n_blocks)])
        self.h_model = torch.nn.ModuleList([get_h_model() for _ in range(n_blocks)])
    
    def forward(self, x):
        h = self.conv1(x)
        h = self.batch_norm1(h)
        for i in range(self.n_blocks):
            _h = self._h_model[i](h)
            h = self.h_model[i](_h)
            h = (h + _h)
        h = F.relu(h)
        x = self.conv2(h)
        x = self.batch_norm2(x)
        return x

In [0]:
class AffineCoupling(nn.Module):
    def __init__(self, ch_in, sign):
        super(AffineCoupling, self).__init__()
        self.resnet = Resnet(ch_in)
        self.sign = sign
    
    def forward(self, x, masks):
        (x1, x2) = x
        y1 = x1
        log_s, t = torch.chunk(self.resnet(x1), 2, dim=1)
        y2 = torch.exp(log_s) * (x2 + t * masks[1])
        log_det = log_s.view(x1.shape[0], -1).sum(dim=1) * self.sign
        return ((y1, y2), log_det)

    def reverse(self, y, mask):
        (y1, y2) = y
        x1 = y1
        log_s, t = torch.chunk(self.resnet(x1), 2, dim=1)
        x2 = y2 * torch.exp(-log_s) - t * (1 - mask[1])
        return (x1, x2)

In [0]:
def train(model, batch_size, epochs_cnt, train_data=DATA['train'], test_data=DATA['test'], lr=5e-5):
    dim_factor = torch.FloatTensor([3 * 32 * 32]).to(DEVICE)
    log_factor = torch.log(torch.Tensor([2])).to(DEVICE)
        
    def loss_func(log_prob):
        return -torch.mean(log_prob) / dim_factor / log_factor
    
    train_iter = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_iter = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    optimizer = torch.optim.Adam(MODEL.parameters(),lr=lr, weight_decay=0.0001)

    losses, val_losses = [], []
    for epoch in range(epochs_cnt):
        loss = 0
        tmp_losses = []
        model.train()
        total_train_loss = 0
        for batch in tqdm(train_iter):
            batch = batch.to(DEVICE)
            loss = loss_func(model.log_prob(batch))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            tmp_losses.append(loss.data.cpu().numpy())
        losses.append(np.mean(tmp_losses))
        
        model.eval()
        with torch.no_grad():
            tmp_val_losses = []
            for item in val_iter:
                batch = batch.to(DEVICE)
                val_loss = loss_func(model.log_prob(batch))
                tmp_val_losses.append(val_loss.data.cpu().numpy())
            val_losses.append(np.mean(tmp_val_losses))
    return losses, val_losses

In [0]:
batch_size = 16
epoch_cnt = 4

In [0]:
model = Model(3).to(DEVICE)

In [0]:
train_losses, validate_losses = train(model, batch_size, epoch_cnt)
plot_losses(train_losses, validate_losses)