In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.optim as optim

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [2]:
import numpy as np

In [3]:
train_dataset = dsets.MNIST(root='./MNIST/', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./MNIST/', 
                           train=False, 
                           transform=transforms.ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [143]:
class ActNorm(nn.Module):
    def __init__(self, h, w, c):
        super().__init__()
        self.shape = (h, w, c)
        self.initialized = False
        self.weights = nn.Parameter(torch.Tensor(c))
        self.bias = nn.Parameter(torch.Tensor(c))
#         self.weights.data.uniform_(-0.1, 0.1)
#         self.bias.data.uniform_(-0.1, 0.1)
        
    def forward(self, inp):
        if not self.initialized:
            c = self.shape[-1]
            self.weights.data = 1/inp.view(-1, c).std(0)
            self.bias.data = -(inp * self.weights).view(-1, c).mean(0)
            self.initialized = True
        
        return inp * self.weights + self.bias
    
    def reverse(self, out):
        return (out - self.bias) / self.weights
    
    def log_determinant(self):
        return self.shape[0] * self.shape[1] * torch.log(torch.abs(self.weights))

In [284]:
class InvertibleConv(nn.Module):
    def __init__(self, h, w, c):
        super().__init__()
        self.shape = (h, w, c)
        self.weight = torch.from_numpy(np.linalg.qr(np.random.randn(c, c))[0])
                                    
    def forward(self, inp):
        return torch.einsum("abcd,eb->aecd", (inp, self.weight))
    
    def reverse(self, out):
        return torch.einsum("abcd,eb->aecd", (inp, torch.inverse(self.weight)))
    
    def log_determinant(self):
        return self.shape[0] * self.shape[1] * torch.log(torch.abs(torch.det(self.weight)))

In [5]:
class AffineCoupling(nn.Module):
    def __init__(self, NN):
        super().__init__()
        self.NN = NN
        
    def forward(self, inp):
        x_a, x_b = torch.split(inp, inp.shape[2] // 2, dim = 2)
        log_s, t = self.NN(x_b)
        s = torch.exp(log_s)
        y_a = s * x_a + t
        return torch.cat([y_a, x_b], dim = 2)
    
    def reverse(self, out):
        y_a, y_b = torch.split(out, out.shape[2] // 2, dim = 2)
        log_s, t = self.NN(y_b)
        s = torch.exp(log_s)
        x_a = (y_a - t) / s
        return torch.cat([x_a, y_b], dim = 2)

    def log_determinant(self, inp):
        x_a, x_b = torch.split(inp, inp.shape[2] // 2, dim = 2)
        log_s, t = self.NN(x_b)
        s = torch.exp(log_s)
        return torch.sum(torch.log(torch.abs(s)))

In [6]:
class GlowModel(nn.Module):
    def __init__(self):
        super().__init__()