## HuggingFace Multimodal Perceiver Inference on Trn1 / Inf2

**Introduction**

This notebook demonstrates how to compile and run the HuggingFace Multimodal Perceiver model to classify and autoencode video inputs on Neuron. The script is loosely based on HuggingFace's official tutorial for running inference on the multimodal perceiver at https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Perceiver/Perceiver_for_Multimodal_Autoencoding.ipynb

This notebook can be run on the smallest Inf2 instance `inf2.xlarge`

Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the [PyTorch Installation Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/torch-neuronx.html#setup-torch-neuronx). You can select the kernel from the 'Kernel -> Change Kernel' option on the top of this Jupyter notebook page.

**Install Dependencies**

This tutorial requires the following pip packages to be installed:
- `torch-neuronx`
- `neuronx-cc`
- `transformers==4.30.2`
- `opencv-python-headless`
- `imageio`
- `scipy`
- `accelerate`
Furthermore, it requires the `ffmpeg` video-audio converter which is used to extract audio from the input videos.

`torch-neuronx` and `neuronx-cc` should be installed when you configure your environment following the Inf2 setup guide. The remaining dependencies can be installed below:

In [None]:
%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect
!pip install transformers==4.30.2 opencv-python-headless==4.8.0.74 imageio scipy accelerate opencv-python==4.8.0.74

!wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz
!tar xvf ffmpeg-git-amd64-static.tar.xz
!mv ffmpeg-git-*-amd64-static/ffmpeg .
!rm -rf ffmpeg-git-*-amd64-static ffmpeg-git-amd64-static.tar.xz

**Imports**

In [None]:
import base64
import os
import ssl
import re
from urllib import request
import cv2
import imageio
import time
import random
from tqdm import tqdm
import numpy as np
import scipy.io.wavfile
from IPython.display import HTML

from typing import Optional, Tuple, Union
from transformers import PerceiverForMultimodalAutoencoding
from transformers.modeling_outputs import BaseModelOutputWithCrossAttentions
from transformers.models.perceiver.modeling_perceiver import PerceiverBasicDecoder, PerceiverClassifierOutput
from transformers.models.perceiver.modeling_perceiver import restructure
import torch
import torch.nn as nn
import torch_neuronx

**Video Preprocessing Utilities**

The following code cell defines some useful functions for fetching, preprocessing and visualizing the input video. Most of these are taken directly from HuggingFace's official multimodal perceiver tutorial at https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Perceiver/Perceiver_for_Multimodal_Autoencoding.ipynb.  

In [None]:
# Utilities to fetch videos from UCF101 dataset
UCF_ROOT = 'https://www.crcv.ucf.edu/THUMOS14/UCF101/UCF101/'
_VIDEO_LIST = None
_CACHE_DIR_NAME = "video_cache"

os.makedirs("video_cache", exist_ok=True)
# As of July 2020, crcv.ucf.edu doesn't use a certificate accepted by the
# default Colab environment anymore.
unverified_context = ssl._create_unverified_context()

def list_ucf_videos():
  """Lists videos available in UCF101 dataset."""
  global _VIDEO_LIST
  if not _VIDEO_LIST:
    index = request.urlopen(UCF_ROOT, context=unverified_context).read().decode('utf-8')
    videos = re.findall('(v_[\w_]+\.avi)', index)
    _VIDEO_LIST = sorted(set(videos))
  return list(_VIDEO_LIST)

def fetch_ucf_video(video):
  """Fetchs a video and cache into local filesystem."""
  cache_path = os.path.join(_CACHE_DIR_NAME, video)
  if not os.path.exists(cache_path):
    urlpath = request.urljoin(UCF_ROOT, video)
    print('Fetching %s => %s' % (urlpath, cache_path))
    data = request.urlopen(urlpath, context=unverified_context).read()
    open(cache_path, "wb").write(data)
  return cache_path

# Utilities to open video files using CV2
def crop_center_square(frame):
  y, x = frame.shape[0:2]
  min_dim = min(y, x)
  start_x = (x // 2) - (min_dim // 2)
  start_y = (y // 2) - (min_dim // 2)
  return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]

def load_video(path, max_frames=0, resize=(224, 224)):
  cap = cv2.VideoCapture(path)
  frames = []
  try:
    while True:
      ret, frame = cap.read()
      if not ret:
        break
      frame = crop_center_square(frame)
      frame = cv2.resize(frame, resize)
      frame = frame[:, :, [2, 1, 0]]
      frames.append(frame)

      if len(frames) == max_frames:
        break
  finally:
    cap.release()
  return np.array(frames) / 255.0

def to_gif(images):
  converted_images = np.clip(images * 255, 0, 255).astype(np.uint8)
  imageio.mimsave('./animation.gif', converted_images, duration=40, loop=100)
  with open('./animation.gif', 'rb') as f:
    gif_64 = base64.b64encode(f.read()).decode('utf-8')
  return HTML('<img src="data:image/gif;base64,%s"/>' % gif_64)

def play_audio(data, sample_rate=48000):
  scipy.io.wavfile.write('tmp_audio.wav', sample_rate, data)

  with open('./tmp_audio.wav', 'rb') as f:
    audio_64 = base64.b64encode(f.read()).decode('utf-8')
  return HTML('<audio controls src="data:audio/wav;base64,%s"/>' % audio_64)

def table(elements):
  row = ['<td>%s</td>' % el.data for el in elements]
  return HTML('<table><tr>%s</tr></table>' % ''.join(row))


**Fetch and Preprocess Input Videos**

The following cell samples a number of videos at random from the UCF101 dataset and preprocesses them using the utilities defined in the previous cell. You can control how many videos you'd like to process by changing the `num_videos_to_process` variable, keeping in mind that setting this value too high may lead to memory issues on smaller inf2 instances (For demonstration purposes, this script will do all the preprocessing before running any inference. You can, of course, modify the script so that the preprocessing and the inference are pipelined, thereby reducing intermediate memory pressure.)

Note also that because some input videos are unusable due to their lack of an audio stream (handled in the try-except block below), the final number of preprocessed inputs may be less than the value you give to `num_videos_to_process`.

In [None]:
video_names = list_ucf_videos()
num_videos_to_process = 20
vid_indices = random.sample(range(len(video_names)), num_videos_to_process) # Select videos to process at random
videos, audios = [], []
for i in vid_indices:
  video_path = fetch_ucf_video(video_names[i])

  # Extract audio using FFMPEG and encode as pcm float wavfile (only format readable by scipy.io.wavfile).
  !yes | ./ffmpeg -i "$video_path"  -c copy  -f wav -map 0:a? pcm_f32le -ar 48000 -loglevel quiet output.wav

  # There may be no audio stream present in the input video, in which case we simply skip this input, because the model requires both modalities to be present.
  try:
    sample_rate, audio = scipy.io.wavfile.read("output.wav")
  except:
    continue

  if audio.dtype == np.int16:
    audio = audio.astype(np.float32) / 2**15
  elif audio.dtype != np.float32:
    raise ValueError('Unexpected datatype. Model expects sound samples to lie in [-1, 1]')

  video = load_video(video_path)
  audios.append(audio)
  videos.append(video)

print(f"Received {len(audios)} valid input videos to run inference on.")

Now we can visualize the first input:

In [None]:
# Visualize inputs
table([to_gif(videos[0]), play_audio(audios[0])])

While many of the input clips have many frames, we only feed the first 16 frames to the multimodal perceiver. In the code below, we select the first 16 frames of each input and the corresponding audio samples.

In [None]:
# Select the first 16 frames of the video and one of the audio channels for autoencoding and classification
# Also add a dummy batch dimension.
AUDIO_SAMPLES_PER_FRAME = 48000 // 25
SAMPLES_PER_PATCH = 16

preprocessed_images, preprocessed_audios = [], []
for i in range(len(videos)):
    image = videos[i][None, :16]
    preprocessed_images.append(image)

    audio = audios[i]
    if len(audio.shape) == 2:
        audio = audio[None, :16*AUDIO_SAMPLES_PER_FRAME, 0:1]
    elif len(audio.shape) == 1:
        audio = audio[None, :16*AUDIO_SAMPLES_PER_FRAME]
    else:
        raise ValueError("audio has wrong shape")
    preprocessed_audios.append(audio)

**Utilities for Neuron Tracing**

The following cells define some utilities and wrappers that make tracing the multimodal perceiver on Neuron robust and performant. For users that simply want to run the model and see the results, you can stop reading here as it is entirely unnecessary to understand the content inside this cell - you simply need to run it before running the subsequent code.

We define three wrappers and two utility functions:
1. `MultimodalPerceiverWrapper` wraps the perceiver and is called inside the custom forward function `custom_model_forward`. It is an optimization to avoid redundant computation.
2. `custom_model_forward` replaces the model's original `forward` method. When the model is called later on during inference, the `custom_model_forward` function will execute instead of the model's original `forward` method. It instantiates and uses a `MultimodalPerceiverWrapper` to take advantage of the optimization that avoids redundant computation.
3. `custom_decoder_query` replaces the `decoder_query` method of the `PerceiverBasicDecoder` class. This replacement is necessary to make tracing work - without the replacement, tracing the decoder will generate a segfault.
4. `EncoderWrapper` and `NeuronEncoder` wrap the encoder of the perceiver so that it can be traced.
5. `DecoderWrapper` and `NeuronDecoder` wrap the decoder query, decoder, and output postprocessor of the perceiver so that they can be traced together.

In [None]:
class MultimodalPerceiverWrapper(nn.Module):
    def __init__(self, perceiver_model, nchunks, image_chunk_size, audio_chunk_size):
        super().__init__()
        self.perceiver_model = perceiver_model
        self.nchunks = nchunks
        self.image_chunk_size = image_chunk_size
        self.audio_chunk_size = audio_chunk_size
    
    def forward(self, inputs: torch.FloatTensor,
        neuron_decoder,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None):


        output_attentions = output_attentions if output_attentions is not None else self.perceiver_model.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.perceiver_model.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.perceiver_model.config.use_return_dict
        
        if self.perceiver_model.input_preprocessor is not None:
            inputs, modality_sizes, inputs_without_pos = self.perceiver_model.input_preprocessor(inputs)
        else:
            modality_sizes = None
            inputs_without_pos = None
            if inputs.size()[-1] != self.perceiver_model.config.d_model:
                raise ValueError(
                    f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model:"
                    f" {self.perceiver_model.config.d_model}. Make sure to set config.d_model appropriately."
                )

        batch_size, seq_length, _ = inputs.size()
        device = inputs.device

        # If no attention mask is provided, make them all ones
        if attention_mask is None:
            attention_mask = torch.ones((batch_size, seq_length), device=device)
        # Make the attention mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
        extended_attention_mask = self.perceiver_model.invert_attention_mask(attention_mask)

        head_mask = self.perceiver_model.get_head_mask(head_mask, self.perceiver_model.config.num_blocks * self.perceiver_model.config.num_self_attends_per_block)
        embedding_output = self.perceiver_model.embeddings(batch_size=batch_size)

        encoder_outputs = self.perceiver_model.encoder(
            embedding_output,
            attention_mask=None,
            head_mask=head_mask,
            inputs=inputs,
            inputs_mask=extended_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]

        logits = None
        reconstruction = {}
        for chunk_idx in tqdm(range(self.nchunks)):
            subsampled_output_points = {
            'image': torch.arange(
                self.image_chunk_size * chunk_idx, self.image_chunk_size * (chunk_idx + 1)).to(device),
            'audio': torch.arange(
                self.audio_chunk_size * chunk_idx, self.audio_chunk_size * (chunk_idx + 1)).to(device),
            'label': None,
            }
            
            logits = neuron_decoder(sequence_output, extended_attention_mask, 
                                             inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points)

            reconstruction['label'] = logits['label']
            if 'image' not in reconstruction:
                reconstruction['image'] = logits['image']
                reconstruction['audio'] = logits['audio']
            else:
                reconstruction['image'] = torch.cat(
                    [reconstruction['image'], logits['image']], dim=1)
                reconstruction['audio'] = torch.cat(
                    [reconstruction['audio'], logits['audio']], dim=1)
            
            del logits

        return reconstruction

def custom_model_forward(
        self,
        nchunks,
        image_chunk_size,
        audio_chunk_size,
        neuron_decoder,
        inputs: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, PerceiverClassifierOutput]:

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        perceiver_wrapper = MultimodalPerceiverWrapper(self.perceiver, nchunks, image_chunk_size, audio_chunk_size)
        outputs = perceiver_wrapper(
            inputs,
            neuron_decoder,
            attention_mask=attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        return outputs


def custom_decoder_query(self, inputs, modality_sizes=None, inputs_without_pos=None, subsampled_points=None):
    if self.position_encoding_type == "none":  # Queries come from elsewhere
        raise ValueError("You cannot construct decoder queries when position_encoding_type is set to none")
    if subsampled_points is not None:
        # subsampled_points are the indices if the inputs would be flattened
        # however, the inputs aren't flattened, that's why we use unravel_index
        # to get the indices for the unflattened array
        # unravel_index returns a tuple (x_idx, y_idx, ...)
        # stack to get the [n, d] tensor of coordinates

        def unravel_indices(indices, shape):
            coord = []

            for dim in reversed(shape):
                coord.append(indices % dim)
                indices = indices // dim

            coord = torch.stack(coord[::-1], dim=-1)

            return coord

        pos = unravel_indices(subsampled_points, self.output_index_dims)

        batch_size = inputs.shape[0]
        # Map these coordinates to [-1, 1]
        pos = -1 + 2 * pos / torch.tensor(self.output_index_dims)[None, :]
        pos = torch.broadcast_to(pos[None], [batch_size, pos.shape[0], pos.shape[1]])
        # Construct the position encoding.
        if self.position_encoding_type == "trainable":
            pos_emb = self.output_position_encodings(batch_size)
        elif self.position_encoding_type == "fourier":
            pos_emb = self.output_position_encodings(
                self.output_index_dims, batch_size=batch_size, device=inputs.device, dtype=inputs.dtype, pos=pos
            )

        # Optionally project them to a target dimension.
        pos_emb = self.positions_projection(pos_emb)
        pos_emb = torch.reshape(pos_emb, [pos_emb.shape[0], -1, pos_emb.shape[-1]])
    else:
        batch_size = inputs.shape[0]
        index_dims = inputs.shape[2:]

        # Construct the position encoding.
        if self.position_encoding_type == "trainable":
            pos_emb = self.output_position_encodings(batch_size)
        elif self.position_encoding_type == "fourier":
            pos_emb = self.output_position_encodings(
                index_dims, batch_size, device=inputs.device, dtype=inputs.dtype
            )

        # Optionally project them to a target dimension.
        pos_emb = self.positions_projection(pos_emb)

    if self.concat_preprocessed_input:
        if inputs_without_pos is None:
            raise ValueError("Value is required for inputs_without_pos if concat_preprocessed_input is True")
        pos_emb = torch.cat([inputs_without_pos, pos_emb], dim=-1)

    return pos_emb


# Define wrapper for tracing encoder
class EncoderWrapper(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
    
    def forward(self, embedding_output, inputs, extended_attention_mask):
        output = self.encoder(embedding_output, inputs=inputs, inputs_mask=extended_attention_mask)
        return output

class NeuronEncoder(nn.Module):
    def __init__(self, encoder_wrapper):
       super().__init__()
       self.encoder_wrapper = encoder_wrapper
    
    def forward(self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs: Optional[torch.FloatTensor] = None,
        inputs_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True):

        last_hidden_states = self.encoder_wrapper(hidden_states, inputs, inputs_mask)['last_hidden_state']
        return BaseModelOutputWithCrossAttentions(last_hidden_state=last_hidden_states)


# Define wrapper for tracing decoder
class DecoderWrapper(nn.Module):
    def __init__(self, decoder, decoder_query_audio, decoder_query_image, decoder_query_label, output_postprocessor):
        super().__init__()
        self.decoder = decoder
        self.decoder_query_audio = decoder_query_audio
        self.decoder_query_image = decoder_query_image
        self.decoder_query_label = decoder_query_label
        self.output_postprocessor = output_postprocessor
        self.num_query_channels = decoder.num_query_channels
    
    def forward(self, z, query_mask,
                audio_input, audio_input_without_pos, audio_subsampled_point, audio_padding,
                image_input, image_input_without_pos, image_subsampled_point, image_padding,
                label_input, label_input_without_pos, label_padding):
        audio_query = self.decoder_query_audio(inputs=audio_input, inputs_without_pos=audio_input_without_pos, subsampled_points=audio_subsampled_point)
        image_query = self.decoder_query_image(inputs=image_input, inputs_without_pos=image_input_without_pos, subsampled_points=image_subsampled_point)
        label_query = self.decoder_query_label(inputs=label_input, inputs_without_pos=label_input_without_pos)

        def embed(x, pos):
            x = torch.reshape(x, [x.shape[0], np.prod(x.shape[1:-1]), x.shape[-1]])
            pos = torch.broadcast_to(pos, [x.shape[0], x.shape[1], self.num_query_channels - x.shape[2]])
            return torch.cat([x, pos], dim=2)

        audio_padded = embed(audio_query, audio_padding)
        image_padded = embed(image_query, image_padding)
        label_padded = embed(label_query, label_padding)

        decoder_query = torch.cat([audio_padded, image_padded, label_padded], dim=1)
        logits = self.decoder(decoder_query, z, query_mask).logits
        
        output_modality_sizes = {"audio": audio_subsampled_point.shape[0],
                                 "image": image_subsampled_point.shape[0],
                                 "label": 1}
        logits = self.output_postprocessor(logits, modality_sizes=output_modality_sizes)
        return logits

class NeuronDecoder(nn.Module):
    def __init__(self, decoder_wrapper):
        super().__init__()
        self.decoder_wrapper = decoder_wrapper
        self.modalities = decoder_wrapper.decoder.modalities
        self.padding = decoder_wrapper.decoder.padding

    def forward(self, z, query_mask, inputs, modality_sizes, inputs_without_pos=None, subsampled_points=None, output_attentions=False):
        # Partition the flat inputs among the different modalities
        inputs = restructure(modality_sizes, inputs)

        assert(subsampled_points is not None)
        assert(inputs_without_pos is not None)

        for modality, decoder in self.modalities.items():
            if modality == "audio":
                audio_input, audio_input_without_pos, audio_subsampled_point, audio_padding = inputs[modality], inputs_without_pos[modality], subsampled_points[modality].to(torch.float32), self.padding[modality]
            elif modality == "image":
                image_input, image_input_without_pos, image_subsampled_point, image_padding = inputs[modality], inputs_without_pos[modality], subsampled_points[modality].to(torch.float32), self.padding[modality]
            else:
                # label doesn't have subsampled point
                label_input, label_input_without_pos, label_padding = inputs[modality], inputs_without_pos[modality], self.padding[modality]

        assert(audio_input_without_pos is not None)
        assert(audio_subsampled_point is not None)
        assert(image_input_without_pos is not None)
        assert(image_subsampled_point is not None)
        assert(label_input_without_pos is not None)

        output = self.decoder_wrapper(z, query_mask, 
                                        audio_input, audio_input_without_pos, audio_subsampled_point, audio_padding,
                                        image_input, image_input_without_pos, image_subsampled_point, image_padding,
                                        label_input, label_input_without_pos, label_padding)
        return output


**Trace (Compile) the Model on Neuron**

With the utilities defined above, we can now compile the perceiver encoder and the perceiver decoder for inference on Neuron. Note that we compile the encoder and the decoder as two independent modules and save them respectively. Note also that in both calls to `torch_neuronx.trace`, we are passing `--auto-cast=none` as a compiler argument. This was intentional in order to avoid numerical errors when casting to BF16 is allowed. We are aware of the performance hit and are hoping to remove this restriction in an upcoming release.

In [None]:
model = PerceiverForMultimodalAutoencoding.from_pretrained("deepmind/multimodal-perceiver", 
                                                                   low_cpu_mem_usage=True)
COMPILER_WORKDIR_ROOT="perceiver_multimodal_compile_dir"

PerceiverForMultimodalAutoencoding.forward = custom_model_forward
PerceiverBasicDecoder.decoder_query = custom_decoder_query


# --- Compile Encoder ---
# Define sample inputs for tracing encoder
embedding_output = torch.randn(1, 784, 512)
sample_inputs = torch.randn(1, 52097, 704)
extended_attention_mask = torch.zeros(1, 1, 1, 52097)

# Wrap and trace the encoder, save the traced encoder
COMPILER_WORKDIR_ENCODER = os.path.join(COMPILER_WORKDIR_ROOT, "encoder")
neuron_encoder = NeuronEncoder(EncoderWrapper(model.perceiver.encoder))

# You might see a warning from trace about unused input - these are safe to ignore.
neuron_encoder.encoder_wrapper = torch_neuronx.trace(
  neuron_encoder.encoder_wrapper,
  (embedding_output, sample_inputs, extended_attention_mask),
  compiler_workdir=COMPILER_WORKDIR_ENCODER,
  compiler_args=[f"--temp-dir={COMPILER_WORKDIR_ENCODER}", "--auto-cast=none"] # --auto-cast=none is needed to avoid numerical error.
)

# Save compiled encoder
encoder_fname = os.path.join(COMPILER_WORKDIR_ENCODER, 'model.pt')
torch.jit.save(neuron_encoder.encoder_wrapper, encoder_fname)


# --- Compile Decoder ---
# Define sample inputs for tracing decoder
z = torch.randn(1, 784, 512)
query_mask = torch.zeros(1, 1, 1, 52097)

audio_input = torch.randn(1, 1920, 704)
audio_input_without_pos = torch.randn(1, 1920, 16)
audio_subsampled_point = torch.arange(0, 15, dtype=torch.float32) # 15 = 1920/128
audio_padding = torch.randn(1, 641)

image_input = torch.randn(1, 50176, 704)
image_input_without_pos = torch.randn(1, 50176, 48)
image_subsampled_point = torch.arange(0, 6272, dtype=torch.float32) # 6272 = 224*224*16/128
image_padding = torch.randn(1, 831)

label_input = torch.randn(1, 1, 704)
label_input_without_pos = torch.randn(1, 1, 700)
label_padding = torch.randn(1, 2)

# Wrap and trace the decoder, save the traced decoder
COMPILER_WORKDIR_DECODER = os.path.join(COMPILER_WORKDIR_ROOT, "decoder")
neuron_decoder = NeuronDecoder(DecoderWrapper(model.perceiver.decoder, model.perceiver.decoder.modalities['audio'].decoder_query, \
                                              model.perceiver.decoder.modalities['image'].decoder_query, model.perceiver.decoder.modalities['label'].decoder_query, \
                                              model.perceiver.output_postprocessor))

# You might see a warning from trace about unused input - these are safe to ignore.
neuron_decoder.decoder_wrapper = torch_neuronx.trace(
   neuron_decoder.decoder_wrapper,
   (z, query_mask, audio_input, audio_input_without_pos, audio_subsampled_point, audio_padding,
        image_input, image_input_without_pos, image_subsampled_point, image_padding,
        label_input, label_input_without_pos, label_padding),
   compiler_workdir=COMPILER_WORKDIR_DECODER,
   compiler_args=[f"--temp-dir={COMPILER_WORKDIR_DECODER}", "--auto-cast=none"] # --auto-cast=none is needed to avoid numerical error.
)

# Save compiled decoder
decoder_fname = os.path.join(COMPILER_WORKDIR_DECODER, 'model.pt')
torch.jit.save(neuron_decoder.decoder_wrapper, decoder_fname)


**Load the Compiled Models and Run Inference**

Now that the model is compiled, you can load them and run inference on the preprocessed images and audios. For each set of image and audio, a `reconstruction` dictionary is returned which contains three items: An `audio` with the reconstructed audio tensor, an `image` with the reconstructed image tensor, and a `label` representing the classification logits. We print out the top 3 predicted labels of each input video along with the inference latency.

In [None]:
# -- Load compiled models --
model = PerceiverForMultimodalAutoencoding.from_pretrained("deepmind/multimodal-perceiver", 
                                                                  low_cpu_mem_usage=True)

# load saved encoder from disk
encoder_fname = os.path.join(COMPILER_WORKDIR_ENCODER, 'model.pt')
neuron_encoder = NeuronEncoder(EncoderWrapper(model.perceiver.encoder))
neuron_encoder.encoder_wrapper = torch.jit.load(encoder_fname)
model.perceiver.encoder = neuron_encoder

# load saved decoder from disk
decoder_fname = os.path.join(COMPILER_WORKDIR_DECODER, 'model.pt')
neuron_decoder = NeuronDecoder(DecoderWrapper(model.perceiver.decoder, model.perceiver.decoder.modalities['audio'].decoder_query, \
                                              model.perceiver.decoder.modalities['image'].decoder_query, model.perceiver.decoder.modalities['label'].decoder_query, \
                                              model.perceiver.output_postprocessor))
neuron_decoder.decoder_wrapper = torch.jit.load(decoder_fname)

# Inference function
def autoencode_video(images, audio, nchunks, image_chunk_size, audio_chunk_size):
    input_image = torch.from_numpy(np.moveaxis(images, -1, 2)).to(torch.float32)
    input_audio = torch.from_numpy(audio)
    input_label = torch.zeros((images.shape[0], 700))

    inputs = {'image': input_image, 'audio': input_audio, 'label':input_label}

    reconstruction = {}
    with torch.no_grad():
        reconstruction = model(nchunks, image_chunk_size, audio_chunk_size, neuron_decoder, inputs=inputs)

    # reshape image and audio modalities back to original shape
    reconstruction['image'] = torch.reshape(reconstruction['image'], images.shape)
    reconstruction['audio'] = torch.reshape(reconstruction['audio'], audio.shape)
    return reconstruction

nchunks = 128
reconstructions = []
for audio, image in zip(preprocessed_audios, preprocessed_images):
    image_chunk_size = np.prod(image.shape[1:-1]) // nchunks
    audio_chunk_size = audio.shape[1] // SAMPLES_PER_PATCH // nchunks

    start = time.time()
    reconstruction = autoencode_video(image, audio, nchunks, image_chunk_size, audio_chunk_size)
    print(f"Inference latency is {time.time()-start} seconds")
    reconstructions.append(reconstruction)

    # Print top 3 predicted labels
    scores, indices = torch.topk(torch.softmax(reconstruction["label"], dim=1), k=3)
    top3 = []
    for score, index in zip(scores[0], indices[0]):
        print("%s: %s" % (model.config.id2label[index.item()], score.item()))

We can also visualize any of the reconstructed videos:

In [None]:
# Visualize reconstruction of first 16 frames
table([to_gif(reconstructions[0]["image"][0].numpy()), play_audio(np.array(reconstructions[0]["audio"][0].numpy()))])