In [1]:
import torch

from transformers import AutoModelForCausalLM, pipeline, GPT2LMHeadModel, GPT2Tokenizer

import numpy as np
import sympy as sp

In [2]:
from model.tokens import Token, TOKEN_TYPE_EXPRESSIONS, TOKEN_TYPE_ANSWERS
from model.equation_interpreter import Equation
from model.vocabulary import Vocabulary
from model.tokens import Token

In [3]:
from datasets import disable_caching
disable_caching()

In [4]:
# Create a combined vocabulary
vocabulary = Vocabulary.construct_from_list(TOKEN_TYPE_EXPRESSIONS + TOKEN_TYPE_ANSWERS)
vectorized_sample = vocabulary.vectorize(["#", "/", "0", "-1", "[SEP]", "TT_INTEGER"])
vectorized_sample, [vocabulary.getToken(idx) for idx in vectorized_sample]

# Global variables
model_name = "JustSumAI"
project_name = "JustSumAI"
repo_name = f"{model_name}_cleaned_gpt2_data"

In [5]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

input_text = "This is my input sequence."
input_ids = tokenizer.encode(input_text, return_tensors='pt')
input_ids, input_ids.size()

(tensor([[1212,  318,  616, 5128, 8379,   13]]), torch.Size([1, 6]))

In [6]:
vocabulary.end_seq_index

65

# Load model

In [7]:
model = GPT2LMHeadModel.from_pretrained(f"Dragonoverlord3000/{model_name}", force_download=True)
model

Downloading (…)lve/main/config.json:   0%|          | 0.00/943 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading pytorch_model.bin:   0%|          | 0.00/356M [00:00<?, ?B/s]

ConnectionError: HTTPSConnectionPool(host='cdn-lfs.huggingface.co', port=443): Read timed out.

In [None]:
test_example_ids = torch.LongTensor([vocabulary.vectorize(["#", "/", "0", "0"])[:-1] + [vocabulary.separator_index]])
test_example_ids[0], test_example_ids.size()

In [None]:
test = model(test_example_ids).logits
print(test.size())
[vocabulary.getToken(torch.argmax(o).item()) for o in test[0]]

In [None]:
out = model.generate(test_example_ids, 
                     eos_token_id=vocabulary.end_seq_index, 
                     pad_token_id=vocabulary.mask_index)
out, out.size()

In [None]:
l = [vocabulary.getToken(o.item()) for o in out[0]]
l, l.index("[SEP]")

In [None]:
eq = Equation([Token(vocabulary.getToken(o.item())) for o in out[0]][l.index("[SEP]")+1:-1], notation="postfix")
eq

In [None]:
eq.getMathmetaicalNotation()

______

In [None]:
sp.parse_expr(eq.getMathmetaicalNotation(), evaluate=False)

### Accept user input

In [8]:
from IPython.display import display
import ipywidgets as widgets

In [9]:
numerator_degree = 3
denominator_degree = 6

In [13]:
def parentherizer(val):
    if "-" in val:
        return f"({val})"
    return val

In [14]:
math_display = widgets.HTMLMath(
    value=f"Sum Math",
    placeholder='',
    description='',
)

numerator_list = [widgets.SelectionSlider(
    options=TOKEN_TYPE_EXPRESSIONS[:-2],
    value='-5',
    description=f'Root {i+1}',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True
) for i in range(numerator_degree)]

denominator_list = [
    widgets.SelectionSlider(
    options=TOKEN_TYPE_EXPRESSIONS[:-2],
    value='-5',
    description=f'Root {i+1}',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True
) for i in range(denominator_degree)]

ui_1 = widgets.VBox(numerator_list)

ui_2 = widgets.VBox(denominator_list)

def f(**kwargs):
    print(kwargs, list(kwargs))
    sum_str = r"$$\sum_{n=1}^{\infty}\frac{" + "".join([f"(n - {parentherizer(kwargs[str(i)])})" for i in range(numerator_degree)]) + "}{" + "".join([f"(n - {parentherizer(kwargs[str(i + numerator_degree)])})" for i in range(denominator_degree)]) + "}" + "$$"
    math_display.value = sum_str
    display(math_display)

out = widgets.interactive(f, **{str(i):v for i,v in enumerate(numerator_list + denominator_list)})

display(out)#ui_1, ui_2, out)

interactive(children=(SelectionSlider(continuous_update=False, description='Root 1', options=('-5', '-4', '-3'…