In [None]:
from google.colab import files
files.download('/content/first_results/epoch=11-val_exact_match=1.0000.ckpt')

In [None]:
!pip install pytorch_lightning==1.9.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_lightning==1.9.0
  Downloading pytorch_lightning-1.9.0-py3-none-any.whl (825 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m825.8/825.8 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m39.8 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.4.2
  Downloading lightning_utilities-0.8.0-py3-none-any.whl (20 kB)
Collecting aiohttp!=4.0.0a0,!=4.0.0a1
  Downloading aiohttp-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m60.0 MB/s[0m eta [36m0:00:00[0m
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2

In [None]:
!pip install transformers==4.28.1

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers==4.28.1
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m96.3 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m107.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m29.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1


In [35]:
# Train and evalute T5 on arithmetic problems.
SEED = 1

# 1/12 - 12 = TRAIN_SIZE / TRAIN_BATCh + valid_size / valid_batch

NUM_WORKERS = 4
LR = 3e-4
WEIGHT_DECAY = 5e-5
GAMMA = 1.0  # 0.1

STEP_SIZE = 1000
OPTIMIZER_NAME = 'AdamW'

MODEL_NAME = 't5-base'  # t5-small, t5-base
MIN_DIGITS_TRAIN = 2
MAX_DIGITS_TRAIN = 15
MIN_DIGITS_TEST = 2
MAX_DIGITS_TEST = 15
OUTPUT_DIR = 'first_results'

TRAIN_SIZE = 100000
TRAIN_BATCH_SIZE = 4
VAL_SIZE = 10000
VAL_BATCH_SIZE = 32
TEST_SIZE = 10000

MAX_SEQ_LEN = 512
CHECK_VAL_EVERY_N_EPOCH = 2
MAX_EPOCHS = 20
GPUS = 1

In [53]:
import argparse
import glob
import json
import os
import pytorch_lightning as pl
import random
import torch

from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

from typing import List


def compute_exact_match(predicted_answer, correct_answer) -> bool:
    predicted_answer = predicted_answer.strip().lower()
    correct_answer = correct_answer.strip().lower()
    return predicted_answer == correct_answer


def convert_to_10ebased(number: str) -> str:
    signal = None
    if number[0] == '-':
        signal = '-'
        number = number[1:]

    output = []
    for i, digit in enumerate(number[::-1]):
        output.append('10e' + str(i))
        output.append(digit)

    if signal:
        output.append(signal)

    # as we want it to _not_ be inverted, then we invert it.
    output = output[::-1]

    return ' '.join(output)



def translate_task(a_int: int, b_int: int):
    result_int = a_int + b_int

    a_str = convert_to_10ebased(str(a_int))
    b_str = convert_to_10ebased(str(b_int))
    result_str = convert_to_10ebased(str(result_int))

    question = f'What is {a_str} plus {b_str}?'
    return {
        'a_int': a_int,
        'b_int': b_int,
        'expected_result_int': result_int,

        'a_str': a_str,
        'b_str': b_str,
        'expected_result_str': result_str,

        'question': question,
    }

class T5Finetuner(pl.LightningModule):
    def __init__(self, train_dataloader, val_dataloader, test_dataloader):
        super(T5Finetuner, self).__init__()

        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

        self._train_dataloader = train_dataloader
        self._val_dataloader = val_dataloader
        self._test_dataloader = test_dataloader

    def prepare_batch(self, questions: List[str], answers: List[str]) -> List[str]:

        input_dict = self.tokenizer.batch_encode_plus(
            list(questions), padding=True, truncation=False, return_tensors='pt')

        labels = self.tokenizer.batch_encode_plus(
            list(answers), padding=True, truncation=False, return_tensors='pt')['input_ids']

        assert input_dict['input_ids'].shape[1] < MAX_SEQ_LEN
        assert labels.shape[1] < MAX_SEQ_LEN

        input_ids = input_dict['input_ids'].to(self.model.device)
        attention_mask = input_dict['attention_mask'].to(self.model.device)
        labels = labels.to(self.model.device)

        return input_ids, attention_mask, labels

    def forward(self, **kwargs):
        return self.model(**kwargs)

    def training_step(self, batch, batch_nb):
        questions, correct_answers = batch

        # Log every power of two.
        if batch_nb & (batch_nb - 1) == 0:
            print(questions[0])
            print(correct_answers[0])

        input_ids, attention_mask, labels = self.prepare_batch(questions=questions, answers=correct_answers)

        loss = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)[0]

        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}
      
    def predict(self, question):
        input_ids, attention_mask, _ = self.prepare_batch(questions=[question], answers=[question])
        batch_outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=False,
                                            max_length=MAX_SEQ_LEN)

        predicted_answers = [
            self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            for output in batch_outputs]
          
        return predicted_answers[0]


    def inference_step(self, batch, batch_nb: int):
        questions, correct_answers = batch

        input_ids, attention_mask, _ = self.prepare_batch(questions=questions, answers=correct_answers)
        batch_outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, do_sample=False,
                                            max_length=MAX_SEQ_LEN)

        predicted_answers = [
            self.tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            for output in batch_outputs]

        exact_matches = [
            compute_exact_match(predicted_answer=predicted_answer, correct_answer=correct_answer)
            for predicted_answer, correct_answer in zip(predicted_answers, correct_answers)]

        # Log every power of two.
        if batch_nb & (batch_nb - 1) == 0:
            print('\nQuestion:', questions[0])
            print('Correct:  ', correct_answers[0])
            print('Predicted:', predicted_answers[0].encode('utf-8'))
            print('Exact?', exact_matches[0])

        metrics = {'exact_matches': exact_matches}
        return metrics

    def validation_step(self, batch, batch_nb):
        return self.inference_step(batch, batch_nb)

    def test_step(self, batch, batch_nb):
        return self.inference_step(batch, batch_nb)

    def validation_epoch_end(self, outputs):
        print('QQ: in validation_epoch_end')
        exact_matches = []
        for x in outputs:
            exact_matches.extend(x['exact_matches'])
        exact_match = sum(exact_matches) / len(exact_matches)

        metrics = {'val_exact_match': exact_match}

        output = metrics.copy()
        output['progress_bar'] = metrics

        # added
        self.log('val_exact_match', exact_match, prog_bar=True)

        return output

    def test_epoch_end(self, outputs):
        exact_matches = []
        for x in outputs:
            exact_matches.extend(x['exact_matches'])
        exact_match = sum(exact_matches) / len(exact_matches)

        metrics = {'test_exact_match': exact_match}
        print('test_exact_match', exact_match)

        output = metrics.copy()
        output['progress_bar'] = metrics
        self.log('test_exact_match', exact_match, prog_bar=True)

        return output

    def train_dataloader(self):
        return self._train_dataloader

    def val_dataloader(self):
        return self._val_dataloader

    def test_dataloader(self):
        return self._test_dataloader

    def get_optimizer(self):
        optimizer = getattr(torch.optim, OPTIMIZER_NAME)

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": WEIGHT_DECAY,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0
            },
        ]
        optimizer = optimizer(optimizer_grouped_parameters, lr=LR, weight_decay=WEIGHT_DECAY)

        print(f'=> Using {OPTIMIZER_NAME} optimizer')

        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE,
                                                    gamma=GAMMA)
        print(f'=> Using StepLR (step_size = {STEP_SIZE}, gamma = {GAMMA}) scheduler')

        return [optimizer], [scheduler]

    def configure_optimizers(self):
        optimizer = self.get_optimizer()
        return optimizer


