In [1]:
import re
import math
import copy
import nltk
import torch
import pickle
import random
import fractions
import numpy as np
import sympy as sp
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import mean_squared_error
from tqdm.notebook import tqdm_notebook as tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [2]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
df = pd.read_pickle("/kaggle/input/faseroh-dataset-v2/df.pkl")
df = df.drop_duplicates(subset=['simplified_functions']).reset_index(drop = True)
df.head()

Unnamed: 0,function_tree,function,taylor,simplified_functions
0,[x],x,1.0*x,x
1,"[cos, x]",cos(x),0.041667*x**4 - 0.5*x**2 + 1.0,cos(x)
2,"[tan, sqrt, mul, x, x]",tan(x),0.33333*x**3 + 1.0*x,tan(x)
3,"[tanh, x]",tanh(x),-0.33333*x**3 + 1.0*x,tanh(x)
4,"[add, tanh, x, x]",x + tanh(x),-0.33333*x**3 + 2.0*x,x + tanh(x)


In [5]:
class Tokenizer:
    def __init__(self, precision=4, pos_dim=6, max_nums=50):
        self.precision = precision
        self.pos_dim = pos_dim
        self.abs_embeddings, self.rel_embeddings = self.generate_dec_embeddings(max_nums)

    def fit(self, functions):
        self.enc_vocab = self.enc_load_vocab(functions)
        self.dec_vocab, self.target_weights = self.dec_load_vocab()
        self.enc_vocab_size = len(self.enc_vocab)
        self.dec_vocab_size = len(self.dec_vocab)
        self.enc_id_to_token = {idx: token for idx, token in enumerate(self.enc_vocab)}
        self.dec_id_to_token = {idx: token for idx, token in enumerate(self.dec_vocab)}
        self.enc_token_to_id = {token: idx for idx, token in enumerate(self.enc_vocab)}
        self.dec_token_to_id = {token: idx for idx, token in enumerate(self.dec_vocab)}

    def dec_load_vocab(self):
        """
        This is for loading the decoder vocab. While loading I also generate the weights for the tokens, since I use Weighted Cross Entropy
        """
        vocab = []
        weights = []
    
        # Special tokens
        vocab += ['PAD', 'SOS', 'EOS']
        weights += [0.0, 0.0, 1.0]  # Assigning specific weights
    
        # Operators
        vocab += ['+', '-']
        weights += [10] * 2
    
        # Digits
        vocab += [f'{i}' for i in range(10)]
        weights += [2] * 10
    
        # Exponents
        vocab += [f'E{i}' for i in range(-5, 6)]
        weights += [7] * 11
    
        # Variables
        vocab += [f'x{i}' for i in range(5)]
        weights += [14] * 5
    
        # Convert weights to tensor
        weight_tensor = torch.tensor(weights, dtype=torch.float32)
    
        return vocab, weight_tensor


    def return_dec_embeddings(self, seq_len):
        """
        Just returns the precomputed embeddings for model to fetch and use during inference
        """

        abs_pos = self.abs_embeddings[:seq_len]
        rel_pos = self.rel_embeddings[:seq_len, :seq_len]
        
        return abs_pos, rel_pos

    def generate_dec_embeddings(self, max_nums):
        """
        Generates absolute and relative embeddings till some max no. of numbers (for decoder only)
        """
        abs_pos = [0]
        num_pos = [0]
        
        for i in range(max_nums):
            abs_pos.extend([i+1 for j in range(self.precision+3)])
            num_pos.extend([1 - (j/(self.precision+3)) for j in range(self.precision+3)])

        rel_pos = np.zeros((len(num_pos), len(num_pos)))
        
        for i in range(len(num_pos)):
            for j in range(len(num_pos)):
                rel_pos[i][j] = num_pos[i] - num_pos[j]

        return abs_pos, rel_pos

    def sympy_tokenizer(self, expr, path=None):
        """
        Convert a SymPy expression to a tokenized prefix notation.
        During this nested tokenization, it also returns a binary traversal for every node.
        """
        if path is None:
            path = [1]

        if expr.is_Number:
            paths = []
            for i in range(self.precision+2):
                new_path = path[:-1]
                new_path.append(i/(self.precision+1))
                paths.append(new_path)
            return [str(expr)], paths

        if expr.is_Symbol:
            return [str(expr)], [path]

        tokens = [expr.func.__name__.lower()]
        paths = [path]

        for i, arg in enumerate(expr.args):
            new_tokens, new_paths = self.sympy_tokenizer(arg, path + [i])
            tokens.extend(new_tokens)
            paths.extend(new_paths)

        return tokens, paths


    def parse_token(self, token):
        """
        This function converts the tokens to what I actually want it to be.
        In most cases I like sympy's function names for the tokens, but for numbers I use P10 tokenization and for exp1 I just replace with e
        And for fractions, sympy preserves it as a fraction but I convert them back into numbers for more uniformity
        """
        
        # just personal preference for this token
        if token == 'exp1':
            return False, 'e'

        if '/' in token:
            frac = fractions.Fraction(token)
            return True, float(frac)

        try:
            return True, float(token)
        except ValueError:
            return False, token

    def enc_load_vocab(self, functions):
        """
        This function is just for loading up the vocab for encoder for the functions
        """
        
        vocab = ['PAD']
        vocab += ['+', '-']
        vocab += [f'{i}' for i in range(10)]
        vocab += [f'E{i}' for i in range(-1, 2)]

        for fun in tqdm(functions, desc = "Fitting Tokenizer: "):
            tokens, _ = self.sympy_tokenizer(fun)
            for token in tokens:
                isNum, token = self.parse_token(token)
                if isNum:
                    continue
                else:
                    if token not in vocab:
                        vocab.append(token)
        return vocab

    def encode_number(self, x):
        """
        This function is for tokenizing a float number into its P10 tokenization.
        """
            
        sign = '+' if x>=0 else '-'
        x = abs(x)
        x *= 10**self.precision
        x = int(round(x))
        exp = -self.precision

        if x == 0:
            return ['+', 'E0'] + ['0' for _ in range(self.precision)]        
            
        while(x%10 == 0):
            x //= 10
            exp += 1

        x = int(round(x))
        seq = [sign]
        num = []

        for digit in str(x):
            num.append(digit)

        while(len(num)<self.precision):
            num.append('0')
            exp -= 1

        exp += len(num) - 1
        num = num[:self.precision]
        seq.append(f'E{exp}')
        seq.extend(num)

        return seq

    def decode_number(self, x):
        """
        This function is for getting back a float number using its P10 tokenization
        """
        sign = x[0]
        exp = int(x[1][1:])
        x = x[2:]

        num = float(x[0] + "." + "".join(x[1:])) * 10**exp

        if sign == '-':
            num *= -1

        num = round(num, len(x) - exp - 1)

        return num

    def encode_dec(self, poly):
        """
        This function is for tokenizing the output sequence.
        Also returns the absolute and relative positional embeddings of the sequence.
        """
        variables = poly.free_symbols
        x = variables.pop()
        coeff_dict = poly.as_coefficients_dict()
        coeffs = [float(coeff_dict.get(x**i, 0.0)) for i in range(5)]
        
        seq = ['SOS']
        
        for i, coeff in enumerate(coeffs):
            if coeff == 0:
                continue
            seq.append(f'x{i}')
            seq.extend(self.encode_number(coeff))

        abs_pos, rel_pos = self.return_dec_embeddings(len(seq))
        seq.append('EOS')
        
        return [self.dec_token_to_id[token] for token in seq], abs_pos, rel_pos

    def decode_dec(self, seq):
        """
        Just a simple function to convert the sequence to readable tokens
        """
        seq = [self.dec_id_to_token[id] for id in seq]

        return seq

    def seq_to_coeffs(self, seq):
        coeffs = [0] * 5
        
        for i, token in enumerate(seq):
            if token == 'SOS':
                num_list = []

            elif 'x' in token:
                if num_list == []:
                    continue
                num = self.decode_number(num_list)
                degree = int(seq[i-self.precision-3][1])
                coeffs[degree] += num
                num_list = []

            elif token == 'EOS':
                num = self.decode_number(num_list)
                degree = int(seq[i-self.precision-3][1])
                coeffs[degree] += num
                break

            else:
                num_list.append(token)

        for i in range(len(coeffs)):
            coeffs[i] = round(coeffs[i], self.precision)

        return coeffs
        

    def rel_ij(self, i, j):
        """
        Just a simple binary function for the relative embedding calculator
        """
        
        if i == 0:
            if j == 0:
                return 0
            else:
                return -1

        else:
            if j == 0:
                return 0
            else:
                return 1

    def encode_enc(self, fun):
        """
        This code first tokenizes the function and then it uses the custom absolute and relative embeddings according to the function tree
        Since this is my encoder, I am not limited to using some predefined embedding scheme for every token and hence I can leverage a recursive like embeddings to 
        help my model learn good stuff about the tree like nature and structure of the function. Method explained in detail in the README file
        """
        seq = []
        rel_pos = []
        tokens, abs_pos = self.sympy_tokenizer(fun)
        
        for i in range(len(abs_pos)):
            abs_pos[i].extend([0 for _ in range(self.pos_dim - len(abs_pos[i]))])

        for token in tokens:
            isNum, token = self.parse_token(token)
            if isNum:
                num = self.encode_number(token)
                seq.extend(num)
            else:
                seq.append(token)

        seq = [self.enc_token_to_id[token] for token in seq]
        
        for i, row_i in enumerate(abs_pos):
            rel_pos.append([])
            for j, row_j in enumerate(abs_pos):
                r_ij = [self.rel_ij(a, b) for a, b in zip(row_i, row_j)]
                rel_pos[i].append(r_ij)
                
        return seq, abs_pos, rel_pos
            

