<a href="https://colab.research.google.com/github/AngeloOttendorfer/Python_Projects/blob/master/Math_Language_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dataset

## Helper functions

In [4]:
def last_boxed_only(sample):
    """
    Given a (q,a) sample, filter the answers so that they only contain
    the last \boxed{...} or \fbox{...} element
    """
    q, a = sample
    a = last_boxed_only_string(a)
    if a == None:
        return None
    return (q, a)

def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx == None:
        retval = None
    else:
        retval = string[idx:right_brace_idx + 1]

    return retval

def only_until_first_boxed_from_tokens(string, tokens):
    idx = string.find("\\boxed")
    if idx < 0:
        idx = string.find("\\fbox")
        if idx < 0:
            return None

    cum_length = 0
    for i, t in enumerate(tokens):
        cum_length += len(t)
        if cum_length >= idx:
            break

    return tokens[:i]

def clean_numbers(sample):
    if not sample:
        return None
    new_sample = list()
    for s in sample:
        new_sample.append(_clean_numbers(s))

    return tuple(new_sample)

def _clean_numbers(string):
    """
    Clean Numbers in the given string

    >>> _clean_numbers(None, "Hello 123")
    'Hello 123'
    >>> _clean_numbers(None, "Hello 1234")
    'Hello 1,234'
    >>> _clean_numbers(None, "Hello 1234324asdasd")
    'Hello 1,234,324asdasd'
    """
    num_prev_digits = 0
    new_string = ""
    for i, c in enumerate(string):
        # isdigit() doesnt work here because of weird unicode chars.
        if c in {'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}:
            num_prev_digits += 1
        else:
            if num_prev_digits > 3:
                # Some fixing
                string_number = new_string[-num_prev_digits:]
                new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number))
            num_prev_digits = 0
        new_string += c

    if num_prev_digits > 3:
        # Some fixing
        string_number = new_string[-num_prev_digits:]
        new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number))

    return new_string

## Base Math Dataset

### import necessary modules

In [2]:
# !pip install torch
# !pip install torch.nn.functional
# !pip install random
# !pip install os
# !pip install time

import torch
import torch.nn.functional as F
import random
import os
import time

### Implementation

In [3]:
class BaseMathDataset(torch.utils.data.Dataset):
    """Configurable AMPS Dataset.
    """

    def __init__(self, dataroot, tokenizer, max_tokens, mode, mode_answer='default', len_multiplier=1.0, packing=None,
                 randomize=None, pack_end=None, clean_numbers=True, latex_mask=False, peek_fraction=(0.1, 1.0)):
        self.dataroot = dataroot
        self.tokenizer = tokenizer  # Set in run_training(), not in dataset creation
        self.max_tokens = max_tokens
        self.mode = mode
        self.mode_answer = mode_answer  # Used in subclass
        self.len_multiplier = len_multiplier
        self.clean_numbers = clean_numbers
        self.latex_mask = latex_mask
        self.peek_fraction = peek_fraction

        if self.mode in {'gpt2'}:
            self.clean_sample = self.clean_filter_sample_gpt
            self.packing = True
            self.randomize = True
            self.include_fnames = False
            self.pack_end = True
        elif self.mode in {'gpt2-eval'}:
            self.clean_sample = self.clean_filter_sample_gpt_eval
            self.packing = True
            self.randomize = False
            self.include_fnames = True
            self.pack_end = True
        else:
            raise NotImplementedError()

        if packing != None:
            print("Overriding packing to be", packing)
            self.packing = packing
        if randomize != None:
            print("Overriding randomize to be", randomize)
            self.randomize = randomize
        if pack_end != None:
            print("Overriding pack_end to be", pack_end)
            self.pack_end = pack_end

        self.initialize()

        self.bad_fnames = set()
        self.i = 0

    def initialize(self):
        raise NotImplementedError()

    def __len__(self):
        raise NotImplementedError()

    def __getitem__(self, index):

        # Each worker needs a different seed....
        random.seed(os.getpid() + time.time() + random.random())

        # Sampling with replacement.
        # We need to pack random elements to get close to self.max_tokens
        curr_input_ids = []
        curr_label_ids = []
        curr_fnames = []
        num_samples = 0
        while len(curr_input_ids) + 1 <= self.max_tokens and len(curr_label_ids) + 1 <= self.max_tokens:
            # print("curr_input_ids: " + str(curr_input_ids))
            # print("curr_label_ids: " + str(curr_label_ids))
            # print("curr_fnames: " + str(curr_fnames))
            curr_sample, fname = self.get_random_sample()
            # print("current_sample: " + str(curr_sample))
            # print(fname)
            if curr_sample is None:
                # This only happens in eval modes
                return {
                    "input_ids": torch.zeros([self.max_tokens]),
                    "labels": torch.zeros([self.max_tokens]),
                    "fnames": [fname]
                }

            if not self.pack_end and (
                    (len(curr_input_ids) + 1 + len(curr_sample['input_ids_list']) > self.max_tokens) or
                    (len(curr_label_ids) + 1 + len(curr_sample['label_ids_list']) > self.max_tokens)
            ):
                # Do not include curr_sample if either the input_ids or the label_ids will run off the end.
                break

            # Add curr_sample to the current inputs and labels
            # print("input_ids_list: " + str(curr_sample['input_ids_list']))
            curr_input_ids.extend(curr_sample['input_ids_list'])
            curr_label_ids.extend(curr_sample['label_ids_list'])
            curr_fnames.append(fname)

            num_samples += 1

            # Break on the first iteration if we don't want to do packing.
            if not self.packing:
                break

        input_ids = torch.LongTensor(curr_input_ids)
        label_ids = torch.LongTensor(curr_label_ids)

        # Sanity check
        if 'eval' not in self.mode:
            assert len(curr_input_ids) == len(curr_label_ids)

        input_ids = input_ids[:self.max_tokens]
        label_ids = label_ids[:self.max_tokens]

        if len(curr_input_ids) < self.max_tokens and 'eval' not in self.mode:
            # Pad
            num_to_pad = self.max_tokens - len(curr_input_ids)
            input_ids = F.pad(input_ids, [0, num_to_pad], mode='constant', value=self.tokenizer.pad_token_id)

        if len(curr_label_ids) < self.max_tokens and 'eval' not in self.mode:
            num_to_pad = self.max_tokens - len(curr_label_ids)
            label_ids = F.pad(label_ids, [0, num_to_pad], mode='constant', value=-100)

        # Sanity check
        if 'eval' not in self.mode:
            assert input_ids.shape[0] == label_ids.shape[
                0] == self.max_tokens, f"{input_ids.shape[0]}, {label_ids.shape[0]}, {self.max_tokens}"

        if self.include_fnames:
            return {
                "input_ids": input_ids,
                "labels": label_ids,
                "fnames": curr_fnames
            }
        else:
            # This is the format required by our GPT2Trainer class
            return {
                "input_ids": input_ids,
                "labels": label_ids
            }

    def get_random_sample(self):
        """
        Get a full on random sample (used for training)
        """
        random_sample = None
        while random_sample is None:
            if self.randomize:
                q, a, fname = random.choice(self.samples)
            else:
                q, a, fname = self.samples[self.i]
                self.i = (self.i + 1) % len(self.samples)

            random_sample = self.clean_sample((q, a))  # q + '\n' + a  # self.clean_sample((q, a))
            # print("random_sample: " + str(random_sample))

            if not self.randomize:
                break

        return random_sample, fname

    def clean_filter_sample_gpt(self, sample):
        raise NotImplementedError()

    def clean_filter_sample_gpt_eval(self, sample):
        raise NotImplementedError()

    def clean_filter_sample_t5(self, sample):
        raise NotImplementedError()

    def clean_filter_sample_t5_eval(self, sample):
        raise NotImplementedError()


