In [1]:
import sys
sys.path.append('../../')
import time
from jax_smi import initialise_tracking
initialise_tracking()
from tqdm import tqdm

In [2]:
from dataclasses import dataclass, field, asdict
from typing import Tuple, Optional, Union
from EasyLM.models.gpt2.gpt2_model import GPT, GPTConfig, get_pretrained_params
from torch.utils.data import DataLoader

@dataclass(frozen=True)
class WandbConfig:
    """
    wandb logging configuration
    """
    entity: str = 'ars22'
    """username or team name where you're sending runs"""
    project: str = 'star_graph'
    """project name"""
    name: str = 'gpt2'
    """experiment name"""
    mode: str = 'online'
    """'offline', 'online', or 'disabled'"""
    notes: str = ''


@dataclass(frozen=True)
class CosineDecayScheduleConfig:
    init_value: float = 0.0
    peak_value: float = 2.5e-4
    warmup_steps: int = 2000
    decay_steps: int = 150000
    end_value: float = 1e-5

@dataclass(frozen=True)
class StaticLRConfig:
    init_value: float = 1e-4


@dataclass(frozen=False)
class TrainConfig:
    gpt2_model_type: str = 'gpt2' # gpt2 model type
    seed: int = 555
    out_dir: str = 'out'                        # output directory for checkpoints (can be gcs path)
    shuffle_buffer_size: int = 128
    eval_interval: int = 500
    eval_steps: int = 16        # evaluate for this number of steps (per-device)
    eval_only: bool = False     # if True, script exits right after the first eval
    keep_checkpoints: int = 3   # number of historical checkpoints to keep
    batch_size: int = 128        # per-device batch size
    train_steps: int = 30     # total number of training iterations
    weight_decay: float = 0.0  # not applied to bias and embedding parameters
    grad_clip: float = 1.0      # gradient norm clipping magnitude
    gradient_accumulation_steps: int = 1    # used to simulate larger batch sizes
    betas: Tuple[float, float] = (0.9, 0.95) # adamw optimizer betas
    # learning_rate: CosineDecayScheduleConfig = field(default_factory=CosineDecayScheduleConfig)
    learning_rate: StaticLRConfig = field(default_factory=StaticLRConfig)
    wandb: WandbConfig = field(default_factory=WandbConfig) # wandb logging
    model: GPTConfig = field(default_factory=GPTConfig)     # gpt model config
    remat: bool = False    # set to True to rematerialize gradients during backward pass


def get_default_config() -> TrainConfig:
    return TrainConfig()

config = get_default_config()
config

  from .autonotebook import tqdm as notebook_tqdm


TrainConfig(gpt2_model_type='gpt2', seed=555, out_dir='out', shuffle_buffer_size=128, eval_interval=500, eval_steps=16, eval_only=False, keep_checkpoints=3, batch_size=128, train_steps=30, weight_decay=0.0, grad_clip=1.0, gradient_accumulation_steps=1, betas=(0.9, 0.95), learning_rate=StaticLRConfig(init_value=0.0001), wandb=WandbConfig(entity='ars22', project='star_graph', name='gpt2', mode='online', notes=''), model=GPTConfig(block_size=1024, vocab_size=50257, num_layers=12, num_heads=12, num_embeds=768, dropout_rate=0.1, use_bias=True, dtype=None), remat=False)

In [3]:
import jax
import jax.numpy as jnp
import flax
from flax.core import FrozenDict, frozen_dict
from flax.training import checkpoints
from flax.training.train_state import TrainState
from flax.jax_utils import replicate, unreplicate
import optax
from functools import partial

2024-05-01 05:05:59.156712: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2024-05-01 05:05:59.801494: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2024-05-01 05:05:59.801572: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib


In [4]:
from torch.utils.data import Dataset
import torch

def prefix_target_list(filename=None):
    """
    Load graphs and split them into prefix and target and return the list
    """
    data_list = []
    with open(filename, 'r') as f:
        lines = f.readlines()
    for line in lines:
        prefix = line.strip().split('=')[0] + '='
        target = line.strip().split('=')[1]
        # target = target.split(',')[1]
        data_list.append((prefix, target))
    return data_list


