In [1]:
from base_model2 import MyModel
from utils import prepare_tokenizer
import torch
import torch.nn.functional as F

In [2]:
chkpt = torch.load('/root/autodl-tmp/flash-llm/main/mp_rank_00_model_states.pt')

In [3]:
tkn, VOCAB_SIZE = prepare_tokenizer('./tokenizer')

In [4]:
max_len = 512
model = MyModel(
    vocab=VOCAB_SIZE,
    pad_token_id=tkn.pad_token_id,
    d_model=2560,
    num_head=32,
    num_block=24,
    max_len=max_len
)

In [5]:
model.load_state_dict(chkpt['module'])

RuntimeError: Error(s) in loading state_dict for MyModel:
	Unexpected key(s) in state_dict: "blocks.0.attn.mask", "blocks.1.attn.mask", "blocks.2.attn.mask", "blocks.3.attn.mask", "blocks.4.attn.mask", "blocks.5.attn.mask", "blocks.6.attn.mask", "blocks.7.attn.mask", "blocks.8.attn.mask", "blocks.9.attn.mask", "blocks.10.attn.mask", "blocks.11.attn.mask", "blocks.12.attn.mask", "blocks.13.attn.mask", "blocks.14.attn.mask", "blocks.15.attn.mask", "blocks.16.attn.mask", "blocks.17.attn.mask", "blocks.18.attn.mask", "blocks.19.attn.mask", "blocks.20.attn.mask", "blocks.21.attn.mask", "blocks.22.attn.mask", "blocks.23.attn.mask". 

In [6]:
from tqdm import trange
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    assert logits.dim() == 1 
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits


def sample_sequence(model, context, length, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
                    device='cpu'):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0)
    inputs = context

    output = None
    prefix_kv_list = None

    end_count = 0
    with torch.no_grad():
        for _ in range(length - context.size(1)):
            model_o, prefix_kv_list = model(inputs, prefix_kv_list=prefix_kv_list, generate=True)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
            next_token_logits = model_o[0, -1, :]

            if output is not None:
                for tkn_id in set(output[0]):
                    next_token_logits[tkn_id] /= repitition_penalty

            next_token_logits = next_token_logits / temperature
            next_token_logits[tkn.bos_token_id] = -float('Inf')
            next_token_logits[tkn.eos_token_id] = -float('Inf')
            next_token_logits[tkn.unk_token_id] = -float('Inf')
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)

            next_token = next_token.unsqueeze(0)
            inputs = next_token

            cur_token = tokenizer.decode(next_token[0])
            if 'End' in cur_token:
                end_count += 1
            else:
                end_count = 0
            
            if end_count >= 2:
                break
                
            print(cur_token, end='')
            if output is None:
                output = next_token
            else:
                output = torch.cat((output, next_token), dim=1)

    return output

In [7]:
from IPython.display import clear_output
def answer(prompt):
    print(prompt, end='')
    context_tokens = tkn(prompt, return_tensors='pt').input_ids.cuda()
    print(context_tokens.shape)
    out = sample_sequence(
        model=model, length=100,
        context=context_tokens[0], tokenizer=tkn,
        temperature=1, top_k=30, repitition_penalty=10
    )
    clear_output()
    print(prompt, '\n', tkn.decode(out[0]))

In [8]:
answer('Bill Gates retired, because')

Bill Gates retired, because 
 .3 to the U, but was a time of its new on his most were in this).  By and an three:
American at one-to for by other playerser statistics from The first or with 1–17 during over P's career as it  under Tsz." In 18; "ex (1thoran), H% below through New York State". Additionally" bgium", he is also

) - Aclosa de