In [6]:
class DisentangledAttention(nn.Module):
    """
    Implements Disentangled Attention, coded very similarly as MultiHeadAttention of Pytorch.
    Used both Absolute and Relative embeddings biases here.

    Args:
        E_q (int): Size of embedding dim for query
        E_k (int): Size of embedding dim for key
        E_v (int): Size of embedding dim for value
        E_total (int): Total embedding dim of combined heads post input projection. Each head
            has dim E_total // nheads
        abs_q (int): Size of absolute embedding dim for query
        abs_k (int): Size of absolute embedding dim for key
        rel_q (int): Size of relative embedding dim for query
        rel_k (int): Size of relative embedding dim for key
        nheads (int): Number of heads
        dropout (float, optional): Dropout probability. Default: 0.0
        is_causal (bool, optional) : Whether to do a causal masking or not. Default: True
        bias (bool, optional): Whether to add bias to input projection. Default: True
    """

    def __init__(
        self,
        E_q: int,
        E_k: int,
        E_v: int,
        E_total: int,
        abs_q: int,
        abs_k: int,
        rel_q: int,
        rel_k: int,
        nheads: int,
        dropout: float = 0.0,
        is_causal=False,
        bias=False,
        abs_bias=True,
        rel_bias=True,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.nheads = nheads
        self.dropout = nn.Dropout(dropout)
        self.is_causal = is_causal
        self._qkv_same_embed_dim = E_q == E_k and E_q == E_v
        if self._qkv_same_embed_dim:
            self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)
            self.abs_pos = nn.Linear(abs_q, E_total * 2, bias=bias, **factory_kwargs)
            self.rel_pos = nn.Linear(rel_q, E_total * 2, bias=bias, **factory_kwargs)
            self.q_abs = nn.Linear(abs_q, E_total, bias=abs_bias, **factory_kwargs)
            self.k_abs = nn.Linear(abs_k, E_total, bias=abs_bias, **factory_kwargs)
            self.q_rel = nn.Linear(rel_q, E_total, bias=rel_bias, **factory_kwargs)
            self.k_rel = nn.Linear(rel_k, E_total, bias=rel_bias, **factory_kwargs)
        else:
            self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)
            self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs)
            self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs)
            self.q_abs = nn.Linear(abs_q, E_total, bias=abs_bias, **factory_kwargs)
            self.k_abs = nn.Linear(abs_k, E_total, bias=abs_bias, **factory_kwargs)
            self.q_rel = nn.Linear(rel_q, E_total, bias=rel_bias, **factory_kwargs)
            self.k_rel = nn.Linear(rel_k, E_total, bias=rel_bias, **factory_kwargs)
        E_out = E_q
        self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads
        self.bias = bias

    def attention_with_pos_embeddings(self, query, key, value, a_query, a_key, r_query, r_key, is_causal) -> torch.Tensor:
        """
        Here I am implementing attention with absolute and relative positions, I hypothesize that the absolute positions would be
        needed for knowing the node position and the relative for knowing how important two nodes from any position should be related
        to. Detailed explanation of computing and intuition is given in README.
        """

        scale = 1

        if a_query is not None:
            scale += 1

        if r_query is not None:
            scale += 1

        if r_key is not None:
            scale += 1

        B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
        attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
        scale_factor = 1 / math.sqrt(scale * query.size(-1)) 

        attn_weight = query @ key.transpose(-2, -1)

        if a_query is not None:
            attn_bias += a_query @ a_key.transpose(-2, -1)

        if r_key is not None:
            attn_bias += torch.einsum('bhnnd,bhmd->bhmn', r_key, query) 

        if r_query is not None:
            attn_bias += torch.einsum('bhmmd,bhnd->bhmn', r_query, key)

        if is_causal:
            temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).to(query.device)
            attn_bias.masked_fill_(~temp_mask, float("-inf"))
            attn_bias = attn_bias.to(dtype=query.dtype, device=query.device) 
            
        attn_weight += attn_bias
        attn_weight *= scale_factor
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_weight = self.dropout(attn_weight)
        return attn_weight @ value

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        a_query: torch.Tensor = None,
        a_key: torch.Tensor = None,
        r_query: torch.Tensor = None,
        r_key: torch.Tensor = None,
        
        ) -> torch.Tensor:
        """
        Forward pass; runs the following process:
            1. Apply input projection and computes the necessary matrices for Disentangled Attention
            2. Split heads and prepare for Disentangled Attention
            3. Run Attention with Positional Embeddings
            4. Apply output projection

        Args:
            query (torch.Tensor): query
            key (torch.Tensor): key
            value (torch.Tensor): value
            a_query (torch.Tensor): absolute query
            a_key (torch.Tensor): absolute key
            r_query (torch.Tensor): relative query
            r_key (torch.Tensor): relative key
            is_causal (bool, optional): Whether to apply causal mask. Default: False

        Returns:
            attn_output (torch.Tensor): output of shape (N, L_t, E_q)
        """
        # Step 1. Apply input projection
        if self._qkv_same_embed_dim:
            if query is key and key is value and a_query is a_key and r_query is r_key:
                result = self.packed_proj(query)
                query, key, value = torch.chunk(result, 3, dim=-1)

                if a_query is not None:
                    abs_pos = self.abs_pos(a_query)
                    a_query, a_key = torch.chunk(abs_pos, 2, dim=-1)

                if r_query is not None:
                    rel_pos = self.rel_pos(r_query)
                    r_query, r_key = torch.chunk(rel_pos, 2, dim=-1)

            else:
                q_weight, k_weight, v_weight = torch.chunk(
                    self.packed_proj.weight, 3, dim=0
                )
                query, key, value = (
                    F.linear(query, q_weight),
                    F.linear(key, k_weight),
                    F.linear(value, v_weight),
                )

                if a_query is not None:
                    a_query = self.q_abs(a_query)
                if a_key is not None:
                    a_key = self.k_abs(a_key)
                if r_query is not None:
                    r_query = self.q_rel(r_query)
                if r_key is not None:
                    r_key = self.k_rel(r_key)

        else:
            query = self.q_proj(query)
            key = self.k_proj(key)
            value = self.v_proj(value)
            a_query = self.q_abs(a_query)
            a_key = self.k_abs(a_key)
            r_query = self.q_rel(r_query)
            r_key = self.k_rel(r_key)

        
        query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        
        if a_query is not None:
            a_query = a_query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        
        if r_query is not None:
            r_query = r_query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 3)
        
        if a_key is not None:
            a_key = a_key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)
        
        if r_key is not None:
            r_key = r_key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 3)

        attn_output = self.attention_with_pos_embeddings(
            query, key, value, a_query, a_key, r_query, r_key, self.is_causal
        )
        
        attn_output = attn_output.transpose(1, 2).flatten(-2)
        attn_output = self.out_proj(attn_output)

        return attn_output

