In [None]:
!nvidia-smi

Sun May 30 03:03:21 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
#!pip install tensorflow==1.13.1 tensorboard==1.13.0 
!pip install transformers==2.1.1
!pip install gdown==3.6.4

Collecting transformers==2.1.1
[?25l  Downloading https://files.pythonhosted.org/packages/fd/f9/51824e40f0a23a49eab4fcaa45c1c797cbf9761adedd0b558dab7c958b34/transformers-2.1.1-py3-none-any.whl (311kB)
[K     |████████████████████████████████| 317kB 7.5MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/f5/99/e0808cb947ba10f575839c43e8fafc9cc44e4a7a2c8f79c60db48220a577/sentencepiece-0.1.95-cp37-cp37m-manylinux2014_x86_64.whl (1.2MB)
[K     |████████████████████████████████| 1.2MB 34.0MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 49.2MB/s 
Collecting boto3
[?25l  Downloading https://files.pythonhosted.org/packages/11/20/4294e37c3c6936c905f1e9da958c776d7fee54a4512bdb7706d69c8720e6/boto3-1.17.84-py2.py3-none-any.whl (131kB)
[K     |███████████

In [None]:
import os
import time
import numpy as np
import pandas as pd
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

# import huggingface transformers
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, AdamW, WarmupLinearSchedule

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [None]:
def top_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k: <=0: no filtering, >0: keep only top k tokens with highest probability.
            top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset
                whose total probability mass is greater than or equal to the threshold top_p.
                In practice, we select the highest probability tokens whose cumulative probability mass exceeds
                the threshold top_p.
    """
    # batch support!
    if top_k > 0:
        values, _ = torch.topk(logits, top_k)
        min_values = values[:, -1].unsqueeze(1).repeat(1, logits.shape[-1])
        logits = torch.where(logits < min_values, 
                             torch.ones_like(logits, dtype=logits.dtype) * -float('Inf'), 
                             logits)
    if top_p > 0.0:
        # Compute cumulative probabilities of sorted tokens
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probabilities > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        
        sorted_logits = sorted_logits.masked_fill_(sorted_indices_to_remove, filter_value)
        logits = torch.zeros_like(logits).scatter(1, sorted_indices, sorted_logits)
    
    return logits

In [None]:
# np.random.seed(args.seed)
# torch.random.manual_seed(args.seed)
# torch.cuda.manual_seed(args.seed)
np.random.seed(123)
torch.random.manual_seed(123)
torch.cuda.manual_seed(123)

In [None]:
gpt2_small_config = GPT2Config()
gpt2_medium_config = GPT2Config(n_ctx=1024, n_embd=1024, n_layer=24, n_head=16)
gpt2_large_config = GPT2Config(n_ctx=1024, n_embd=1280, n_layer=36, n_head=20)   

In [None]:
# load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

100%|██████████| 1042301/1042301 [00:00<00:00, 2828227.66B/s]
100%|██████████| 456318/456318 [00:00<00:00, 1492237.54B/s]


In [None]:
# download all model weights
# - small 335Mb
# - medium 823Mb
# - large 1.6Gb
# !wget https://convaisharables.blob.core.windows.net/lsp/multiref/small_ft.pkl
!wget https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl
#!wget https://convaisharables.blob.core.windows.net/lsp/multiref/large_ft.pkl

--2021-05-30 03:06:48--  https://convaisharables.blob.core.windows.net/lsp/multiref/small_ft.pkl
Resolving convaisharables.blob.core.windows.net (convaisharables.blob.core.windows.net)... 13.77.184.64
Connecting to convaisharables.blob.core.windows.net (convaisharables.blob.core.windows.net)|13.77.184.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 351265273 (335M) [application/octet-stream]
Saving to: ‘small_ft.pkl’


2021-05-30 03:06:59 (32.2 MB/s) - ‘small_ft.pkl’ saved [351265273/351265273]

--2021-05-30 03:06:59--  https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl
Resolving convaisharables.blob.core.windows.net (convaisharables.blob.core.windows.net)... 13.77.184.64
Connecting to convaisharables.blob.core.windows.net (convaisharables.blob.core.windows.net)|13.77.184.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 862954531 (823M) [application/octet-stream]
Saving to: ‘medium_ft.pkl’


2021-05-30 03:07

In [None]:
# load the model
model_size = "medium"

if model_size == "small":
    model = GPT2LMHeadModel(gpt2_small_config)
    model.load_state_dict(torch.load("small_ft.pkl"), strict=False)
elif model_size == "medium":
    model = GPT2LMHeadModel(gpt2_medium_config)
    model.load_state_dict(torch.load("medium_ft.pkl"), strict=False)
elif model_size == "large":
    model = GPT2LMHeadModel(gpt2_large_config)
    model.load_state_dict(torch.load("large_ft.pkl"), strict=False)

device = torch.device("cuda")
model = model.to(device)

In [None]:
# beg huggingface not to change this anymore
model.lm_head.weight.data = model.transformer.wte.weight.data

In [None]:
eos = [tokenizer.encoder["<|endoftext|>"]]

In [None]:
past = None
temperature = 0.9
top_k = -1
top_p = 0.9

model.eval()
prev_input = None

while True:
    with torch.no_grad():
        # input and update B's utterance
        user = input("User:")
        
        if user == "quit":
            "stop talking!"
            break
        
        user = tokenizer.encode(user)
        prev_input = user
        prev_input = torch.LongTensor(prev_input).unsqueeze(0).to(device)
        _, past = model(prev_input, past=past)

        prev_input = torch.LongTensor([eos]).to(device) 

        sent = []
        for i in range(500):
            logits, past = model(prev_input, past=past)
            logits = logits[:, -1, :] / temperature
            logits = top_filtering(logits, top_k=top_k, top_p=top_p)

            probs = torch.softmax(logits, dim=-1)

            prev_input = torch.multinomial(probs, num_samples=1)
            prev_word = prev_input.item()

            if prev_word == eos[0]:
                break
            sent.append(prev_word)
        
        print("Bot:", tokenizer.decode(sent))
        prev_input = torch.LongTensor([eos]).to(device)
        _, past = model(prev_input, past=past)

User:Hi
[17250]
tensor([[50256]], device='cuda:0')
Bot: hey joe, u say u have the most awesome rainbow jump and my whole life has been liek watching you guys haha
User:haha
[71, 12236]
tensor([[50256]], device='cuda:0')
Bot: sooo jealous :D
User:who are you
[8727, 389, 345]
tensor([[50256]], device='cuda:0')
Bot: Haha I'm a dude
User:dude?
[67, 2507, 30]
tensor([[50256]], device='cuda:0')
Bot: I'm confused now haha


KeyboardInterrupt: ignored