In [1]:
from abc import ABC, abstractmethod
import random
import math
import numpy as np
import torch
from torch import utils
from torch import nn
from torch import distributions
from torch import optim
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToPILImage

In [2]:
data_dir = '/Users/armandli/data/'

In [3]:
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_built()
if use_cuda:
    device = torch.device('cuda')
elif use_mps:
    device = torch.device('cpu')
else:
    device = torch.device('cpu')
cpu = torch.device('cpu')

In [4]:
default_batch_size = 256
loader_args = {'batch_size' : default_batch_size, 'shuffle' : True}
score_args = {'batch_size' : default_batch_size, 'shuffle' : False}
if use_cuda:
    loader_args.update({'pin_memory' : True})
    score_args.update({'pin_memory' : True})

In [5]:
class Reporter(ABC):
    @abstractmethod
    def report(self, typ, **metric):
        pass
    @abstractmethod
    def reset(self):
        pass

In [6]:
class SReporter(Reporter):
    def __init__(self):
        self.log = []
    def report(self, typ, **data):
        self.log.append((typ, data))
    def reset(self):
        self.log.clear()
    def loss(self, t):
        losses = []
        for (typ, data) in self.log:
            if typ == t:
                losses.append(data['loss'])
        return losses
    def loss(self, t, idx):
        if idx >= 0:
            count = 0
            for (typ, data) in self.log:
                if typ == t:
                    if count == idx:
                        return data['loss']
                    count += 1
        else:
            count = -1
            for (typ, data) in reversed(self.log):
                if typ == t:
                    if count == idx:
                        return data['loss']
                    count -= 1
        return float("inf")
    def eval_loss(self):
        return self.loss('eval')
    def train_loss(self):
        return self.loss('train')
    def eval_loss(self, idx):
        return self.loss('eval', idx)
    def train_loss(self, idx):
        return self.loss('train', idx)
    def get_record(self, t, idx):
        if idx >= 0:
            count = 0
            for (typ, data) in self.log:
                if typ == t:
                    if count == idx:
                        return data
                    count += 1
        else:
            count = -1
            for (typ, data) in reversed(self.log):
                if typ == t:
                    if count == idx:
                        return data
                    count -= 1
        return dict()
    def eval_record(self, idx):
        return self.get_record('eval', idx)
    def train_record(self, idx):
        return self.get_record('train', idx)

datasets

In [7]:
trainset = datasets.MNIST(root=data_dir, train=True, transform=transforms.ToTensor(), download=True)
evalset  = datasets.MNIST(root=data_dir, train=False, transform=transforms.ToTensor(), download=True)

In [8]:
trainset = datasets.CIFAR10(root=data_dir, train=True, transform=transforms.ToTensor(), download=True)
evalset = datasets.CIFAR10(root=data_dir, train=True, transform=transforms.ToTensor(), download=True)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
trainset[0][0].shape

torch.Size([3, 32, 32])

In [10]:
len(trainset), len(evalset)

(50000, 50000)

In [11]:
train_loader = utils.data.DataLoader(dataset=trainset, **loader_args)
eval_loader = utils.data.DataLoader(dataset=evalset, **score_args)

Model

In [12]:
def relu_activation():
    return nn.ReLU(inplace=True)

In [13]:
def downsampling2DV2(in_c, out_c, stride, norm_layer):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, 1, stride=stride),
        norm_layer(out_c),
    )

In [14]:
def upsampling2DV1(in_c, out_c, stride, norm_layer):
    return nn.Sequential(
        nn.ConvTranspose2d(in_c, out_c, 2, stride=stride),
        norm_layer(out_c),
    )

In [15]:
class Gate(nn.Module):
    def __init__(self):
        super(Gate, self).__init__()
    def forward(self, x):
        a, b = torch.chunk(x, 2, dim=1)
        return a * torch.sigmoid(b)

In [16]:
class ConcatELU(nn.Module):
    def __init__(self):
        super(ConcatELU, self).__init__()
    def forward(self, x):
        #concat at channel dim
        return F.elu(torch.cat([x, -x], dim=1))

