# import module

In [None]:
import os
import numpy as np
import math
import json
from functools import partial
import functools

## Imports for plotting
import matplotlib.pyplot as plt

%matplotlib inline
from IPython.display import set_matplotlib_formats
import matplotlib
import seaborn as sns
sns.reset_orig()

## tqdm for loading bars
from tqdm.auto import tqdm

## To run JAX on TPU in Google Colab, uncomment the two lines below
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()

## JAX
import jax
import jax.numpy as jnp
from jax import random

## Flax (NN in JAX)
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax
    
from flax import linen as nn
from flax.training import train_state, checkpoints
from flax.training import common_utils
## Optax (Optimizers in JAX)
try:
    import optax
except ModuleNotFoundError: # Install optax if missing
    !pip install --quiet optax
    import optax

## PyTorch
import torch
import torch.utils.data as data
try:
    import wandb
except ModuleNotFoundError: # Install wandb if missing
    !pip install --quiet wandb
    import wandb
import pandas as pd
import re

[K     |████████████████████████████████| 202 kB 5.1 MB/s 
[K     |████████████████████████████████| 145 kB 40.2 MB/s 
[K     |████████████████████████████████| 596 kB 41.4 MB/s 
[K     |████████████████████████████████| 7.5 MB 29.5 MB/s 
[K     |████████████████████████████████| 217 kB 63.4 MB/s 
[K     |████████████████████████████████| 51 kB 4.6 MB/s 
[K     |████████████████████████████████| 76 kB 2.7 MB/s 
[K     |████████████████████████████████| 1.8 MB 5.5 MB/s 
[K     |████████████████████████████████| 181 kB 42.7 MB/s 
[K     |████████████████████████████████| 157 kB 23.3 MB/s 
[K     |████████████████████████████████| 63 kB 587 kB/s 
[K     |████████████████████████████████| 157 kB 18.5 MB/s 
[K     |████████████████████████████████| 157 kB 39.5 MB/s 
[K     |████████████████████████████████| 156 kB 10.3 MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [None]:
wandb.init(
    project='Flax-transformer',
    entity='seegong'
)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()


In [None]:
jax.local_devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

# data preprocess

In [None]:
url = 'https://github.com/songys/Chatbot_data/raw/master/ChatbotData.csv'
raw_data = pd.read_csv(url)
raw_data

Unnamed: 0,Q,A,label
0,12시 땡!,하루가 또 가네요.,0
1,1지망 학교 떨어졌어,위로해 드립니다.,0
2,3박4일 놀러가고 싶다,여행은 언제나 좋죠.,0
3,3박4일 정도 놀러가고 싶다,여행은 언제나 좋죠.,0
4,PPL 심하네,눈살이 찌푸려지죠.,0
...,...,...,...
11818,훔쳐보는 것도 눈치 보임.,티가 나니까 눈치가 보이는 거죠!,2
11819,훔쳐보는 것도 눈치 보임.,훔쳐보는 거 티나나봐요.,2
11820,흑기사 해주는 짝남.,설렜겠어요.,2
11821,힘든 연애 좋은 연애라는게 무슨 차이일까?,잘 헤어질 수 있는 사이 여부인 거 같아요.,2


In [None]:
raw_question = raw_data['Q'].values
raw_answer = raw_data['A'].values

In [None]:
def preprocess_sentecne(sentence):
    sentence = re.sub(r"[^a-zA-Z0-9ㄱ-ㅎㅏ-ㅣ가-힣?.!,\"']+", " ",  sentence) # 알파벳, 문장부호, 한글만 남기고 모두 제거
    sentence = sentence.lower().strip()
    return sentence

In [None]:
clean_question = [preprocess_sentecne(sentence) for sentence in raw_question]
clean_answer = [preprocess_sentecne(sentence) for sentence in raw_answer]

In [None]:
import os

os.system('apt-get update')
os.system('apt-get install g++ openjdk-8-jdk python-dev python3-dev')
os.environ['JAVA_HOME'] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.system("curl -s -L https://raw.githubusercontent.com/konlpy/konlpy/master/scripts/mecab.sh | bash")
os.system('pip3 install /tmp/mecab-python-0.996')

0

In [None]:
%%capture
!pip install konlpy

In [None]:
from konlpy.tag import Mecab
mecab = Mecab()

In [None]:
mecab_question = [mecab.morphs(sentence) for sentence in clean_question]
mecab_answer = [mecab.morphs(sentence) for sentence in clean_answer]

In [None]:
total_corpus = []

In [None]:
raw_corpus = []
for question, answer in zip(mecab_question, mecab_answer):
  if len(answer) <= 15 and len(question) <= 15:
    raw_corpus.append((question, answer))

In [None]:
from gensim.models import Word2Vec
word2vec_path = '/content/drive/MyDrive/Colab Notebooks/ko.bin'
word2vec = Word2Vec.load(word2vec_path)

In [None]:
import random as rand
from tqdm import tqdm

In [None]:
def lexical_sub(sentence, word2vec):
    res = []
    index_array = list(range(len(sentence)))
    rand.shuffle(index_array)
    for n, index in enumerate(index_array):
      try:      
        to = word2vec.wv.most_similar(sentence[index])[0][0]
        for tok in sentence:
          if tok is sentence[index]: res.append(to)
          else: res.append(tok)
        break

      except:
        if n == len(sentence):   
          return None
        else:
          continue

    return res

In [None]:
new_question_corpus = [(lexical_sub(sentence, word2vec), old_answer) for sentence, old_answer in tqdm(raw_corpus)]

100%|██████████| 11030/11030 [01:46<00:00, 103.37it/s]


In [None]:
for question, answer in new_question_corpus:
  if question is None: continue
  total_corpus.append((question, answer))

In [None]:
new_answer_corpus = [(old_question, lexical_sub(sentence, word2vec)) for old_question, sentence in tqdm(raw_corpus)]

100%|██████████| 11030/11030 [01:11<00:00, 154.02it/s]


In [None]:
for question, answer in new_answer_corpus:
  if answer is None: continue
  total_corpus.append((question, answer))

In [None]:
for question, answer in raw_corpus:
  total_corpus.append((question, answer))

In [None]:
word_dict = {}
for question, answer in total_corpus:
  for word in question:
    try:
      word_dict[word] += 1
    except:
      word_dict[word] = 0
  for word in answer:
    try:
      word_dict[word] += 1
    except:
      word_dict[word] = 0

word_dict['<PAD>'] = max(word_dict.values())+1

In [None]:
src_word_index = {k:v for v,k in enumerate(dict(sorted(word_dict.items(), key=lambda x: x[1], reverse=True)).keys())}
src_index_word = {k:v for v,k in src_word_index.items()}

* 워드 딕셔너리는 아주 잘 생성되었다.

In [None]:
def change_to_tensor(sentence, word_dict):
  sentence = [word_dict[word] for word in sentence]
  if len(sentence) < 15:
    for i in range(15 - len(sentence)):
      sentence.append(0)
  return sentence

In [None]:
total_tensor = [(change_to_tensor(question, src_word_index), change_to_tensor(answer, src_word_index))  for question, answer in total_corpus]

In [None]:
question_tensor = jnp.array([array[0] for array in total_tensor])
question_tensor

DeviceArray([[2422,  169, 3454, ...,    0,    0,    0],
             [ 278, 4538, 3455, ...,    0,    0,    0],
             [ 294, 2423,  603, ...,    0,    0,    0],
             ...,
             [6808,   12,   55, ...,    0,    0,    0],
             [ 199,  121,   11, ...,    0,    0,    0],
             [  67,  113,  159, ...,    0,    0,    0]], dtype=int32)

In [None]:
answer_tensor = jnp.array([array[1] for array in total_tensor])
answer_tensor

DeviceArray([[ 282,    7,  138, ...,    0,    0,    0],
             [ 527,   12, 1410, ...,    0,    0,    0],
             [ 251,   14,  686, ...,    0,    0,    0],
             ...,
             [2678,   23,   27, ...,    0,    0,    0],
             [  44, 1953,   42, ...,    0,    0,    0],
             [4417,  159,   14, ...,    0,    0,    0]], dtype=int32)

# Model architecture

In [None]:
from typing import Callable, Any, Optional

from flax import linen as nn
from flax import struct
from jax import lax
import jax.numpy as jnp
import numpy as np


# TransformerConfig

In [None]:
@struct.dataclass
class TransformerConfig:
  """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
  vocab_size: int = len(src_word_index)
  output_vocab_size: int = len(src_word_index)
  share_embeddings: bool = True
  logits_via_embedding: bool = False
  dtype: Any = jnp.float32
  emb_dim: int = 512
  num_heads: int = 8
  num_layers: int = 6
  qkv_dim: int = 512
  mlp_dim: int = 2048
  max_len: int = 2048
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  deterministic: bool = False
  decode: bool = False
  learning_rate: float = 0.01
  warmup_steps: int = 10
  kernel_init: Callable = nn.initializers.xavier_uniform()
  bias_init: Callable = nn.initializers.normal(stddev=1e-6)
  posemb_init: Optional[Callable] = None
  label_smoothing: float = 0.1

In [None]:
def shift_right(x, axis=1):
  """Shift the input to the right by padding on axis 1."""
  pad_widths = [(0, 0)] * len(x.shape)
  pad_widths[axis] = (1, 0)
  padded = jnp.pad(
      x, pad_widths, mode='constant', constant_values=x.dtype.type(0))
  return padded[:, :-1]

In [None]:
def sinusoidal_init(max_len=2048,
                    min_scale=1.0,
                    max_scale=10000.0):
  """1D Sinusoidal Position Embedding Initializer.

  Args:
      max_len: maximum possible length for the input.
      min_scale: float: minimum frequency-scale in sine grating.
      max_scale: float: maximum frequency-scale in sine grating.

  Returns:
      output: init function returning `(1, max_len, d_feature)`
  """

  def init(key, shape, dtype=np.float32):
    """Sinusoidal init."""
    del key, dtype
    d_feature = shape[-1]
    pe = np.zeros((max_len, d_feature), dtype=np.float32)
    position = np.arange(0, max_len)[:, np.newaxis]
    scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1)
    div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor)
    pe[:, :d_feature // 2] = np.sin(position * div_term)
    pe[:, d_feature // 2: 2 * (d_feature // 2)] = np.cos(position * div_term)
    pe = pe[np.newaxis, :, :]  # [1, max_len, d_feature]
    return jnp.array(pe)

  return init

In [None]:
class AddPositionEmbs(nn.Module):
  """Adds (optionally learned) positional embeddings to the inputs.

  Attributes:
    config: TransformerConfig dataclass containing hyperparameters.
    decode: whether to run in single-position autoregressive mode.
  """
  config: TransformerConfig
  decode: bool = False

  @nn.compact
  def __call__(self,
               inputs,
               inputs_positions=None):
    """Applies AddPositionEmbs module.

    By default this layer uses a fixed sinusoidal embedding table. If a
    learned position embedding is desired, pass an initializer to
    posemb_init in the configuration.

    Args:
      inputs: input data.
      inputs_positions: input position indices for packed sequences.

    Returns:
      output: `(bs, timesteps, in_dim)`
    """
    config = self.config
    # inputs.shape is (batch_size, seq_len, emb_dim)
    assert inputs.ndim == 3, ('Number of dimensions should be 3,'
                              ' but it is: %d' % inputs.ndim)
    length = inputs.shape[1]
    pos_emb_shape = (1, config.max_len, inputs.shape[-1])
    if config.posemb_init is None:
      # Use a fixed (non-learned) sinusoidal position embedding.
      pos_embedding = sinusoidal_init(max_len=config.max_len)(None,
                                                              pos_emb_shape,
                                                              None)
    else:
      pos_embedding = self.param('pos_embedding', config.posemb_init,
                                 pos_emb_shape)
    pe = pos_embedding[:, :length, :]

    # We use a cache position index for tracking decoding position.
    if self.decode:
      is_initialized = self.has_variable('cache', 'cache_index')
      cache_index = self.variable('cache', 'cache_index',
                                  lambda: jnp.array(0, dtype=jnp.uint32))
      if is_initialized:
        i = cache_index.value
        cache_index.value = i + 1
        _, _, df = pos_embedding.shape
        pe = lax.dynamic_slice(pos_embedding,
                               jnp.array((0, i, 0)),
                               (1, 1, df))
    if inputs_positions is None:
      # normal unpacked case:
      return inputs + pe
    else:
      # for packed data we need to use known position indices:
      return inputs + jnp.take(pe[0], inputs_positions, axis=0)

In [None]:
class MlpBlock(nn.Module):
  """Transformer MLP / feed-forward block.

  Attributes:
    config: TransformerConfig dataclass containing hyperparameters.
    out_dim: optionally specify out dimension.
  """
  config: TransformerConfig
  out_dim: Optional[int] = None

  @nn.compact
  def __call__(self, inputs):
    """Applies Transformer MlpBlock module."""
    config = self.config
    actual_out_dim = (inputs.shape[-1] if self.out_dim is None
                      else self.out_dim)
    x = nn.Dense(
        config.mlp_dim,
        dtype=config.dtype,
        kernel_init=config.kernel_init,
        bias_init=config.bias_init)(
            inputs)
    x = nn.relu(x)
    x = nn.Dropout(rate=config.dropout_rate)(
        x, deterministic=config.deterministic)
    output = nn.Dense(
        actual_out_dim,
        dtype=config.dtype,
        kernel_init=config.kernel_init,
        bias_init=config.bias_init)(
            x)
    output = nn.Dropout(rate=config.dropout_rate)(
        output, deterministic=config.deterministic)
    return output

In [None]:
class Encoder1DBlock(nn.Module):
  """Transformer encoder layer.

  Attributes:
    config: TransformerConfig dataclass containing hyperparameters.
  """
  config: TransformerConfig

  @nn.compact
  def __call__(self,
               inputs,
               encoder_mask=None):
    """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      encoder_mask: encoder self-attention mask.

    Returns:
      output after transformer encoder block.
    """
    config = self.config

    # Attention block.
    assert inputs.ndim == 3
    x = nn.LayerNorm(dtype=config.dtype)(inputs)
    x = nn.SelfAttention(
        num_heads=config.num_heads,
        dtype=config.dtype,
        qkv_features=config.qkv_dim,
        kernel_init=config.kernel_init,
        bias_init=config.bias_init,
        use_bias=False,
        broadcast_dropout=False,
        dropout_rate=config.attention_dropout_rate,
        deterministic=config.deterministic)(x, encoder_mask)

    x = nn.Dropout(rate=config.dropout_rate)(
        x, deterministic=config.deterministic)
    x = x + inputs

    # MLP block.
    y = nn.LayerNorm(dtype=config.dtype)(x)
    y = MlpBlock(config=config)(y)

    return x + y

In [None]:
class EncoderDecoder1DBlock(nn.Module):
  """Transformer encoder-decoder layer.

  Attributes:
    config: TransformerConfig dataclass containing hyperparameters.
  """
  config: TransformerConfig

  @nn.compact
  def __call__(self,
               targets,
               encoded,
               decoder_mask=None,
               encoder_decoder_mask=None):
    """Applies EncoderDecoder1DBlock module.

    Args:
      targets: input data for decoder
      encoded: input data from encoder
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

    Returns:
      output after transformer encoder-decoder block.
    """
    config = self.config

    # Decoder block.
    assert targets.ndim == 3
    x = nn.LayerNorm(dtype=config.dtype)(targets)
    x = nn.SelfAttention(
        num_heads=config.num_heads,
        dtype=config.dtype,
        qkv_features=config.qkv_dim,
        kernel_init=config.kernel_init,
        bias_init=config.bias_init,
        use_bias=False,
        broadcast_dropout=False,
        dropout_rate=config.attention_dropout_rate,
        deterministic=config.deterministic,
        decode=config.decode)(x, decoder_mask)
    x = nn.Dropout(rate=config.dropout_rate)(
        x, deterministic=config.deterministic)
    x = x + targets

    # Encoder-Decoder block.
    y = nn.LayerNorm(dtype=config.dtype)(x)
    y = nn.MultiHeadDotProductAttention(
        num_heads=config.num_heads,
        dtype=config.dtype,
        qkv_features=config.qkv_dim,
        kernel_init=config.kernel_init,
        bias_init=config.bias_init,
        use_bias=False,
        broadcast_dropout=False,
        dropout_rate=config.attention_dropout_rate,
        deterministic=config.deterministic)(y, encoded, encoder_decoder_mask)

    y = nn.Dropout(rate=config.dropout_rate)(
        y, deterministic=config.deterministic)
    y = y + x

    # MLP block.
    z = nn.LayerNorm(dtype=config.dtype)(y)
    z = MlpBlock(config=config)(z)

    return y + z

In [None]:
class Encoder(nn.Module):
  """Transformer Model Encoder for sequence to sequence translation.

  Attributes:
    config: TransformerConfig dataclass containing hyperparameters.
    shared_embedding: a shared embedding layer to use.
  """
  config: TransformerConfig
  shared_embedding: Any = None

  @nn.compact
  def __call__(self,
               inputs,
               inputs_positions=None,
               encoder_mask=None):
    """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      inputs_positions: input subsequence positions for packed examples.
      encoder_mask: decoder self-attention mask.

    Returns:
      output of a transformer encoder.
    """
    config = self.config
    assert inputs.ndim == 2  # (batch, len)

    # Input Embedding
    if self.shared_embedding is None:
      input_embed = nn.Embed(
          num_embeddings=config.vocab_size,
          features=config.emb_dim,
          embedding_init=nn.initializers.normal(stddev=1.0))
    else:
      input_embed = self.shared_embedding
    x = inputs.astype('int32')
    x = input_embed(x)
    x = AddPositionEmbs(
        config=config, decode=False, name='posembed_input')(
            x, inputs_positions=inputs_positions)
    x = nn.Dropout(rate=config.dropout_rate)(
        x, deterministic=config.deterministic)

    x = x.astype(config.dtype)

    # Input Encoder
    for lyr in range(config.num_layers):
      x = Encoder1DBlock(
          config=config, name=f'encoderblock_{lyr}')(x, encoder_mask)

    encoded = nn.LayerNorm(dtype=config.dtype, name='encoder_norm')(x)

    return encoded


In [None]:
class Decoder(nn.Module):
  """Transformer Model Decoder for sequence to sequence translation.

  Attributes:
    config: TransformerConfig dataclass containing hyperparameters.
    shared_embedding: a shared embedding layer to use.
  """
  config: TransformerConfig
  shared_embedding: Any = None

  @nn.compact
  def __call__(self,
               encoded,
               targets,
               targets_positions=None,
               decoder_mask=None,
               encoder_decoder_mask=None):
    """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      targets: target inputs.
      targets_positions: input subsequence positions for packed examples.
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

    Returns:
      output of a transformer decoder.
    """
    config = self.config

    assert encoded.ndim == 3  # (batch, len, depth)
    assert targets.ndim == 2  # (batch, len)

    # Target Embedding
    if self.shared_embedding is None:
      output_embed = nn.Embed(
          num_embeddings=config.output_vocab_size,
          features=config.emb_dim,
          embedding_init=nn.initializers.normal(stddev=1.0))
    else:
      output_embed = self.shared_embedding

    y = targets.astype('int32')
    if not config.decode:
      y = shift_right(y)
    y = output_embed(y)
    y = AddPositionEmbs(
        config=config, decode=config.decode, name='posembed_output')(
            y, inputs_positions=targets_positions)
    y = nn.Dropout(rate=config.dropout_rate)(
        y, deterministic=config.deterministic)

    y = y.astype(config.dtype)

    # Target-Input Decoder
    for lyr in range(config.num_layers):
      y = EncoderDecoder1DBlock(
          config=config, name=f'encoderdecoderblock_{lyr}')(
              y,
              encoded,
              decoder_mask=decoder_mask,
              encoder_decoder_mask=encoder_decoder_mask)
    y = nn.LayerNorm(dtype=config.dtype, name='encoderdecoder_norm')(y)

    # Decoded Logits
    if config.logits_via_embedding:
      # Use the transpose of embedding matrix for logit transform.
      logits = output_embed.attend(y.astype(jnp.float32))
      # Correctly normalize pre-softmax logits for this shared case.
      logits = logits / jnp.sqrt(y.shape[-1])
    else:
      logits = nn.Dense(
          config.output_vocab_size,
          dtype=config.dtype,
          kernel_init=config.kernel_init,
          bias_init=config.bias_init,
          name='logitdense')(
              y)
    return logits

In [None]:
class Transformer(nn.Module):
  """Transformer Model for sequence to sequence translation.

  Attributes:
    config: TransformerConfig dataclass containing hyperparameters.
  """
  config: TransformerConfig

  def setup(self):
    config = self.config

    if config.share_embeddings:
      if config.output_vocab_size is not None:
        assert config.output_vocab_size == config.vocab_size, (
            "can't share embedding with different vocab sizes.")
      self.shared_embedding = nn.Embed(
          num_embeddings=config.vocab_size,
          features=config.emb_dim,
          embedding_init=nn.initializers.normal(stddev=1.0))
    else:
      self.shared_embedding = None

    self.encoder = Encoder(
        config=config, shared_embedding=self.shared_embedding)
    self.decoder = Decoder(
        config=config, shared_embedding=self.shared_embedding)

  def encode(self,
             inputs,
             inputs_positions=None,
             inputs_segmentation=None):
    """Applies Transformer encoder-branch on the inputs.

    Args:
      inputs: input data.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      encoded feature array from the transformer encoder.
    """
    config = self.config
    # Make padding attention mask.
    encoder_mask = nn.make_attention_mask(
        inputs > 0, inputs > 0, dtype=config.dtype)
    # Add segmentation block-diagonal attention mask if using segmented data.
    if inputs_segmentation is not None:
      encoder_mask = nn.combine_masks(
          encoder_mask,
          nn.make_attention_mask(
              inputs_segmentation,
              inputs_segmentation,
              jnp.equal,
              dtype=config.dtype))
    return self.encoder(
        inputs,
        inputs_positions=inputs_positions,
        encoder_mask=encoder_mask)

  def decode(self,
             encoded,
             inputs,  # only needed for masks
             targets,
             targets_positions=None,
             inputs_segmentation=None,
             targets_segmentation=None):
    """Applies Transformer decoder-branch on encoded-input and target.

    Args:
      encoded: encoded input data from encoder.
      inputs: input data (only needed for masking).
      targets: target data.
      targets_positions: target subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.

    Returns:
      logits array from transformer decoder.
    """
    config = self.config

    # Make padding attention masks.
    if config.decode:
      # for fast autoregressive decoding only a special encoder-decoder mask is used
      decoder_mask = None
      encoder_decoder_mask = nn.make_attention_mask(
          jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype)
    else:
      decoder_mask = nn.combine_masks(
          nn.make_attention_mask(targets > 0, targets > 0, dtype=config.dtype),
          nn.make_causal_mask(targets, dtype=config.dtype))
      encoder_decoder_mask = nn.make_attention_mask(
          targets > 0, inputs > 0, dtype=config.dtype)

    # Add segmentation block-diagonal attention masks if using segmented data.
    if inputs_segmentation is not None:
      decoder_mask = nn.combine_masks(
          decoder_mask,
          nn.make_attention_mask(
              targets_segmentation,
              targets_segmentation,
              jnp.equal,
              dtype=config.dtype))
      encoder_decoder_mask = nn.combine_masks(
          encoder_decoder_mask,
          nn.make_attention_mask(
              targets_segmentation,
              inputs_segmentation,
              jnp.equal,
              dtype=config.dtype))
    logits = self.decoder(
        encoded,
        targets,
        targets_positions=targets_positions,
        decoder_mask=decoder_mask,
        encoder_decoder_mask=encoder_decoder_mask)
    return logits.astype(self.config.dtype)

  def __call__(self,
               inputs,
               targets,
               inputs_positions=None,
               targets_positions=None,
               inputs_segmentation=None,
               targets_segmentation=None):
    """Applies Transformer model on the inputs.

    Args:
      inputs: input data.
      targets: target data.
      inputs_positions: input subsequence positions for packed examples.
      targets_positions: target subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.

    Returns:
      logits array from full transformer.
    """
    encoded = self.encode(inputs,
                          inputs_positions=inputs_positions,
                          inputs_segmentation=inputs_segmentation)

    return self.decode(encoded,
                       inputs,  # only used for masks
                       targets,
                       targets_positions=targets_positions,
                       inputs_segmentation=inputs_segmentation,
                       targets_segmentation=targets_segmentation)

# run 

In [None]:
config = TransformerConfig()

In [None]:
x = jnp.ones(shape=(128, 15),dtype='bfloat16')
y = jnp.ones(shape=(128, 15),dtype='bfloat16')
rng, init_rng, dropout_rng = jax.random.split(random.PRNGKey(42), 3)
model = Transformer(config)
initial_variables = model.init({'params': init_rng, 'dropout': dropout_rng},
                                  x,
                                  y,
                                  )
jax.tree_map(lambda x: x.shape, initial_variables)

FrozenDict({
    params: {
        decoder: {
            encoderdecoder_norm: {
                bias: (512,),
                scale: (512,),
            },
            encoderdecoderblock_0: {
                LayerNorm_0: {
                    bias: (512,),
                    scale: (512,),
                },
                LayerNorm_1: {
                    bias: (512,),
                    scale: (512,),
                },
                LayerNorm_2: {
                    bias: (512,),
                    scale: (512,),
                },
                MlpBlock_0: {
                    Dense_0: {
                        bias: (2048,),
                        kernel: (512, 2048),
                    },
                    Dense_1: {
                        bias: (512,),
                        kernel: (2048, 512),
                    },
                },
                MultiHeadDotProductAttention_0: {
                    key: {
                        kernel: (512, 8, 64),
  

In [None]:
jax.tree_map(lambda y: y.shape, initial_variables)

FrozenDict({
    params: {
        decoder: {
            encoderdecoder_norm: {
                bias: (512,),
                scale: (512,),
            },
            encoderdecoderblock_0: {
                LayerNorm_0: {
                    bias: (512,),
                    scale: (512,),
                },
                LayerNorm_1: {
                    bias: (512,),
                    scale: (512,),
                },
                LayerNorm_2: {
                    bias: (512,),
                    scale: (512,),
                },
                MlpBlock_0: {
                    Dense_0: {
                        bias: (2048,),
                        kernel: (512, 2048),
                    },
                    Dense_1: {
                        bias: (512,),
                        kernel: (2048, 512),
                    },
                },
                MultiHeadDotProductAttention_0: {
                    key: {
                        kernel: (512, 8, 64),
  

In [None]:
nn.tabulate(model, rngs={'params': init_rng, 'dropout': dropout_rng})(x, y)

ValueError: ignored

* exclude_methods가 필요하다!

* 어떻게 필요하냐면...
* multiple intermediates라고 한다. `['__call__', 'decode', 'encode']`가 있고, `'__call__'`을 제외한 다른 intermediates를 exclude_methods에 넣어야 한다. 
* Transformer class안에 `encode`라는 함수와 `decode`라는 함수, 그리고 `__call__`이라는 함수가 있는데, 우리는 `__call__`이 필요하니 나머지 함수들은 메소드들은 지우는 것이다. 

```python
exclude_methods=['encode','decode']
```

In [None]:
nn.tabulate(model, rngs={'params': init_rng, 'dropout': dropout_rng}, exclude_methods=['encode','decode'])(x, y)

'\n\n'

In [None]:
class QADataset(data.Dataset):

    def __init__(self, vocab_size, max_len, src, tgt):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.src = src
        self.tgt = tgt

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        src = self.src[idx]
        tgt = self.tgt[idx]
        return src, tgt

In [None]:
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


qa_train_loader = data.DataLoader(QADataset(vocab_size=len(src_word_index), max_len=len(question_tensor[0]),src=question_tensor, tgt=answer_tensor),
                                   batch_size=128,
                                   shuffle=True,
                                   drop_last=True,
                                   collate_fn=numpy_collate)

In [None]:
def rsqrt_schedule(
    init_value: float,
    shift: int = 0,
):
  """Applies a reverse square-root schedule.

  The reverse square root schedule is simply `lr = init_value / sqrt(step)`.

  Args:
    init_value: Base learning rate (before applying the rsqrt schedule).
    shift: How many steps the rsqrt should be shifted. Shifting the rsqrt
      schedule makes it less steep in the beginning (close to 0).

  Returns:
    A schedule `count -> learning_rate`.
  """

  def schedule(count):
    return init_value * (count + shift)**-.5 * shift**.5

  return schedule


def create_learning_rate_schedule(learning_rate: float, warmup_steps: int):
  """Creates a rsqrt schedule with linear warmup."""
  return optax.join_schedules([
      optax.linear_schedule(
          init_value=0, end_value=learning_rate, transition_steps=warmup_steps),
      rsqrt_schedule(init_value=learning_rate, shift=warmup_steps),
  ],
                              boundaries=[warmup_steps])


def compute_weighted_cross_entropy(logits,
                                   targets,
                                   weights=None,
                                   label_smoothing=0.0):
  """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   weights: None or array of shape [batch, length].
   label_smoothing: label smoothing constant, used to determine the on and off
     values.

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
  if logits.ndim != targets.ndim + 1:
    raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" %
                     (str(logits.shape), str(targets.shape)))
  vocab_size = logits.shape[-1]
  confidence = 1.0 - label_smoothing
  low_confidence = (1.0 - confidence) / (vocab_size - 1)
  normalizing_constant = -(
      confidence * jnp.log(confidence) +
      (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
  soft_targets = common_utils.onehot(
      targets, vocab_size, on_value=confidence, off_value=low_confidence)

  loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
  loss = loss - normalizing_constant

  normalizing_factor = np.prod(targets.shape)
  if weights is not None:
    loss = loss * weights
    normalizing_factor = weights.sum()

  return loss.sum(), normalizing_factor


def compute_weighted_accuracy(logits, targets, weights=None):
  """Compute weighted accuracy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   weights: None or array of shape [batch, length]

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
  if logits.ndim != targets.ndim + 1:
    raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" %
                     (str(logits.shape), str(targets.shape)))
  loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
  normalizing_factor = np.prod(logits.shape[:-1])
  if weights is not None:
    loss = loss * weights
    normalizing_factor = weights.sum()

  return loss.sum(), normalizing_factor


def compute_metrics(logits, labels, weights, label_smoothing=0.0):
  """Compute summary metrics."""
  loss, weight_sum = compute_weighted_cross_entropy(logits, labels, weights,
                                                    label_smoothing)
  acc, _ = compute_weighted_accuracy(logits, labels, weights)
  metrics = {
      "loss": loss,
      "accuracy": acc,
      "denominator": weight_sum,
  }
  metrics = jax.lax.psum(metrics, axis_name="batch")
  return metrics


# Primary training / eval / decode step functions.
# -----------------------------------------------------------------------------


def train_step(state,
               batch,
               config,
               learning_rate_fn,
               label_smoothing=0.0,
               dropout_rng=None):
  """Perform a single training step."""
  # X_position and X_segmentation are needed only when using "packed examples"
  # where multiple sequences are packed into the same example with this
  # metadata.
  # if such features are not present they are ignored and the example is treated
  # like a normal, unpacked sequence example.
  (inputs, targets) = batch

  weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32)

  dropout_rng = jax.random.fold_in(dropout_rng, state.step)

  def loss_fn(params):
    """loss function used for training."""
    logits = Transformer(config).apply(
        params,
        inputs,
        targets,
        inputs_positions=None,
        targets_positions=None,
        inputs_segmentation=None,
        targets_segmentation=None,
        rngs={"dropout": dropout_rng})

    loss, weight_sum = compute_weighted_cross_entropy(logits, targets, weights,
                                                      label_smoothing)
    mean_loss = loss / weight_sum
    return mean_loss, logits
  step = state.step


  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  grads = jax.lax.pmean(grads, axis_name="batch")
  new_state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits, targets, weights)
  metrics["learning_rate"] = learning_rate_fn(step)
  

  return new_state, metrics, grads, logits, targets

In [None]:
def compute_weighted_accuracy(logits, targets, weights=None):
  """Compute weighted accuracy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   weights: None or array of shape [batch, length]

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
  if logits.ndim != targets.ndim + 1:
    raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" %
                     (str(logits.shape), str(targets.shape)))
  loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
  normalizing_factor = np.prod(logits.shape[:-1])
  if weights is not None:
    loss = loss * weights
    normalizing_factor = weights.sum()

  return loss.sum(), normalizing_factor

# Run


In [None]:
learning_rate_fn = create_learning_rate_schedule(
    learning_rate=config.learning_rate, warmup_steps=config.warmup_steps)

In [None]:
epochs = 100

In [None]:
lr_schedule = optax.warmup_cosine_decay_schedule(
            init_value=0.0,
            peak_value=0.02,
            warmup_steps=100,
            decay_steps=epochs*len(qa_train_loader),
            end_value=0.0
        )
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),  # Clip gradients at norm 1
    optax.adam(lr_schedule)
)

In [None]:
p_train_step = jax.pmap(
    functools.partial(
        train_step,
        config=config,
        learning_rate_fn=lr_schedule,
        label_smoothing=config.label_smoothing),
    axis_name="batch",
    donate_argnums=(0,)) 

In [None]:
state = train_state.TrainState.create(apply_fn=model.apply, params=initial_variables, tx=optimizer)


* pmap에 넣으려면 state를 replicate를 해서 8개로 나누어야 한다. 

In [None]:
state = flax.jax_utils.replicate(state,jax.local_devices())

* dropout에 필요한 `rng` (랜덤 키들)도 8개로 나누어 준다.

In [None]:
dropout_rngs = jax.random.split(rng, jax.local_device_count())

* pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

* dropout rng 나 다른 모든 입력들은 pmap모양으로 나누어야 한다.

In [None]:
train_loader = iter(qa_train_loader)
batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_loader)))
state, metrics, grads, logits, targets = p_train_step(state, batch, dropout_rng=dropout_rngs)

In [None]:
logits.shape

(8, 16, 15, 7690)

In [None]:
targets.shape

(8, 16, 15)

In [None]:
targets

ShardedDeviceArray([[[ 290,   73, 1184, ...,    0,    0,    0],
                     [ 764,    6, 1928, ...,    0,    0,    0],
                     [   3,  397,  216, ...,    0,    0,    0],
                     ...,
                     [3563,  859,  721, ...,    0,    0,    0],
                     [ 662,    3,   21, ...,    0,    0,    0],
                     [  48, 2046,  352, ...,    0,    0,    0]],

                    [[  64,    3,   39, ...,    0,    0,    0],
                     [ 496,  779,   24, ...,    0,    0,    0],
                     [ 658,   19,    3, ...,    0,    0,    0],
                     ...,
                     [ 223,   24,  553, ...,    0,    0,    0],
                     [1096,   12,   13, ...,    0,    0,    0],
                     [ 741,    2,  885, ...,    0,    0,    0]],

                    [[ 598,    2,  543, ...,    0,    0,    0],
                     [ 176,  336,    2, ...,    0,    0,    0],
                     [  51,   96,   57, ...,    

In [None]:
weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32)
print(weights.shape)

(8, 16, 15)


In [None]:
jnp.argmax(logits, axis=-1).shape

(8, 16, 15)

In [None]:
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
normalizing_factor = np.prod(logits.shape[:-1])
print(loss)
loss = loss * weights
normalizing_factor = weights.sum()
print(loss.shape)

[[[False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]
  ...
  [False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]]

 [[False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]
  ...
  [False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]]

 [[False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]
  ...
  [False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]]

 ...

 [[False False False ... False False False]
  [False False False ... False False False]
  [False False False ... False False False]
  ...
  [False False False ... False False Fal

In [None]:
loss.sum()

DeviceArray(0., dtype=float32)

In [None]:
8*16*15

1920

In [None]:
loss, weight_sum = compute_weighted_cross_entropy(logits, targets, weights,
                                                    label_smoothing=0.1)

## train

In [None]:
for step in range(epochs):
    for batch_idx in tqdm(range(len(qa_train_loader))):
        train_loader = iter(qa_train_loader)
        batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_loader)))
        state, metrics, grads = p_train_step(state, batch, dropout_rng=dropout_rngs)
        for key in metrics.keys():
           metrics[key] = np.array(np.mean(metrics[key]))
        wandb.log(metrics)
        flatten_grad = frozen_dict_gradient_flatten(grads)
        wandb_histo = {layers:wandb.Histogram(grads) for layers, grads in flatten_grad.items()}
        wandb.log(wandb_histo)
    print(f'TRAIN ({step}/{epochs}): metrics : {metrics}')
    

100%|██████████| 258/258 [1:24:59<00:00, 19.76s/it]


TRAIN (0/100): metrics : {'accuracy': array(246., dtype=float32), 'denominator': array(985., dtype=float32), 'learning_rate': array(0.01999814, dtype=float32), 'loss': array(4480.2627, dtype=float32)}


  1%|          | 3/258 [01:11<1:41:15, 23.82s/it]

In [None]:
for key in metrics.keys():
  metrics[key] = np.array(metrics[key])
print(metrics)

#  Flax용 wandb gradients

## wandb.keras.WandbCallback

https://github.com/wandb/wandb/blob/0a3b035d0fb206570660275503c8b72f8d7b4399/wandb/integration/keras/keras.py#L934

* 케라스용 wandb에서 그레디언트를 어떻게 시각화 하고 있는지 알아보도록 하자.

In [None]:
def _log_gradients(self):
    # Suppress callback warnings grad accumulator

    self._grad_accumulator_model.fit(
        self._training_data_x,
        self._training_data_y,
        verbose=0,
        callbacks=[self._grad_accumulator_callback],
    )

    weights = self.model.trainable_weights
    grads = self._grad_accumulator_callback.grads
    metrics = {}
    for (weight, grad) in zip(weights, grads):
        metrics[
            "gradients/" + weight.name.split(":")[0] + ".gradient"
        ] = wandb.Histogram(grad)
    return metrics

In [None]:
import tensorflow as tf

In [None]:
class _GradAccumulatorCallback(tf.keras.callbacks.Callback):
    """
    Accumulates gradients during a fit() call when used in conjunction with
    the CustomOptimizer above.
    """

    def set_model(self, model):
        super().set_model(model)
        self.og_weights = model.get_weights()
        self.grads = [np.zeros(tuple(w.shape)) for w in model.trainable_weights]

    def on_batch_end(self, batch, logs=None):
        for g, w in zip(self.grads, self.model.trainable_weights):
            g += w.numpy()
        self.model.set_weights(self.og_weights)

    def get_grads(self):
        return [g.copy() for g in self.grads]

In [None]:
_grad_accumulator_callback = _GradAccumulatorCallback()

In [None]:
model = tf.keras.Sequential(
    [
        tf.keras.layers.Dense(2, activation="relu", name="layer1"),
        tf.keras.layers.Dense(3, activation="relu", name="layer2"),
        tf.keras.layers.Dense(3, name="layer3"),
    ]
)
# Call model on a test input
x = tf.random.uniform((3, 3),maxval=5)
loss = tf.keras.losses.MeanSquaredError()
a = [0,1,2]
y_true = tf.convert_to_tensor([a,a,a])
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)


model.build(input_shape=(3,3))

model.compile(loss=loss, optimizer=optimizer)

model.fit(x, y_true, verbose=0, callbacks=[_grad_accumulator_callback],)




<keras.callbacks.History at 0x7f4a98ec3e90>

In [None]:
y_true = [0,1,3,4]
print(y_true)

[0, 1, 3, 4]


In [None]:
_grad_accumulator_callback.grads

[array([[ 0.03186357, -0.71543443],
        [-0.88569725,  0.74210143],
        [-0.55287474,  0.83285022]]),
 array([ 0.        , -0.01328947]),
 array([[-0.5680694 ,  0.26485944, -0.36831582],
        [ 0.43944037, -0.10591018, -0.66314888]]),
 array([-0.025164,  0.      ,  0.      ]),
 array([[-0.17332166,  0.63437182, -0.9452166 ],
        [ 0.62420583, -0.67187238,  0.50105071],
        [ 0.71027231, -0.93712902, -0.81261611]]),
 array([ 0.00202704, -0.00061873,  0.02465536])]

In [None]:
grad = _grad_accumulator_callback.grads

In [None]:
metric = {}
for n, i in enumerate(grad):
  metric[str(n)] = wandb.Histogram(i)

In [None]:
wandb.log(metric)

* array로 된 gradient값들을 딕셔너리로 층층이 `wandb.Histogram(grad)`에 담기게 된다. 

```python
>>> print(metric)

{'레이어 이름' : wandb.Histogram(grad),
'레이어 이름' : wandb.Histogram(grad)....}
```

* 생성한 metric 딕셔너리를 wandb.log에 담으면 끝.

In [None]:
metric = {}
for (weight, grad) in zip(weights, grads):
            metrics[
                "gradients/" + weight.name.split(":")[0] + ".gradient"
            ] = wandb.Histogram(grad)

## Flax의 그레디언트는?

In [None]:
def frozen_dict_gradient_flatten(some_dict):
  flatten_dict = {}

  def find_keys(some_dict, sub_key='Main'):

    count = 0 
    for key in some_dict.keys():
      try:        

        find_keys(some_dict[key], sub_key=f'{sub_key}/{key}')
        count += 1
      except:
        path = f'{sub_key}/{key}'
        flatten_dict[path] = some_dict[key]


  find_keys(some_dict)

  return flatten_dict

* `train_step`의 pmap한 `p_train_step`함수에서 gradient를 전달하는 방식은 다음과 같다.

```python
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grads = grad_fn(state.params)
  grads = jax.lax.pmean(grads, axis_name="batch")
  new_state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits, targets, weights)
  metrics["learning_rate"] = learning_rate_fn(step)
```

jax.value_and_grad에 오차함수를 has_aux인자와 함께 넣어서 새로운 grad_fn을 만든다.

이후에 우리의 `state`에 `.params`를 `grad_fn`에 넣고 grad변수를 `apply_gradient`로 적용시킨다.

`jax.lax.pmean`은 8개로 분산시킨 배열의 gradient의 평균으로 정규화한다는 것이다.

`grad`변수는 frozen_dict으로 되어있으며, 이것을 가지고 `wandb.log`에 집어넣는 것이 가능하리라 생각이 된다.

In [None]:
flatten_grad = frozen_dict_gradient_flatten(grads)

In [None]:
flatten_grad = frozen_dict_gradient_flatten(grads)
wandb_histo = {layers: wandb.Histogram(grads) for layers, grads in flatten_grad.items()}
print(wandb_histo)

{'Main/params/decoder/encoderdecoder_norm/bias': <wandb.sdk.data_types.histogram.Histogram object at 0x7f4a986d2510>, 'Main/params/decoder/encoderdecoder_norm/scale': <wandb.sdk.data_types.histogram.Histogram object at 0x7f4a986aa8d0>, 'Main/params/decoder/encoderdecoderblock_0/LayerNorm_0/bias': <wandb.sdk.data_types.histogram.Histogram object at 0x7f4a99586f10>, 'Main/params/decoder/encoderdecoderblock_0/LayerNorm_0/scale': <wandb.sdk.data_types.histogram.Histogram object at 0x7f4a9a86ea10>, 'Main/params/decoder/encoderdecoderblock_0/LayerNorm_1/bias': <wandb.sdk.data_types.histogram.Histogram object at 0x7f4a9a86ee50>, 'Main/params/decoder/encoderdecoderblock_0/LayerNorm_1/scale': <wandb.sdk.data_types.histogram.Histogram object at 0x7f4a99b19ad0>, 'Main/params/decoder/encoderdecoderblock_0/LayerNorm_2/bias': <wandb.sdk.data_types.histogram.Histogram object at 0x7f4a99b192d0>, 'Main/params/decoder/encoderdecoderblock_0/LayerNorm_2/scale': <wandb.sdk.data_types.histogram.Histogram ob

In [None]:
wandb.finish()