In [7]:
class Encoder(nn.Module):
    """
    A pretty straightforward implementation of the encoder in my model
    """
    def __init__(self, embed_dim, expansion_dim, n_heads, pos_dim, dropout):
        super().__init__()
        self.mha = DisentangledAttention(embed_dim, embed_dim, embed_dim, embed_dim, pos_dim, pos_dim, pos_dim, pos_dim, n_heads, dropout)
        self.mha_norm = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, expansion_dim),
            nn.ReLU(),
            nn.Linear(expansion_dim, embed_dim)
        )
        self.ffn_norm = nn.LayerNorm(embed_dim)

    def forward(self, x, a, r):
        x = x + self.mha(x, x, x, a, a, r, r)
        x = self.mha_norm(x)
        x = x + self.ffn(x)
        x = self.ffn_norm(x)

        return x

In [8]:
class Decoder(nn.Module):
    """
    A pretty straightforward implementation of the decoder in my model
    """
    def __init__(self, embed_dim, expansion_dim, n_heads, enc_pos_dim, dec_pos_dim, dropout):
        super().__init__()
        self.mha = DisentangledAttention(embed_dim, embed_dim, embed_dim, embed_dim, dec_pos_dim, dec_pos_dim, dec_pos_dim, dec_pos_dim, n_heads, dropout, is_causal = True)
        self.mha_norm = nn.LayerNorm(embed_dim)
        self.cross = DisentangledAttention(embed_dim, embed_dim, embed_dim, embed_dim, dec_pos_dim, enc_pos_dim, dec_pos_dim, enc_pos_dim, n_heads, dropout)
        self.cross_norm = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, expansion_dim),
            nn.ReLU(),
            nn.Linear(expansion_dim, embed_dim)
        )
        self.ffn_norm = nn.LayerNorm(embed_dim)

    def forward(self, enc_x, dec_x, enc_a, dec_a, enc_r, dec_r):
        dec_x = dec_x + self.mha(dec_x, dec_x, dec_x, dec_a, dec_a, dec_r, dec_r)
        dec_x = self.mha_norm(dec_x)
        dec_x = dec_x + self.cross(dec_x, enc_x, enc_x, dec_a, enc_a, dec_r, enc_r)
        dec_x = self.cross_norm(dec_x)
        dec_x = dec_x + self.ffn(dec_x)
        dec_x = self.ffn_norm(dec_x)

        return dec_x

