# generate_data

In [None]:
import random, csv, os, time, re
import sympy as sp

def delete_old_files (file_names):
    for file_name in file_names:
        if os.path.exists(f"/content/drive/MyDrive/u_sub/{file_name}.csv"):
            os.remove(f"/content/drive/MyDrive/u_sub/{file_name}.csv")

def format_expression (expression):
    expression = str(expression)
    expression = expression.replace("**", "^")
    expression = expression.replace("asin", "arcsin")
    expression = expression.replace("acos", "arccos")
    expression = expression.replace("atan", "arctan")
    expression = expression.replace("acsc", "arccosec")
    expression = expression.replace("asec", "arcsec")
    expression = expression.replace("acot", "arccot")
    expression = re.sub(r'sqrt\(([^()]+)\)', r'(\1)^(1/2)', expression)
    expression = re.sub(r'([a-zA-Z]+\([^()]*\)\^[0-9]+)', r'(\1)', expression)
    expression = re.sub(r'([\d]+\^[0-9]+)', r'(\1)', expression)
    expression = re.sub(r'([a-zA-Z]+\^[0-9]+)', r'(\1)', expression)
    expression = expression.replace("^", "**")
    return expression

def linear_composite_functions (count, expressions, subs):
    count *= len (file_names)
    function_choices = {"sin": sp.sin, "cos": sp.cos, "exp": sp.exp, "ln": sp.log,
                        "rational": sp.Pow, "power": sp.Pow, "root": sp.Pow,
                        "tan": sp.tan, "cosec": sp.csc, "sec": sp.sec, "cot": sp.cot,
                        "arcsin": sp.asin, "arccos": sp.acos, "arctan": sp.atan,
                        "arccosec": sp.acsc, "arcsec": sp.asec, "arccot": sp.acot}

    x = sp.symbols('x')

    for _ in range(count):
        slope = random.randint(-100, 100)
        while slope == 0:
            slope = random.randint(-100, 100)

        intercept = random.randint(-100, 100)
        while intercept == 0:
            intercept = random.randint(-100, 100)

        if random.choice([True, False]) and slope != 1:
            intercept = 0

        linear_expr = slope * x + intercept

        func_key = random.choice(list(function_choices.keys()))

        if func_key == "rational":
            composite_function = function_choices[func_key](linear_expr, -1, evaluate=False)
        elif func_key == "power":
            composite_function = function_choices[func_key](linear_expr, random.choice([1, -1])*random.randint(2, 3), evaluate=False)
        elif func_key == "root":
            composite_function = function_choices[func_key](linear_expr, random.choice([1, -1])*sp.Pow(random.randint(2, 3), -1), evaluate=False)
        else:
            composite_function = function_choices[func_key](linear_expr, evaluate=False)

        expressions.append(format_expression (composite_function))
        subs.append(format_expression (linear_expr))

    return expressions, subs

def polynomial_composite_functions (count, expressions, subs):
    count *= len (file_names)
    function_choices = {"sin": sp.sin, "cos": sp.cos, "exp": sp.exp, "ln": sp.log,
                        "rational": sp.Pow, "power": sp.Pow, "root": sp.Pow,
                        "tan": sp.tan, "cosec": sp.csc, "sec": sp.sec, "cot": sp.cot,
                        "arcsin": sp.asin, "arccos": sp.acos, "arctan": sp.atan,
                        "arccosec": sp.acsc, "arcsec": sp.asec, "arccot": sp.acot}

    x = sp.symbols('x')

    for _ in range(count):
        power = random.randint(2, 4)

        polynomial_expr = 0 * x
        for i in range (power):
            coefficient = random.randint(-100, 100)
            while coefficient == 0:
                coefficient = random.randint(-100, 100)
            if i != power-1 and random.choice([True, False]):
                coefficient = 0
            polynomial_expr += coefficient * x**(i+1)

        func_key = random.choice(list(function_choices.keys()))

        if random.choice([True, False]) and func_key != "rational":
            polynomial_expr = sp.Pow(polynomial_expr, -1)

        derivative = sp.diff(polynomial_expr, x).doit()
        if random.choice([True, False]):
            derivative *= random.randint (2, 5)
        if random.choice([True, False]):
            derivative *= -1

        if func_key == "rational":
            composite_function = sp.Mul(derivative, function_choices[func_key](polynomial_expr, -1, evaluate=False), evaluate=False)
        elif func_key == "power":
            polynomial_expr = sp.UnevaluatedExpr(polynomial_expr)
            composite_function = sp.Mul(derivative, function_choices[func_key](polynomial_expr, random.choice([1, -1])*random.randint(2, 4), evaluate=False), evaluate=False)
        elif func_key == "root":
            polynomial_expr = sp.UnevaluatedExpr(polynomial_expr)
            composite_function = sp.Mul(derivative, function_choices[func_key](polynomial_expr, random.choice([1, -1])*sp.Pow(random.randint(2, 4), -1), evaluate=False), evaluate=False)
        else:
            composite_function = sp.Mul(derivative, function_choices[func_key](polynomial_expr, evaluate=False), evaluate=False)

        expressions.append(format_expression (composite_function))
        subs.append(format_expression (polynomial_expr))

    return expressions, subs

