In [1]:
import argparse
from contextlib import nullcontext
import torch
from tqdm import tqdm
from data import get_dataset
from utils.training_utils import get_lr, get_run_name, AverageMeter
from torch.utils.data import DataLoader
from models import get_model
from tokenizing import get_tokenizer
import wandb
from pprint import pprint
import numpy as np
from datasets import Dataset
from copy import deepcopy
from transformers import TrainingArguments
from trl import DPOTrainer
from transformers import AutoTokenizer
import os

  from .autonotebook import tqdm as notebook_tqdm


[2024-04-20 05:02:35,403] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
parser = argparse.ArgumentParser(description="Next-token failures")
# Data
parser.add_argument(
    "--n_samples", type=int, default=5, help="Number of samples to generate"
    )
parser.add_argument(
    "--model", default='gpt2', type=str, help="Type of model"
    )
parser.add_argument(
    "--dataset", default='graph', type=str, help="Choice of dataset"
    )
parser.add_argument(
    "--n_train", default=200000, type=int, help="Number of training samples"
    )
parser.add_argument(
    "--n_test", default=500, type=int, help="Number of test samples"
    )
parser.add_argument(
    "--num_nodes", default=50, type=int, help="Number of node values in graph"
    )
parser.add_argument(
    "--deg", default=2, type=int, help="Degree of starting node"
    )
parser.add_argument(
    "--path_len", default=5, type=int, help="Path length in star graph"
    )
parser.add_argument(
        "--mate_in", default=2, type=int, help="For chess, number of moves to checkmate"
    )
parser.add_argument(
        "--unrolled", action=argparse.BooleanOptionalAction, default=True, help="For chess, unrolled board state",
    )
parser.add_argument(
        "--batch_size", type=int, default=48, help="Batch size",
    )
parser.add_argument(
        "--lr", type=float, default=1e-5, help="Learning rate",
    )
parser.add_argument(
        "--weight_decay", type=float, default=0., help="Strength of weight decay",
    )
parser.add_argument(
        "--epochs_sft", type=int, default=1, help="Number of SFT epochs",
    )
parser.add_argument(
        "--save_every", type=int, default=5000, help="Interval (in steps) at which to save model",
    )
parser.add_argument(
        "--pass_at_k", type=int, default=1, help="pass at k eval",
    )
parser.add_argument(
        "--teacherless", action=argparse.BooleanOptionalAction, default=False, help="Standard or teacherless training",
    )
parser.add_argument(
        "--reverse", action=argparse.BooleanOptionalAction, default=False, help="Standard format or reverse targets",
    )
parser.add_argument(
        "--cot", action=argparse.BooleanOptionalAction, default=False, help="Standard format or cot targets",
    )
parser.add_argument(
        "--pos", action=argparse.BooleanOptionalAction, default=False, help="Standard format or pos tokens",
    )
parser.add_argument(
        "--eval_train", action=argparse.BooleanOptionalAction, default=False, help="Eval for training set",
    )
parser.add_argument(
        "--eval_every", type=int, default=400, help="Interval (in steps) to evaluate the model on test",
    )
parser.add_argument(
        "--use_wandb", action=argparse.BooleanOptionalAction, default=False, help="Whether to use wandb",
    )
parser.add_argument(
        "--wandb_entity", type=str, default=5000, help="Wandb username",
    )
parser.add_argument(
    "--pad_length", default=0, type=int, help="Default value for pad length"
)


args = parser.parse_args([])

In [3]:
# System stuff
device = 'cuda' if torch.cuda.is_available() else 'cpu'
wandb_entity = args.wandb_entity
wandb_log = args.use_wandb
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# Model stuff
top_k = 1000
temperature = 1.
pass_at_k = args.pass_at_k
n_samples = args.n_samples

# Evaluation stuff
eval_iters = 1000
eval_interval = 5
log_interval = 10

# Optimiser
dtype = 'bfloat16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
beta1 = 0.9
beta2 = 0.999
decay_lr = False
args.compile = False if device == 'cuda' else False
args.use_flash = True if device == 'cuda' else False
warmup_iters = 100
min_lr = 1e-5

In [4]:
tokenizer = get_tokenizer(args)

# train data without cot
args.cot = False
args.n_train = 1000
args.n_test = 100
args.pad_length = 10
train_data, test_data = get_dataset(args, tokenizer, device)
train_data.num_target_tokens += train_data.pad_length
test_data.num_target_tokens += test_data.pad_length
train_data.num_tokens += train_data.pad_length
test_data.num_tokens += test_data.pad_length

