In [1]:
import sys
sys.path.append("..")
sys.path.append("../../eqnet")

In [2]:
import torch
import sympy as sp
import numpy as np
import heapq
from copy import deepcopy
from scipy.spatial.distance import cdist
from IPython.display import display, Math, Markdown
from expemb.model import ExpEmbTx

In [3]:
saved_emb = torch.load("../models/20220925-213952183214/saved_embddings.pth", map_location = torch.device("cpu"))

In [4]:
class EmbeddingMathematics:
    def __init__(self, modelpath, exp_list, emb_list):
        self.tokenizer = torch.load(modelpath)["tokenizer"]
        self.model = ExpEmbTx.load_from_checkpoint(modelpath, tokenizer = self.tokenizer)
        self.emb_list = np.array(deepcopy(emb_list))
        self.exp_list = deepcopy(exp_list)
        print(f"emb_list: {self.emb_list.shape}")
        

    @torch.no_grad()
    def find_embedding(self, prefix_eq):
        tensor = self.tokenizer.encode(prefix_eq)
        src = tensor.unsqueeze(1)
        src_mask, src_padding_mask = self.model.create_src_mask(src)
        memory = self.model.encode(src = src, src_mask = src_mask, src_padding_mask = src_padding_mask)
        embedding = memory[:, 0]
        embedding = embedding[1:-1]
        embedding = embedding.max(dim = 0)[0]
        return embedding.detach().cpu().numpy()


    def get_or_create_embedding(self, prefix_eq):
        if prefix_eq not in self.exp_list:
            print(f"Embedding for {prefix_eq} not in the map. Computing...")
            emb = self.find_embedding(prefix_eq)
            self.emb_list = np.append(self.emb_list, emb[None, :], axis = 0)
            self.exp_list.append(prefix_eq)

        return self.emb_list[self.exp_list.index(prefix_eq)]


    def get_analogy(self, x1, y1, y2, expected_x2):
        prefix_to_sympy = self.tokenizer.prefix_to_sympy
        display(Markdown(f"x1: ${sp.latex(prefix_to_sympy(x1))}$ <br />y1: ${sp.latex(prefix_to_sympy(y1))}$ <br />y2: ${sp.latex(prefix_to_sympy(y2))}$ <br />Expected x2: ${sp.latex(prefix_to_sympy(expected_x2))}$"))
        embx1 = self.get_or_create_embedding(x1)
        emby1 = self.get_or_create_embedding(y1)
        emby2 = self.get_or_create_embedding(y2)
        _  = self.get_or_create_embedding(expected_x2)

        embx2 = embx1 - emby1 + emby2
        embx2 = embx2[None, :]

        dist = cdist(self.emb_list, embx2, metric = "cosine")
        dist = dist.squeeze(1)
        maxidx = np.argpartition(dist, 8)[:8]

        # Remove x1, y1, and y2
        if self.exp_list.index(x1) in maxidx:
            maxidx = np.delete(maxidx, np.where(maxidx == self.exp_list.index(x1)))
        if self.exp_list.index(y1) in maxidx:
            maxidx = np.delete(maxidx, np.where(maxidx == self.exp_list.index(y1)))
        if self.exp_list.index(y2) in maxidx:
            maxidx = np.delete(maxidx, np.where(maxidx == self.exp_list.index(y2)))
        maxidx = maxidx[:5]

        closest = {self.exp_list[idx] : dist[idx] for idx in maxidx}
        closest = dict(sorted(closest.items(), key=lambda item: item[1]))
        for exp, score in closest.items():
            print(f"{prefix_to_sympy(exp)} : {score}")

In [5]:
modelpath = "../models/20220925-213952183214/saved_models/best.ckpt"
emb_math = EmbeddingMathematics(modelpath, saved_emb["exp_list"], saved_emb["emb_list"])

emb_list: (2744809, 512)


In [6]:
emb_math.get_analogy("cos x", "sin x", "csc x", "sec x")

x1: $\cos{\left(x \right)}$ <br />y1: $\sin{\left(x \right)}$ <br />y2: $\csc{\left(x \right)}$ <br />Expected x2: $\sec{\left(x \right)}$

Embedding for csc x not in the map. Computing...
sec(x) : 0.3076893021697509
x*cos(x) : 0.3529464371894493
cos(cos(x)) : 0.3548633104427319
x + cos(x) : 0.36909497299135174
cos(log(x)) : 0.3762256392108363


In [7]:
emb_math.get_analogy("sin x", "cos x", "cosh x", "sinh x")

x1: $\sin{\left(x \right)}$ <br />y1: $\cos{\left(x \right)}$ <br />y2: $\cosh{\left(x \right)}$ <br />Expected x2: $\sinh{\left(x \right)}$

Embedding for sinh x not in the map. Computing...
x*cosh(x) : 0.21487585518069596
cosh(sin(x)) : 0.2185385581099304
sin(cosh(x)) : 0.21859363000397347
sin(sinh(x)) : 0.2474086568472712
10*cosh(x) : 0.2536227874401996


In [8]:
emb_math.get_analogy("add pow x INT+ 2 INT- 1", "add x INT+ 1", "add x INT+ 2", "add pow x INT+ 2 INT- 4")

x1: $x^{2} - 1$ <br />y1: $x + 1$ <br />y2: $x + 2$ <br />Expected x2: $x^{2} - 4$

Embedding for add pow x INT+ 2 INT- 1 not in the map. Computing...
Embedding for add x INT+ 1 not in the map. Computing...
Embedding for add x INT+ 2 not in the map. Computing...
Embedding for add pow x INT+ 2 INT- 4 not in the map. Computing...
1/(x**2 + 3) : 0.16227422927981383
2 - 5*x**2 : 0.1658932252991101
1/(x**2 + x) : 0.16627381949072184
2 - 4*x**2 : 0.17113628266389436
1/(x**2 + 2) : 0.17186165858337166


In [9]:
emb_math.get_analogy("add pow x INT+ 2 INT- 1", "add x INT+ 1", "add mul x INT+ 2 INT+ 2", "add mul INT+ 4 pow x INT+ 2 INT- 4")

x1: $x^{2} - 1$ <br />y1: $x + 1$ <br />y2: $2 x + 2$ <br />Expected x2: $4 x^{2} - 4$

Embedding for add mul x INT+ 2 INT+ 2 not in the map. Computing...
Embedding for add mul INT+ 4 pow x INT+ 2 INT- 4 not in the map. Computing...
2*x**4 + 5*x**2/2 : 0.15826225686135298
1/(2*x**2 - 2) : 0.16072043504838307
4*x**(5/2) + 5*x**2/2 : 0.1607452568100033
2 - 5*x**2 : 0.16141582684218314
(2 - x**2)**(-2) : 0.1614917403816265


In [5]:
b = np.random.randn(512)

In [6]:
b.shape

(512,)

In [11]:
c = np.append(a, b[None, :], axis = 0)
c.shape

(101, 512)