def nested_composite_functions (count, expressions, subs):
    count *= len (file_names)
    outer_function_choices = {"sin": sp.sin, "cos": sp.cos, "exp": sp.exp, "ln": sp.log,
                              "rational": sp.Pow, "power": sp.Pow, "root": sp.Pow,
                              "tan": sp.tan, "cosec": sp.csc, "sec": sp.sec, "cot": sp.cot,
                              "arcsin": sp.asin, "arccos": sp.acos, "arctan": sp.atan,
                              "arccosec": sp.acsc, "arcsec": sp.asec, "arccot": sp.acot}

    inner_function_choices = {"sin": sp.sin, "cos": sp.cos, "tan": sp.tan, "cosec": sp.csc, "sec": sp.sec, "cot": sp.cot,
    "arcsin": sp.asin, "arccos": sp.acos, "arctan": sp.atan, "arccosec": sp.acsc, "arcsec": sp.asec, "arccot": sp.acot,
    "exp": sp.exp, "ln": sp.log}

    x = sp.symbols('x')

    for _ in range(count):
        terms = 1
        if random.choice([True, False]):
            terms = random.randint(1, 4)

        nest_expr = 0 * x
        for i in range (terms):
            coefficient = 1
            if random.choice([True, False]):
                coefficient = random.randint(-100, 100)
                while coefficient == 0:
                    coefficient = random.randint(-100, 100)
            inner_func_key = random.choice(list(inner_function_choices.keys()))
            power = 1
            if random.choice([True, False]):
                power = random.randint(2, 4)

        power_inner = random.randint(1, 2)

        polynomial_expr = 0 * x
        for i in range (power_inner):
            coefficient_inner = random.randint(-10, 10)
            while coefficient_inner == 0:
                coefficient_inner = random.randint(-10, 10)
            if i != power_inner-1 and random.choice([True, False]):
                coefficient_inner = 0
            polynomial_expr += coefficient_inner * x**(i+1)

        nest_expr += coefficient * inner_function_choices[inner_func_key](polynomial_expr)**power


        derivative = sp.diff(nest_expr, x).doit()
        if random.choice([True, False]):
            derivative *= random.randint (2, 4)
        if random.choice([True, False]):
            derivative *= -1

        outer_func_key = random.choice(list(outer_function_choices.keys()))

        if outer_func_key == "rational":
            composite_function = sp.Mul(derivative, outer_function_choices[outer_func_key](nest_expr, -1, evaluate=False), evaluate=False)
        elif outer_func_key == "power":
            nest_expr = sp.UnevaluatedExpr(nest_expr)
            composite_function = sp.Mul(derivative, outer_function_choices[outer_func_key](nest_expr, random.choice([1, -1])*random.randint(2, 4), evaluate=False), evaluate=False)
        elif outer_func_key == "root":
            nest_expr = sp.UnevaluatedExpr(nest_expr)
            composite_function = sp.Mul(derivative, outer_function_choices[outer_func_key](nest_expr, random.choice([1, -1])*sp.Pow(random.randint(2, 4), -1), evaluate=False), evaluate=False)
        else:
            composite_function = sp.Mul(derivative, outer_function_choices[outer_func_key](nest_expr, evaluate=False), evaluate=False)

        expressions.append(format_expression (composite_function))
        subs.append(format_expression (nest_expr))

    return expressions, subs

