In [66]:
import torch
from torch import nn
import numpy as np

from torch.nn.utils.spectral_norm import SpectralNorm

from helper import assert_shape

In [59]:
class SpectralConv1d(nn.Conv1d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        SpectralNorm.apply(self, name = "weight", n_power_iterations=1, dim = 0, eps = 1e-12)

# Dummy input: [batch, channels, length]
x = torch.randn(5, 64, 100)

conv1d_modified = SpectralConv1d(in_channels=64, out_channels=128, kernel_size=5, stride=1)
modified_out = conv1d_modified(x)
modified_out.shape

torch.Size([5, 128, 96])

In [122]:
class LocalBatchNorm(nn.Module):
    # When using large batch sizes, the variance across the batch can be very high, especially in early training.
    # It may cause the normalization to overreact, resulting in instability in the discriminator’s learning.
    # So we use virtual_bs for smaller batch size to normalize it through.
    def __init__(self, num_features: int, affine: bool, virtual_bs: int = 8, eps: float = 1e-8):
        super().__init__()

        self.num_features = num_features
        self.affine = affine             # learn weight and biases?
        self.virtual_bs = virtual_bs
        self.eps = eps

        if self.affine:
            self.weights = nn.Parameter(torch.ones(num_features))
            self.bias    = nn.Parameter(torch.zeros(num_features))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shape = x.shape

        G = np.ceil(x.shape[0] / self.virtual_bs).astype(int) # G = B / 8 
        x = x.view(G, -1, x.shape[1], x.shape[2])             # x: [G, -1, N, C]
        # Normalizing per group, per channel
        mean = x.mean([1, 3], keepdim=True)        # mean: [G, 1, N, 1] 
        var = x.var([1, 3], keepdim=True)          # var : [G, 1, N, 1] 

        x = (x - mean) / (np.sqrt(var) + self.eps)      # x: [G, -1, N, C]

        if self.affine:
            x = x * self.weights[None, :, None]      # weight: [1, N, 1]
                                                    # x     : [G, -1, N, C]

            x = x + self.bias[None, :, None]        # bias  : [1, N, 1]
                                                    # x     : [G, -1, N, C]
        return x.view(shape)

x = torch.randn(18, 64, 100)

lbn = LocalBatchNorm(num_features = 64, affine=True)
out = lbn(x)
out.shape

torch.Size([18, 64, 100])

In [118]:
x = torch.randn(18, 64, 100)

shape = x.shape

num_features=64
virtual_bs=8
eps = 1e-8
affine = True

weight = nn.Parameter(torch.ones(num_features))
bias   = nn.Parameter(torch.zeros(num_features))

G = np.ceil(x.shape[0] / virtual_bs).astype(int) # G = 20 / 8 = 3

x = x.view(G, -1, x.shape[1], x.shape[2])  # x: [G, -1, N, C]

# Normalizing per group, per channel
mean = x.mean([1, 3], keepdim=True)        # mean: [G, 1, N, 1] 
var = x.var([1, 3], keepdim=True)          # var : [G, 1, N, 1] 

x = (x - mean) / (np.sqrt(var) + eps)      # x: [G, -1, N, C]
if affine: 
    x = x * weight[None, :, None]          # weight: [1, N, 1]
                                           # x     : [G, -1, N, C]

    x = x + bias[None, :, None]            # bias  : [1, N, 1]
                                           # x     : [G, -1, N, C]
x = x.view(shape)
x.shape

torch.Size([18, 64, 100])

In [106]:
bias[None, : None].shape

torch.Size([1, 64])

In [84]:
x.shape[0] / 18

1.0

In [70]:
x.view(G, -1, x.shape[1], x.shape[2])

RuntimeError: shape '[3, -1, 64, 100]' is invalid for input of size 128000