In [115]:
import re
import math
import copy
import nltk
import torch
import pickle
import random
import fractions
import numpy as np
import sympy as sp
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import mean_squared_error
from tqdm.notebook import tqdm_notebook as tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

In [116]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [117]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


In [118]:
df = pd.read_pickle("/kaggle/input/df.pkl") 
df = df.drop_duplicates(subset=['simplified_functions']).reset_index(drop = True)
df.head()

Unnamed: 0,function_tree,function,taylor,simplified_functions
0,[x],x,1.0*x,x
1,"[cos, x]",cos(x),0.041667*x**4 - 0.5*x**2 + 1.0,cos(x)
2,"[tan, sqrt, mul, x, x]",tan(x),0.33333*x**3 + 1.0*x,tan(x)
3,"[tanh, x]",tanh(x),-0.33333*x**3 + 1.0*x,tanh(x)
4,"[add, tanh, x, x]",x + tanh(x),-0.33333*x**3 + 2.0*x,x + tanh(x)


In [119]:
class Tokenizer:
    def __init__(self, precision=4, pos_dim=6, max_nums=50):
        self.precision = precision

        self.pos_dim = pos_dim

        self.abs_embeddings, self.rel_embeddings = self.generate_dec_embeddings(max_nums)

    def fit(self, functions):
        self.enc_vocab = self.enc_load_vocab(functions)
        self.dec_vocab, self.target_weights = self.dec_load_vocab()
        self.enc_vocab_size = len(self.enc_vocab)
        self.dec_vocab_size = len(self.dec_vocab)
        self.enc_id_to_token = {idx: token for idx, token in enumerate(self.enc_vocab)}
        self.dec_id_to_token = {idx: token for idx, token in enumerate(self.dec_vocab)}
        self.enc_token_to_id = {token: idx for idx, token in enumerate(self.enc_vocab)}
        self.dec_token_to_id = {token: idx for idx, token in enumerate(self.dec_vocab)}
        print("Tokenizer fitted.")
        print(f"Encoder vocab size: {self.enc_vocab_size}")
        print(f"Decoder vocab size: {self.dec_vocab_size}")

    def dec_load_vocab(self):
        """
        This is for loading the decoder vocab. While loading I also generate the weights for the tokens, since I use Weighted Cross Entropy
        """
        vocab = []
        weights = []

        vocab += ['PAD', 'SOS', 'EOS']
        weights += [0.0, 1.0, 1.0] 

        vocab += ['+', '-']
        weights += [10] * 2

        vocab += [f'{i}' for i in range(10)]
        weights += [2] * 10

        vocab += [f'E{i}' for i in range(-5, 6)]
        weights += [7] * 11

        vocab += [f'x{i}' for i in range(5)]
        weights += [14] * 5

        weight_tensor = torch.tensor(weights, dtype=torch.float32)

        return vocab, weight_tensor

    def return_dec_embeddings(self, seq_len):
        abs_pos = self.abs_embeddings[:seq_len]
        rel_pos = self.rel_embeddings[:seq_len, :seq_len]
        return abs_pos, rel_pos

    def generate_dec_embeddings(self, max_nums):
        abs_pos = [0]
        num_pos = [0]
        for i in range(max_nums):
            abs_pos.extend([i+1 for j in range(self.precision+3)])
            num_pos.extend([1 - (j/(self.precision+3)) for j in range(self.precision+3)])
        rel_pos = np.zeros((len(num_pos), len(num_pos)))
        for i in range(len(num_pos)):
            for j in range(len(num_pos)):
                rel_pos[i][j] = num_pos[i] - num_pos[j]

        return list(abs_pos), np.array(rel_pos)

    def sympy_tokenizer(self, expr, path=None):
        """
        Convert a SymPy expression to a tokenized prefix notation.
        Returns tokens and paths (paths less relevant for basic LSTM).
        """
        if path is None:
            path = [1] 

        if expr.is_Number:

            return [str(expr)], [path] 

        if expr.is_Symbol:
            return [str(expr)], [path]

        tokens = [expr.func.__name__.lower()]
        paths = [path]

        for i, arg in enumerate(expr.args):

            new_tokens, new_paths = self.sympy_tokenizer(arg, path + [i])
            tokens.extend(new_tokens)

        return tokens, paths

    def parse_token(self, token):
        if token == 'exp1':
            return False, 'e'
        if '/' in token:
            try:
                frac = fractions.Fraction(token)
                return True, float(frac)
            except ValueError: 
                 return False, token
        try:

            return True, float(token)
        except ValueError:
            return False, token

    def enc_load_vocab(self, functions):
        vocab = ['PAD', '<UNK>'] 
        vocab += ['+', '-']
        vocab += [f'{i}' for i in range(10)]
        vocab += [f'E{i}' for i in range(-1, 2)] 

        all_tokens = set()
        for fun in tqdm(functions, desc = "Fitting Tokenizer (Encoder): "):
            try:
                tokens, _ = self.sympy_tokenizer(fun)
                for token in tokens:
                    isNum, parsed_token = self.parse_token(token)
                    if not isNum:
                        all_tokens.add(parsed_token)
            except Exception as e:
                print(f"Warning: Skipping function due to tokenization error: {fun} -> {e}")
                continue 

        vocab.extend(sorted(list(all_tokens)))
        return vocab

    def encode_number(self, x):
        sign = '+' if x>=0 else '-'
        x = abs(x)

        if np.isclose(x, 0):

             return ['+', 'E0'] + ['0'] * self.precision

        sci_not = f"{x:.{self.precision}e}"
        parts = sci_not.split('e')
        mantissa_str = parts[0].replace('.', '').replace('-', '') 
        exp = int(parts[1])

        effective_exp = exp

        digits = list(mantissa_str.ljust(self.precision + 1, '0'))[:self.precision +1] 

        final_digits = digits[1:self.precision+1] 

        seq = [sign, f'E{effective_exp}', digits[0]] 
        seq.extend(final_digits) 

        if len(seq) > self.precision + 2:
             seq = seq[:self.precision+2]
        elif len(seq) < self.precision + 2:

             seq.extend(['0'] * (self.precision + 2 - len(seq)))

        exp_val = effective_exp
        if exp_val < -5: exp_val = -5
        if exp_val > 5: exp_val = 5
        seq[1] = f'E{exp_val}' 

        return seq

    def decode_number(self, x):
        """
        This function is for getting back a float number using its P10 tokenization
        """
        if not x or len(x) < self.precision + 2: 
             return 0.0

        sign_token = x[0]
        exp_token = x[1]
        digit_tokens = x[2:]

        sign = -1 if sign_token == '-' else 1
        try:
            exp = int(exp_token[1:]) 
        except (ValueError, IndexError):
            print(f"Warning: Invalid exponent token: {exp_token}")
            return 0.0 

        digit_tokens = list(digit_tokens) 
        while len(digit_tokens) < self.precision:
            digit_tokens.append('0')
        digit_tokens = digit_tokens[:self.precision] 

        num_str = digit_tokens[0] + "." + "".join(digit_tokens[1:])

        try:
            num = float(num_str) * (10**exp)
        except ValueError:
            print(f"Warning: Invalid number format during decode: {num_str}")
            return 0.0

        num *= sign

        return num

    def encode_dec(self, poly):
        """
        Tokenizes the output polynomial sequence for the LSTM decoder.
        Returns only the token IDs.
        """
        variables = list(poly.free_symbols)

        if not variables:

             x = sp.symbols('x') 
             coeffs = [float(poly)] + [0.0] * 4 
        elif len(variables) > 1:
             print(f"Warning: Multiple variables found in polynomial: {variables}. Using {variables[0]}.")
             x = variables[0]

             try:
                 coeff_dict = poly.as_coefficients_dict()
                 coeffs = [float(coeff_dict.get(x**i, 0.0)) for i in range(5)]
             except Exception as e:
                 print(f"Error extracting coefficients for multi-variable poly: {poly} -> {e}")
                 coeffs = [0.0] * 5 
        else:
             x = variables[0]
             coeff_dict = poly.as_coefficients_dict()
             coeffs = [float(coeff_dict.get(x**i, 0.0)) for i in range(5)]

        seq = ['SOS']
        added_term = False
        for i, coeff in enumerate(coeffs):

            if np.isclose(coeff, 0.0, atol=1e-9): 
                continue
            seq.append(f'x{i}')
            seq.extend(self.encode_number(coeff))
            added_term = True

        if not added_term:
             seq.append('x0')
             seq.extend(self.encode_number(0.0))

        seq.append('EOS')

        token_ids = []
        for token in seq:
            token_id = self.dec_token_to_id.get(token)
            if token_id is None:
                print(f"Warning: Token '{token}' not found in decoder vocab during encode_dec. Using PAD.")
                token_ids.append(self.dec_token_to_id['PAD']) 
            else:
                token_ids.append(token_id)

        return token_ids 

    def decode_dec(self, seq_ids):
        """
        Converts a sequence of decoder token IDs back to tokens.
        """
        return [self.dec_id_to_token.get(id, '<UNK>') for id in seq_ids]

    def seq_to_coeffs(self, seq_tokens):
        """
        Converts a sequence of decoded tokens (strings) back to coefficient list.
        """
        coeffs = [0.0] * 5 
        num_list = []
        current_degree = -1 

        if not seq_tokens or seq_tokens[0] != 'SOS':
             print(f"Warning: Sequence does not start with SOS: {seq_tokens[:5]}")

        seq_iter = iter(seq_tokens)
        if seq_tokens[0] == 'SOS':
            next(seq_iter) 

        for token in seq_iter:
            if token == 'EOS':

                if current_degree != -1 and num_list:
                    num = self.decode_number(num_list)
                    if 0 <= current_degree < 5:
                        coeffs[current_degree] = num 
                    else:
                        print(f"Warning: Invalid degree {current_degree} encountered before EOS.")
                break 

            elif token.startswith('x') and token[1:].isdigit():

                if current_degree != -1 and num_list:
                     num = self.decode_number(num_list)
                     if 0 <= current_degree < 5:
                         coeffs[current_degree] = num 
                     else:
                         print(f"Warning: Invalid degree {current_degree} encountered.")

                num_list = []
                try:
                    current_degree = int(token[1:])
                except ValueError:
                     print(f"Warning: Invalid degree token {token}")
                     current_degree = -1 

            elif token in ['+', '-'] or token.startswith('E') or token.isdigit():

                 if current_degree != -1: 
                     num_list.append(token)
                 else:

                     pass 

            elif token not in ['SOS']: 
                 print(f"Warning: Unexpected token '{token}' during coefficient parsing.")

                 num_list = []

        if current_degree != -1 and num_list and token != 'EOS':
             num = self.decode_number(num_list)
             if 0 <= current_degree < 5:
                 coeffs[current_degree] = num
             else:
                 print(f"Warning: Invalid degree {current_degree} encountered at end of sequence.")

        coeffs = [round(c, self.precision) for c in coeffs]

        return coeffs

    def encode_enc(self, fun):
        """
        Tokenizes the input function for the LSTM encoder.
        Returns only the sequence of token IDs.
        """
        try:
            tokens, _ = self.sympy_tokenizer(fun) 
        except Exception as e:
             print(f"Error tokenizing function {fun}: {e}")

             return [self.enc_token_to_id.get('PAD', 0)] 

        seq = []
        for token in tokens:
            isNum, parsed_token = self.parse_token(token)
            if isNum:

                num_tokens = self.encode_number(parsed_token)
                seq.extend(num_tokens)
            else:

                 processed_token = str(parsed_token).lower() 
                 seq.append(processed_token)

        token_ids = []
        for token in seq:
             token_id = self.enc_token_to_id.get(token, self.enc_token_to_id.get('<UNK>'))
             if token_id is None: 
                 print(f"Critical Error: <UNK> token not found in encoder vocab!")
                 token_id = 0 
             token_ids.append(token_id)

        return token_ids