In [9]:
class Transformer(nn.Module):
    """
    Class for my Transformer Achitecture, I have not deviated very much from the vanilla transformer architecture but am using disentangled attention
    """
    def __init__(self, enc_vocab_size, dec_vocab_size, n_encoders, n_decoders, embed_dim, expansion_dim, enc_pos_dim, dec_pos_dim, n_heads, dropout):
        super().__init__()
        self.dec_vocab_size = dec_vocab_size
        self.linear = nn.Linear(embed_dim, dec_vocab_size)
        self.enc_embed = nn.Embedding(enc_vocab_size, embed_dim)
        self.dec_embed = nn.Embedding(dec_vocab_size, embed_dim)
        self.Encoders = nn.ModuleList(Encoder(embed_dim, expansion_dim, n_heads, enc_pos_dim, dropout) for _ in range(n_encoders))
        self.Decoders = nn.ModuleList(Decoder(embed_dim, expansion_dim, n_heads, enc_pos_dim, dec_pos_dim, dropout) for _ in range(n_decoders))

    def forward(self, enc_seq, dec_seq, enc_a, enc_r, dec_a, dec_r):
        enc_x = self.enc_embed(enc_seq)
        dec_x = self.dec_embed(dec_seq)
        
        for encoder in self.Encoders:
            enc_x = encoder(enc_x, enc_a, enc_r)
            
        for decoder in self.Decoders:
            dec_x = decoder(enc_x, dec_x, enc_a, dec_a, enc_r, dec_r)

        logits = self.linear(dec_x)
        
        return logits

