In [None]:
# | default_exp init

# Initialization
> Implementation of various random initialization strategies suitable for complex-valued layers.

In [None]:
# |hide
from nbdev.showdoc import *

In [None]:
# | export
import math
import torch
import numpy as np
from torch.nn import init

In [None]:
# | export
def cplx_kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
    a = math.sqrt(1 + 2 * a * a)
    init.kaiming_normal_(tensor.real, a=a, mode=mode, nonlinearity=nonlinearity)
    init.kaiming_normal_(tensor.imag, a=a, mode=mode, nonlinearity=nonlinearity)

In [None]:
# | export
def cplx_kaiming_uniform_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
    a = math.sqrt(1 + 2 * a * a)
    init.kaiming_uniform_(tensor.real, a=a, mode=mode, nonlinearity=nonlinearity)
    init.kaiming_uniform_(tensor.imag, a=a, mode=mode, nonlinearity=nonlinearity)

In [None]:
# | export
def cplx_xavier_normal_(tensor, gain=1.0):
    init.xavier_normal_(tensor.real, gain=gain / math.sqrt(2))
    init.xavier_normal_(tensor.imag, gain=gain / math.sqrt(2))

In [None]:
# | export
def cplx_xavier_uniform_(tensor, gain=1.0):
    init.xavier_uniform_(tensor.real, gain=gain / math.sqrt(2))
    init.xavier_uniform_(tensor.imag, gain=gain / math.sqrt(2))

In [None]:
# | export
def cplx_trabelsi_standard_(tensor, kind="glorot"):
    """Standard complex initialization proposed in Trabelsi et al. (2018)."""
    kind = kind.lower()
    assert kind in ("glorot", "xavier", "kaiming", "he")

    fan_in, fan_out = init._calculate_fan_in_and_fan_out(tensor)
    if kind == "glorot" or kind == "xavier":
        scale = 1 / math.sqrt(fan_in + fan_out)
    else:
        scale = 1 / math.sqrt(fan_in)

    # Rayleigh(\sigma / \sqrt2) x uniform[-\pi, +\pi] on p. 7
    rho = np.random.rayleigh(scale, size=tensor.shape)
    theta = np.random.uniform(-np.pi, +np.pi, size=tensor.shape)

    # eq. (8) on p. 6
    with torch.no_grad():
        tensor.real.copy_(torch.from_numpy(np.cos(theta) * rho))
        tensor.imag.copy_(torch.from_numpy(np.sin(theta) * rho))

In [None]:
# | export
def cplx_trabelsi_independent_(tensor, kind="glorot"):
    """Orthogonal complex initialization proposed in Trabelsi et al. (2018)."""
    kind = kind.lower()
    assert kind in ("glorot", "xavier", "kaiming", "he")

    ndim = tensor.dim()
    if ndim == 2:
        shape = tensor.shape
    else:
        shape = int(np.prod(tensor.shape[:2])), int(np.prod(tensor.shape[2:]))

    # generate a semi-unitary (orthogonal) matrix from a random matrix
    # M = U V is semi-unitary: V^H U^H U V = I_k
    Z = np.random.rand(*shape) + 1j * np.random.rand(*shape)

    # Z is n x m, so u is n x n and vh is m x m
    u, _, vh = np.linalg.svd(Z, compute_uv=True, full_matrices=True, hermitian=False)
    k = min(*shape)
    M = np.dot(u[:, :k], vh[:, :k].conjugate().T)

    fan_in, fan_out = init._calculate_fan_in_and_fan_out(tensor)
    if kind == "glorot" or kind == "xavier":
        scale = 1 / math.sqrt(fan_in + fan_out)
    else:
        scale = 1 / math.sqrt(fan_in)

    M /= M.std() / scale
    M = M.reshape(tensor.shape)
    with torch.no_grad():
        tensor.real.copy_(torch.from_numpy(M.real))
        tensor.imag.copy_(torch.from_numpy(M.imag))

In [None]:
# | export
def cplx_normal_independent_(tensor, a=0.0, b=1.0):
    init.normal_(tensor.real, a, b)
    init.normal_(tensor.imag, a, b)

In [None]:
# | export
def cplx_uniform_independent_(tensor, a=0.0, b=1.0):
    init.uniform_(tensor.real, a, b)
    init.uniform_(tensor.imag, a, b)

In [None]:
# | export
def ones_(tensor, imag_zero=False):
    tensor.real = 1
    if not imag_zero:
        tensor.imag = 1
    else:
        tensor.imag = 0

In [None]:
# | export
def zeros_(tensor):
    tensor.real = 0
    tensor.imag = 0

In [None]:
weights = torch.randn(48, 100, dtype=torch.cdouble) / math.sqrt(48)

In [None]:
cplx_kaiming_normal_(weights)

In [None]:
cplx_trabelsi_standard_(weights)

In [None]:
ones_(weights, imag_zero=True)

In [None]:
cplx_normal_independent_(weights)

In [None]:
cplx_normal_independent_(weights)

In [None]:
weights

tensor([[-0.0470+0.7112j, -1.5469+0.1394j, -1.1867-0.1471j,  ...,
          0.5766+0.8785j,  0.6712+0.0261j,  0.0456-0.7064j],
        [-0.6531-1.1577j,  0.8416-0.2790j,  0.0602+0.5025j,  ...,
         -0.6614-2.2219j,  0.7899-1.3343j,  0.3995+1.1260j],
        [-1.7843-0.4350j,  1.7556-1.0172j, -0.3011-0.9143j,  ...,
         -0.7598+1.0774j, -0.6681+0.5441j,  1.6245-2.0797j],
        ...,
        [-1.0470+1.3312j, -0.6345+1.6134j, -0.9516+0.6245j,  ...,
         -0.0545-0.2412j,  0.3157+1.3666j,  0.4336-1.8152j],
        [-1.6996+1.1974j,  0.6489-0.8106j, -0.3492-1.3665j,  ...,
          1.5693+0.1399j, -0.3388+0.8230j,  0.6139-0.0835j],
        [-1.1932+0.1981j, -0.2530-0.4910j,  0.4789-0.0780j,  ...,
          1.5589+0.9441j,  1.1046-0.4397j, -2.2771-1.6190j]],
       dtype=torch.complex128)