In [1]:
import sympy as sp
from sympy import *
import pandas as pd
import re
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Transformer
import math
import os
import random
import torch
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from dataclasses import dataclass, field, fields
from typing import Optional

In [2]:
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i : i + n]

In [3]:
class Tokenizer:
    def __init__(self, vocab_path):
        self.vocab_path = vocab_path
        self.word2id = {}
        self.id2word = {}

        with open(vocab_path) as file:
            words = map(lambda x: x.rstrip('\n'), file.readlines())

        for (n, word) in enumerate(words):
            self.word2id[word] = n
            self.id2word[n] = word 

    def encode(self, lst):
        return np.array([[self.word2id[j] for j in i] for i in lst], dtype=np.ushort)

    def decode(self, lst):
        return [[self.id2word[j] for j in i] for i in lst]

In [4]:
class Encoder_tokeniser(Tokenizer):
    def __init__(self,float_precision,mantissa_len,max_exponent,vocab_path,max_len = 10):
        super().__init__(vocab_path)
        
        self.max_len = max_len
        self.float_precision = float_precision
        self.mantissa_len = mantissa_len
        self.max_exponent = max_exponent
        self.base = (self.float_precision + 1) // self.mantissa_len
        self.max_token = 10 ** self.base
        
    def pre_tokenize(self, data):
        arr = np.array([i.split() for i in data], dtype=np.float32)
        permutation = [-1] + [i for i in range(arr.shape[1]-1)]
        arr = np.pad(arr[:, permutation], ((0,0), (0, self.max_len - arr.shape[1])), mode="constant", constant_values=[-np.inf])
        return arr
    
    def tokenize(self, data):
        out = self.pre_tokenize(data)
        out = self.encode_float(out)
        out = self.encode(out)
        return out
        
    def encode_float(self,values):
        if len(values.shape) == 1:
            seq = []
            value = values
            for val in value:
                if val in [-np.inf, np.inf]:
                    seq.extend(['<pad>']*3)
                    continue
                
                sign = "+" if val >= 0 else "-"
                m, e = (f"%.{self.float_precision}e" % val).split("e")
                i, f = m.lstrip("-").split(".")
                i = i + f
                tokens = chunks(i, self.base)
                expon = int(e) - self.float_precision
                if expon < -self.max_exponent:
                    tokens = ["0" * self.base] * self.mantissa_len
                    expon = int(0)
                seq.extend([sign, *["N" + token for token in tokens], "E" + str(expon)])
            return seq
        else:
            seqs = [self.encode_float(values[0])]
            N = values.shape[0]
            for n in range(1, N):
                seqs += [self.encode_float(values[n])]
        return seqs
    def decode_float(self,seq):
        decoded_seq = []
        for val in chunks(decoded_seq, 2 + self.mantissa_len):
            for x in val:
                if x[0] not in ["-", "+", "E", "N"]:
                    return np.nan
            try:
                sign = 1 if val[0] == "+" else -1
                mant = ""
                for x in val[1:-1]:
                    mant += x[1:]
                mant = int(mant)
                exp = int(val[-1][1:])
                value = sign * mant * (10 ** exp)
                value = float(value)
            except Exception:
                value = np.nan
            decoded_seq.append(value)
        return decoded_seq

In [5]:
df_target = pd.read_csv('/kaggle/input/gsoc-symba-task/FeynmanEquations.csv')

In [6]:
df_target