# train data with cot
args.cot = True
args.n_train = 300
args.pad_length = 0
train_data_wcot, _ = get_dataset(args, tokenizer, device)
train_data_wcot.data_file = train_data_wcot.data_file[-args.n_train:]
train_data_wcot.tokenized, train_data_wcot.num_prefix_tokens, train_data_wcot.num_target_tokens = train_data_wcot.tokenizer.tokenize(train_data_wcot.data_file)


print(len(train_data), train_data[0][0].shape, len(train_data_wcot), train_data_wcot[0][0].shape)
print("W/o COT", train_data[0], tokenizer.decode(train_data[0][0]))
print("With COT", train_data_wcot[0], tokenizer.decode(train_data_wcot[0][0]))

Equal sequence lengths!
Equal sequence lengths!
Equal sequence lengths!
Equal sequence lengths!
Equal sequence lengths!
1000 torch.Size([54]) 300 torch.Size([54])
W/o COT (tensor([ 2920,    11,  1485,    91,  1954,    11,  2078,    91,  3510,    11,
         2327,    91,  2624,    11,  3682,    91,  2327,    11,  2624,    91,
         3682,    11,  1828,    91,  1485,    11,  1954,    91,  3510,    11,
         2920,    14,  3510,    11,  1828,    28,  3510,    11,  2327,    11,
         2624,    11,  3682,    11,  1828, 50256, 50256, 50256, 50256, 50256,
        50256, 50256, 50256, 50256], device='cuda:0'), tensor([   -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,    -1,
           -1,    -1,    -1,    -1,    -1,  3510,    11,  2327,    11,  2624,
           11,  3682,    11,  1828, 50256, 50256, 50256, 50256, 50256, 

In [5]:
train_data.num_prefix_tokens, train_data.num_target_tokens, train_data.num_tokens, train_data_wcot.num_prefix_tokens, train_data_wcot.num_target_tokens, train_data_wcot.num_tokens, test_data.num_prefix_tokens, test_data.num_target_tokens, test_data.num_tokens

(36, 19, 55, 36, 19, 55, 36, 19, 55)

In [6]:
from torch.utils.data import ConcatDataset
concatenated_train_data = ConcatDataset([train_data, train_data_wcot])
train_loader = DataLoader(concatenated_train_data, batch_size=args.batch_size, shuffle=True)
# train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True)
len(train_loader), len(test_loader)

(28, 3)

In [7]:
def evaluate(model, loader, tokenizer, ctx, num_prefix_tokens, y_tokens, num_target_tokens):
    model.eval()
    total_acc = AverageMeter()
    bar = tqdm(loader)
    pad_length = num_target_tokens - y_tokens
    c = 0
    # print(pad_length)
    for x in bar:
        y = x[:, num_prefix_tokens:].clone()
        x = x[:, :num_prefix_tokens].clone()
        with ctx:
            y_pred = model.generate(x, max_new_tokens=num_target_tokens,
                                do_sample=False, attention_mask = torch.ones_like(x), pad_token_id=2)
        completely_correct = 0 
        for i in range(y.shape[0]):
            # if c == 0 and i == 0:    
                # print("y:", tokenizer.decode(y[i, -y_tokens:]), "ypred:", tokenizer.decode(y_pred[i, -y_tokens:]))
                # print("y:", tokenizer.decode(y[i, -y_tokens:]), "ypred:", tokenizer.decode(y_pred[i, -y_tokens-pad_length:-pad_length]))
                # print("y_pred:", tokenizer.decode(y_pred[i]), "y:", tokenizer.decode(y[i]))
            completely_correct += int(tokenizer.decode(y[i]) in tokenizer.decode(y_pred[i]))
        completely_correct /= x.shape[0]
        total_acc.update(completely_correct, x.shape[0])
        c += 1
        bar.set_description(f' accuracy: {total_acc.get(percentage=True):.2f}')
    loader.dataset.train()
    model.train()



In [8]:
train_data.num_prefix_tokens + train_data.num_target_tokens

55

In [9]:
args.model = 'gpt2-medium'
args.lr = 1e-4
args.epochs_sft = 20
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained(args.model)
model.to(device)
model.train()

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.0)
ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)

max_iters = len(train_loader) * args.epochs_sft
lr_decay_iters = max_iters

results = {}
num_iters = 0

for ep in range(args.epochs_sft):
    train_bar = tqdm(train_loader)
    total_loss, total_acc = AverageMeter(), AverageMeter()
    verifier_loss = AverageMeter()
    for x, y in train_bar:
        with ctx:
            # logits, loss, accs = model(x, y)
            y_tokens = train_data.num_target_tokens - train_data.pad_length
            logits = model(x)['logits']
            # print(x.shape, y.shape, logits.shape)
            loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)
            acc = torch.mean((torch.argmax(logits[:, -y_tokens-train_data.pad_length:-train_data.pad_length, :], dim=-1) == y[:, -y_tokens-train_data.pad_length:-train_data.pad_length]).float())
            
            # if num_iters % 100 == 0:
            #     print(tokenizer.decode(y[0, -y_tokens-train_data.pad_length:]), tokenizer.decode(torch.argmax(logits[0, -y_tokens-train_data.pad_length:, :], dim=-1)))
            
        # verifier 
        if ep > 5000:
            num_prefix_tokens=train_data.num_prefix_tokens
            num_target_tokens=train_data.num_target_tokens
            y_tokens=train_data.num_target_tokens-train_data.pad_length
            x = torch.cat([x, y[:, -1:]], dim=1)
            y = x[:, -y_tokens:].clone()
            x = x[:, :num_prefix_tokens].clone()
            y_pred = []
            scores = []
            for i in range(args.n_samples):
                with ctx:
                    _y_pred = model.generate(x, min_length=num_prefix_tokens+num_target_tokens, 
                                max_length=num_prefix_tokens+num_target_tokens, temperature=1., top_p=0.99,
                                do_sample=True, attention_mask = torch.ones_like(x), pad_token_id=2)
                _scores = []
                for i in range(y.shape[0]):
                    _scores.append(int(tokenizer.decode(y[i]) in tokenizer.decode(_y_pred[i])))
                y_pred.append(_y_pred)
                scores.append(_scores)
            scores = torch.tensor(scores).float().transpose(0, 1)
            y_pred = torch.stack(y_pred, dim=0).transpose(0, 1)
            y_pred = y_pred.reshape(-1, y_pred.shape[-1]) 
            sign = scores.reshape(-1) * 2 - 1.
            # sign = scores.reshape(-1)
            # sign = scores.reshape(-1) - 1. 
        
            x = x.unsqueeze(0).repeat(args.n_samples, 1, 1).transpose(0, 1).reshape(-1, x.shape[-1])
            x_new = torch.cat([x, y_pred], dim=1)
            y_new = x_new.clone()
            y_new[:, :x.shape[1]] = -1
            x_new = x_new[:, :-1].clone()
            y_new = y_new[:, 1:].clone()
            sign = sign.cuda()
            with ctx:
                logits = model(x_new)['logits']
                vloss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y_new.view(-1), ignore_index=-1, reduction='none')
                # print(vloss)
                # vloss.sigmoid_()
                vloss = torch.sigmoid(0.1 * vloss)
                sign = sign.unsqueeze(1).repeat(1, x_new.shape[-1]).reshape(-1)
                vloss = vloss * sign
                vloss = vloss.mean()
        else:
            vloss = torch.tensor(0.)
        verifier_loss.update(vloss.item(), x.shape[0] * train_data.num_target_tokens)
        total_loss.update(loss.item(), x.shape[0] * train_data.num_target_tokens)
        total_acc.update(acc.item(), x.shape[0] * train_data.num_target_tokens)
        scaler.scale(loss + 0.1 * vloss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        num_iters += 1
        train_bar.set_description(
            'Epoch: [{}/{}] Loss: {:.4f} Verifier_loss: {:.2f} Acc: {:.2f}'.format(ep, args.epochs_sft, total_loss.get(), verifier_loss.get(),
             total_acc.get(percentage=True))
        )
        verifier_loss = AverageMeter()
                

    # evaluate the loss on train/val sets and write checkpoints
    print(f"Epoch {ep} completed")
    print("Train without COT")
    loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
    loader.dataset.eval()
    evaluate(model, loader, tokenizer, ctx, num_prefix_tokens=train_data.num_prefix_tokens, y_tokens=train_data.num_target_tokens-train_data.pad_length, num_target_tokens=train_data.num_target_tokens)
    loader.dataset.train()
    print("Train with COT")
    loader_wcot = torch.utils.data.DataLoader(train_data_wcot, batch_size=args.batch_size, shuffle=False)
    loader_wcot.dataset.eval()
    evaluate(model, loader_wcot, tokenizer, ctx, num_prefix_tokens=train_data_wcot.num_prefix_tokens, y_tokens=train_data_wcot.num_target_tokens-train_data.pad_length, num_target_tokens=train_data_wcot.num_target_tokens)
    loader_wcot.dataset.train()
    print("Test without COT")
    loader_eval = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
    loader_eval.dataset.eval()
    evaluate(model, loader_eval, tokenizer, ctx, num_prefix_tokens=test_data.num_prefix_tokens, y_tokens=test_data.num_target_tokens-test_data.pad_length, num_target_tokens=test_data.num_target_tokens)
    loader_eval.dataset.train()


Epoch: [0/20] Loss: 0.7631 Verifier_loss: 0.00 Acc: 78.21: 100%|██████████| 28/28 [00:06<00:00,  4.13it/s]


Epoch 0 completed
Train without COT


 accuracy: 54.80: 100%|██████████| 21/21 [00:09<00:00,  2.22it/s]


Train with COT


 accuracy: 0.00: 100%|██████████| 7/7 [00:02<00:00,  2.34it/s]


Test without COT


 accuracy: 46.00: 100%|██████████| 3/3 [00:01<00:00,  2.75it/s]
Epoch: [1/20] Loss: 0.2399 Verifier_loss: 0.00 Acc: 85.15: 100%|██████████| 28/28 [00:06<00:00,  4.62it/s]


Epoch 1 completed
Train without COT


 accuracy: 62.10: 100%|██████████| 21/21 [00:08<00:00,  2.34it/s]


Train with COT


 accuracy: 0.67: 100%|██████████| 7/7 [00:02<00:00,  2.42it/s]


Test without COT


 accuracy: 48.00: 100%|██████████| 3/3 [00:01<00:00,  2.33it/s]
Epoch: [2/20] Loss: 0.1744 Verifier_loss: 0.00 Acc: 87.80: 100%|██████████| 28/28 [00:06<00:00,  4.61it/s]


Epoch 2 completed
Train without COT


 accuracy: 68.90: 100%|██████████| 21/21 [00:08<00:00,  2.38it/s]


Train with COT


 accuracy: 1.67: 100%|██████████| 7/7 [00:03<00:00,  2.32it/s]


Test without COT


 accuracy: 52.00: 100%|██████████| 3/3 [00:01<00:00,  2.77it/s]
Epoch: [3/20] Loss: 0.1433 Verifier_loss: 0.00 Acc: 89.37: 100%|██████████| 28/28 [00:06<00:00,  4.60it/s]


Epoch 3 completed
Train without COT


 accuracy: 76.50: 100%|██████████| 21/21 [00:09<00:00,  2.28it/s]


Train with COT


 accuracy: 4.67: 100%|██████████| 7/7 [00:03<00:00,  2.30it/s]


Test without COT


 accuracy: 51.00: 100%|██████████| 3/3 [00:01<00:00,  2.60it/s]
Epoch: [4/20] Loss: 0.1169 Verifier_loss: 0.00 Acc: 91.03: 100%|██████████| 28/28 [00:06<00:00,  4.61it/s]


Epoch 4 completed
Train without COT


 accuracy: 80.90: 100%|██████████| 21/21 [00:08<00:00,  2.36it/s]


Train with COT


 accuracy: 18.00: 100%|██████████| 7/7 [00:03<00:00,  2.27it/s]


Test without COT


 accuracy: 49.00: 100%|██████████| 3/3 [00:01<00:00,  2.80it/s]
Epoch: [5/20] Loss: 0.1010 Verifier_loss: 0.00 Acc: 92.37: 100%|██████████| 28/28 [00:06<00:00,  4.60it/s]


Epoch 5 completed
Train without COT


 accuracy: 85.70: 100%|██████████| 21/21 [00:08<00:00,  2.35it/s]


Train with COT


 accuracy: 37.33: 100%|██████████| 7/7 [00:03<00:00,  2.27it/s]


Test without COT


 accuracy: 46.00: 100%|██████████| 3/3 [00:01<00:00,  2.74it/s]
Epoch: [6/20] Loss: 0.0693 Verifier_loss: 0.00 Acc: 94.82: 100%|██████████| 28/28 [00:06<00:00,  4.59it/s]


Epoch 6 completed
Train without COT


 accuracy: 91.40: 100%|██████████| 21/21 [00:08<00:00,  2.40it/s]


Train with COT


 accuracy: 64.00: 100%|██████████| 7/7 [00:03<00:00,  2.33it/s]


Test without COT


 accuracy: 56.00: 100%|██████████| 3/3 [00:01<00:00,  2.30it/s]
Epoch: [7/20] Loss: 0.0645 Verifier_loss: 0.00 Acc: 95.03: 100%|██████████| 28/28 [00:06<00:00,  4.61it/s]


Epoch 7 completed
Train without COT


 accuracy: 92.80: 100%|██████████| 21/21 [00:08<00:00,  2.46it/s]


Train with COT


 accuracy: 62.67: 100%|██████████| 7/7 [00:03<00:00,  2.30it/s]


Test without COT


 accuracy: 48.00: 100%|██████████| 3/3 [00:01<00:00,  2.84it/s]
Epoch: [8/20] Loss: 0.0480 Verifier_loss: 0.00 Acc: 96.25: 100%|██████████| 28/28 [00:06<00:00,  4.60it/s]


Epoch 8 completed
Train without COT


 accuracy: 95.50: 100%|██████████| 21/21 [00:07<00:00,  2.79it/s]


Train with COT


 accuracy: 55.67: 100%|██████████| 7/7 [00:02<00:00,  2.38it/s]


Test without COT


 accuracy: 49.00: 100%|██████████| 3/3 [00:01<00:00,  2.91it/s]
Epoch: [9/20] Loss: 0.0403 Verifier_loss: 0.00 Acc: 96.91: 100%|██████████| 28/28 [00:06<00:00,  4.58it/s]


Epoch 9 completed
Train without COT


 accuracy: 96.20: 100%|██████████| 21/21 [00:07<00:00,  2.96it/s]


Train with COT


 accuracy: 72.33: 100%|██████████| 7/7 [00:03<00:00,  2.27it/s]


Test without COT


 accuracy: 54.00: 100%|██████████| 3/3 [00:01<00:00,  2.71it/s]
Epoch: [10/20] Loss: 0.0487 Verifier_loss: 0.00 Acc: 96.29: 100%|██████████| 28/28 [00:06<00:00,  4.59it/s]


Epoch 10 completed
Train without COT


 accuracy: 95.30: 100%|██████████| 21/21 [00:08<00:00,  2.46it/s]


Train with COT


 accuracy: 77.67: 100%|██████████| 7/7 [00:03<00:00,  2.28it/s]


Test without COT


 accuracy: 57.00: 100%|██████████| 3/3 [00:01<00:00,  2.69it/s]
Epoch: [11/20] Loss: 0.0360 Verifier_loss: 0.00 Acc: 97.32: 100%|██████████| 28/28 [00:06<00:00,  4.60it/s]


Epoch 11 completed
Train without COT


 accuracy: 96.60: 100%|██████████| 21/21 [00:06<00:00,  3.12it/s]


Train with COT


 accuracy: 83.00: 100%|██████████| 7/7 [00:03<00:00,  2.30it/s]


Test without COT


 accuracy: 62.00: 100%|██████████| 3/3 [00:01<00:00,  2.69it/s]
Epoch: [12/20] Loss: 0.0299 Verifier_loss: 0.00 Acc: 97.81: 100%|██████████| 28/28 [00:06<00:00,  4.59it/s]


Epoch 12 completed
Train without COT


 accuracy: 98.60: 100%|██████████| 21/21 [00:07<00:00,  2.65it/s]


Train with COT


 accuracy: 93.33: 100%|██████████| 7/7 [00:03<00:00,  2.28it/s]


Test without COT


 accuracy: 61.00: 100%|██████████| 3/3 [00:01<00:00,  2.67it/s]
Epoch: [13/20] Loss: 0.0283 Verifier_loss: 0.00 Acc: 97.80: 100%|██████████| 28/28 [00:06<00:00,  4.61it/s]


Epoch 13 completed
Train without COT


 accuracy: 97.80: 100%|██████████| 21/21 [00:07<00:00,  2.99it/s]


Train with COT


 accuracy: 92.33: 100%|██████████| 7/7 [00:03<00:00,  2.27it/s]


Test without COT


 accuracy: 60.00: 100%|██████████| 3/3 [00:01<00:00,  2.27it/s]
Epoch: [14/20] Loss: 0.0249 Verifier_loss: 0.00 Acc: 98.26: 100%|██████████| 28/28 [00:06<00:00,  4.60it/s]


Epoch 14 completed
Train without COT


 accuracy: 99.20: 100%|██████████| 21/21 [00:06<00:00,  3.01it/s]


Train with COT


 accuracy: 97.67: 100%|██████████| 7/7 [00:03<00:00,  2.33it/s] 


Test without COT


 accuracy: 60.00: 100%|██████████| 3/3 [00:01<00:00,  2.64it/s]
Epoch: [15/20] Loss: 0.0187 Verifier_loss: 0.00 Acc: 98.68: 100%|██████████| 28/28 [00:06<00:00,  4.59it/s]


Epoch 15 completed
Train without COT


 accuracy: 99.40: 100%|██████████| 21/21 [00:07<00:00,  2.95it/s]


Train with COT


 accuracy: 95.33: 100%|██████████| 7/7 [00:02<00:00,  2.42it/s]


Test without COT


 accuracy: 61.00: 100%|██████████| 3/3 [00:01<00:00,  2.29it/s]
Epoch: [16/20] Loss: 0.0148 Verifier_loss: 0.00 Acc: 98.97: 100%|██████████| 28/28 [00:06<00:00,  4.59it/s]


Epoch 16 completed
Train without COT


 accuracy: 99.70: 100%|██████████| 21/21 [00:06<00:00,  3.18it/s]


Train with COT


 accuracy: 97.67: 100%|██████████| 7/7 [00:03<00:00,  2.26it/s] 


Test without COT


 accuracy: 58.00: 100%|██████████| 3/3 [00:01<00:00,  2.76it/s]
Epoch: [17/20] Loss: 0.0114 Verifier_loss: 0.00 Acc: 99.23: 100%|██████████| 28/28 [00:06<00:00,  4.61it/s]


Epoch 17 completed
Train without COT


 accuracy: 99.70: 100%|██████████| 21/21 [00:07<00:00,  2.94it/s]


Train with COT


 accuracy: 98.67: 100%|██████████| 7/7 [00:03<00:00,  2.28it/s]


Test without COT


 accuracy: 66.00: 100%|██████████| 3/3 [00:01<00:00,  2.73it/s]
Epoch: [18/20] Loss: 0.0104 Verifier_loss: 0.00 Acc: 99.26: 100%|██████████| 28/28 [00:06<00:00,  4.59it/s]


Epoch 18 completed
Train without COT


 accuracy: 99.80: 100%|██████████| 21/21 [00:06<00:00,  3.17it/s]


Train with COT


 accuracy: 99.67: 100%|██████████| 7/7 [00:03<00:00,  2.31it/s] 


Test without COT


 accuracy: 62.00: 100%|██████████| 3/3 [00:01<00:00,  2.81it/s]
Epoch: [19/20] Loss: 0.0127 Verifier_loss: 0.00 Acc: 99.15: 100%|██████████| 28/28 [00:06<00:00,  4.59it/s]


Epoch 19 completed
Train without COT


 accuracy: 99.80: 100%|██████████| 21/21 [00:06<00:00,  3.31it/s]


Train with COT


 accuracy: 99.33: 100%|██████████| 7/7 [00:03<00:00,  2.26it/s] 


Test without COT


 accuracy: 64.00: 100%|██████████| 3/3 [00:01<00:00,  2.75it/s]


In [56]:
x, y = next(iter(torch.utils.data.DataLoader(train_data)))
outputs = model.generate(
    x[:, :train_data.num_prefix_tokens],
    num_beams=100,
    max_new_tokens=19,
    num_return_sequences=5,
    temperature=1.0,
    attention_mask=torch.ones_like(x[:, :train_data.num_prefix_tokens]),
)
np.array(tokenizer.batch_decode(outputs)).reshape(x.shape[0], 5)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


array([['49,13|23,28|46,35|32,42|35,32|42,22|13,23|46,49/46,22=46,35,32,42,22<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
        '49,13|23,28|46,35|32,42|35,32|42,22|13,23|46,49/46,22=46,49,13,23,28,32,42,22<|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
        '49,13|23,28|46,35|32,42|35,32|42,22|13,23|46,49/46,22=46,34,28,23,28,23,28,23,28,23',
        '49,13|23,28|46,35|32,42|35,32|42,22|13,23|46,49/46,22=46,34,37,28,23,46:46,34,37,28',
        '49,13|23,28|46,35|32,42|35,32|42,22|13,23|46,49/46,22=22,42,32,35,46:46,35,32,42,22']],
      dtype='<U198')

In [70]:
scores = []
num_return_sequences = 5
pad_length = train_data.pad_length
for i in range(outputs.shape[0]):
    y_actual = tokenizer.decode(y[i // num_return_sequences, -y_tokens-pad_length:-pad_length]) if y[i // num_return_sequences, -1] == tokenizer.pad_token_id else tokenizer.decode(y[i // num_return_sequences, -y_tokens:])
    y_cot = ",".join(y_actual.split(",")[::-1]) + ":" + y_actual
    scores.append(max(int(y_actual in tokenizer.decode(outputs[i])), int(y_cot in tokenizer.decode(outputs[i]))))
    # print(y_actual, y_cot)
    # scores.append(
    #     int(tokenizer.decode(y[i // num_return_sequences, -y_tokens-pad_length:-pad_length]) in tokenizer.decode(outputs[i])) if tokenizer.decode(y[i // num_return_sequences, -1]) == tokenizer.pad_token_id else
    #     int(tokenizer.decode(y[i // num_return_sequences, -y_tokens-pad_length:]) in tokenizer.decode(outputs[i])),
    # )
scores

[1, 0, 0, 0, 1]

In [71]:
x, y = next(iter(torch.utils.data.DataLoader(train_data_wcot)))
outputs = model.generate(
    x[:, :train_data.num_prefix_tokens],
    num_beams=100,
    max_new_tokens=19,
    num_return_sequences=5,
    temperature=1.0,
    attention_mask=torch.ones_like(x[:, :train_data.num_prefix_tokens]),
)
np.array(tokenizer.batch_decode(outputs)).reshape(x.shape[0], 5)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


array([['30,35|2,30|49,29|15,6|15,49|24,16|6,2|29,24/15,16=16,24,29,49,15:15,49,29,24,16',
        '30,35|2,30|49,29|15,6|15,49|24,16|6,2|29,24/15,16=15,49,29,24,16<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
        '30,35|2,30|49,29|15,6|15,49|24,16|6,2|29,24/15,16=15,30,35,2,30,35,2,30,35,2',
        '30,35|2,30|49,29|15,6|15,49|24,16|6,2|29,24/15,16=15,6,2,30,35<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>',
        '30,35|2,30|49,29|15,6|15,49|24,16|6,2|29,24/15,16=15,24,16,29,49,15:15,49,29,24']],
      dtype='<U194')

In [74]:
scores = []
num_return_sequences = 5
pad_length = train_data.pad_length
for i in range(outputs.shape[0]):
    y_actual = tokenizer.decode(y[i // num_return_sequences, -y_tokens-pad_length:-pad_length]) if y[i // num_return_sequences, -1] == tokenizer.pad_token_id else tokenizer.decode(y[i // num_return_sequences, -y_tokens:])
    y_cot = ",".join(y_actual.split(",")[::-1]) + ":" + y_actual
    scores.append(max(int(y_actual in tokenizer.decode(outputs[i])), int(y_cot in tokenizer.decode(outputs[i]))))
scores

15,49,29,24,16 16,24,29,49,15:15,49,29,24,16
15,49,29,24,16 16,24,29,49,15:15,49,29,24,16
15,49,29,24,16 16,24,29,49,15:15,49,29,24,16
15,49,29,24,16 16,24,29,49,15:15,49,29,24,16
15,49,29,24,16 16,24,29,49,15:15,49,29,24,16


[1, 1, 0, 0, 0]

In [114]:
sft_model = deepcopy(model).cpu()

In [77]:
args.lr = 1e-5

model = model.cuda()
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
for ep in range(10):
    ref_model = deepcopy(model).cuda()
    train_bar = tqdm(train_loader)
    total_loss, total_acc = AverageMeter(), AverageMeter()
    verifier_loss = AverageMeter()
    scores_meter = AverageMeter()
    for x, y in train_bar:
        with ctx:
            # logits, loss, accs = model(x, y)
            logits = model(x)['logits']
            loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)
            acc = torch.mean((torch.argmax(logits[:, -train_data.num_target_tokens, :], dim=-1) == y[:, -train_data.num_target_tokens]).float())

        num_return_sequences = 5
        pad_length = train_data.pad_length
        if ep >= 0:
            outputs = model.generate(
                x[:, :train_data.num_prefix_tokens],
                num_beams=5,
                max_new_tokens=19,
                num_return_sequences=num_return_sequences,
                temperature=1.0,
                attention_mask=torch.ones_like(x[:, :train_data.num_prefix_tokens]),
                pad_token_id=tokenizer.eos_token_id
            )
            scores = []
            for i in range(outputs.shape[0]):
                y_actual = tokenizer.decode(y[i // num_return_sequences, -y_tokens-pad_length:-pad_length]) if y[i // num_return_sequences, -1] == tokenizer.pad_token_id else tokenizer.decode(y[i // num_return_sequences, -y_tokens:])
                y_cot = ",".join(y_actual.split(",")[::-1]) + ":" + y_actual
                scores.append(max(int(y_actual in tokenizer.decode(outputs[i])), int(y_cot in tokenizer.decode(outputs[i]))))
            scores = torch.tensor(scores).float().cuda()
            x_new = outputs[:, :-1].clone()
            y_new = outputs.clone()
            y_new[:, :train_data.num_prefix_tokens] = -1
            y_new = y_new[:, 1:]
            with ctx:
                logits = model(x_new)['logits']
            vloss = torch.nn.functional.cross_entropy(logits.reshape(-1, logits.size(-1)), y_new.reshape(-1), ignore_index=-1, reduction='none')
            with torch.no_grad():
                with ctx:
                    ref_logits = ref_model(x_new)['logits']
                    vloss_ref = torch.nn.functional.cross_entropy(ref_logits.reshape(-1, ref_logits.size(-1)), y_new.reshape(-1), ignore_index=-1, reduction='none')
            sign = (2 * scores - 1.) 
            sign = sign.unsqueeze(1).repeat(1, x_new.shape[-1])
            vloss = vloss.reshape(x_new.shape[0], x_new.shape[1])
            vloss_ref = vloss_ref.reshape(x_new.shape[0], x_new.shape[1])
            # print(vloss.shape, vloss_ref.shape, sign.shape)
            vloss_inside = ((vloss - vloss_ref) * sign).sum(1)
            vloss = -torch.log(torch.sigmoid(vloss_inside))
            vloss = vloss.mean()
        else:
            vloss = torch.tensor(0.)
            scores = torch.tensor(0.)
        verifier_loss.update(vloss.item(), x.shape[0] * train_data.num_target_tokens)
        scores_meter.update(scores.mean().item(), x.shape[0]*num_return_sequences)
        total_loss.update(loss.item(), x.shape[0] * train_data.num_target_tokens)
        total_acc.update(acc.item(), x.shape[0] * train_data.num_target_tokens)

        scaler.scale(1. * loss + 1. * vloss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        num_iters += 1
        
        train_bar.set_description(
            'Epoch: [{}/{}] Loss: {:.4f} Verifier_loss: {:.2f} Scores: {:.2f} Acc: {:.2f}'.format(ep, args.epochs_sft, total_loss.get(), verifier_loss.get(), scores_meter.get(),
             total_acc.get(percentage=True))
        )
        
    print(f"Epoch {ep} completed")
    print("Train without COT")
    loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
    loader.dataset.eval()
    evaluate(model, loader, tokenizer, ctx, num_prefix_tokens=train_data.num_prefix_tokens, y_tokens=train_data.num_target_tokens-train_data.pad_length, num_target_tokens=train_data.num_target_tokens)
    loader.dataset.train()
    print("Train with COT")
    loader_wcot = torch.utils.data.DataLoader(train_data_wcot, batch_size=args.batch_size, shuffle=False)
    loader_wcot.dataset.eval()
    evaluate(model, loader_wcot, tokenizer, ctx, num_prefix_tokens=train_data_wcot.num_prefix_tokens, y_tokens=train_data_wcot.num_target_tokens-train_data.pad_length, num_target_tokens=train_data_wcot.num_target_tokens)
    loader_wcot.dataset.train()
    print("Test without COT")
    loader_eval = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
    loader_eval.dataset.eval()
    evaluate(model, loader_eval, tokenizer, ctx, num_prefix_tokens=test_data.num_prefix_tokens, y_tokens=test_data.num_target_tokens-test_data.pad_length, num_target_tokens=test_data.num_target_tokens)
    loader_eval.dataset.train()



  0%|          | 0/63 [00:00<?, ?it/s]

Epoch: [0/20] Loss: 0.4924 Verifier_loss: 21.59 Scores: 0.04 Acc: 95.52:  84%|████████▍ | 53/63 [01:05<00:12,  1.24s/it]