## Configurable Mathematica Dataset

### Import necessary modules

In [7]:
# !pip install torch
# !pip install tqdm
# !pip install os
import torch
import tqdm
import os

 ### Implementation

In [8]:
class MathematicaMathDataset(BaseMathDataset):
    """Configurable Math Dataset.
    """

    def __len__(self):
        return int(len(self.samples) * self.len_multiplier)

    def initialize(self):
        """
        Set up self.samples by loading from the dataroot
        """

        with open(self.dataroot, 'r') as fp:
            all_filenames = fp.readlines()

        print(f"{self.__class__.__name__}: Loading samples from {len(all_filenames)} files.")
        samples_raw = []
        for fname in tqdm(all_filenames):
            fname = fname.rstrip()
            # print(fname)
            # fname = os.path.join(os.path.dirname(os.path.dirname(self.dataroot)), fname[2:])
            # print(fname)

            if not os.path.isfile(fname):
                print(f"SKIPPING {fname}")
                continue
            with open(fname, 'r') as fp:
                question = ""
                answers  = []
                reading_question = True
                curr_section = ""
                for line in fp:
                    if line == "Problem:\n":
                        reading_question = True
                    elif line == "Answer:\n":
                        if reading_question:
                            # curr_section contains Q
                            question = curr_section
                        else:
                            # curr_section contains an A
                            answers.append(curr_section)
                        curr_section = ""
                        reading_question = False
                    else:
                        curr_section += line

                # The last answer needs to be recorded.
                answers.append(curr_section)

            for a in answers:
                samples_raw.append((question, a, fname))

        # manager = Manager()
        # samples_raw = manager.list(samples_raw)
        self.samples = samples_raw
        del samples_raw

        print(f"{self.__class__.__name__}: Loaded {len(self.samples)} samples.")
        # print(self.samples)

    def clean_filter_sample_gpt(self, sample):
        """
        Does the actual tokenization. Should be parallelized because it can be a bit slow.
        """

        if sample == None:
            return None

        question, answer = sample
        if self.clean_numbers:
            question = _clean_numbers(question)
            answer = _clean_numbers(answer)

        if self.mode_answer == 'default':
            question_ids     = torch.LongTensor(self.tokenizer.encode("\nQUESTION:\n" + question, verbose=False))

            sep_ids          = torch.LongTensor(self.tokenizer.encode("\nFINAL ANSWER:\n", verbose=False))
            answer_ids       = self.tokenizer.encode(answer, verbose=False)
            answer_ids.append(self.tokenizer.eos_token_id)
            answer_ids       = torch.LongTensor(answer_ids)

            # Use full solution
            input_ids = torch.cat([
                question_ids,
                sep_ids,
                answer_ids
            ], dim=0)

            label_ids = torch.cat([
                torch.ones_like(question_ids) * -100,
                torch.ones_like(sep_ids) * -100,
                answer_ids.clone()
            ], dim=0)
        else:
            raise NotImplementedError()

        # Stop early if this Q,A pair is too long
        if input_ids.shape[0] > self.max_tokens:
            # Print reason for skipping
            # print(f"{self.__class__.__name__} Skipping due to input_ids being too big. input_ids.shape[0] = {input_ids.shape[0]}.")
            return None

        input_ids = input_ids.tolist()
        label_ids = label_ids.tolist()

        return {
            'input_ids_list' : input_ids,
            'label_ids_list' : label_ids
        }

    def clean_filter_sample_t5(self, sample):
        """
        Does the actual tokenization. Should be parallelized because it can be a bit slow.
        """

        if sample == None:
            return None

        question, answer = sample
        if self.clean_numbers:
            question = _clean_numbers(question)
            answer = _clean_numbers(answer)

        if self.mode_answer == 'default':
            question_ids     = torch.LongTensor(self.tokenizer.encode("\nQUESTION:\n" + question + "\nFINAL ANSWER:\n", verbose=False))
            answer_ids       = torch.LongTensor(self.tokenizer.encode(answer, verbose=False))

            input_ids = torch.cat([
                question_ids,
            ], dim=0)

            label_ids = torch.cat([
                answer_ids
            ], dim=0)
        else:
            raise NotImplementedError()

        # Stop early if this Q,A pair is too long
        if input_ids.shape[0] > self.max_tokens:
            # Print reason for skipping
            # print(f"{self.__class__.__name__} Skipping due to input_ids being too big. input_ids.shape[0] = {input_ids.shape[0]}.")
            return None

        input_ids = input_ids.tolist()
        label_ids = label_ids.tolist()

        return {
            'input_ids_list' : input_ids,
            'label_ids_list' : label_ids
        }


