In [1]:
import argparse
from contextlib import nullcontext
import torch
from tqdm import tqdm

from data import get_dataset
from utils.training_utils import get_lr, get_run_name, AverageMeter
from torch.utils.data import DataLoader
from models import get_model
from tokenizing import get_tokenizer
import wandb
from pprint import pprint
import numpy as np
from datasets import Dataset
from copy import deepcopy
from transformers import TrainingArguments
from trl import DPOTrainer
from transformers import AutoTokenizer
import os

  from .autonotebook import tqdm as notebook_tqdm


[2024-04-18 04:15:55,698] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
parser = argparse.ArgumentParser(description="Next-token failures")
# Data
parser.add_argument(
    "--n_samples", type=int, default=5, help="Number of samples to generate"
    )
parser.add_argument(
    "--model", default='gpt2', type=str, help="Type of model"
    )
parser.add_argument(
    "--dataset", default='graph', type=str, help="Choice of dataset"
    )
parser.add_argument(
    "--n_train", default=200000, type=int, help="Number of training samples"
    )
parser.add_argument(
    "--n_test", default=500, type=int, help="Number of test samples"
    )
parser.add_argument(
    "--num_nodes", default=50, type=int, help="Number of node values in graph"
    )
parser.add_argument(
    "--deg", default=2, type=int, help="Degree of starting node"
    )
parser.add_argument(
    "--path_len", default=5, type=int, help="Path length in star graph"
    )
parser.add_argument(
        "--mate_in", default=2, type=int, help="For chess, number of moves to checkmate"
    )
parser.add_argument(
        "--unrolled", action=argparse.BooleanOptionalAction, default=True, help="For chess, unrolled board state",
    )
parser.add_argument(
        "--batch_size", type=int, default=64, help="Batch size",
    )
parser.add_argument(
        "--lr", type=float, default=5e-4, help="Learning rate",
    )
parser.add_argument(
        "--weight_decay", type=float, default=0., help="Strength of weight decay",
    )
parser.add_argument(
        "--epochs_sft", type=int, default=1, help="Number of SFT epochs",
    )
parser.add_argument(
        "--save_every", type=int, default=5000, help="Interval (in steps) at which to save model",
    )
parser.add_argument(
        "--pass_at_k", type=int, default=1, help="pass at k eval",
    )
parser.add_argument(
        "--teacherless", action=argparse.BooleanOptionalAction, default=False, help="Standard or teacherless training",
    )
parser.add_argument(
        "--reverse", action=argparse.BooleanOptionalAction, default=False, help="Standard format or reverse targets",
    )
parser.add_argument(
        "--cot", action=argparse.BooleanOptionalAction, default=False, help="Standard format or cot targets",
    )
parser.add_argument(
        "--pos", action=argparse.BooleanOptionalAction, default=False, help="Standard format or pos tokens",
    )
parser.add_argument(
        "--eval_train", action=argparse.BooleanOptionalAction, default=False, help="Eval for training set",
    )
parser.add_argument(
        "--eval_every", type=int, default=400, help="Interval (in steps) to evaluate the model on test",
    )
parser.add_argument(
        "--use_wandb", action=argparse.BooleanOptionalAction, default=False, help="Whether to use wandb",
    )
parser.add_argument(
        "--wandb_entity", type=str, default=5000, help="Wandb username",
    )
parser.add_argument(
    "--pad_length", default=0, type=int, help="Default value for pad length"
)


args = parser.parse_args([])

In [3]:
# System stuff
device = 'cuda' if torch.cuda.is_available() else 'cpu'
wandb_entity = args.wandb_entity
wandb_log = args.use_wandb
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# Model stuff
top_k = 1000
temperature = 1.
pass_at_k = args.pass_at_k
n_samples = args.n_samples

# Evaluation stuff
eval_iters = 1000
eval_interval = 5
log_interval = 10

# Optimiser
dtype = 'bfloat16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
beta1 = 0.9
beta2 = 0.999
decay_lr = True
args.compile = False if device == 'cuda' else False
args.use_flash = True if device == 'cuda' else False
warmup_iters = 100
min_lr = 1e-5

In [4]:
tokenizer = get_tokenizer(args)

