# **Step 1: This notebook finetunes a model to generate decompositions**

In [None]:
# set up
import torch

from google.colab import drive
drive.mount('/content/drive')

!git clone https://github.com/anayap0/strategyqa_v2.git

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device:', DEVICE)

Mounted at /content/drive
Cloning into 'strategyqa_v2'...
remote: Enumerating objects: 159, done.[K
remote: Counting objects: 100% (159/159), done.[K
remote: Compressing objects: 100% (119/119), done.[K
remote: Total 159 (delta 82), reused 110 (delta 38), pack-reused 0[K
Receiving objects: 100% (159/159), 33.59 MiB | 20.04 MiB/s, done.
Resolving deltas: 100% (82/82), done.
Device: cuda


In [None]:
# relevant imports
from strategyqa_v2.src.SQP1Dataset import initialize_datasets, SQP1Dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.optim import Optimizer, AdamW
from tqdm.notebook import tqdm

In [None]:
# Load tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("t5-base", model_max_length=512)
break_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-break_data", model_max_length=512)
# model = T5ForConditionalGeneration.from_pretrained('t5-base').to(DEVICE)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The `xla

In [None]:
datasets = initialize_datasets('strategyqa_v2/data/train.json', 'strategyqa_v2/data/dev.json', tokenizer)
breaked_datasets = initialize_datasets('strategyqa_v2/data/train.json', 'strategyqa_v2/data/dev.json', break_tokenizer)

breaked_train_dataloader = DataLoader(breaked_datasets['train'],
                                   batch_size=32,
                                   shuffle=True,
                                   collate_fn=SQP1Dataset.collate_fn)

breaked_validation_dataloader = DataLoader(breaked_datasets['dev'],
                                   batch_size=32,
                                   shuffle=False,
                                   collate_fn=SQP1Dataset.collate_fn)


In [None]:
next(iter(train_dataloader))["target_ids"]

{'input_ids': tensor([[ 2645,    65, 18063,  ...,     0,     0,     0],
        [  363,    47, 13346,  ...,     0,     0,     0],
        [  363,   349,    19,  ...,     0,     0,     0],
        ...,
        [  571,   186,  1688,  ...,     0,     0,     0],
        [ 2840,    19,     8,  ...,     0,     0,     0],
        [  366,   410,  9066,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

In [None]:
def train_one_epoch(model: nn.Module, train_dataloader: DataLoader, optimizer: Optimizer, epoch: int):
    model.train()
    with tqdm(train_dataloader, desc=f"Train Ep {epoch}", total=len(train_dataloader)) as tq:
        for batch in tq:
            inputs = batch['input_ids'].input_ids.to(model.device)
            outputs = batch['target_ids'].input_ids.to(model.device)

            loss = model(input_ids=inputs, labels=outputs).loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


In [None]:
# finetune model
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
for i in range(1, 11):
  train_one_epoch(model, train_dataloader, optimizer, i)

Train Ep 1:   0%|          | 0/65 [00:00<?, ?it/s]

Train Ep 2:   0%|          | 0/65 [00:00<?, ?it/s]

Train Ep 3:   0%|          | 0/65 [00:00<?, ?it/s]

Train Ep 4:   0%|          | 0/65 [00:00<?, ?it/s]

Train Ep 5:   0%|          | 0/65 [00:00<?, ?it/s]

Train Ep 6:   0%|          | 0/65 [00:00<?, ?it/s]

Train Ep 7:   0%|          | 0/65 [00:00<?, ?it/s]

Train Ep 8:   0%|          | 0/65 [00:00<?, ?it/s]

Train Ep 9:   0%|          | 0/65 [00:00<?, ?it/s]

Train Ep 10:   0%|          | 0/65 [00:00<?, ?it/s]

In [None]:
def generatePredictions(model: nn.Module, tokenizer, data):
  qids_with_decomps = {}
  for d in data:
    question = d['question']
    input_ids = tokenizer(question, return_tensors="pt", padding=True, truncation=True).input_ids.to(DEVICE)

    model.eval()
    predictions = model.generate(input_ids=input_ids, max_length=512)
    generated = ""

    for pred in predictions:
      generated = "".join(tokenizer.decode(pred, skip_special_tokens=True))
      generated = generated.split("SEP>")

    qids_with_decomps[d['qid']] = {"decomposition": generated}

  return qids_with_decomps

In [None]:
# generate predictions for test data
test_data = json.load(open("./strategyqa_v2/data/strategyqa_test.json", encoding="utf8"))

output = generatePredictions(model, tokenizer, test_data)

with open("./drive/MyDrive/UW/CSE 447/Final Project/NLP/test_decomps.json", "w") as f:
  json.dump(output, f, indent=4)

In [None]:
# evaluate - need to increase SARI as much as we can
# data must be written to file above to evaluate
# NOTE: must delete runtime and run again
!python ./strategyqa_v2/src/evaluators/evaluate_all.py --golds_file ./strategyqa_v2/data/dev.json --predictions_file ./strategyqa_v2/data/decomps_to_evaluate.json

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./strategyqa_v2/data/decomps_to_evaluate.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4381665289373771,
    "Recall@10": 0.0
}


# EXPERIMENTS

## Hyperparameters

In [None]:
import json
import locale
locale.getpreferredencoding = lambda: "UTF-8"
from transformers import AutoTokenizer, BartForConditionalGeneration

In [None]:
# BART failed to produce model with high SARI
bart_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
# break_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-break_data")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [None]:
def fresh_model_optimizer(lr=1e-5, func=None):
  if func is not None:
    model, optimizer = func(lr)
  else:
    model = T5ForConditionalGeneration.from_pretrained('t5-base').to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
  return model, optimizer

def fresh_model_optimizer_bart(lr=1e-5):
  model = BartForConditionalGeneration.from_pretrained('facebook/bart-base').to(DEVICE)
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
  return model, optimizer

def fresh_model_optimizer_break(lr=1e-5):
  model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-break_data").to(DEVICE)
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
  return model, optimizer

def generate_batch_predictions(model: nn.Module, tokenizer, validation_dataloader):
  qids_with_decomps = {}
  for batch in validation_dataloader:

    questions = batch["questions"]
    input_encoding = tokenizer(questions, return_tensors="pt", padding=True, truncation=True).input_ids.to(DEVICE)
    predictions = model.generate(input_encoding, max_length=512)
    generated = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    for qid, gen in zip(batch["qids"], generated):
      qids_with_decomps[qid] = {"decomposition": gen.split("SEP>")}
  return qids_with_decomps

def generate_batch_predictions_beam(model: nn.Module, tokenizer, validation_dataloader, nbeam=4):
  qids_with_decomps = {}
  for batch in validation_dataloader:

    questions = batch["questions"]
    input_encoding = tokenizer(questions, return_tensors="pt", padding=True, truncation=True).input_ids.to(DEVICE)
    predictions = model.generate(input_encoding, num_beams=nbeam, max_length=512, do_sample=True)
    generated = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    for qid, gen in zip(batch["qids"], generated):
      qids_with_decomps[qid] = {"decomposition": gen.split("SEP>")}
  return qids_with_decomps

def print_evaluation(fname):
  !python ./strategyqa_v2/src/evaluators/evaluate_all.py --golds_file ./strategyqa_v2/data/dev.json --predictions_file $fname

def check_valid(decomps):
  for i in range(len(decomps)):


def evaluate(model: nn.Module, tokenizer, validation_dataloader, bs, lr, eps, model_name="T5"):
  model.eval()
  output = generate_batch_predictions(model, tokenizer, validation_dataloader)
  fname = f"./drive/MyDrive/nlp_models/generations/{model_name}_test_decomps_{bs}_{lr}_{eps}.json"
  with open(fname, "w") as f:
    json.dump(output, f, indent=4)
  print_evaluation(fname)


In [None]:
# BART MODEL <- this is a model finetuned on the
bart_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
bart_datasets = initialize_datasets('strategyqa_v2/data/train.json', 'strategyqa_v2/data/dev.json', bart_tokenizer)
bart_validation_dataloader = DataLoader(bart_datasets['dev'],
                                      batch_size=32,
                                      shuffle=True,
                                      collate_fn=SQP1Dataset.collate_fn)

batch_sizes = [8, 16]
learning_rates = [1e-3, 1e-4, 5e-5]
epochs = [15]
for bs in batch_sizes:
  bart_train_dataloader = DataLoader(bart_datasets['train'],
                                      batch_size=bs,
                                      shuffle=True,
                                      collate_fn=SQP1Dataset.collate_fn)
  for lr in learning_rates:
    for eps in epochs:
      print(f"Batch size: {bs}, Learning rate: {lr}, Epochs: {eps}")
      # test_break_model, test_break_optimizer = fresh_model_optimizer(lr)
      test_bart_model, test_bart_optimizer = fresh_model_optimizer(lr, func=fresh_model_optimizer_bart)

      for i in range(eps):
        train_one_epoch(test_bart_model, bart_train_dataloader, test_bart_optimizer, i)
        test_bart_model.save_pretrained(f"./drive/MyDrive/nlp_models/models/BART_test_model_epoch_{i}_bs_{bs}_lr_{lr}")
        evaluate(test_bart_model, bart_tokenizer, bart_validation_dataloader, bs, lr, eps, model_name=f"BART_test_ep_{i}")
        print("-" * 10)
      del test_bart_model
      del test_bart_optimizer
    print("-" * 30)
    print("-" * 30)
      # test_break_model.save_pretrained(f"./drive/MyDrive/nlp_models/models/break_test_model_{bs}_{lr}_{eps}")
      # evaluate(test_model, tokenizer, validation_dataloader, bs, lr, eps)


Batch size: 8, Learning rate: 0.001, Epochs: 15


model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Train Ep 0:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_0_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4191069503821602,
    "Recall@10": 0.0
}
----------


Train Ep 1:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_1_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4195148294360511,
    "Recall@10": 0.0
}
----------


Train Ep 2:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_2_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.398714816241566,
    "Recall@10": 0.0
}
----------


Train Ep 3:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_3_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.39520837282613813,
    "Recall@10": 0.0
}
----------


Train Ep 4:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_4_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4191370988233955,
    "Recall@10": 0.0
}
----------


Train Ep 5:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_5_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.42047375395968856,
    "Recall@10": 0.0
}
----------


Train Ep 6:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_6_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4169873899758183,
    "Recall@10": 0.0
}
----------