def no_sub_polynomial_functions (count, expressions, subs):
    count *= len (file_names)
    function_choices = {"sin": sp.sin, "cos": sp.cos, "exp": sp.exp, "ln": sp.log,
                        "rational": sp.Pow, "power": sp.Pow, "root": sp.Pow,
                        "tan": sp.tan, "cosec": sp.csc, "sec": sp.sec, "cot": sp.cot,
                        "arcsin": sp.asin, "arccos": sp.acos, "arctan": sp.atan,
                        "arccosec": sp.acsc, "arcsec": sp.asec, "arccot": sp.acot}

    x = sp.symbols('x')

    for _ in range(count):
        power = random.randint(2, 4)

        polynomial_expr = 0 * x
        for i in range (power):
            coefficient = random.randint(-100, 100)
            while coefficient == 0:
                coefficient = random.randint(-100, 100)
            if i != power-1 and random.choice([True, False]):
                coefficient = 0
            polynomial_expr += coefficient * x**(i+1)

        func_key = random.choice(list(function_choices.keys()))

        if random.choice([True, False]) and func_key != "rational":
            polynomial_expr = sp.Pow(polynomial_expr, -1)

        derivative = sp.diff(polynomial_expr, x).doit()
        if random.choice([True, False]):
            derivative *= random.randint (2, 5)
        if random.choice([True, False]):
            derivative *= -1

        choice = random.choice([1, 2, 3, 4])
        terms = sp.Add.make_args(derivative)
        flag = False

        if choice == 1 or len(terms) == 1:
            derivative *= x
        elif choice == 2:
            old_term = random.choice(terms)
            coeff, rest = old_term.as_coeff_Mul()
            new_coeff = coeff + random.choice([-1, 1]) * random.randint(1, 10)
            new_term = sp.Mul(new_coeff, rest)
            derivative += new_term - old_term
        elif choice == 3:
            term = random.choice(terms)
            derivative -= term
        else:
            derivative = 1
            flag = True

        if func_key == "rational":
            composite_function = sp.Mul(derivative, function_choices[func_key](polynomial_expr, -1, evaluate=False), evaluate=flag)
        elif func_key == "power":
            polynomial_expr = sp.UnevaluatedExpr(polynomial_expr)
            composite_function = sp.Mul(derivative, function_choices[func_key](polynomial_expr, random.choice([1, -1])*random.randint(2, 4), evaluate=False), evaluate=flag)
        elif func_key == "root":
            polynomial_expr = sp.UnevaluatedExpr(polynomial_expr)
            composite_function = sp.Mul(derivative, function_choices[func_key](polynomial_expr, random.choice([1, -1])*sp.Pow(random.randint(2, 4), -1), evaluate=False), evaluate=flag)
        else:
            composite_function = sp.Mul(derivative, function_choices[func_key](polynomial_expr, evaluate=False), evaluate=flag)

        expressions.append(format_expression (composite_function))
        subs.append(format_expression (polynomial_expr*0))

    return expressions, subs