In [10]:
class TaylorDataset(Dataset):
    """
    Pytorch custom dataset for this specific task
    """
    def __init__(self, df, tokenizer):
        super().__init__()
        self.functions = df['simplified_functions'].to_list()
        self.polynomials = df['taylor'].to_list()
        self.tokenizer = tokenizer
        self.pos_dim = self.tokenizer.pos_dim
        self.enc_vocab_size = self.tokenizer.enc_vocab_size
        self.dec_vocab_size = self.tokenizer.dec_vocab_size

    def __len__(self):
        return len(self.functions)

    def __getitem__(self, idx):
        fun = self.functions[idx]
        poly = self.polynomials[idx]
        enc_seq, enc_a, enc_r = self.tokenizer.encode_enc(fun)
        out_seq, dec_a, dec_r = self.tokenizer.encode_dec(poly)

        enc_seq = torch.tensor(enc_seq)
        enc_a = torch.tensor(enc_a)
        enc_r = torch.tensor(enc_r)
        
        out_seq = torch.tensor(out_seq)
        dec_a = torch.tensor(dec_a)
        dec_r = torch.tensor(dec_r)
        
        return {
            'inputs': (enc_seq, enc_a, enc_r),
            'outputs': (out_seq, dec_a, dec_r)
        }

In [11]:
def collate_fn(batch):
    """
    Custom collate_fn since the paddings are kinda weird and demands a collate fn for proper batch training
    """
    enc_seqs, enc_as, enc_rs = [], [], []
    out_seqs, dec_as, dec_rs = [], [], []

    for item in batch:
        inputs, outputs = item['inputs'], item['outputs']
        
        enc_seqs.append(inputs[0])    
        enc_as.append(inputs[1])    
        enc_rs.append(inputs[2])     

        out_seqs.append(outputs[0])   
        dec_as.append(outputs[1])     
        dec_rs.append(outputs[2])     

    enc_seqs = pad_sequence(enc_seqs, batch_first=True, padding_value=0)
    enc_as = pad_sequence(enc_as, batch_first=True, padding_value=0)
    out_seqs = pad_sequence(out_seqs, batch_first=True, padding_value=0)
    dec_as = pad_sequence(dec_as, batch_first=True, padding_value=0).unsqueeze(-1)

    
    def pad_3d_tensors(tensor_list):
        """
        Since my relative embeddings have another dimension (because they are unique for every ith and jth value of the attention computation),
        They require to be padded in two dimension, hence this function implements them
        """
        max_seq_len = max(tensor.shape[0] for tensor in tensor_list) 
        feature_dim = tensor_list[0].shape[2]  

        padded_tensors = []
        for tensor in tensor_list:
            seq_len = tensor.shape[0]
            pad_size = max_seq_len - seq_len
            padded_tensor = F.pad(tensor, (0, 0, 0, pad_size, 0, pad_size), value=0)
            padded_tensors.append(padded_tensor)

        return torch.stack(padded_tensors)

    def pad_2d_tensors(tensor_list):
        max_seq_len = max(tensor.shape[0] for tensor in tensor_list)  # Max seq length

        padded_tensors = []
        for tensor in tensor_list:
            seq_len = tensor.shape[0]
            pad_size = max_seq_len - seq_len
            padded_tensor = F.pad(tensor, (0, pad_size, 0, pad_size), value=0)
            padded_tensors.append(padded_tensor)

        return torch.stack(padded_tensors)

    enc_rs = pad_3d_tensors(enc_rs)  # 3D padding
    dec_rs = pad_2d_tensors(dec_rs).unsqueeze(-1) # 2D padding

    return {
        'inputs': (enc_seqs, enc_as, enc_rs),
        'outputs': (out_seqs, dec_as, dec_rs)
    }

