<h1>Hosting Large Models</h1>

In this notebook, we'll host and perform inferences using GPT-2 using remote grid node.  

**Requirements:**
- [Install pytorch_transformers lib.](https://github.com/huggingface/pytorch-transformers#installation)
- [Choose pre-trained model.](https://huggingface.co/pytorch-transformers/pretrained_models.html)
- Run Grid Node app.  

**PS: In this example, we'll use GPT-2 Model (12-layer, 768-hidden, 12-heads, 117M parameters)**


In [2]:
import syft as sy
import torch as th
import grid as gr
from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel

In [3]:
hook = sy.TorchHook(th)

<h2>Set up Configs</h2>

In [4]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Load pre-trained model (weights)
model = GPT2LMHeadModel.from_pretrained('gpt2',torchscript=True)

<h2>Setting Input</h2>

In [5]:
# Encode a text inputs
text = "Who was Jim Henson ? Jim Henson was a"
indexed_tokens = tokenizer.encode(text)

# Convert indexed tokens in a PyTorch tensor
tokens_tensor = th.tensor([indexed_tokens])

<h2>Hosting GPT-2 Model</h2>

In [6]:
traced_model = th.jit.trace(model, (tokens_tensor,))

# Grid Node
bob = gr.WebsocketGridClient(hook, "http://localhost:3000/", id="Bob")
bob.connect()

# Host GPT-2 on Bob worker
bob.serve_model(traced_model, model_id="GPT-2")

  response = method(*args, **kwargs)
  response = method(*args, **kwargs)





{'success': True}

<h2>Running Inference</h2>

In [7]:
%%time
response = bob.run_inference(model_id="GPT-2", data=tokens_tensor)

predictions = th.tensor(response['prediction'])
predicted_index = th.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
print("Predicted text: ", predicted_text)

Predicted text:   Who was Jim Henson? Jim Henson was a great
CPU times: user 1.1 s, sys: 25.6 ms, total: 1.13 s
Wall time: 3.46 s


<h2>Text Generation</h2>

In [9]:
import torch.nn.functional as F


def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < th.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = th.sort(logits, descending=True)
        cumulative_probs = th.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

def sample_sequence(worker, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
    context = th.tensor(context, dtype=th.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)

    predicted_indexes = []

    generated = context
    with th.no_grad():
        for _ in range(length):
            inputs = {'input_ids': generated}

            # Inference
            outputs = th.tensor( worker.run_inference(model_id="GPT-2", data=generated)["prediction"] )

            # Applying Filter
            next_token_logits = outputs[0, -1, :] / temperature
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            next_token = th.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)

            # Update context shifting tokens
            generated = th.cat((th.tensor([generated[0][1:].tolist()]), next_token.unsqueeze(0)), dim=1)

            # Save predicted word
            predicted_indexes.append(th.argmax(outputs[0, -1, :]).item())
    return predicted_indexes

In [10]:
%%time
out = sample_sequence(bob,20, indexed_tokens)
text = tokenizer.decode(indexed_tokens + out, clean_up_tokenization_spaces=True)
print(text)

 Who was Jim Henson? Jim Henson was a great manager for the Computer Entertainment. and he worked the PlayStation. software company game.
 game is
CPU times: user 21.3 s, sys: 287 ms, total: 21.6 s
Wall time: 39.8 s