class Graphs(Dataset):
    def __init__(self, tokenizer, n_samples, data_path):
        self.tokenizer = tokenizer
        self.n_samples = n_samples
        self.data_path = data_path
        self.eval_mode = False
        self.data_file = prefix_target_list(self.data_path)
        self.tokenized, self.num_prefix_tokens, self.num_target_tokens = self.tokenize(self.data_file[:n_samples])

    def __len__(self):
        return len(self.tokenized)

    def __getitem__(self, idx):
        if self.eval_mode:
            # In eval mode return the entire sequence
            return self.tokenized[idx].to(self.device)

        # Create inputs
        x = self.tokenized[idx].clone()
        y = torch.cat([-torch.ones((self.num_prefix_tokens - 1, )),
                       x[self.num_prefix_tokens:].clone()])
        return x[:-1], y.long()

    def tokenize(self, data_list):
        """
        Takes a list of prefix-target pairs, tokenizes and concatenates them
        """
        out = []
        prefix_len = len(self.tokenizer.encode(data_list[0][0]))
        target_len = len(self.tokenizer.encode(data_list[0][1]))
        same_len = True

        for prefix, target in data_list:
            prefix = torch.tensor(self.tokenizer.encode(prefix))
            target = torch.tensor(self.tokenizer.encode(target))
            if not (len(prefix) == prefix_len and len(target) == target_len):
                same_len = False
            seq = torch.concatenate([prefix, target], dim=-1).long()
            out.append(seq)

        # Check if all prefixes and all targets have the same length
        if not same_len:
            print('Not all prefixes or targets have the same length!!')
        else:
            print('Equal sequence lengths!')

        return out, prefix_len, target_len

    def eval(self):
        # Switch to "eval" mode when generating sequences without teacher-forcing
        self.eval_mode = True

    def train(self):
        # Switch back to "train" mode for teacher-forcing
        self.eval_mode = False

In [5]:
# LOAD TOKENIZER
from transformers import AutoTokenizer # type: ignore
tokenizer = AutoTokenizer.from_pretrained(config.gpt2_model_type)
tokenizer.pad_token_id = tokenizer.eos_token_id

# LOAD DATASET
data_path = 'deg_2_path_4_nodes_10'
train_path, test_path = data_path + '_train_200000.txt', data_path + '_test_20000.txt'
train_data = Graphs(tokenizer=tokenizer, n_samples=20000, data_path=train_path)
test_data = Graphs(tokenizer=tokenizer, n_samples=500, data_path=test_path)
train_data.train()

# sanity check
print(train_data[0], tokenizer.decode(train_data[0][0]), tokenizer.decode(train_data[0][1][-train_data.num_target_tokens:]))

# LOAD DATALOADER
train_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, drop_last=True) 
test_loader = DataLoader(test_data, batch_size=config.batch_size, shuffle=False, drop_last=True)


Equal sequence lengths!
Equal sequence lengths!
(tensor([21, 11, 24, 91, 22, 11, 21, 91, 24, 11, 15, 91, 20, 11, 17, 91, 22, 11,
        19, 91, 19, 11, 20, 14, 22, 11, 17, 28, 22, 11, 19, 11, 20, 11]), tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, 22, 11, 19, 11, 20, 11, 17])) 6,9|7,6|9,0|5,2|7,4|4,5/7,2=7,4,5, 7,4,5,2


In [6]:
print("tokenizer vocab size: ", tokenizer.vocab_size, train_data.num_prefix_tokens, train_data.num_target_tokens)

tokenizer vocab size:  50257 28 7


In [7]:
def param_decay_mask(params: FrozenDict) -> FrozenDict:
    """ pytree mask for non-bias parameters """
    flat_params = flax.traverse_util.flatten_dict(params)
    flat_param_mask = {k: k[-1] not in ('bias', 'embedding', 'scale') for k in flat_params.keys()}
    param_mask = flax.traverse_util.unflatten_dict(flat_param_mask)
    return frozen_dict.freeze(param_mask)

