In [2]:
import argparse
from contextlib import nullcontext
import torch
from tqdm import tqdm

from data import get_dataset
from utils.training_utils import get_lr, get_run_name, AverageMeter
from torch.utils.data import DataLoader
from evaluate_model import evaluate, evaluate_forced
from models import get_model
from tokenizing import get_tokenizer
import wandb
from pprint import pprint


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
parser = argparse.ArgumentParser(description="Next-token failures")
# Data
parser.add_argument(
    "--model", default='gpt2', type=str, help="Type of model"
    )
parser.add_argument(
    "--dataset", default='graph', type=str, help="Choice of dataset"
    )
parser.add_argument(
    "--n_train", default=200000, type=int, help="Number of training samples"
    )
parser.add_argument(
    "--n_test", default=5000, type=int, help="Number of test samples"
    )
parser.add_argument(
    "--num_nodes", default=50, type=int, help="Number of node values in graph"
    )
parser.add_argument(
    "--deg", default=2, type=int, help="Degree of starting node"
    )
parser.add_argument(
    "--path_len", default=5, type=int, help="Path length in star graph"
    )
parser.add_argument(
        "--mate_in", default=2, type=int, help="For chess, number of moves to checkmate"
    )
parser.add_argument(
        "--unrolled", action=argparse.BooleanOptionalAction, default=True, help="For chess, unrolled board state",
    )
parser.add_argument(
        "--batch_size", type=int, default=256, help="Batch size",
    )
parser.add_argument(
        "--lr", type=float, default=5e-4, help="Learning rate",
    )
parser.add_argument(
        "--weight_decay", type=float, default=1e-2, help="Strength of weight decay",
    )
parser.add_argument(
        "--epochs", type=int, default=100, help="Number of epochs",
    )
parser.add_argument(
        "--save_every", type=int, default=5000, help="Interval (in steps) at which to save model",
    )
parser.add_argument(
        "--teacherless", action=argparse.BooleanOptionalAction, default=False, help="Standard or teacherless training",
    )
parser.add_argument(
        "--reverse", action=argparse.BooleanOptionalAction, default=False, help="Standard format or reverse targets",
    )
parser.add_argument(
        "--cot", action=argparse.BooleanOptionalAction, default=False, help="Standard format or cot targets",
    )
parser.add_argument(
        "--eval_train", action=argparse.BooleanOptionalAction, default=False, help="Eval for training set",
    )
parser.add_argument(
        "--eval_every", type=int, default=400, help="Interval (in steps) to evaluate the model on test",
    )
parser.add_argument(
        "--use_wandb", action=argparse.BooleanOptionalAction, default=False, help="Whether to use wandb",
    )
parser.add_argument(
        "--wandb_entity", type=str, default=5000, help="Wandb username",
    )


_StoreAction(option_strings=['--wandb_entity'], dest='wandb_entity', nargs=None, const=None, default=5000, type=<class 'str'>, choices=None, required=False, help='Wandb username', metavar=None)

In [4]:
args = parser.parse_args(args=[])

In [5]:
args.cot

False

In [6]:
args.cot = True
device = 'cpu'
tokenizer = get_tokenizer(args)
train_data, test_data = get_dataset(args, tokenizer, device)
# train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
# test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True)


Equal sequence lengths!
Equal sequence lengths!


In [22]:
filename = 'data/datasets/graphs/deg_2_path_5_nodes_50_train_200000.txt'
with open(filename, 'r') as f:
    lines = f.readlines()
for line in lines:
    prefix = line.strip().split('=')[0] + '='
    target = line.strip().split('=')[1]
    reverse = False
    cot = True
    if reverse:
        target = ','.join(target.split(',')[::-1])
    if cot:
        if reverse:
            raise ValueError("Cannot be cot and reverse at the same time.")
        rev_path = target.split(',')[::-1]
        path_str = ''
        for node in rev_path:
            path_str += str(node) + ','
        path_str = path_str[:-1] + ':' + target
        target = path_str
    print(prefix+target)
    break

49,13|23,28|46,35|32,42|35,32|42,22|13,23|46,49/46,22=22,42,32,35,46:46,35,32,42,22


In [23]:
train_data.data_file[0]

('49,13|23,28|46,35|32,42|35,32|42,22|13,23|46,49/46,22=',
 '22,42,32,35,46$46,35,32,42,22')

In [18]:
prefix, target

('49,13|23,28|46,35|32,42|35,32|42,22|13,23|46,49/46,22=', '46,35,32,42,22')