# GPT-2 Secure inference with Puma

In this lab, we showcase how to run 3PC secure inference on a pre-trained [GPT-2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) model for text generation with Puma.

First, we show how to use JAX and the Hugging Face Transformers library for text generation with the pre-trained GPT-2 model. After that, we show how to use Puma on the top of SPU for secure text generation with minor modifications to the plaintext counterpart. 

>The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production.

> This tutorial may need more resources than 16c48g.

## What is Puma?

Puma is a fast and accurate end-to-end 3-party secure Transformer models inference framework. 
Puma designs high quality approximations for expensive functions, such as $\mathsf{GeLU}$ and $\mathsf{Softmax}$, which significantly reduce the cost of secure inference while preserving the model performance. Additionally, we design secure $\mathsf{Embedding}$ and $\mathsf{LayerNorm}$ procedures that faithfully implement the desired functionality without undermining the Transformer architecture.
Puma is approximately $2\times$ faster than the state-of-the-art MPC framework MPCFormer (ICLR 2023) and has similar accuracy as plaintext models without fine-tuning (which the previous works failed to achieve).  

## Text generation using GPT-2 with JAX/FLAX
### Install the transformers library

In [1]:
import sys

!{sys.executable} -m pip install transformers[flax]

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
[0m

>The JAX version required by transformers is not satisfied with SPU. But it's ok to run with the conflicted JAX with SPU in this example.

### Load the pre-trained GPT-2 Model

Please refer to this [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/gpt2) for more details.

In [2]:
from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config

tokenizer = AutoTokenizer.from_pretrained("gpt2")
pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like gpt2 is not the path to a directory containing a file named config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

To hack GeLU function of GPT2, you need to change the `self.act` as `jax.nn.gelu` to hack `gelu`.
For example, in `transformers/src/transformers/models/gpt2/modeling_flax_gpt2.py`, line 296:

```python
hidden_states = self.act(hidden_states)
```

is changed as

```python
hidden_states = jax.nn.gelu(hidden_states)
```


### Define the text generation function


We use a [greedy search strategy](https://huggingface.co/blog/how-to-generate) for text generation here.

In [3]:
def text_generation(input_ids, params):
    config = GPT2Config()
    model = FlaxGPT2LMHeadModel(config=config)

    for _ in range(10):
        outputs = model(input_ids=input_ids, params=params)
        next_token_logits = outputs[0][0, -1, :]
        next_token = jnp.argmax(next_token_logits)
        input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1)
    return input_ids

### Run text generation on CPU

In [4]:
import jax.numpy as jnp

inputs_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')
outputs_ids = text_generation(inputs_ids, pretrained_model.params)

print('-' * 65 + '\nRun on CPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

2023-06-15 17:07:55.627043: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
2023-06-15 17:07:55.627112: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst


-----------------------------------------------------------------
Run on CPU:
-----------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever
-----------------------------------------------------------------


Here we generate 10 tokens. Keep the generated text in mind, we are going to generate text on SPU in the next step.


### Run text generation on SPU

#### Import the necessary libraries and config the optimizations

In [None]:
import secretflow as sf
from typing import Any, Callable, Dict, Optional, Tuple, Union
import jax.nn as jnn
import flax.linen as nn
from flax.linen.linear import Array
import jax
import argparse
import spu.utils.distributed as ppd
import spu.intrinsic as intrinsic
import spu.spu_pb2 as spu_pb2
from contextlib import contextmanager

copts = spu_pb2.CompilerOptions()
copts.enable_pretty_print = False
copts.xla_pp_kind = 2
# enable x / broadcast(y) -> x * broadcast(1/y)
copts.enable_optimize_denominator_with_broadcast = True
Array = Any

# In case you have a running secretflow runtime already.
sf.shutdown()

#### Define the Softmax hijack function.

In [None]:
def hack_softmax(
    x: Array,
    axis: Optional[Union[int, Tuple[int, ...]]] = -1,
    where: Optional[Array] = None,
    initial: Optional[Array] = None,
) -> Array:
    x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
    x = x - x_max

    # exp on large negative is clipped to zero
    b = x > -14
    nexp = jnp.exp(x)

    divisor = jnp.sum(nexp, axis, where=where, keepdims=True)

    return b * (nexp / divisor)


@contextmanager
def hack_softmax_context(msg: str, enabled: bool = False):
    if not enabled:
        yield
        return
    # hijack some target functions
    raw_softmax = jnn.softmax
    jnn.softmax = hack_softmax
    yield
    # recover back
    jnn.softmax = raw_softmax

#### Define the GeLU hijack function

In [None]:
def hack_gelu(
    x: Array,
    axis: Optional[Union[int, Tuple[int, ...]]] = -1,
    where: Optional[Array] = None,
    initial: Optional[Array] = None,
) -> Array:
    b0 = x < -4.0
    b1 = x < -1.95
    b2 = x > 3.0
    b3 = b1 ^ b2 ^ True  # x in [-1.95, 3.0]
    b4 = b0 ^ b1  # x in [-4, -1.95]

    # seg1 = a[3] * x^3 + a[2] * x^2 + a[1] * x + a[0]
    # seg2 = b[6] * x^6 + b[4] * x^4 + b[2] * x^2 + b[1] * x + b[0]
    a_coeffs = jnp.array(
        [
            -0.5054031199708174,
            -0.42226581151983866,
            -0.11807612951181953,
            -0.011034134030615728,
        ]
    )
    b_coeffs = jnp.array(
        [
            0.008526321541038084,
            0.5,
            0.3603292692789629,
            0.0,
            -0.037688200365904236,
            0.0,
            0.0018067462606141187,
        ]
    )
    x2 = jnp.square(x)
    x3 = jnp.multiply(x, x2)
    x4 = jnp.square(x2)
    x6 = jnp.square(x3)

    seg1 = a_coeffs[3] * x3 + a_coeffs[2] * x2 + a_coeffs[1] * x + a_coeffs[0]
    seg2 = (
        b_coeffs[6] * x6
        + b_coeffs[4] * x4
        + b_coeffs[2] * x2
        + b_coeffs[1] * x
        + b_coeffs[0]
    )

    ret = b2 * x + b4 * seg1 + b3 * seg2

    return ret


@contextmanager
def hack_gelu_context(msg: str, enabled: bool = False):
    if not enabled:
        yield
        return
    # hijack some target functions
    raw_gelu = jnn.gelu
    jnn.gelu = hack_gelu
    yield
    # recover back
    jnn.gelu = raw_gelu

#### Launch Puma on GPT2:

In [5]:
sf.init(['alice', 'bob', 'carol'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')
conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol'])
conf['runtime_config']['protocol'] = 'ABY3'
conf['runtime_config']['field'] = 'FM64'
conf['runtime_config']['fxp_exp_mode'] = 0
conf['runtime_config']['fxp_exp_iters'] = 5

spu = sf.SPU(conf)


def get_model_params():
    pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
    return pretrained_model.params


def get_token_ids():
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')


model_params = alice(get_model_params)()
input_token_ids = bob(get_token_ids)()

device = spu
model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)

with hack_softmax_context("hijack jax softmax", enabled=True), hack_gelu_context(
    "hack jax gelu", enabled=True
):
    output_token_ids = spu(text_generation, copts=copts)(
        input_token_ids_, model_params_
    )



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


2023-06-15 17:08:14,157	INFO worker.py:1538 -- Started a local Ray instance.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[2m[36m(pid=2109508)[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
[2m[36m(pid=2109408)[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
[2m[36m(pid=2121303)[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
[2m[36m(pid=2121304)[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
[2m[36m(pid=2121301)[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
[2m[36m(_run pid=2109408)[0m INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
[2m[36m(_run pid=2109408)[0m INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: C

[2m[36m(_run pid=2109408)[0m [2023-06-15 17:08:24.221] [info] [thread_pool.cc:30] Create a fixed thread pool with size 127


### Check the Puma output

As you can see, it's very easy to run GPT-2 inference on Puma. Now let's reveal the generated text from Puma.

In [6]:
outputs_ids = sf.reveal(output_token_ids)
print('-' * 65 + '\nRun on SPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

[2m[36m(_spu_compile pid=2109408)[0m 2023-06-15 17:09:12.722333: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
[2m[36m(_spu_compile pid=2109408)[0m 2023-06-15 17:09:12.722414: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst


[2m[36m(SPURuntime(device_id=None, party=bob) pid=2121303)[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
[2m[36m(SPURuntime(device_id=None, party=alice) pid=2121301)[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
[2m[36m(SPURuntime(device_id=None, party=carol) pid=2121304)[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
-----------------------------------------------------------------
Run on SPU:
-----------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever
-----------------------------------------------------------------


As we can see, the generated text from Puma is exactly same as the generated text from CPU!

This is the end of the lab.
For more benchmarks about Puma, please refer to: https://github.com/secretflow/spu/tree/main/examples/python/ml/flax_llama7b