In [25]:
from transformers import BertTokenizer, GPT2LMHeadModel
import torch
from tqdm import trange
import os
import numpy as np
import random
import time

In [26]:
def seed_everything(seed: int = 42):
    """Util to make training reproducible"""
    random.seed(seed)

    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if os.getenv("CUBLAS_WORKSPACE_CONFIG") is not None:
        torch.use_deterministic_algorithms(True)
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [27]:
seed_everything()
tok_path = '..\\..\\Raw_GPT2\\vocab.txt'
pretrain_model_path = "..\\..\\Raw_GPT2\\"
# output_dir = "model\\"

tokenizer = BertTokenizer(vocab_file=tok_path)
model = GPT2LMHeadModel.from_pretrained(pretrain_model_path)

In [28]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
print('using device:', device)
model = model.to(device)
model.eval()

using device: cpu


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(21128, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [29]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
        logits: logits distribution shape (vocabulary size)
        top_k > 0: keep only top k tokens with highest probability (top-k filtering).
        top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert (
        logits.dim() == 1
    )  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

def sample_sequence(
    model,
    context,
    length,
    n_ctx,
    tokenizer,
    temperature=1.0,
    top_k=30,
    top_p=0.0,
    repitition_penalty=1.0,
    device="cpu",
):
    context = torch.tensor(context, dtype=torch.long, device=model.device)
    context = context.unsqueeze(0)
    generated = context
    with torch.no_grad():
        for _ in trange(length):
            inputs = {"input_ids": generated[0][-(n_ctx - 1) :].unsqueeze(0)}
            outputs = model(
                **inputs
            )  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
            next_token_logits = outputs[0][0, -1, :]
            for id in set(generated):
                next_token_logits[id] /= repitition_penalty
            next_token_logits = next_token_logits / temperature
            next_token_logits[tokenizer.convert_tokens_to_ids("[UNK]")] = -float("Inf")
            filtered_logits = top_k_top_p_filtering(
                next_token_logits, top_k=top_k, top_p=top_p
            )
            next_token = torch.multinomial(
                torch.nn.functional.softmax(filtered_logits, dim=-1), num_samples=1
            )
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
    return generated.tolist()[0]

def is_word(word):
    for item in list(word):
        if item not in "qwertyuiopasdfghjklzxcvbnm":
            return False
    return True

In [30]:
raw_text = "我的手机号是156"

nsamples = 10
batch_size = 5
length = 30
temperature = 0.5
repitition_penalty = 5
top_k = 20
top_p = 0
n_ctx = 1024

save_samples = True

if save_samples:
    # if not os.path.exists(save_samples_path):
    #     os.makedirs(save_samples_path)
    samples_file = open("samples.txt", "w", encoding="utf8")

while True:
    context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text))
    generated = 0
    for _ in range(nsamples // batch_size):
        # seed_everything(int(time.time() * 1000) % (2**32 - 1))
        out = sample_sequence(
            model,
            context_tokens,
            length,
            n_ctx,
            tokenizer=tokenizer,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repitition_penalty=repitition_penalty,
            device=device,
        )
        
        for i in range(batch_size):
            generated += 1
            text = tokenizer.convert_ids_to_tokens(out)
            for i, item in enumerate(text[:-1]):  # 确保英文前后有空格
                if is_word(item) and is_word(text[i + 1]):
                    text[i] = item + " "
            for i, item in enumerate(text):
                if item == "[MASK]":
                    text[i] = ""
                elif item == "[CLS]":
                    text[i] = "\n\n"
                elif item == "[SEP]":
                    text[i] = "\n"
            info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n"
            print(info)
            text = "".join(text).replace("##", "").strip()
            print(text)
            if save_samples:
                samples_file.write(info)
                samples_file.write(text)
                samples_file.write("\n")
                samples_file.write("=" * 90)
                samples_file.write("\n" * 2)
    print("=" * 80)
    
    break
samples_file.close()

100%|██████████| 30/30 [00:02<00:00, 13.97it/s]



我的手机号是15612223497。如果你还有什么问题，可以在下面留言告诉小编吧！
【

我的手机号是15612223497。如果你还有什么问题，可以在下面留言告诉小编吧！
【

我的手机号是15612223497。如果你还有什么问题，可以在下面留言告诉小编吧！
【

我的手机号是15612223497。如果你还有什么问题，可以在下面留言告诉小编吧！
【

我的手机号是15612223497。如果你还有什么问题，可以在下面留言告诉小编吧！
【


100%|██████████| 30/30 [00:02<00:00, 13.74it/s]


我的手机号是15621827799，所以他们打电话给你也没用。但这并不影响对方很快就通

我的手机号是15621827799，所以他们打电话给你也没用。但这并不影响对方很快就通

我的手机号是15621827799，所以他们打电话给你也没用。但这并不影响对方很快就通

我的手机号是15621827799，所以他们打电话给你也没用。但这并不影响对方很快就通

我的手机号是15621827799，所以他们打电话给你也没用。但这并不影响对方很快就通



