In [20]:
%pip install --upgrade -q pip jax jaxlib
%pip install --upgrade -q git+https://github.com/google/flax.git

/home/wsl2/slm/.venv/bin/python: No module named pip
Note: you may need to restart the kernel to use updated packages.
/home/wsl2/slm/.venv/bin/python: No module named pip
Note: you may need to restart the kernel to use updated packages.


In [21]:
from typing import Any, Callable, Sequence

import flax
import jax
from flax import linen as nn
from jax import numpy as jnp
from jax import random

## 全結合層

- **Implementation by Flax** : `flax.linen.Dense`
- 「Fead Forward層」、「Prediction層」などと別の名称で呼称されることが多い

### `base`

- 1次元入力に対する全結合層
- $ p[j] = \sum^i n[i] \times w[i, j] + b[j] $
  - $ p $ : 出力
  - $ n $ : 入力
  - $ w $ : 重みパラメータ(入力次元×出力次元)
  - $ b $ : バイアスパラメータ(出力次元)

In [22]:
class Base(nn.Module):
    features: int
    bias_init: Callable = nn.initializers.zeros_init()
    weight_init: Callable = nn.initializers.normal(stddev=0.02)

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        weight = self.param('weight',
                            self.weight_init,
                            (inputs.shape[-1], self.features))
        y = jnp.dot(x, weight)
        bias = self.param('bias', self.bias_init, (self.features,))
        y = y + bias
        return y

### `Conv1D`

- ２次元ベクトルデータへの畳み込み層
- カーネルサイズが1であるため、全結合層と同様に求めることができる
- $ p[j] = \sum^i n[i] \times w[i, j] + b[j] $
  - $ p $ : 出力
  - $ n $ : 入力
  - $ w $ : 重みパラメータ(入力次元×出力次元)
  - $ b $ : バイアスパラメータ(出力次元)

In [23]:
class Conv1D(nn.Module):
    features: int
    bias_init: Callable = nn.initializers.zeros_init()
    weight_init: Callable = nn.initializers.normal(stddev=0.02)

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        weight = self.param('weight',
                            self.weight_init,
                            (1, x.shape[-1], self.features))
        batch_size = x.shape[0]
        x_channel = x.shape[-1]

        # カーネルサイズに畳み込み
        x = jnp.reshape(x, (-1, batch_size))
        weight = jnp.reshape(weight, (-1, x_channel))

        y = jnp.dot(x, weight)
        bias = self.param('bias', self.bias_init, (self.features,))
        y = y + bias

        # カーネルを展開
        y = jnp.reshape(y, (batch_size, x.shape[1], self.features))
        return y

## Normalization層

- **Implementaion by Flax** : `flax.linen.LayerNorm`
- $ h_i = f\left(\frac{g_i}{\sigma_i}\left(a_i-\mu_i\right)+b_i \right) $
  - $ \sigma $ : 二乗誤差の平方根
  - $ \mu $ : 平均値
  - $ g $ : 重みパラメータ
  - $ b $ : バイアスパラメータ

In [24]:
class Normalization(nn.Module):
    axis: int = -1
    epsilon: float = 1e-5
    weight_init: Callable = nn.initializers.ones_init()
    bias_init: Callable = nn.initializers.zeros_init()

    @nn.compact
    def __call__(self, input):
        x = input
        # 平均値と二乗誤差
        mean = jnp.mean(x, axis=self.axis, keepdims=True)
        squared_error = jnp.mean(jnp.square(x-mean), axis=self.axis, keepdims=True)
        # 二乗誤差の平方根
        y = (x - mean) * jax.lax.rsqrt(squared_error + self.epsilon)
        weight = self.param('weight', self.weight_init, (input.shape[-1],))
        bias = self.param('bias', self.bias_init, (input.shape[-1]))
        y = y * weight + bias
        return y

## 活性化関数

#### `GELU`

- **Implementation by JAX** : `jax.nn.gelu`

- $ 0.5 * \left( 1 + \tanh\left(\sqrt{2/\pi}\left(x+0.044715 x^3\right)\right) \right) $
- 原点が0を通る
- 負の数では0に近づく
- 正の数では比例する
- 滑らかな連続関数

In [25]:
def gelu(x):
    return 0.5 * (1 + jnp.tanh(jnp.sqrt(2/jnp.pi) * (x + 0.044715 * jnp.pow(x, 3))))

#### `Swish`

- **Implementation by JAX** : `jax.nn.swish`
- $ x \times \text{sigmoid}(x) $
- `GELU`の単純実装

In [26]:
def swish(x):
    return x * jax.nn.sigmoid(x)

#### `Mish`

- **Implementation by JAX** : `jax.nn.mish`
- $ x * \tanh(\text{softplus}(x)) $
- `Swish`よりも性能が良いされている

In [27]:
def mish(x):
    return x * jnp.tanh(jax.nn.softplus(x))

#### `Softmax`

