## Install all requirements.

In [None]:
!pip install -r requirements.txt

## Define seq2seq model.
### All important information with tensorrt initialization you can find in `utils/trt_model.py`.

All our engines have 57 inputs and 57 outputs. First input is the list of tokens ids. All other inputs are context. If you want to send empty context, then you should create empty tensor with the shape -> (1, 16, **0**, 256) for every context input. First output is the logits. All other outputs are context. Our engines always return only new values of the context, so before send context to the next step you must manually concatenate previous context values with new context values (see example code below).

Inputs names and shapes:
<ol>
    <li>input_ids (1, sequence_length)</li>
    <li>history_key_0 (1, 16, history_length, 256)</li>
    <li>history_value_0 (1, 16, history_length, 256)</li>
    <li>history_key_1 (1, 16, history_length, 256)</li>
    <li>history_value_1 (1, 16, history_length, 256)</li>
</ol>
...
<ol start="56">
    <li>history_key_27 (1, 16, history_length, 256)</li>
    <li>history_value_27 (1, 16, history_length, 256)</li>
</ol>

Outputs names and shapes:
<ol>
    <li>logits (1, sequance_length, 50400)</li>
    <li>out_history_key_0 (1, 16, sequence_length, 256)</li>
    <li>out_history_value_0 (1, 16, sequence_length, 256)</li>
    <li>out_history_key_1 (1, 16, sequence_length, 256)</li>
    <li>out_history_value_1 (1, 16, sequence_length, 256)</li>
</ol>
...
<ol start="56">
    <li>out_history_key_27 (1, 16, sequence_length, 256)</li>
    <li>out_history_value_27 (1, 16, sequence_length, 256)</li>
</ol>

`sequence_length` - dynamic axis which value must lie in the next range \[1, 512\]

`history_length` - dynamic axis which value must lie in the next range \[0, 512\]

In [None]:
import torch

from pathlib import Path
from typing import Union

from utils.trt_model import TrtModel


class TrtSeq2SeqModel:
    def __init__(self, path_to_engine: Union[Path, str]):
        self._model = TrtModel(str(path_to_engine))

    @property
    def batch_size(self) -> int:
        return self._model.binding_shape('input_ids')[0]

    def generate(self, input_ids: torch.Tensor, generate_len: int, return_logit: bool = False) -> torch.Tensor:
        input_ids = input_ids.contiguous()

        input_tensors = {'input_ids': input_ids}
        for name in self._model.inputs:
            if name.startswith('history'):
                # add empty context for the first run
                input_tensors[name] = torch.empty(
                    size=(self.batch_size, 16, 0, 256),
                    dtype=self._model.binding_dtype(name),
                    device='cuda',
                )

        result = []
        output_tensors = None
        for i in range(generate_len):
            output_tensors = self._model.run(input_tensors=input_tensors, output_tensors_cache=output_tensors)

            logits = output_tensors['logits']
            next_id = logits[:, -1, :].argmax(dim=-1, keepdims=True).to(torch.int32)
            result.append(logits.clone() if return_logit else next_id)

            # concatenate previous context values with new context values
            input_tensors['input_ids'] = next_id
            for name, new_value in output_tensors.items():
                if name.startswith('out_history_'):
                    name = name[4:]
                    input_tensors[name] = torch.cat((input_tensors[name], new_value), dim=-2)

        dim = -2 if return_logit else -1
        result = torch.cat(result, dim=dim)

        return result

## Download prebuilded engine and initialize seq2seq model.

**Currently you can find prebuild engines only fo next GPUs:**
* rtx2080ti
* rtx3080ti
* rtx4090

**onnx + build script will be published later.**

In [None]:
from utils.engine import get_engine

path_to_engine = get_engine()
model = TrtSeq2SeqModel(path_to_engine=path_to_engine)

## Seq2seq example.

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-j-6B')

input_text = 'Hello world!'

input_ids = tokenizer(input_text, return_tensors='pt')
input_ids = input_ids['input_ids'].to(device='cuda', dtype=torch.int32)
generated_ids = model.generate(input_ids, generate_len=100)
(generated_ids,) = generated_ids
generated_text = tokenizer.decode(generated_ids)

print(input_text + generated_text)

## Accuracy test.

In [None]:
from utils.test import test_acc


def predict_last_id(input_ids: torch.Tensor) -> torch.Tensor:
    input_ids = input_ids.to(device='cuda', dtype=torch.int32)
    result = model.generate(input_ids, generate_len=1)
    result = result.detach().cpu()

    return result


test_acc(predict_last_id, verbose=True);