In [17]:
class WeightNormLinear2d(nn.Module):
    def __init__(self, d_in, d_out):
        super(WeightNormLinear2d, self).__init__()
        self.layer = nn.utils.parametrizations.weight_norm(nn.Linear(d_in, d_out))
        self.d_out = d_out

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        shape = [int(d) for d in x.shape]
        x = self.layer(x.contiguous().view(shape[0]*shape[1]*shape[2], shape[3]))
        shape[-1] = self.d_out
        x = x.view(shape).permute(0, 3, 1, 2)
        return x

In [18]:
class WeightNormConv2d(nn.Module):
    def __init__(self, in_c, out_c, kernel_size, stride=1, padding=0):
        super(WeightNormConv2d, self).__init__()
        self.layer = nn.utils.parametrizations.weight_norm(nn.Conv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding))
        
    def forward(self, x):
        return self.layer(x)

In [19]:
class WeightNormConvTransposed2d(nn.Module):
    def __init__(self, in_c, out_c, kernel_size, stride, output_padding=1):
        super(WeightNormConvTransposed2d, self).__init__()
        self.layer = nn.utils.parametrizations.weight_norm(nn.ConvTranspose2d(in_c, out_c, kernel_size, stride, output_padding=output_padding))
    
    def forward(self, x):
        return self.layer(x)

In [20]:
class DownShift(nn.Module):
    def __init__(self):
        super(DownShift, self).__init__()
        #pad Left=0 Right=0 Up=1 Down=0
        self.pad = nn.ZeroPad2d((0,0,1,0))
    def forward(self, x):
        shape = x.shape
        x = x[:, :, :shape[2]-1, :]
        x = self.pad(x)
        return x

In [21]:
class DownShiftConv2d(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=(2,3), stride=(1,1), shift_down=False):
        super(DownShiftConv2d, self).__init__()
        self.layers = nn.Sequential(
            nn.ZeroPad2d((int((kernel_size[1]-1)/2),int((kernel_size[1]-1)/2),kernel_size[0]-1,0)),
            WeightNormConv2d(in_c, out_c, kernel_size, stride),
        )
        if shift_down:
            self.shift_down = DownShift()
        else:
            self.shift_down = nn.Identity()

    def forward(self, x):
        x = self.layers(x)
        x = self.shift_down(x)
        return x

In [22]:
class DownShiftDeconv2d(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=(2,3), stride=(2,2)):
        super(DownShiftDeconv2d, self).__init__()
        self.ks = kernel_size
        self.layer = WeightNormConvTransposed2d(in_c, out_c, kernel_size, stride, output_padding=1)
    
    def forward(self, x):
        x = self.layer(x)
        s = x.shape
        # correct the shape because TransposedConv2d would produce a few rows and columns bigger
        x = x[:, :, :(s[2]-self.ks[0]+1), int((self.ks[1]-1)/2):(s[3]-int((self.ks[1]-1)/2))]
        return x

In [23]:
class RightShift(nn.Module):
    def __init__(self):
        super(RightShift, self).__init__()
        #pad Left=1 Right=0 Up=0 Down=0
        self.pad = nn.ZeroPad2d((1,0,0,0))
    def forward(self, x):
        shape = x.shape
        x = x[:, :, :, :shape[3]-1]
        x = self.pad(x)
        return x

In [24]:
class DownRightShiftConv2d(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=(2,2), stride=(1,1), shift_right=False):
        super(DownRightShiftConv2d, self).__init__()
        self.layers = nn.Sequential(
            nn.ZeroPad2d((kernel_size[1]-1, 0, kernel_size[0]-1, 0)),
            WeightNormConv2d(in_c, out_c, kernel_size, stride),
        )
        if shift_right:
            self.shift_right = RightShift()
        else:
            self.shift_right = nn.Identity()
    
    def forward(self, x):
        x = self.layers(x)
        x = self.shift_right(x)
        return x