In [120]:
class TaylorDataset(Dataset):
    def __init__(self, df, tokenizer):
        super().__init__()
        self.functions = df['simplified_functions'].to_list()
        self.polynomials = df['taylor'].to_list()
        self.tokenizer = tokenizer
        self.enc_vocab_size = self.tokenizer.enc_vocab_size
        self.dec_vocab_size = self.tokenizer.dec_vocab_size

    def __len__(self):
        return len(self.functions)

    def __getitem__(self, idx):
        fun = self.functions[idx]
        poly = self.polynomials[idx]

        enc_seq_ids = self.tokenizer.encode_enc(fun)
        out_seq_ids = self.tokenizer.encode_dec(poly)

        enc_seq = torch.tensor(enc_seq_ids, dtype=torch.long)
        out_seq = torch.tensor(out_seq_ids, dtype=torch.long)

        return {
            'inputs': enc_seq,
            'outputs': out_seq
        }

In [121]:
def collate_fn(batch):
    """
    Custom collate_fn for LSTM. Pads only the input and output sequences.
    """
    enc_seqs = []
    out_seqs = []
    enc_lens = []

    for item in batch:
        enc_seqs.append(item['inputs'])
        out_seqs.append(item['outputs'])
        enc_lens.append(len(item['inputs']))

    enc_pad_value = 0 
    dec_pad_value = 0 

    enc_seqs_padded = pad_sequence(enc_seqs, batch_first=True, padding_value=enc_pad_value)
    out_seqs_padded = pad_sequence(out_seqs, batch_first=True, padding_value=dec_pad_value)

    return {
        'inputs': enc_seqs_padded,
        'input_lengths': torch.tensor(enc_lens),
        'outputs': out_seqs_padded
    }

