In [None]:
!pip install tokenizers transformers

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import prepare_tokenizer
from IPython.display import clear_output

In [2]:
tkn, VOCAB_SIZE = prepare_tokenizer('./tokenizer/')

In [3]:
from IPython.display import HTML, display

def set_css():
    display(HTML('''
    <style>
    pre {
        white-space: pre-wrap;
    }
    </style>
    '''))
get_ipython().events.register('pre_run_cell', set_css)

In [6]:
from rope_model import LLM
model = LLM(
    vocab=VOCAB_SIZE,
    pad_token_id=tkn.pad_token_id,
    d_model=2560,
    num_head=32,
    num_blocks=24
)

In [8]:
chkpt = torch.load('/root/autodl-tmp/wrong-backup-models/myllm4-2B-wiki-rope-90000.pt')
model = model.cpu()
model.load_state_dict(chkpt['module'])
model.eval()
print(chkpt.keys())

dict_keys(['module', 'buffer_names', 'optimizer', 'param_shapes', 'frozen_param_shapes', 'shared_params', 'frozen_param_fragments', 'lr_scheduler', 'data_sampler', 'random_ltd', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'global_samples', 'dp_world_size', 'mp_world_size', 'ds_config', 'ds_version'])


In [9]:
from tqdm import trange
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits


def sample_sequence(model, context, length, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
                    device='cpu'):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0)
    inputs = context

    output = None
    prefix_kv_list = None
    with torch.no_grad():
        for _ in range(length - context.size(1)):
            model_o, prefix_kv_list = model(inputs, prefix_kv_list=prefix_kv_list)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
            next_token_logits = model_o[0, -1, :]

            if output is not None:
                for tmp_id in set(output[0]):
                    next_token_logits[tmp_id] /= repitition_penalty

            next_token_logits = next_token_logits / temperature
            next_token_logits[tkn.bos_token_id] = -float('Inf')
            next_token_logits[tkn.eos_token_id] = -float('Inf')
            next_token_logits[tkn.unk_token_id] = -float('Inf')
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)

            next_token = next_token.unsqueeze(0)
            inputs = next_token

            print(tokenizer.decode(next_token[0]), end='')
            if output is None:
                output = next_token
            else:
                output = torch.cat((output, next_token), dim=1)

    return output

In [10]:
def answer(prompt):
    print(prompt, end='')
    context_tokens = tkn.convert_tokens_to_ids(tkn.tokenize(prompt))
    out = sample_sequence(
      model=model, length=512,
      context=context_tokens, tokenizer=tkn,
      temperature=0.9, top_k=30, repitition_penalty=5
    )
    clear_output()
    print(prompt, tkn.decode(out[0]))

In [None]:
answer('''Peter Parker''')

Peter Parker may refer to the year before Christmas Day in his final day
 November 29 (which is a complete opposite number of 36 appearances. 

	This article published by radio show that were all three episodes from previous records have been held on Saturday and further data shows, with each seasonally released as "The original purpose-offeeverests are also known for younger viewers." The other people was introduced || at first placeknot counting downsized; however – basedrape some critics' age restrictionable without male/re or running unassisted outlier if you canisterous diffusion has never material" | A–presentation playboyfriending their current events justiculated - see fit within 60 questions instead than liveliest sense: 1stvages whether they're stillness after five generations 461 days later," not only when considering nonmembers per week's = 1052 seconds left-, heuristic howitzoreditch...hearing two minutes required). Inscriptions an interview certain points excepted offshar