def no_sub_nested_functions (count, expressions, subs):
    count *= len (file_names)
    outer_function_choices = {"sin": sp.sin, "cos": sp.cos, "exp": sp.exp, "ln": sp.log,
                              "rational": sp.Pow, "power": sp.Pow, "root": sp.Pow,
                              "tan": sp.tan, "cosec": sp.csc, "sec": sp.sec, "cot": sp.cot,
                              "arcsin": sp.asin, "arccos": sp.acos, "arctan": sp.atan,
                              "arccosec": sp.acsc, "arcsec": sp.asec, "arccot": sp.acot}
    inner_function_choices = {"sin": sp.sin, "cos": sp.cos, "tan": sp.tan, "cosec": sp.csc, "sec": sp.sec, "cot": sp.cot,
    "arcsin": sp.asin, "arccos": sp.acos, "arctan": sp.atan, "arccosec": sp.acsc, "arcsec": sp.asec, "arccot": sp.acot,
    "exp": sp.exp, "ln": sp.log}

    x = sp.symbols('x')

    for _ in range(count):
        terms = 1
        if random.choice([True, False]):
            terms = random.randint(1, 4)

        nest_expr = 0 * x
        for i in range (terms):
            coefficient = 1
            if random.choice([True, False]):
                coefficient = random.randint(-100, 100)
                while coefficient == 0:
                    coefficient = random.randint(-100, 100)
            inner_func_key = random.choice(list(inner_function_choices.keys()))
            power = 1
            if random.choice([True, False]):
                power = random.randint(2, 4)

        power_inner = random.randint(1, 2)

        polynomial_expr = 0 * x
        for i in range (power_inner):
            coefficient_inner = random.randint(-10, 10)
            while coefficient_inner == 0:
                coefficient_inner = random.randint(-10, 10)
            if i != power_inner-1 and random.choice([True, False]):
                coefficient_inner = 0
            polynomial_expr += coefficient_inner * x**(i+1)

        nest_expr += coefficient * inner_function_choices[inner_func_key](polynomial_expr)**power

        derivative = sp.diff(nest_expr, x).doit()
        if random.choice([True, False]):
            derivative *= random.randint (2, 4)
        if random.choice([True, False]):
            derivative *= -1

        choice = random.choice([1, 2, 3, 4, 5])
        terms = sp.Add.make_args(derivative)
        flag = False

        if choice == 1:
            derivative *= x
        elif choice == 2 and len(terms) != 1:
            old_term = random.choice(terms)
            coeff, rest = old_term.as_coeff_Mul()
            new_coeff = coeff + random.choice([-1, 1]) * random.randint(1, 10)
            new_term = sp.Mul(new_coeff, rest)
            derivative += new_term - old_term
        elif choice == 3 and len(terms) != 1:
            term = random.choice(terms)
            derivative -= term
        elif choice == 4:
            new_key = random.choice(list(inner_function_choices.keys()))
            derivative *= inner_function_choices[new_key](x)
        else:
            derivative = 1
            flag = True

        outer_func_key = random.choice(list(outer_function_choices.keys()))

        if outer_func_key == "rational":
            composite_function = sp.Mul(derivative, outer_function_choices[outer_func_key](nest_expr, -1, evaluate=False), evaluate=flag)
        elif outer_func_key == "power":
            nest_expr = sp.UnevaluatedExpr(nest_expr)
            composite_function = sp.Mul(derivative, outer_function_choices[outer_func_key](nest_expr, random.choice([1, -1])*random.randint(2, 4), evaluate=False), evaluate=flag)
        elif outer_func_key == "root":
            nest_expr = sp.UnevaluatedExpr(nest_expr)
            composite_function = sp.Mul(derivative, outer_function_choices[outer_func_key](nest_expr, random.choice([1, -1])*sp.Pow(random.randint(2, 4), -1), evaluate=False), evaluate=flag)
        else:
            composite_function = sp.Mul(derivative, outer_function_choices[outer_func_key](nest_expr, evaluate=False), evaluate=flag)

        expressions.append(format_expression (composite_function))
        subs.append(format_expression (nest_expr*0))

    return expressions, subs