In [122]:
class EncoderLSTM(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, src_len):

        embedded = self.dropout(self.embedding(src))

        packed_embedded = pack_padded_sequence(embedded, src_len.to('cpu'), batch_first=True, enforce_sorted=False)

        packed_outputs, (hidden, cell) = self.rnn(packed_embedded)

        outputs, _ = pad_packed_sequence(packed_outputs, batch_first=True)

        return outputs, hidden, cell

In [123]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        self.attn = nn.Linear((enc_hid_dim) + dec_hid_dim, dec_hid_dim, bias=False) 
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)

    def forward(self, hidden, encoder_outputs, mask):

        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]

        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)

        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))

        attention = self.v(energy).squeeze(2)

        attention = attention.masked_fill(mask == 0, -1e10) 

        return F.softmax(attention, dim=1)

In [124]:
class DecoderLSTM(nn.Module):
    def __init__(self, output_dim, embed_dim, enc_hid_dim, dec_hid_dim, n_layers, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, embed_dim)

        self.rnn = nn.LSTM(enc_hid_dim + embed_dim, dec_hid_dim, n_layers, dropout=dropout, batch_first=True)
        self.fc_out = nn.Linear(enc_hid_dim + dec_hid_dim + embed_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell, encoder_outputs, mask):

        input = input.unsqueeze(1) 

        embedded = self.dropout(self.embedding(input))

        a = self.attention(hidden[-1], encoder_outputs, mask)

        a = a.unsqueeze(1)

        weighted = torch.bmm(a, encoder_outputs)

        rnn_input = torch.cat((embedded, weighted), dim=2)

        output, (hidden, cell) = self.rnn(rnn_input, (hidden, cell))

        embedded = embedded.squeeze(1)
        output = output.squeeze(1)
        weighted = weighted.squeeze(1)

        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1))

        return prediction, hidden, cell

In [125]:
class LSTMSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

        self.enc_vocab_size = encoder.embedding.num_embeddings
        self.dec_vocab_size = decoder.output_dim

    def create_mask(self, src):

         mask = (src != 0)
         return mask 

    def forward(self, src, src_len, trg, teacher_forcing_ratio=0.5):

        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim

        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        encoder_outputs, hidden, cell = self.encoder(src, src_len)

        input = trg[:, 0]

        mask = self.create_mask(src) 

        for t in range(1, trg_len): 

            output, hidden, cell = self.decoder(input, hidden, cell, encoder_outputs, mask)

            outputs[:, t] = output

            teacher_force = random.random() < teacher_forcing_ratio

            top1 = output.argmax(1)

            input = trg[:, t] if teacher_force else top1

        return outputs