# train data without cot
args.cot = False
args.n_train = 1000
args.n_test = 100
args.pad_length = 0
train_data, test_data = get_dataset(args, tokenizer, device)
train_data.num_target_tokens += train_data.pad_length
test_data.num_target_tokens += test_data.pad_length
train_data.num_tokens += train_data.pad_length
test_data.num_tokens += test_data.pad_length

# train data with cot
args.cot = True
args.n_train = 100
args.pad_length = 0
train_data_wcot, _ = get_dataset(args, tokenizer, device)
train_data_wcot.data_file = train_data_wcot.data_file[-args.n_train:]
train_data_wcot.tokenized, train_data_wcot.num_prefix_tokens, train_data_wcot.num_target_tokens = train_data_wcot.tokenizer.tokenize(train_data_wcot.data_file)


print(len(train_data), train_data[0][0].shape, len(train_data_wcot), train_data_wcot[0][0].shape)
print("W/o COT", train_data[0], tokenizer.decode(train_data[0][0]))
print("With COT", train_data_wcot[0], tokenizer.decode(train_data_wcot[0][0]))

Equal sequence lengths!
Equal sequence lengths!
Equal sequence lengths!
Equal sequence lengths!
Equal sequence lengths!
1000 torch.Size([44]) 100 torch.Size([54])
W/o COT (tensor([2920,   11, 1485,   91, 1954,   11, 2078,   91, 3510,   11, 2327,   91,
        2624,   11, 3682,   91, 2327,   11, 2624,   91, 3682,   11, 1828,   91,
        1485,   11, 1954,   91, 3510,   11, 2920,   14, 3510,   11, 1828,   28,
        3510,   11, 2327,   11, 2624,   11, 3682,   11], device='cuda:0'), tensor([  -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
          -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,
          -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1, 3510,
          11, 2327,   11, 2624,   11, 3682,   11, 1828], device='cuda:0')) 49,13|23,28|46,35|32,42|35,32|42,22|13,23|46,49/46,22=46,35,32,42,
With COT (tensor([3559,   11, 2682,   91, 1238,   11,   24,   91, 2780,   11,   23,   91,
          24,   11, 2075,   91, 1983,   1

In [5]:
train_data.num_prefix_tokens, train_data.num_target_tokens, train_data.num_tokens, train_data_wcot.num_prefix_tokens, train_data_wcot.num_target_tokens, train_data_wcot.num_tokens, test_data.num_prefix_tokens, test_data.num_target_tokens, test_data.num_tokens

(36, 9, 45, 36, 19, 55, 36, 9, 45)

In [6]:
# from torch.utils.data import ConcatDataset
# concatenated_train_data = ConcatDataset([train_data, train_data_wcot])
# train_loader = DataLoader(concatenated_train_data, batch_size=args.batch_size, shuffle=True)
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True)
len(train_loader), len(test_loader)

(16, 2)

In [7]:
def evaluate(model, loader, tokenizer, ctx, num_prefix_tokens, y_tokens, pad_length):
    model.eval()
    total_acc = AverageMeter()
    bar = tqdm(loader)
    for x in bar:
        original_length = x.shape[1]
        y = x[:, -y_tokens:].clone()
        x = x[:, :num_prefix_tokens].clone()
        with ctx:
            y_pred = model.generate(x, max_new_tokens=original_length + y_tokens, min_length=original_length + y_tokens,
                                do_sample=False, attention_mask = torch.ones_like(x), pad_token_id=2)
        completely_correct = 0 
        for i in range(y.shape[0]):
            completely_correct += int(tokenizer.decode(y[i]) in tokenizer.decode(y_pred[i]))
        completely_correct /= x.shape[0]
        total_acc.update(completely_correct, x.shape[0])
        bar.set_description(f' accuracy: {total_acc.get(percentage=True):.2f}')
    loader.dataset.train()
    model.train()



In [8]:
def generate_and_score(model, loader, tokenizer, ctx, temperature, top_p, n_samples, num_prefix_tokens, num_target_tokens, y_tokens):
    model.eval()
    bar = tqdm(loader)
    x_dataset = []
    y_pred_dataset = []
    scores_dataset = []
    for x in bar:
        original_length = x.shape[1]
        y = x[:, -y_tokens:].clone()
        x = x[:, :num_prefix_tokens].clone()
        y = y
        y_pred = []
        scores = []
        for i in range(n_samples):
            with ctx:
                # print(x.shape, num_target_tokens)
                _y_pred = model.generate(x, min_length=num_prefix_tokens+num_target_tokens, 
                                   max_length=num_prefix_tokens+num_target_tokens, temperature=temperature, top_p=top_p,
                                   do_sample=True, attention_mask = torch.ones_like(x), pad_token_id=2)
            _scores = []
            for i in range(y.shape[0]):
                _scores.append(int(tokenizer.decode(y[i]) in tokenizer.decode(_y_pred[i])))
            # print(_y_pred.shape)
            y_pred.append(_y_pred.cpu())
            scores.append(_scores)
            
        scores = torch.tensor(scores).float().transpose(0, 1)
        y_pred = torch.stack(y_pred, dim=0).transpose(0, 1)
        # y_pred should be bs x nsamples x length
        y_pred_dataset.append(y_pred[:, :, num_prefix_tokens:].cpu())
        scores_dataset.append(scores.cpu())
        # print(scores.shape, y_pred.shape, x.shape)
        x_dataset.append(x.cpu())
        # correct should be bs x nsamples
    # Switch back to train mode
    loader.dataset.train()
    model.train()
    y_pred_dataset = torch.cat(y_pred_dataset, dim=0)
    x_dataset = torch.cat(x_dataset, dim=0)
    scores_dataset = torch.cat(scores_dataset, dim=0)
    return x_dataset, y_pred_dataset, scores_dataset

In [21]:
args.lr = 1e-5
args.model = 'gpt2-large'
args.epochs_sft = 15
from transformers import GPT2LMHeadModel
model = GPT2LMHeadModel.from_pretrained(args.model)
model.to(device)
model.train()
tokenizer.pad_token_id = tokenizer.eos_token_id

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.0)
ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)