Train Ep 7:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_7_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4219789590381097,
    "Recall@10": 0.0
}
----------


Train Ep 8:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_8_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4236128870912174,
    "Recall@10": 0.0
}
----------


Train Ep 9:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_9_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4182355769992824,
    "Recall@10": 0.0
}
----------


Train Ep 10:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_10_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.3858189426167212,
    "Recall@10": 0.0
}
----------


Train Ep 11:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_11_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.3947478071472993,
    "Recall@10": 0.0
}
----------


Train Ep 12:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_12_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.38654891980115363,
    "Recall@10": 0.0
}
----------


Train Ep 13:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_13_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.38249242470977274,
    "Recall@10": 0.0
}
----------


Train Ep 14:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_14_test_decomps_8_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.39306162732530037,
    "Recall@10": 0.0
}
----------
------------------------------
------------------------------
Batch size: 8, Learning rate: 0.0001, Epochs: 15


Train Ep 0:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_0_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5275043048678386,
    "Recall@10": 0.0
}
----------


Train Ep 1:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_1_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5523857668537459,
    "Recall@10": 0.0
}
----------


Train Ep 2:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_2_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5266946214311653,
    "Recall@10": 0.0
}
----------


Train Ep 3:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_3_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5308191556881972,
    "Recall@10": 0.0
}
----------