def init_train_state(key, config: TrainConfig, learning_rate) -> TrainState:

    if config.remat:
        model = flax.linen.remat(GPT,
            static_argnums=(2,),
            policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)(config.model)
    else:
        config.model, params = get_pretrained_params(config.gpt2_model_type)
        model = GPT(config.model)    
        model.init(key)

    optimizer = optax.chain(
        # Apply weight decay only to non-bias parameters
        optax.clip_by_global_norm(config.grad_clip),
        optax.adamw(learning_rate, *config.betas, weight_decay=config.weight_decay, mask=param_decay_mask(params)),
        optax.apply_every(config.gradient_accumulation_steps),
    )

    train_state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer)

    return train_state

def count_params(params: FrozenDict) -> int:
    p = jax.tree_util.tree_map(lambda a: a.size if isinstance(a, jnp.ndarray) else 0, params)
    return jax.tree_util.tree_reduce(lambda a, b: a + b, p)

In [8]:
# =====  init parameters ============
key = jax.random.PRNGKey(config.seed)
key, key_params, key_dropout, key_generation = jax.random.split(key, 4)
# make sure dropout keys are different for each device (local and global)
key_dropout = jax.random.fold_in(key_dropout, jax.process_index())
keys_dropout = jax.random.split(key_dropout, jax.local_device_count())
key_gen = jax.random.split(key_generation, jax.local_device_count())

In [9]:
learning_rate = config.learning_rate.init_value
train_state = init_train_state(key_params, config, learning_rate)
num_params = count_params(train_state.params)

loading weights from pretrained gpt: gpt2


In [10]:
print(f"Total parameters: {num_params:,}") # 774,030,080 for gpt2-large

Total parameters: 124,439,808


In [11]:
from flax.core import FrozenDict, freeze, unfreeze
from transformers import FlaxGPT2LMHeadModel
hf_model = FlaxGPT2LMHeadModel.from_pretrained(config.gpt2_model_type)
hf_params = hf_model.init_weights(key_params, (2, config.model.block_size))

In [12]:
# replicate model
train_state = replicate(train_state)
hf_params = replicate(hf_params)


In [13]:
def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
    if valid is None:
        valid = jnp.ones(tokens.shape[:2])
    valid = valid.astype(jnp.float32)
    valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
    logits = logits.astype(jnp.float32)  # for numerical stability
    token_log_prob = jnp.squeeze(
        jnp.take_along_axis(
            jax.nn.log_softmax(logits, axis=-1),
            jnp.expand_dims(tokens, -1),
            axis=-1,
        ),
        -1,
    )
    token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
    loss = -(jnp.sum(token_log_prob) / jnp.sum(valid))
    # old: loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
    # changed to match hf implementation
    correct = jnp.where(
        valid > 0.0,
        jnp.argmax(logits, axis=-1) == tokens,
        jnp.array(False)
    )
    accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
    return loss, accuracy


@partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0, 0))
def train_step(state: TrainState, input_tokens: jnp.ndarray, target_tokens: jnp.ndarray, dropout_key):
    dropout_key = jax.random.fold_in(dropout_key, state.step)
    def loss_fn(params: FrozenDict) -> jnp.ndarray:
        logits = state.apply_fn(params, input_tokens, False, rngs={'dropout': dropout_key})
        
        logits = logits.astype(jnp.float32)  # for numerical stability
        token_log_prob = jnp.squeeze(
            jnp.take_along_axis(
                jax.nn.log_softmax(logits, axis=-1),
                jnp.expand_dims(target_tokens, -1),
                axis=-1,
            ),
            -1,
        )
        prob_hard_token = jnp.exp(token_log_prob[:, train_data.num_prefix_tokens+1]).mean()
        
        loss, acc = cross_entropy_loss_and_accuracy(
            logits, target_tokens, (target_tokens > 0).astype(jnp.int32))
        
        return loss, (prob_hard_token, acc)
    # per-device loss and grads
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (prob_hard_token, acc)), grads = grad_fn(state.params)
    # average gradients across devices
    prob_hard_token = jax.lax.pmean(prob_hard_token, axis_name="batch")
    grads = jax.lax.pmean(grads, axis_name="batch")
    loss = jax.lax.pmean(loss, axis_name="batch")
    acc = jax.lax.pmean(acc, axis_name="batch")
    new_state = state.apply_gradients(grads=grads)
    
    return loss, acc, prob_hard_token, new_state


