In [1]:
import mindspore as ms
ms.set_context(mode=0, device_target="Ascend", device_id=7)

In [2]:
from mindformers import GPT2LMHeadModel, GPT2Tokenizer, TextStreamer, TextIteratorStreamer
tok = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

2023-09-11 15:40:46,517 - mindformers - INFO - config in the yaml file ./checkpoint_download/gpt2/gpt2.yaml are used for tokenizer building.
2023-09-11 15:40:46,538 - mindformers - INFO - build tokenizer class name is: GPT2Tokenizer using args {'unk_token': '<|endoftext|>', 'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'pad_token': '<|endoftext|>', 'vocab_file': './checkpoint_download/gpt2/vocab.json', 'merges_file': './checkpoint_download/gpt2/merges.txt'}.


If you want to enable it, please use semi auto or auto parallel mode by context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL)


2023-09-11 15:40:47,800 - mindformers - INFO - start to read the ckpt file: 497772028
2023-09-11 15:40:49,148 - mindformers - INFO - weights in ./checkpoint_download/gpt2/gpt2.ckpt are loaded
2023-09-11 15:40:49,152 - mindformers - INFO - model built successfully!


In [62]:
text2 = ["An increasing sequence: one,", 
         "I love Beijing, because",
         "The highest mountain in the world is",
         "The largest country in the world is",
         "The largest river in China is",
         "Singapore is located",
         "The population in China is",
         "The atomic bombs were invented in"]

In [63]:
inputs2 = tok(text2, max_length=8, padding='max_length', return_tensors=None, add_special_tokens=False)

In [64]:
input_batch2  = inputs2['input_ids']
input_batch = inputs2['input_ids'][0]

In [65]:
input_batch

[2025, 3649, 8379, 25, 530, 11, 50256, 50256]

In [66]:
input_batch2

[[2025, 3649, 8379, 25, 530, 11, 50256, 50256],
 [40, 1842, 11618, 11, 780, 50256, 50256, 50256],
 [464, 4511, 8598, 287, 262, 995, 318, 50256],
 [464, 4387, 1499, 287, 262, 995, 318, 50256],
 [464, 4387, 7850, 287, 2807, 318, 50256, 50256],
 [29974, 11656, 318, 5140, 50256, 50256, 50256, 50256],
 [464, 3265, 287, 2807, 318, 50256, 50256, 50256],
 [464, 17226, 12134, 547, 15646, 287, 50256, 50256]]

In [67]:
streamer = TextStreamer(tok, skip_prompt=False)

In [71]:
# 不做线程, 单条输入
_ = model.generate(input_batch, streamer=streamer, max_length=20, top_k=1)

An increasing sequence: one, two, three, four, five, six, seven, eight,
2023-09-11 15:53:09,303 - mindformers - INFO - total time: 1.645613431930542 s; generated tokens: 14 tokens; generate speed: 8.507465804758278 tokens/s


In [73]:
# 不做线程, batch输入
output = model.generate(input_batch2, streamer=streamer, max_length=20, top_k=1)

['An increasing sequence: one, two, three, four, five, six, seven, eight,', "I love Beijing, because it's a beautiful city. It's a beautiful city. It's a", 'The highest mountain in the world is the Himalayas, which is the highest mountain in the world', 'The largest country in the world is the United States, with a population of 1.3 billion.', 'The largest river in China is the Yangtze River, which flows through the heart of the country', 'Singapore is located in the middle of the South China Sea, and is a major shipping route for', 'The population in China is growing at a rate of about 1.5 percent per year, according to', 'The atomic bombs were invented in the early 1950s by the United States and Britain. The United States']
2023-09-11 15:53:37,407 - mindformers - INFO - total time: 11.205190181732178 s; generated tokens: 114 tokens; generate speed: 10.173856770932296 tokens/s


### 开线程使用

In [95]:
iter_streamer = TextIteratorStreamer(tok)

In [96]:
# 单输入
from threading import Thread
generation_kwargs = dict(
    input_ids=input_batch,
    streamer=iter_streamer,
    max_length=20,
    top_k=1
)

In [97]:
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

In [98]:
generated_text = ""
for new_text in iter_streamer:
    generated_text += new_text
generated_text

2023-09-11 16:03:09,849 - mindformers - INFO - total time: 1.6586799621582031 s; generated tokens: 14 tokens; generate speed: 8.440446812767787 tokens/s


'An increasing sequence: one, two, three, four, five, six, seven, eight,'

In [107]:
# batch输入
from threading import Thread
generation_kwargs = dict(
    input_ids=input_batch2,
    streamer=iter_streamer,
    max_length=20,
    top_k=1
)

In [108]:
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

In [109]:
output = [""] * len(input_batch2)
print("thread.start! ")
for new_text in iter_streamer:
    print("streamer_batch:", new_text)
    for i in range(len(input_batch2)):
        output[i] += new_text[i]
print(" ===== For loop over! ===== ")

thread.start! 
streamer_batch: ['An increasing sequence: one,', 'I love Beijing, because', 'The highest mountain in the world is', 'The largest country in the world is', 'The largest river in China is', 'Singapore is located', 'The population in China is', 'The atomic bombs were invented in']
streamer_batch: [' two', ' it', ' the', ' the', ' the', ' in', ' growing', ' the']
streamer_batch: [',', "'s", ' Himal', ' United', ' Yang', ' the', ' at', ' early']
streamer_batch: [' three', ' a', 'ay', ' States', 't', ' middle', ' a', ' 1950']
streamer_batch: [',', ' beautiful', 'as', ',', 'ze', ' of', ' rate', 's']
streamer_batch: [' four', ' city', ',', ' with', ' River', ' the', ' of', ' by']
streamer_batch: [',', '.', ' which', ' a', ',', ' South', ' about', ' the']
streamer_batch: [' five', ' It', ' is', ' population', ' which', ' China', ' 1', ' United']
streamer_batch: [',', "'s", ' the', ' of', ' flows', ' Sea', '.', ' States']
streamer_batch: [' six', ' a', ' highest', ' 1', ' through'

In [110]:
output

['An increasing sequence: one, two, three, four, five, six, seven, eight,',
 "I love Beijing, because it's a beautiful city. It's a beautiful city. It's a",
 'The highest mountain in the world is the Himalayas, which is the highest mountain in the world',
 'The largest country in the world is the United States, with a population of 1.3 billion.',
 'The largest river in China is the Yangtze River, which flows through the heart of the country',
 'Singapore is located in the middle of the South China Sea, and is a major shipping route for',
 'The population in China is growing at a rate of about 1.5 percent per year, according to',
 'The atomic bombs were invented in the early 1950s by the United States and Britain. The United States']

In [112]:
a = ['An increasing sequence: one, two, three, four, five, six, seven, eight,',
 "I love Beijing, because it's a beautiful city. It's a beautiful city. It's a",
 'The highest mountain in the world is the Himalayas, which is the highest mountain in the world',
 'The largest country in the world is the United States, with a population of 1.3 billion.',
 'The largest river in China is the Yangtze River, which flows through the heart of the country',
 'Singapore is located in the middle of the South China Sea, and is a major shipping route for',
 'The population in China is growing at a rate of about 1.5 percent per year, according to',
 'The atomic bombs were invented in the early 1950s by the United States and Britain. The United States']

In [113]:
a == output

True