# INSTALLATIONS

In [None]:
!pip install lightning

# IMPORTS

In [None]:
import argparse
import json
import logging
import os
import random
import re
import time
from difflib import SequenceMatcher

import numpy as np
import pandas as pd
import torch
from sympy import N, symbols, sympify
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)
from tqdm.auto import tqdm

import lightning as pl

# MODEL TRAINING LOOP

## T5 INITIALIZATION

In [None]:
class T5FineTuner(pl.LightningModule):
  def __init__(self, hparams, train_data, val_data):
    super(T5FineTuner, self).__init__()
    self.save_hyperparameters(hparams)
    self.train_dataset = train_data
    self.val_dataset = val_data
    self.model = T5ForConditionalGeneration.from_pretrained(hparams.model_name_or_path)
    self.model.train()
    self.tokenizer = T5Tokenizer.from_pretrained(hparams.tokenizer_name_or_path)

  def is_logger(self):
    return self.trainer.global_rank <= 0

  def forward(
      self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, labels=None
  ):
    return self.model(
        input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=decoder_input_ids,
        decoder_attention_mask=decoder_attention_mask,
        labels = labels,
    )

  def _step(self, batch):
    labels = batch["target_ids"].clone()
    labels[labels[:, :] == self.tokenizer.pad_token_id] = -100

    outputs = self(
        input_ids=batch["source_ids"],
        attention_mask=batch["source_mask"],
        labels= labels,
        decoder_attention_mask=batch['target_mask']
    )

    loss = outputs[0]
    return loss

  def training_step(self, batch, batch_idx):
      input_ids = batch["source_ids"]
      attention_mask = batch["source_mask"]
      labels = batch["target_ids"]

      outputs = self.model(
          input_ids=input_ids,
          attention_mask=attention_mask,
          labels=labels
      )
      loss = outputs.loss

      self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
      return loss

  def validation_step(self, batch, batch_idx):
    self.model.eval()
    loss = self._step(batch)
    tensorboard_logs = {"val_loss": loss}

    self.log("val_loss", loss)
    return {"val_loss": loss}

  def configure_optimizers(self):
    "Prepare optimizer and schedule (linear warmup and decay)"

    model = self.model
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": self.hparams.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
    self.opt = optimizer
    return [optimizer]

  def optimizer_step(self, epoch=None, batch_idx=None, optimizer=None, optimizer_closure=None,):
    optimizer.step(optimizer_closure)
    optimizer.zero_grad()
    self.lr_scheduler.step()

  def get_tqdm_dict(self):
    tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}

    return tqdm_dict

  def train_dataloader(self):
    train_dataset = self.train_dataset
    dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size, drop_last=True, shuffle=True, num_workers=4)
    t_total = (
        (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
        // self.hparams.gradient_accumulation_steps
        * float(self.hparams.num_train_epochs)
    )
    scheduler = get_linear_schedule_with_warmup(
        self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
    )
    self.lr_scheduler = scheduler
    return dataloader

  def val_dataloader(self):
    val_dataset = self.val_dataset
    return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size, num_workers=4)

logger = logging.getLogger(__name__)

class LoggingCallback(pl.Callback):
  def on_validation_end(self, trainer, pl_module):
    logger.info("***** Validation results *****")
    if pl_module.is_logger():
      metrics = trainer.callback_metrics
      # Log results
      for key in sorted(metrics):
        if key not in ["log", "progress_bar"]:
          logger.info("{} = {}\n".format(key, str(metrics[key])))
          print("{} = {}\n".format(key, str(metrics[key])))

class PredictionCallback(pl.Callback):
    def __init__(self, tokenizer, example_text):
        self.tokenizer = tokenizer
        self.example_text = example_text

    def on_train_epoch_end(self, trainer, pl_module):
        print(f"\n[Callback] Epoch {trainer.current_epoch}\n")

        pl_module.model.eval()

        input_ids = self.tokenizer(
            self.example_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=100
        ).input_ids.to(pl_module.device)

        with torch.no_grad():
            output_ids = pl_module.model.generate(
                input_ids=input_ids,
                max_length=30,
                do_sample=False,
                num_beams=4,
                early_stopping=True
            )

tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')

## DATA UPLOAD AND CLEANING

In [None]:
train_df = pd.read_csv('train_answerextracted.csv')
val_df = pd.read_csv('validation_answerextracted.csv')
test_df = pd.read_csv('test_answerextracted.csv')

def insert_spaces(formula):
    if not isinstance(formula, str):
        return formula
    return re.sub(r'([(),])', r' \1 ', formula).replace("  ", " ").strip()