In [126]:
def train_one_epoch(model, dataloader, criterion, optimizer, scheduler, clip=1.0): 
    model.train()
    total_loss = 0

    for batch in tqdm(dataloader, desc='Training', leave=False):
        src = batch['inputs'].to(device)
        trg = batch['outputs'].to(device)
        src_len = batch['input_lengths'].to(device) 

        optimizer.zero_grad()

        output = model(src, src_len, trg[:, :-1], teacher_forcing_ratio=0.5) 

        output_dim = output.shape[-1]
        output = output.reshape(-1, output_dim) 
        targets = trg[:, 1:].reshape(-1)        

        loss = criterion(output, targets)
        total_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip) 
        optimizer.step()
        if scheduler: 
             scheduler.step() 

    return total_loss / len(dataloader)

In [127]:
def validate(model, dataloader, criterion):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validation', leave=False):
            src = batch['inputs'].to(device)
            trg = batch['outputs'].to(device)
            src_len = batch['input_lengths'].to(device) # Get lengths

             # Turn off teacher forcing for validation
            output = model(src, src_len, trg[:,:-1], teacher_forcing_ratio=0.0)

            output_dim = output.shape[-1]
            output = output.reshape(-1, output_dim)
            targets = trg[:, 1:].reshape(-1)

            loss = criterion(output, targets)
            total_loss += loss.item()

        return total_loss / len(dataloader)

In [128]:
def finite_state_prediction(logits, state, tokenizer): 

    logits = logits.squeeze().reshape(1, -1)

    state_dict = ['number', 'degree', 'exponent', 'sign']
    assert state in state_dict, "state not in predefined states"

    number_states = [idx for idx, val in enumerate(tokenizer.dec_vocab) if val in '1234567890']
    degree_states = [idx for idx, val in enumerate(tokenizer.dec_vocab) if val.startswith('x') and val[1:].isdigit() or val == 'EOS']
    sign_states = [idx for idx, val in enumerate(tokenizer.dec_vocab) if val in '+-']
    exponent_states = [idx for idx, val in enumerate(tokenizer.dec_vocab) if val.startswith('E')]

    state_indices = {
        'number': number_states,
        'degree': degree_states,
        'exponent': exponent_states,
        'sign': sign_states
    }

    valid_indices = state_indices[state]

    if not valid_indices:
        return torch.tensor(tokenizer.dec_token_to_id['EOS']).to(logits.device)

    state_logits = logits[0, valid_indices]

    best_local_idx = torch.argmax(state_logits).item()

    max_idx = valid_indices[best_local_idx]

    return torch.tensor(max_idx).to(logits.device)

In [129]:
def evaluate_model(model, df, tokenizer):
    model.eval()
    preds = []
    targets = []

    max_len = 1 + 5 * (3 + tokenizer.precision) + 1 

    sos_idx = tokenizer.dec_token_to_id['SOS']
    eos_idx = tokenizer.dec_token_to_id['EOS']

    with torch.no_grad():
        for i, row in tqdm(df.iterrows(), desc="Generating Predictions: ", total=df.shape[0]):
            function = row['simplified_functions']
            polynomial = row['taylor']

            src_tokens = tokenizer.encode_enc(function)
            src_tensor = torch.tensor(src_tokens).unsqueeze(0).to(device)
            src_len = torch.tensor([len(src_tokens)]).to(device)

            target_tokens = tokenizer.encode_dec(polynomial)
            targets.append(tokenizer.seq_to_coeffs(tokenizer.decode_dec(target_tokens))) 

            trg_indexes = [sos_idx]
            cur_state = 'degree' 
            next_state_in = 1 

            encoder_outputs, hidden, cell = model.encoder(src_tensor, src_len)
            mask = model.create_mask(src_tensor) 

            for _ in range(max_len):
                trg_tensor = torch.tensor([trg_indexes[-1]]).to(device) 

                output, hidden, cell = model.decoder(trg_tensor, hidden, cell, encoder_outputs, mask)

                pred_token = finite_state_prediction(output, cur_state, tokenizer).item() 

                trg_indexes.append(pred_token)

                next_state_in -= 1
                if next_state_in <= 0:

                    predicted_token_str = tokenizer.dec_id_to_token.get(pred_token, '')
                    if predicted_token_str.startswith('x') or predicted_token_str == 'EOS':
                         cur_state = 'sign'
                         next_state_in = 1
                    elif predicted_token_str in ['+', '-']:
                         cur_state = 'exponent'
                         next_state_in = 1
                    elif predicted_token_str.startswith('E'):
                         cur_state = 'number'
                         next_state_in = tokenizer.precision + 1 
                    elif predicted_token_str.isdigit():

                         num_digits_so_far = sum(1 for tok_idx in trg_indexes \
                                                 if tokenizer.dec_id_to_token.get(tok_idx, '').isdigit())
                         exp_idx = -1
                         for k, tok_idx in enumerate(reversed(trg_indexes)):
                              if tokenizer.dec_id_to_token.get(tok_idx, '').startswith('E'):
                                  exp_idx = len(trg_indexes) - 1 - k
                                  break
                         digits_after_exp = sum(1 for k, tok_idx in enumerate(trg_indexes) if k > exp_idx and \
                                                  tokenizer.dec_id_to_token.get(tok_idx, '').isdigit())

                         if digits_after_exp >= tokenizer.precision + 1: 
                              cur_state = 'degree'
                              next_state_in = 1
                         else:
                              cur_state = 'number' 
                              next_state_in = 1 

                    else: 

                         cur_state = 'degree'
                         next_state_in = 1

                if pred_token == eos_idx:
                    break

            pred_decoded_tokens = tokenizer.decode_dec(trg_indexes)
            pred_coeffs = tokenizer.seq_to_coeffs(pred_decoded_tokens)
            preds.append(pred_coeffs)

    return preds, targets

