## Imports

In [35]:
import math
import pathlib
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss

from super_resolution.src.sen2venus_dataset import create_train_test_split

In [20]:
DATA_DIR = pathlib.Path("C:/Users/Mitch/stat3007_data")
SITES_DIR = DATA_DIR / "sites"

In [9]:
# ----- Helpers -----


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class UnFlatten(nn.Module):
    def __init__(self, unflatten_size):
        super().__init__()
        if isinstance(unflatten_size, tuple):
            self.c = unflatten_size[0]
            self.h = unflatten_size[1]
            self.w = unflatten_size[2]
        elif isinstance(unflatten_size, int):
            self.c = unflatten_size
            self.h = 1
            self.w = 1

    def forward(self, x):
        return x.view(x.size(0), self.c, self.h, self.w)


# ----- 2D Convolutions -----


# Conv2d init_parameters from: https://github.com/vlievin/biva-pytorch/blob/master/biva/layers/convolution.py
class Conv2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
        weightnorm=True,
        act=None,
        drop_prob=0.0,
    ):
        super().__init__()
        self.weightnorm = weightnorm
        self.initialized = True

        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

        self.act = nn.ELU(inplace=True) if act is not None else Identity()
        self.drop_prob = drop_prob

        if self.weightnorm:
            self.initialized = False
            self.conv = nn.utils.weight_norm(self.conv, dim=0, name="weight")

    def forward(self, input):
        if not self.initialized:
            self.init_parameters(input)
        return F.dropout(self.act(self.conv(input)), p=self.drop_prob, training=True)

    def init_parameters(self, x, init_scale=0.05, eps=1e-8):
        self.initialized = True
        if self.weightnorm:
            # initial values
            self.conv._parameters["weight_v"].data.normal_(mean=0, std=init_scale)
            self.conv._parameters["weight_g"].data.fill_(1.0)
            self.conv._parameters["bias"].data.fill_(0.0)
            init_scale = 0.01
            # data dependent init
            x = self.conv(x)
            t = x.view(x.size()[0], x.size()[1], -1)
            t = t.permute(0, 2, 1).contiguous()
            t = t.view(-1, t.size()[-1])
            m_init, v_init = torch.mean(t, 0), torch.var(t, 0)
            scale_init = init_scale / torch.sqrt(v_init + eps)

            self.conv._parameters["weight_g"].data = self.conv._parameters[
                "weight_g"
            ].data * scale_init[:, None].view(
                self.conv._parameters["weight_g"].data.size()
            )
            self.conv._parameters["bias"].data = (
                self.conv._parameters["bias"].data - m_init * scale_init
            )
            return scale_init[None, :, None, None] * (x - m_init[None, :, None, None])


class ConvTranspose2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        padding=1,
        output_padding=0,
        dilation=1,
        groups=1,
        bias=True,
        weightnorm=True,
        act=None,
        drop_prob=0.0,
    ):
        super().__init__()
        self.weightnorm = weightnorm
        self.initialized = True

        self.conv = nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

        self.act = nn.ELU(inplace=True) if act is not None else Identity()
        self.drop_prob = drop_prob

        if self.weightnorm:
            self.initialized = False
            self.conv = nn.utils.weight_norm(self.conv, dim=1, name="weight")

    def forward(self, input):
        if not self.initialized:
            self.init_parameters(input)
        return F.dropout(self.act(self.conv(input)), p=self.drop_prob, training=True)

    def init_parameters(self, x, init_scale=0.05, eps=1e-8):
        self.initialized = True
        if self.weightnorm:
            # initial values
            self.conv._parameters["weight_v"].data.normal_(mean=0, std=init_scale)
            self.conv._parameters["weight_g"].data.fill_(1.0)
            self.conv._parameters["bias"].data.fill_(0.0)
            init_scale = 0.01
            # data dependent init
            x = self.conv(x)
            t = x.view(x.size()[0], x.size()[1], -1)
            t = t.permute(0, 2, 1).contiguous()
            t = t.view(-1, t.size()[-1])
            m_init, v_init = torch.mean(t, 0), torch.var(t, 0)
            scale_init = init_scale / torch.sqrt(v_init + eps)

            self.conv._parameters["weight_g"].data = self.conv._parameters[
                "weight_g"
            ].data * scale_init[None, :].view(
                self.conv._parameters["weight_g"].data.size()
            )
            self.conv._parameters["bias"].data = (
                self.conv._parameters["bias"].data - m_init * scale_init
            )
            return scale_init[None, :, None, None] * (x - m_init[None, :, None, None])


