In [1]:
from myllm_model import MyModel
from transformers import GPT2Tokenizer
import torch
import torch.nn.functional as F

In [2]:
chkpt = torch.load('/root/autodl-tmp/myllm3-2B-81600.pt')

In [3]:
tkn = GPT2Tokenizer.from_pretrained('./tokenizer/')
tkn.pad_token = '[PAD]'
VOCAB_SIZE = tkn.vocab_size

In [5]:
max_len = 512
model = MyModel(
    vocab=VOCAB_SIZE,
    pad_token_id=tkn.pad_token_id,
    d_model=2560,
    num_head=32,
    num_block=24,
    max_len=max_len
)

In [6]:
model.load_state_dict(chkpt['module'])

<All keys matched successfully>

In [7]:
from tqdm import trange
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    assert logits.dim() == 1 
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits


def sample_sequence(model, context, length, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
                    device='cpu'):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0)
    inputs = context

    output = None
    prefix_kv_list = None

    end_count = 0
    with torch.no_grad():
        for _ in range(length - context.size(1)):
            model_o, prefix_kv_list = model(inputs, prefix_kv_list=prefix_kv_list)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
            next_token_logits = model_o[0, -1, :]

            if output is not None:
                for tkn_id in set(output[0]):
                    next_token_logits[tkn_id] /= repitition_penalty

            next_token_logits = next_token_logits / temperature
            next_token_logits[tkn.bos_token_id] = -float('Inf')
            next_token_logits[tkn.eos_token_id] = -float('Inf')
            next_token_logits[tkn.unk_token_id] = -float('Inf')
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)

            next_token = next_token.unsqueeze(0)
            inputs = next_token

            cur_token = tokenizer.decode(next_token[0])
            if 'End' in cur_token:
                end_count += 1
            else:
                end_count = 0
            
            if end_count >= 2:
                break
                
            print(cur_token, end='')
            if output is None:
                output = next_token
            else:
                output = torch.cat((output, next_token), dim=1)

    return output

In [13]:
from IPython.display import clear_output
def answer(prompt):
    print(prompt, end='')
    context_tokens = tkn.convert_tokens_to_ids(tkn.tokenize(prompt))
    out = sample_sequence(
        model=model, length=100,
        context=context_tokens, tokenizer=tkn,
        temperature=1, top_k=30, repitition_penalty=10
    )
    clear_output()
    print(prompt, '\n', tkn.decode(out[0]))

In [None]:
answer('Bill Gates retired, because')