In [1]:
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

In [2]:
from infilling_gpt2 import get_model
from tokenizer_util import get_tokenizer, tokenize

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\rabir\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
model = get_model()

In [4]:
tokenizer = get_tokenizer()

In [5]:
s = 'She is a _ and a director'

In [6]:
s = s.replace('_', '<|infillphrase|>')

In [7]:
s

'She is a <|infillphrase|> and a director'

In [8]:
s = s + '<|startofinfill|>'

In [9]:
input_ids = torch.tensor(tokenizer(s)['input_ids']).unsqueeze(0)

In [10]:
input_ids

tensor([[ 3347,   318,   257, 50259,   392,   257,  3437, 50257]])

In [11]:
model_path = os.path.join(os.path.expanduser('E:/ResearchWork/CurrentResearch/gpt2-ilm/ModelState'), 'GPT2FineTuned.pt')

In [12]:
model.load_state_dict(torch.load(model_path, map_location = torch.device('cpu')))

<All keys matched successfully>

In [13]:
def top_k_logits(logits, k, temp):
    if k == 0:
        return logits
    if temp == 0:
        return torch.argmax(logits, dim=-1).unsqueeze(-1)
    if temp != 0:
        logits = logits / temp
    values, _ = torch.topk(logits, k)
    min_values = values[:, -1]
    return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)

In [14]:
def top_p_logits(logits, top_p, filter_value = -float('Inf')):
    if top_p == 0:
        return logits
    
    #samp_probs = F.softmax(logits, dim = -1)
    sorted_logits, sorted_indices = torch.sort(logits, descending = True)
    cumulative_probs = torch.cumsum(sorted_logits, dim = -1)
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) # Need to research this
    logits[indices_to_remove] = filter_value
    return logits

In [15]:
def get_tokens(model, input_ids):
    model.eval()
    with torch.no_grad():
        out = model(input_ids)
        logits = out['logits']
        print(logits.size())
        logits = logits[:, -1, :]
        #logits = top_k_logits(logits, k = 10, temp = 0.8)
        logits = top_p_logits(logits, top_p = 0.95)
        log_probs = F.softmax(logits, dim=-1)
        output = torch.multinomial(log_probs, num_samples = 1)
        input_ids = torch.cat((input_ids, output), axis = 1)
        return input_ids

In [16]:
t = 0
while (t < 2):
    input_ids = get_tokens(model, input_ids)
    t += 1

torch.Size([1, 8, 50262])
torch.Size([1, 9, 50262])


In [17]:
input_ids

tensor([[ 3347,   318,   257, 50259,   392,   257,  3437, 50257,  2646,  9920]])

In [18]:
tokenizer.decode(x for x in input_ids.squeeze(0))

'She is a <|infillphrase|> and a director <|startofinfill|>  film producer'