# ----- Up and Down Sampling -----


class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels, drop_prob=0.0):
        super().__init__()
        self.core_nn = nn.Sequential(
            Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                drop_prob=drop_prob,
            )
        )

    def forward(self, input):
        return self.core_nn(input)


class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, drop_prob=0.0):
        super().__init__()
        self.core_nn = nn.Sequential(
            ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
                drop_prob=drop_prob,
            )
        )

    def forward(self, input):
        return self.core_nn(input)


# ----- Gated/Attention Blocks -----


class CALayer(nn.Module):
    """
    ChannelWise Gated Layer.
    """

    def __init__(self, channel, reduction=8, drop_prob=0.0):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.ca_block = nn.Sequential(
            Conv2d(
                channel,
                channel // reduction,
                kernel_size=1,
                stride=1,
                padding=0,
                drop_prob=drop_prob,
            ),
            Conv2d(
                channel // reduction,
                channel,
                kernel_size=1,
                stride=1,
                padding=0,
                act=None,
                drop_prob=drop_prob,
            ),
            nn.Sigmoid(),
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca_block(y)
        return x * y


# ----- DenseNets -----


class DenseNetBlock(nn.Module):
    def __init__(self, inplanes, growth_rate, drop_prob=0.0):
        super().__init__()
        self.dense_block = nn.Sequential(
            Conv2d(
                inplanes,
                4 * growth_rate,
                kernel_size=1,
                stride=1,
                padding=0,
                drop_prob=drop_prob,
            ),
            Conv2d(
                4 * growth_rate,
                growth_rate,
                kernel_size=3,
                stride=1,
                padding=1,
                drop_prob=drop_prob,
                act=None,
            ),
        )

    def forward(self, input):
        y = self.dense_block(input)
        y = torch.cat([input, y], dim=1)
        return y


class DenseNetLayer(nn.Module):
    def __init__(self, inplanes, growth_rate, steps, drop_prob=0.0):
        super().__init__()
        self.activation = nn.ELU(inplace=True)

        net = []
        for step in range(steps):
            net.append(DenseNetBlock(inplanes, growth_rate, drop_prob=drop_prob))
            net.append(self.activation)
            inplanes += growth_rate

        net.append(CALayer(inplanes, drop_prob=drop_prob))
        self.core_nn = nn.Sequential(*net)

    def forward(self, input):
        return self.core_nn(input)


class DenselyNetwork(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        growth_rate,
        steps,
        blocks,
        act=None,
        drop_prob=0.0,
    ):
        super().__init__()
        # downscale block
        net = []
        for i in range(blocks):
            net.append(
                DenseNetLayer(in_channels, growth_rate, steps, drop_prob=drop_prob)
            )
            in_channels = in_channels + growth_rate * steps

        # output layer
        net.append(
            Conv2d(
                in_channels, out_channels, kernel_size=1, stride=1, padding=0, act=None
            )
        )

        self.core_nn = nn.Sequential(*net)

    def forward(self, input):
        return self.core_nn(input)


class DenselyEncoder(nn.Module):
    def __init__(
        self, in_channels, out_channels, growth_rate, steps, scale_factor, drop_prob=0.0
    ):
        super().__init__()
        # downscale block
        net = []
        for i in range(scale_factor):
            net.append(
                DenseNetLayer(in_channels, growth_rate, steps, drop_prob=drop_prob)
            )
            in_channels = in_channels + growth_rate * steps
            net.append(Downsample(in_channels, 2 * in_channels, drop_prob=drop_prob))
            in_channels *= 2
            growth_rate *= 2

        # output block
        net.append(DenseNetLayer(in_channels, growth_rate, steps, drop_prob=drop_prob))
        in_channels = in_channels + growth_rate * steps

        # output layer
        net.append(
            Conv2d(
                in_channels, out_channels, kernel_size=1, stride=1, padding=0, act=None
            )
        )

        self.core_nn = nn.Sequential(*net)

    def forward(self, input):
        return self.core_nn(input)


class DenselyDecoder(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        growth_rate=16,
        steps=3,
        scale_factor=2,
        drop_prob=0.0,
    ):
        super().__init__()
        # upsample block
        net = []
        for i in range(scale_factor):
            net.append(
                DenseNetLayer(in_channels, growth_rate, steps, drop_prob=drop_prob)
            )
            in_channels = in_channels + growth_rate * steps
            net.append(Upsample(in_channels, in_channels // 2, drop_prob=drop_prob))
            in_channels = in_channels // 2
            growth_rate = growth_rate // 2

        # output block
        net.append(
            Conv2d(
                in_channels, out_channels, kernel_size=3, stride=1, padding=1, act=None
            )
        )

        self.core_nn = nn.Sequential(*net)

    def forward(self, x):
        return self.core_nn(x)

In [27]:
# Distributions
NMIX = 10


def n_embenddings(nc, distribution="dmol"):

    if distribution == "dmol":
        nmix = NMIX
        n_emb = (nc * 3 + 1) * nmix
    else:
        raise NotImplementedError
    return n_emb


def log_normal_diag(z, z_mu, z_logvar):
    eps = 1e-12
    log_probs = (
        z_logvar + (z - z_mu).pow(2).div(z_logvar.exp() + eps) + math.log(math.pi * 2.0)
    )
    log_probs = -0.5 * log_probs.view(z.size(0), -1).sum(dim=1)
    return log_probs


def logsumexp(x, dim=None):
    if dim is None:
        xmax = x.max()
        xmax_ = x.max()
        return xmax_ + torch.log(torch.exp(x - xmax).sum())
    else:
        xmax, _ = x.max(dim, keepdim=True)
        xmax_, _ = x.max(dim)
        return xmax_ + torch.log(torch.exp(x - xmax).sum(dim))


def dmol_loss(x, output, nc=3, nmix=NMIX, nbits=8):
    """Discretized mix of logistic distributions loss"""
    bits = 2.0**nbits
    scale_min, scale_max = [0.0, 1.0]

    bin_size = (scale_max - scale_min) / (bits - 1.0)
    eps = 1e-12

    # unpack values
    batch_size, nmix, H, W = output[:, :nmix].size()
    logit_probs = output[:, :nmix]
    means = output[:, nmix : (nc + 1) * nmix].view(batch_size, nmix, nc, H, W)
    logscales = output[:, (nc + 1) * nmix : (nc * 2 + 1) * nmix].view(
        batch_size, nmix, nc, H, W
    )
    coeffs = output[:, (nc * 2 + 1) * nmix : (nc * 2 + 4) * nmix].view(
        batch_size, nmix, nc, H, W
    )

    # activation functions and resize
    logit_probs = F.log_softmax(logit_probs, dim=1)
    logscales = logscales.clamp(min=-7.0)
    coeffs = coeffs.tanh()

    x = x.unsqueeze(1)
    means = means.view(batch_size, *means.size()[1:])
    logscales = logscales.view(batch_size, *logscales.size()[1:])
    coeffs = coeffs.view(batch_size, *coeffs.size()[1:])
    logit_probs = logit_probs.view(batch_size, *logit_probs.size()[1:])

    # channel-wise conditional modelling sub-pixels
    mean0 = means[:, :, 0]
    mean1 = means[:, :, 1] + coeffs[:, :, 0] * x[:, :, 0]
    mean2 = means[:, :, 2] + coeffs[:, :, 1] * x[:, :, 0] + coeffs[:, :, 2] * x[:, :, 1]
    means = torch.stack([mean0, mean1, mean2], dim=2)

    # compute log CDF for the normal cases (lower < x < upper)
    x_plus = torch.exp(-logscales) * (x - means + 0.5 * bin_size)
    x_minus = torch.exp(-logscales) * (x - means - 0.5 * bin_size)
    cdf_delta = torch.sigmoid(x_plus) - torch.sigmoid(x_minus)
    log_cdf_mid = torch.log(cdf_delta.clamp(min=eps))

    # Extreme Case #1: x > upper (before scaling)
    upper = scale_max - 0.5 * bin_size
    mask_upper = x.le(upper).float()
    log_cdf_up = -F.softplus(x_minus)

    # Extreme Case #2: x < lower (before scaling)
    lower = scale_min + 0.5 * bin_size
    mask_lower = x.ge(lower).float()
    log_cdf_low = x_plus - F.softplus(x_plus)

    # Extreme Case #3: probability on a sub-pixel is below 1e-5
    #   --> If the probability on a sub-pixel is below 1e-5, we use an approximation
    #       based on the assumption that the log-density is constant in the bin of
    #       the observed sub-pixel value
    x_in = torch.exp(-logscales) * (x - means)
    mask_delta = cdf_delta.gt(1e-5).float()
    log_cdf_approx = x_in - logscales - 2.0 * F.softplus(x_in) + np.log(bin_size)

    # Compute log CDF w/ extrime cases
    log_cdf = log_cdf_mid * mask_delta + log_cdf_approx * (1.0 - mask_delta)
    log_cdf = log_cdf_low * (1.0 - mask_lower) + log_cdf * mask_lower
    log_cdf = log_cdf_up * (1.0 - mask_upper) + log_cdf * mask_upper

    # Compute log loss
    loss = logsumexp(log_cdf.sum(dim=2) + logit_probs, dim=1)
    return loss.view(loss.shape[0], -1).sum(1)


def sample_from_dmol(x_mean, nc=3, nmix=NMIX, random_sample=False):
    """Sample from Discretized mix of logistic distribution"""
    scale_min, scale_max = [0.0, 1.0]

    # unpack values
    logit_probs = x_mean[:, :nmix]  # pi
    batch_size, nmix, H, W = logit_probs.size()
    means = x_mean[:, nmix : (nc + 1) * nmix].view(batch_size, nmix, nc, H, W)  # mean
    logscales = x_mean[:, (nc + 1) * nmix : (nc * 2 + 1) * nmix].view(
        batch_size, nmix, nc, H, W
    )  # log_var
    coeffs = x_mean[:, (nc * 2 + 1) * nmix : (nc * 2 + 4) * nmix].view(
        batch_size, nmix, nc, H, W
    )  # chan_coeff

    # activation functions
    logscales = logscales.clamp(min=-7.0)
    logit_probs = F.log_softmax(logit_probs, dim=1)
    coeffs = coeffs.tanh()

    # sample mixture
    index = (
        logit_probs.argmax(dim=1, keepdim=True)
        + logit_probs.new_zeros(means.size(0), *means.size()[2:]).long()
    )
    one_hot = means.new_zeros(means.size()).scatter_(1, index.unsqueeze(1), 1)
    means = (means * one_hot).sum(dim=1)
    logscales = (logscales * one_hot).sum(dim=1)
    coeffs = (coeffs * one_hot).sum(dim=1)
    x = means

    if random_sample:
        # sample y from CDF
        u = means.new_zeros(means.size()).uniform_(1e-5, 1 - 1e-5)
        # from y map it to the corresponing x
        x = x + logscales.exp() * (torch.log(u) - torch.log(1.0 - u))

    # concat image channels
    x0 = (x[:, 0]).clamp(min=scale_min, max=scale_max)
    x1 = (x[:, 1] + coeffs[:, 0] * x0).clamp(min=scale_min, max=scale_max)
    x2 = (x[:, 2] + coeffs[:, 1] * x0 + coeffs[:, 2] * x1).clamp(
        min=scale_min, max=scale_max
    )
    x = torch.stack([x0, x1, x2], dim=1)
    return x

In [72]:
class q_u(nn.Module):
    """Encoder q(u|y)"""

    def __init__(self, output_shape, input_shape):
        super().__init__()
        nc_in = input_shape[0]
        nc_out = 2 * output_shape[0]

        self.core_nn = nn.Sequential(
            DenselyEncoder(
                in_channels=nc_in,
                out_channels=nc_out,
                growth_rate=64,
                steps=3,
                scale_factor=1,
            )
        )

    def forward(self, input):
        mu, logvar = self.core_nn(input).chunk(2, 1)
        return mu, F.hardtanh(logvar, min_val=-7, max_val=7.0)


class p_y(nn.Module):
    """Dencoder p(y|u)"""

    def __init__(self, output_shape, input_shape):
        super().__init__()
        nc_in = input_shape[0]
        nc_out = n_embenddings(output_shape[0])

        self.core_nn = nn.Sequential(
            DenselyDecoder(
                in_channels=nc_in,
                out_channels=nc_out,
                growth_rate=128,
                steps=4,
                scale_factor=1,
            )
        )

    def forward(self, input):
        logits = self.core_nn(input)
        return logits


class q_z(nn.Module):
    """Encoder q(z|x)"""

    def __init__(self, output_shape, input_shape):
        super().__init__()
        nc_in = input_shape[0]
        nc_out = 2 * output_shape[0]

        self.core_nn = nn.Sequential(
            DenselyEncoder(
                in_channels=nc_in,
                out_channels=nc_out,
                growth_rate=16,
                steps=4,
                scale_factor=2,
            )
        )

    def forward(self, input):
        mu, logvar = self.core_nn(input).chunk(2, 1)
        return mu, F.hardtanh(logvar, min_val=-7, max_val=7.0)


class DenselyDecoder(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        growth_rate=16,
        steps=3,
        scale_factor=2,
        drop_prob=0.0,
    ):
        super().__init__()
        # upsample block
        net = []
        for i in range(scale_factor):
            net.append(
                DenseNetLayer(in_channels, growth_rate, steps, drop_prob=drop_prob)
            )
            in_channels = in_channels + growth_rate * steps
            net.append(Upsample(in_channels, in_channels // 2, drop_prob=drop_prob))
            in_channels = in_channels // 2
            growth_rate = growth_rate // 2

        # output block
        net.append(
            Conv2d(
                in_channels, out_channels, kernel_size=3, stride=1, padding=1, act=None
            )
        )

        self.core_nn = nn.Sequential(*net)

    def forward(self, x):
        return self.core_nn(x)


class p_z(nn.Module):
    """Encoder p(z| y, u)"""

    def __init__(self, output_shape, input_shape):
        super().__init__()
        nc_y_in, nc_u_in = input_shape[0][0], input_shape[1][0]
        nc_out = 2 * output_shape[0]

        self.y_nn = nn.Sequential(
            DenselyEncoder(
                in_channels=nc_y_in,
                out_channels=nc_out // 2,
                growth_rate=32,
                steps=5,
                scale_factor=1,
            ),
            nn.ELU(inplace=True),
        )

        self.u_nn = nn.Sequential(
            DenselyNetwork(
                in_channels=nc_u_in,
                out_channels=nc_out // 2,
                growth_rate=64,
                steps=3,
                blocks=3,
                act=True,
            )
        )

        self.core_nn = nn.Sequential(
            DenselyNetwork(
                in_channels=nc_out,
                out_channels=nc_out,
                growth_rate=64,
                steps=3,
                blocks=3,
                act=None,
            )
        )

    def forward(self, input):
        y, u = input[0], input[1]

        y_out = self.y_nn(y)
        u_out = self.u_nn(u)

        joint = torch.cat((y_out, u_out), 1)

        mu, logvar = self.core_nn(joint).chunk(2, 1)
        return mu, F.hardtanh(logvar, min_val=-7, max_val=7.0)


class p_x(nn.Module):
    """p(x| y, z)"""

    def __init__(self, output_shape, input_shape):
        super().__init__()
        nc_y_in, nc_z_in = input_shape[0][0], input_shape[1][0]
        nc_out = n_embenddings(output_shape[0])

        self.z_nn = nn.Sequential(
            DenselyDecoder(
                in_channels=nc_z_in,
                out_channels=nc_out,
                growth_rate=64,
                steps=8,
                scale_factor=2,
            )
        )

        self.core_nn = nn.Sequential(
            DenselyNetwork(
                in_channels=nc_out + 3,
                out_channels=nc_out,
                growth_rate=64,
                steps=5,
                blocks=3,
                act=None,
            )
        )

    def forward(self, input):
        y, z = input[0], input[1]

        y_out = F.interpolate(y, size=[256, 256], align_corners=False, mode="bilinear")
        z_out = self.z_nn(z)

        joint = torch.cat((y_out, z_out), 1)
        logits = self.core_nn(joint)
        return logits

In [73]:
from torch.optim.lr_scheduler import _LRScheduler


class LowerBoundedExponentialLR(_LRScheduler):
    def __init__(self, optimizer, gamma, lower_bound, last_epoch=-1):
        self.gamma = gamma
        self.lower_bound = lower_bound
        super(LowerBoundedExponentialLR, self).__init__(optimizer, last_epoch)

    def _get_lr(self, base_lr):
        lr = base_lr * self.gamma**self.last_epoch
        if lr < self.lower_bound:
            lr = self.lower_bound
        return lr

    def get_lr(self):
        return [self._get_lr(base_lr) for base_lr in self.base_lrs]

In [82]:
# Inspired by https://github.com/ioangatop/srVAE
class StandardNormal:
    def __init__(self, z_shape):
        self.z_shape = z_shape

    def sample(self, n_samples=1, **kwargs):
        return torch.randn((n_samples, *self.z_shape))

    def log_p(self, z, **kwargs):
        return self.forward(z)

    def forward(self, z, **kwargs):
        """Outputs the log p(z)."""
        log_probs = z.pow(2) + math.log(math.pi * 2.0)
        log_probs = -0.5 * log_probs.view(z.size(0), -1).sum(dim=1)
        return log_probs

    def __call__(self, z, **kwargs):
        return self.forward(z, **kwargs)

    def __str__(self):
        return "StandardNormal"


class ELBOLoss(_Loss):
    """
    Computes negative ELBO loss and diagnostics.
    """

    def __init__(self, **kwargs):
        super().__init__()

    def forward(self, input, outputs, model):
        elbo = (
            model.module.calculate_elbo
            if isinstance(model, nn.DataParallel)
            else model.calculate_elbo
        )
        return elbo(input, outputs)


class srVAE(nn.Module):
    def __init__(
        self,
        x_shape,  # 3, 256, 256
        y_shape=(3, 128, 128),
        u_dim: tuple[int, int, int] = (16, 8, 8),
        z_dim: tuple[int, int, int] = (16, 8, 8),
        device: str | torch.device = "cpu",
    ):
        super().__init__()
        self.device = device
        self.x_shape = x_shape
        self.y_shape = y_shape

        self.u_shape = u_dim
        self.z_shape = z_dim

        # q(y|x): deterministic "compressed" transformation
        self.compressed_transform = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize((self.y_shape[1], self.y_shape[2])),
                transforms.ToTensor(),
            ]
        )

        # p(u)
        self.p_u = StandardNormal(self.u_shape)

        # q(u | y)
        self.q_u = q_u(self.u_shape, self.y_shape)

        # p(z | y)
        self.p_z = p_z(self.z_shape, (self.y_shape, self.u_shape))

        # q(z | x)
        self.q_z = q_z(self.z_shape, self.x_shape)

        # p(y | u)
        self.p_y = p_y(self.y_shape, self.u_shape)

        # p(x | y, z)
        self.p_x = p_x(self.x_shape, (self.y_shape, self.z_shape))

        # likelihood distribution
        self.recon_loss = partial(dmol_loss)
        self.sample_distribution = partial(sample_from_dmol)

    def compressed_transformation(self, input):
        y = []
        for x in input:
            y.append(self.compressed_transform(x.cpu()))
        return torch.stack(y).to(self.device)

    def initialize(self, dataloader):
        """Data dependent init for weight normalization
        (Automatically done during the first forward pass).
        """
        with torch.no_grad():
            x, _ = next(iter(dataloader))
            x = x.to(self.device)
            output = self.forward(x)
            self.calculate_elbo(x, output)
        return

    @staticmethod
    def reparameterize(z_mean, z_log_var):
        """z ~ N(z| z_mu, z_logvar)"""
        epsilon = torch.randn_like(z_mean)
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon

    @torch.no_grad()
    def generate(self, n_samples=20):
        # u ~ p(u)
        u = self.p_u.sample(self.u_shape, n_samples=n_samples, device=self.device).to(
            self.device
        )

        # p(y|u)
        y_logits = self.p_y(u)
        y_hat = self.sample_distribution(y_logits, nc=self.y_shape[0])

        # z ~ p(z|y, u)
        z_p_mean, z_p_logvar = self.p_z((y_hat, u))
        z_p = self.reparameterize(z_p_mean, z_p_logvar)

        # x ~ p(x|y,z)
        x_logits = self.p_x((y_hat, z_p))
        x_hat = self.sample_distribution(x_logits, nc=self.x_shape[0])
        return x_hat, y_hat

    @torch.no_grad()
    def reconstruct(self, x, **kwargs):
        outputs = self.forward(x)
        y_hat = self.sample_distribution(outputs.get("y_logits"), nc=self.y_shape[0])
        x_hat = self.sample_distribution(outputs.get("x_logits"), nc=self.x_shape[0])
        return outputs.get("y"), y_hat, x_hat

    @torch.no_grad()
    def super_resolution(self, y):
        # u ~ q(u| y)
        u_q_mean, u_q_logvar = self.q_u(y)
        u_q = self.reparameterize(u_q_mean, u_q_logvar)

        # z ~ p(z|y)
        z_p_mean, z_p_logvar = self.p_z((y, u_q))
        z_p = self.reparameterize(z_p_mean, z_p_logvar)

        # x ~ p(x|y,z)
        x_logits = self.p_x((y, z_p))
        x_hat = self.sample_distribution(x_logits)
        return x_hat

    def calculate_elbo(self, x, outputs, **kwargs):
        # unpack variables
        y, x_logits, y_logits = (
            outputs.get("y"),
            outputs.get("x_logits"),
            outputs.get("y_logits"),
        )
        u_q, u_q_mean, u_q_logvar = (
            outputs.get("u_q"),
            outputs.get("u_q_mean"),
            outputs.get("u_q_logvar"),
        )
        z_q, z_q_mean, z_q_logvar = (
            outputs.get("z_q"),
            outputs.get("z_q_mean"),
            outputs.get("z_q_logvar"),
        )
        z_p_mean, z_p_logvar = outputs.get("z_p_mean"), outputs.get("z_p_logvar")

        # Reconstraction loss
        RE_x = self.recon_loss(x, x_logits, nc=self.x_shape[0])
        RE_y = self.recon_loss(y, y_logits, nc=self.y_shape[0])

        # Regularization loss
        log_p_u = self.p_u.log_p(u_q, dim=1)
        log_q_u = log_normal_diag(u_q, u_q_mean, u_q_logvar)
        KL_u = log_q_u - log_p_u

        log_p_z = log_normal_diag(z_q, z_p_mean, z_p_logvar)
        log_q_z = log_normal_diag(z_q, z_q_mean, z_q_logvar)
        KL_z = log_q_z - log_p_z

        # Total lower bound loss
        nelbo = -(RE_x + RE_y - KL_u - KL_z).mean()

        diagnostics = {
            "bpd": (nelbo.item()) / (np.prod(x.shape[1:]) * np.log(2.0)),
            "nelbo": nelbo.item(),
            "RE": -(RE_x + RE_y).mean().item(),
            "RE_x": -RE_x.mean().item(),
            "RE_y": -RE_y.mean().item(),
            "KL": (KL_z + KL_u).mean().item(),
            "KL_u": KL_u.mean().item(),
            "KL_z": KL_z.mean().item(),
        }
        return nelbo, diagnostics

    def forward(self, x, y, **kwargs):
        """Forward pass through the inference and the generative model."""
        # y ~ f(x) (deterministc)
        # y = self.compressed_transformation(x)

        # u ~ q(u| y)
        u_q_mean, u_q_logvar = self.q_u(y)
        u_q = self.reparameterize(u_q_mean, u_q_logvar)

        # z ~ q(z| x, y)
        z_q_mean, z_q_logvar = self.q_z(x)
        z_q = self.reparameterize(z_q_mean, z_q_logvar)

        # x ~ p(x| y, z)
        x_logits = self.p_x((y, z_q))

        # y ~ p(y| u)
        y_logits = self.p_y(u_q)

        # z ~ p(z| x)
        z_p_mean, z_p_logvar = self.p_z((y, u_q))

        return {
            "u_q_mean": u_q_mean,
            "u_q_logvar": u_q_logvar,
            "u_q": u_q,
            "z_q_mean": z_q_mean,
            "z_q_logvar": z_q_logvar,
            "z_q": z_q,
            "z_p_mean": z_p_mean,
            "z_p_logvar": z_p_logvar,
            "y": y,
            "y_logits": y_logits,
            "x_logits": x_logits,
        }

In [83]:
train_data, test_data = create_train_test_split(
    str(SITES_DIR) + "\\", seed=42, sites={"K34-AMAZ"}
)
train_loader = DataLoader(train_data)

In [84]:
len(train_data), len(test_data)

(969, 416)

In [85]:
model = srVAE((3, 256, 256))



In [86]:
criterion = ELBOLoss()
optimizer = torch.optim.Adamax(
    model.parameters(), lr=2e-3, betas=(0.9, 0.999), eps=1e-7
)
scheduler = LowerBoundedExponentialLR(optimizer, gamma=0.999999, lower_bound=0.0001)


def train(model, train_loader):
    model.train()

    acc_losses = {}
    for i, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        # forward pass
        x = x.to("cpu")
        x = x[:, :3, :, :]
        y = y[:, :3, :, :]
        output = model(y, x)
        loss, diagnostics = criterion(x, output, model)
        # back-prop
        loss.backward(retain_graph=True)
        optimizer.step()
        scheduler.step()
        # # gather statistics
        # acc_losses = Counter(acc_losses) + Counter(diagnostics)
        # log_interval(i + 1, len(train_loader), acc_losses)
    avg_losses = {k: acc_losses[k] / len(train_loader) for k in acc_losses}
    return avg_losses

In [87]:
for epoch in range(1, 2):
    train_losses = train(model, train_loader)

  m_init, v_init = torch.mean(t, 0), torch.var(t, 0)


RuntimeError: The size of tensor a (256) must match the size of tensor b (128) at non-singleton dimension 3

In [39]:
train_losses

NameError: name 'train_losses' is not defined