- **Implementation by JAX** : `jax.nn.softmax`
- $ \exp{x_i} / \sum_j \exp{x_j} $
- 全和が1になるように拡大・縮小する

In [28]:
def softmax(x, axis=-1):
    ex = jnp.exp(x - jnp.sum(x, axis=axis, keepdims=True))
    return ex / jnp.sum(ex, axis=axis, keepdims=True)

#### `LogSoftmax`

- **Implementation by JAX** : `jax.nn.log_softmax`
- $ x - \log \left(\sum_j \exp{x_j}\right) $
  - $ \log \left(\exp{x} / \sum{x}\right) $
  - $ = \log\left(\exp{(x)}\right) - \log\left(\sum{\exp(x)}\right)$
  - $ = x - \log\left(\sum{\exp(x)}\right) $

In [29]:
def log_softmax(x, axis=-1):
    ex = jnp.exp(x)
    return x - jnp.log(jnp.sum(ex))

## 損失関数

### Cross Entropy Loss

- **Implementation by Flax** : `optax.losses.safe_softmax_cross_entropy`
- エッジケースによる、極大値や極小値への対応が必要
- Flaxでは損失関数、活性化関数の両方で対応可能

In [30]:
def cross_entropy_loss(labels, logits):
    # batch_size * tokenを一次元ベクトルに展開
    num_vocabulary = logits.shape[-1]
    flatten_labels = jnp.reshape(labels, (-1,))
    flatten_labels = jnp.asarray(flatten_labels, dtype=jnp.int32)
    flatten_logits = jnp.reshape(logits, (-1, num_vocabulary))
    # ラベルをone-hotベクトル化
    one_hot_labels = jax.nn.one_hot(flatten_labels, num_vocabulary, dtype=jnp.float32)
    # クロスエントロピーを算出
    log_probs = log_softmax(flatten_logits)
    # 損失を計算
    loss = -1 * jnp.sum(log_probs * one_hot_labels, axis=(-1))
    # 逆伝播誤差
    loss = jnp.mean(loss)
    return loss


### Masked Cross Entropy Loss

- **Implementation by Flax** : `None`
- 出力の特定の位置を無視できるように動作するクロスエントロピー
- 正解ラベルとして`-1`の場合には無視する

In [31]:
def masked_cross_entropy_loss(labels, logits):
    # batch_size * tokenを一次元ベクトルに展開
    num_vocabulary = logits.shape[-1]
    flatten_labels = jnp.reshape(labels, (-1,))
    flatten_logits = jnp.reshape(logits, (-1, num_vocabulary))
    # マスクの作成
    mask = flatten_labels + 1
    mask = jnp.asarray(jnp.asarray(mask, dtype=jnp.bool), dtype=jnp.int32)
    # マスクを適応
    flatten_labels = jnp.asarray(flatten_labels + (1 - mask), jnp.int32)
    # ラベルをone-hotベクトル化
    one_hot_labels = jax.nn.one_hot(flatten_labels, num_vocabulary, dtype=jnp.float32)
    # クロスエントロピーを算出
    log_probs = log_softmax(flatten_logits)
    # 損失を計算
    loss = -1 * jnp.sum(log_probs * one_hot_labels, axis=(-1))
    #マスクを損失に適応
    mask = jnp.asarray(mask, jnp.float32)
    loss = loss * mask
    # 逆伝播誤差
    loss = jnp.mean(loss) / (jnp.sum(mask) + 1e-5)
    return loss


## Attention層

### Self-Attention

- **Implementation by Flax** : `flax.linen.SelfAttention`
- `tanh`層で挟んだ2層のFNN
- `Softmax`で活性化したのちに、入力ベクトルと要素を掛け合わせる

In [32]:
class SelfAttention(nn.Module):
    @nn.compact
    def __call__(self, input):
        x = input
        nx = x.shape[-1]
        p = Conv1D(nx, name=self.name+"_p1")(x)
        p = jnp.tanh(p)
        p = Conv1D(nx, name=self.name+"_p2")(p)
        p = softmax(p)
        return x * p

### KeyValue-Self-Attention

- **Implementation by Flax** : `flax.linen.SelfAttention`
- 単一ヘッドのキーバリューストアによるSelf-Attention

In [33]:
class KeyValueSelfAttention(nn.Module):
    qkv_features: int
    out_features: int

    @nn.compact
    def __call__(self, input, scores_weight=None):
        x = input
        # Query, Key, Valueに分割して投射
        query = Conv1D(self.qkv_features, name=self.name+"_query")(x)
        key = Conv1D(self.qkv_features, name=self.name+"_key")(x)
        value = Conv1D(self.qkv_features, name=self.name+"_value")(x)
        # シーケンス、特徴量方向の次元の積を算出
        scores = query @ key.T
        scores = scores * 1 / jnp.sqrt(jnp.asarray(query.shape[-1], jnp.float32))
        if scores_weight is not None:
            scores += scores_weight
        # Self-Attention
        prob = softmax(scores)
        context = prob @ value
        context = Conv1D(self.out_features, name=self.name+"_proj")(context)
        return context