from flax.traverse_util import flatten_dict, unflatten_dict

def convert_jax_params_to_hf(hf_params, jax_params) -> FrozenDict:
    hf_params = unfreeze(hf_params)
    
    for k in ['ln_f', 'wpe', 'wte']:
        hf_params['transformer'][k] = jax_params[k]
    for k in hf_params['transformer']['h'].keys():
        hf_params['transformer']['h'][k] = jax_params[k] 

    hf_params = flatten_dict(hf_params, sep='.')
    for k in hf_params.keys():
        if k.endswith('attn.c_attn.kernel'):
            hf_params[k] = hf_params[k].T
        elif k.endswith('attn.c_proj.kernel'):
            hf_params[k] = hf_params[k].T
        elif len(k.split('.')) > 3 and k.split('.')[3] == 'mlp' and k.endswith('kernel'):
            hf_params[k] = hf_params[k].T
    hf_params = unflatten_dict({k: v for k, v in hf_params.items()}, sep='.')
    return freeze(hf_params)


@partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0, 0))
def eval_step(hf_params, state, input_tokens: jnp.ndarray, target_tokens: jnp.ndarray):
    hf_params = convert_jax_params_to_hf(hf_params, state.params['params'])
    output = hf_model.generate(
        input_tokens[:, :train_data.num_prefix_tokens],
        params=hf_params,
        max_new_tokens=train_data.num_target_tokens,
        min_length=train_data.num_target_tokens+train_data.num_prefix_tokens, 
        do_sample=False, 
        attention_mask=jnp.ones_like(input_tokens[:, :train_data.num_prefix_tokens]))
    acc = ((output[0][:, -train_data.num_target_tokens:] == target_tokens[:, -train_data.num_target_tokens:]).sum(1) == train_data.num_target_tokens).mean()
    acc = jax.lax.pmean(acc, axis_name="batch")
    return acc

max_new_tokens = train_data.num_target_tokens
num_beams=2
num_return_sequences=2
temperature=1.0

from flax.core import FrozenDict, freeze, unfreeze

@partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0, 0))
def generate_negative_data(hf_params, train_state, input_tokens, key):
    hf_params = convert_jax_params_to_hf(hf_params, train_state.params['params'])
    return hf_model.generate(
        input_tokens[:, :train_data.num_prefix_tokens],
        params=hf_params,
        max_new_tokens=max_new_tokens, 
        min_length=train_data.num_target_tokens+train_data.num_prefix_tokens,
        prng_key=key, 
        num_beams=num_beams, 
        num_return_sequences=num_return_sequences, 
        temperature=1.0,
        attention_mask=jnp.ones_like(input_tokens[:, :train_data.num_prefix_tokens]))



def evaluate(hf_params, state: TrainState, loader: DataLoader) -> jnp.ndarray:
    accs = []
    for batch in loader:
        input_tokens, target_tokens = batch
        input_tokens = jnp.array(input_tokens)
        target_tokens = jnp.array(target_tokens)
        input_tokens = input_tokens.reshape(jax.local_device_count(), -1, input_tokens.shape[-1])
        target_tokens = target_tokens.reshape(jax.local_device_count(), -1, target_tokens.shape[-1])
        acc = eval_step(hf_params, state, input_tokens, target_tokens)
        accs.append(acc)
    return jnp.mean(jnp.stack(accs))

In [14]:
class AverageMeter:
    def __init__(self):
        self.num = 0
        self.val = 0

    def update(self, val, num):
        self.val += val * num
        self.num += num

    def get(self, percentage=False):
        val = self.val / self.num * 100 if percentage else self.val / self.num
        return val

In [15]:
def get_log_likelihood(state, input_tokens, target_tokens, dropout_key):
    logits = state.apply_fn(state.params, input_tokens, False, rngs={'dropout': dropout_key})
    valid = (target_tokens > 0).astype(jnp.float32)
    valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
    
    logits = logits.astype(jnp.float32)  # for numerical stability
    token_log_prob = jnp.squeeze(
        jnp.take_along_axis(
            jax.nn.log_softmax(logits, axis=-1),
            jnp.expand_dims(target_tokens, -1),
            axis=-1,
        ),
        -1,
    )
    token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
    return token_log_prob 

