# GPT-2 private inference with SPU

In lab [Neural Network with SPU](./nn_with_spu.ipynb), we have demonstrated how to use SecretFlow/SPU to train a Neural Network model privately.

In this lab, we showcase how to run private inference on a pre-trained [GPT-2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) model for text generation with SPU.

First, we show how to use JAX and the Hugging Face Transformers library for text generation with the pre-trained GPT-2 model. After that, we show how to use SPU for private text generation with minor modifications to the plaintext counterpart. 

>The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production.

> This tutorial may need more resources than 16c48g.

## Text generation using GPT-2 with JAX/FLAX
### Install the transformers library

In [1]:
import sys

!{sys.executable} -m pip install transformers[flax]

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting transformers[flax]
  Using cached https://pypi.tuna.tsinghua.edu.cn/packages/20/0a/739426a81f7635b422fbe6cb8d1d99d1235579a6ac8024c13d743efa6847/transformers-4.36.2-py3-none-any.whl (8.2 MB)
Collecting safetensors>=0.3.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b6/ae/ef743263a42e21e3722845678c809e3962b6f886453431e847d520839bf0/safetensors-0.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting tokenizers<0.19,>=0.14
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ad/75/56230c5c65b226e707e1adbc759c19fdf1b20bb02c0276796b132c97118a/tokenizers-0.15.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m39.5 MB/s[0m eta [36m0:00:00

>The JAX version required by transformers is not satisfied with SPU. But it's ok to run with the conflicted JAX with SPU in this example.

### Load the pre-trained GPT-2 Model

Please refer to this [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/gpt2) for more details.

In [3]:
from transformers import AutoTokenizer, FlaxGPT2LMHeadModel, GPT2Config

tokenizer = AutoTokenizer.from_pretrained("gpt2")
pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")

OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like gpt2 is not the path to a directory containing a file named config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

### Define the text generation function


We use a [greedy search strategy](https://huggingface.co/blog/how-to-generate) for text generation here.

In [5]:
def text_generation(input_ids, params):
    config = GPT2Config()
    model = FlaxGPT2LMHeadModel(config=config)

    for _ in range(10):
        outputs = model(input_ids=input_ids, params=params)
        next_token_logits = outputs[0][0, -1, :]
        next_token = jnp.argmax(next_token_logits)
        input_ids = jnp.concatenate([input_ids, jnp.array([[next_token]])], axis=1)
    return input_ids

### Run text generation on CPU

In [6]:
import jax.numpy as jnp

inputs_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')
outputs_ids = text_generation(inputs_ids, pretrained_model.params)

print('-' * 65 + '\nRun on CPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

NameError: name 'tokenizer' is not defined

Here we generate 10 tokens. Keep the generated text in mind, we are going to generate text on SPU in the next step.


### Run text generation on SPU

In [5]:
import secretflow as sf

# In case you have a running secretflow runtime already.
sf.shutdown()

sf.init(['alice', 'bob', 'carol'], address='local')

alice, bob = sf.PYU('alice'), sf.PYU('bob')
conf = sf.utils.testing.cluster_def(['alice', 'bob', 'carol'])
conf['runtime_config']['fxp_exp_mode'] = 1
conf['runtime_config']['experimental_disable_mmul_split'] = True
spu = sf.SPU(conf)


def get_model_params():
    pretrained_model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
    return pretrained_model.params


def get_token_ids():
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    return tokenizer.encode('I enjoy walking with my cute dog', return_tensors='jax')


model_params = alice(get_model_params)()
input_token_ids = bob(get_token_ids)()

device = spu
model_params_, input_token_ids_ = model_params.to(device), input_token_ids.to(device)

output_token_ids = spu(text_generation)(input_token_ids_, model_params_)



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


2023-06-15 17:08:14,157	INFO worker.py:1538 -- Started a local Ray instance.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[2m[36m(pid=2109508)[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
[2m[36m(pid=2109408)[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
[2m[36m(pid=2121303)[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
[2m[36m(pid=2121304)[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
[2m[36m(pid=2121301)[0m Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.
[2m[36m(_run pid=2109408)[0m INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
[2m[36m(_run pid=2109408)[0m INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: C

[2m[36m(_run pid=2109408)[0m [2023-06-15 17:08:24.221] [info] [thread_pool.cc:30] Create a fixed thread pool with size 127


### Check the SPU output

As you can see, it's very easy to run GPT-2 inference on SPU. Now let's reveal the generated text from SPU program.

In [6]:
outputs_ids = sf.reveal(output_token_ids)
print('-' * 65 + '\nRun on SPU:\n' + '-' * 65)
print(tokenizer.decode(outputs_ids[0], skip_special_tokens=True))
print('-' * 65)

[2m[36m(_spu_compile pid=2109408)[0m 2023-06-15 17:09:12.722333: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst
[2m[36m(_spu_compile pid=2109408)[0m 2023-06-15 17:09:12.722414: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst


[2m[36m(SPURuntime(device_id=None, party=bob) pid=2121303)[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
[2m[36m(SPURuntime(device_id=None, party=alice) pid=2121301)[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
[2m[36m(SPURuntime(device_id=None, party=carol) pid=2121304)[0m 2023-06-15 17:09:32.011 [info] [thread_pool.cc:ThreadPool:30] Create a fixed thread pool with size 127
-----------------------------------------------------------------
Run on SPU:
-----------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever
-----------------------------------------------------------------


As we can see, the generated text from SPU is exactly same as the generated text from CPU!

This is the end of the lab.