class MyDataset(Dataset):
    def __init__(self, n_examples: int, min_digits: int, max_digits: int):

        self.max_digits = max_digits

        # if balance:
        self.examples = []
        for _ in range(n_examples):
            example = []
            for _ in range(2):
                max_digits_i = random.randint(min_digits, max_digits)
                min_number = int((max_digits_i - 1) * '9') + 1
                max_number = int(max_digits_i * '9')
                example.append(random.randint(min_number, max_number))
            self.examples.append(example)

    def __len__(self):
        return len(self.examples)


    def __getitem__(self, idx):
        first_term, second_term = self.examples[idx]
        translated_task = translate_task(first_term, second_term)
        

        return translated_task['question'], translated_task['expected_result_str']


def train():

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    random.seed(SEED)
    pl.seed_everything(SEED)

    dataset_train = MyDataset(n_examples=TRAIN_SIZE, min_digits=MIN_DIGITS_TRAIN, max_digits=MAX_DIGITS_TRAIN)
    dataset_val = MyDataset(n_examples=VAL_SIZE, min_digits=MIN_DIGITS_TRAIN, max_digits=MAX_DIGITS_TRAIN)
    dataset_test = MyDataset(n_examples=TEST_SIZE, min_digits=MIN_DIGITS_TEST, max_digits=MAX_DIGITS_TEST)

    train_dataloader = DataLoader(dataset_train, batch_size=TRAIN_BATCH_SIZE,
                                  shuffle=True, num_workers=NUM_WORKERS)
    val_dataloader = DataLoader(dataset_val, batch_size=VAL_BATCH_SIZE,
                                shuffle=False, num_workers=NUM_WORKERS)
    test_dataloader = DataLoader(dataset_test, batch_size=VAL_BATCH_SIZE,
                                 shuffle=False, num_workers=NUM_WORKERS)

    checkpoint_callback = ModelCheckpoint(
        dirpath=OUTPUT_DIR, filename='{epoch}-{val_exact_match:.4f}',
        verbose=False, save_last=True, save_top_k=2, mode='max', monitor='val_exact_match',
        save_weights_only=False, every_n_epochs=CHECK_VAL_EVERY_N_EPOCH,
        # save_on_train_epoch_end=True
    )

    trainer = pl.Trainer(
        precision=32,
        callbacks=[checkpoint_callback],
        max_epochs=MAX_EPOCHS,
        check_val_every_n_epoch=CHECK_VAL_EVERY_N_EPOCH,
        accumulate_grad_batches=32,
        gradient_clip_val=1.0,
        amp_level='O0',
        amp_backend='apex',
        gpus=GPUS)

    model = T5Finetuner(train_dataloader=train_dataloader,
                        val_dataloader=val_dataloader,
                        test_dataloader=test_dataloader)

    trainer.fit(model)
  