In [25]:
class DownRightShiftDeconv2d(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=(2,2), stride=(2,2)):
        super(DownRightShiftDeconv2d, self).__init__()
        self.ks = kernel_size
        self.layer = WeightNormConvTransposed2d(in_c, out_c, kernel_size, stride, output_padding=1)
    
    def forward(self, x):
        x = self.layer(x)
        s = x.shape
        # correct the shape because TransposedConv2d produces a few rows and columns bigger
        x = x[:, :, :(s[2]-self.ks[0]+1):, :(s[3]-self.ks[1]+1)]
        return x

In [26]:
class GatedResidualLayer(nn.Module):
    def __init__(self, nc, conv, skip=0, p_dropout=0.5):
        super(GatedResidualLayer, self).__init__()
        self.layer1 = nn.Sequential(
            ConcatELU(),
            conv(2*nc, nc),
        )
        self.layer2 = nn.Sequential(
            ConcatELU(),
            nn.Dropout2d(p_dropout),
            conv(2*nc, 2*nc),
            Gate(),
        )
        if skip > 0:
            self.skip = nn.Sequential(
                ConcatELU(),
                WeightNormLinear2d(2*skip*nc, nc),
            )
        
    def forward(self, x, a=None):
        s = x
        x = self.layer1(x)
        if a is not None:
            x += self.skip(a)
        x = self.layer2(x)
        return s + x

In [27]:
class PixelUpSample(nn.Module):
    def __init__(self, nlayers, nchannel):
        super(PixelUpSample, self).__init__()
        self.up_stream = nn.ModuleList([
            GatedResidualLayer(nchannel, DownShiftConv2d, skip=0)
            for _ in range(nlayers)
        ])
        self.upleft_stream = nn.ModuleList([
            GatedResidualLayer(nchannel, DownRightShiftConv2d, skip=1)
            for _ in range(nlayers)
        ])
        self.nlayers = nlayers
    def forward(self, up, upleft):
        ups, uplefts = [], []
        for i in range(self.nlayers):
            up = self.up_stream[i](up)
            upleft = self.upleft_stream[i](upleft, a=up)
            ups.append(up)
            uplefts.append(upleft)
        return ups, uplefts

In [28]:
class PixelDownSample(nn.Module):
    def __init__(self, nlayer, nchannel):
        super(PixelDownSample, self).__init__()
        self.up_stream = nn.ModuleList([
            GatedResidualLayer(nchannel, DownShiftConv2d, skip=1)
            for _ in range(nlayer)
        ])
        self.upleft_stream = nn.ModuleList([
            GatedResidualLayer(nchannel, DownRightShiftConv2d, skip=2)
            for _ in range(nlayer)
        ])
        self.nlayer = nlayer
    
    def forward(self, up, upleft, ups, uplefts):
        for i in range(self.nlayer):
            up = self.up_stream[i](up, a=ups.pop())
            upleft = self.upleft_stream[i](upleft, a=torch.cat((up, uplefts.pop()), 1))
        return up, upleft

In [29]:
class ResidualLayer2DV4(nn.Module):
    def __init__(self, in_c, out_c, ksz, act_layer, norm_layer, stride=1):
        super(ResidualLayer2DV4, self).__init__()
        if in_c <= out_c:
            self.c1 = nn.Conv2d(in_c, out_c, ksz, stride=stride, padding=int((ksz-1)/2))
            self.c2 = nn.Conv2d(out_c, out_c, ksz, stride=1, padding=int((ksz-1)/2))
        else:
            self.c1 = nn.ConvTranspose2d(in_c, out_c, ksz+1, stride=stride, padding=int((ksz-1)/2))
            self.c2 = nn.ConvTranspose2d(out_c, out_c, ksz, stride=1, padding=int((ksz-1)/2))
        self.a1 = act_layer()
        self.a2 = act_layer()
        self.b1 = norm_layer(in_c)
        self.b2 = norm_layer(out_c)
        
        if in_c < out_c:
            self.residual = downsampling2DV2(in_c, out_c, stride, norm_layer)
        elif in_c > out_c:
            self.residual = upsampling2DV1(in_c, out_c, stride, norm_layer)
        elif stride > 1:
            self.residual = downsampling2DV2(in_c, out_c, stride, norm_layer)
        else:
            self.residual = nn.Identity()

    def forward(self, x):
        s = x
        x = self.b1(x)
        x = self.a1(x)
        x = self.c1(x)
        x = self.b2(x)
        x = self.a2(x)
        x = self.c2(x)
        s = self.residual(s)
        x = x + s
        return x