Train Ep 4:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_4_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5333283587712737,
    "Recall@10": 0.0
}
----------


Train Ep 5:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_5_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5334106366804445,
    "Recall@10": 0.0
}
----------


Train Ep 6:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_6_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.534106532880245,
    "Recall@10": 0.0
}
----------


Train Ep 7:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_7_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5277873833468596,
    "Recall@10": 0.0
}
----------


Train Ep 8:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_8_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.523763244354881,
    "Recall@10": 0.0
}
----------


Train Ep 9:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_9_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5327564741847669,
    "Recall@10": 0.0
}
----------


Train Ep 10:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_10_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5336572908877414,
    "Recall@10": 0.0
}
----------


Train Ep 11:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_11_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5277789970672127,
    "Recall@10": 0.0
}
----------


Train Ep 12:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_12_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5299003560617228,
    "Recall@10": 0.0
}
----------


Train Ep 13:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_13_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5264809831387437,
    "Recall@10": 0.0
}
----------


Train Ep 14:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_14_test_decomps_8_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5262820576005348,
    "Recall@10": 0.0
}
----------
------------------------------
------------------------------
Batch size: 8, Learning rate: 5e-05, Epochs: 15


Train Ep 0:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_0_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5272438433218652,
    "Recall@10": 0.0
}
----------