Unnamed: 0,Filename,Number,Output,Formula,# variables,v1_name,v1_low,v1_high,v2_name,v2_low,...,v7_high,v8_name,v8_low,v8_high,v9_name,v9_low,v9_high,v10_name,v10_low,v10_high
0,I.6.2a,1.0,f,exp(-theta**2/2)/sqrt(2*pi),1.0,theta,1.0,3.0,,,...,,,,,,,,,,
1,I.6.2,2.0,f,exp(-(theta/sigma)**2/2)/(sqrt(2*pi)*sigma),2.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,
2,I.6.2b,3.0,f,exp(-((theta-theta1)/sigma)**2/2)/(sqrt(2*pi)*...,3.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,
3,I.8.14,4.0,d,sqrt((x2-x1)**2+(y2-y1)**2),4.0,x1,1.0,5.0,x2,1.0,...,,,,,,,,,,
4,I.9.18,5.0,F,G*m1*m2/((x2-x1)**2+(y2-y1)**2+(z2-z1)**2),9.0,m1,1.0,2.0,m2,1.0,...,2.0,z1,3.0,4.0,z2,1.0,2.0,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
125,,,,,,,,,,,...,,,,,,,,,,
126,,,,,,,,,,,...,,,,,,,,,,
127,,,,,,,,,,,...,,,,,,,,,,
128,,,,,,,,,,,...,,,,,,,,,,


In [7]:
df_target = df_target.dropna(subset=['Filename'])
df_target

Unnamed: 0,Filename,Number,Output,Formula,# variables,v1_name,v1_low,v1_high,v2_name,v2_low,...,v7_high,v8_name,v8_low,v8_high,v9_name,v9_low,v9_high,v10_name,v10_low,v10_high
0,I.6.2a,1.0,f,exp(-theta**2/2)/sqrt(2*pi),1.0,theta,1.0,3.0,,,...,,,,,,,,,,
1,I.6.2,2.0,f,exp(-(theta/sigma)**2/2)/(sqrt(2*pi)*sigma),2.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,
2,I.6.2b,3.0,f,exp(-((theta-theta1)/sigma)**2/2)/(sqrt(2*pi)*...,3.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,
3,I.8.14,4.0,d,sqrt((x2-x1)**2+(y2-y1)**2),4.0,x1,1.0,5.0,x2,1.0,...,,,,,,,,,,
4,I.9.18,5.0,F,G*m1*m2/((x2-x1)**2+(y2-y1)**2+(z2-z1)**2),9.0,m1,1.0,2.0,m2,1.0,...,2.0,z1,3.0,4.0,z2,1.0,2.0,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,III.15.14,96.0,m,(h/(2*pi))**2/(2*E_n*d**2),3.0,h,1.0,5.0,E_n,1.0,...,,,,,,,,,,
96,III.15.27,97.0,k,2*pi*alpha/(n*d),3.0,alpha,1.0,5.0,n,1.0,...,,,,,,,,,,
97,III.17.37,98.0,f,beta*(1+alpha*cos(theta)),3.0,beta,1.0,5.0,alpha,1.0,...,,,,,,,,,,
98,III.19.51,99.0,E_n,-m*q**4/(2*(4*pi*epsilon)**2*(h/(2*pi))**2)*(1...,4.0,m,1.0,5.0,q,1.0,...,,,,,,,,,,


In [8]:
df_target.loc[82, '# variables'] = 3
df_target.loc[18,'Filename'] = 'I.15.10'
df_target.loc[49,'Filename'] = 'I.48.20'
df_target.loc[61,'Filename'] = 'II.11.7'

In [10]:
used_variables = set()

In [11]:
for i in range(0,len(df_target)):
    for j in range(0,int(df_target.iloc[i]['# variables'])):
        used_variables.add(df_target.iloc[i][f'v{j+1}_name'])

In [12]:
variables = []

In [13]:
for i in range(1,10):  
    variables.append(f'v{i}')

In [14]:
operators = {
    # Elementary functions
    sp.Add: 'add',
    sp.Mul: 'mul',
    sp.Pow: 'pow',
    sp.exp: 'exp',
    sp.log: 'ln',
    sp.Abs: 'abs',
    sp.sign: 'sign',
    # Trigonometric Functions
    sp.sin: 'sin',
    sp.cos: 'cos',
    sp.tan: 'tan',
    sp.cot: 'cot',
    sp.sec: 'sec',
    sp.csc: 'csc',
    # Trigonometric Inverses
    sp.asin: 'asin',
    sp.acos: 'acos',
    sp.atan: 'atan',
    sp.acot: 'acot',
    sp.asec: 'asec',
    sp.acsc: 'acsc',
    # Hyperbolic Functions
    sp.sinh: 'sinh',
    sp.cosh: 'cosh',
    sp.tanh: 'tanh',
    sp.coth: 'coth',
    sp.sech: 'sech',
    sp.csch: 'csch',
    # Hyperbolic Inverses
    sp.asinh: 'asinh',
    sp.acosh: 'acosh',
    sp.atanh: 'atanh',
    sp.acoth: 'acoth',
    sp.asech: 'asech',
    sp.acsch: 'acsch',
    # Derivative
    sp.Derivative: 'derivative',
}

operators_inv = {operators[key]: key for key in operators}
operators_inv["mul("] = sp.Mul
operators_inv["add("] = sp.Add

operators_nargs = {
    # Elementary functions
    'mul(': -1,
    'add(': -1,
    'add': 2,
    'sub': 2,
    'mul': 2,
    'div': 2,
    'pow': 2,
    'rac': 2,
    'inv': 1,
    'pow2': 1,
    'pow3': 1,
    'pow4': 1,
    'pow5': 1,
    'sqrt': 1,
    'exp': 1,
    'ln': 1,
    'abs': 1,
    'sign': 1,
    # Trigonometric Functions
    'sin': 1,
    'cos': 1,
    'tan': 1,
    'cot': 1,
    'sec': 1,
    'csc': 1,
    # Trigonometric Inverses
    'asin': 1,
    'acos': 1,
    'atan': 1,
    'acot': 1,
    'asec': 1,
    'acsc': 1,
    # Hyperbolic Functions
    'sinh': 1,
    'cosh': 1,
    'tanh': 1,
    'coth': 1,
    'sech': 1,
    'csch': 1,
    # Hyperbolic Inverses
    'asinh': 1,
    'acosh': 1,
    'atanh': 1,
    'acoth': 1,
    'asech': 1,
    'acsch': 1,
    # Derivative
    'derivative': 2,
    # custom functions
    'f': 1,
    'g': 2,
    'h': 3,
}

masses_strings = [
        "m_e",
        "m_u",
        "m_d",
        "m_s",
        "m_c",
        "m_b",
        "m_t",
        ]

masses = [sp.Symbol(x) for x in masses_strings]

# these will be converted to the numbers format in `format_number`
integers_types = [
        sp.core.numbers.Integer,
        sp.core.numbers.One,
        sp.core.numbers.NegativeOne,
        sp.core.numbers.Zero,
        ]

numbers_types = integers_types + [sp.core.numbers.Rational,
        sp.core.numbers.Half, sp.core.numbers.Exp1, sp.core.numbers.Pi, "<class 'sympy.core.numbers.Pi'>",
        sp.core.numbers.ImaginaryUnit]

# don't continue evaluating at these, but stop
atoms = [
        str,
        sp.core.symbol.Symbol,
        sp.core.numbers.Exp1,
        sp.core.numbers.Pi,
        "<class 'sympy.core.numbers.Pi'>",
        ] + numbers_types


Inverse_trig = {
    'arcsin': 'asin',
    'arccos': 'acos',
    'arctan': 'atan',
    'arccot': 'acot',
    'arcsec': 'asec',
    'arccsc': 'acsc',
    'arcsinh': 'asinh',
    'arccosh': 'acosh',
    'arctanh': 'atanh',
    'arccoth': 'acoth',
    'arcsech': 'asech',
    'arccsch': 'acsch',         
}

In [15]:
def sympy_expression(formula):
    # create a map of variables
    variables_map = {key : sp.Symbol(key) for key in used_variables}

    for a in Inverse_trig.keys():
        formula = re.sub(a,Inverse_trig[a],formula)

    # Convert to sympy expression
    return sp.sympify(formula, locals=variables_map)

In [16]:
def flatten(l, ltypes=(list, tuple)):
    """
    flatten a python list
    from http://rightfootin.blogspot.com/2006/09/more-on-python-flatten.html
    """
    ltype = type(l)
    l = list(l)
    i = 0
    while i < len(l):
        while isinstance(l[i], ltypes):
            if not l[i]:
                l.pop(i)
                i -= 1
                break
            else:
                l[i:i + 1] = l[i]
        i += 1
    return ltype(l)

In [17]:
def sympy_to_prefix(expression):
    """
    Recursively go from a sympy expression to a prefix notation.
    Returns a flat list of tokens.
    """
    return flatten(sympy_to_prefix_rec(expression, []))

def sympy_to_prefix_rec(expression, ret):
    """
    Recursively go from a sympy expression to a prefix notation.
    The operators all get converted to their names in the array `operators`.
    Returns a nested list, where the nesting basically stands for parentheses.
    Since in prefix notation with a fixed number of arguments for each function (given in `operators_nargs`),
    parentheses are not needed, we can flatten the list later.
    """
    if expression in [sp.core.numbers.Pi, sp.core.numbers.ImaginaryUnit]:
        f = expression
    else:
        f = expression.func
    if f in atoms:
        if type(expression) in numbers_types:
            return ret + format_number(expression)
        return ret+[str(expression)]
    f_str = operators[f]
    f_nargs = operators_nargs[f_str]
    args = expression.args
    if len(args) == 1 & f_nargs == 1:
        ret = ret + [f_str]
        return sympy_to_prefix_rec(args[0], ret)
    if len(args) == 2:
        ret = ret + [f_str, sympy_to_prefix_rec(args[0], []), sympy_to_prefix_rec(args[1], [])]
    if len(args) > 2:
        args = list(map(lambda x: sympy_to_prefix_rec(x, []), args))
        ret = ret + repeat_operator_until_correct_binary(f_str, args)
    return ret
def repeat_operator_until_correct_binary(op, args, ret=[]):
    """
    sympy is not strict enough with the number of arguments.
    E.g. multiply takes a variable number of arguments, but for
    prefix notation it needs to ALWAYS have exactly 2 arguments

    This function is only for binary operators.

    Here I choose the convention as follows:
        1 + 2 + 3 --> + 1 + 2 3

    This is the same convention as in https://arxiv.org/pdf/1912.01412.pdf
    on page 15.

    input:
        op: in string form as in the list `operators`
        args: [arg1, arg2, ...] arguments of the operator, e.c. [1, 2, x**2,
                ...]. They can have other things to be evaluated in them
        ret: the list you already have. Usually []. Watch out, I think one has to explicitely give [],
            otherwise somehow the default value gets mutated, which I find a strange python behavior.
    """

    is_binary = operators_nargs[op] == 2
    assert is_binary, "repeat_operator_until_correct_binary only takes binary operators"

    if len(args) == 0:
        return ret
    elif len(ret) == 0:
        ret = [op] + args[-2:]
        args = args[:-2]
    else:
        ret = [op] + args[-1:] + ret
        args = args[:-1]

    return repeat_operator_until_correct_binary(op, args, ret)

def format_number(number):
    if type(number) in integers_types:
        return format_integer(number)
    elif type(number) == sp.core.numbers.Rational:
        return format_rational(number)
    elif type(number) == sp.core.numbers.Half:
        return format_half()
    elif type(number) == sp.core.numbers.Exp1:
        return format_exp1()
    elif type(number) == sp.core.numbers.Pi:
        return format_pi()
    elif type(number) == sp.core.numbers.ImaginaryUnit:
        return format_imaginary_unit()
    else:
        raise NotImplementedError

def format_exp1():
    return ['E']

def format_pi():
    return ['pi']

def format_imaginary_unit():
    return ['I']

def format_half():
    """
    for some reason in sympy 1/2 is its own object and not a rational.
    This function formats it correctly like `format_rational`
    """
    return ['mul'] + ['s+', '1'] + ['pow'] + ['s+', '2'] + ["s-", "1"]

def format_rational(number):
    # for some reason number.p is a string
    p = sp.sympify(number.p)
    q = sp.sympify(number.q)
    return ['mul'] + format_integer(p) + ['pow'] + format_integer(q) + ['s-', '1']

def format_integer(integer):
    """take a sympy integer and format it as in
    https://arxiv.org/pdf/1912.01412.pdf

    input:
        integer: a `sympy.Integer` object, e.g. `sympy.Integer(-1)`

    output:
        [sign_token, digit0, digit1, ...]
        where sign_token is 's+' or 's-'

    Example:
        format_integer(sympy.Integer(-123))
        >> ['s-', '1', '2', '3']

    Implementation notes:
    Somehow Integer inherits from Rational in Sympy and a rational is p/q,
    so integer.p is used to extract the number.
    """
    # plus_sign = "s+"
    plus_sign = "s+"
    minus_sign = "s-"
    abs_num = abs(integer.p)
    is_neg = integer.could_extract_minus_sign()
    digits = list(str(abs_num))
    # digits = [str(abs_num)]

    if is_neg:
        ret = [minus_sign] + digits
    else:
        ret = [plus_sign] + digits

    return ret

In [18]:
def parse_if_str(x):
    if isinstance(x, str):
        return sp.parsing.parse_expr(x)
    return x

In [19]:
def rightmost_string_pos(expr_arr, pos=-1):
    if isinstance(expr_arr[pos], str):
        return len(expr_arr)+pos
    else:
        return rightmost_string_pos(expr_arr, pos-1)


def rightmost_operand_pos(expr, pos=-1):
    operators = list(operators_inv.keys()) + ["s+", "s-"] + variables
    if expr[pos] in operators:
        return len(expr) + pos
    else:
        return rightmost_operand_pos(expr, pos-1)

def unformat_integer(arr):
    """
    inverse of the function format_integer.

    input:
        arr: array of strings just as the output of format_integer. E.g. ["s+", "4", "2"]

    output:
        the correspinding sympy integer, e.g. sympy.Integer(42) in the above example.

    The sign tokens are "s+" for positive integers and "s-" for negative. 0 comes with "s+", but does not matter.

    """
    sign_token = arr[0]
    ret = "-" if sign_token == "s-" else ""
    for s in arr[1:]:
        ret += str(s)

    return sp.parsing.parse_expr(ret)

In [20]:
def prefix_to_sympy(expr_arr):
    if len(expr_arr) == 1:
        return parse_if_str(expr_arr[0])
    op_pos = rightmost_operand_pos(expr_arr)
    if (op_pos == -1) | (op_pos == len(expr_arr)):
        print("something went wrong, operator should not be at end of array")
    op = expr_arr[op_pos]
    if op in operators_inv.keys():
        num_args = operators_nargs[op]
        op = operators_inv[op]
        args = expr_arr[op_pos+1:op_pos+num_args+1]
        args = [parse_if_str(a) for a in args]
        func = op(*args)
        expr = expr_arr[0:op_pos] + [func] + expr_arr[op_pos+num_args+1:]
        return prefix_to_sympy(expr)

    elif (op == 's+') | (op == "s-"):
        # int_end_pos = rightmost_int_pos(expr_arr)
        string_end_pos = rightmost_string_pos(expr_arr)
        integer = unformat_integer(expr_arr[op_pos:string_end_pos+1])
        expr_arr_new = expr_arr[0:op_pos] + [integer] + expr_arr[string_end_pos+1:]
        return prefix_to_sympy(expr_arr_new)
    elif op in variables:
        op = sp.sympify(op)
        expr_arr_new = expr_arr[0:op_pos] + [op] + expr_arr[op_pos+1:]
        return prefix_to_sympy(expr_arr_new)

    return op

In [21]:
prefix_lists = []
for i in range(0,len(df_target)):
    x = sympy_expression(df_target['Formula'][i])
    prefix_lists.append(sympy_to_prefix(x))
df_target = df_target.assign(Prefix_lists=prefix_lists)

In [22]:
for i in range(0,len(df_target)): 
    for j in range(0,int(df_target.iloc[i]['# variables'])):
        for k in range(0,len(df_target['Prefix_lists'][i])):
            if(str(df_target['Prefix_lists'][i][k]) == str(df_target[f'v{j+1}_name'][i])):
                df_target['Prefix_lists'][i][k] = f'v{j+1}'

In [23]:
df_target.to_csv('Modified_csv')

In [24]:
class DecoderTokenizer(Tokenizer):
    def __init__(self, vocab_path):
        super().__init__(vocab_path)

    def equation_encoder(self, data):
        return [sympy_to_prefix(expr) for expr in data]
    
    def equation_decoder(self, data):
        return [prefix_to_sympy(lst) for lst in data]

    def pre_tokenize(self, data):
        return data
    
    def tokenize(self, data):
        out = self.pre_tokenize(data)
        out = self.equation_encoder(out)
        out = [['<bos>'] + i + ['<eos>'] for i in out]
        out = self.encode(out)
        return out
    
    def reverse_tokenize(self, data):
        out = self.decode(data)
        out = self.equation_decoder(out)
        return out

In [25]:
INPUT_DIR = '/kaggle/input/gsoc-symba-task/Feynman_with_units/Feynman_with_units/'

In [26]:
os.makedirs('dataset_arrays')

In [27]:
PAD_IDX = 0

def prepare_dataset(config):

    input_max_len = config.input_max_len
    df = pd.read_csv(config.df_path)

    encoder_tokenizer = Encoder_tokeniser(2,1,100,config.encoder_vocab)
    decoder_tokenizer = DecoderTokenizer(config.decoder_vocab)

    train_df = {
        "Filename":[],
        "data_num":[], 
        "number":[]
        }
    
    for (index, row) in tqdm(df.iterrows()):
        with open(INPUT_DIR + row['Filename']) as file:
            data = file.readlines()
        X = encoder_tokenizer.tokenize(data)

        n_splits = X.shape[0] // input_max_len
        X = X[:n_splits*input_max_len]
        x_chunks = np.split(X, n_splits)

        sub_dir = os.path.join(config.output_dir, row["Filename"])
        os.makedirs(sub_dir, exist_ok=True)
        
        for (index, x) in enumerate(x_chunks):
            np.save(os.path.join(sub_dir, f"{index}.npy"), x)

        train_df["filename"].extend([row["Filename"]]*n_splits)
        train_df["data_num"].extend([i for i in range(n_splits)])
        train_df["number"].extend([row["Number"] for i in range(n_splits)])

    train_df = pd.DataFrame(train_df)
    
    prefix_equations = np.zeros((100, 256)).astype(np.int32)
    for (index, row) in df.iterrows():
        prefix = eval(row["Prefix_lists"])
        prefix = ["<bos>"] + prefix + ["<eos>"]
        y = decoder_tokenizer.encode([prefix])[0]
        y = np.pad(y, (0, 256 - len(y)))
        prefix_equations[int(row["Number"])-1, :] = y

    path = os.path.join(config.output_dir, "prefix_equations.npy")
    np.save(path, prefix_equations)

    return train_df

class FeynmanDataset(Dataset):
    def __init__(self, df, dataset_dir):
        super().__init__()
        self.df = df
        self.dataset_dir = dataset_dir
        self.prefix_equations = np.load(os.path.join(dataset_dir, "prefix_equations.npy"))
        # prefix_equations = []

        prefix_equations = []
        for prefix in self.prefix_equations:
            prefix_equations.append(np.trim_zeros(prefix))

        self.prefix_equations = prefix_equations

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = os.path.join(os.path.join(self.dataset_dir, row['Filename']), f"{row['data_num']}.npy")
        x = np.load(path).astype(np.int32)

        path = os.path.join(self.dataset_dir, f"{row['Filename']}.npy")
        y = self.prefix_equations[row['number'] - 1]

        return (torch.Tensor(x).long(), torch.Tensor(y).long())

In [28]:

def get_datasets(df, input_df, dataset_dir):
    train_df, test_df = train_test_split(df, test_size=0.1)
    train_equations = train_df['Filename'].tolist()
    test_equations = test_df['Filename'].tolist()

    input_test_df = input_df[input_df['filename'].isin(test_equations)]
    input_train_df = input_df[input_df['filename'].isin(train_equations)]

    input_train_df, input_val_df = train_test_split(input_train_df, test_size = 0.1, shuffle=True)

    train_dataset = FeynmanDataset(input_train_df, dataset_dir)
    val_dataset = FeynmanDataset(input_val_df, dataset_dir)
    test_dataset = FeynmanDataset(input_test_df, dataset_dir)

    datasets = {
        "train":train_dataset,
        "test":test_dataset,
        "valid":val_dataset
        }

    return datasets

def get_dataloaders(datasets, train_bs, test_bs):
    train_dataloader = DataLoader(datasets['train'], batch_size=train_bs,
                                  shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
    val_dataloader = DataLoader(datasets['valid'], batch_size=test_bs,
                                  shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
    test_dataloader = DataLoader(datasets['test'], batch_size=test_bs,
                                  shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
    
    dataloaders = {
        "train":train_dataloader,
        "test":test_dataloader,
        "valid":val_dataloader
        }
    
    return dataloaders

def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for (src_sample, tgt_sample) in batch:
        src_batch.append(src_sample)
        tgt_batch.append(tgt_sample)
        
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)
    return src_batch, tgt_batch

In [29]:
class config:
    def __init__(self):
        self.input_max_len = 1000
        self.df_path = '/kaggle/working/Modified_csv'
        self.encoder_vocab = '/kaggle/input/gsoc-symba-task/encoder_vocab (1).txt'
        self.decoder_vocab = '/kaggle/input/gsoc-symba-task/decoder_vocab (2).txt'
        self.output_dir = '/kaggle/working/dataset_arrays'

In [30]:
config1 = config()

In [31]:
df_target

Unnamed: 0,Filename,Number,Output,Formula,# variables,v1_name,v1_low,v1_high,v2_name,v2_low,...,v8_name,v8_low,v8_high,v9_name,v9_low,v9_high,v10_name,v10_low,v10_high,Prefix_lists
0,I.6.2a,1.0,f,exp(-theta**2/2)/sqrt(2*pi),1.0,theta,1.0,3.0,,,...,,,,,,,,,,"[mul, mul, s+, 1, pow, s+, 2, s-, 1, mul, pow,..."
1,I.6.2,2.0,f,exp(-(theta/sigma)**2/2)/(sqrt(2*pi)*sigma),2.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,"[mul, mul, s+, 1, pow, s+, 2, s-, 1, mul, pow,..."
2,I.6.2b,3.0,f,exp(-((theta-theta1)/sigma)**2/2)/(sqrt(2*pi)*...,3.0,sigma,1.0,3.0,theta,1.0,...,,,,,,,,,,"[mul, mul, s+, 1, pow, s+, 2, s-, 1, mul, pow,..."
3,I.8.14,4.0,d,sqrt((x2-x1)**2+(y2-y1)**2),4.0,x1,1.0,5.0,x2,1.0,...,,,,,,,,,,"[pow, add, pow, add, v2, mul, s-, 1, v1, s+, 2..."
4,I.9.18,5.0,F,G*m1*m2/((x2-x1)**2+(y2-y1)**2+(z2-z1)**2),9.0,m1,1.0,2.0,m2,1.0,...,z1,3.0,4.0,z2,1.0,2.0,,,,"[mul, v3, mul, v1, mul, v2, pow, add, pow, add..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,III.15.14,96.0,m,(h/(2*pi))**2/(2*E_n*d**2),3.0,h,1.0,5.0,E_n,1.0,...,,,,,,,,,,"[mul, mul, s+, 1, pow, s+, 8, s-, 1, mul, pow,..."
96,III.15.27,97.0,k,2*pi*alpha/(n*d),3.0,alpha,1.0,5.0,n,1.0,...,,,,,,,,,,"[mul, s+, 2, mul, pi, mul, v1, mul, pow, v3, s..."
97,III.17.37,98.0,f,beta*(1+alpha*cos(theta)),3.0,beta,1.0,5.0,alpha,1.0,...,,,,,,,,,,"[mul, v1, add, s+, 1, mul, v2, cos, v3]"
98,III.19.51,99.0,E_n,-m*q**4/(2*(4*pi*epsilon)**2*(h/(2*pi))**2)*(1...,4.0,m,1.0,5.0,q,1.0,...,,,,,,,,,,"[mul, mul, s-, 1, pow, s+, 8, s-, 1, mul, v1, ..."


In [32]:
train_df = prepare_dataset(config1)

100it [3:38:59, 131.39s/it]


In [33]:
train_df.to_csv('train_df.csv',index=False)