## MATH Dataset Configuration

### import necessary modules

In [10]:
# !pip install torch
# !pip install json
# !pip install glob
# !pip install random
# !pip install numpy

import torch
import json
import glob
import random
import numpy as np

from multiprocessing import Manager

### Implementation

In [11]:
class MATHDataset(BaseMathDataset):
    """Configurable Math Dataset.
    """

    def __len__(self):
        return int(len(self.samples) * self.len_multiplier)

    def initialize(self):
        """
        Set up self.samples by loading from the dataroot
        """
        print(self.dataroot)
        all_filenames = glob.glob('/Users/angeloottendorfer/Desktop/amps/mathematica/algebra/testdaten_find_roots/*')
        print(all_filenames)
        samples_raw = []
        for fname in all_filenames:
            with open(fname, 'r') as fp:
                try:
                    problem_data = json.load(fp)
                except Exception as e:
                    print(f"Error loading JSON from {fname}", e)
                    raise e
            curr_sample_raw = (problem_data['problem'], problem_data['solution'], fname)
            for e in curr_sample_raw:
                assert e
            samples_raw.append(curr_sample_raw)

        manager = Manager()
        samples_raw = manager.list(samples_raw)
        self.samples = samples_raw
        del samples_raw

        print(f"{self.__class__.__name__}: Loaded {len(self.samples)} samples.")

    def clean_filter_sample_gpt(self, sample):
        """
        Does the actual tokenization. Should be parallelized because it can be a bit slow.
        """

        if sample == None:
            return None

        if self.mode_answer == 'peeking_only':
            return self.clean_filter_sample_peeking_gpt(sample)
        if self.mode_answer == 'mixed_full_and_peeking':
            if random.random() < 0.5:
                return self.clean_filter_sample_peeking_gpt(sample)
            else:
                _mode_answer = 'full'
        elif self.mode_answer == 'mixed_full_and_nopack_padding':
            if random.random() < 0.5:
                return self.clean_filter_sample_nopackpadding_gpt(sample)
            else:
                _mode_answer = 'full'
        elif self.mode_answer == 'mixed_final_boxed_and_full':
            if random.random() < 0.5:
                _mode_answer = 'full'
            else:
                _mode_answer = 'final_boxed'
        elif self.mode_answer == 'full':
            _mode_answer = 'full'
        elif self.mode_answer == 'final_boxed':
            _mode_answer = 'final_boxed'
        else:
            raise NotImplementedError(f"self.mode_answer = {self.mode_answer} not recognized.")

        if _mode_answer == 'full':
            question, answer = sample

            if self.clean_numbers:
                question = _clean_numbers(question)
                answer = _clean_numbers(answer)

            answer_final = last_boxed_only_string(answer)

            question_ids = torch.LongTensor(self.tokenizer.encode("\nQUESTION:\n" + question, verbose=False))

            sep_ids_2 = torch.LongTensor(self.tokenizer.encode("\nFULL SOLUTION:\n", verbose=False))
            answer_ids = self.tokenizer.encode(answer, verbose=False)
            answer_ids.append(self.tokenizer.eos_token_id)
            answer_ids = torch.LongTensor(answer_ids)

            input_ids = torch.cat([
                question_ids,
                sep_ids_2,
                answer_ids
            ], dim=0)

            # Only answer_ids contribute to the loss
            label_ids = torch.cat([
                torch.ones_like(question_ids) * -100,
                torch.ones_like(sep_ids_2) * -100,
                answer_ids.clone()
            ], dim=0)

        elif _mode_answer == 'final_boxed':
            question, answer = sample

            if self.clean_numbers:
                question = _clean_numbers(question)
                answer = _clean_numbers(answer)
            answer_final = last_boxed_only_string(answer)
            if not answer_final:
                print("ERROR FROM", question, answer)
                return None

            question_ids = torch.LongTensor(self.tokenizer.encode("\nQUESTION:\n" + question, verbose=False))

            sep_ids_1 = torch.LongTensor(self.tokenizer.encode("\nFINAL ANSWER:\n", verbose=False))
            answer_final_ids = self.tokenizer.encode(answer_final, verbose=False)
            answer_final_ids.append(self.tokenizer.eos_token_id)
            answer_final_ids = torch.LongTensor(answer_final_ids)

            input_ids = torch.cat([
                question_ids,
                sep_ids_1,
                answer_final_ids,
            ], dim=0)

            # Only answer_ids contribute to the loss
            label_ids = torch.cat([
                torch.ones_like(question_ids) * -100,
                torch.ones_like(sep_ids_1) * -100,
                answer_final_ids.clone(),
            ], dim=0)

        else:
            raise NotImplementedError()

        # Stop early if this Q,A pair is too long
        if input_ids.shape[0] > self.max_tokens:
            # Print reason for skipping
            # print(f"Skipping due to input_ids being too big. input_ids.shape[0] = {input_ids.shape[0]}.")
            return None

        input_ids = input_ids.tolist()
        label_ids = label_ids.tolist()

        return {
            'input_ids_list': input_ids,
            'label_ids_list': label_ids
        }

    def clean_filter_sample_nopackpadding_gpt(self, sample):

        if sample == None:
            return None

        question, answer = sample

        if self.clean_numbers:
            question = _clean_numbers(question)
            answer = _clean_numbers(answer)

        answer_final = last_boxed_only_string(answer)

        question_ids = torch.LongTensor(self.tokenizer.encode("\nQUESTION:\n" + question, verbose=False))
        sep_ids = torch.LongTensor(self.tokenizer.encode("\nFINAL ANSWER:\n", verbose=False))
        final_answer_ids = torch.LongTensor(self.tokenizer.encode(answer_final, verbose=False))

        # Stop early if this Q,A pair is too long
        num_to_pad = 32
        padding_tensor = torch.ones((num_to_pad)) * 220  # 220 is the token for space in the case of GPT2 models

        input_ids = torch.cat([
            question_ids,
            padding_tensor,
            sep_ids,
            final_answer_ids
        ], dim=0)

        # Only answer_ids contribute to the loss
        label_ids = torch.cat([
            torch.ones_like(question_ids) * -100,
            torch.ones_like(padding_tensor) * -100,
            torch.ones_like(sep_ids) * -100,
            final_answer_ids.clone()
        ], dim=0)

        input_ids = input_ids.tolist()
        label_ids = label_ids.tolist()

        return {
            'input_ids_list': input_ids,
            'label_ids_list': label_ids
        }

    def clean_filter_sample_nopackpadding_gpt_eval(self, sample):

        if sample == None:
            return None

        question, answer = sample

        if self.clean_numbers:
            question = _clean_numbers(question)
            answer = _clean_numbers(answer)

        answer_final = last_boxed_only_string(answer)

        question_ids = torch.LongTensor(self.tokenizer.encode("\nQUESTION:\n" + question, verbose=False))
        sep_ids = torch.LongTensor(self.tokenizer.encode("\nFINAL ANSWER:\n", verbose=False))
        final_answer_ids = torch.LongTensor(self.tokenizer.encode(answer_final, verbose=False))

        num_to_pad = 32
        padding_tensor = torch.ones((num_to_pad)) * 220  # 220 is the token for space in the case of GPT2 models

        input_ids = torch.cat([
            question_ids,
            padding_tensor,
            sep_ids,
        ], dim=0)

        # Only answer_ids contribute to the loss
        label_ids = torch.cat([
            final_answer_ids.clone()
        ], dim=0)

        # Stop early if this Q,A pair is too long
        if input_ids.shape[0] + label_ids.shape[0] > self.max_tokens:
            # Print reason for skipping
            # print(f"Skipping due to input_ids being too big. input_ids.shape[0] = {input_ids.shape[0]}.")
            return None

        input_ids = input_ids.tolist()
        label_ids = label_ids.tolist()

        return {
            'input_ids_list': input_ids,
            'label_ids_list': label_ids
        }

    def clean_filter_sample_peeking_gpt(self, sample):
        """
        Does the actual tokenization. Should be parallelized because it can be a bit slow.
        """

        if sample == None:
            return None

        question, answer = sample

        if self.clean_numbers:
            question = _clean_numbers(question)
            answer = _clean_numbers(answer)

        answer_final = last_boxed_only_string(answer)

        question_ids = torch.LongTensor(
            self.tokenizer.encode("\nQUESTION:\n" + question + "\nFULL SOLUTION:\n", verbose=False))
        answer_ids = self.tokenizer.tokenize(answer)
        answer_ids = only_until_first_boxed_from_tokens(answer, answer_ids)
        answer_ids = torch.LongTensor(self.tokenizer.encode(answer_ids, verbose=False))

        # Take a fraction
        if isinstance(self.peek_fraction, tuple):
            final_idx = int(len(answer_ids) * random.uniform(*self.peek_fraction))
        else:
            final_idx = int(len(answer_ids) * self.peek_fraction)

        # # Override peeking fraction
        # final_idx = int(len(answer_ids) * np.random.choice([0.25, 0.5, 0.75, 1.0], p=[1/6, 1/6, 1/3, 1/3]))

        answer_ids = answer_ids[:final_idx]

        sep_ids = torch.LongTensor(self.tokenizer.encode("\nFINAL ANSWER:\n", verbose=False))
        final_answer_ids = torch.LongTensor(self.tokenizer.encode(answer_ids[final_idx:]))

        input_ids = torch.cat([
            question_ids,
            answer_ids,
            sep_ids,
            final_answer_ids
        ], dim=0)

        # Only answer_ids contribute to the loss
        label_ids = torch.cat([
            torch.ones_like(question_ids) * -100,
            torch.ones_like(answer_ids) * -100,
            torch.ones_like(sep_ids) * -100,
            final_answer_ids.clone()
        ], dim=0)

        # Stop early if this Q,A pair is too long
        if input_ids.shape[0] > self.max_tokens:
            # Print reason for skipping
            # print(f"Skipping due to input_ids being too big. input_ids.shape[0] = {input_ids.shape[0]}.")
            return None

        input_ids = input_ids.tolist()
        label_ids = label_ids.tolist()

        return {
            'input_ids_list': input_ids,
            'label_ids_list': label_ids
        }

    def clean_filter_sample_peeking_gpt_eval(self, sample):
        """
        Does the actual tokenization. Should be parallelized because it can be a bit slow.
        """

        if sample == None:
            return None

        question, answer = sample

        if self.clean_numbers:
            question = _clean_numbers(question)
            answer = _clean_numbers(answer)

        answer_final = last_boxed_only_string(answer)

        question_ids = torch.LongTensor(
            self.tokenizer.encode("\nQUESTION:\n" + question + "\nFULL SOLUTION:\n", verbose=False))
        answer_ids = self.tokenizer.tokenize(answer)
        answer_ids_full = torch.LongTensor(self.tokenizer.encode(answer))
        answer_ids = only_until_first_boxed_from_tokens(answer, answer_ids)
        if len(answer_ids) == 0:
            return None
        answer_ids = torch.LongTensor(self.tokenizer.encode(answer_ids, verbose=False))

        # Take a fraction
        if isinstance(self.peek_fraction, tuple):
            final_idx = int(len(answer_ids) * random.uniform(*self.peek_fraction))
        else:
            final_idx = int(len(answer_ids) * self.peek_fraction)

        answer_ids = answer_ids[:final_idx]

        # sep_ids          = torch.LongTensor(self.tokenizer.encode("\nFINAL ANSWER\n", verbose=False))
        final_answer_ids = answer_ids_full[final_idx:]
        print(final_answer_ids)

        input_ids = torch.cat([
            question_ids,
            answer_ids,
            # sep_ids,
        ], dim=0)

        # Only answer_ids contribute to the loss
        label_ids = torch.cat([
            final_answer_ids.clone()
        ], dim=0)

        # Stop early if this Q,A pair is too long
        if input_ids.shape[0] + label_ids.shape[0] > self.max_tokens:
            # Print reason for skipping
            # print(f"Skipping due to input_ids being too big. input_ids.shape[0] = {input_ids.shape[0]}.")
            return None

        input_ids = input_ids.tolist()
        label_ids = label_ids.tolist()

        return {
            'input_ids_list': input_ids,
            'label_ids_list': label_ids
        }

    def clean_filter_sample_gpt_eval(self, sample):
        """
        Does tokenization for final model evaluation. This should return
        input_ids as the context and labels as the true answer.
        """

        if sample == None:
            return None

        if self.mode_answer == 'eval_peeking':
            return self.clean_filter_sample_peeking_gpt_eval(sample)
        elif self.mode_answer == 'eval_nopack_padding':
            return self.clean_filter_sample_nopackpadding_gpt_eval(sample)

        question, answer = sample
        print("question_from_sample: " + question)
        print("answer from sample: " + answer)

        if self.clean_numbers:
            question = _clean_numbers(question)
            answer = _clean_numbers(answer)
            print("question: " + question)
            print("answer: " + answer)
        # answer_final = last_boxed_only_string(answer)
        # print("answer_final: " + str(answer_final))

        assert not answer.isspace()

        question_ids = torch.LongTensor(self.tokenizer.encode("\nQUESTION:\n" + question, verbose=False))
        sep_ids = torch.LongTensor(self.tokenizer.encode("\FULL SOLUTION:\n", verbose=False))
        answer_final_ids = torch.LongTensor(
            self.tokenizer.encode(answer, verbose=False))  # Loss only counted on these tokens.

        input_ids = torch.cat([
            question_ids,
            sep_ids,
        ], dim=0)

        label_ids = torch.cat([
            answer_final_ids.clone()
        ], dim=0)

        # Stop early if this Q,A pair is too long
        if input_ids.shape[0] + label_ids.shape[0] > self.max_tokens:
            # Print reason for skipping
            # print(f"Skipping due to input_ids being too big. input_ids.shape[0] = {input_ids.shape[0]}.")
            return None

        return {
            'input_ids_list': input_ids.tolist(),
            'label_ids_list': label_ids.tolist()
        }