max_iters = len(train_loader) * args.epochs_sft
lr_decay_iters = max_iters

results = {}
num_iters = 0

for ep in range(args.epochs_sft):
    train_bar = tqdm(train_loader)
    total_loss, total_acc = AverageMeter(), AverageMeter()
    verifier_loss = AverageMeter()
    scores_meter = AverageMeter()
    for x, y in train_bar:
        # determine and set the learning rate for this iteration
        lr = get_lr(num_iters, args.lr, warmup_iters, lr_decay_iters, min_lr) if decay_lr else args.lr
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        with ctx:
            # logits, loss, accs = model(x, y)
            logits = model(x)['logits']
            loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)
            acc = torch.mean((torch.argmax(logits[:, -train_data.num_target_tokens, :], dim=-1) == y[:, -train_data.num_target_tokens]).float())

        if ep >= 100:
            outputs = model.generate(
                x[:, :train_data.num_prefix_tokens],
                num_beams=2,
                max_new_tokens=9,
                num_return_sequences=2,
                temperature=1.0,
                attention_mask=torch.ones_like(x[:, :train_data.num_prefix_tokens]),
                pad_token_id=tokenizer.eos_token_id
            )
            y_tokens=train_data.num_target_tokens-train_data.pad_length
            scores = []
            for i in range(outputs.shape[0]):
                scores.append(int(tokenizer.decode(outputs[i, -y_tokens:]) in tokenizer.decode(y[i // 2, -y_tokens:])))
            scores = torch.tensor(scores).float().cuda()
            x_new = outputs[:, :-1].clone()
            y_new = outputs.clone()
            y_new[:, :train_data.num_prefix_tokens] = -1
            y_new = y_new[:, 1:]
            logits = model(x_new)['logits']
            # print(logits.shape, y_new.shape)
            vloss = torch.nn.functional.cross_entropy(logits.reshape(-1, logits.size(-1)), y_new.reshape(-1), ignore_index=-1, reduction='none')
            sign = (2 * scores - 1.)
            sign = sign.unsqueeze(1).repeat(1, x_new.shape[-1])
            vloss = vloss.reshape(x_new.shape[0], x_new.shape[1])
            # print(vloss.shape, sign.shape)
            vloss = (vloss * sign).sum(1).mean(0)
            vloss = -torch.log(torch.sigmoid(vloss))
            vloss = vloss.mean()
        else:
            vloss = torch.tensor(0.)
            scores = torch.tensor(0.)
        verifier_loss.update(vloss.item(), x.shape[0] * train_data.num_target_tokens)
        scores_meter.update(scores.mean().item(), x.shape[0]*2)
        total_loss.update(loss.item(), x.shape[0] * train_data.num_target_tokens)
        total_acc.update(acc.item(), x.shape[0] * train_data.num_target_tokens)

        scaler.scale(loss + vloss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        num_iters += 1
        
        train_bar.set_description(
            'Epoch: [{}/{}] Loss: {:.4f} Verifier_loss: {:.2f} Scores: {:.2f} Acc: {:.2f}'.format(ep, args.epochs_sft, total_loss.get(), verifier_loss.get(), scores_meter.get(),
             total_acc.get(percentage=True))
        )
        # if num_iters % 50 == 0:
        #     loader_eval = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
        #     loader_eval.dataset.eval()
        #     evaluate(model, loader_eval, tokenizer, ctx, num_prefix_tokens=test_data.num_prefix_tokens, y_tokens=test_data.num_target_tokens-test_data.pad_length, pad_length=test_data.pad_length)
        #     loader_eval.dataset.train()


        
                
            # x = torch.cat([x, y[:, -1:]], dim=1)
            # y = x[:, -y_tokens:].clone()
            # x = x[:, :num_prefix_tokens].clone()
            # y_pred = []
            # scores = []
                        
        # verifier 
        # if ep > 5:
        #     num_prefix_tokens=train_data.num_prefix_tokens
        #     num_target_tokens=train_data.num_target_tokens
        #     y_tokens=train_data.num_target_tokens-train_data.pad_length
        #     x = torch.cat([x, y[:, -1:]], dim=1)
        #     y = x[:, -y_tokens:].clone()
        #     x = x[:, :num_prefix_tokens].clone()
        #     y_pred = []
        #     scores = []
        #     for i in range(args.n_samples):
        #         with ctx:
        #             _y_pred = model.generate(x, min_length=num_prefix_tokens+num_target_tokens, 
        #                         max_length=num_prefix_tokens+num_target_tokens, temperature=1., top_p=0.99,
        #                         do_sample=True, attention_mask = torch.ones_like(x), pad_token_id=2)
        #         _scores = []
        #         for i in range(y.shape[0]):
        #             _scores.append(int(tokenizer.decode(y[i]) in tokenizer.decode(_y_pred[i])))
        #         y_pred.append(_y_pred)
        #         scores.append(_scores)
        #     scores = torch.tensor(scores).float().transpose(0, 1)
        #     y_pred = torch.stack(y_pred, dim=0).transpose(0, 1)
        #     y_pred = y_pred.reshape(-1, y_pred.shape[-1]) 
        #     sign = scores.reshape(-1) * 2 - 1.
        #     # sign = scores.reshape(-1)
        #     # sign = scores.reshape(-1) - 1. 
        
        #     x = x.unsqueeze(0).repeat(args.n_samples, 1, 1).transpose(0, 1).reshape(-1, x.shape[-1])
        #     x_new = torch.cat([x, y_pred], dim=1)
        #     y_new = x_new.clone()
        #     y_new[:, :x.shape[1]] = -1
        #     x_new = x_new[:, :-1].clone()
        #     y_new = y_new[:, 1:].clone()
        #     sign = sign.cuda()
        #     with ctx:
        #         logits = model(x_new)['logits']
        #         vloss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y_new.view(-1), ignore_index=-1, reduction='none')
        #         # print(vloss)
        #         # vloss.sigmoid_()
        #         vloss = torch.sigmoid(0.1 * vloss)
        #         sign = sign.unsqueeze(1).repeat(1, x_new.shape[-1]).reshape(-1)
        #         vloss = vloss * sign
        #         vloss = vloss.mean()
        # else:
        #   vloss = torch.tensor(0.)
        # verifier_loss.update(vloss.item(), x.shape[0] * train_data.num_target_tokens)


        # verifier_loss = AverageMeter()
        

    # if ep > 1:
    #     print("Generating samples")
    #     loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
    #     loader.dataset.eval()
    #     prompts, generations, scores = generate_and_score(
    #         model, loader, tokenizer, ctx, temperature=0.5, top_p=0.9, n_samples=3, num_prefix_tokens=train_data.num_prefix_tokens, 
    #         num_target_tokens=train_data.num_target_tokens, y_tokens=train_data.num_target_tokens-train_data.pad_length)
    #     loader.dataset.train() 
    #     flat_generations = generations.reshape(-1, generations.shape[-1]) 
    #     flat_scores = scores.reshape(-1) * 2 - 1.
    #     flat_prompts = prompts.unsqueeze(0).repeat(3, 1, 1).transpose(0, 1).reshape(-1, prompts.shape[-1])
    #     pref_dataset = torch.utils.data.TensorDataset(flat_prompts, flat_generations, flat_scores)
        # print("Training verifier")
        # verifier_loss = AverageMeter()
        # for _ in range(2):
        #     loader_bar = tqdm(torch.utils.data.DataLoader(pref_dataset, batch_size=args.batch_size//2, shuffle=True))
        #     for x, y, sign in loader_bar:
        #         x = x.cuda()
        #         y = y.cuda()
        #         x_new = torch.cat([x, y], dim=1)
        #         y_new = x_new.clone()
        #         y_new[:, :x.shape[1]] = -1
        #         x_new = x_new[:, :-1].clone()
        #         y_new = y_new[:, 1:].clone()
        #         sign = sign.cuda()
        #         with ctx:
        #             logits = model(x_new)['logits']
        #             loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y_new.view(-1), ignore_index=-1, reduction='none')
        #             sign = sign.unsqueeze(1).repeat(1, x_new.shape[-1]).reshape(-1)
        #             loss *= sign
        #             loss = loss.mean()
        #         verifier_loss.update(loss.item(), x.shape[0] * train_data.num_target_tokens)
        #         scaler.scale(loss).backward()
        #         scaler.step(optimizer)
        #         scaler.update()
        #         optimizer.zero_grad(set_to_none=True)
        #         loader_bar.set_description(
        #             'Verifier Loss: {:.4f}'.format(verifier_loss.get()),
        #         )

                

    # evaluate the loss on train/val sets and write checkpoints
    # print(f"Epoch {ep} completed")
    # print("Train without COT")
    # loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
    # loader.dataset.eval()
    # evaluate(model, loader, tokenizer, ctx, num_prefix_tokens=train_data.num_prefix_tokens, y_tokens=train_data.num_target_tokens-train_data.pad_length, pad_length=train_data.pad_length)
    # loader.dataset.train()
    # print("Train with COT")
    # loader_wcot = torch.utils.data.DataLoader(train_data_wcot, batch_size=args.batch_size, shuffle=False)
    # loader_wcot.dataset.eval()
    # evaluate(model, loader_wcot, tokenizer, ctx, num_prefix_tokens=train_data_wcot.num_prefix_tokens, y_tokens=train_data_wcot.num_target_tokens-train_data.pad_length, pad_length=0)
    # loader_wcot.dataset.train()
    # print("Test without COT")
    loader_eval = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
    loader_eval.dataset.eval()
    evaluate(model, loader_eval, tokenizer, ctx, num_prefix_tokens=test_data.num_prefix_tokens, y_tokens=test_data.num_target_tokens-test_data.pad_length, pad_length=test_data.pad_length)
    loader_eval.dataset.train()


Epoch: [0/15] Loss: 1.0472 Verifier_loss: 0.00 Scores: 0.00 Acc: 29.80: 100%|██████████| 32/32 [00:08<00:00,  3.61it/s]
 accuracy: 54.00: 100%|██████████| 2/2 [00:03<00:00,  1.73s/it]
Epoch: [1/15] Loss: 0.1814 Verifier_loss: 0.00 Scores: 0.00 Acc: 98.40: 100%|██████████| 32/32 [00:08<00:00,  3.61it/s]
 accuracy: 53.00: 100%|██████████| 2/2 [00:03<00:00,  1.74s/it]
Epoch: [2/15] Loss: 0.1210 Verifier_loss: 0.00 Scores: 0.00 Acc: 100.00: 100%|██████████| 32/32 [00:08<00:00,  3.62it/s]
 accuracy: 64.00: 100%|██████████| 2/2 [00:03<00:00,  1.79s/it]
Epoch: [3/15] Loss: 0.0935 Verifier_loss: 0.00 Scores: 0.00 Acc: 100.00: 100%|██████████| 32/32 [00:08<00:00,  3.62it/s]
 accuracy: 68.00: 100%|██████████| 2/2 [00:03<00:00,  1.74s/it]
Epoch: [4/15] Loss: 0.0781 Verifier_loss: 0.00 Scores: 0.00 Acc: 100.00: 100%|██████████| 32/32 [00:08<00:00,  3.61it/s]
 accuracy: 59.00: 100%|██████████| 2/2 [00:03<00:00,  1.80s/it]
Epoch: [5/15] Loss: 0.0593 Verifier_loss: 0.00 Scores: 0.00 Acc: 100.00: 100%

In [10]:
loader_eval = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
loader_eval.dataset.eval()
evaluate(model, loader_eval, tokenizer, ctx, num_prefix_tokens=test_data.num_prefix_tokens, y_tokens=test_data.num_target_tokens-test_data.pad_length, pad_length=test_data.pad_length)
loader_eval.dataset.train()


 accuracy: 56.00: 100%|██████████| 2/2 [00:03<00:00,  1.65s/it]


In [11]:
args.batch_size

64

In [12]:
# loader_eval = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
# # loader_eval.dataset.eval()
# # evaluate(model, loader_eval, tokenizer, ct# x, num_prefix_tokens=train_data.num_prefix_tokens, y_tokens=train_data.num_target_tokens-test_data.pad_length, pad_length=train_data.pad_length)
loader_eval.dataset.train()


In [13]:
next(train_loader.__iter__())

[tensor([[3559,   11, 1433,  ...,   11, 1731,   11],
         [3132,   11, 1065,  ...,   11, 1065,   11],
         [1433,   11, 2091,  ...,   11, 2598,   11],
         ...,
         [1558,   11, 2231,  ...,   11, 1558,   11],
         [  22,   11, 2091,  ...,   11, 1433,   11],
         [2718,   11, 3682,  ...,   11, 2682,   11]], device='cuda:0'),
 tensor([[  -1,   -1,   -1,  ..., 1731,   11, 1983],
         [  -1,   -1,   -1,  ..., 1065,   11,   18],
         [  -1,   -1,   -1,  ..., 2598,   11, 2327],
         ...,
         [  -1,   -1,   -1,  ..., 1558,   11, 2231],
         [  -1,   -1,   -1,  ..., 1433,   11,   20],
         [  -1,   -1,   -1,  ..., 2682,   11, 1065]], device='cuda:0')]

In [14]:
outputs = model.generate(
    x[:, :train_data.num_prefix_tokens],
    num_beams=5,
    max_new_tokens=9,
    num_return_sequences=5,
    temperature=1.0,
    attention_mask=torch.ones_like(x[:, :train_data.num_prefix_tokens]),
)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [15]:
np.array(tokenizer.batch_decode(outputs)).reshape(x.shape[0], 5)[2]

array(['15,35|35,5|20,11|16,24|20,15|11,16|24,2|5,31/20,31=20,15,35,5,31',
       '15,35|35,5|20,11|16,24|20,15|11,16|24,2|5,31/20,31=20,11,16,24,2',
       '15,35|35,5|20,11|16,24|20,15|11,16|24,2|5,31/20,31=20,5,31,35,5',
       '15,35|35,5|20,11|16,24|20,15|11,16|24,2|5,31/20,31=20,15,35,5,30',
       '15,35|35,5|20,11|16,24|20,15|11,16|24,2|5,31/20,31=20,25,35,5,31'],
      dtype='<U69')

In [16]:
sft_model = deepcopy(model.cpu())

In [17]:
args.lr = 1e-6


model = model.cuda()
sft_model = sft_model.cuda()
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
# train_loader.batch_size = 16
for ep in range(10):
    train_bar = tqdm(train_loader)
    total_loss, total_acc = AverageMeter(), AverageMeter()
    verifier_loss = AverageMeter()
    scores_meter = AverageMeter()
    for x, y in train_bar:
        # determine and set the learning rate for this iteration
        lr = get_lr(num_iters, args.lr, warmup_iters, lr_decay_iters, min_lr) if decay_lr else args.lr
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        with ctx:
            # logits, loss, accs = model(x, y)
            logits = model(x)['logits']
            loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-1)
            acc = torch.mean((torch.argmax(logits[:, -train_data.num_target_tokens, :], dim=-1) == y[:, -train_data.num_target_tokens]).float())

        if ep >= 0:
            outputs = model.generate(
                x[:, :train_data.num_prefix_tokens],
                num_beams=2,
                max_new_tokens=9,
                num_return_sequences=2,
                temperature=1.0,
                attention_mask=torch.ones_like(x[:, :train_data.num_prefix_tokens]),
                pad_token_id=tokenizer.eos_token_id
            )
            y_tokens=train_data.num_target_tokens-train_data.pad_length
            scores = []
            for i in range(outputs.shape[0]):
                scores.append(int(tokenizer.decode(outputs[i, -y_tokens:]) in tokenizer.decode(y[i // 2, -y_tokens:])))
            scores = torch.tensor(scores).float().cuda()
            x_new = outputs[:, :-1].clone()
            y_new = outputs.clone()
            y_new[:, :train_data.num_prefix_tokens] = -1
            y_new = y_new[:, 1:]
            with ctx:
                logits = model(x_new)['logits']
            # print(logits.shape, y_new.shape)
            vloss = torch.nn.functional.cross_entropy(logits.reshape(-1, logits.size(-1)), y_new.reshape(-1), ignore_index=-1, reduction='none')
            with torch.no_grad():
                with ctx:
                    ref_logits = sft_model(x_new)['logits']
                    vloss_ref = torch.nn.functional.cross_entropy(ref_logits.reshape(-1, ref_logits.size(-1)), y_new.reshape(-1), ignore_index=-1, reduction='none')
            sign = (2 * scores - 1.) 
            sign = sign.unsqueeze(1).repeat(1, x_new.shape[-1])
            vloss = vloss.reshape(x_new.shape[0], x_new.shape[1])
            vloss_ref = vloss_ref.reshape(x_new.shape[0], x_new.shape[1])
            # print(vloss.shape, sign.shape)
            vloss_inside = ((vloss - vloss_ref) * sign).sum(1)
            vloss = -torch.log(torch.sigmoid(vloss_inside))
            vloss = vloss.mean()
        else:
            vloss = torch.tensor(0.)
            scores = torch.tensor(0.)
        verifier_loss.update(vloss_inside.item(), x.shape[0] * train_data.num_target_tokens)
        scores_meter.update(scores.mean().item(), x.shape[0]*10)
        total_loss.update(loss.item(), x.shape[0] * train_data.num_target_tokens)
        total_acc.update(acc.item(), x.shape[0] * train_data.num_target_tokens)

        scaler.scale(1. * loss + 1. * vloss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
        num_iters += 1
        
        train_bar.set_description(
            'Epoch: [{}/{}] Loss: {:.4f} Verifier_loss: {:.2f} Scores: {:.2f} Acc: {:.2f}'.format(ep, args.epochs_sft, total_loss.get(), verifier_loss.get(), scores_meter.get(),
             total_acc.get(percentage=True))
        )
        if num_iters % 50 == 0:
            loader_eval = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
            loader_eval.dataset.eval()
            evaluate(model, loader_eval, tokenizer, ctx, num_prefix_tokens=test_data.num_prefix_tokens, y_tokens=test_data.num_target_tokens-test_data.pad_length, pad_length=test_data.pad_length)
            loader_eval.dataset.train()



 accuracy: 64.00: 100%|██████████| 2/2 [00:03<00:00,  1.74s/it]: 98.44:  28%|██▊       | 9/32 [00:13<00:30,  1.31s/it] 
Epoch: [0/15] Loss: 0.1235 Verifier_loss: 0.36 Scores: 0.48 Acc: 99.50: 100%|██████████| 32/32 [00:45<00:00,  1.41s/it]
 accuracy: 55.00: 100%|██████████| 2/2 [00:03<00:00,  1.77s/it]: 100.00:  84%|████████▍ | 27/32 [00:36<00:06,  1.33s/it]
Epoch: [1/15] Loss: 0.1292 Verifier_loss: 0.28 Scores: 0.49 Acc: 100.00: 100%|██████████| 32/32 [00:44<00:00,  1.40s/it]
Epoch: [2/15] Loss: 0.1192 Verifier_loss: 0.27 Scores: 0.49 Acc: 100.00: 100%|██████████| 32/32 [00:41<00:00,  1.30s/it]
 accuracy: 53.00: 100%|██████████| 2/2 [00:03<00:00,  1.87s/it]: 100.00:  41%|████      | 13/32 [00:18<00:24,  1.31s/it]
Epoch: [3/15] Loss: 0.1236 Verifier_loss: 0.26 Scores: 0.50 Acc: 100.00: 100%|██████████| 32/32 [00:45<00:00,  1.41s/it]
 accuracy: 54.00: 100%|██████████| 2/2 [00:03<00:00,  1.74s/it]: 100.00:  97%|█████████▋| 31/32 [00:41<00:01,  1.31s/it]
Epoch: [4/15] Loss: 0.1242 Verifie

In [23]:
-torch.log(torch.sigmoid(vloss))

tensor(0.6931)

In [19]:
np.array(tokenizer.batch_decode(outputs)).reshape(x.shape[0], 2)

array([['27,11|22,43|27,1|11,44|36,41|1,22|44,36|43,31/27,41=27,1,22,43,31',
        '27,11|22,43|27,1|11,44|36,41|1,22|44,36|43,31/27,41=27,11,44,36,41'],
       ['26,29|34,31|31,26|29,22|36,4|10,36|38,10|34,38/34,22=34,31,26,29,22',
        '26,29|34,31|31,26|29,22|36,4|10,36|38,10|34,38/34,22=34,38,10,36,4'],
       ['11,33|4,15|31,43|44,41|36,11|15,31|36,4|33,44/36,43=36,4,15,31,43',
        '11,33|4,15|31,43|44,41|36,11|15,31|36,4|33,44/36,43=36,11,33,44,41'],
       ['30,49|35,38|2,23|34,30|46,35|2,34|23,46|49,20/2,20=2,23,46,35,38',
        '30,49|35,38|2,23|34,30|46,35|2,34|23,46|49,20/2,20=2,34,30,49,20'],
       ['26,49|29,9|36,12|9,26|29,28|44,36|28,44|49,18/29,18=29,28,44,36,12',
        '26,49|29,9|36,12|9,26|29,28|44,36|28,44|49,18/29,18=29,9,26,49,18'],
       ['37,47|7,48|7,9|48,5|44,25|9,44|25,11|5,37/7,11=7,9,44,25,11',
        '37,47|7,48|7,9|48,5|44,25|9,44|25,11|5,37/7,11=7,48,5,37,47'],
       ['0,7|13,1|23,45|1,23|43,31|7,43|26,0|26,13/26,31=26,0,7,43,31',
      

In [20]:
tokenizer.batch_decode(next(test_loader.__iter__())[0])

['6,14|19,25|14,19|18,32|32,39|39,20|18,6|20,42/18,42=18,32,39,20,',
 '35,3|3,22|22,12|12,31|43,20|49,5|5,43|35,49/35,20=35,49,5,43,',
 '40,27|27,45|7,43|30,28|40,7|6,47|45,30|43,6/40,47=40,7,43,6,',
 '28,32|7,37|37,36|36,42|8,6|32,8|6,41|28,7/28,42=28,7,37,36,',
 '21,1|39,26|35,27|26,21|28,39|27,48|28,35|48,2/28,1=28,39,26,21,',
 '0,23|49,18|20,1|8,7|39,0|23,8|39,49|18,20/39,1=39,49,18,20,',
 '26,9|23,26|18,46|25,13|46,25|3,23|4,3|4,18/4,13=4,18,46,25,',
 '21,45|38,27|27,4|4,35|45,5|9,38|9,21|5,1/9,35=9,38,27,4,',
 '25,24|24,42|23,45|48,36|45,44|23,25|44,48|42,14/23,36=23,45,44,48,',
 '12,26|32,34|46,12|48,30|26,48|42,32|46,31|31,42/46,30=46,12,26,48,',
 '3,7|26,3|7,10|12,38|10,45|26,35|38,47|35,12/26,47=26,35,12,38,',
 '36,23|5,18|45,26|38,5|18,37|38,45|26,36|37,42/38,23=38,45,26,36,',
 '19,44|5,30|5,15|1,36|12,1|44,29|30,19|15,12/5,36=5,15,12,1,',
 '30,10|29,32|24,12|35,29|24,46|12,19|46,35|19,30/24,32=24,46,35,29,',
 '8,48|46,34|29,28|48,29|24,46|34,25|25,33|24,8/24,28=24,8,48,29,'