def remove_const(expression):
    return re.sub(r'const_([-0-9_.]+)', r'\1', expression)

ops = ['add', 'subtract', 'multiply', 'divide', 'power', 'sqrt', 'log', 'choose', 'speed',
       'volume_rectangular_prism', 'square_area', 'circle_area', 'circumface']

def fuse_operator_parens(expression, operators):
    for op in operators:
        expression = re.sub(rf'\b{op}\s*\(', f'{op}(', expression)
    return expression

train_df['annotated_formula'] = train_df['annotated_formula'].apply(insert_spaces)
val_df['annotated_formula'] = val_df['annotated_formula'].apply(insert_spaces)
test_df['annotated_formula'] = test_df['annotated_formula'].apply(insert_spaces)

train_df['annotated_formula'] = train_df['annotated_formula'].apply(remove_const)
val_df['annotated_formula'] = val_df['annotated_formula'].apply(remove_const)
test_df['annotated_formula'] = test_df['annotated_formula'].apply(remove_const)

train_df['annotated_formula'] = train_df['annotated_formula'].apply(lambda x: fuse_operator_parens(x, ops))
val_df['annotated_formula'] = val_df['annotated_formula'].apply(lambda x: fuse_operator_parens(x, ops))
test_df['annotated_formula'] = test_df['annotated_formula'].apply(lambda x: fuse_operator_parens(x, ops))

train_df['count'] = train_df["annotated_formula"].apply(lambda x: len(tokenizer.encode(x, truncation=False)))
train_df = train_df[train_df["count"] <= 30]
train_df['count2'] = train_df["Problem"].apply(lambda x: len(tokenizer.encode(x, truncation=False)))
train_df = train_df[train_df["count2"] <= 100]

val_df['count'] = val_df["annotated_formula"].apply(lambda x: len(tokenizer.encode(x, truncation=False)))
val_df = val_df[val_df["count"] <= 30]
val_df['count2'] = val_df["Problem"].apply(lambda x: len(tokenizer.encode(x, truncation=False)))
val_df = val_df[val_df["count2"] <= 100]

test_df['count'] = test_df["annotated_formula"].apply(lambda x: len(tokenizer.encode(x, truncation=False)))
test_df = test_df[test_df["count"] <= 30]
test_df['count2'] = test_df["Problem"].apply(lambda x: len(tokenizer.encode(x, truncation=False)))
test_df = test_df[test_df["count2"] <= 100]

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

## CREATE DATASET

In [None]:
class SATDataset(Dataset):
  def __init__(self, tokenizer, data,  max_len=100):
    self.data_column = "Problem"
    self.class_column = "annotated_formula"
    self.data = data

    self.max_len = max_len
    self.tokenizer = tokenizer
    self.inputs = []
    self.targets = []

    self._build()

  def __len__(self):
    return len(self.inputs)

  def __getitem__(self, index):
    source_ids = self.inputs[index]["input_ids"].squeeze(0)
    target_ids = self.targets[index]["input_ids"].squeeze(0)

    src_mask    = self.inputs[index]["attention_mask"].squeeze(0)
    target_mask = self.targets[index]["attention_mask"].squeeze(0)

    return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask}

  def _build(self):
    for idx in range(len(self.data)):
      input_, target = self.data.loc[idx, self.data_column], self.data.loc[idx, self.class_column]

      input_ = input_ + ' '
      target = target + " "

      tokenized_inputs = self.tokenizer.batch_encode_plus(
          [input_], max_length=self.max_len, padding="max_length", truncation=True, return_tensors="pt"
      )
      tokenized_targets = self.tokenizer.batch_encode_plus(
          [target], max_length=30, padding="max_length", truncation=True, return_tensors="pt"
      )
      self.inputs.append(tokenized_inputs)
      self.targets.append(tokenized_targets)

train_dataset = SATDataset(tokenizer, train_df)
val_dataset = SATDataset(tokenizer, val_df)

## ARGUMENTS

In [None]:
# Note: Experiments have gone on since our best model, so these parameters may not be optimal. Model is saved to cloud and can be run by itself in later section

args_dict = dict(
    model_name_or_path='google/flan-t5-base',
    tokenizer_name_or_path='google/flan-t5-base',
    max_seq_length=100,
    learning_rate=8e-5,
    weight_decay=0,
    adam_epsilon=1e-8,
    warmup_steps=0,
    train_batch_size=32,
    eval_batch_size=32,
    num_train_epochs=10,
    gradient_accumulation_steps=2,
    n_gpu=1,
    early_stop_callback=False,
    seed=42,
    output_dir="t5_sat_generator",
)
args = argparse.Namespace(**args_dict)