# Modeling

## fine-tuning

### import necessary modules

In [13]:
# !pip install os
# !pip install pprint
# !pip install argparse
# !pip install transformers
# !pip install torch

import os
import pprint
import argparse
import transformers
import torch

from datetime import datetime

### Get the dataset

In [15]:
def get_dataset(args):
    tokenizer = get_tokenizer_gpt(args)
    print(tokenizer)
    # print(tokenizer.tokenize("1231231234441234 blah dklkjl12490!!@ 2*x + y^k + f(x)"))  # sanity check

    train_data = []

    if args.mathematica_dataroot:
        for mathematica_dr in args.mathematica_dataroot:
            len_multiplier, dirname = mathematica_dr.split("@")
            len_multiplier = float(len_multiplier)
            print(len_multiplier)

            # Save path to txt file which contains all txt files of math problems and answers for a specific category
            # Algebra
            flist_find_roots = os.path.join(dirname, "algebra/flist_find_roots.txt")
            flist_invert_function = os.path.join(dirname, "algebra/flist_invert_function.txt")

            # Calculus
            flist_derivatives = os.path.join(dirname, "calculus/flist_derivatives.txt")
            flist_integrals = os.path.join(dirname, "calculus/flist_integrals.txt")

            # Geometry
            flist_polygons = os.path.join(dirname, "geometry/flist_polygons.txt")
            flist_triangles = os.path.join(dirname, "geometry/flist_triangles.txt")

            # Linear Algebra
            flist_determinant = os.path.join(dirname, "linear_algebra/flist_determinant.txt")
            flist_orthogonolize_vectors = os.path.join(dirname, "linear_algebra/flist_orthogonolize_vectors.txt")

            with open(flist_find_roots, "r") as f:
                find_roots_num_files = len(f.readlines())

            with open(flist_invert_function, "r") as f:
                invert_function_num_files = len(f.readlines())

            with open(flist_derivatives, "r") as f:
               derivatives_num_files = len(f.readlines())

            with open(flist_integrals, "r") as f:
                integrals_num_files = len(f.readlines())

            with open(flist_polygons, "r") as f:
                polygons_num_files = len(f.readlines())

            with open(flist_triangles, "r") as f:
                triangles_num_files = len(f.readlines())

            with open(flist_determinant, "r") as f:
                determinant_num_files = len(f.readlines())

            with open(flist_orthogonolize_vectors, "r") as f:
                orthogonolize_vectors_num_files = len(f.readlines())

            if find_roots_num_files:
                train_data.append(MathematicaMathDataset(
                    dataroot=flist_find_roots,
                    tokenizer=tokenizer,
                    max_tokens=384 if args.arch == 'gpt2-xl' else 1024,
                    mode='gpt2',
                    len_multiplier=len_multiplier
                ))

            if invert_function_num_files:
                train_data.append(MathematicaMathDataset(
                    dataroot=flist_invert_function,
                    tokenizer=tokenizer,
                    max_tokens=384 if args.arch == 'gpt2-xl' else 1024,
                    mode='gpt2',
                    len_multiplier=len_multiplier
                ))

            if derivatives_num_files:
                train_data.append(MathematicaMathDataset(
                    dataroot=flist_derivatives,
                    tokenizer=tokenizer,
                    max_tokens=384 if args.arch == 'gpt2-xl' else 1024,
                    mode='gpt2',
                    len_multiplier=len_multiplier
                ))

            if integrals_num_files:
                train_data.append(MathematicaMathDataset(
                    dataroot=flist_integrals,
                    tokenizer=tokenizer,
                    max_tokens=384 if args.arch == 'gpt2-xl' else 1024,
                    mode='gpt2',
                    len_multiplier=len_multiplier
                ))

            if polygons_num_files:
                train_data.append(MathematicaMathDataset(
                    dataroot=flist_polygons,
                    tokenizer=tokenizer,
                    max_tokens=384 if args.arch == 'gpt2-xl' else 1024,
                    mode='gpt2',
                    len_multiplier=len_multiplier
                ))

            if triangles_num_files:
                train_data.append(MathematicaMathDataset(
                    dataroot=flist_triangles,
                    tokenizer=tokenizer,
                    max_tokens=384 if args.arch == 'gpt2-xl' else 1024,
                    mode='gpt2',
                    len_multiplier=len_multiplier
                ))

            if determinant_num_files:
                train_data.append(MathematicaMathDataset(
                    dataroot=flist_determinant,
                    tokenizer=tokenizer,
                    max_tokens=384 if args.arch == 'gpt2-xl' else 1024,
                    mode='gpt2',
                    len_multiplier=len_multiplier
                ))

            if orthogonolize_vectors_num_files:
                train_data.append(MathematicaMathDataset(
                    dataroot=flist_orthogonolize_vectors,
                    tokenizer=tokenizer,
                    max_tokens=384 if args.arch == 'gpt2-xl' else 1024,
                    mode='gpt2',
                    len_multiplier=len_multiplier
                ))

    # Print the sizes of each dataset, useful for weighting
    for dset in train_data:
        print(f"{dset.__class__.__name__}: __len__ = {len(dset)}")

    return torch.utils.data.ConcatDataset(train_data)


