In [1]:
import torch
import torch.nn as nn
import time
class FlowSequential(nn.Sequential):
    def forward(self, inputs, inverse=False):
        batch_size = inputs.size(0)
        sum_LDJ = torch.zeros(batch_size, device=inputs.device)
        if not inverse:
            for m in self._modules.values():
                inputs, LDJ = m(inputs, inverse=False)
                sum_LDJ += LDJ
            return inputs, sum_LDJ
        else:
            for m in reversed(self._modules.values()):
                inputs, LDJ = m(inputs, inverse=True)
                sum_LDJ += LDJ
            return inputs, sum_LDJ

In [9]:


def check_tensors(tensor_dict, message):
    for k, v in tensor_dict.items():
        # inf check
        if torch.isinf(v).any():
            print(message)
            print('--Found inf in %s' % k)
            raise FloatingPointError
        # nan check
        if torch.isnan(v).any():
            print(message)
            print('--Found nan in %s' % k)
            raise FloatingPointError

In [10]:

class AffineCouplingLayer(nn.Module):
    def __init__(self, channels, channels_h):
        super(AffineCouplingLayer, self).__init__()
        self.channels = channels
        d = channels // 2
        conv1 = nn.Conv2d(d, channels_h, 3, 1, 1)
        conv2 = nn.Conv2d(channels_h, channels_h, 1, 1, 0)
        conv3 = nn.Conv2d(channels_h, (channels - d) * 2, 3, 1, 1)
        def init_normal(m):
            nn.init.normal_(m.weight.data, mean=0.0, std=0.05)
            nn.init.constant_(m.bias.data, 0.0)
        def init_zero(m):
            nn.init.constant_(m.weight.data, 0.0)
            nn.init.constant_(m.bias.data, 0.0)
        conv1.apply(init_normal)
        conv2.apply(init_normal)
        conv3.apply(init_zero)
        self.nn = nn.Sequential(conv1,
                                nn.ReLU(True),
                                conv2,
                                nn.ReLU(True),
                                conv3)
        self.log_scale = nn.Parameter(torch.zeros(channels, 1, 1))

    def split(self, x):
        d = self.channels // 2
        x1, x2 = torch.split(x, [d, self.channels - d], 1)
        return x1, x2

    def concat(self, x1, x2):
        x = torch.cat([x1, x2], 1)
        return x

    def forward(self, inputs, inverse=False):
        batch_size = inputs.size(0)
        if not inverse:
            x1, x2 = self.split(inputs)
            y1 = x1
            log_s, t = torch.chunk(self.nn(x1) * self.log_scale.exp(), 2, 1)
            #s = torch.exp(log_s)
            s = torch.sigmoid(log_s + 2) + 1.0 # numerically stable ver
            log_s = s.log()
            y2 = x2 * s + t
            y = self.concat(y1, y2)
            LDJ = log_s.view(batch_size, -1).sum(-1)
            check_tensors({'y': y, 'LDJ': LDJ}, str(self.__class__) + ': forward')
            return y, LDJ
        else:
            y1, y2 = self.split(inputs)
            x1 = y1
            log_s, t = torch.chunk(self.nn(x1) * self.log_scale.exp(), 2, 1)
            #s = torch.exp(log_s)
            s = torch.sigmoid(log_s + 2) + 1.0 # numerically stable ver
            log_s = s.log()
            x2 = (y2 - t) / s
            x = self.concat(x1, x2)
            LDJ = -log_s.view(batch_size, -1).sum(-1)
            check_tensors({'x': x, 'LDJ': LDJ}, str(self.__class__) + ': inverse')
            return x, LDJ

In [11]:

class InversibleConv1x1(nn.Module):
    def __init__(self, channels):
        super(InversibleConv1x1, self).__init__()
        self.w = nn.Parameter(torch.qr(torch.randn(channels, channels))[0])

    def forward(self, inputs, inverse=False):
        batch_size = inputs.size(0)
        pixels = inputs.size(-1) * inputs.size(-2)
        if not inverse:
            y = nn.functional.conv2d(inputs, self.w.unsqueeze(-1).unsqueeze(-1))
            abs_det = torch.det(self.w).abs()
            LDJ = abs_det.log().repeat(batch_size) * pixels
            check_tensors({'y': y, 'LDJ': LDJ}, str(self.__class__) + ': forward')
            return y, LDJ
        else:
            inv_w = torch.inverse(self.w)
            x = nn.functional.conv2d(inputs, inv_w.unsqueeze(-1).unsqueeze(-1))
            abs_det = torch.det(inv_w).abs()
            LDJ = abs_det.log().repeat(batch_size) * pixels
            check_tensors({'x': x, 'LDJ': LDJ}, str(self.__class__) + ': inverse')
            return x, LDJ

In [12]:


