In [7]:
import argparse
from contextlib import nullcontext
import torch
from tqdm import tqdm

from data import get_dataset
from utils.training_utils import get_lr, get_run_name, AverageMeter
from torch.utils.data import DataLoader
from models import get_model
from tokenizing import get_tokenizer
import wandb
from pprint import pprint
import numpy as np
from datasets import Dataset
from copy import deepcopy
from transformers import TrainingArguments
from trl import DPOTrainer
from transformers import AutoTokenizer
import os

In [8]:
parser = argparse.ArgumentParser(description="Next-token failures")
# Data
parser.add_argument(
    "--n_samples", type=int, default=5, help="Number of samples to generate"
    )
parser.add_argument(
    "--model", default='gpt2', type=str, help="Type of model"
    )
parser.add_argument(
    "--dataset", default='graph', type=str, help="Choice of dataset"
    )
parser.add_argument(
    "--n_train", default=200000, type=int, help="Number of training samples"
    )
parser.add_argument(
    "--n_test", default=500, type=int, help="Number of test samples"
    )
parser.add_argument(
    "--num_nodes", default=50, type=int, help="Number of node values in graph"
    )
parser.add_argument(
    "--deg", default=2, type=int, help="Degree of starting node"
    )
parser.add_argument(
    "--path_len", default=5, type=int, help="Path length in star graph"
    )
parser.add_argument(
        "--mate_in", default=2, type=int, help="For chess, number of moves to checkmate"
    )
parser.add_argument(
        "--unrolled", action=argparse.BooleanOptionalAction, default=True, help="For chess, unrolled board state",
    )
parser.add_argument(
        "--batch_size", type=int, default=48, help="Batch size",
    )
parser.add_argument(
        "--lr", type=float, default=1e-5, help="Learning rate",
    )
parser.add_argument(
        "--weight_decay", type=float, default=0., help="Strength of weight decay",
    )
parser.add_argument(
        "--epochs_sft", type=int, default=1, help="Number of SFT epochs",
    )
parser.add_argument(
        "--save_every", type=int, default=5000, help="Interval (in steps) at which to save model",
    )
parser.add_argument(
        "--pass_at_k", type=int, default=1, help="pass at k eval",
    )
parser.add_argument(
        "--teacherless", action=argparse.BooleanOptionalAction, default=False, help="Standard or teacherless training",
    )
parser.add_argument(
        "--reverse", action=argparse.BooleanOptionalAction, default=False, help="Standard format or reverse targets",
    )
parser.add_argument(
        "--cot", action=argparse.BooleanOptionalAction, default=False, help="Standard format or cot targets",
    )
parser.add_argument(
        "--pos", action=argparse.BooleanOptionalAction, default=False, help="Standard format or pos tokens",
    )
parser.add_argument(
        "--eval_train", action=argparse.BooleanOptionalAction, default=False, help="Eval for training set",
    )
parser.add_argument(
        "--eval_every", type=int, default=400, help="Interval (in steps) to evaluate the model on test",
    )
parser.add_argument(
        "--use_wandb", action=argparse.BooleanOptionalAction, default=False, help="Whether to use wandb",
    )
parser.add_argument(
        "--wandb_entity", type=str, default=5000, help="Wandb username",
    )
parser.add_argument(
    "--pad_length", default=0, type=int, help="Default value for pad length"
)


args = parser.parse_args([])

In [9]:
# System stuff
device = 'cuda' if torch.cuda.is_available() else 'cpu'
wandb_entity = args.wandb_entity
wandb_log = args.use_wandb
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# Model stuff
top_k = 1000
temperature = 1.
pass_at_k = args.pass_at_k
n_samples = args.n_samples

# Evaluation stuff
eval_iters = 1000
eval_interval = 5
log_interval = 10

# Optimiser
dtype = 'bfloat16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
beta1 = 0.9
beta2 = 0.999
decay_lr = False
args.compile = False if device == 'cuda' else False
args.use_flash = True if device == 'cuda' else False
warmup_iters = 100
min_lr = 1e-5

