In [None]:
# | default_exp init

# Initialization
> Implementation of various random initialization strategies suitable for complex-valued layers.

In [None]:
# |hide
from nbdev.showdoc import *

In [None]:
# | export
import math
import torch
import numpy as np
from torch.nn import init

In [None]:
# | export
def get_fans(cplxtensor):
    """Almost verbatim copy of `init._calculate_fan_in_and_fan_out`"""
    ndim = cplxtensor.dim()
    if ndim < 2:
        raise ValueError(
            "Fan in and fan out can not be computed for tensor with "
            "fewer than 2 dimensions."
        )

    n_fmaps_output, n_fmaps_input, *rest = cplxtensor.shape
    if ndim == 2:
        fan_in, fan_out = n_fmaps_output, n_fmaps_input

    else:
        receptive_field_size = np.prod((1, *rest))
        fan_in = n_fmaps_input * receptive_field_size
        fan_out = n_fmaps_output * receptive_field_size

    return fan_in, fan_out

In [None]:
# | export
def cplx_kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
    a = math.sqrt(1 + 2 * a * a)
    init.kaiming_normal_(tensor.real, a=a, mode=mode, nonlinearity=nonlinearity)
    init.kaiming_normal_(tensor.imag, a=a, mode=mode, nonlinearity=nonlinearity)

In [None]:
# | export
def cplx_xavier_normal_(tensor, gain=1.0):
    init.xavier_normal_(tensor.real, gain=gain / math.sqrt(2))
    init.xavier_normal_(tensor.imag, gain=gain / math.sqrt(2))

In [None]:
# | export
def cplx_kaiming_uniform_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
    a = math.sqrt(1 + 2 * a * a)
    init.kaiming_uniform_(tensor.real, a=a, mode=mode, nonlinearity=nonlinearity)
    init.kaiming_uniform_(tensor.imag, a=a, mode=mode, nonlinearity=nonlinearity)

In [None]:
# | export
def cplx_xavier_uniform_(tensor, gain=1.0):
    init.xavier_uniform_(tensor.real, gain=gain / math.sqrt(2))
    init.xavier_uniform_(tensor.imag, gain=gain / math.sqrt(2))

In [None]:
# | export
def cplx_trabelsi_standard_(tensor, kind="glorot"):
    """Standard complex initialization proposed in Trabelsi et al. (2018)."""
    kind = kind.lower()
    assert kind in ("glorot", "xavier", "kaiming", "he")

    fan_in, fan_out = init._calculate_fan_in_and_fan_out(tensor)
    if kind == "glorot" or kind == "xavier":
        scale = 1 / math.sqrt(fan_in + fan_out)
    else:
        scale = 1 / math.sqrt(fan_in)

    # Rayleigh(\sigma / \sqrt2) x uniform[-\pi, +\pi] on p. 7
    rho = np.random.rayleigh(scale, size=tensor.shape)
    theta = np.random.uniform(-np.pi, +np.pi, size=tensor.shape)

    # eq. (8) on p. 6
    with torch.no_grad():
        tensor.real.copy_(torch.from_numpy(np.cos(theta) * rho))
        tensor.imag.copy_(torch.from_numpy(np.sin(theta) * rho))

In [None]:
# | export
def cplx_trabelsi_independent_(tensor, kind="glorot"):
    """Orthogonal complex initialization proposed in Trabelsi et al. (2018)."""
    kind = kind.lower()
    assert kind in ("glorot", "xavier", "kaiming", "he")

    ndim = tensor.dim()
    if ndim == 2:
        shape = tensor.shape
    else:
        shape = int(np.prod(tensor.shape[:2])), int(np.prod(tensor.shape[2:]))

    # generate a semi-unitary (orthogonal) matrix from a random matrix
    # M = U V is semi-unitary: V^H U^H U V = I_k
    Z = np.random.rand(*shape) + 1j * np.random.rand(*shape)

    # Z is n x m, so u is n x n and vh is m x m
    u, _, vh = np.linalg.svd(Z, compute_uv=True, full_matrices=True, hermitian=False)
    k = min(*shape)
    M = np.dot(u[:, :k], vh[:, :k].conjugate().T)

    fan_in, fan_out = init._calculate_fan_in_and_fan_out(tensor)
    if kind == "glorot" or kind == "xavier":
        scale = 1 / math.sqrt(fan_in + fan_out)
    else:
        scale = 1 / math.sqrt(fan_in)

    M /= M.std() / scale
    M = M.reshape(tensor.shape)
    with torch.no_grad():
        tensor.real.copy_(torch.from_numpy(M.real))
        tensor.imag.copy_(torch.from_numpy(M.imag))

In [None]:
# | export
def cplx_uniform_independent_(tensor, a=0.0, b=1.0):
    init.uniform_(tensor.real, a, b)
    init.uniform_(tensor.imag, a, b)

In [None]:
# | export
def ones_(tensor, imag_zero=False):
    tensor.real = 1
    if not imag_zero:
        tensor.imag = 1
    else:
        tensor.imag = 0

In [None]:
# | export
def zeros_(tensor):
    tensor.real = 0
    tensor.imag = 0

In [None]:
weights = torch.randn(48, 100, dtype=torch.cdouble) / math.sqrt(48)

In [None]:
cplx_kaiming_normal_(weights)

In [None]:
cplx_trabelsi_standard_(weights)

In [None]:
ones_(weights, imag_zero=True)

In [None]:
weights

tensor([[-0.0639+0.0598j,  0.0096-0.0157j, -0.0034+0.0258j,  ...,
          0.0597+0.1678j, -0.1941-0.0643j, -0.0859+0.0302j],
        [ 0.1408+0.0399j, -0.0184+0.0744j, -0.0116-0.0638j,  ...,
          0.0015-0.0343j,  0.0500+0.0194j, -0.0193+0.1579j],
        [ 0.1240+0.0365j,  0.0625+0.1287j, -0.0760-0.0389j,  ...,
         -0.0329-0.0883j, -0.0886+0.0832j, -0.0763+0.0033j],
        ...,
        [-0.0690-0.0864j,  0.0969-0.0260j,  0.0542+0.0594j,  ...,
         -0.0266-0.0080j, -0.0857-0.0438j, -0.0829+0.0879j],
        [ 0.0309-0.0178j, -0.0471-0.0125j, -0.0295-0.0489j,  ...,
          0.1484-0.1507j,  0.0723-0.1205j,  0.0248+0.0568j],
        [-0.0315+0.0104j, -0.1373+0.0039j, -0.0147+0.1844j,  ...,
         -0.0114-0.0055j,  0.0397+0.0040j, -0.0397+0.0574j]],
       dtype=torch.complex128)