In [None]:
# | default_exp init

# Initialization
> Implementation of various random initialization strategies suitable for complex-valued layers.

In [None]:
# |hide
from nbdev.showdoc import *

In [None]:
# | export
import math
import torch
import numpy as np
from torch.nn import init

In [None]:
# | export
def cplx_kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
    a = math.sqrt(1 + 2 * a * a)
    init.kaiming_normal_(tensor.real, a=a, mode=mode, nonlinearity=nonlinearity)
    init.kaiming_normal_(tensor.imag, a=a, mode=mode, nonlinearity=nonlinearity)

In [None]:
# | export
def cplx_kaiming_uniform_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
    a = math.sqrt(1 + 2 * a * a)
    init.kaiming_uniform_(tensor.real, a=a, mode=mode, nonlinearity=nonlinearity)
    init.kaiming_uniform_(tensor.imag, a=a, mode=mode, nonlinearity=nonlinearity)

In [None]:
# | export
def cplx_xavier_normal_(tensor, gain=1.0):
    init.xavier_normal_(tensor.real, gain=gain / math.sqrt(2))
    init.xavier_normal_(tensor.imag, gain=gain / math.sqrt(2))

In [None]:
# | export
def cplx_xavier_uniform_(tensor, gain=1.0):
    init.xavier_uniform_(tensor.real, gain=gain / math.sqrt(2))
    init.xavier_uniform_(tensor.imag, gain=gain / math.sqrt(2))

In [None]:
# | export
def cplx_trabelsi_standard_(tensor, kind="glorot"):
    """Standard complex initialization proposed in Trabelsi et al. (2018)."""
    kind = kind.lower()
    assert kind in ("glorot", "xavier", "kaiming", "he")

    fan_in, fan_out = init._calculate_fan_in_and_fan_out(tensor)
    if kind == "glorot" or kind == "xavier":
        scale = 1 / math.sqrt(fan_in + fan_out)
    else:
        scale = 1 / math.sqrt(fan_in)

    # Rayleigh(\sigma / \sqrt2) x uniform[-\pi, +\pi] on p. 7
    rho = np.random.rayleigh(scale, size=tensor.shape)
    theta = np.random.uniform(-np.pi, +np.pi, size=tensor.shape)

    # eq. (8) on p. 6
    with torch.no_grad():
        tensor.real.copy_(torch.from_numpy(np.cos(theta) * rho))
        tensor.imag.copy_(torch.from_numpy(np.sin(theta) * rho))

In [None]:
# | export
def cplx_trabelsi_independent_(tensor, kind="glorot"):
    """Orthogonal complex initialization proposed in Trabelsi et al. (2018)."""
    kind = kind.lower()
    assert kind in ("glorot", "xavier", "kaiming", "he")

    ndim = tensor.dim()
    if ndim == 2:
        shape = tensor.shape
    else:
        shape = int(np.prod(tensor.shape[:2])), int(np.prod(tensor.shape[2:]))

    # generate a semi-unitary (orthogonal) matrix from a random matrix
    # M = U V is semi-unitary: V^H U^H U V = I_k
    Z = np.random.rand(*shape) + 1j * np.random.rand(*shape)

    # Z is n x m, so u is n x n and vh is m x m
    u, _, vh = np.linalg.svd(Z, compute_uv=True, full_matrices=True, hermitian=False)
    k = min(*shape)
    M = np.dot(u[:, :k], vh[:, :k].conjugate().T)

    fan_in, fan_out = init._calculate_fan_in_and_fan_out(tensor)
    if kind == "glorot" or kind == "xavier":
        scale = 1 / math.sqrt(fan_in + fan_out)
    else:
        scale = 1 / math.sqrt(fan_in)

    M /= M.std() / scale
    M = M.reshape(tensor.shape)
    with torch.no_grad():
        tensor.real.copy_(torch.from_numpy(M.real))
        tensor.imag.copy_(torch.from_numpy(M.imag))

In [None]:
# | export
def cplx_normal_independent_(tensor, a=0.0, b=1.0):
    init.normal_(tensor.real, a, b)
    init.normal_(tensor.imag, a, b)

In [None]:
# | export
def cplx_uniform_independent_(tensor, a=0.0, b=1.0):
    init.uniform_(tensor.real, a, b)
    init.uniform_(tensor.imag, a, b)

In [None]:
# | export
def ones_(tensor, imag_zero=False):
    tensor.real = 1
    if not imag_zero:
        tensor.imag = 1
    else:
        tensor.imag = 0

In [None]:
# | export
def zeros_(tensor):
    tensor.real = 0
    tensor.imag = 0

In [None]:
weights = torch.randn(48, 100, dtype=torch.cdouble) / math.sqrt(48)

In [None]:
cplx_kaiming_normal_(weights)

In [None]:
cplx_trabelsi_standard_(weights)

In [None]:
ones_(weights, imag_zero=True)

In [None]:
cplx_normal_independent_(weights)

In [None]:
cplx_normal_independent_(weights)

In [None]:
weights

tensor([[ 1.2088-0.1780j,  1.3915+0.0717j, -0.4758-0.0159j,  ...,
          1.2787-0.4958j, -1.3118-0.1963j, -0.1222+2.1602j],
        [-1.0095+2.3892j, -0.4441+1.6676j,  0.6802+1.2186j,  ...,
         -0.2286+0.4150j, -0.1627+0.8436j,  0.4031+0.5756j],
        [ 0.1665-0.5673j,  0.7762-1.4712j,  0.8819-0.6611j,  ...,
         -0.2478-0.8351j,  0.5986+0.3264j,  0.9021-0.5091j],
        ...,
        [ 0.5833-0.1446j,  0.3841+0.0449j, -0.7109-0.9702j,  ...,
          1.4883+0.1660j,  0.2870-0.5798j,  2.3875+2.6999j],
        [-0.2952+0.8427j, -0.6253+2.1380j, -0.5874-0.0462j,  ...,
          1.7140-0.2101j, -0.2020+0.7448j,  0.1977+1.1116j],
        [-2.0322+0.3303j, -0.0070+0.4699j, -1.3509-0.4418j,  ...,
          0.5421-1.1324j,  0.8526+1.6163j,  2.2733-0.3466j]],
       dtype=torch.complex128)