In [130]:
def polynomial_rmse(preds, targets, n=100, x_range=(-1, 1)):
    rmse_sum_sq = 0.0
    count = 0

    for pred, target in zip(preds, targets):

        if not isinstance(pred, (list, np.ndarray)) or not isinstance(target, (list, np.ndarray)):
            continue
        if not all(isinstance(x, (int, float, np.number)) for x in pred) or \
           not all(isinstance(x, (int, float, np.number)) for x in target):
            continue

        x = sp.symbols('x')
        try:
            pred_poly = sum(float(coef) * x**i for i, coef in enumerate(pred))
            target_poly = sum(float(coef) * x**i for i, coef in enumerate(target))
        except (TypeError, ValueError) as e:
            continue

        x_vals = np.linspace(x_range[0], x_range[1], n)

        try:
            pred_func = sp.lambdify(x, pred_poly, 'numpy')
            target_func = sp.lambdify(x, target_poly, 'numpy')

            y_pred = pred_func(x_vals)
            y_true = target_func(x_vals)

            if np.any(np.isnan(y_pred)) or np.any(np.isinf(y_pred)) or \
               np.any(np.isnan(y_true)) or np.any(np.isinf(y_true)):
                 continue

            rmse_sum_sq += mean_squared_error(y_true, y_pred)
            count += 1
        except Exception as e:
            continue

    if count == 0:
        print("Warning: No valid pairs processed for polynomial RMSE.")
        return float('nan')

    mean_rmse_sq = rmse_sum_sq / count
    return np.sqrt(mean_rmse_sq)

In [131]:
def coeff_rmse(preds, targets):
    rmse_sum_sq = 0.0
    count = 0
    for pred, target in zip(preds, targets):

        if not isinstance(pred, (list, np.ndarray)) or not isinstance(target, (list, np.ndarray)) or len(pred) != len(target):
            continue
        if not all(isinstance(x, (int, float, np.number)) for x in pred) or \
           not all(isinstance(x, (int, float, np.number)) for x in target):
            continue

        try:

             pred_arr = np.array(pred, dtype=float)
             target_arr = np.array(target, dtype=float)

             if np.any(np.isnan(pred_arr)) or np.any(np.isinf(pred_arr)) or \
                np.any(np.isnan(target_arr)) or np.any(np.isinf(target_arr)):
                  continue

             rmse_sum_sq += mean_squared_error(target_arr, pred_arr)
             count += 1
        except Exception as e:
             continue

    if count == 0:
        print("Warning: No valid pairs processed for coefficient RMSE.")
        return float('nan')

    mean_rmse_sq = rmse_sum_sq / count
    return np.sqrt(mean_rmse_sq)

In [132]:
tokenizer = Tokenizer(precision=4) 
tokenizer.fit(df['simplified_functions'])

Fitting Tokenizer (Encoder):   0%|          | 0/2521 [00:00<?, ?it/s]

Tokenizer fitted.
Encoder vocab size: 37
Decoder vocab size: 31


In [133]:
train_df, test_df = train_test_split(df, train_size=0.9, random_state=seed)
train_df, val_df = train_test_split(train_df, train_size=0.8, random_state=seed)

train_data = TaylorDataset(train_df, tokenizer)
val_data = TaylorDataset(val_df, tokenizer)

train_load = DataLoader(train_data, shuffle=True, batch_size=64, collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_load = DataLoader(val_data, shuffle=False, batch_size=64, collate_fn=collate_fn, num_workers=2, pin_memory=True)

In [134]:
ENC_EMBED_DIM = 256
DEC_EMBED_DIM = 256
HID_DIM = 512 
N_LAYERS = 2
ENC_DROPOUT = 0.3
DEC_DROPOUT = 0.3
LEARNING_RATE = 1e-4 
WEIGHT_DECAY = 0.01
CLIP = 1.0 

attn = Attention(HID_DIM, HID_DIM)
enc = EncoderLSTM(tokenizer.enc_vocab_size, ENC_EMBED_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = DecoderLSTM(tokenizer.dec_vocab_size, DEC_EMBED_DIM, HID_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT, attn)

model = LSTMSeq2Seq(enc, dec, device).to(device)

In [135]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

The model has 8,986,911 trainable parameters


In [136]:
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.dec_token_to_id['PAD'],
                               weight=tokenizer.target_weights.to(device))
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50) 

In [137]:
num_epochs = 100
best_val_loss = float('inf')
save_path = 'best_lstm_model.pth' 
patience = 10
no_improve_epochs = 0