In [10]:
tokenizer = get_tokenizer(args)

# train data without cot
args.cot = False
args.n_train = 500
args.n_test = 100
args.pad_length = 10
train_data, test_data = get_dataset(args, tokenizer, device)
train_data.num_target_tokens += train_data.pad_length
test_data.num_target_tokens += test_data.pad_length
train_data.num_tokens += train_data.pad_length
test_data.num_tokens += test_data.pad_length

# train data with cot
args.cot = True
args.n_train = 100
args.pad_length = 0
train_data_wcot, _ = get_dataset(args, tokenizer, device)
train_data_wcot.data_file = train_data_wcot.data_file[-args.n_train:]
train_data_wcot.tokenized, train_data_wcot.num_prefix_tokens, train_data_wcot.num_target_tokens = train_data_wcot.tokenizer.tokenize(train_data_wcot.data_file)


print(len(train_data), train_data[0][0].shape, len(train_data_wcot), train_data_wcot[0][0].shape)
print("W/o COT", train_data[0], tokenizer.decode(train_data[0][0]))
print("With COT", train_data_wcot[0], tokenizer.decode(train_data_wcot[0][0]))

Equal sequence lengths!
Equal sequence lengths!
Equal sequence lengths!
Equal sequence lengths!
Equal sequence lengths!
500 torch.Size([54]) 100 torch.Size([54])
W/o COT (tensor([ 2920,    11,  1485,    91,  1954,    11,  2078,    91,  3510,    11,
         2327,    91,  2624,    11,  3682,    91,  2327,    11,  2624,    91,
         3682,    11,  1828,    91,  1485,    11,  1954,    91,  3510,    11,
         2920,    14,  3510,    11,  1828,    28,  3510,    11,  2327,    11,
         2624,    11,  3682,    11,  1828, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256], device='cuda:0'), tensor([   -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,  3510,    11,  2327,    11,  2624,
           11,  3682,    11,  1828, 50256, 50256, 50256, 50256, 50256, 5

In [11]:
train_data.num_prefix_tokens, train_data.num_target_tokens, train_data.num_tokens, train_data_wcot.num_prefix_tokens, train_data_wcot.num_target_tokens, train_data_wcot.num_tokens, test_data.num_prefix_tokens, test_data.num_target_tokens, test_data.num_tokens

(36, 19, 55, 36, 19, 55, 36, 19, 55)

In [12]:
from torch.utils.data import ConcatDataset
concatenated_train_data = ConcatDataset([train_data, train_data_wcot])
train_loader = DataLoader(concatenated_train_data, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True)
len(train_loader), len(test_loader)

(13, 3)

In [13]:
def evaluate(model, loader, tokenizer, ctx, num_prefix_tokens, y_tokens, pad_length):
    model.eval()
    total_acc = AverageMeter()
    bar = tqdm(loader)
    for x in bar:
        original_length = x.shape[1]
        y = x[:, -y_tokens:].clone()
        x = x[:, :num_prefix_tokens].clone()
        with ctx:
            y_pred = model.generate(x, max_new_tokens=original_length + pad_length, min_length=original_length + pad_length,
                                do_sample=False, attention_mask = torch.ones_like(x), pad_token_id=2)
        completely_correct = 0 
        for i in range(y.shape[0]):
            completely_correct += int(tokenizer.decode(y[i]) in tokenizer.decode(y_pred[i]))
        completely_correct /= x.shape[0]
        total_acc.update(completely_correct, x.shape[0])
        bar.set_description(f' accuracy: {total_acc.get(percentage=True):.2f}')
    loader.dataset.train()
    model.train()



In [14]:
def generate_and_score(model, loader, tokenizer, ctx, temperature, top_p, n_samples, num_prefix_tokens, num_target_tokens, y_tokens):
    model.eval()
    bar = tqdm(loader)
    x_dataset = []
    y_pred_dataset = []
    scores_dataset = []
    for x in bar:
        original_length = x.shape[1]
        y = x[:, -y_tokens:].clone()
        x = x[:, :num_prefix_tokens].clone()
        y = y
        y_pred = []
        scores = []
        for i in range(n_samples):
            with ctx:
                # print(x.shape, num_target_tokens)
                _y_pred = model.generate(x, min_length=num_prefix_tokens+num_target_tokens, 
                                   max_length=num_prefix_tokens+num_target_tokens, temperature=temperature, top_p=top_p,
                                   do_sample=True, attention_mask = torch.ones_like(x), pad_token_id=2)
            _scores = []
            for i in range(y.shape[0]):
                _scores.append(int(tokenizer.decode(y[i]) in tokenizer.decode(_y_pred[i])))
            # print(_y_pred.shape)
            y_pred.append(_y_pred.cpu())
            scores.append(_scores)
            
        scores = torch.tensor(scores).float().transpose(0, 1)
        y_pred = torch.stack(y_pred, dim=0).transpose(0, 1)
        # y_pred should be bs x nsamples x length
        y_pred_dataset.append(y_pred[:, :, num_prefix_tokens:].cpu())
        scores_dataset.append(scores.cpu())
        # print(scores.shape, y_pred.shape, x.shape)
        x_dataset.append(x.cpu())
        # correct should be bs x nsamples
    # Switch back to train mode
    loader.dataset.train()
    model.train()
    y_pred_dataset = torch.cat(y_pred_dataset, dim=0)
    x_dataset = torch.cat(x_dataset, dim=0)
    scores_dataset = torch.cat(scores_dataset, dim=0)
    return x_dataset, y_pred_dataset, scores_dataset

In [15]:
args.model = 'gpt2'
args.epochs_sft = 10
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained(args.model)
model.to(device)
model.train()

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.0)
ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)

max_iters = len(train_loader) * args.epochs_sft
lr_decay_iters = max_iters

results = {}
num_iters = 0

for ep in range(args.epochs_sft):
    train_bar = tqdm(train_loader)
    total_loss, total_acc = AverageMeter(), AverageMeter()
    verifier_loss = AverageMeter()
    for x, y in train_bar:
        with ctx:
            # logits, loss, accs = model(x, y)
            logits = model(x)['logits']
            loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)
            acc = torch.mean((torch.argmax(logits[:, -train_data.num_target_tokens, :], dim=-1) == y[:, -train_data.num_target_tokens]).float())

        # verifier 
        if ep > 5:
            num_prefix_tokens=train_data.num_prefix_tokens
            num_target_tokens=train_data.num_target_tokens
            y_tokens=train_data.num_target_tokens-train_data.pad_length
            x = torch.cat([x, y[:, -1:]], dim=1)
            y = x[:, -y_tokens:].clone()
            x = x[:, :num_prefix_tokens].clone()
            y_pred = []
            scores = []
            for i in range(args.n_samples):
                with ctx:
                    _y_pred = model.generate(x, min_length=num_prefix_tokens+num_target_tokens, 
                                max_length=num_prefix_tokens+num_target_tokens, temperature=1., top_p=0.99,
                                do_sample=True, attention_mask = torch.ones_like(x), pad_token_id=2)
                _scores = []
                for i in range(y.shape[0]):
                    _scores.append(int(tokenizer.decode(y[i]) in tokenizer.decode(_y_pred[i])))
                y_pred.append(_y_pred)
                scores.append(_scores)
            scores = torch.tensor(scores).float().transpose(0, 1)
            y_pred = torch.stack(y_pred, dim=0).transpose(0, 1)
            y_pred = y_pred.reshape(-1, y_pred.shape[-1]) 
            sign = scores.reshape(-1) * 2 - 1.
            # sign = scores.reshape(-1)
            # sign = scores.reshape(-1) - 1. 
        
            x = x.unsqueeze(0).repeat(args.n_samples, 1, 1).transpose(0, 1).reshape(-1, x.shape[-1])
            x_new = torch.cat([x, y_pred], dim=1)
            y_new = x_new.clone()
            y_new[:, :x.shape[1]] = -1
            x_new = x_new[:, :-1].clone()
            y_new = y_new[:, 1:].clone()
            sign = sign.cuda()
            with ctx:
                logits = model(x_new)['logits']
                vloss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y_new.view(-1), ignore_index=-1, reduction='none')
                # print(vloss)
                # vloss.sigmoid_()
                vloss = torch.sigmoid(0.1 * vloss)
                sign = sign.unsqueeze(1).repeat(1, x_new.shape[-1]).reshape(-1)
                vloss = vloss * sign
                vloss = vloss.mean()
        else:
            vloss = torch.tensor(0.)
        verifier_loss.update(vloss.item(), x.shape[0] * train_data.num_target_tokens)
        total_loss.update(loss.item(), x.shape[0] * train_data.num_target_tokens)
        total_acc.update(acc.item(), x.shape[0] * train_data.num_target_tokens)
        scaler.scale(loss + 0.1 * vloss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        num_iters += 1
        train_bar.set_description(
            'Epoch: [{}/{}] Loss: {:.4f} Verifier_loss: {:.2f} Acc: {:.2f}'.format(ep, args.epochs_sft, total_loss.get(), verifier_loss.get(),
             total_acc.get(percentage=True))
        )
        verifier_loss = AverageMeter()
        

    # if ep > 1:
    #     print("Generating samples")
    #     loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
    #     loader.dataset.eval()
    #     prompts, generations, scores = generate_and_score(
    #         model, loader, tokenizer, ctx, temperature=0.5, top_p=0.9, n_samples=3, num_prefix_tokens=train_data.num_prefix_tokens, 
    #         num_target_tokens=train_data.num_target_tokens, y_tokens=train_data.num_target_tokens-train_data.pad_length)
    #     loader.dataset.train() 
    #     flat_generations = generations.reshape(-1, generations.shape[-1]) 
    #     flat_scores = scores.reshape(-1) * 2 - 1.
    #     flat_prompts = prompts.unsqueeze(0).repeat(3, 1, 1).transpose(0, 1).reshape(-1, prompts.shape[-1])
    #     pref_dataset = torch.utils.data.TensorDataset(flat_prompts, flat_generations, flat_scores)
        # print("Training verifier")
        # verifier_loss = AverageMeter()
        # for _ in range(2):
        #     loader_bar = tqdm(torch.utils.data.DataLoader(pref_dataset, batch_size=args.batch_size//2, shuffle=True))
        #     for x, y, sign in loader_bar:
        #         x = x.cuda()
        #         y = y.cuda()
        #         x_new = torch.cat([x, y], dim=1)
        #         y_new = x_new.clone()
        #         y_new[:, :x.shape[1]] = -1
        #         x_new = x_new[:, :-1].clone()
        #         y_new = y_new[:, 1:].clone()
        #         sign = sign.cuda()
        #         with ctx:
        #             logits = model(x_new)['logits']
        #             loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y_new.view(-1), ignore_index=-1, reduction='none')
        #             sign = sign.unsqueeze(1).repeat(1, x_new.shape[-1]).reshape(-1)
        #             loss *= sign
        #             loss = loss.mean()
        #         verifier_loss.update(loss.item(), x.shape[0] * train_data.num_target_tokens)
        #         scaler.scale(loss).backward()
        #         scaler.step(optimizer)
        #         scaler.update()
        #         optimizer.zero_grad(set_to_none=True)
        #         loader_bar.set_description(
        #             'Verifier Loss: {:.4f}'.format(verifier_loss.get()),
        #         )

                

    # evaluate the loss on train/val sets and write checkpoints
    print(f"Epoch {ep} completed")
    print("Train without COT")
    loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
    loader.dataset.eval()
    evaluate(model, loader, tokenizer, ctx, num_prefix_tokens=train_data.num_prefix_tokens, y_tokens=train_data.num_target_tokens-train_data.pad_length, pad_length=train_data.pad_length)
    loader.dataset.train()
    print("Train with COT")
    loader_wcot = torch.utils.data.DataLoader(train_data_wcot, batch_size=args.batch_size, shuffle=False)
    loader_wcot.dataset.eval()
    evaluate(model, loader_wcot, tokenizer, ctx, num_prefix_tokens=train_data_wcot.num_prefix_tokens, y_tokens=train_data_wcot.num_target_tokens-train_data.pad_length, pad_length=0)
    loader_wcot.dataset.train()
    print("Test without COT")
    loader_eval = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
    loader_eval.dataset.eval()
    evaluate(model, loader_eval, tokenizer, ctx, num_prefix_tokens=test_data.num_prefix_tokens, y_tokens=test_data.num_target_tokens-test_data.pad_length, pad_length=test_data.pad_length)
    loader_eval.dataset.train()


Epoch: [0/10] Loss: 4.4973 Verifier_loss: 0.00 Acc: 17.00: 100%|██████████| 13/13 [00:01<00:00,  7.12it/s]


Epoch 0 completed
Train without COT


 accuracy: 21.00: 100%|██████████| 11/11 [00:07<00:00,  1.41it/s]


Train with COT


 accuracy: 17.00: 100%|██████████| 3/3 [00:02<00:00,  1.36it/s]


Test without COT


 accuracy: 15.00: 100%|██████████| 3/3 [00:02<00:00,  1.43it/s]


In [25]:
outputs = model.generate(
    x,
    num_beams=10,
    max_new_tokens=20,
    num_return_sequences=2,
    attention_mask=torch.ones_like(x),
)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


In [33]:
np.array(tokenizer.batch_decode(outputs)).reshape(x.shape[0], 2)[2]

array(['32,25|48,35|7,9|48,32|9,33|25,1|1,13|35,7/48,33=48,35,7,9,33<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
       '32,25|48,35|7,9|48,32|9,33|25,1|1,13|35,7/48,33=48,35,7,9,33<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n'],
      dtype='<U212')

In [34]:
np.array(tokenizer.batch_decode(outputs)).reshape(x.shape[0], 2)

array([['40,36|31,46|31,40|5,42|46,47|47,0|36,5|0,48/31,42=31,40,36,5,42<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
        '40,36|31,46|31,40|5,42|46,47|47,0|36,5|0,48/31,42=31,40,36,5,42<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n<|endoftext|>'],
       ['10,5|34,37|3,38|19,3|37,48|19,34|38,10|48,35/19,35=19,34,37,48,35<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>\n-\n\n\n-\n\n\n-\n\n-\n\n-\n\n\n-',
        '10,5|34,37|3,38|19,3|37,48|19,34|38,10|48,35/19,35=19,34,37,48,35<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n'],
       ['32,25|48,35|7,9|48,32|9,33|25,1|1,13|35,7/48,33=48,3

In [40]:
tokenizer.batch_decode(next(test_loader.__iter__())[0])

['23,4|6,19|24,12|24,23|7,39|4,6|12,25|25,7/24,39=24,12,25,7,39<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '21,19|36,5|33,22|25,33|3,36|19,3|22,34|21,25/21,5=21,19,3,36,5<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '12,22|37,49|49,12|20,14|22,16|14,5|37,20|5,23/37,16=37,49,12,22,16<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '18,24|39,34|24,4|4,28|33,29|18,39|28,41|34,33/18,41=18,24,4,28,41<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '8,21|41,14|36,1|36,27|27,0|0,41|1,8|21,42/36,14=36,27,0,41,14<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
 '35,37|4,36|28,35|37,4|22,32|45,41|28,45|41,22/28,32=28,45,41,22,32<|endof