Train Ep 1:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_1_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5193584253923842,
    "Recall@10": 0.0
}
----------


Train Ep 2:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_2_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5467970877293542,
    "Recall@10": 0.0
}
----------


Train Ep 3:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_3_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5229900899263915,
    "Recall@10": 0.0
}
----------


Train Ep 4:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_4_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5484921051113393,
    "Recall@10": 0.0
}
----------


Train Ep 5:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_5_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.532180240401441,
    "Recall@10": 0.0
}
----------


Train Ep 6:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_6_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5349580484511948,
    "Recall@10": 0.0
}
----------


Train Ep 7:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_7_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5338006203172265,
    "Recall@10": 0.0
}
----------


Train Ep 8:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_8_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5234091225429784,
    "Recall@10": 0.0
}
----------


Train Ep 9:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_9_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.535470823990901,
    "Recall@10": 0.0
}
----------


Train Ep 10:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_10_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5224063323025877,
    "Recall@10": 0.0
}
----------


Train Ep 11:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_11_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5202576093916094,
    "Recall@10": 0.0
}
----------


Train Ep 12:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_12_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.534072368952206,
    "Recall@10": 0.0
}
----------


Train Ep 13:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_13_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5297885384371713,
    "Recall@10": 0.0
}
----------


Train Ep 14:   0%|          | 0/258 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_14_test_decomps_8_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5213590775961077,
    "Recall@10": 0.0
}
----------
------------------------------
------------------------------
Batch size: 16, Learning rate: 0.001, Epochs: 15


Train Ep 0:   0%|          | 0/129 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_0_test_decomps_16_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.38836080180961996,
    "Recall@10": 0.0
}
----------


Train Ep 1:   0%|          | 0/129 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_1_test_decomps_16_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.38493091904041843,
    "Recall@10": 0.0
}
----------


Train Ep 2:   0%|          | 0/129 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_2_test_decomps_16_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4076491154387196,
    "Recall@10": 0.0
}
----------


Train Ep 3:   0%|          | 0/129 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_3_test_decomps_16_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4002991713409679,
    "Recall@10": 0.0
}
----------


Train Ep 4:   0%|          | 0/129 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_4_test_decomps_16_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.37571976879698815,
    "Recall@10": 0.0
}
----------


Train Ep 5:   0%|          | 0/129 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_5_test_decomps_16_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.37657080932829906,
    "Recall@10": 0.0
}
----------


Train Ep 6:   0%|          | 0/129 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_6_test_decomps_16_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4229534456366145,
    "Recall@10": 0.0
}
----------