def get_token_level_scores(original_dataset, generated_dataset):
    token_scores_arr = []
    for i in range(original_dataset.shape[0]):
        original_seq = original_dataset[i]
        generated_seq = generated_dataset[i]
        token_scores = (original_seq == generated_seq).float()
        incorrect = torch.where(token_scores==0)
        if len(incorrect[0]) > 0:
            token_scores[incorrect[0][0].item()] = -1. 
            token_scores[incorrect[0][0].item()+1:] = 0. 
        token_scores[:train_data.num_prefix_tokens] = 0.
        token_scores_arr.append(token_scores)
    return torch.stack(token_scores_arr, 0)

In [1]:
# def signed_log_sigmoid_loss_and_accuracy(logits, tokens, scores, valid=None):
#     if valid is None:
#         valid = jnp.ones(tokens.shape[:2])
#     valid = valid.astype(jnp.float32)
#     scores = scores.astype(jnp.float32)
#     valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
#     logits = logits.astype(jnp.float32)  # for numerical stability
#     token_log_prob = jnp.squeeze(
#         jnp.take_along_axis(
#             jax.nn.log_softmax(logits, axis=-1),
#             jnp.expand_dims(tokens, -1),
#             axis=-1,
#         ),
#         -1,
#     )
#     sign_token_log_prob = jnp.where(valid > 0.0, token_log_prob * scores, jnp.array(0.0))
#     # sign_seq_log_prob = jnp.sum(sign_token_log_prob, axis=-1) /valid_text_length
#     # # loss = -jax.nn.log_sigmoid(sign_seq_log_prob).mean()
#     # loss = -sign_seq_log_prob.mean()

#     # loss = -(sign_token_log_prob.sum(axis=-1)).mean()
#     loss = -jax.nn.log_sigmoid(sign_token_log_prob.sum(axis=-1) / valid_text_length).mean()

#     # old: loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
#     # changed to match hf implementation
#     correct = jnp.where(
#         (valid > 0.0) & (scores > 0.0),
#         jnp.argmax(logits, axis=-1) == tokens,
#         jnp.array(False)
#     )
#     accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
#     return loss, accuracy


def dpo_loss(pos_logits, pos_tokens, neg_logits, neg_tokens):
    valid = (pos_tokens > 0).astype(jnp.float32)
    valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
    pos_logits = pos_logits.astype(jnp.float32)  # for numerical stability
    pos_token_log_prob = jnp.squeeze(
        jnp.take_along_axis(
            jax.nn.log_softmax(pos_logits, axis=-1),
            jnp.expand_dims(pos_tokens, -1),
            axis=-1,
        ),
        -1,
    )
    pos_token_log_prob = jnp.where(valid > 0.0, pos_token_log_prob, jnp.array(0.0))
    pos_token_log_prob = pos_token_log_prob.sum(axis=-1) 
    # / valid_text_length

    correct = jnp.where(
        (valid > 0.0),
        jnp.argmax(pos_logits, axis=-1) == pos_tokens,
        jnp.array(False)
    )
    accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
    

    valid = (neg_tokens > 0).astype(jnp.float32)
    valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
    neg_logits = neg_logits.astype(jnp.float32)  # for numerical stability
    neg_token_log_prob = jnp.squeeze(
        jnp.take_along_axis(
            jax.nn.log_softmax(neg_logits, axis=-1),
            jnp.expand_dims(neg_tokens, -1),
            axis=-1,
        ),
        -1,
    )
    neg_token_log_prob = jnp.where(valid > 0.0, neg_token_log_prob, jnp.array(0.0))
    neg_token_log_prob = neg_token_log_prob.sum(axis=-1)
    # / valid_text_length

    loss =  (-jax.nn.log_sigmoid(1. * (pos_token_log_prob - neg_token_log_prob))).mean()
    # loss =  (pos_token_log_prob - neg_token_log_prob).mean()
    return loss, accuracy




@partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0, 0, 0, 0))
def train_step_onpolicy(state: TrainState, pos_input_tokens, pos_target_tokens, neg_input_tokens, neg_target_tokens, dropout_key) -> Tuple[jnp.ndarray, TrainState]:
    dropout_key = jax.random.fold_in(dropout_key, state.step)
    def loss_fn(params: FrozenDict) -> jnp.ndarray:
        pos_logits = state.apply_fn(params, pos_input_tokens, False, rngs={'dropout': dropout_key})
        neg_logits = state.apply_fn(params, neg_input_tokens, False, rngs={'dropout': dropout_key})
        loss, acc = dpo_loss(
            pos_logits, pos_target_tokens, neg_logits, neg_target_tokens)
        token_log_prob = jnp.squeeze(
            jnp.take_along_axis(
                jax.nn.log_softmax(pos_logits, axis=-1),
                jnp.expand_dims(pos_target_tokens, -1),
                axis=-1,
            ),
            -1,
        )
        prob_hard_token = jnp.exp(token_log_prob[:, train_data.num_prefix_tokens+1]).mean()
        return loss, (prob_hard_token, acc)    
        
    # per-device loss and grads
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (prob_hard_token, acc)), grads = grad_fn(state.params)
    # average gradients across devices
    grads = jax.lax.pmean(grads, axis_name="batch")
    loss = jax.lax.pmean(loss, axis_name="batch")
    acc = jax.lax.pmean(acc, axis_name="batch")
    prob_hard_token = jax.lax.pmean(prob_hard_token, axis_name="batch")
    new_state = state.apply_gradients(grads=grads)
    return loss, acc, prob_hard_token, new_state


NameError: name 'partial' is not defined

In [None]:



# import numpy as np
# original_dataset = []
# generated_dataset = []
# # scores_dataset = []
# loader = DataLoader(train_data, batch_size=64, shuffle=True, drop_last=True) 

# for input_tokens, target_tokens in tqdm(iter(loader)): 
#     input_tokens = jnp.array(input_tokens)
#     target_tokens = jnp.array(target_tokens)
#     original = jnp.concatenate([input_tokens[:, :train_data.num_prefix_tokens], target_tokens[:, -train_data.num_target_tokens:]], axis=1)
    
#     input_tokens = input_tokens.reshape(jax.device_count(), -1, input_tokens.shape[-1])

#     generations = generate_negative_data(hf_params, train_state, input_tokens, key_gen)
    
#     repeated_original = jnp.repeat(original[None, :, :], num_return_sequences, 0).transpose(1, 0, 2)
    
#     original_dataset.append(repeated_original.reshape(-1, repeated_original.shape[-1]))
#     generated_dataset.append(generations[0].reshape(-1, generations[0].shape[-1]))
    
#     # scores_dataset.append(generations[1].reshape(-1, generations[1].shape[-1]))

# # scores_dataset = jnp.concatenate([x.reshape(-1) for x in scores_dataset], axis=0).squeeze()
# generated_dataset = jnp.concatenate(generated_dataset, axis=0)
# original_dataset = jnp.concatenate(original_dataset, axis=0)
# token_lvl_scores = get_token_level_scores( torch.tensor(np.array(original_dataset)), torch.tensor(np.array(generated_dataset)))


In [None]:
# token_lvl_scores.shape, len(train_data) // 64 * 64 * 5, len(generated_dataset)

In [None]:
# on_policy_dataset = torch.utils.data.TensorDataset(
#     torch.tensor(np.array(original_dataset)), torch.tensor(np.array(generated_dataset)), token_lvl_scores)

# on_policy_loader = DataLoader(on_policy_dataset, batch_size=128, shuffle=True, drop_last=True)
# on_policy_iter = iter(on_policy_loader) 

# pbar = tqdm(on_policy_loader, total=len(on_policy_loader), desc='training on-policy')

# train_loss, train_acc, phard = AverageMeter(), AverageMeter(), AverageMeter() 
# policy_train_acc, policy_test_acc = 0., 0.

# step = 0.

# for original, generation, scores in pbar:

#     pos_input_tokens = original.clone()
#     pos_input_tokens = pos_input_tokens[:, :-1]
#     pos_target_tokens = original.clone()
#     pos_target_tokens[:, :train_data.num_prefix_tokens] = -1
#     pos_target_tokens = pos_target_tokens[:, 1:]