def no_sub_linear_functions (count, expressions, subs):
    count *= len (file_names)
    function_choices = {"sin": sp.sin, "cos": sp.cos, "exp": sp.exp, "ln": sp.log,
                        "rational": sp.Pow, "power": sp.Pow, "root": sp.Pow,
                        "tan": sp.tan, "cosec": sp.csc, "sec": sp.sec, "cot": sp.cot,
                        "arcsin": sp.asin, "arccos": sp.acos, "arctan": sp.atan,
                        "arccosec": sp.acsc, "arcsec": sp.asec, "arccot": sp.acot}
    x = sp.symbols('x')

    for _ in range(count):
        slope = random.randint(-100, 100)
        while slope == 0:
            slope = random.randint(-100, 100)

        intercept = random.randint(-100, 100)
        while intercept == 0:
            intercept = random.randint(-100, 100)

        if random.choice([True, False]) and slope != 1:
            intercept = 0

        linear_expr = slope * x + intercept
        multiplied_term = x**random.randint(1, 4)

        func_key = random.choice(list(function_choices.keys()))

        if func_key == "rational":
            composite_function = function_choices[func_key](linear_expr, -1, evaluate=False)
        elif func_key == "power":
            composite_function = function_choices[func_key](linear_expr, random.choice([1, -1])*random.randint(2, 3), evaluate=False)
        elif func_key == "root":
            composite_function = function_choices[func_key](linear_expr, random.choice([1, -1])*sp.Pow(random.randint(2, 3), -1), evaluate=False)
        else:
            composite_function = function_choices[func_key](linear_expr, evaluate=False)

        composite_function = sp.Mul(composite_function, multiplied_term)

        expressions.append(format_expression (composite_function))
        subs.append(format_expression (linear_expr*0))

    return expressions, subs


def special_cases (count, expressions, subs):
    count *= len (file_names)
    function_choices = {"sin": sp.sin, "cos": sp.cos, "exp": sp.exp}
    x = sp.symbols('x')

    for _ in range(count):
        power = 2
        polynomial_expr = 0 * x
        for i in range (power):
            coefficient = random.randint(1, 5)*random.choice([-1, 1])
            if i != power-1 and random.choice([True, False]):
                coefficient = 0
            polynomial_expr += sp.Mul(coefficient, x**(i+1), evaluate=True)

        func_key = random.choice(list(function_choices.keys()))

        derivative = sp.diff(polynomial_expr, x).doit()
        if random.choice([True, False]):
            derivative *= random.randint (2, 5)
        if random.choice([True, False]):
            derivative *= -1

        choice = random.choice([1, 2, 3, 4])
        terms = sp.Add.make_args(derivative)
        flag = False
        temp = polynomial_expr

        if choice == 1:
            derivative = sp.expand(derivative*x)
            polynomial_expr *= 0
        elif choice == 2 or choice == 3:
            pass
        else:
            derivative = 1
            polynomial_expr *= 0
            flag = True

        composite_function = sp.Mul(derivative, function_choices[func_key](temp, evaluate=False), evaluate=flag)

        expressions.append(format_expression (composite_function))
        subs.append(format_expression (polynomial_expr))

    return expressions, subs

def shuffle_lists (expressions, subs):
    combined = list(zip(expressions, subs))
    random.shuffle(combined)
    expressions = [combined[i][0] for i in range (len(combined))]
    subs = [combined[i][1] for i in range (len(combined))]
    return expressions, subs

def write_to_files (file_names, expressions, subs):
    start = 0
    end = len (expressions) // len (file_names)
    example_limit = round (0.05*end)

    for file_name in file_names:
        file_exists = os.path.isfile(f"/content/drive/MyDrive/u_sub/{file_name}.csv")
        with open(f"/content/drive/MyDrive/u_sub/{file_name}.csv", mode="a", newline="") as file:
            writer = csv.writer(file)

            if not file_exists:
                writer.writerow(["expression", "sub"])

            if file_name in ["validation"] and end - start > example_limit:
              end = start + example_limit

            for exp, sub_exp in zip(expressions[start:end], subs[start:end]):
                writer.writerow([exp, sub_exp])
        start, end = end, end + end

def batch_data_generation (file_names, data_generation_functions, batch_count, batch_size):
    start = time.time()

    if round(sum(list(data_generation_functions.values())[i]["weight"] for i in range(len(list(data_generation_functions.values())))), 15) != 1:
        print ("ERROR: Function weights must sum to 1")
        return

    print("Data Generation Started\n---")

    delete_old_files(file_names)

    for i in range (batch_count):
        expressions = []
        subs = []

        for function, function_info in data_generation_functions.items():
            expressions, subs = function (int(batch_size*function_info["weight"]), expressions, subs)

        expressions, subs = shuffle_lists (expressions, subs)
        write_to_files(file_names, expressions, subs)
        print(f"Batch {i+1} completed")

    end = time.time()
    print (f"---\nData Generation Completed\nTime taken: {(round(end-start)//60)} minutes, {round((end-start)%60)} seconds\n---")

    for file_name in file_names:
        with open(f"/content/drive/MyDrive/u_sub/{file_name}.csv", "r") as file:
            reader = csv.reader(file)
            row_count = sum(1 for row in reader) - 1

        print(f"{file_name}.csv contains {row_count} examples")
    print("---")

