In [11]:
import torch

from transformers import AutoModelForCausalLM, pipeline, GPT2LMHeadModel, GPT2Tokenizer

import numpy as np
import sympy as sp

In [12]:
from model.tokens import Token, TOKEN_TYPE_EXPRESSIONS, TOKEN_TYPE_ANSWERS
from model.equation_interpreter import Equation
from model.vocabulary import Vocabulary
from model.tokens import Token

In [13]:
from datasets import disable_caching
disable_caching()

In [14]:
# Create a combined vocabulary
vocabulary = Vocabulary.construct_from_list(TOKEN_TYPE_EXPRESSIONS + TOKEN_TYPE_ANSWERS)
vectorized_sample = vocabulary.vectorize(["#", "/", "0", "-1", "[SEP]", "TT_INTEGER"])
vectorized_sample, [vocabulary.getToken(idx) for idx in vectorized_sample]

# Global variables
model_name = "JustSumAI"
project_name = "JustSumAI"
repo_name = f"{model_name}_cleaned_gpt2_data"

In [15]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

input_text = "This is my input sequence."
input_ids = tokenizer.encode(input_text, return_tensors='pt')
input_ids, input_ids.size()

(tensor([[1212,  318,  616, 5128, 8379,   13]]), torch.Size([1, 6]))

In [16]:
vocabulary.end_seq_index

65

# Load model

In [17]:
model = GPT2LMHeadModel.from_pretrained(f"Dragonoverlord3000/{model_name}", force_download=True, revision="6512ca7619eafd2da815379268c73c4382b8d3a1")
model

Downloading pytorch_model.bin: 100%|██████████| 356M/356M [00:37<00:00, 9.42MB/s]
Downloading (…)neration_config.json: 100%|██████████| 134/134 [00:00<00:00, 1.21MB/s]


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(68, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=68, bias=False)
)

In [39]:
test_example_ids = torch.LongTensor([vocabulary.vectorize(["0", "/", "0", "0", "0"])[:-1] + [vocabulary.separator_index]])
test_example_ids[0], test_example_ids.size()

(tensor([64, 19, 39, 19, 19, 19, 67]), torch.Size([1, 7]))

In [40]:
test = model(test_example_ids).logits
print(test.size())
[vocabulary.getToken(torch.argmax(o).item()) for o in test[0]]

torch.Size([1, 7, 68])


['<BEGIN>', '0', '/', '0', '0', '0', 'TT_ZERO']

In [41]:
out = model.generate(test_example_ids, 
                     eos_token_id=vocabulary.end_seq_index, 
                     pad_token_id=vocabulary.mask_index)
out, out.size()



(tensor([[64, 19, 39, 19, 19, 19, 67, 51, 49, 59, 49, 44, 60, 58, 49, 49, 57, 60,
          59, 49, 61, 65]]),
 torch.Size([1, 22]))

In [42]:
l = [vocabulary.getToken(o.item()) for o in out[0]]
l, l.index("[SEP]")

(['<BEGIN>',
  '0',
  '/',
  '0',
  '0',
  '0',
  '[SEP]',
  'TT_ZERO',
  'TT_INTEGER',
  'TT_MINUS',
  'TT_INTEGER',
  'TT_PI',
  'TT_MULTIPLY',
  'TT_PLUS',
  'TT_INTEGER',
  'TT_INTEGER',
  'TT_LOG',
  'TT_MULTIPLY',
  'TT_MINUS',
  'TT_INTEGER',
  'TT_DIVIDE',
  '<END>'],
 6)

In [43]:
eq = Equation([Token(vocabulary.getToken(o.item())) for o in out[0]][l.index("[SEP]")+1:-1], notation="postfix")
eq

<model.equation_interpreter.Equation at 0x105fcced0>

In [44]:
eq.getMathemetaicalNotation()

'((((-Z)+(Z*Pi))-(Z*Log(Z)))/Z)'

______

In [45]:
sp.parse_expr(eq.getMathemetaicalNotation(), evaluate=False)

(Pi*Z - Z*Log(Z) - Z)/Z

### Accept user input

In [25]:
from IPython.display import display
import ipywidgets as widgets

ModuleNotFoundError: No module named 'ipywidgets'

In [None]:
numerator_degree = 3
denominator_degree = 6

In [None]:
def parentherizer(val):
    if "-" in val:
        return f"({val})"
    return val

In [None]:
math_display = widgets.HTMLMath(
    value=f"Sum Math",
    placeholder='',
    description='',
)

numerator_list = [widgets.SelectionSlider(
    options=TOKEN_TYPE_EXPRESSIONS[:-2],
    value='-5',
    description=f'Root {i+1}',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True
) for i in range(numerator_degree)]

denominator_list = [
    widgets.SelectionSlider(
    options=TOKEN_TYPE_EXPRESSIONS[:-2],
    value='-5',
    description=f'Root {i+1}',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True
) for i in range(denominator_degree)]

ui_1 = widgets.VBox(numerator_list)
ui_2 = widgets.VBox(denominator_list)

def f(**kwargs):
    print(kwargs, list(kwargs))
    sum_str = r"$$\sum_{n=1}^{\infty}\frac{" + "".join([f"(n - {parentherizer(kwargs[str(i)])})" for i in range(numerator_degree)]) + "}{" + "".join([f"(n - {parentherizer(kwargs[str(i + numerator_degree)])})" for i in range(denominator_degree)]) + "}" + "$$"
    math_display.value = sum_str
    display(math_display)

out = widgets.interactive(f, **{str(i):v for i,v in enumerate(numerator_list + denominator_list)})

display(out)#ui_1, ui_2, out)

interactive(children=(SelectionSlider(continuous_update=False, description='Root 1', options=('-5', '-4', '-3'…

In [None]:
[v.value for v in numerator_list]

['-5', '-5', '-4/5']

In [None]:
input_sum = vocabulary.vectorize(
    [v.value for v in numerator_list] + ["/"] + [v.value for v in denominator_list]
)[:-1] + [vocabulary.separator_index]
input_sum = torch.LongTensor([input_sum])
out = model.generate(input_sum)
pred = [vocabulary.getToken(o.item()) for o in out[0]]
pred

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:65 for open-end generation.


['<BEGIN>',
 '-5',
 '-5',
 '-4/5',
 '/',
 '5/4',
 '-5/3',
 '3/5',
 '1/2',
 '1/4',
 '3/4',
 '[SEP]',
 'TT_ZERO',
 'TT_INTEGER',
 'TT_ZERO',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_MINUS',
 'TT_INTEGER',
 'TT_INTEGER',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_EULERGAMMA',
 'TT_MINUS',
 'TT_MULTIPLY',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_MINUS',
 'TT_INTEGER',
 'TT_PI',
 'TT_MULTIPLY',
 'TT_INTEGER',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_PLUS',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_PI',
 'TT_MULTIPLY',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_SQRT',
 'TT_INTEGER',
 'TT_PI',
 'TT_MULTIPLY',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 'TT_DIVIDE',
 'TT_INTEGER',
 