In [2]:
%pip install --upgrade -q pip jax jaxlib
%pip install --upgrade -q git+https://github.com/google/flax.git

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [3]:
import jax
from typing import Any, Callable, Sequence
from jax import random, numpy as jnp
import flax
from flax import linen as nn

## 全結合層

- **Implementation by Flax** : `flax.linen.Dense`
- 「Fead Forward層」、「Prediction層」などと別の名称で呼称されることが多い

### `base`

- 1次元入力に対する全結合層
- $ p[j] = \sum^i n[i] \times w[i, j] + b[j] $
  - $ p $ : 出力
  - $ n $ : 入力
  - $ w $ : 重みパラメータ(入力次元×出力次元)
  - $ b $ : バイアスパラメータ(出力次元)

In [4]:
class Base(nn.Module):
    features: int
    bias_init: Callable = nn.initializers.zeros_init()
    weight_init: Callable = nn.initializers.normal()

    @nn.compact
    def __call__(self, inputs: jax.typing.ArrayLike):
        x = inputs
        weight = self.param('weight',
                            self.weight_init,
                            (inputs.shape[-1], self.features))
        y = jnp.dot(x, weight)
        bias = self.param('bias', self.bias_init, (self.features,))
        y = y + bias
        return y

### `Conv1D`

- ２次元ベクトルデータへの畳み込み層
- カーネルサイズが1であるため、全結合層と同様に求めることができる
- $ p[j] = \sum^i n[i] \times w[i, j] + b[j] $
  - $ p $ : 出力
  - $ n $ : 入力
  - $ w $ : 重みパラメータ(入力次元×出力次元)
  - $ b $ : バイアスパラメータ(出力次元)

In [10]:
class Conv1D(nn.Module):
    features: int
    bias_init: Callable = nn.initializers.zeros_init()
    weight_init: Callable = nn.initializers.normal()

    @nn.compact
    def __call__(self, inputs: jax.typing.ArrayLike):
        x = inputs
        weight = self.param('weight',
                            self.weight_init,
                            (1, x.shape[-1], self.features))
        batch_size = x.shape[0]
        x_channel = x.shape[-1]

        # カーネルサイズに畳み込み
        x = jnp.reshape(x, (-1, batch_size))
        weight = jnp.reshape(weight, (-1, x_channel))

        y = jnp.dot(x, weight)
        bias = self.param('bias', self.bias_init, (self.features,))
        y = y + bias

        # カーネルを展開
        y = jnp.reshape(y, (batch_size, x.shape[1], self.features))
        return y

## Normalization層

- **Implementaion by Flax** : `flax.linen.LayerNorm`
- $ h_i = f\left(\frac{g_i}{\sigma_i}\left(a_i-\mu_i\right)+b_i \right) $
  - $ \sigma $ : 二乗誤差の平方根
  - $ \mu $ : 平均値
  - $ g $ : 重みパラメータ
  - $ b $ : バイアスパラメータ

In [None]:
class Normalization(nn.Module):
    axis: int = -1
    epsilon: float = 1e-5
    weight_init: Callable = nn.initializers.ones()
    bias_init: Callable = nn.initializers.zeros()

    @nn.compact
    def __call__(self, input):
        x = input
        # 平均値と二乗誤差
        mean = jnp.mean(x, axis=self.axis, keepdims=True)
        squared_error = jnp.mean(jnp.square(x-mean), axis=self.axis, keepdims=True)
        # 二乗誤差の平方根
        y = (x - mean) * jax.lax.rsqrt(squared_error + self.epsilon)
        weight = self.param('weight', self.weight_init, (input.shape[-1],))
        bias = self.param('bias', self.bias_init, (input.shape[-1]))
        y = y * weight + bias
        return y