### Multi-head-Attentition

- **Implementation by Flax** : `flax.linen.MultiHeadAttention`
- 複数ヘッドによるAttention

In [34]:
class MultiHeadAttention(nn.Module):
    qkv_features: int
    out_features: int
    num_heads: int

    @nn.compact
    def __call__(self, input, scores_weight=None):
        x = input
        batch_size = -1 if x.shape[0] is None else x.shape[0]
        num_tokens = x.shape[2]
        num_channels = x.shape[1] * x.shape[3]
        # Query, Key, Valueに分割して投射
        query = Conv1D(self.qkv_features, name=self.name+"_query")(x)
        key = Conv1D(self.qkv_features, name=self.name+"_key")(x)
        value = Conv1D(self.qkv_features, name=self.name+"_value")(x)
        query = jnp.reshape(query, (batch_size, x.shape[1], self.num_heads, x.shape[-1] // self.num_heads))
        key = jnp.reshape(key, (batch_size, x.shape[1], self.num_heads, x.shape[-1] // self.num_heads))
        value = jnp.reshape(value, (batch_size, x.shape[1], self.num_heads, x.shape[-1] // self.num_heads))
        # [batch, head, sequence, features] になるように入れ替え
        query = jnp.transpose(query, (0, 2, 1, 3))
        key = jnp.transpose(key, (0, 2, 1, 3))
        value = jnp.transpose(value, (0, 2, 1, 3))
        # シーケンス、特徴量方向の次元の積を算出
        scores = query @ key.T
        scores = scores * 1 / jnp.sqrt(jnp.asarray(num_channels, jnp.float32))
        if scores_weight is not None:
            scores += scores_weight
        # Self-Attention
        probs = softmax(scores)
        context = probs @ value
        # [batch, sequence, heads, features] に戻す
        context = jnp.transpose(context, (0, 2, 1, 3))
        # [batch, sequence, features] に戻す
        context = jnp.reshape(context, (batch_size, num_tokens, num_channels))
        return Conv1D(self.out_features, name=self.name+"_proj")



## Embedded層

- **Implementation by Flax** : `None`
- 単語ベクトルを学習可能なパラメータとして保持する

In [35]:
class PositionEmbedding(nn.Module):
    num_context: int
    num_hidden: int
    weight_init: Callable = nn.initializers.normal()

    @nn.compact
    def __call__(self, input):
        x = input
        weights = self.param('weight', self.weight_init, (self.num_context, self.num_hidden))
        return x + weights

In [36]:
class VocabularyEmbedding(nn.Module):
    num_vocabulary: int
    weight_init: Callable = nn.initializers.normal()

    @nn.compact
    def __call__(self, input):
        x = input
        weights = self.param('weight', self.weight_init, (self.num_vocabulary,))
        return x + weights

### Embed層

- **Implementation by Flax** : `None`

In [37]:
class Embed(nn.Module):
    num_vocab: int
    num_hidden: int
    weight_init: Callable = nn.initializers.normal(stddev=0.02)

    @nn.compact
    def __call__(self, input, num_context):
        x = input
        weights = self.param('weight', self.weight_init, (self.num_vocab, self.num_hidden))
        if num_context > 0:
            x_flat = jnp.reshape(x, (-1, self.num_hidden))
            logits = x_flat @ weights.T
            return jnp.reshape(logits, (-1, num_context, self.num_vocab))
        else:
            num_context = x.shape[-1]
            x_flat = jnp.reshape(x, (-1,))
            one_hot_ids = jax.nn.one_hot(x_flat, self.num_vocab)
            vector = one_hot_ids @ weights
            return jnp.reshape(vector, (-1, num_context, self.num_hidden))

In [51]:
e = Embed(num_vocab=2000, num_hidden=200, name="Embeddings")

key1, key2 = random.split(random.key(0), 2)
x = jnp.array([[100, 101, 102]])

params = e.init(key1, x, 0)
y = e.apply(params, x, 0)
print(y.shape)

word_vector = random.uniform(key2, (1, 5, 200))
params = e.init(key1, word_vector, 5)
y = e.apply(params, word_vector, 5)
print(y.shape)

print(jnp.argmax(y, axis=-1).shape)

key, key3, key4 = random.split(key1, 3)
params = e.init(key3, x, 0)
y = e.apply(params, x, 0)
params = e.init(key4, y, 3)
y = e.apply(params, y, 3)
print(jnp.argmax(y, axis=-1))

key, key5, key6 = random.split(key, 3)
logits_out = random.uniform(key5, (1, 3, 200))
params = e.init(key6, logits_out, 3)
y = e.apply(params, logits_out, 3)
print(cross_entropy_loss(x, y))
print(cross_entropy_loss(x, y).shape)

(1, 3, 200)
(1, 5, 2000)
(1, 5)
[[347 332 955]]
8.714087
()