Train Ep 7:   0%|          | 0/129 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_7_test_decomps_16_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.39351021801657543,
    "Recall@10": 0.0
}
----------


Train Ep 8:   0%|          | 0/129 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/BART_test_ep_8_test_decomps_16_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.3981078189708794,
    "Recall@10": 0.0
}
----------


Train Ep 9:   0%|          | 0/129 [00:00<?, ?it/s]

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


In [None]:
# BREAK MODEL <- this is a model finetuned on the
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
break_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-break_data")

batch_sizes = [32]
learning_rates = [1e-3, 1e-4, 5e-5]
epochs = [15]
for bs in batch_sizes:
  breaked_train_dataloader = DataLoader(breaked_datasets['train'],
                                      batch_size=bs,
                                      shuffle=True,
                                      collate_fn=SQP1Dataset.collate_fn)
  for lr in learning_rates:
    for eps in epochs:
      print(f"Batch size: {bs}, Learning rate: {lr}, Epochs: {eps}")
      # test_break_model, test_break_optimizer = fresh_model_optimizer(lr)
      test_break_model, test_break_optimizer = fresh_model_optimizer(lr, func=fresh_model_optimizer_break)

      for i in range(eps):
        train_one_epoch(test_break_model, breaked_train_dataloader, test_break_optimizer, i)
        test_break_model.save_pretrained(f"./drive/MyDrive/nlp_models/models/break_test_model_epoch_{i}_bs_{bs}_lr_{lr}")
        evaluate(test_break_model, break_tokenizer, breaked_validation_dataloader, bs, lr, eps, model_name=f"break_test_ep_{i}")
        print("-" * 10)
      del test_break_model
      del test_break_optimizer
    print("-" * 30)
    print("-" * 30)
      # test_break_model.save_pretrained(f"./drive/MyDrive/nlp_models/models/break_test_model_{bs}_{lr}_{eps}")
      # evaluate(test_model, tokenizer, validation_dataloader, bs, lr, eps)


The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
  return self.fget.__get__(instance, owner)()


Batch size: 32, Learning rate: 0.001, Epochs: 15


Train Ep 0:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_0_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5424091449630386,
    "Recall@10": 0.0
}
----------


Train Ep 1:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_1_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5485825541753767,
    "Recall@10": 0.0
}
----------


Train Ep 2:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_2_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5524533499838492,
    "Recall@10": 0.0
}
----------


Train Ep 3:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_3_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5555470519667682,
    "Recall@10": 0.0
}
----------


Train Ep 4:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_4_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5613093231848802,
    "Recall@10": 0.0
}
----------


Train Ep 5:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_5_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5572715850218233,
    "Recall@10": 0.0
}
----------


Train Ep 6:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_6_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5557977102130399,
    "Recall@10": 0.0
}
----------


Train Ep 7:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_7_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5624227211108331,
    "Recall@10": 0.0
}
----------


Train Ep 8:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_8_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5667946030647808,
    "Recall@10": 0.0
}
----------


Train Ep 9:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_9_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.557177831289101,
    "Recall@10": 0.0
}
----------


Train Ep 10:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_10_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5560728326237652,
    "Recall@10": 0.0
}
----------


Train Ep 11:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_11_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5493361137677276,
    "Recall@10": 0.0
}
----------


Train Ep 12:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_12_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5573827959737165,
    "Recall@10": 0.0
}
----------


Train Ep 13:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_13_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5459117867467708,
    "Recall@10": 0.0
}
----------


Train Ep 14:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_14_test_decomps_32_0.001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5562687516203239,
    "Recall@10": 0.0
}
----------
------------------------------
------------------------------
Batch size: 32, Learning rate: 0.0001, Epochs: 15


The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


Train Ep 0:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_0_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4848732941947351,
    "Recall@10": 0.0
}
----------


Train Ep 1:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_1_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5119879414920293,
    "Recall@10": 0.0
}
----------


Train Ep 2:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_2_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.521317693445945,
    "Recall@10": 0.0
}
----------


