# Process Data

## Utilities

In [27]:
import torch
import numpy as np
from tqdm import tqdm

In [28]:
def tokenize_q(text):
    l = len(text)
    return [char_to_idx[c] for c in text] + [eos_token] + [empty_token] * (max_q_len - l - 1)

def tokenize_a(text):
    l = len(text)
    token_a = [char_to_idx[c] for c in text]
    token_a = [start_token] + token_a + [eos_token] + [empty_token] * (max_a_len - l - 1)
    token_target = token_a[:-1]
    token_label = token_a[1:]
    return token_target, token_label

def invert_tokenization(idx):
    return [idx_to_char[i] for i in idx]

In [29]:
def load_qa(filename):
    with open(f'{data_path}/{tr_l}/{task}.txt') as f:
        text = f.read().splitlines()
        quess = text[::2]
        anss = text[1::2]

    return quess, anss

def tokenize_qa(quess, anss):
    tokenized_source = [tokenize_q(text) for text in tqdm(quess)]

    tokenized_target, tokenized_label = [], []
    for text in tqdm(anss):
        tt, tl = tokenize_a(text)
        tokenized_target.append(tt)
        tokenized_label.append(tl)

    tokenized_source = torch.tensor(tokenized_source)
    tokenized_target = torch.tensor(tokenized_target)
    tokenized_label = torch.tensor(tokenized_label)

    return tokenized_source, tokenized_target, tokenized_label

In [30]:
def create_ds(fname):
    quess, anss = load_qa(fname)

    tokenized_source, tokenized_target, tokenized_label = tokenize_qa(quess, anss)

    return torch.utils.data.TensorDataset(tokenized_source, tokenized_target, tokenized_label)

In [31]:
data_path = '../../data/math/mathematics_dataset-v1.0'
out_path = 'tokenized_data'

In [32]:
with open('text_vectorizer/vocabulary.txt') as f:
    vocab = f.read().splitlines()

idx_to_char = {i: c for i, c in enumerate(vocab)}
char_to_idx = {c: i for i, c in enumerate(vocab)}

empty_token = char_to_idx['']
eos_token = char_to_idx[';']
start_token = char_to_idx['@']

max_q_len, max_a_len = 161, 31

## Process Task data

In [33]:
tasks =  ['algebra__linear_1d', 'polynomials__add', 'polynomials__expand', 'calculus__differentiate', 'algebra__sequence_next_term']

# task = 'polynomials__add'

In [34]:
train_levels = ['train-easy', 'train-medium', 'train-hard']

### Train ds

In [35]:
def tokenize_task(task):
    tss, tts, tls = [], [], []

    for tr_l in tqdm(train_levels):
        fname = f'{data_path}/{tr_l}/{task}.txt'

        quess, anss = load_qa(fname)

        tokenized_source, tokenized_target, tokenized_label = tokenize_qa(quess, anss)
        tss.append(tokenized_source)
        tts.append(tokenized_target)
        tls.append(tokenized_label)

        del quess, anss

    ts = torch.concat(tss)
    tt = torch.concat(tts)
    tl = torch.concat(tls)

    del tss, tts, tls

    train_ds = torch.utils.data.TensorDataset(ts, tt, tl)

    torch.save(train_ds, f'{out_path}/{task}_train.pt')

    interpolate_ds = create_ds(f'{data_path}/interpolate/{task}.txt')
    torch.save(interpolate_ds, f'{out_path}/{task}_interpolate.pt')
    interpolate_ds = create_ds(f'{data_path}/extrapolate/{task}.txt')
    torch.save(interpolate_ds, f'{out_path}/{task}_extrapolate.pt')

In [36]:
for task in tasks:
    print(task)
    tokenize_task(task)

algebra__linear_1d


