# **Perceiver IO model implementation**
A General Architecture for Structured Inputs & Outputs by Deepmind
* It is considered as a generalized architecture that solves quadratic time and space complexity with transformers which occurs due to attention mechanism.
* It is a generalization of Perceiver to handle arbitrary outputs in addition to arbitrary inputs. 
* The original Perceiver only produced a single classification label. 
* In addition to classification labels, Perceiver IO can produce (for example) language, optical flow, and multimodal videos with audio. 
* This is done using the same building blocks as the original Perceiver. 
* The computational complexity of Perceiver IO is linear in the input and output size and the bulk of the processing occurs in the latent space, allowing us to process inputs and outputs that are much larger than can be handled by standard Transformers. 
* This means, for example, Perceiver IO can do BERT-style masked language modeling directly using bytes instead of tokenized inputs.


# **Masked language modeling**
It involves: input a sentence into the model and optimizing the weights inside model to output the same sentence on the other side.
* Normally BERT can be best suited for tasks such as MLM, this notebook demonstrates using perceiver IO for same and achieves the goal.

## **References:**
* https://github.com/2796gaurav/code_examples/blob/main/Perceiver/Perceiver_masked_language_modelling.ipynb
* https://medium.com/analytics-vidhya/perceiver-io-a-general-architecture-for-structured-inputs-outputs-4ad669315e7f
* https://deepmind.com/research/open-source/perceiver-IO


In [1]:
# Install dependencies for Google Colab.
# If you want to run this notebook on your own machine, you can skip this cell
!pip install dm-haiku
!pip install einops

!mkdir /content/perceiver
!touch /content/perceiver/__init__.py
!wget -O /content/perceiver/bytes_tokenizer.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/bytes_tokenizer.py
!wget -O /content/perceiver/io_processors.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/io_processors.py
!wget -O /content/perceiver/perceiver.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/perceiver.py
!wget -O /content/perceiver/position_encoding.py https://raw.githubusercontent.com/deepmind/deepmind-research/master/perceiver/position_encoding.py

Collecting dm-haiku
  Downloading dm_haiku-0.0.4-py3-none-any.whl (284 kB)