print("Starting training...")
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_load, criterion, optimizer, scheduler, CLIP) 
    val_loss = validate(model, val_load, criterion)

    if scheduler:
         scheduler.step()

    print(f"Epoch {epoch+1:02}/{num_epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f} - LR: {optimizer.param_groups[0]['lr']:.6f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), save_path)
        print(f"*** New best model saved with validation loss: {best_val_loss:.4f} ***")
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1
        print(f"Validation loss did not improve for {no_improve_epochs} epoch(s).")
        if no_improve_epochs >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs!")
            break
    print("-" * 50)

print("Training finished.")

Starting training...


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 01/100 - Train Loss: 3.1519 - Val Loss: 2.7497 - LR: 0.000035
*** New best model saved with validation loss: 2.7497 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 02/100 - Train Loss: 2.6835 - Val Loss: 2.6911 - LR: 0.000010
*** New best model saved with validation loss: 2.6911 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 03/100 - Train Loss: 2.5821 - Val Loss: 2.4619 - LR: 0.000090
*** New best model saved with validation loss: 2.4619 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 04/100 - Train Loss: 2.2988 - Val Loss: 2.3905 - LR: 0.000065
*** New best model saved with validation loss: 2.3905 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 05/100 - Train Loss: 2.0531 - Val Loss: 2.5977 - LR: 0.000000
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 06/100 - Train Loss: 1.9926 - Val Loss: 2.8431 - LR: 0.000065
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 07/100 - Train Loss: 1.7826 - Val Loss: 2.0543 - LR: 0.000090
*** New best model saved with validation loss: 2.0543 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 08/100 - Train Loss: 1.5635 - Val Loss: 1.9893 - LR: 0.000010
*** New best model saved with validation loss: 1.9893 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 09/100 - Train Loss: 1.5080 - Val Loss: 2.2965 - LR: 0.000035
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 10/100 - Train Loss: 1.5021 - Val Loss: 2.5327 - LR: 0.000100
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 11/100 - Train Loss: 1.4170 - Val Loss: 1.9989 - LR: 0.000035
Validation loss did not improve for 3 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 12/100 - Train Loss: 1.3731 - Val Loss: 1.4919 - LR: 0.000010
*** New best model saved with validation loss: 1.4919 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 13/100 - Train Loss: 1.3650 - Val Loss: 1.4815 - LR: 0.000090
*** New best model saved with validation loss: 1.4815 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 14/100 - Train Loss: 1.3286 - Val Loss: 1.3641 - LR: 0.000065
*** New best model saved with validation loss: 1.3641 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 15/100 - Train Loss: 1.2974 - Val Loss: 1.3591 - LR: 0.000000
*** New best model saved with validation loss: 1.3591 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 16/100 - Train Loss: 1.2828 - Val Loss: 1.3545 - LR: 0.000065
*** New best model saved with validation loss: 1.3545 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 17/100 - Train Loss: 1.2681 - Val Loss: 1.3081 - LR: 0.000090
*** New best model saved with validation loss: 1.3081 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 18/100 - Train Loss: 1.2421 - Val Loss: 1.2954 - LR: 0.000010
*** New best model saved with validation loss: 1.2954 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 19/100 - Train Loss: 1.2253 - Val Loss: 1.2898 - LR: 0.000035
*** New best model saved with validation loss: 1.2898 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 20/100 - Train Loss: 1.2184 - Val Loss: 1.2862 - LR: 0.000100
*** New best model saved with validation loss: 1.2862 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 21/100 - Train Loss: 1.2159 - Val Loss: 1.2794 - LR: 0.000035
*** New best model saved with validation loss: 1.2794 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 22/100 - Train Loss: 1.2022 - Val Loss: 1.2594 - LR: 0.000010
*** New best model saved with validation loss: 1.2594 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 23/100 - Train Loss: 1.1911 - Val Loss: 1.2566 - LR: 0.000090
*** New best model saved with validation loss: 1.2566 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 24/100 - Train Loss: 1.1875 - Val Loss: 1.2535 - LR: 0.000065
*** New best model saved with validation loss: 1.2535 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 25/100 - Train Loss: 1.1703 - Val Loss: 1.2421 - LR: 0.000000
*** New best model saved with validation loss: 1.2421 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 26/100 - Train Loss: 1.1698 - Val Loss: 1.2340 - LR: 0.000065
*** New best model saved with validation loss: 1.2340 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 27/100 - Train Loss: 1.1525 - Val Loss: 1.2279 - LR: 0.000090
*** New best model saved with validation loss: 1.2279 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 28/100 - Train Loss: 1.1559 - Val Loss: 1.2237 - LR: 0.000010
*** New best model saved with validation loss: 1.2237 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 29/100 - Train Loss: 1.1533 - Val Loss: 1.2136 - LR: 0.000035
*** New best model saved with validation loss: 1.2136 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 30/100 - Train Loss: 1.1479 - Val Loss: 1.2356 - LR: 0.000100
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 31/100 - Train Loss: 1.1488 - Val Loss: 1.2183 - LR: 0.000035
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 32/100 - Train Loss: 1.1265 - Val Loss: 1.2024 - LR: 0.000010
*** New best model saved with validation loss: 1.2024 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 33/100 - Train Loss: 1.1281 - Val Loss: 1.2271 - LR: 0.000090
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 34/100 - Train Loss: 1.1295 - Val Loss: 1.2046 - LR: 0.000065
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 35/100 - Train Loss: 1.1170 - Val Loss: 1.1924 - LR: 0.000000
*** New best model saved with validation loss: 1.1924 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 36/100 - Train Loss: 1.1064 - Val Loss: 1.1997 - LR: 0.000065
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 37/100 - Train Loss: 1.1119 - Val Loss: 1.1929 - LR: 0.000090
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 38/100 - Train Loss: 1.1063 - Val Loss: 1.1797 - LR: 0.000010
*** New best model saved with validation loss: 1.1797 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 39/100 - Train Loss: 1.0913 - Val Loss: 1.1847 - LR: 0.000035
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 40/100 - Train Loss: 1.1001 - Val Loss: 1.2336 - LR: 0.000100
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 41/100 - Train Loss: 1.0995 - Val Loss: 1.1793 - LR: 0.000035
*** New best model saved with validation loss: 1.1793 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 42/100 - Train Loss: 1.0831 - Val Loss: 1.1784 - LR: 0.000010
*** New best model saved with validation loss: 1.1784 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 43/100 - Train Loss: 1.0823 - Val Loss: 1.1874 - LR: 0.000090
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 44/100 - Train Loss: 1.0869 - Val Loss: 1.1723 - LR: 0.000065
*** New best model saved with validation loss: 1.1723 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 45/100 - Train Loss: 1.0765 - Val Loss: 1.1621 - LR: 0.000000
*** New best model saved with validation loss: 1.1621 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 46/100 - Train Loss: 1.0661 - Val Loss: 1.1689 - LR: 0.000065
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 47/100 - Train Loss: 1.0667 - Val Loss: 1.1652 - LR: 0.000090
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 48/100 - Train Loss: 1.0642 - Val Loss: 1.1646 - LR: 0.000010
Validation loss did not improve for 3 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 49/100 - Train Loss: 1.0507 - Val Loss: 1.1528 - LR: 0.000035
*** New best model saved with validation loss: 1.1528 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 50/100 - Train Loss: 1.0570 - Val Loss: 1.1875 - LR: 0.000100
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 51/100 - Train Loss: 1.0619 - Val Loss: 1.1575 - LR: 0.000035
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 52/100 - Train Loss: 1.0370 - Val Loss: 1.1576 - LR: 0.000010
Validation loss did not improve for 3 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 53/100 - Train Loss: 1.0395 - Val Loss: 1.1674 - LR: 0.000090
Validation loss did not improve for 4 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 54/100 - Train Loss: 1.0448 - Val Loss: 1.1784 - LR: 0.000065
Validation loss did not improve for 5 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 55/100 - Train Loss: 1.0326 - Val Loss: 1.1615 - LR: 0.000000
Validation loss did not improve for 6 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 56/100 - Train Loss: 1.0237 - Val Loss: 1.1413 - LR: 0.000065
*** New best model saved with validation loss: 1.1413 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 57/100 - Train Loss: 1.0309 - Val Loss: 1.1477 - LR: 0.000090
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 58/100 - Train Loss: 1.0264 - Val Loss: 1.1443 - LR: 0.000010
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 59/100 - Train Loss: 1.0199 - Val Loss: 1.1554 - LR: 0.000035
Validation loss did not improve for 3 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 60/100 - Train Loss: 1.0170 - Val Loss: 1.1657 - LR: 0.000100
Validation loss did not improve for 4 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 61/100 - Train Loss: 1.0201 - Val Loss: 1.1809 - LR: 0.000035
Validation loss did not improve for 5 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 62/100 - Train Loss: 0.9996 - Val Loss: 1.1453 - LR: 0.000010
Validation loss did not improve for 6 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 63/100 - Train Loss: 1.0067 - Val Loss: 1.1588 - LR: 0.000090
Validation loss did not improve for 7 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 64/100 - Train Loss: 1.0141 - Val Loss: 1.1787 - LR: 0.000065
Validation loss did not improve for 8 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 65/100 - Train Loss: 1.0020 - Val Loss: 1.1356 - LR: 0.000000
*** New best model saved with validation loss: 1.1356 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 66/100 - Train Loss: 0.9808 - Val Loss: 1.1362 - LR: 0.000065
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 67/100 - Train Loss: 0.9896 - Val Loss: 1.1352 - LR: 0.000090
*** New best model saved with validation loss: 1.1352 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 68/100 - Train Loss: 0.9923 - Val Loss: 1.1268 - LR: 0.000010
*** New best model saved with validation loss: 1.1268 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 69/100 - Train Loss: 0.9684 - Val Loss: 1.1358 - LR: 0.000035
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 70/100 - Train Loss: 0.9861 - Val Loss: 1.1278 - LR: 0.000100
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 71/100 - Train Loss: 0.9924 - Val Loss: 1.1491 - LR: 0.000035
Validation loss did not improve for 3 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 72/100 - Train Loss: 0.9698 - Val Loss: 1.1270 - LR: 0.000010
Validation loss did not improve for 4 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 73/100 - Train Loss: 0.9724 - Val Loss: 1.1751 - LR: 0.000090
Validation loss did not improve for 5 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 74/100 - Train Loss: 0.9831 - Val Loss: 1.1208 - LR: 0.000065
*** New best model saved with validation loss: 1.1208 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 75/100 - Train Loss: 0.9651 - Val Loss: 1.1240 - LR: 0.000000
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 76/100 - Train Loss: 0.9549 - Val Loss: 1.1499 - LR: 0.000065
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 77/100 - Train Loss: 0.9604 - Val Loss: 1.1437 - LR: 0.000090
Validation loss did not improve for 3 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 78/100 - Train Loss: 0.9596 - Val Loss: 1.1264 - LR: 0.000010
Validation loss did not improve for 4 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 79/100 - Train Loss: 0.9447 - Val Loss: 1.1267 - LR: 0.000035
Validation loss did not improve for 5 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 80/100 - Train Loss: 0.9424 - Val Loss: 1.1495 - LR: 0.000100
Validation loss did not improve for 6 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 81/100 - Train Loss: 0.9504 - Val Loss: 1.1139 - LR: 0.000035
*** New best model saved with validation loss: 1.1139 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 82/100 - Train Loss: 0.9361 - Val Loss: 1.1264 - LR: 0.000010
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 83/100 - Train Loss: 0.9330 - Val Loss: 1.1345 - LR: 0.000090
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 84/100 - Train Loss: 0.9389 - Val Loss: 1.1354 - LR: 0.000065
Validation loss did not improve for 3 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 85/100 - Train Loss: 0.9299 - Val Loss: 1.1318 - LR: 0.000000
Validation loss did not improve for 4 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 86/100 - Train Loss: 0.9225 - Val Loss: 1.1145 - LR: 0.000065
Validation loss did not improve for 5 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 87/100 - Train Loss: 0.9373 - Val Loss: 1.1355 - LR: 0.000090
Validation loss did not improve for 6 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 88/100 - Train Loss: 0.9257 - Val Loss: 1.1095 - LR: 0.000010
*** New best model saved with validation loss: 1.1095 ***
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 89/100 - Train Loss: 0.9084 - Val Loss: 1.1356 - LR: 0.000035
Validation loss did not improve for 1 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 90/100 - Train Loss: 0.9201 - Val Loss: 1.1724 - LR: 0.000100
Validation loss did not improve for 2 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 91/100 - Train Loss: 0.9231 - Val Loss: 1.1702 - LR: 0.000035
Validation loss did not improve for 3 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 92/100 - Train Loss: 0.9035 - Val Loss: 1.1268 - LR: 0.000010
Validation loss did not improve for 4 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 93/100 - Train Loss: 0.9102 - Val Loss: 1.1283 - LR: 0.000090
Validation loss did not improve for 5 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 94/100 - Train Loss: 0.9128 - Val Loss: 1.1248 - LR: 0.000065
Validation loss did not improve for 6 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 95/100 - Train Loss: 0.9015 - Val Loss: 1.1190 - LR: 0.000000
Validation loss did not improve for 7 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 96/100 - Train Loss: 0.8980 - Val Loss: 1.1611 - LR: 0.000065
Validation loss did not improve for 8 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 97/100 - Train Loss: 0.9057 - Val Loss: 1.2190 - LR: 0.000090
Validation loss did not improve for 9 epoch(s).
--------------------------------------------------


Training:   0%|          | 0/29 [00:00<?, ?it/s]

Validation:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 98/100 - Train Loss: 0.9141 - Val Loss: 1.1227 - LR: 0.000010
Validation loss did not improve for 10 epoch(s).
Early stopping triggered after 98 epochs!
Training finished.


In [138]:
print(f"Loading best model from {save_path}")
try:
    model.load_state_dict(torch.load(save_path, map_location=device))
except FileNotFoundError:
    print(f"Error: Saved model file {save_path} not found. Evaluating with the last state.")
except Exception as e:
     print(f"Error loading model state_dict: {e}. Evaluating with the last state.")

print("Evaluating model on test set...")
preds, targets = evaluate_model(model, test_df, tokenizer)

print("Calculating final metrics...")
polynomial_rmse_value = polynomial_rmse(preds, targets)
coeff_rmse_value = coeff_rmse(preds, targets)

print("\n--- Evaluation Results ---")
print(f"Polynomial RMSE: {polynomial_rmse_value:.6f}")
print(f"Coefficient RMSE: {coeff_rmse_value:.6f}")
print("------------------------")

print("\nExample Predictions vs Targets (Coefficients):")
for i in range(min(5, len(preds))):
     print(f" Pred {i+1}: {[f'{x:.3f}' for x in preds[i]]}")
     print(f" Target {i+1}: {[f'{x:.3f}' for x in targets[i]]}")
     print("-" * 20)

Loading best model from best_lstm_model.pth
Evaluating model on test set...


  model.load_state_dict(torch.load(save_path, map_location=device))


Generating Predictions:   0%|          | 0/253 [00:00<?, ?it/s]

Calculating final metrics...

--- Evaluation Results ---
Polynomial RMSE: 323.214037
Coefficient RMSE: 315.131449
------------------------

Example Predictions vs Targets (Coefficients):
 Pred 1: ['0.000', '0.000', '-0.000', '-0.667', '0.000']
 Target 1: ['1.570', '-1.000', '-1.000', '-0.667', '-0.667']
--------------------
 Pred 2: ['0.000', '0.000', '0.000', '-0.000', '-0.667']
 Target 2: ['0.000', '2.000', '0.000', '-2.833', '0.000']
--------------------
 Pred 3: ['0.000', '0.000', '-0.000', '0.000', '-0.010']
 Target 3: ['1.000', '1.000', '0.500', '-0.167', '-0.292']
--------------------
 Pred 4: ['0.000', '0.000', '0.000', '0.000', '0.000']
 Target 4: ['0.000', '3.000', '0.000', '-1.333', '0.000']
--------------------
 Pred 5: ['0.000', '0.000', '0.000', '0.000', '0.000']
 Target 5: ['0.000', '2.000', '0.000', '1.500', '0.000']
--------------------