Train Ep 3:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_3_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5329212021398568,
    "Recall@10": 0.0
}
----------


Train Ep 4:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_4_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.525981190183167,
    "Recall@10": 0.0
}
----------


Train Ep 5:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_5_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.530218918749442,
    "Recall@10": 0.0
}
----------


Train Ep 6:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_6_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.535216447930601,
    "Recall@10": 0.0
}
----------


Train Ep 7:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_7_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5314980077239594,
    "Recall@10": 0.0
}
----------


Train Ep 8:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_8_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5373726391161733,
    "Recall@10": 0.0
}
----------


Train Ep 9:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_9_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5356336064261379,
    "Recall@10": 0.0
}
----------


Train Ep 10:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_10_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5441474113193734,
    "Recall@10": 0.0
}
----------


Train Ep 11:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_11_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5473241707560835,
    "Recall@10": 0.0
}
----------


Train Ep 12:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_12_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5510921867468208,
    "Recall@10": 0.0
}
----------


Train Ep 13:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_13_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5584942438173014,
    "Recall@10": 0.0
}
----------


Train Ep 14:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_14_test_decomps_32_0.0001_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5555677749199107,
    "Recall@10": 0.0
}
----------
------------------------------
------------------------------
Batch size: 32, Learning rate: 5e-05, Epochs: 15


The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


Train Ep 0:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_0_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4605593362233255,
    "Recall@10": 0.0
}
----------


Train Ep 1:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_1_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.4602958421763494,
    "Recall@10": 0.0
}
----------


Train Ep 2:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_2_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.49341948460103197,
    "Recall@10": 0.0
}
----------


Train Ep 3:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_3_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5017010487189169,
    "Recall@10": 0.0
}
----------


Train Ep 4:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_4_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5137102182236297,
    "Recall@10": 0.0
}
----------


Train Ep 5:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_5_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5233323661691861,
    "Recall@10": 0.0
}
----------


Train Ep 6:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_6_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5317058846376078,
    "Recall@10": 0.0
}
----------


Train Ep 7:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_7_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5388109843730783,
    "Recall@10": 0.0
}
----------


Train Ep 8:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_8_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.535972663745691,
    "Recall@10": 0.0
}
----------


Train Ep 9:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_9_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5310309600967379,
    "Recall@10": 0.0
}
----------


Train Ep 10:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_10_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5332943114866586,
    "Recall@10": 0.0
}
----------


Train Ep 11:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_11_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5311370780631344,
    "Recall@10": 0.0
}
----------


Train Ep 12:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_12_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5322211549632103,
    "Recall@10": 0.0
}
----------


Train Ep 13:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_13_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5346155447460986,
    "Recall@10": 0.0
}
----------


Train Ep 14:   0%|          | 0/65 [00:00<?, ?it/s]

====Input Arguments====
{
  "golds_file": "./strategyqa_v2/data/dev.json",
  "metrics_output_file": "metrics.json",
  "predictions_file": "./drive/MyDrive/nlp_models/generations/break_test_ep_14_test_decomps_32_5e-05_15.json",
  "retrieval_limit": 10
}
{
    "Accuracy": 0.0,
    "SARI": 0.5398767369712345,
    "Recall@10": 0.0
}
----------
------------------------------
------------------------------


In [None]:
del test_break_model
del test_break_optimizer

## Fine tuning final model

In [None]:
best_model = T5ForConditionalGeneration.from_pretrained("./drive/MyDrive/nlp_models/break_test_model_epoch_14_bs_8_lr_0.0001/").to(DEVICE)

# for i in range(15, )

## **THE BELOW SECTION IS UNUSED CODE**

In [None]:
#### THIS WORKS WITHOUT ERRORS
input_question = "Are more people today related to Genghis Khan than Julius Caesar?"
decompositions = [
            "How many kids did Julius Caesar have?",
            "How many kids did Genghis Khan have?",
            "Is #2 greater than #1?"
        ]

