In [1]:
!pip install tokenizers transformers

Looking in indexes: http://mirrors.aliyun.com/pypi/simple


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import prepare_tokenizer
from IPython.display import clear_output

In [2]:
tkn, VOCAB_SIZE = prepare_tokenizer('./tokenizer/')

In [3]:
from models import SFLLM
from data_obj import ModelArgs
model = SFLLM(
    vocab_size=VOCAB_SIZE,
    pad_token_id=tkn.pad_token_id,
    args=ModelArgs(
        hidden_states=3200,
        n_heads=32,
        n_layers=32,
        max_len=1024,
        ext_factor=1,
    )
).cpu()

In [4]:
chkpt = torch.load('/root/autodl-tmp/sfllm-magic32/main-0_200000.pt', map_location='cpu')
load_res = model.load_state_dict(chkpt, strict=False)
model.eval()
print(load_res)

<All keys matched successfully>


In [5]:
from tqdm import trange
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits


def sample_sequence(model, context, length, tokenizer, min_length=20, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
                    device='cpu'):
    context = context.long().to(device)
    context = context.unsqueeze(0)
    inputs = context
    
    display_period = max(min_length, length // min_length)

    output = None
    prefix_kv_list = None
    with torch.no_grad():
        display_num = length - context.size(1)
        last_i = display_num - 1
        for i in range(display_num):
            model_o, prefix_kv_list = model(inputs, prefix_kv_list=prefix_kv_list, generate=True)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
            next_token_logits = model_o[0, -1, :]

            if output is not None:
                for tmp_id in set(output[0]):
                    next_token_logits[tmp_id] /= repitition_penalty

            next_token_logits = next_token_logits / temperature
            next_token_logits[tkn.bos_token_id] = -float('Inf')

            if output is None or output.size(-1) < min_length:
                next_token_logits[tkn.eos_token_id] = -float('Inf')
                
            next_token_logits[tkn.unk_token_id] = -float('Inf')
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)

            if next_token.item() == tokenizer.eos_token_id:
                break
            
            next_token = next_token.unsqueeze(0)
            inputs = next_token

            if output is None:
                output = next_token
            else:
                output = torch.cat((output, next_token), dim=1)
                
            if output.size(-1) % display_period == 0 or i >= last_i:
                yield output

In [6]:
from IPython.display import clear_output
def answer(model, tokenizer, prompt):
    context_tokens = tokenizer(f'{prompt}', return_tensors='pt').input_ids[0]
    out_iter = sample_sequence(
        model=model, length=1024,
        context=context_tokens, 
        tokenizer=tkn,
        temperature=1, 
        top_k=10, 
        repitition_penalty=10,
        device='cpu'
    )
    
    for out in out_iter:
        clear_output()
        txt_gen = tkn.decode(out[0])
        print(f'\rAI: {txt_gen.strip()}', end='')

In [7]:
answer(model, tkn, '''Subject: How to make a plane?
Content:''')

AI: This post has been updated with information about the 2019 Subject,