In [30]:
class ConvVariationalEncoderV2(nn.Module):
    def __init__(self, ic, chmuls, hmul):
        super(ConvVariationalEncoderV2, self).__init__()
        layer1 = []
        outmul = 1
        for mul in chmuls:
            layer1.append(ResidualLayer2DV4(ic*outmul, ic*mul, 3, relu_activation, nn.BatchNorm2d, stride=2))
            outmul = mul
        self.layer1 = nn.ModuleList(layer1)
        self.mu_layer = nn.Conv2d(ic*outmul, ic*hmul, (3,3), (2,2), (1,1))
        self.sig_layer = nn.Sequential(
            nn.Conv2d(ic*outmul, ic*hmul, (3,3), (2,2), (1,1)),
            nn.Softplus(threshold=6),
        )
    
    def forward(self, x):
        for layer in self.layer1:
            x = layer(x)
        mu = self.mu_layer(x)
        sig = self.sig_layer(x)
        return (mu, sig)

In [31]:
class PixelCNN(nn.Module):
    def __init__(self, in_c, nresnet, nlayer, nchannel=80, nlogmix=10):
        super(PixelCNN, self).__init__()
        down_nlayer = [nresnet] + [nresnet+1 for _ in range(1, nlayer)]
        self.down_layers = nn.ModuleList([
            PixelDownSample(down_nlayer[i], nchannel) for i in range(nlayer)
        ])
        self.up_layers = nn.ModuleList([
            PixelUpSample(nresnet, nchannel) for _ in range(nlayer)
        ])
        self.downsize_up_stream = nn.ModuleList([
            DownShiftConv2d(nchannel, nchannel, stride=(2,2)) for _ in range((nlayer-1))
        ])
        self.downsize_upleft_stream = nn.ModuleList([
            DownRightShiftConv2d(nchannel, nchannel, stride=(2,2)) for _ in range(nlayer-1)
        ])
        self.upsize_up_stream = nn.ModuleList([
            DownShiftDeconv2d(nchannel, nchannel, stride=(2,2)) for _ in range(nlayer-1)
        ])
        self.upsize_upleft_stream = nn.ModuleList([
            DownRightShiftDeconv2d(nchannel, nchannel, stride=(2,2)) for _ in range(nlayer-1)
        ])
        self.up_init = DownShiftConv2d(in_c+1, nchannel, kernel_size=(2,3), shift_down=True)
        self.upleft_init = nn.ModuleList([
            DownShiftConv2d(in_c+1, nchannel, kernel_size=(1,3), shift_down=True),
            DownRightShiftConv2d(in_c+1, nchannel, kernel_size=(2,1), shift_right=True),
        ])
        self.out_layer = nn.Sequential(
            nn.ELU(),
            WeightNormLinear2d(nchannel, 10*nlogmix),
        )
        self.nlayer = nlayer
    
    def forward(self, x, device):
        shape = x.shape
        padding = torch.ones(shape[0], 1, shape[2], shape[3], device=device, requires_grad=False)
        x = torch.cat((x, padding), 1)
        
        # UP PASS
        ups = [self.up_init(x)]
        uplefts = [self.upleft_init[0](x) + self.upleft_init[1](x)]
        for i in range(self.nlayer):
            up_out, upleft_out = self.up_layers[i](ups[-1], uplefts[-1])
            ups.extend(up_out)
            uplefts.extend(upleft_out)
            if i < self.nlayer-1:
                ups.append(self.downsize_up_stream[i](ups[-1]))
                uplefts.append(self.downsize_upleft_stream[i](uplefts[-1]))

        # DOWN PASS
        up = ups.pop()
        upleft = uplefts.pop()
        for i in range(self.nlayer):
            up, upleft = self.down_layers[i](up, upleft, ups, uplefts)
            if i < self.nlayer-1:
                up = self.upsize_up_stream[i](up)
                upleft = self.upsize_upleft_stream[i](upleft)
        
        x = self.out_layer(upleft)
        return x

    def sample(self, batch_sz, img_shape, device):
        x = torch.zeros(batch_sz, img_shape[0], img_shape[1], img_shape[2]).to(device)
        shape = x.shape
        padding = torch.ones(shape[0], 1, shape[2], shape[3], device=device, requires_grad=False)
        x = torch.cat((x, padding), 1)
        
        # UP PASS
        ups = [self.up_init(x)]
        uplefts = [self.upleft_init[0](x) + self.upleft_init[1](x)]
        for i in range(self.nlayer):
            up_out, upleft_out = self.up_layers[i](ups[-1], uplefts[-1])
            ups.extend(up_out)
            uplefts.extend(upleft_out)
            if i < self.nlayer-1:
                ups.append(self.downsize_up_stream[i](ups[-1]))
                uplefts.append(self.downsize_upleft_stream[i](uplefts[-1]))

        # DOWN PASS
        up = ups.pop()
        upleft = uplefts.pop()
        for i in range(self.nlayer):
            up, upleft = self.down_layers[i](up, upleft, ups, uplefts)
            if i < self.nlayer-1:
                up = self.upsize_up_stream[i](up)
                upleft = self.upsize_upleft_stream[i](upleft)
        
        x = self.out_layer(upleft)
        return x