def input_weights (data_generation_functions):
    while True:
        default = True if input("---\nDo you want to use the default function weights?\nYes or No: ").lower() == "yes" else False
        default_weights = {linear_composite_functions: 0.26,
                          polynomial_composite_functions: 0.27,
                          nested_composite_functions: 0.27,
                          no_sub_polynomial_functions: 0.05,
                          no_sub_nested_functions: 0.05,
                          no_sub_linear_functions: 0.05,
                          special_cases: 0.05}

        if not default:
            print ("---")
            print(f"There are {len(list(default_weights.keys()))} functions in total.")

        for function_info in data_generation_functions.values():
            function_info["weight"] = 0
        i = 0
        for function, function_info in data_generation_functions.items():
            i += 1
            if default:
                function_info["weight"] = default_weights[function]

            else:
                if function == list(data_generation_functions.keys())[-1]:
                    function_info["weight"] = round(1 - sum(list(data_generation_functions.values())[i]["weight"] for i in range(len(list(data_generation_functions.values())))), 15)
                    print (f"Setting final weight to {function_info['weight']}")

                else:
                    function_info["weight"] = float(input (f"{i}. Enter the weight for {function.__name__}: "))

                    while function_info["weight"] < 0 or function_info["weight"] > 1:
                        print ("ERROR: Weight must be between 0 and 1")
                        function_info["weight"] = float(input (f"{i}. Enter the weight for {function.__name__}: "))

                    while sum(list(data_generation_functions.values())[i]["weight"] for i in range(len(list(data_generation_functions.values())))) > 1:
                        print ("ERROR: Weights must sum to 1")
                        function_info["weight"] = float(input (f"{i}. Enter the weight for {function.__name__}: "))

                    if sum(list(data_generation_functions.values())[i]["weight"] for i in range(len(list(data_generation_functions.values())))) == 1:
                        print ("Weights sum to 1, setting subsequent weights to 0")
                        break

        print ("---\nData generation will use these weights:")
        for function, function_info in data_generation_functions.items():
            print (f"{function.__name__} weight: {function_info['weight']}")
        print ("---")
        proceed = True if input("Do you wish to proceed?\nYes or No: ").lower() == "yes" else False
        if proceed:
            break

    print ("---")
    return data_generation_functions

if __name__ == "__main__":
    file_names = ["train", "evaluate", "validation"]
    data_generation_functions = {linear_composite_functions: {},
                                 polynomial_composite_functions: {},
                                 nested_composite_functions: {},
                                 no_sub_polynomial_functions: {},
                                 no_sub_nested_functions: {},
                                 no_sub_linear_functions: {},
                                 special_cases: {}}

    data_generation_functions = input_weights (data_generation_functions)
    batch_count = int(input("Enter the number of batches: "))
    batch_size = int(input("Enter the batch size: "))
    batch_data_generation (file_names, data_generation_functions, batch_count, batch_size)

# train_model

In [None]:
!pip uninstall transformers -y
!pip install transformers==4.45.2
!pip install datasets evaluate

In [None]:
import os, shutil
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_dataset
from evaluate import load

if os.path.exists("/content/drive/MyDrive/u_sub/t5-u-sub"):
    shutil.rmtree("/content/drive/MyDrive/u_sub/t5-u-sub")

checkpoint = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(checkpoint)
model = T5ForConditionalGeneration.from_pretrained(checkpoint)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)

metric = load("exact_match")

def tokenizer_function (example):
    model_inputs = tokenizer(example["expression"], truncation=True, max_length=512)
    labels = tokenizer(text_target=example["sub"], truncation=True, max_length=512)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return result

