In [None]:
# | default_exp init

# Initialization
> Implementation of various random initialization strategies suitable for complex-valued layers.

In [None]:
# |hide
from nbdev.showdoc import *

In [None]:
# | export
import math
import torch
import numpy as np
from torch.nn import init

In [None]:
# | export
def cplx_kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
    a = math.sqrt(1 + 2 * a * a)
    init.kaiming_normal_(tensor.real, a=a, mode=mode, nonlinearity=nonlinearity)
    init.kaiming_normal_(tensor.imag, a=a, mode=mode, nonlinearity=nonlinearity)

In [None]:
# | export
def cplx_kaiming_uniform_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
    a = math.sqrt(1 + 2 * a * a)
    init.kaiming_uniform_(tensor.real, a=a, mode=mode, nonlinearity=nonlinearity)
    init.kaiming_uniform_(tensor.imag, a=a, mode=mode, nonlinearity=nonlinearity)

In [None]:
# | export
def cplx_xavier_normal_(tensor, gain=1.0):
    init.xavier_normal_(tensor.real, gain=gain / math.sqrt(2))
    init.xavier_normal_(tensor.imag, gain=gain / math.sqrt(2))

In [None]:
# | export
def cplx_xavier_uniform_(tensor, gain=1.0):
    init.xavier_uniform_(tensor.real, gain=gain / math.sqrt(2))
    init.xavier_uniform_(tensor.imag, gain=gain / math.sqrt(2))

In [None]:
# | export
def cplx_trabelsi_standard_(tensor, kind="glorot"):
    """Standard complex initialization proposed in Trabelsi et al. (2018)."""
    kind = kind.lower()
    assert kind in ("glorot", "xavier", "kaiming", "he")

    fan_in, fan_out = init._calculate_fan_in_and_fan_out(tensor)
    if kind == "glorot" or kind == "xavier":
        scale = 1 / math.sqrt(fan_in + fan_out)
    else:
        scale = 1 / math.sqrt(fan_in)

    # Rayleigh(\sigma / \sqrt2) x uniform[-\pi, +\pi] on p. 7
    rho = np.random.rayleigh(scale, size=tensor.shape)
    theta = np.random.uniform(-np.pi, +np.pi, size=tensor.shape)

    # eq. (8) on p. 6
    with torch.no_grad():
        tensor.real.copy_(torch.from_numpy(np.cos(theta) * rho))
        tensor.imag.copy_(torch.from_numpy(np.sin(theta) * rho))

In [None]:
# | export
def cplx_trabelsi_independent_(tensor, kind="glorot"):
    """Orthogonal complex initialization proposed in Trabelsi et al. (2018)."""
    kind = kind.lower()
    assert kind in ("glorot", "xavier", "kaiming", "he")

    ndim = tensor.dim()
    if ndim == 2:
        shape = tensor.shape
    else:
        shape = int(np.prod(tensor.shape[:2])), int(np.prod(tensor.shape[2:]))

    # generate a semi-unitary (orthogonal) matrix from a random matrix
    # M = U V is semi-unitary: V^H U^H U V = I_k
    Z = np.random.rand(*shape) + 1j * np.random.rand(*shape)

    # Z is n x m, so u is n x n and vh is m x m
    u, _, vh = np.linalg.svd(Z, compute_uv=True, full_matrices=True, hermitian=False)
    k = min(*shape)
    M = np.dot(u[:, :k], vh[:, :k].conjugate().T)

    fan_in, fan_out = init._calculate_fan_in_and_fan_out(tensor)
    if kind == "glorot" or kind == "xavier":
        scale = 1 / math.sqrt(fan_in + fan_out)
    else:
        scale = 1 / math.sqrt(fan_in)

    M /= M.std() / scale
    M = M.reshape(tensor.shape)
    with torch.no_grad():
        tensor.real.copy_(torch.from_numpy(M.real))
        tensor.imag.copy_(torch.from_numpy(M.imag))

In [None]:
# | export
def cplx_normal_independent_(tensor, a=0.0, b=1.0):
    init.normal_(tensor.real, a, b)
    init.normal_(tensor.imag, a, b)

In [None]:
# | export
def cplx_uniform_independent_(tensor, a=0.0, b=1.0):
    init.uniform_(tensor.real, a, b)
    init.uniform_(tensor.imag, a, b)

In [None]:
# | export
def ones_(tensor, imag_zero=False):
    tensor.real = 1
    if not imag_zero:
        tensor.imag = 1
    else:
        tensor.imag = 0

In [None]:
# | export
def zeros_(tensor):
    tensor.real = 0
    tensor.imag = 0

In [None]:
weights = torch.randn(48, 100, dtype=torch.cdouble) / math.sqrt(48)

In [None]:
cplx_kaiming_normal_(weights)

In [None]:
cplx_trabelsi_standard_(weights)

In [None]:
ones_(weights, imag_zero=True)

In [None]:
cplx_normal_independent_(weights)

In [None]:
cplx_normal_independent_(weights)

In [None]:
weights

tensor([[ 0.4300+0.3978j,  0.4044+0.0886j, -0.0497+0.3637j,  ...,
         -1.1625-0.2458j, -0.3606+1.3778j,  1.3233+0.1381j],
        [-0.0352+0.4488j, -0.4317+0.1049j, -0.5743-0.0114j,  ...,
          1.3755-0.1862j, -0.1263+1.6037j, -0.4251+0.0149j],
        [ 0.1022+2.1415j, -0.3880-0.8916j,  0.9741+0.4967j,  ...,
         -0.4281+1.6826j, -0.3992+0.4414j,  1.5657+0.3260j],
        ...,
        [-1.2332+1.4589j, -0.0185-3.0959j, -0.4712+2.1996j,  ...,
         -0.4455+1.3777j,  0.4046-0.8535j,  1.3204+1.6157j],
        [-1.3413-0.8369j, -1.7481+0.1788j,  0.2018-0.2512j,  ...,
         -0.3099-1.4537j,  0.1969+3.0771j,  0.7079+1.1503j],
        [ 1.7616+1.3624j,  0.6755+0.8201j, -0.2582+0.6527j,  ...,
         -1.9779+0.6458j,  0.1358+0.4088j,  0.0482+0.3482j]],
       dtype=torch.complex128)