In [32]:
class PixelDecoderV1(nn.Module):
    def __init__(self, ic, chmuls, hmul, n_res, n_layer, nmix=10):
        super(PixelDecoderV1, self).__init__()
        layers = []
        outmul = hmul
        for mul in reversed(chmuls):
            layers.append(ResidualLayer2DV4(ic*outmul, ic*mul, 3, relu_activation, nn.BatchNorm2d, stride=2))
            outmul = mul
        layers.append(ResidualLayer2DV4(ic*outmul, ic, 3, relu_activation, nn.BatchNorm2d, stride=2))
        self.layers = nn.ModuleList(layers)
        self.pixel = PixelCNN(in_c=3, nresnet=n_res, nlayer=n_layer, nlogmix=nmix)
    
    def forward(self, x, device):
        for layer in self.layers:
            x = layer(x)
        x = self.pixel(x, device)
        return x

In [33]:
class PixelVariationalAutoEncoderV1(nn.Module):
    def __init__(self, ic, chmuls, hmul, dist, n_res=1, n_layer=1, nmix=10):
        super(PixelVariationalAutoEncoderV1, self).__init__()
        self.encoder = ConvVariationalEncoderV2(ic, chmuls, hmul)
        self.decoder = PixelDecoderV1(ic, chmuls, hmul, n_res, n_layer, nmix)
        self.dist = dist
    
    def forward(self, x, device):
        mu, sig = self.encode(x)
        x_h = self.decode(mu, sig, device)
        return (x_h, mu, sig)

    def encode(self, x):
        mu, sig = self.encoder(x)
        return (mu, sig)
    
    def decode(self, mu, sig, device):
        s = self.dist.sample(mu.shape).to(device)
        z = mu + sig * s
        x_h = self.decoder(z, device)
        return x_h

Training

In [34]:
class Reporter(ABC):
    @abstractmethod
    def report(self, typ, **metric):
        pass
    @abstractmethod
    def reset(self):
        pass