data_files = {"train": "/content/drive/MyDrive/u_sub/train.csv", "evaluate": "/content/drive/MyDrive/u_sub/evaluate.csv"}
dataset = load_dataset("csv", data_files=data_files)
tokenized_dataset = dataset.map(tokenizer_function, batched=True)

training_args = Seq2SeqTrainingArguments(
    "/content/drive/MyDrive/u_sub/training_outputs",
    num_train_epochs=6,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    warmup_steps=1000,
    weight_decay=0.01,
    predict_with_generate=True,

    eval_strategy="steps",
    eval_steps=250,
    save_strategy="steps",
    save_steps=250,
    load_best_model_at_end=True,
    save_total_limit=3,

    logging_steps=100,
    max_grad_norm=1.0,
    report_to=["none"],
    fp16=True,
    dataloader_num_workers=8
)

model.gradient_checkpointing_enable()

trainer = Seq2SeqTrainer(
    model,
    training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["evaluate"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics)

trainer.train()

trainer.save_model("/content/drive/MyDrive/u_sub/t5-u-sub")
if os.path.exists("/content/drive/MyDrive/u_sub/training_outputs"):
    shutil.rmtree("/content/drive/MyDrive/u_sub/training_outputs")

# validate_model

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from datasets import load_dataset
import torch

checkpoint = "AryaR-06/t5-u-sub"
tokenizer = T5Tokenizer.from_pretrained(checkpoint)
model = T5ForConditionalGeneration.from_pretrained(checkpoint)

data_files = {"validation": "/content/drive/MyDrive/u_sub/validation.csv"}
dataset = load_dataset("csv", data_files=data_files)

def tokenize_function(example):
    return tokenizer(example["expression"], truncation=True, padding="max_length", max_length=512, return_tensors="pt")

tokenized_dataset = dataset["validation"].map(tokenize_function, batched=True, remove_columns=dataset["validation"].column_names)

model.eval()

predictions = []
labels = []

no_sub = 0
no_sub_incorrect = 0
sub = 0
sub_incorrect = 0

with torch.no_grad():
    for i in range(len(tokenized_dataset)):
        input_ids = torch.tensor(tokenized_dataset[i]["input_ids"])
        attention_mask = torch.tensor(tokenized_dataset[i]["attention_mask"])

        if input_ids.dim() == 1:
            input_ids = input_ids.unsqueeze(0)
        if attention_mask.dim() == 1:
            attention_mask = attention_mask.unsqueeze(0)

        outputs = model.generate(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 max_length=512,
                                 num_beams=4)

        pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
        label = dataset["validation"][i]["sub"]

        predictions.append(pred)
        labels.append(label)
        if label == "0":
            no_sub += 1
            if pred != label:
                no_sub_incorrect += 1
                print(f"{i+1}:")
                print(f"Input: {dataset['validation'][i]['expression']}")
                print(f"Prediction: {pred}")
                print(f"Label: {label}")
                print("---")
        else:
            sub += 1
            if pred != label:
                sub_incorrect += 1
                print(f"{i+1}:")
                print(f"Input: {dataset['validation'][i]['expression']}")
                print(f"Prediction: {pred}")
                print(f"Label: {label}")
                print("---")

print (f"Total Exact Match Score: {1 - (no_sub_incorrect + sub_incorrect)/(no_sub + sub)}")
print (f"No Sub Exact Match Score: {1 - no_sub_incorrect/no_sub}")
print (f"Sub Exact Match Score: {1 - sub_incorrect/sub}")

# user_input

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
import re
import sympy as sp

def generate_u_sub (expression):
    checkpoint = "AryaR-06/t5-u-sub"
    tokenizer = T5Tokenizer.from_pretrained(checkpoint)
    model = T5ForConditionalGeneration.from_pretrained(checkpoint)

    model.eval()

    input = tokenizer(expression, truncation=True, padding="max_length", max_length=512, return_tensors="pt")

    with torch.no_grad():
        output = model.generate(input_ids=input["input_ids"],
                                 attention_mask=input["attention_mask"],
                                  max_length=512,
                                  num_beams=4)

        prediction = tokenizer.decode(output[0], skip_special_tokens=True)

        return prediction

def expr_tree_len(expr, max_depth=0, cur_depth=0):
    cur_depth += 1
    max_depth = max(max_depth, cur_depth)

    for arg in expr.args:
        max_depth = expr_tree_len(arg, max_depth=max_depth, cur_depth=cur_depth)

    return max_depth

def format_input (expression,early_return=False):
    if expression.isalpha() and len(expression) != 1:
        raise ValueError

    expression = expression.replace("^", "**")
    expression = expression.replace("arcsin", "asin")
    expression = expression.replace("arccos", "acos")
    expression = expression.replace("arctan", "atan")
    expression = expression.replace("arccosec", "acsc")
    expression = expression.replace("arcsec", "asec")
    expression = expression.replace("arccot", "acot")

    early_expression = sp.sympify(expression)

    if early_return:
         return early_expression

    expression = str(expression)
    expression = expression.replace("**", "^")
    expression = expression.replace("asin", "arcsin")
    expression = expression.replace("acos", "arccos")
    expression = expression.replace("atan", "arctan")
    expression = expression.replace("acsc", "arccosec")
    expression = expression.replace("asec", "arcsec")
    expression = expression.replace("acot", "arccot")
    expression = re.sub(r'sqrt\(([^()]+)\)', r'(\1)^(1/2)', expression)
    expression = re.sub(r'([a-zA-Z]+\([^()]*\)\^[0-9]+)', r'(\1)', expression)
    expression = re.sub(r'([\d]+\^[0-9]+)', r'(\1)', expression)
    expression = re.sub(r'([a-zA-Z]+\^[0-9]+)', r'(\1)', expression)
    expression = expression.replace("^", "**")

    return expression

def get_user_input():
    while True:
        try:
            user_input = input("Enter an expression to integrate by substitution: ")
            formatted_expression = format_input(user_input)
            break
        except ValueError:
            print("Invalid expression formatting. Please try again.")

    return formatted_expression

def find_answer(formatted_expression):
    statements_to_print = []
    statements_to_print.append(f"(1)  {chr(0x222B)}({formatted_expression.replace('**', '^')})*dx")

    prediction = generate_u_sub (formatted_expression)

    if prediction == "0":
        statements_to_print.append("No Substitution Found")
        return statements_to_print
    else:
        statements_to_print.append(f"(2)  Let: u = {prediction.replace('**', '^')}")

        try:
            derivative = sp.diff(sp.sympify(prediction))
        except ValueError:
            statements_to_print.append("No Substitution Found")
            return statements_to_print
        if derivative == 0:
            statements_to_print.append("No Substitution Found")
            return statements_to_print

        statements_to_print.append(f"(3)  du = ({str(derivative).replace('**', '^')})*dx")
        statements_to_print.append(f"(4)  dx = du/({str(derivative).replace('**', '^')})")

        if formatted_expression.replace(prediction, 'u') == formatted_expression:
            statements_to_print.append("No Substitution Found")
            return statements_to_print

        substituted_integrand = format_input(f"{formatted_expression.replace(prediction, 'u')}/({str(derivative).replace('**', '^')})", True)
        substituted_integrand = sp.ratsimp(sp.trigsimp(substituted_integrand))

        if "x" in str(substituted_integrand).replace("exp","temp"):
            statements_to_print.append("No Substitution Found")
            return statements_to_print

        statements_to_print.append(f"(5)  {chr(0x222B)}({str(substituted_integrand).replace('**', '^')})*du")
        return statements_to_print

def print_answer(statements_to_print):
    if "No Substitution Found" in statements_to_print:
        print(statements_to_print[0])
        print("No Substitutions Found")
    else:
        for i in range(len(statements_to_print)):
            print(statements_to_print[i])
            if i == 0 or i == 3:
                print("-"*len(max(statements_to_print, key=lambda x: len(x))))

if __name__ == "__main__":
    integrand = get_user_input()
    solution = find_answer(integrand)
    print_answer(solution)