### Get a tokenizer

In [16]:
def get_tokenizer_gpt(args):
    """
    If args.tokenizer_merges_file is given, return a tokenizer that uses that merges_file.
    In the paper, we use this to restrict models to ingest and outuput digits. For example:

    >>> tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2", merges_file="merges_gpt2_single_digit_numbers.txt")
    >>> tokenizer_old = transformers.GPT2Tokenizer.from_pretrained("gpt2")
    >>> tokenizer.encode("1")
    [16]
    >>> tokenizer_old.encode("1")
    [16]
    >>> tokenizer.encode("2")
    [17]
    >>> tokenizer_old.encode("12")
    [1065]
    >>> tokenizer.encode("12")
    [16, 17]
    >>> tokenizer.encode("HEllo world!")
    [13909, 18798, 995, 0]
    >>> tokenizer_old.encode("HEllo world!")
    [13909, 18798, 995, 0]
    """
    if args.tokenizer_merges_file is not None:
        tokenizer = transformers.GPT2Tokenizer.from_pretrained(args.arch, merges_file=args.tokenizer_merges_file)
    else:
        tokenizer = transformers.GPT2Tokenizer.from_pretrained(args.arch)
    return tokenizer


### GPT2 Trainer

In [17]:
class GPT2Trainer(transformers.Trainer):
    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.
        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
        """
        if self.optimizer is None:
            print("Making AdamW Optimizer")
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            self.optimizer = torch.optim.AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )

        if self.lr_scheduler is None:

            if self.args.warmup_steps == -1:
                print("Using constant LR")
                self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda steps: 1.0)
            else:
                print("Using Linear warmup LR")
                self.lr_scheduler = self.get_linear_schedule_with_warmup(
                    self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
                )

    @staticmethod
    def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
        """
        Linear warmup from 0 to max lr, then linear decay from max_lr to 0.1*max_lr
        As done in https://arxiv.org/pdf/2010.14701.pdf
        """

        def lr_lambda(current_step: int):
            if current_step < num_warmup_steps:
                return float(current_step) / float(max(1, num_warmup_steps))
            min_lr_multiplier = 0.1
            return max(
                min_lr_multiplier,
                ((1 - min_lr_multiplier) * float(num_training_steps - current_step) / float(
                    max(1, num_training_steps - num_warmup_steps))) + min_lr_multiplier
            )

        return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

### Run the training

In [18]:
def run_training(args, train_data):
    if not args.save_steps:
        # Save every epoch
        if not args.tpu_num_cores:
            save_steps = len(train_data)
            # print("mps_is_available = " + str(torch.backends.mps.is_available()))  # the MacOS is higher than 12.3+
            # print("mps_is_built = " + str(torch.backends.mps.is_built()))  # MPS is activated

            # save_steps = int(save_steps / torch.cuda.device_count())
            save_steps = int(save_steps / args.grad_acc_steps)
            save_steps = int(save_steps / args.batch_size_per_replica)
        else:
            save_steps = len(train_data)
            save_steps = int(save_steps / 8)  # 8 TPU cores is constant for now.
            save_steps = int(save_steps / args.grad_acc_steps)
            save_steps = int(save_steps / args.batch_size_per_replica)
    else:
        save_steps = args.save_steps

    print("Save Steps = ", save_steps)

    ## Checkpoint Loading ########################################################
    if args.load:
        model = transformers.GPT2LMHeadModel.from_pretrained(args.load)
        print(f"Loaded model from {args.load}")
    else:
        model = transformers.GPT2LMHeadModel.from_pretrained(args.arch)

    start_epoch = 0
    start_iteration = 0

    ## Dataloading ########################################################
    train_data.start_iteration = start_iteration

    ## Start Loop ########################################################
    print(f"Setting up Trainer")

    training_args = transformers.TrainingArguments(
        output_dir=args.save_dir,
        overwrite_output_dir=False,

        do_train=True,
        do_eval=False,
        do_predict=True,
        evaluation_strategy='no',
        eval_steps=0,

        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size_per_replica,
        gradient_accumulation_steps=args.grad_acc_steps,

        learning_rate=args.lr,
        weight_decay=args.weight_decay,
        warmup_steps=args.lr_warmup_steps,
        max_grad_norm=100000.0,  # Essentially disable gradient clipping

        logging_dir=args.save_dir,
        logging_first_step=True,
        logging_steps=args.log_freq,
        save_steps=save_steps,
        save_total_limit=10,  # Only save the last epoch

        dataloader_drop_last=True,
        dataloader_num_workers=args.dataloader_num_workers,

        local_rank=args.local_rank,
        tpu_num_cores=args.tpu_num_cores,
    )

    trainer = GPT2Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
    )
    trainer.remove_callback(transformers.integrations.TensorBoardCallback)
    # trainer.add_callback(CustomTensorBoardCallback())

    print(f"STARTING TRAINING. save_steps={save_steps}")
    trainer.train()

    trainer.save_model(os.path.join(args.save_dir, "final_checkpoint"))
    print("Finished")


### Specify additional args and start the training

In [20]:
def main():
    ######### Arg parsing ###############################################################

    parser = argparse.ArgumentParser(description="Language Modelling on Code")
    parser.add_argument('--arch', default='gpt2', choices=transformers.GPT2_PRETRAINED_MODEL_ARCHIVE_LIST)
    parser.add_argument('--tokenizer-merges-file', default=None, type=str)
    parser.add_argument('--load', default=None, type=str)

    # Dataloading
    parser.add_argument('--khan-mode', default='mixed_hints', type=str)
    parser.add_argument('--khan-dataroot', default=None, type=str)
    parser.add_argument('--khan-latex-mask', default=False, action='store_true')
    parser.add_argument('--deepmind-dataroot', default=None, type=str, action='append')
    parser.add_argument('--mathematica-dataroot', default='1@/content/sample_data/train_data/', type=str, action='append')
    parser.add_argument('--mathematica-with-steps-dataroot', default=None, type=str, action='append')
    parser.add_argument('--MATH-mode', default='mixed_final_boxed_and_full', type=str,
                        choices=['mixed_final_boxed_and_full', 'final_boxed', 'peeking', 'nopack_padding',
                                 'mixed_full_and_peeking', 'mixed_full_and_nopack_padding'])
    parser.add_argument('--MATH-peek-min', default=0.1, type=float)
    parser.add_argument('--MATH-peek-max', default=1.0, type=float)
    parser.add_argument('--MATH-dataroot', default=None, type=str)
    parser.add_argument('--stackexchange-dataroot', default=None, type=str)
    parser.add_argument('--dataloader-num-workers', default=1, type=int)

    # Training
    parser.add_argument('--epochs', default=1, type=int)
    parser.add_argument('--lr', default=5e-5, type=float)
    parser.add_argument('--weight-decay', default=0.05, type=float)
    parser.add_argument('--lr-warmup-steps', default=1, type=int)
    parser.add_argument('--batch-size-per-replica', default=8, type=int)
    parser.add_argument('--grad-acc-steps', default=4, type=int)
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--tpu_num_cores', default=None, type=int)

    # Logging and stuff
    parser.add_argument('--save-dir', default="checkpoints/TEMP", type=str)
    parser.add_argument('--save-steps', default=0, type=int)
    parser.add_argument('--log-freq', default=5, type=int)

    args = parser.parse_args()
    args.save_dir = os.path.join(args.save_dir, datetime.now().strftime("%m-%d-%Y__%H:%M:%S"))

    ######### Start training ###############################################################

    argsdict = vars(args)
    print(pprint.pformat(argsdict))

    train_data = get_dataset(args)

    os.makedirs(args.save_dir, exist_ok=True)
    with open(os.path.join(args.save_dir, "command.txt"), 'w') as f:
        f.write(pprint.pformat(argsdict))

    run_training(args, train_data)


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()

usage: colab_kernel_launcher.py [-h] [--arch {gpt2,gpt2-medium,gpt2-large,gpt2-xl,distilgpt2}]
                                [--tokenizer-merges-file TOKENIZER_MERGES_FILE] [--load LOAD]
                                [--khan-mode KHAN_MODE] [--khan-dataroot KHAN_DATAROOT]
                                [--khan-latex-mask] [--deepmind-dataroot DEEPMIND_DATAROOT]
                                [--mathematica-dataroot MATHEMATICA_DATAROOT]
                                [--mathematica-with-steps-dataroot MATHEMATICA_WITH_STEPS_DATAROOT]
                                [--MATH-mode {mixed_final_boxed_and_full,final_boxed,peeking,nopack_padding,mixed_full_and_peeking,mixed_full_and_nopack_padding}]
                                [--MATH-peek-min MATH_PEEK_MIN] [--MATH-peek-max MATH_PEEK_MAX]
                                [--MATH-dataroot MATH_DATAROOT]
                                [--stackexchange-dataroot STACKEXCHANGE_DATAROOT]
                                [--dataloader-nu

SystemExit: ignored