In [35]:
class SReporter(Reporter):
    def __init__(self):
        self.log = []
    def report(self, typ, **data):
        self.log.append((typ, data))
    def reset(self):
        self.log.clear()
    def loss(self, t):
        losses = []
        for (typ, data) in self.log:
            if typ == t:
                losses.append(data['loss'])
        return losses
    def loss(self, t, idx):
        if idx >= 0:
            count = 0
            for (typ, data) in self.log:
                if typ == t:
                    if count == idx:
                        return data['loss']
                    count += 1
        else:
            count = -1
            for (typ, data) in reversed(self.log):
                if typ == t:
                    if count == idx:
                        return data['loss']
                    count -= 1
        return float("inf")
    def eval_loss(self):
        return self.loss('eval')
    def train_loss(self):
        return self.loss('train')
    def eval_loss(self, idx):
        return self.loss('eval', idx)
    def train_loss(self, idx):
        return self.loss('train', idx)
    def get_record(self, t, idx):
        if idx >= 0:
            count = 0
            for (typ, data) in self.log:
                if typ == t:
                    if count == idx:
                        return data
                    count += 1
        else:
            count = -1
            for (typ, data) in reversed(self.log):
                if typ == t:
                    if count == idx:
                        return data
                    count -= 1
        return dict()
    def eval_record(self, idx):
        return self.get_record('eval', idx)
    def train_record(self, idx):
        return self.get_record('train', idx)

In [36]:
class DiscretizedMixLogisticLoss(nn.Module):
    def __init__(self, nmix):
        super(DiscretizedMixLogisticLoss, self).__init__()
        self.nmix = nmix
    
    def log_sum_exp(self, x):
        axis = len(x.shape)-1
        m, _ = torch.max(x, dim=axis)
        n, _ = torch.max(x, dim=axis, keepdim=True)
        return m + torch.log(torch.sum(torch.exp(x-n), dim=axis))

    def log_prob_from_logits(self, x):
        axis = len(x.shape)-1
        m, _ = torch.max(x, dim=axis, keepdim=True)
        return x - m - torch.log(torch.sum(torch.exp(x-m), dim=axis, keepdim=True))
    
    def forward(self, target, prediction, device):
        nmix = self.nmix
        target = target.permute(0, 2, 3, 1)
        prediction = prediction.permute(0, 2, 3, 1)
        ts = list(target.shape)

        # unpack prediction parameters
        lp = prediction[:,:,:,:nmix]
        prediction = prediction[:,:,:,nmix:].view(ts+[nmix*3]) # 3 for mean, scale, coeff
        means = prediction[:,:,:,:,:nmix]
        log_scales = torch.clamp(prediction[:,:,:,:,nmix:nmix*2], min=-7.)
        coeffs = torch.tanh(prediction[:,:,:,:,nmix*2:nmix*3])
        target = target.unsqueeze(-1) + torch.zeros(ts+[nmix], requires_grad=False).to(device)
        
        m2 = (means[:,:,:,1,:]+coeffs[:,:,:,0,:]*target[:,:,:,0,:]).view(ts[0],ts[1],ts[2],1,nmix)
        m3 = (means[:,:,:,2,:]+coeffs[:,:,:,1,:]*target[:,:,:,0,:]+coeffs[:,:,:,2,:]*target[:,:,:,1,:]).view(ts[0],ts[1],ts[2],1,nmix)
        
        centered_target = target - torch.cat((means[:,:,:,0,:].unsqueeze(3), m2, m3), dim=3)
        inv_stdv  = torch.exp(-log_scales)
        plus_in   = inv_stdv * (centered_target + 1./255.)
        cdf_plus  = torch.sigmoid(plus_in)
        min_in    = inv_stdv * (centered_target - 1./255.)
        cdf_minus = torch.sigmoid(min_in)

        log_cdf_plus          = plus_in - F.softplus(plus_in)
        log_one_minus_cdf_min = -F.softplus(min_in)
        cdf_delta             = cdf_plus - cdf_minus
        log_pdf_mid = (inv_stdv*centered_target) - log_scales - 2.*F.softplus(inv_stdv*centered_target)

        inner_inner_cond = (cdf_delta > 1e-5).float()
        inner_inner_out  = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1-inner_inner_cond)*(log_pdf_mid-math.log(127.5))
        inner_cond       = (target > 0.999).float()
        inner_out        = inner_cond*log_one_minus_cdf_min + (1.-inner_cond)*inner_inner_out
        cond             = (target < -0.999).float()
        log_probs        = cond*log_cdf_plus + (1.-cond)*inner_out
        log_probs        = torch.sum(log_probs, dim=3) + self.log_prob_from_logits(lp)
        return -torch.sum(self.log_sum_exp(log_probs))

