In [6]:
import numpy as np
import torch
import torch.nn as nn
from functools import partial

In [None]:
class SubspaceModel(nn.Module):
    def __init__(self, dim: int, num_basis: int) -> None:
        super().__init__()
        self.U = nn.Parameter(torch.empty((num_basis, dim)))
        nn.init.orthogonal_(self.U)
        self.L = nn.Parameter(torch.FloatTensor([i for i in range(num_basis)])) #
        self.mu = nn.Parameter(torch.zeros(num_basis))

    def forward(self, z):
        return self.U.mm(self.L * z) + self.mu #

class ConvLayer(nn.Module):
    def __init__(self,
        in_channels: int, 
        out_channels: int, 
        kernel_size: int = 3, 
        stride: int = 1, 
        padding: int = 1, 
        padding_mode: str = "zeros", 
        # groups: int = 1, 
        # # bias: bool = True, 
        transposed: bool = False, 
        # normalization: str = None, 
        activation: bool = True, 
        pre_activate: bool = False
        ) -> None:
        if transposed:
            conv = partial(nn.ConvTranspose2d, output_padding=stride - 1)
            padding_mode = "zeros"
        else:
            conv = nn.Conv2d
        
        layers = [conv(
                    in_channels, 
                    out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    padding_mode=padding_mode
                    )]
        if activation:
            if pre_activate:
                layers.insert(0, nn.LeakyReLU())
            else:
                layers.append(nn.LeakyReLU())
        super().__init__(*layers)

class EigenBlock(nn.Module):
    def __init__(self, 
    width: int, 
    height: int, 
    in_channels: int, 
    out_channels: int, 
    num_basis: int) -> None:
        super().__init__()
        self.subspacelayer = SubspaceModel(dim=width * height * in_channels, num_basis=num_basis)
        self.subspace_conv1 = ConvLayer(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            transposed=True,
            activation=False
        )
        self.subspace_conv2 = ConvLayer(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=2,
            padding=1,
            transposed=True,
            activation=False
        )
        self.feature_conv1 = ConvLayer(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=2,
            transposed=True,
            pre_activate=True
        )
        self.feature_conv2 = ConvLayer(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            transposed=True,
            pre_activate=True
        )
    
    def forward(self, z, h):
        phi = self.subspacelayer(z).view(h.shape)
        h = self.feature_conv1(self.subspace_conv1(phi) + h)
        h = self.feature_conv2(self.subspace_conv2(phi) + h)
        return h

In [None]:
class Generator(nn.Module):
    def __init__(self,
        ) -> None:
        super().__init__()