class ActNorm(nn.Module):
    def __init__(self, channels):
        super(ActNorm, self).__init__()
        self.channels = channels
        self.mean = nn.Parameter(torch.empty(channels, 1, 1))
        self.log_std = nn.Parameter(torch.empty(channels, 1, 1))
        self.initialized = False

    def forward(self, inputs, inverse=False):
        if not self.initialized:
            inputs_view = inputs.transpose(0, 1).contiguous().view(self.channels, -1)
            mean = inputs_view.mean(-1).view(-1, 1, 1)
            std = inputs_view.std(-1).view(-1, 1, 1)
            std = std.clamp(min=1e-16) # avoid nan
            self.mean.data.copy_(mean)
            self.log_std.data.copy_(std.log())
            self.initialized = True

        batch_size = inputs.size(0)
        pixels = inputs.size(-1) * inputs.size(-2)
        if not inverse:
            y = (inputs - self.mean) * torch.exp(-self.log_std)
            LDJ = -self.log_std.sum().repeat(batch_size) * pixels
            check_tensors({'y': y, 'LDJ': LDJ}, str(self.__class__) + ': forward')
            return y, LDJ
        else:
            x = inputs * torch.exp(self.log_std) + self.mean
            LDJ = self.log_std.sum().repeat(batch_size) * pixels
            check_tensors({'x': x, 'LDJ': LDJ}, str(self.__class__) + ': inverse')
            return x, LDJ

In [13]:

class Squeeze(nn.Module):
    def __init__(self):
        super(Squeeze, self).__init__()
    def forward(self, inputs, inverse=False):
        batch_size, c, h, w = inputs.size()
        if not inverse:
            x_view = inputs.contiguous().view(batch_size, c, h // 2, 2, w // 2, 2)
            y = x_view.permute(0, 1, 3, 5, 2, 4).contiguous().view(batch_size, c * 4, h // 2, w // 2)
            return y, torch.zeros(batch_size, device=inputs.device)
        else:
            y_view = inputs.contiguous().view(batch_size, c // 4, 2, 2, h, w)
            x = y_view.permute(0, 1, 4, 2, 5, 3).contiguous().view(batch_size, c // 4, h * 2, w * 2)
            return x, torch.zeros(batch_size, device=inputs.device)

In [14]:


class Glow(nn.Module):
    def __init__(self, input_size, channels_h, K, L, save_memory=False):
        super(Glow, self).__init__()
        self.L = L
        self.save_memory = save_memory
        self.output_sizes = []
        blocks = []
        c, h, w = input_size
        for l in range(L):
            block = [Squeeze()]
            c *= 4; h //= 2; w //= 2 # squeeze
            for _ in range(K):
                norm_layer = ActNorm(c)
                if save_memory:
                    perm_layer = flows.RandomRotation(c) # easily inversible ver
                else:
                    perm_layer = InversibleConv1x1(c)
                coupling_layer = AffineCouplingLayer(c, channels_h)
                block += [norm_layer, perm_layer, coupling_layer]
            blocks.append(FlowSequential(*block))
            self.output_sizes.append((c, h, w))
            c //= 2 # split
        self.blocks = nn.ModuleList(blocks)

    def forward(self, inputs, inverse=False):
        batch_size = inputs.size(0)
        if not inverse:
            h = inputs
            sum_LDJ = 0
            xs = []
            for l in range(self.L):
                if self.save_memory:
                    h, LDJ = flows.rev_sequential(self.blocks[l], h, inverse=False)
                else:
                    h, LDJ = self.blocks[l](h, inverse=False)
                sum_LDJ += LDJ
                if l < self.L - 1:
                    x, h = torch.chunk(h, 2, 1)
                else:
                    x = h
                xs.append(x.view(batch_size, -1))
            x = torch.cat(xs, -1)
            return x, sum_LDJ
        else:
            sections = [inputs.size(-1) // (2 ** (l + 1)) for l in range(self.L)]
            sections[-1] *= 2
            xs = torch.split(inputs, sections, -1)
            h = xs[-1]
            sum_LDJ = 0
            for l in reversed(range(self.L)):
                h = h.view(batch_size, *self.output_sizes[l])
                if self.save_memory:
                    h, LDJ = flows.rev_sequential(self.blocks[l], h, inverse=True)
                else:
                    h, LDJ = self.blocks[l](h, inverse=True)
                sum_LDJ += LDJ
                if l > 0:
                    h = torch.cat([xs[l - 1], h.view(batch_size, -1)], -1)
            y = h
            return y, sum_LDJ
    
    def log_prob(self, y):
        x, LDJ = self.forward(y, inverse=False)
        log_2pi = 0.79817986835
        log_p_x = -0.5 * (x.pow(2) + log_2pi).sum(-1) # x ~ N(0, I)
        log_p_y = log_p_x + LDJ
        return log_p_y

    def sample(self, n, device, temperature=1.0):
        size = self.output_sizes[0][0] * self.output_sizes[0][1] * self.output_sizes[0][2]
        x = torch.randn(n, size, device=device) * temperature # sample from the reduced-temperature distribution
        y, LDJ = self.forward(x, inverse=True)
        return y

In [15]:
import os
from os.path import join, exists
import argparse
import torch
import torch.utils.data
import torch.nn as nn
from torchvision.utils import save_image
from torchvision import datasets, transforms


# Parse args
print('==> Args')
parser = argparse.ArgumentParser()
parser.add_argument('--datasets_dir', default='./MNIT-pp', type=str,
                    help='Directory of datasets')
parser.add_argument('--out_dir', default='./out_MNIT', type=str,
                    help='Directory to put the training result')
parser.add_argument('--channels_h', default=256, type=int,
                    help='Number of channels of hidden layers of conv-nets')
parser.add_argument('--K', default=16, type=int,
                    help='Depth of flow')
parser.add_argument('--L', default=2, type=int,
                    help='Number of levels')
parser.add_argument('--lr', default=1e-3, type=float,
                    help='Learning rate')
parser.add_argument('--weight_decay', default=1e-6, type=float,
                    help='Weight decay')
parser.add_argument('--batch_size', default=512, type=int,
                    help='Mini-batch size')
parser.add_argument('--epochs', default=10, type=int,
                    help='Number of epochs to train totally')
parser.add_argument('--save_memory', action='store_true',
                    help='Enables memory-saving backpropagation')
parser.add_argument('--display_interval', default=1, type=int,
                    help='Steps between logging training details')
parser.add_argument('--sample_interval', default=1, type=int,
                    help='Epochs between sampling')
parser.add_argument('--temperature', default=0.7, type=float,
                    help='Temperature of distribution to sample from')
parser.add_argument('--save_model_interval', default=5, type=int,
                    help='Epochs between saving model')
args = parser.parse_args(args=[])
print(vars(args))

# Device
print('==> Device')
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device('cpu')
print(device)

# Dataset
print('==> Dataset')
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(args.datasets_dir, train=True,
                                 transform=transform, download=True)
test_dataset = datasets.MNIST(args.datasets_dir, train=False,
                                transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=True)
image_size = train_dataset[0][0].size()
print('size of train data: %d' % len(train_dataset))
print('size of test data: %d' % len(test_dataset))
print('image size: %s' % str(image_size))

# Model
print('==> Model')
model = Glow(image_size, args.channels_h, args.K, args.L,
             save_memory=args.save_memory).to(device)
#print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

def train(epoch):
    # warmup
    since = time.time()
    lr = min(args.lr * epoch / 10, args.lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    model.train()
    sum_loss = 0
    count = 0
    for iteration, batch in enumerate(train_loader, 1):
        batch = batch[0].to(device)
        optimizer.zero_grad()
        loss = -model.log_prob(batch)
        mean_loss = loss.mean()
        mean_loss.backward()
        optimizer.step()

        sum_loss += loss.sum().item()
        if iteration % args.display_interval == 0:
            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('[%6d][%6d] | loss: %.4f' % \
                  (epoch, iteration, mean_loss.item()))
    average_loss = sum_loss / len(train_dataset)

    print('==> Epoch%d Average Loss | loss: %.4f' % \
          (epoch, average_loss))
    return average_loss

def test(epoch):
    model.eval()
    sum_loss = 0
    for iteration, batch in enumerate(test_loader, 1):
        batch = batch[0].to(device)
        with torch.no_grad():
            loss = -model.log_prob(batch)
        sum_loss += loss.sum().item()

    average_loss = sum_loss / len(test_dataset)
    print('==> Epoch%d Test Loss | loss: %.4f' % \
          (epoch, average_loss))
    if epoch % args.sample_interval == 0:
        n_samples = 64
        with torch.no_grad():
            sample = model.sample(n_samples, device).detach().cpu()
        save_image(sample, join(args.out_dir, 'sample_%06d.png' % epoch), nrow=8)
    return average_loss

def dump(train_loss, test_loss):
    with open(join(args.out_dir, 'dump.csv'), mode='a') as f:
        f.write('%.4f, %.4f\n' % (train_loss, test_loss))

if __name__ == '__main__':
    if not exists(args.out_dir):
        os.mkdir(args.out_dir)
    print('==> Start learning')
    for epoch in range(1, args.epochs + 1):
        train_loss = train(epoch)
        test_loss = test(epoch)
        dump(train_loss, test_loss)
        if epoch % args.save_model_interval == 0:
            params = model.state_dict()
            torch.save(params, join(args.out_dir, 'model_%06d' % epoch))

==> Args
{'datasets_dir': './MNIT-pp', 'out_dir': './out_MNIT', 'channels_h': 256, 'K': 16, 'L': 2, 'lr': 0.001, 'weight_decay': 1e-06, 'batch_size': 512, 'epochs': 10, 'save_memory': False, 'display_interval': 1, 'sample_interval': 1, 'temperature': 0.7, 'save_model_interval': 5}
==> Device
cpu
==> Dataset
size of train data: 60000
size of test data: 10000
image size: torch.Size([1, 28, 28])
==> Model
==> Start learning
Time elapsed 0m 13s
[     1][     1] | loss: -502.5470


KeyboardInterrupt: 