In [1]:
import numpy as np
import torch
import torch.nn as nn


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=20):
        super().__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[:, : x.size(1), :]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConditionalAffineCoupling(nn.Module):
    # using original inn paper, we can handcode the params for dim=96, hidden=1024
    def __init__(self, dim, cond_dim, hidden, clamping_alpha=2):
        super().__init__()
        
        self.dim = dim
        self.cond_dim = cond_dim

        input_dim = self.dim // 2 + cond_dim

        self.clamping_alpha = clamping_alpha
        
        #To be implemented afterwards
        self.net_s = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, self.dim//2))
        self.net_t = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, self.dim//2))

    def forward(self, x, cond):
        x1, x2 = x.chunk(2, dim=1)

        inp = torch.cat([x2, cond], dim=1)

        s = self.net_s(inp)
        t = self.net_t(inp)

        #Soft-Clamping
        s_cl = self.clamp(s)

        y1 = x1 * torch.exp(s_cl) + t
        y = torch.cat([y1, x2], dim=1)
        return y, (s_cl, t)

    def reverse(self, y, cond):
        y1, y2 = y.chunk(2, dim=1)

        inp = torch.cat([y2, cond], dim=1)

        s = self.net_s(inp)
        t = self.net_t(inp)
        
        #Soft-Clamping to be done
        s_cl = self.clamp(s)

        x1 = (y1 - t) * torch.exp(-s_cl)
        x = torch.cat([x1, y2], dim=1)
        return x

    def clamp(self, r):
        return (2 * self.clamping_alpha / torch.pi) * torch.atan(r / self.clamping_alpha)


In [None]:
class LocalINN(nn.Module):
    # hidden = 1024, dim=96
    def __init__(self, dim, cond_dim, n_layers, hidden):
        super().__init__()

        self.dim = dim
        self.layers = nn.ModuleList([
            ConditionalAffineCoupling(dim, cond_dim, hidden)
            for _ in range(n_layers)
        ])

        self.perms = [torch.randperm(dim) for _ in range(n_layers)]
        self.inv_perms = [torch.argsort(p) for p in self.perms]

    def forward(self, x, cond):
        out = x
        for layer, perm in zip(self.layers, self.perms):
            out = out[:, perm]
            out, _ = layer(out, cond)
        return out

    def reverse(self, y, cond):
        out = y
        for layer, inv_perm in zip(reversed(self.layers), reversed(self.inv_perms)):
            out = layer.reverse(out, cond)
            out = out[:, inv_perm]
        return out