[?25l[K     |█▏                              | 10 kB 15.0 MB/s eta 0:00:01[K     |██▎                             | 20 kB 18.3 MB/s eta 0:00:01[K     |███▌                            | 30 kB 14.4 MB/s eta 0:00:01[K     |████▋                           | 40 kB 10.3 MB/s eta 0:00:01[K     |█████▊                          | 51 kB 5.6 MB/s eta 0:00:01[K     |███████                         | 61 kB 5.6 MB/s eta 0:00:01[K     |████████                        | 71 kB 5.2 MB/s eta 0:00:01[K     |█████████▏                      | 81 kB 5.8 MB/s eta 0:00:01[K     |██████████▍                     | 92 kB 5.7 MB/s eta 0:00:01[K     |███████████▌                    | 102 kB 5.3 MB/s eta 0:00:01[K     |████████████▊                   | 112 kB 5.3 MB/s eta 0:00:01[K     |█████████████▉                  | 122 kB 5.3 MB/s eta 0:00:01[K     |███████████████                 | 133 kB 5.3 MB/s eta 0:00:01

In [2]:
# Imports
from typing import Union

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import pickle

from perceiver import perceiver, position_encoding, io_processors, bytes_tokenizer

## **Loading parameters from checkpoint**

In [3]:
## loading the pickle file
!wget -O language_perceiver_io_bytes.pickle https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle

with open("language_perceiver_io_bytes.pickle", "rb") as f:
  params = pickle.loads(f.read())

--2021-10-10 22:39:59--  https://storage.googleapis.com/perceiver_io/language_perceiver_io_bytes.pickle
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.211.128, 173.194.213.128, 173.194.214.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.211.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 804479532 (767M) [application/octet-stream]
Saving to: ‘language_perceiver_io_bytes.pickle’


2021-10-10 22:40:05 (139 MB/s) - ‘language_perceiver_io_bytes.pickle’ saved [804479532/804479532]



## **Model configuration**

In [4]:
## Contruction the model with proper hyperparameters
D_MODEL = 768
D_LATENTS = 1280
MAX_SEQ_LEN = 2048

encoder_config = dict(
    num_self_attends_per_block=26,
    num_blocks=1,
    z_index_dim=256,
    num_z_channels=D_LATENTS,
    num_self_attend_heads=8,
    num_cross_attend_heads=8,
    qk_channels=8 * 32,
    v_channels=D_LATENTS,
    use_query_residual=True,
    cross_attend_widening_factor=1,
    self_attend_widening_factor=1)

decoder_config = dict(
    output_num_channels=D_LATENTS,
    position_encoding_type='trainable',
    output_index_dims=MAX_SEQ_LEN,
    num_z_channels=D_LATENTS,
    qk_channels=8 * 32,
    v_channels=D_MODEL,
    num_heads=8,
    final_project=False,
    use_query_residual=False,
    trainable_position_encoding_kwargs=dict(num_channels=D_MODEL))

# The tokenizer is just UTF-8 encoding (with an offset)
tokenizer = bytes_tokenizer.BytesTokenizer()


* Runs a forward pass on the Perceiver.

* Args: inputs: input bytes, an int array of shape [B, T]
* input_mask: Array of shape indicating which entries are valid and which are masked. A truthy value indicates that the entry is valid.
* Returns: The output logits, an array of shape [B, T, vocab_size].
  

In [5]:
## Decoding the Perceiver Model
def apply_perceiver(
    inputs: jnp.ndarray, input_mask: jnp.ndarray) -> jnp.ndarray:
  
  ##cross checking input size
  assert inputs.shape[1] == MAX_SEQ_LEN

  embedding_layer = hk.Embed(
      vocab_size=tokenizer.vocab_size,
      embed_dim=D_MODEL)
  embedded_inputs = embedding_layer(inputs)

  batch_size = embedded_inputs.shape[0]

  input_pos_encoding = perceiver.position_encoding.TrainablePositionEncoding(
      index_dim=MAX_SEQ_LEN, num_channels=D_MODEL)
  embedded_inputs = embedded_inputs + input_pos_encoding(batch_size)
  perceiver_mod = perceiver.Perceiver(
      encoder=perceiver.PerceiverEncoder(**encoder_config),
      decoder=perceiver.BasicDecoder(**decoder_config))
  output_embeddings = perceiver_mod(
      embedded_inputs, is_training=False, input_mask=input_mask, query_mask=input_mask)

  logits = io_processors.EmbeddingDecoder(
      embedding_matrix=embedding_layer.embeddings)(output_embeddings)
  return logits

apply_perceiver = hk.transform(apply_perceiver).apply

In [15]:
input_str = "This is an incomplete sentence where some words are missing."
input_tokens = tokenizer.to_int(input_str)

# Masking "missing". The model performs much better if the masked chunk starts with a space.
input_tokens[51:60] = tokenizer.mask_token
print("Tokenized string without masked bytes:")
print(tokenizer.to_string(input_tokens))


Tokenized string without masked bytes:
This is an incomplete sentence where some words are


In [16]:
# Padding and reshaping inputs
inputs = input_tokens[None]
input_mask = np.ones_like(inputs)

def pad(max_sequence_length: int, inputs, input_mask):
  input_len = inputs.shape[1]
  assert input_len <= max_sequence_length
  pad_len = max_sequence_length - input_len
  padded_inputs = np.pad(
      inputs,
      pad_width=((0, 0), (0, pad_len)),
      constant_values=tokenizer.pad_token)
  padded_mask = np.pad(
      input_mask,
      pad_width=((0, 0), (0, pad_len)),
      constant_values=0)
  return padded_inputs, padded_mask

inputs, input_mask = pad(MAX_SEQ_LEN, inputs, input_mask)


In [17]:
rng = jax.random.PRNGKey(1)  # Unused

out = apply_perceiver(params, rng=rng, inputs=inputs, input_mask=input_mask)

masked_tokens_predictions = out[0, 51:60].argmax(axis=-1)
print("Greedy predictions:")
print(masked_tokens_predictions)
print()
print("Predicted string:")
print(tokenizer.to_string(masked_tokens_predictions))

Greedy predictions:
[ 38 115 111 121 121 111 116 109  52]

Predicted string:
 missing.


***We can see that the model predicted the missing word to be 'missing'.***


In [18]:
input_str = "I love eating chocolates"
input_tokens = tokenizer.to_int(input_str)

# Masking "chocolates". The model performs much better if the masked chunk starts with a space.
input_tokens[14:24] = tokenizer.mask_token
print("Tokenized string without masked bytes:")
print(tokenizer.to_string(input_tokens))


Tokenized string without masked bytes:
I love eating 


In [21]:
rng = jax.random.PRNGKey(1)  # Unused

out = apply_perceiver(params, rng=rng, inputs=inputs, input_mask=input_mask)

masked_tokens_predictions = out[0, 14:24].argmax(axis=-1)
print("Greedy predictions:")
print(masked_tokens_predictions)
print()
print("Predicted string:")
print(tokenizer.to_string(masked_tokens_predictions))

Greedy predictions:
[38 38 38 38 38 38 38 38 38 38]

Predicted string:
          