100%|██████████| 666666/666666 [00:02<00:00, 239459.47it/s]
100%|██████████| 666666/666666 [00:03<00:00, 191717.96it/s]
100%|██████████| 666666/666666 [00:03<00:00, 181588.21it/s]
100%|██████████| 666666/666666 [00:03<00:00, 188481.38it/s]
100%|██████████| 666666/666666 [00:02<00:00, 247201.97it/s]
100%|██████████| 666666/666666 [00:03<00:00, 195482.21it/s]
100%|██████████| 3/3 [01:00<00:00, 20.03s/it]
100%|██████████| 666666/666666 [00:03<00:00, 176004.45it/s]
100%|██████████| 666666/666666 [00:03<00:00, 187816.22it/s]
100%|██████████| 666666/666666 [00:02<00:00, 235482.89it/s]
100%|██████████| 666666/666666 [00:03<00:00, 193230.30it/s]


polynomials__add


100%|██████████| 666666/666666 [00:05<00:00, 131617.48it/s]
100%|██████████| 666666/666666 [00:03<00:00, 175334.62it/s]
100%|██████████| 666666/666666 [00:04<00:00, 145356.81it/s]
100%|██████████| 666666/666666 [00:03<00:00, 178173.26it/s]
100%|██████████| 666666/666666 [00:04<00:00, 150836.72it/s]
100%|██████████| 666666/666666 [00:03<00:00, 179684.41it/s]
100%|██████████| 3/3 [01:05<00:00, 21.98s/it]
100%|██████████| 666666/666666 [00:04<00:00, 142044.04it/s]
100%|██████████| 666666/666666 [00:03<00:00, 182500.13it/s]
100%|██████████| 666666/666666 [00:04<00:00, 135371.59it/s]
100%|██████████| 666666/666666 [00:03<00:00, 179334.65it/s]


polynomials__expand


100%|██████████| 666666/666666 [00:04<00:00, 165741.30it/s]
100%|██████████| 666666/666666 [00:03<00:00, 216347.83it/s]
100%|██████████| 666666/666666 [00:04<00:00, 153999.34it/s]
100%|██████████| 666666/666666 [00:03<00:00, 180422.63it/s]
100%|██████████| 666666/666666 [00:04<00:00, 145408.23it/s]
100%|██████████| 666666/666666 [00:03<00:00, 174631.02it/s]
100%|██████████| 3/3 [01:04<00:00, 21.39s/it]
100%|██████████| 666666/666666 [00:03<00:00, 166855.44it/s]
100%|██████████| 666666/666666 [00:03<00:00, 216377.50it/s]
100%|██████████| 666666/666666 [00:04<00:00, 156434.31it/s]
100%|██████████| 666666/666666 [00:03<00:00, 187101.77it/s]


calculus__differentiate


100%|██████████| 666666/666666 [00:04<00:00, 148411.50it/s]
100%|██████████| 666666/666666 [00:03<00:00, 170079.46it/s]
100%|██████████| 666666/666666 [00:03<00:00, 197250.91it/s]
100%|██████████| 666666/666666 [00:03<00:00, 177817.90it/s]
100%|██████████| 666666/666666 [00:04<00:00, 152056.37it/s]
100%|██████████| 666666/666666 [00:03<00:00, 172489.45it/s]
100%|██████████| 3/3 [01:04<00:00, 21.53s/it]
100%|██████████| 666666/666666 [00:03<00:00, 173690.80it/s]
100%|██████████| 666666/666666 [00:03<00:00, 214842.55it/s]
100%|██████████| 666666/666666 [00:04<00:00, 163104.72it/s]
100%|██████████| 666666/666666 [00:03<00:00, 183468.81it/s]


algebra__sequence_next_term


100%|██████████| 666666/666666 [00:04<00:00, 155121.86it/s]
100%|██████████| 666666/666666 [00:03<00:00, 183576.02it/s]
100%|██████████| 666666/666666 [00:03<00:00, 214646.74it/s]
100%|██████████| 666666/666666 [00:03<00:00, 191979.51it/s]
100%|██████████| 666666/666666 [00:04<00:00, 161735.94it/s]
100%|██████████| 666666/666666 [00:03<00:00, 184931.65it/s]
100%|██████████| 3/3 [01:03<00:00, 21.10s/it]
100%|██████████| 666666/666666 [00:03<00:00, 184356.03it/s]
100%|██████████| 666666/666666 [00:03<00:00, 185265.19it/s]
100%|██████████| 666666/666666 [00:03<00:00, 193069.39it/s]
100%|██████████| 666666/666666 [00:03<00:00, 187188.25it/s]