In [37]:
class PixelVAELoss(nn.Module):
    def __init__(self, nmix):
        super(PixelVAELoss, self).__init__()
        self.mixlogloss = DiscretizedMixLogisticLoss(nmix)
    
    def forward(self, target, prediction, mu, sig, device):
        recon_loss = self.mixlogloss(target, prediction, device)
        dkl_loss = (sig**2. + mu**2. - torch.log(sig) - 0.5).sum()
        return recon_loss + dkl_loss

In [38]:
def vae_image_train(model, device, loader, optimizer, loss, epoch, reporter):
    model.train()
    total_loss = 0.
    for x, _ in loader:
        optimizer.zero_grad()
        x = x.to(device)
        x_h, mu, sig = model(x, device)
        l = loss(x, x_h, mu, sig, device)
        l.backward()
        optimizer.step()
        total_loss += l.item()
        print(f"Epoch {epoch}: {l.item()}")
    total_loss /= float(len(loader.dataset))
    reporter.report(typ='train', loss=total_loss)
    print(f"Train Loss: {total_loss}")

In [39]:
def vae_image_validate(model, device, loader, loss, train_epoch, reporter):
    model.eval()
    total_loss = 0.
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            x_h, mu, sig = model(x, device)
            total_loss += loss(x, x_h, mu, sig, device)
    total_loss /= float(len(loader.dataset))
    reporter.report(typ='eval', loss=total_loss)

In [40]:
def vae_image_train_validate(
        model,
        device,
        train_loader,
        eval_loader,
        optimizer,
        scheduler,
        loss,
        total_epoch,
        patience,
        patience_decay,
        reporter,
):
    validation_loss = float("inf")
    patience_count = patience
    patience = int(patience * patience_decay)
    reset_patience = False
    for epoch in range(total_epoch):
        vae_image_train(model, device, train_loader, optimizer, loss, epoch, reporter)
        vae_image_validate(model, device, eval_loader, loss, epoch, reporter)
        new_validation_loss = reporter.eval_loss(-1)
        print(f"Epoch {epoch} VLoss: {new_validation_loss}")
        scheduler.step(new_validation_loss)
        if new_validation_loss < validation_loss:
            validation_loss = new_validation_loss
            patience_count = patience
            if reset_patience:
                patience = int(patience * patience_decay)
                reset_patience = False
        else:
            validation_loss = new_validation_loss
            patience_count -= 1
            reset_patience = True
            if patience_count <= 0:
                print(f"Improvement stopped. VLoss: {validation_loss}")
                break

In [41]:
norm_dist = distributions.Normal(0., 1.)
inc = trainset[0][0].shape[0]
nmix=10
model = PixelVariationalAutoEncoderV1(inc, [inc*2, inc*4, inc*8], inc*16, norm_dist, n_res=1, n_layer=3, nmix=nmix).to(device)

In [42]:
learning_rate = 0.0001
total_epochs = 10
patience = 8
patience_decay = 0.9
optimizer = optim.Adam(model.parameters(recurse=True), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=patience, threshold=0.00001)
loss = PixelVAELoss(nmix)
reporter = SReporter()

In [43]:
vae_image_train_validate(model, device, train_loader, eval_loader, optimizer, scheduler, loss, total_epochs, patience, patience_decay, reporter)

Epoch 0: 5084684.5
Epoch 0: 5070136.0
Epoch 0: 5047143.0


KeyboardInterrupt: 

load model

validation

In [None]:
to_img = ToPILImage()

In [None]:
def sample_image(dataset):
    i = random.randint(0, len(dataset))
    return dataset[i][0]

In [None]:
x = sample_image(evalset)
x.shape

In [None]:
x_h, _, _ = model(torch.unsqueeze(x, dim=0).to(device), device)
x_h = x_h.to(device)
x_h = x_h.reshape(x.shape)

In [None]:
to_img(x_h)

In [None]:
to_img(x)