# Converting a Language Model via Core ML


### Imports

In [6]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import torch
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from coremltools.models import MLModel
from coremltools.converters import convert

### Model

In [8]:
class FinishMySentence(torch.nn.Module):
    def __init__(self, model=None, eos=198):
        super(FinishMySentence, self).__init__()
        self.eos = torch.tensor([eos])
        self.predict = model

    def forward(self, x):
        generated = x
        token = torch.tensor([0])
        while token != self.eos:
            predictions, _ = self.predict(generated)
            token = torch.argmax(predictions[-1, :], dim=0, keepdim=True)
            generated = torch.cat((generated, token), 0)
        return generated

### Initialize the Token Predictor

In [14]:
token_predictor = GPT2LMHeadModel.from_pretrained("gpt2", torchscript=True).eval()

### Trace the Token Predictor

In [15]:
random_tokens = torch.randint(10000, (5,))

traced_token_predictor = torch.jit.trace(token_predictor, random_tokens)

### Script the Outer Loop

In [16]:
model = FinishMySentence(model=traced_token_predictor)
scripted_model = torch.jit.script(model)

### Convert to Core ML

In [None]:
# FIXME
#mlmodel = convert(
#    scripted_model,
#    inputs=[random_tokens],
#)

### Encode the Sentence Fragment

In [17]:
sentence_fragment = "The Manhattan bridge is"

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
context = torch.tensor(tokenizer.encode(sentence_fragment))

### Run the Model

In [18]:
# FIXME
#coreml_inputs = {"context": context.numpy()}
#prediction_dict = mlmodel.predict(coreml_inputs)
#generated_tensor = prediction_dict["generated"]
#generated_text = tokenizer.decode(generated_tensor)
#print(generated_text)

generated_tensor = scripted_model(context)
generated_text = tokenizer.decode(generated_tensor)
print(generated_text)

The Manhattan bridge is a major artery for the city's subway system, and the bridge is one of the busiest in the country.

