In [None]:
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
#@title Import
from typing import Union

# Haiku permet aux utilisateurs d'utiliser des modèles de POO tout en permettant un accès complet aux transformations de fonctions pures de JAX pour les reseaux neuronaux.
import haiku as hk
# JAX est concue pour le calcul numérique à haute performance, notamment pour la recherche en apprentissage automatique.
import jax
import jax.numpy as jnp
# NumPy ajoute un support pour les grands tableaux multidimensionnels et les matrices, ainsi qu'une grande collection de fonctions mathématiques pour opérer sur ces tableaux.
import numpy as np
# Pickle est un module permettant la (dé)sérialisation
import pickle

import perceiver
import position_encoding
import io_processors
import bytes_tokenizer

ModuleNotFoundError: No module named 'perceiver'

In [None]:
#@title Load parameters from checkpoint
#!wget -O language_perceiver_io_bytes.pickle https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle

# On deserialise les valeurs des hyperparametres du modele.
with open("language_perceiver_io_bytes.pickle", "rb") as f:
  params = pickle.loads(f.read())

In [None]:
#@title Model config
D_MODEL = 768
D_LATENTS = 1280
MAX_SEQ_LEN = 2048

encoder_config = dict(
    num_self_attends_per_block=26,
    num_blocks=1,
    z_index_dim=256,
    num_z_channels=D_LATENTS,
    num_self_attend_heads=8,
    num_cross_attend_heads=8,
    qk_channels=8 * 32,
    v_channels=D_LATENTS,
    use_query_residual=True,
    cross_attend_widening_factor=1,
    self_attend_widening_factor=1)

decoder_config = dict(
    output_num_channels=D_LATENTS,
    position_encoding_type='trainable',
    output_index_dims=MAX_SEQ_LEN,
    num_z_channels=D_LATENTS,
    qk_channels=8 * 32,
    v_channels=D_MODEL,
    num_heads=8,
    final_project=False,
    use_query_residual=False,
    trainable_position_encoding_kwargs=dict(num_channels=D_MODEL))

# The tokenizer is just UTF-8 encoding (with an offset)
tokenizer = bytes_tokenizer.BytesTokenizer()

In [None]:
#@title Decoding Perceiver Model
def apply_perceiver(
    inputs: jnp.ndarray, input_mask: jnp.ndarray) -> jnp.ndarray:
  """Runs a forward pass on the Perceiver.

  Args:
    inputs: input bytes, an int array of shape [B, T]
    input_mask: Array of shape indicating which entries are valid and which are
      masked. A truthy value indicates that the entry is valid.

  Returns:
    The output logits, an array of shape [B, T, vocab_size].
  """
  assert inputs.shape[1] == MAX_SEQ_LEN

  # Dimensions choisies des caracteristiques latentes
  embedding_layer = hk.Embed(
      vocab_size=tokenizer.vocab_size,
      embed_dim=D_MODEL)
  # Encodage et reduction de l'input a l'espace des caracteristiques latentes
  embedded_inputs = embedding_layer(inputs)

  # Taille de la reduction de l'input a l'espace des caracteristiques latentes
  batch_size = embedded_inputs.shape[0]

  # Encodage de la partie masquee de l'input pour que le transformer sache la localiser
  input_pos_encoding = perceiver.position_encoding.TrainablePositionEncoding(
      index_dim=MAX_SEQ_LEN, num_channels=D_MODEL)
  # Fusion de la reduction de l'input et de la position encodee de la partie masquee
  embedded_inputs = embedded_inputs + input_pos_encoding(batch_size)
  perceiver_mod = perceiver.Perceiver(
      encoder=perceiver.PerceiverEncoder(**encoder_config),
      decoder=perceiver.BasicDecoder(**decoder_config))
  # Stockage dans output du resultat de l'execution du modele sur l'input encode
  output_embeddings = perceiver_mod(
      embedded_inputs, is_training=False, input_mask=input_mask, query_mask=input_mask)

  # Redimensionnement et decodage de l'output
  logits = io_processors.EmbeddingDecoder(
      embedding_matrix=embedding_layer.embeddings)(output_embeddings)
  return logits

# La transformation permettra plus tard a la fonction de lui passer des parametres pre-enregistres
apply_perceiver = hk.transform(apply_perceiver).apply

In [None]:
# Input sous format initial
input_str = "This is the best thing that happened to me!"
# On marque les coordonnees a predire
input_tokens_positions = [
    input_str.index(" is"),input_str.index("the best")
]
# vectorisation en bytes de l'input, sous forme d'entiers
input_tokens = tokenizer.to_int(input_str)

# Note that the model performs much better if the masked chunk
# starts with a space.
# On masque la partie de l'input a predire avec une constante de tokenizer
input_tokens[input_tokens_positions[0]:input_tokens_positions[1]] = tokenizer.mask_token
print("Tokenized string without masked bytes:")
print(tokenizer.to_string(input_tokens))

In [None]:
#@title Pad and reshape inputs
#print(input_tokens)
# inputs est le tableau d'entiers avec une dimension supplementaire
# c'est comme si on mettait le tableau inputs a l'interieur d'un nouveau tableau
inputs = input_tokens[None]
#print(inputs)
# input_mask est l'equivalent unitaire de inputs :
# pour toutes les valeurs d'inputs, on met un 1
input_mask = np.ones_like(inputs)
#print(input_mask)

# Agrandit les tableaux inputs et input_mask pour qu'ils aient
# une taille de max_sequence_length, et les remplit avec des 0
def pad(max_sequence_length: int, inputs, input_mask):
  input_len = inputs.shape[1]
  assert input_len <= max_sequence_length
  pad_len = max_sequence_length - input_len
  padded_inputs = np.pad(
      inputs,
      pad_width=((0, 0), (0, pad_len)),
      constant_values=tokenizer.pad_token)
  padded_mask = np.pad(
      input_mask,
      pad_width=((0, 0), (0, pad_len)),
      constant_values=0)
  return padded_inputs, padded_mask

inputs, input_mask = pad(MAX_SEQ_LEN, inputs, input_mask)
#print(inputs)
#print(input_mask)

In [None]:
rng = jax.random.PRNGKey(1)  # Unused

out = apply_perceiver(params, rng=rng, inputs=inputs, input_mask=input_mask)

# la prediction sous forme d'entiers est constituee des valeurs maximales
# aux coordonnees du masque sur la 1ere dimension de out
#print(out)
masked_tokens_predictions = out[0, input_tokens_positions[0]:input_tokens_positions[1]].argmax(axis=-1)
print("Greedy predictions:")
print(masked_tokens_predictions) #predictions sous formes de caracteres
print()
print("Predicted string:")
print(tokenizer.to_string(masked_tokens_predictions))