In [12]:
def train_one_epoch(model, dataloader, criterion, optimizer, scheduler):
    """
    Very self explanatory and straighforward code to train the transformer
    """
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc='Training', leave = False):
        inputs, outputs = batch['inputs'], batch['outputs']
        enc_seqs, enc_as, enc_rs = inputs
        out_seqs, dec_as, dec_rs = outputs

        enc_seqs, enc_as, enc_rs = enc_seqs.to(device), enc_as.to(device, dtype = torch.float32), enc_rs.to(device, dtype = torch.float32)
        out_seqs, dec_as, dec_rs = out_seqs.to(device), dec_as.to(device, dtype = torch.float32), dec_rs.to(device, dtype = torch.float32)
        
        optimizer.zero_grad()
        logits = model(enc_seqs, out_seqs[:, :-1], enc_as, enc_rs, dec_as, dec_rs)
        targets = out_seqs[:, 1:].long()
        loss = criterion(logits.reshape(-1, model.dec_vocab_size), targets.reshape(-1))
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        scheduler.step()
    
    return total_loss / len(dataloader)

In [13]:
def validate(model, dataloader, criterion):
    """
    Very self explanatory and straighforward code to train the transformer
    Here I am not generating iteratively simply because I am not benchmarking rather just curious to know if the model is not just memorizing
    If the valid loss also reduces, just gives me a sanity check that the transformer is learning something meaningful instead of memorizing.
    """
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validation', leave = False):
            inputs, outputs = batch['inputs'], batch['outputs']
            enc_seqs, enc_as, enc_rs = inputs
            out_seqs, dec_as, dec_rs = outputs

            enc_seqs, enc_as, enc_rs = enc_seqs.to(device), enc_as.to(device, dtype = torch.float32), enc_rs.to(device, dtype = torch.float32)
            out_seqs, dec_as, dec_rs = out_seqs.to(device), dec_as.to(device, dtype = torch .float32), dec_rs.to(device, dtype = torch.float32)
            
            logits = model(enc_seqs, out_seqs[:, :-1], enc_as, enc_rs, dec_as, dec_rs)
            targets = out_seqs[:, 1:].long()
            loss = criterion(logits.reshape(-1, model.dec_vocab_size), targets.reshape(-1))
            total_loss += loss.item()

        return total_loss / len(dataloader)

In [14]:
tokenizer = Tokenizer()
tokenizer.fit(df['simplified_functions'])

Fitting Tokenizer:   0%|          | 0/2521 [00:00<?, ?it/s]

In [15]:
train_df, test_df = train_test_split(df, train_size = 0.9)
train_df, val_df = train_test_split(train_df, train_size = 0.8)

In [16]:
train_data = TaylorDataset(train_df, tokenizer)
val_data = TaylorDataset(val_df, tokenizer)

In [17]:
train_load = DataLoader(train_data, shuffle = True, batch_size = 32, collate_fn = collate_fn, num_workers = 4)
val_load = DataLoader(val_data, shuffle = False, batch_size = 16, collate_fn = collate_fn, num_workers = 4)

In [18]:
model = Transformer(
    enc_vocab_size=train_data.enc_vocab_size, dec_vocab_size=train_data.dec_vocab_size, n_encoders=12, n_decoders=12,
    embed_dim=768, expansion_dim=3072, enc_pos_dim=train_data.pos_dim, dec_pos_dim=1, n_heads=4, dropout=0.1
).to(device)

criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=0, weight=tokenizer.target_weights.to(device))
optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

In [19]:
num_epochs = 100
best_val_loss = float('inf')
save_path = 'best_transformer.pth'
patience = 10
no_improve_epochs = 0

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_load, criterion, optimizer, scheduler)
    val_loss = validate(model, val_load, criterion)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), save_path)
        print(f"New best model saved with validation loss: {best_val_loss:.4f}")
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1
        if no_improve_epochs >= patience:
            print("Early stopping triggered!")
            break
    print()