# train()

# checkpoint_path = glob.glob(os.path.join(OUTPUT_DIR, '*.ckpt'))[0]
# checkpoint_path = '/content/first_results/epoch=11-val_exact_match=1.0000.ckpt'
# model = T5Finetuner.load_from_checkpoint(checkpoint_path,
#                                          train_dataloader=train_dataloader,
#                                          val_dataloader=val_dataloader,
#                                          test_dataloader=test_dataloader)

# results = trainer.test(model)

# output = {'seed': SEED,
#           'max_digits_train': MAX_DIGITS_TRAIN,
#           'max_digits_test': MAX_DIGITS_TEST,
#           'test_exact_match_': results[0]}

# with open(os.path.join(OUTPUT_DIR, 'results.json'), 'w') as fout:
#     json.dump(output, fout)

# print('Done!')


In [55]:
from typing import Tuple

class Solver2():
    def __init__(self):
        checkpoint_path = '/content/first_results/epoch=11-val_exact_match=1.0000.ckpt'

        self.model = T5Finetuner.load_from_checkpoint(checkpoint_path, 
                                                      train_dataloader=None, 
                                                      val_dataloader=None, 
                                                      test_dataloader=None)
        

    def calc_sum(self, a: int, b: int) -> Tuple[int, dict]:
        translated_task = translate_task(a, b)
        model_answer = self.model.predict(translated_task['question'])
        return model_answer, translated_task

        # real_model_answer = full_model_answer[len(question):]
        # code_model_answer = real_model_answer.split(QUESTION_START_MARKER)[0]

        # extra_tab_len = 4
        # code_model_answer = code_model_answer.split('\n')
        # for i in range(len(code_model_answer)):
        #     if len(code_model_answer[i]) > extra_tab_len:
        #         code_model_answer[i] = code_model_answer[i][extra_tab_len:]

        # # do not change order
        # answer_int = self.execute(code_model_answer)
        # code_model_answer = '\n'.join(code_model_answer)

        # meta_info = {
        #     'question': question,
        #     'full_model_answer': full_model_answer,
        #     'code_model_answer': code_model_answer,
        #     'answer_int': answer_int
        # }

        # return answer_int, meta_info


if __name__ == '__main__':
    solver = Solver2()
    a = 2
    b = 3
    expected = a + b
    answer_int, meta_info = solver.calc_sum(a, b)

    print(meta_info)
    print(answer_int)
    print(expected)


For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


{'a_int': 2, 'b_int': 3, 'expected_result_int': 5, 'a_str': '2 10e0', 'b_str': '3 10e0', 'expected_result_str': '5 10e0', 'question': 'What is 2 10e0 plus 3 10e0?'}
5 10e0
5