inputs = tokenizer(input_question, return_tensors="pt", padding=True, truncation=True).input_ids.to(DEVICE)
outputs = tokenizer("<SEP>".join(decompositions), return_tensors="pt", padding=True, truncation=True).input_ids.to(DEVICE)
print(inputs)
print(outputs)

model.train()
# Fine-tuning
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
for epoch in range(5):
    optimizer.zero_grad()
    oputs = model(input_ids=inputs, labels=outputs)
    loss = oputs.loss
    loss.backward()
    optimizer.step()

# evaluate model after training on one example 5 times
model.eval()
predictions = model.generate(input_ids=inputs, max_length=512)
print(tokenizer.decode(predictions[0], skip_special_tokens=True))

tensor([[ 1521,    72,   151,   469,  1341,    12,  5945,  5649,     7, 14420,
           145,  9983,   302, 26218,    58,     1]])
tensor([[  571,   186,  1082,   410,  9983,   302, 26218,    43,    58,     2,
           134,  8569,  3155,  7825,   186,  1082,   410,  5945,  5649,     7,
         14420,    43,    58,     2,   134,  8569,  3155,   196,     7, 15493,
          2123,   145,  7172,    58,     1]])
Are more people related to Genghis Khan than Julius Caesar?


IndexError: index 1 is out of bounds for dimension 0 with size 1

In [None]:
# write data to file
data = []
evaluation_data = {}
all_questions = []
all_qids = []
all_decomps = []
for batch in validation_dataloader:
  all_questions += batch['questions']
  all_qids += batch['qids']
  all_decomps += batch['decomps']

print(generated)
formatted_decomps = [decomp.split("<SEP>") for decomp in all_decomps]
formatted_preds = [pred.split("SEP>") for pred in generated]
print(formatted_preds)

for i in range(len(all_questions)):
  data.append({
      "qid": all_qids[i],
      "question": all_questions[i],
      "predicted_decomposition": formatted_preds[i],
      "correct_decomposition": formatted_decomps[i]
  })
  evaluation_data[all_qids[i]] = {'decomposition': formatted_preds[i]}

print(data)

NameError: name 'generated' is not defined

In [None]:
import json

def write_jsonl(data, filename):
    with open(filename, 'w') as f:
        for entry in data:
            json.dump(entry, f)
            f.write('\n')

# Writing data to JSONL file
write_jsonl(data, 'strategyqa_v2/data/generated/t5predictions.jsonl')

In [None]:
# this is for the validation set

def evaluate(model: nn.Module, dataloader: DataLoader):
  model.eval()
  all_predictions = []
  with torch.no_grad():
    with tqdm(dataloader, desc=f"", total=len(dataloader)) as tq:
      for batch in tq:
        inputs = batch['input_ids'].input_ids.to(DEVICE)
        predictions = model.generate(input_ids=inputs, max_length=512)
        all_predictions += predictions

  return all_predictions

In [None]:
preds = evaluate(model, validation_dataloader)
generated = []
for pred in preds:
  print(tokenizer.decode(pred, skip_special_tokens=True))
  generated.append("".join(tokenizer.decode(pred, skip_special_tokens=True)))

## Can we use T5 to answer the question directly?


In [None]:
answerer = tokenizer = T5Tokenizer.from_pretrained("t5-base", model_max_length=512)
model = T5ForConditionalGeneration.from_pretrained('t5-base').to(DEVICE)

In [None]:
correct_answers = 0
total = len(train_dataloader)
for batch in train_dataloader:
  for data in batch:
    tokenized_question = tokenizer("question: " + data["question"], return_tensors="pt", padding=True, truncation=True).input_ids.to(DEVICE)
    encoded_ans = model.generate(input_ids=tokenized_question, max_length=512)
    decoded_ans = tokenizer.decode(ans[0], skip_special_tokens=True)
    actual_ans = "yes" if data["answer"] ==
    if decoded_ans == data["answer"]:
      print("correct")
    else:
      print("incorrect")