model.load_state_dict(torch.load(save_path))

Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 1/100 - Train Loss: 1.8447 - Val Loss: 1.1276
New best model saved with validation loss: 1.1276



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 2/100 - Train Loss: 1.0924 - Val Loss: 0.9510
New best model saved with validation loss: 0.9510



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 3/100 - Train Loss: 0.9989 - Val Loss: 1.0106



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 4/100 - Train Loss: 0.9351 - Val Loss: 0.9113
New best model saved with validation loss: 0.9113



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 5/100 - Train Loss: 0.8476 - Val Loss: 0.8087
New best model saved with validation loss: 0.8087



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 6/100 - Train Loss: 0.8480 - Val Loss: 0.8223



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 7/100 - Train Loss: 0.8181 - Val Loss: 0.7122
New best model saved with validation loss: 0.7122



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 8/100 - Train Loss: 0.7073 - Val Loss: 0.6729
New best model saved with validation loss: 0.6729



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 9/100 - Train Loss: 0.6520 - Val Loss: 0.6219
New best model saved with validation loss: 0.6219



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 10/100 - Train Loss: 0.6040 - Val Loss: 0.5920
New best model saved with validation loss: 0.5920



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 11/100 - Train Loss: 0.5833 - Val Loss: 0.5851
New best model saved with validation loss: 0.5851



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 12/100 - Train Loss: 0.7753 - Val Loss: 0.6844



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 13/100 - Train Loss: 0.6868 - Val Loss: 0.6080



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 14/100 - Train Loss: 0.6152 - Val Loss: 0.6044



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 15/100 - Train Loss: 0.5809 - Val Loss: 0.5788
New best model saved with validation loss: 0.5788



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 16/100 - Train Loss: 0.5503 - Val Loss: 0.5617
New best model saved with validation loss: 0.5617



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 17/100 - Train Loss: 0.5167 - Val Loss: 0.5394
New best model saved with validation loss: 0.5394



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 18/100 - Train Loss: 0.4848 - Val Loss: 0.5278
New best model saved with validation loss: 0.5278



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 19/100 - Train Loss: 0.4601 - Val Loss: 0.5051
New best model saved with validation loss: 0.5051



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 20/100 - Train Loss: 0.4408 - Val Loss: 0.4943
New best model saved with validation loss: 0.4943



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 21/100 - Train Loss: 0.4268 - Val Loss: 0.4909
New best model saved with validation loss: 0.4909



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 22/100 - Train Loss: 0.4158 - Val Loss: 0.4923



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 23/100 - Train Loss: 0.5143 - Val Loss: 0.6156



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 24/100 - Train Loss: 0.5407 - Val Loss: 0.5718



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 25/100 - Train Loss: 0.4994 - Val Loss: 0.5262



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 26/100 - Train Loss: 0.4933 - Val Loss: 0.5304



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 27/100 - Train Loss: 0.4586 - Val Loss: 0.5192



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 28/100 - Train Loss: 0.4326 - Val Loss: 0.5003



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 29/100 - Train Loss: 0.4202 - Val Loss: 0.5237



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 30/100 - Train Loss: 0.4035 - Val Loss: 0.5031



Training:   0%|          | 0/57 [00:00<?, ?it/s]

Validation:   0%|          | 0/29 [00:00<?, ?it/s]

Epoch 31/100 - Train Loss: 0.3767 - Val Loss: 0.4956
Early stopping triggered!


  model.load_state_dict(torch.load(save_path))


<All keys matched successfully>

In [20]:
def finite_state_prediction(logits, state):
    
    logits = logits.squeeze().reshape(1, -1)
    
    state_dict = ['number', 'degree', 'exponent', 'sign']
    assert state in state_dict, "state not in predifined states"

    number_states = [idx for idx, val in enumerate(tokenizer.dec_vocab) if val in '1234567890']
    degree_states = [idx for idx, val in enumerate(tokenizer.dec_vocab) if val in 'x0x1x2x3x4EOS' and val not in '1234567890']
    sign_states = [idx for idx, val in enumerate(tokenizer.dec_vocab) if val in '+-']
    exponent_states = [idx for idx, val in enumerate(tokenizer.dec_vocab) if val not in 'PADSOS1234567890x0x1x2x3x4+-EOS']

    state_indices = {
        'number': number_states,
        'degree': degree_states,
        'exponent': exponent_states,
        'sign': sign_states
    }
    
    valid_indices = state_indices[state]

    state_logits = logits[0, valid_indices]
    
    max_idx = valid_indices[torch.argmax(state_logits).item()]
    
    return torch.tensor(max_idx).to(device)