#     neg_input_tokens = generation.clone()
#     neg_input_tokens = neg_input_tokens[:, :-1]
#     neg_target_tokens = generation.clone()
#     neg_target_tokens[scores == 0] = -1
#     neg_target_tokens = neg_target_tokens[:, 1:]

    
#     pos_input_tokens = jnp.array(pos_input_tokens)
#     neg_input_tokens = jnp.array(neg_input_tokens)
#     pos_target_tokens = jnp.array(pos_target_tokens)
#     neg_target_tokens = jnp.array(neg_target_tokens)

#     pos_input_tokens = pos_input_tokens.reshape(jax.device_count(), -1, pos_input_tokens.shape[-1])
#     pos_target_tokens = pos_target_tokens.reshape(jax.device_count(), -1, pos_target_tokens.shape[-1])
#     neg_input_tokens = neg_input_tokens.reshape(jax.device_count(), -1, neg_input_tokens.shape[-1])
#     neg_target_tokens = neg_target_tokens.reshape(jax.device_count(), -1, neg_target_tokens.shape[-1])

    
#     # print(pos_input_tokens.shape, pos_target_tokens.shape, neg_input_tokens.shape, neg_target_tokens.shape) 
#     # break

#     loss, acc, train_state = train_step_onpolicy(train_state, pos_input_tokens, pos_target_tokens, neg_input_tokens, neg_target_tokens, keys_dropout)
#     train_loss.update(loss.mean(), pos_input_tokens.shape[1] * jax.device_count())  
#     train_acc.update(acc.mean(), pos_input_tokens.shape[1] * jax.device_count())    

#     # phard.update(phard_token.mean(), input_tokens.shape[1] * jax.device_count())
#     if step % 10 == 0:
#         pbar.set_description(f'train loss: {train_loss.get()} forcing train acc: {train_acc.get(percentage=True)} policy train acc: {100. * policy_train_acc} policy test acc: {100. * policy_test_acc}')
#     if step % config.eval_interval == 0:
#         policy_train_acc = evaluate(hf_params, train_state, train_loader)
#         policy_test_acc = evaluate(hf_params, train_state, test_loader)
#         train_loss, train_acc, phard = AverageMeter(), AverageMeter(), AverageMeter()
    
#     step += 1

In [20]:
import numpy as np