checkpoint_callback = pl.pytorch.callbacks.ModelCheckpoint(
    dirpath=args.output_dir, filename="checkpoint", monitor="val_loss", mode="min", save_top_k=5, save_last=True
)

train_params = dict(
    accumulate_grad_batches=args_dict["gradient_accumulation_steps"],
    accelerator="gpu",
    devices=1,
    max_epochs=args_dict["num_train_epochs"],
    precision=32,
    gradient_clip_val=1.0,
    log_every_n_steps=10
)

## MODEL TRAINING

In [None]:
model = T5FineTuner(args, train_dataset, val_dataset)

train_params["callbacks"] = [LoggingCallback(), checkpoint_callback]

trainer = pl.Trainer(**train_params)
trainer.fit(model)

## MODEL EVALUATION

In [None]:
def batch_output_formula(model, tokenizer, problems, batch_size=32):
    results = []
    device = next(model.parameters()).device

    for i in range(0, len(problems), batch_size):
        batch = problems[i:i + batch_size]
        inputs = tokenizer(batch.tolist(), return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            output_ids = model.model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=100,
                min_length=10,
                do_sample=False,
                num_beams=4,
                early_stopping=True
            )

        decoded_outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        results.extend(decoded_outputs)

    return results

testcopy_df = test_df.sample(n=100, random_state=1).reset_index(drop=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.model.to(torch.float32)

testcopy_df['prediction'] = batch_output_formula(model, tokenizer, testcopy_df['Problem'])

## METRICS

In [None]:
def normalized_levenshtein(pred, truth):
    ratio = SequenceMatcher(None, pred, truth).ratio()
    return ratio

testcopy_df['score'] = testcopy_df.apply(lambda x: normalized_levenshtein(x['prediction'], x['annotated_formula']), axis=1)
print(testcopy_df['score'].mean())

## CLOSE ANSWERS

In [None]:
high_df = testcopy_df[testcopy_df['score'] >= 0.9][['Problem', 'annotated_formula', 'prediction', 'score']]
high_df.head(10)

## SYMPY

In [None]:
const_100 = symbols('const_100')

def evaluate_functional_expression(expr_str):
    stack = []
    num_buffer = ""
    i = 0
    while i < len(expr_str):
        char = expr_str[i]

        if char.isalnum() or char == '.':
            num_buffer += char
        elif char == "_":
            num_buffer += '.'
        elif char == "(":
            if num_buffer:
                if num_buffer.startswith("const_"):
                    const_value = num_buffer.replace("const_", "").replace("_", ".")
                    stack.append(const_value)
                else:
                    stack.append(num_buffer)
                num_buffer = ""

        elif char == "," or char == ")":
            if num_buffer:
                if num_buffer.startswith("const_"):
                    const_value = num_buffer.replace("const_", "").replace("_", ".")
                    stack.append(const_value)
                else:
                    stack.append(num_buffer)
                num_buffer = ""

            if char == ")":
                args = []
                while stack and stack[-1] not in {"add", "subtract", "multiply", "divide"}:
                    args.append(stack.pop())
                args.reverse()

                if stack:
                    func = stack.pop()
                    if func == "add":
                        result = f"({args[0]} + {args[1]})"
                    elif func == "subtract":
                        result = f"({args[0]} - {args[1]})"
                    elif func == "multiply":
                        result = f"({args[0]} * {args[1]})"
                    elif func == "divide":
                        result = f"({args[0]} / {args[1]})"
                    stack.append(result)

        i += 1

    return stack[0] if stack else ""


def check_answer_numeric(x):
    try:
        math_expr = evaluate_functional_expression(x)
        sympy_expr = sympify(math_expr, locals={'const_100': 100})
        return sympy_expr.simplify()
    except Exception as e:
        return

# ACCURACY

In [None]:
testcopy_df['pred_ans'] = testcopy_df['prediction'].apply(lambda x: check_answer_numeric(x))
def is_close(pred, truth, rtol=1e-5, atol=1e-1):
    try:
        if pred is None or truth is None:
            return False
        return np.isclose(float(N(pred)), float(N(truth)), rtol=rtol, atol=atol)
    except:
        return False

testcopy_df['is_close'] = testcopy_df.apply(lambda row: is_close(row['pred_ans'], row['answer_numeric']), axis=1)
testing = testcopy_df[testcopy_df['pred_ans'].notna()]

print("Accuracy:", np.round(testing['is_close'].mean(), 2))