In [21]:
def evaluate_model(model, df, tokenizer):
    """
    Prints some good metrics about the transformer model using test data. Here I do generate autoregressively and properly benchmark
    
    """
    model.eval()
    preds = []
    targets = []
    max_len = 1 + 7*(3 + tokenizer.precision)

    with torch.no_grad():
        for i, row in tqdm(df.iterrows(), desc="Generating Predictions: ", total = df.shape[0]):
            function = row['simplified_functions']
            polynomial = row['taylor']

            seq, enc_abs, enc_rel = tokenizer.encode_enc(function)
            target, _, _ = tokenizer.encode_dec(polynomial)
            pred = [tokenizer.dec_token_to_id['SOS']]

            seq = torch.tensor(seq).to(device).unsqueeze(0)
            pred = torch.tensor(pred).to(device).unsqueeze(0)
            enc_abs = torch.tensor(enc_abs).to(device, dtype = torch.float32).unsqueeze(0)
            enc_rel = torch.tensor(enc_rel).to(device, dtype = torch.float32).unsqueeze(0)
            cur_state = 'degree'
            next_state_in = 1
            
            while(pred[:, -1] != tokenizer.dec_token_to_id['EOS']):
                
                dec_abs, dec_rel = tokenizer.return_dec_embeddings(pred.shape[0])
                dec_abs = torch.tensor(dec_abs).to(device, dtype = torch.float32).unsqueeze(0).unsqueeze(-1)
                dec_rel = torch.tensor(dec_rel).to(device, dtype = torch.float32).unsqueeze(0).unsqueeze(-1)
                logits = model(seq, pred, enc_abs, enc_rel, dec_abs, dec_rel)
                logits = logits[:, -1, :]
                pred_token = finite_state_prediction(logits, cur_state).reshape(1, 1)
                pred = torch.cat((pred, pred_token), dim=1)
                next_state_in -= 1

                if next_state_in <= 0:
                    if cur_state == 'degree':
                        cur_state = 'sign'
                        next_state_in = 1
                        
                    elif cur_state == 'sign':
                        cur_state = 'exponent'
                        next_state_in = 1

                    elif cur_state == 'exponent':
                        cur_state = 'number'
                        next_state_in = tokenizer.precision

                    elif cur_state == 'number':
                        cur_state = 'degree'
                        next_state_in = 1
                        
                if pred.shape[1] > max_len:
                    break

            pred = [tokenizer.dec_id_to_token[token] for token in pred.squeeze().tolist()]
            target = [tokenizer.dec_id_to_token[token] for token in target]
            pred = tokenizer.seq_to_coeffs(pred)
            target = tokenizer.seq_to_coeffs(target)
            preds.append(pred)
            targets.append(target)
            
    return preds, targets

In [22]:
preds, targets = evaluate_model(model, test_df, tokenizer)

Generating Predictions:   0%|          | 0/253 [00:00<?, ?it/s]

In [23]:
def polynomial_rmse(preds, targets, n=100, x_range=(-1, 1)):
    """
    Generates n random points of x evenly spaced between a certain range, this is deterministic to not introduce any bias for one function, and centered around
    """
    rmse = 0.0
    
    for pred, target in zip(preds, targets):
        x = sp.symbols('x')
        
        # Create polynomial functions from coefficients
        pred_poly = sum(coef * x**i for i, coef in enumerate(pred))
        target_poly = sum(coef * x**i for i, coef in enumerate(target))
        
        # Generate random x values
        x_vals = np.linspace(x_range[0], x_range[1], n)
        
        # Convert sympy expressions to numpy functions
        pred_func = sp.lambdify(x, pred_poly, 'numpy')
        target_func = sp.lambdify(x, target_poly, 'numpy')
        
        # Evaluate functions
        y_pred = pred_func(x_vals)
        y_true = target_func(x_vals)
        
        # Compute R^2 score
        rmse += mean_squared_error(y_true, y_pred)
    
    rmse /= len(preds)
    
    return np.sqrt(rmse)

In [24]:
def coeff_rmse(preds, targets):

    rmse = 0.0
    for pred, target in zip(preds, targets):
        rmse += mean_squared_error(pred, target)

    rmse /= len(preds)
    
    return np.sqrt(rmse)

In [25]:
polynomial_rmse_value = polynomial_rmse(preds, targets)
coeff_rmse_value = coeff_rmse(preds, targets)

print(f"Polynomial RMSE: {polynomial_rmse_value:.4f}")
print(f"Coefficient RMSE: {coeff_rmse_value:.4f}")

Polynomial RMSE: 15.6741
Coefficient RMSE: 18.2635