for ep in range(50):
    loader = DataLoader(train_data, batch_size=128, shuffle=True, drop_last=True) 
    train_loss, train_acc, phard = AverageMeter(), AverageMeter(), AverageMeter() 
    policy_train_acc, policy_test_acc = 0., 0.

    step = 0
    pbar = tqdm(loader, total=len(loader), desc='training on-policy')

    for input_tokens, target_tokens in pbar: 
        input_tokens = jnp.array(input_tokens)
        target_tokens = jnp.array(target_tokens)
        original = jnp.concatenate([input_tokens[:, :train_data.num_prefix_tokens], target_tokens[:, -train_data.num_target_tokens:]], axis=1)
        
        input_tokens = input_tokens.reshape(jax.device_count(), -1, input_tokens.shape[-1])
        generations = generate_negative_data(hf_params, train_state, input_tokens, key_gen)
        generations = generations[0].reshape(-1, generations[0].shape[-1]) 
        repeated_original = jnp.repeat(original[None, :, :], num_return_sequences, 0).transpose(1, 0, 2)
        original = repeated_original.reshape(-1, repeated_original.shape[-1])
        original = torch.tensor(np.array(original))
        generations = torch.tensor(np.array(generations))
        token_lvl_scores = get_token_level_scores(original , generations)
        

        pos_input_tokens = original.clone()
        pos_input_tokens = pos_input_tokens[:, :-1]
        pos_target_tokens = original.clone()
        # pos_target_tokens[:, :train_data.num_prefix_tokens] = -1
        pos_target_tokens[token_lvl_scores == 0] = -1   
        pos_target_tokens = pos_target_tokens[:, 1:]

        neg_input_tokens = generations.clone()
        neg_input_tokens = neg_input_tokens[:, :-1]
        neg_target_tokens = generations.clone()
        neg_target_tokens[token_lvl_scores == 0] = -1
        neg_target_tokens = neg_target_tokens[:, 1:]

        
        pos_input_tokens = jnp.array(pos_input_tokens)
        neg_input_tokens = jnp.array(neg_input_tokens)
        pos_target_tokens = jnp.array(pos_target_tokens)
        neg_target_tokens = jnp.array(neg_target_tokens)

        pos_input_tokens = pos_input_tokens.reshape(jax.device_count(), -1, pos_input_tokens.shape[-1])
        pos_target_tokens = pos_target_tokens.reshape(jax.device_count(), -1, pos_target_tokens.shape[-1])
        neg_input_tokens = neg_input_tokens.reshape(jax.device_count(), -1, neg_input_tokens.shape[-1])
        neg_target_tokens = neg_target_tokens.reshape(jax.device_count(), -1, neg_target_tokens.shape[-1])

        
        # print(pos_input_tokens.shape, pos_target_tokens.shape, neg_input_tokens.shape, neg_target_tokens.shape) 
        # break

        loss, acc, prob_hard_token, train_state = train_step_onpolicy(train_state, pos_input_tokens, pos_target_tokens, neg_input_tokens, neg_target_tokens, keys_dropout)
        train_loss.update(loss.mean(), pos_input_tokens.shape[1] * jax.device_count())  
        train_acc.update(acc.mean(), pos_input_tokens.shape[1] * jax.device_count())   
        phard.update(prob_hard_token.mean(), pos_input_tokens.shape[1] * jax.device_count()) 

        # phard.update(phard_token.mean(), input_tokens.shape[1] * jax.device_count())
        if step % 10 == 0:
            pbar.set_description(f'train loss: {train_loss.get()} phard: {phard.get()} forcing train acc: {train_acc.get(percentage=True)} policy train acc: {100. * policy_train_acc} policy test acc: {100. * policy_test_acc}')
        if step % config.eval_interval == 0:
            policy_train_acc = evaluate(hf_params, train_state, train_loader)
            policy_test_acc = evaluate(hf_params, train_state, test_loader)
            train_loss, train_acc, phard = AverageMeter(), AverageMeter(), AverageMeter()
        
        step += 1


training on-policy:   0%|          | 0/156 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
  x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
train loss: 0.6976994872093201 forcing train acc: 9.440103530883789 policy train acc: 0.0 policy test acc: 0.0:   0%|          | 0/156 [02:23<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
train loss: 0.6934075355529785 forcing train acc: 77.08564758300781 policy train acc: 0.4457131624221802 policy test acc: 0.2604166865348816: 100%|██████████| 156/156 [03:12<00:00,  1.23s/it]
train loss: 0.6931999325752258 forcing train acc: 78.8960189819336 policy train acc: 42.76342010498047 policy test acc: 42.96875: 100%|██████████| 156/156 [00:22<00:00,  6.85it/s] 
train loss: 0.6931730508804321 forcing train acc: 81.21218872070312 policy train acc: 51.09174728393555 policy test acc: 48.69791793823242: 100%|██████████| 156/156 [00:22<00:00,  6.82it/s]
train loss: 0.6932268142700

KeyboardInterrupt: 

In [44]:
generate_negative_data(hf_params, train_state, input_tokens, key_gen)[0].shape

(4, 160, 35)

In [23]:
j = 3
pos_input_tokens[j][j], pos_target_tokens[j][j], neg_input_tokens[j][j], neg_target_tokens[j][j]

(Array([18, 11, 15, 91, 23, 11, 16, 91, 15, 11, 17, 91, 17, 11, 19, 91, 16,
        11, 21, 91, 18, 11, 23, 14, 18, 11, 19, 28, 18, 11, 15, 11, 17, 11],      dtype=int32),
 Array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 18, 11, 15, -1, -1, -1, -1],      dtype=int32),
 Array([18, 11, 15, 91, 23, 11, 16, 91, 15, 11, 17, 91, 17, 11, 19, 91, 16,
        11, 21, 91, 18, 11, 23, 14, 18, 11, 19, 28, 18, 11, 22, 11, 23, 11],      dtype=int32),
 Array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 18, 11, 22, -1, -1, -1, -1],      dtype=int32))