In [2]:
import sys
sys.path.append('../../')
import time
from jax_smi import initialise_tracking
initialise_tracking()
from tqdm import tqdm

In [7]:
from dataclasses import dataclass, field, asdict
from typing import Tuple, Optional, Union
from EasyLM.models.gpt2.gpt2_model import GPT, GPTConfig, get_pretrained_params
from torch.utils.data import DataLoader

@dataclass(frozen=True)
class WandbConfig:
    """
    wandb logging configuration
    """
    entity: str = 'ars22'
    """username or team name where you're sending runs"""
    project: str = 'star_graph'
    """project name"""
    name: str = 'gpt2'
    """experiment name"""
    mode: str = 'online'
    """'offline', 'online', or 'disabled'"""
    notes: str = ''


@dataclass(frozen=True)
class CosineDecayScheduleConfig:
    init_value: float = 0.0
    peak_value: float = 2.5e-4
    warmup_steps: int = 2000
    decay_steps: int = 150000
    end_value: float = 1e-5

@dataclass(frozen=True)
class StaticLRConfig:
    init_value: float = 1e-4


@dataclass(frozen=False)
class TrainConfig:
    gpt2_model_type: str = 'gpt2' # gpt2 model type
    seed: int = 556
    out_dir: str = 'out'                        # output directory for checkpoints (can be gcs path)
    shuffle_buffer_size: int = 128
    eval_interval: int = 500
    eval_steps: int = 16        # evaluate for this number of steps (per-device)
    eval_only: bool = False     # if True, script exits right after the first eval
    keep_checkpoints: int = 3   # number of historical checkpoints to keep
    batch_size: int = 128        # per-device batch size
    train_steps: int = 30     # total number of training iterations
    weight_decay: float = 0.0  # not applied to bias and embedding parameters
    grad_clip: float = 10.0      # gradient norm clipping magnitude
    gradient_accumulation_steps: int = 1    # used to simulate larger batch sizes
    betas: Tuple[float, float] = (0.9, 0.95) # adamw optimizer betas
    # learning_rate: CosineDecayScheduleConfig = field(default_factory=CosineDecayScheduleConfig)
    learning_rate: StaticLRConfig = field(default_factory=StaticLRConfig)
    wandb: WandbConfig = field(default_factory=WandbConfig) # wandb logging
    model: GPTConfig = field(default_factory=GPTConfig)     # gpt model config
    remat: bool = False    # set to True to rematerialize gradients during backward pass


def get_default_config() -> TrainConfig:
    return TrainConfig()

config = get_default_config()
config

TrainConfig(gpt2_model_type='gpt2', seed=556, out_dir='out', shuffle_buffer_size=128, eval_interval=500, eval_steps=16, eval_only=False, keep_checkpoints=3, batch_size=128, train_steps=30, weight_decay=0.0, grad_clip=10.0, gradient_accumulation_steps=1, betas=(0.9, 0.95), learning_rate=StaticLRConfig(init_value=0.0001), wandb=WandbConfig(entity='ars22', project='star_graph', name='gpt2', mode='online', notes=''), model=GPTConfig(block_size=1024, vocab_size=50257, num_layers=12, num_heads=12, num_embeds=768, dropout_rate=0.1, use_bias=True, dtype=None), remat=False)

In [8]:
import jax
import jax.numpy as jnp
import flax
from flax.core import FrozenDict, frozen_dict
from flax.training import checkpoints
from flax.training.train_state import TrainState
from flax.jax_utils import replicate, unreplicate
import optax
from functools import partial

In [9]:
from torch.utils.data import Dataset
import torch

def prefix_target_list(filename=None):
    """
    Load graphs and split them into prefix and target and return the list
    """
    data_list = []
    with open(filename, 'r') as f:
        lines = f.readlines()
    for line in lines:
        prefix = line.strip().split('=')[0] + '='
        target = line.strip().split('=')[1]
        # target = target.split(',')[1]
        data_list.append((prefix, target))
    return data_list


class Graphs(Dataset):
    def __init__(self, tokenizer, n_samples, data_path):
        self.tokenizer = tokenizer
        self.n_samples = n_samples
        self.data_path = data_path
        self.eval_mode = False
        self.data_file = prefix_target_list(self.data_path)
        self.tokenized, self.num_prefix_tokens, self.num_target_tokens = self.tokenize(self.data_file[:n_samples])

    def __len__(self):
        return len(self.tokenized)

    def __getitem__(self, idx):
        if self.eval_mode:
            # In eval mode return the entire sequence
            return self.tokenized[idx].to(self.device)

        # Create inputs
        x = self.tokenized[idx].clone()
        y = torch.cat([-torch.ones((self.num_prefix_tokens - 1, )),
                       x[self.num_prefix_tokens:].clone()])
        return x[:-1], y.long()

    def tokenize(self, data_list):
        """
        Takes a list of prefix-target pairs, tokenizes and concatenates them
        """
        out = []
        prefix_len = len(self.tokenizer.encode(data_list[0][0]))
        target_len = len(self.tokenizer.encode(data_list[0][1]))
        same_len = True

        for prefix, target in data_list:
            prefix = torch.tensor(self.tokenizer.encode(prefix))
            target = torch.tensor(self.tokenizer.encode(target))
            if not (len(prefix) == prefix_len and len(target) == target_len):
                same_len = False
            seq = torch.concatenate([prefix, target], dim=-1).long()
            out.append(seq)

        # Check if all prefixes and all targets have the same length
        if not same_len:
            print('Not all prefixes or targets have the same length!!')
        else:
            print('Equal sequence lengths!')

        return out, prefix_len, target_len

    def eval(self):
        # Switch to "eval" mode when generating sequences without teacher-forcing
        self.eval_mode = True

    def train(self):
        # Switch back to "train" mode for teacher-forcing
        self.eval_mode = False

In [10]:
# LOAD TOKENIZER
from transformers import AutoTokenizer # type: ignore
tokenizer = AutoTokenizer.from_pretrained(config.gpt2_model_type)
tokenizer.pad_token_id = tokenizer.eos_token_id

# LOAD DATASET
data_path = 'deg_2_path_4_nodes_10'
train_path, test_path = data_path + '_train_200000.txt', data_path + '_test_20000.txt'
train_data = Graphs(tokenizer=tokenizer, n_samples=20000, data_path=train_path)
test_data = Graphs(tokenizer=tokenizer, n_samples=500, data_path=test_path)
train_data.train()

# sanity check
print(train_data[0], tokenizer.decode(train_data[0][0]), tokenizer.decode(train_data[0][1][-train_data.num_target_tokens:]))

# LOAD DATALOADER
train_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, drop_last=True) 
test_loader = DataLoader(test_data, batch_size=config.batch_size, shuffle=False, drop_last=True)


Equal sequence lengths!
Equal sequence lengths!
(tensor([21, 11, 24, 91, 22, 11, 21, 91, 24, 11, 15, 91, 20, 11, 17, 91, 22, 11,
        19, 91, 19, 11, 20, 14, 22, 11, 17, 28, 22, 11, 19, 11, 20, 11]), tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, 22, 11, 19, 11, 20, 11, 17])) 6,9|7,6|9,0|5,2|7,4|4,5/7,2=7,4,5, 7,4,5,2


In [11]:
print("tokenizer vocab size: ", tokenizer.vocab_size, train_data.num_prefix_tokens, train_data.num_target_tokens)

tokenizer vocab size:  50257 28 7


In [12]:
def param_decay_mask(params: FrozenDict) -> FrozenDict:
    """ pytree mask for non-bias parameters """
    flat_params = flax.traverse_util.flatten_dict(params)
    flat_param_mask = {k: k[-1] not in ('bias', 'embedding', 'scale') for k in flat_params.keys()}
    param_mask = flax.traverse_util.unflatten_dict(flat_param_mask)
    return frozen_dict.freeze(param_mask)

def init_train_state(key, config: TrainConfig, learning_rate) -> TrainState:

    if config.remat:
        model = flax.linen.remat(GPT,
            static_argnums=(2,),
            policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)(config.model)
    else:
        config.model, params = get_pretrained_params(config.gpt2_model_type)
        model = GPT(config.model)    
        model.init(key)

    optimizer = optax.chain(
        # Apply weight decay only to non-bias parameters
        optax.clip_by_global_norm(config.grad_clip),
        optax.adamw(learning_rate, *config.betas, weight_decay=config.weight_decay, mask=param_decay_mask(params)),
        optax.apply_every(config.gradient_accumulation_steps),
    )

    train_state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer)

    return train_state

def count_params(params: FrozenDict) -> int:
    p = jax.tree_util.tree_map(lambda a: a.size if isinstance(a, jnp.ndarray) else 0, params)
    return jax.tree_util.tree_reduce(lambda a, b: a + b, p)

In [13]:
# =====  init parameters ============
key = jax.random.PRNGKey(config.seed)
key, key_params, key_dropout, key_generation = jax.random.split(key, 4)
# make sure dropout keys are different for each device (local and global)
key_dropout = jax.random.fold_in(key_dropout, jax.process_index())
keys_dropout = jax.random.split(key_dropout, jax.local_device_count())
key_gen = jax.random.split(key_generation, jax.local_device_count())

In [14]:
learning_rate = config.learning_rate.init_value
train_state = init_train_state(key_params, config, learning_rate)
num_params = count_params(train_state.params)

loading weights from pretrained gpt: gpt2


In [15]:
print(f"Total parameters: {num_params:,}") # 774,030,080 for gpt2-large

Total parameters: 124,439,808


In [16]:
from flax.core import FrozenDict, freeze, unfreeze
from transformers import FlaxGPT2LMHeadModel
hf_model = FlaxGPT2LMHeadModel.from_pretrained(config.gpt2_model_type)
hf_params = hf_model.init_weights(key_params, (2, config.model.block_size))

In [17]:
# replicate model
train_state = replicate(train_state)
hf_params = replicate(hf_params)


In [18]:
def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
    if valid is None:
        valid = jnp.ones(tokens.shape[:2])
    valid = valid.astype(jnp.float32)
    valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
    logits = logits.astype(jnp.float32)  # for numerical stability
    token_log_prob = jnp.squeeze(
        jnp.take_along_axis(
            jax.nn.log_softmax(logits, axis=-1),
            jnp.expand_dims(tokens, -1),
            axis=-1,
        ),
        -1,
    )
    token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
    loss = -(jnp.sum(token_log_prob) / jnp.sum(valid))
    # old: loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
    # changed to match hf implementation
    correct = jnp.where(
        valid > 0.0,
        jnp.argmax(logits, axis=-1) == tokens,
        jnp.array(False)
    )
    accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
    return loss, accuracy


@partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0, 0))
def train_step(state: TrainState, input_tokens: jnp.ndarray, target_tokens: jnp.ndarray, dropout_key):
    dropout_key = jax.random.fold_in(dropout_key, state.step)
    def loss_fn(params: FrozenDict) -> jnp.ndarray:
        logits = state.apply_fn(params, input_tokens, False, rngs={'dropout': dropout_key})
        
        logits = logits.astype(jnp.float32)  # for numerical stability
        token_log_prob = jnp.squeeze(
            jnp.take_along_axis(
                jax.nn.log_softmax(logits, axis=-1),
                jnp.expand_dims(target_tokens, -1),
                axis=-1,
            ),
            -1,
        )
        prob_hard_token = jnp.exp(token_log_prob[:, train_data.num_prefix_tokens+1]).mean()
        
        loss, acc = cross_entropy_loss_and_accuracy(
            logits, target_tokens, (target_tokens > 0).astype(jnp.int32))
        
        return loss, (prob_hard_token, acc)
    # per-device loss and grads
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (prob_hard_token, acc)), grads = grad_fn(state.params)
    # average gradients across devices
    prob_hard_token = jax.lax.pmean(prob_hard_token, axis_name="batch")
    grads = jax.lax.pmean(grads, axis_name="batch")
    loss = jax.lax.pmean(loss, axis_name="batch")
    acc = jax.lax.pmean(acc, axis_name="batch")
    new_state = state.apply_gradients(grads=grads)
    
    return loss, acc, prob_hard_token, new_state


from flax.traverse_util import flatten_dict, unflatten_dict

def convert_jax_params_to_hf(hf_params, jax_params) -> FrozenDict:
    hf_params = unfreeze(hf_params)
    
    for k in ['ln_f', 'wpe', 'wte']:
        hf_params['transformer'][k] = jax_params[k]
    for k in hf_params['transformer']['h'].keys():
        hf_params['transformer']['h'][k] = jax_params[k] 

    hf_params = flatten_dict(hf_params, sep='.')
    for k in hf_params.keys():
        if k.endswith('attn.c_attn.kernel'):
            hf_params[k] = hf_params[k].T
        elif k.endswith('attn.c_proj.kernel'):
            hf_params[k] = hf_params[k].T
        elif len(k.split('.')) > 3 and k.split('.')[3] == 'mlp' and k.endswith('kernel'):
            hf_params[k] = hf_params[k].T
    hf_params = unflatten_dict({k: v for k, v in hf_params.items()}, sep='.')
    return freeze(hf_params)


@partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0, 0))
def eval_step(hf_params, state, input_tokens: jnp.ndarray, target_tokens: jnp.ndarray):
    hf_params = convert_jax_params_to_hf(hf_params, state.params['params'])
    output = hf_model.generate(
        input_tokens[:, :train_data.num_prefix_tokens],
        params=hf_params,
        max_new_tokens=train_data.num_target_tokens,
        min_length=train_data.num_target_tokens+train_data.num_prefix_tokens, 
        do_sample=False, 
        attention_mask=jnp.ones_like(input_tokens[:, :train_data.num_prefix_tokens]))
    acc = ((output[0][:, -train_data.num_target_tokens:] == target_tokens[:, -train_data.num_target_tokens:]).sum(1) == train_data.num_target_tokens).mean()
    acc = jax.lax.pmean(acc, axis_name="batch")
    return acc

max_new_tokens = train_data.num_target_tokens
num_beams=5 
num_return_sequences=5
temperature=1.0

from flax.core import FrozenDict, freeze, unfreeze

@partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0, 0))
def generate_negative_data(hf_params, train_state, input_tokens, key):
    hf_params = convert_jax_params_to_hf(hf_params, train_state.params['params'])
    return hf_model.generate(
        input_tokens[:, :train_data.num_prefix_tokens],
        params=hf_params,
        max_new_tokens=max_new_tokens, 
        min_length=train_data.num_target_tokens+train_data.num_prefix_tokens,
        prng_key=key, 
        num_beams=5, 
        num_return_sequences=5, 
        temperature=1.0,
        attention_mask=jnp.ones_like(input_tokens[:, :train_data.num_prefix_tokens]))



def evaluate(hf_params, state: TrainState, loader: DataLoader) -> jnp.ndarray:
    accs = []
    for batch in loader:
        input_tokens, target_tokens = batch
        input_tokens = jnp.array(input_tokens)
        target_tokens = jnp.array(target_tokens)
        input_tokens = input_tokens.reshape(jax.local_device_count(), -1, input_tokens.shape[-1])
        target_tokens = target_tokens.reshape(jax.local_device_count(), -1, target_tokens.shape[-1])
        acc = eval_step(hf_params, state, input_tokens, target_tokens)
        accs.append(acc)
    return jnp.mean(jnp.stack(accs))

In [19]:
class AverageMeter:
    def __init__(self):
        self.num = 0
        self.val = 0

    def update(self, val, num):
        self.val += val * num
        self.num += num

    def get(self, percentage=False):
        val = self.val / self.num * 100 if percentage else self.val / self.num
        return val

In [20]:
# with jax.disable_jit():
train_iter = iter(train_loader)
pbar = tqdm(range(8000), total=1000, desc='training')
train_loss, train_acc, phard = AverageMeter(), AverageMeter(), AverageMeter() 
policy_train_acc, policy_test_acc = 0., 0.
for step in pbar:
    try:
        input_tokens, target_tokens = next(train_iter)
        target_tokens[:, train_data.num_prefix_tokens+2] = -1
    except StopIteration:
        train_iter = iter(train_loader)
    input_tokens = jnp.array(input_tokens)
    target_tokens = jnp.array(target_tokens) 
    input_tokens = input_tokens.reshape(jax.device_count(), -1, input_tokens.shape[-1])
    target_tokens = target_tokens.reshape(jax.device_count(), -1, target_tokens.shape[-1])
    loss, acc, phard_token, train_state = train_step(train_state, input_tokens, target_tokens, keys_dropout)
    train_loss.update(loss.mean(), input_tokens.shape[1] * jax.device_count())  
    train_acc.update(acc.mean(), input_tokens.shape[1] * jax.device_count())    
    phard.update(phard_token.mean(), input_tokens.shape[1] * jax.device_count())
    if step % 10 == 0:
        pbar.set_description(f'train loss: {train_loss.get()} phard: {phard.get()} forcing train acc: {train_acc.get(percentage=True)} policy train acc: {100. * policy_train_acc} policy test acc: {100. * policy_test_acc}')
    if step % config.eval_interval == 0:
        policy_train_acc = evaluate(hf_params, train_state, train_loader)
        policy_test_acc = evaluate(hf_params, train_state, test_loader)
        train_loss, train_acc = AverageMeter(), AverageMeter()

  x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
train loss: 1.9982390403747559 phard: 0.10577373206615448 forcing train acc: 35.286460876464844 policy train acc: 0.0 policy test acc: 0.0:   0%|          | 0/1000 [01:28<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
train loss: 0.000702388642821461 phard: 0.7955995798110962 forcing train acc: 99.98456573486328 policy train acc: 99.9949951171875 policy test acc: 100.0: : 6329it [07:18, 21.39it/s]                                    

In [17]:
target_tokens[0]

Array([[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 16, 11, 24, -1, -1,
        -1, -1],
       [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 23, 11, 22, -1, -1,
        -1, -1],
       [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 19, 11, 18, -1, -1,
        -1, -1],
       [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 19, 11, 21, -1, -1,
        -1, -1],
       [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 15, 11, 22, -1, -1,
        -1, -1],
       [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 22, 11, 20, -1, -1,
        -1, -1],
       [-1, -1, -1, -1, -1, -1, -1

In [None]:
# input_tokens, target_tokens = next(train_iter)
tokenizer.batch_decode(target_tokens.reshape(-1, target_tokens.shape[-1])[:, train_data.num_prefix_tokens:])

[',1,5,6',
 ',9,6,1',
 ',4,3,5',
 ',8,5,9',
 ',1,7,2',
 ',6,4,5',
 ',1,7,4',
 ',0,8,7',
 ',0,6,1',
 ',1,8,2',
 ',8,5,6',
 ',2,1,5',
 ',5,6,0',
 ',5,7,8',
 ',1,3,9',
 ',6,2,9',
 ',7,1,6',
 ',7,5,0',
 ',8,7,9',
 ',1,8,6',
 ',7,8,2',
 ',1,0,2',
 ',9,8,3',
 ',0,1,8',
 ',6,5,7',
 ',1,8,7',
 ',3,0,5',
 ',6,7,9',
 ',3,9,0',
 ',6,3,2',
 ',6,1,0',
 ',3,7,8',
 ',0,8,1',
 ',1,7,2',
 ',4,0,3',
 ',4,5,7',
 ',4,5,6',
 ',3,4,6',
 ',5,2,4',
 ',7,5,8',
 ',6,2,0',
 ',0,3,7',
 ',6,2,1',
 ',9,5,4',
 ',9,3,2',
 ',1,7,8',
 ',6,1,3',
 ',1,0,9',
 ',7,3,2',
 ',2,8,5',
 ',1,5,7',
 ',5,9,1',
 ',0,8,1',
 ',7,8,0',
 ',7,0,4',
 ',5,8,0',
 ',4,3,8',
 ',1,4,5',
 ',3,5,7',
 ',1,6,9',
 ',4,0,6',
 ',3,0,2',
 ',8,5,6',
 ',8,4,1']

In [None]:
def get_log_likelihood(state, input_tokens, target_tokens, dropout_key):
    logits = state.apply_fn(state.params, input_tokens, False, rngs={'dropout': dropout_key})
    valid = (target_tokens > 0).astype(jnp.float32)
    valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
    
    logits = logits.astype(jnp.float32)  # for numerical stability
    token_log_prob = jnp.squeeze(
        jnp.take_along_axis(
            jax.nn.log_softmax(logits, axis=-1),
            jnp.expand_dims(target_tokens, -1),
            axis=-1,
        ),
        -1,
    )
    token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
    return token_log_prob 

In [None]:
get_log_likelihood(unreplicate(train_state), input_tokens.reshape(-1, input_tokens.shape[-1]), target_tokens.reshape(-1, target_tokens.shape[-1]), keys_dropout[0])[0, train_data.num_prefix_tokens:]

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


Array([-2.9248971e-04, -5.0683767e-01, -1.5708758e-04, -2.1932879e-05,
       -6.4020278e-05, -1.5498859e-05], dtype=float32)

target_tokens[0, train_data.num_prefix_tokens:]

In [None]:
from copy import deepcopy
sft_train_state = deepcopy(jax.device_put(train_state, device=jax.devices('cpu')[0]))

In [None]:
original[None, :, :].repeat(num_return_sequences, 1, 1).shape

TypeError: repeat() takes from 2 to 3 positional arguments but 4 were given

100%|██████████| 31/31 [00:01<00:00, 21.81it/s]


In [None]:
generated_dataset.shape

(9920, 35)

In [None]:
import numpy as np
tokenizer.batch_decode(np.concatenate([original_dataset[10:20], generated_dataset[10:20]], axis=1))


['5,2|2,6|4,9|0,5|0,4|9,7/0,6=0,5,2,65,2|2,6|4,9|0,5|0,4|9,7/0,6=0,5,2,6',
 '5,2|2,6|4,9|0,5|0,4|9,7/0,6=0,5,2,65,2|2,6|4,9|0,5|0,4|9,7/0,6=0,4,9,7',
 '5,2|2,6|4,9|0,5|0,4|9,7/0,6=0,5,2,65,2|2,6|4,9|0,5|0,4|9,7/0,6=0,3,6,2',
 '5,2|2,6|4,9|0,5|0,4|9,7/0,6=0,5,2,65,2|2,6|4,9|0,5|0,4|9,7/0,6=0,0,5,2',
 '5,2|2,6|4,9|0,5|0,4|9,7/0,6=0,5,2,65,2|2,6|4,9|0,5|0,4|9,7/0,6=0,9,7,0',
 '7,5|2,7|3,1|9,6|9,2|6,3/9,1=9,6,3,17,5|2,7|3,1|9,6|9,2|6,3/9,1=9,2,7,5',
 '7,5|2,7|3,1|9,6|9,2|6,3/9,1=9,6,3,17,5|2,7|3,1|9,6|9,2|6,3/9,1=9,6,3,1',
 '7,5|2,7|3,1|9,6|9,2|6,3/9,1=9,6,3,17,5|2,7|3,1|9,6|9,2|6,3/9,1=9,3,1,9',
 '7,5|2,7|3,1|9,6|9,2|6,3/9,1=9,6,3,17,5|2,7|3,1|9,6|9,2|6,3/9,1=9,9,2,7',
 '7,5|2,7|3,1|9,6|9,2|6,3/9,1=9,6,3,17,5|2,7|3,1|9,6|9,2|6,3/9,1=9,1,7,5']

In [None]:

def get_token_level_scores(original_dataset, generated_dataset):
    token_scores_arr = []
    for i in range(original_dataset.shape[0]):
        original_seq = original_dataset[i]
        generated_seq = generated_dataset[i]
        token_scores = (original_seq == generated_seq).float()
        incorrect = torch.where(token_scores==0)
        if len(incorrect[0]) > 0:
            token_scores[incorrect[0][0].item()] = -1. 
            token_scores[incorrect[0][0].item()+1:] = 0. 
        token_scores[:train_data.num_prefix_tokens] = 0.
        token_scores_arr.append(token_scores)
    return torch.stack(token_scores_arr, 0)

In [None]:

# token_lvl_scores.shape
# A[10:20]
# torch.where(A[1]==0)[0][0].item()

100%|██████████| 31/31 [00:01<00:00, 22.12it/s]


torch.Size([9920, 35])

In [None]:
def signed_log_sigmoid_loss_and_accuracy(logits, tokens, scores, valid=None):
    if valid is None:
        valid = jnp.ones(tokens.shape[:2])
    valid = valid.astype(jnp.float32)
    scores = scores.astype(jnp.float32)
    valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
    logits = logits.astype(jnp.float32)  # for numerical stability
    token_log_prob = jnp.squeeze(
        jnp.take_along_axis(
            jax.nn.log_softmax(logits, axis=-1),
            jnp.expand_dims(tokens, -1),
            axis=-1,
        ),
        -1,
    )
    sign_token_log_prob = jnp.where(valid > 0.0, token_log_prob * scores, jnp.array(0.0))
    # sign_seq_log_prob = jnp.sum(sign_token_log_prob, axis=-1) /valid_text_length
    # # loss = -jax.nn.log_sigmoid(sign_seq_log_prob).mean()
    # loss = -sign_seq_log_prob.mean()

    # loss = -(sign_token_log_prob.sum(axis=-1)).mean()
    loss = -jax.nn.log_sigmoid(sign_token_log_prob.sum(axis=-1) / valid_text_length).mean()

    # old: loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
    # changed to match hf implementation
    correct = jnp.where(
        (valid > 0.0) & (scores > 0.0),
        jnp.argmax(logits, axis=-1) == tokens,
        jnp.array(False)
    )
    accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
    return loss, accuracy

@partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0, 0, 0))
def train_step_onpolicy(state: TrainState, input_tokens: jnp.ndarray, target_tokens: jnp.ndarray, sign, dropout_key) -> Tuple[jnp.ndarray, TrainState]:
    dropout_key = jax.random.fold_in(dropout_key, state.step)
    sign = sign.reshape(-1)[:, None].repeat(input_tokens.shape[-1], 1)
    def loss_fn(params: FrozenDict) -> jnp.ndarray:
        logits = state.apply_fn(params, input_tokens, False, rngs={'dropout': dropout_key})
        loss, acc = signed_log_sigmoid_loss_and_accuracy(
            logits, target_tokens, scores, (scores != 0).astype(jnp.int32))
        return loss, acc

    # per-device loss and grads
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, acc), grads = grad_fn(state.params)
    # average gradients across devices
    grads = jax.lax.pmean(grads, axis_name="batch")
    loss = jax.lax.pmean(loss, axis_name="batch")
    acc = jax.lax.pmean(acc, axis_name="batch")
    new_state = state.apply_gradients(grads=grads)
    return loss, acc, new_state


import numpy as np


for epoch in range(10):


    import numpy as np
    original_dataset = []
    generated_dataset = []
    # scores_dataset = []

    for input_tokens, target_tokens in tqdm(iter(train_loader)): 
        input_tokens = jnp.array(input_tokens)
        target_tokens = jnp.array(target_tokens)
        original = jnp.concatenate([input_tokens[:, :train_data.num_prefix_tokens], target_tokens[:, -train_data.num_target_tokens:]], axis=1)
        
        input_tokens = input_tokens.reshape(jax.device_count(), -1, input_tokens.shape[-1])

        generations = generate_negative_data(hf_params, train_state, input_tokens, key_gen)
        
        repeated_original = jnp.repeat(original[None, :, :], num_return_sequences, 0).transpose(1, 0, 2)
        
        original_dataset.append(repeated_original.reshape(-1, repeated_original.shape[-1]))
        generated_dataset.append(generations[0].reshape(-1, generations[0].shape[-1]))
        
        # scores_dataset.append(generations[1].reshape(-1, generations[1].shape[-1]))

    # scores_dataset = jnp.concatenate([x.reshape(-1) for x in scores_dataset], axis=0).squeeze()
    generated_dataset = jnp.concatenate(generated_dataset, axis=0)
    original_dataset = jnp.concatenate(original_dataset, axis=0)
    token_lvl_scores = get_token_level_scores( torch.tensor(np.array(original_dataset)), torch.tensor(np.array(generated_dataset)))

    on_policy_dataset = torch.utils.data.TensorDataset(
    torch.tensor(np.array(original_dataset)), torch.tensor(np.array(generated_dataset)), token_lvl_scores)

    on_policy_loader = DataLoader(on_policy_dataset, batch_size=64, shuffle=True, drop_last=True)
    on_policy_iter = iter(on_policy_loader) 

    pbar = tqdm(on_policy_loader, total=len(on_policy_loader), desc='training on-policy')

    train_loss, train_acc, phard = AverageMeter(), AverageMeter(), AverageMeter() 
    policy_train_acc, policy_test_acc = 0., 0.

    step = 0.

    for original, generation, scores in pbar:

        input_tokens = original[:, :-1].clone()
        # target_tokens = original.clone()
        pred_tokens = generation.clone()
        # target_tokens[:, :train_data.num_prefix_tokens] = -1
        pred_tokens[:, :train_data.num_prefix_tokens] = -1
        # target_tokens = target_tokens[:, 1:]
        pred_tokens = pred_tokens[:, 1:]
        # sign = 2 * (target_tokens[:, -1] == pred_tokens[:, -1]).float() - 1.

        # input_tokens = torch.cat([input_tokens, input_tokens], axis=0)
        # target_tokens = torch.cat([target_tokens, pred_tokens], axis=0)    
        # sign = torch.cat([torch.ones(original.shape[0]), sign], axis=0)
        
        input_tokens = jnp.array(input_tokens)
        pred_tokens = jnp.array(pred_tokens)
        scores = jnp.array(scores[:, 1:])

        input_tokens = input_tokens.reshape(jax.device_count(), -1, input_tokens.shape[-1])
        pred_tokens = pred_tokens.reshape(jax.device_count(), -1, pred_tokens.shape[-1])
        scores = scores.reshape(jax.device_count(), -1, scores.shape[-1])
        
        # print(input_tokens.shape, target_tokens.shape, scores.shape)

        loss, acc, train_state = train_step_onpolicy(train_state, input_tokens, pred_tokens, scores, keys_dropout)
        train_loss.update(loss.mean(), input_tokens.shape[1] * jax.device_count())  
        train_acc.update(acc.mean(), input_tokens.shape[1] * jax.device_count())    
        # phard.update(phard_token.mean(), input_tokens.shape[1] * jax.device_count())
        if step % 10 == 0:
            pbar.set_description(f'train loss: {train_loss.get()} forcing train acc: {train_acc.get(percentage=True)} policy train acc: {100. * policy_train_acc} policy test acc: {100. * policy_test_acc}')
        if step % config.eval_interval == 0:
            policy_train_acc = evaluate(hf_params, train_state, train_loader)
            policy_test_acc = evaluate(hf_params, train_state, test_loader)
            train_loss, train_acc, phard = AverageMeter(), AverageMeter(), AverageMeter()
        
        step += 1



100%|██████████| 31/31 [00:01<00:00, 21.94it/s]
train loss: 0.03486354649066925 forcing train acc: 66.58678436279297 policy train acc: 0.0 policy test acc: 0.0: 100%|██████████| 155/155 [01:32<00:00,  1.67it/s]
100%|██████████| 31/31 [00:01<00:00, 22.44it/s]
train loss: 0.031758084893226624 forcing train acc: 66.65966796875 policy train acc: 0.0 policy test acc: 0.0: 100%|██████████| 155/155 [00:06<00:00, 23.16it/s]   
100%|██████████| 31/31 [00:01<00:00, 22.36it/s]
train loss: 0.009832145646214485 forcing train acc: 66.4930191040039 policy train acc: 2.2177419662475586 policy test acc: 3.348214626312256: 100%|██████████| 155/155 [00:07<00:00, 21.77it/s] 
100%|██████████| 31/31 [00:01<00:00, 22.35it/s]
train loss: 0.009871372953057289 forcing train acc: 65.42710876464844 policy train acc: 0.0 policy test acc: 0.0: 100%|██████████| 155/155 [00:06<00:00, 22.95it/s]
100%|██████████| 31/31 [00:01<00:00, 22.37it/s]
train loss: 0.09410448372364044 forcing train acc: 66.65619659423828 policy 

In [None]:
inp.shape, t.shape

((64, 34), (64, 34))

In [None]:
tokenizer.decode([31069])

' Hearth'

In [None]:
jnp.ones_like(inp[:train_data.num_prefix_tokens]).shape

(28, 34)

In [None]:
tokenizer.batch_decode(output[0])

['9,8|3,6|8,0|0,7|5,3|9,5/9,6= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '2,5|4,0|4,9|8,1|0,8|9,2/4,1= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '0,5|0,7|5,2|2,6|8,3|7,8/0,6= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '5,0|1,6|7,1|4,2|2,5|4,7/4,0= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '0,6|2,3|1,0|3,7|1,2|6,5/1,7= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '5,0|0,4|1,9|7,5|9,6|7,1/7,4= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '8,9|3,0|7,4|6,7|9,3|8,6/8,4= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '4,7|9,5|5,0|0,1|9,4|7,8/9,8= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '3,5|1,3|8,7|6,2|2,8|6,1/6,5= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '3,0|5,2|8,5|8,3|0,9|2,7/8,9= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '4,5|7,1|5,3|2,4|2,6|6,7/2,1= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '3,2|5,1|9,7|4,5|7,3|9,4/9,2= Hearth Hearth Hearth Hearth Hearth Hearth Hearth',
 '5,1|4,3|9,7|9,

In [None]:
s.params['params'].keys(), hf['transformer'].keys()

(frozen_dict_keys(['0', '1', '10', '11', '2', '3', '4', '5', '6', '7', '8', '9', 'ln_f', 'wpe', 'wte']),
 dict_keys(['h', 'ln_f', 'wpe', 'wte']))

In [None]:
hf_conv = convert_jax_params_to_hf(hf, s.params['params'])
s.params['params'].keys(), hf['transformer'].keys(), hf_conv['params']['transformer'].keys()

transformer.h.0.attn.c_attn.bias ['transformer', 'h', '0', 'attn', 'c_attn', 'bias']
transformer.h.0.attn.c_attn.kernel ['transformer', 'h', '0', 'attn', 'c_attn', 'kernel']
transformer.h.0.attn.c_proj.bias ['transformer', 'h', '0', 'attn', 'c_proj', 'bias']
transformer.h.0.attn.c_proj.kernel ['transformer', 'h', '0', 'attn', 'c_proj', 'kernel']
transformer.h.0.ln_1.bias ['transformer', 'h', '0', 'ln_1', 'bias']
transformer.h.0.ln_1.scale ['transformer', 'h', '0', 'ln_1', 'scale']
transformer.h.0.ln_2.bias ['transformer', 'h', '0', 'ln_2', 'bias']
transformer.h.0.ln_2.scale ['transformer', 'h', '0', 'ln_2', 'scale']
transformer.h.0.mlp.c_fc.bias ['transformer', 'h', '0', 'mlp', 'c_fc', 'bias']
transformer.h.0.mlp.c_fc.kernel ['transformer', 'h', '0', 'mlp', 'c_fc', 'kernel']
transformer.h.0.mlp.c_proj.bias ['transformer', 'h', '0', 'mlp', 'c_proj', 'bias']
transformer.h.0.mlp.c_proj.kernel ['transformer', 'h', '0', 'mlp', 'c_proj', 'kernel']
transformer.h.1.attn.c_attn.bias ['transform

KeyError: 'params'

In [None]:
# evaluate(hf_params, train_state, train_loader)
# eval_step(hf_params, train_state, input_tokens, target_tokens)
inp = input_tokens.reshape(-1, input_tokens.shape[-1])
t = target_tokens.reshape(-1, target_tokens.shape[-1])

# hf = unfreeze(unreplicate(hf_params))
# s = unreplicate(train_state)
hf_conv = convert_jax_params_to_hf(hf_params, train_state.params['params'])
train_state.params['params'].keys(), hf_params['transformer'].keys(), hf_conv['transformer'].keys()

output = hf_model.generate(
    inp[:, :train_data.num_prefix_tokens], params=hf_conv, max_new_tokens=train_data.num_target_tokens, min_length=train_data.num_target_tokens+train_data.num_prefix_tokens, do_sample=False, attention_mask=jnp.ones_like(inp[:, :train_data.num_prefix_tokens]))
print(output[0][:, -train_data.num_target_tokens:], t[:, -train_data.num_target_tokens:])
acc = ((output[0][:, -train_data.num_target_tokens:] == target_tokens[:, -train_data.num_target_tokens:]).sum(1) == train_data.num_target_tokens).mean()
acc = jax.lax.pmean(acc, axis_name="batch")


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


transformer.h.0.attn.c_attn.bias ['transformer', 'h', '0', 'attn', 'c_attn', 'bias']
transformer.h.0.attn.c_attn.kernel ['transformer', 'h', '0', 'attn', 'c_attn', 'kernel']
transformer.h.0.attn.c_proj.bias ['transformer', 'h', '0', 'attn', 'c_proj', 'bias']
transformer.h.0.attn.c_proj.kernel ['transformer', 'h', '0', 'attn', 'c_proj', 'kernel']
transformer.h.0.ln_1.bias ['transformer', 'h', '0', 'ln_1', 'bias']
transformer.h.0.ln_1.scale ['transformer', 'h', '0', 'ln_1', 'scale']
transformer.h.0.ln_2.bias ['transformer', 'h', '0', 'ln_2', 'bias']
transformer.h.0.ln_2.scale ['transformer', 'h', '0', 'ln_2', 'scale']
transformer.h.0.mlp.c_fc.bias ['transformer', 'h', '0', 'mlp', 'c_fc', 'bias']
transformer.h.0.mlp.c_fc.kernel ['transformer', 'h', '0', 'mlp', 'c_fc', 'kernel']
transformer.h.0.mlp.c_proj.bias ['transformer', 'h', '0', 'mlp', 'c_proj', 'bias']
transformer.h.0.mlp.c_proj.kernel ['transformer', 'h', '0', 'mlp', 'c_proj', 'kernel']
transformer.h.1.attn.c_attn.bias ['transform

ValueError: Incompatible shapes for broadcasting: shapes=[(64, 7), (4, 7, 34)]

In [None]:
hf_conv.keys(   )

frozen_dict_keys(['k'])

In [None]:
hf_conv['params']['transformer']['h']['0']['attn']['c_attn']['kernel'].shape, hf_params['transformer']['h']['0']['attn']['c_attn']['kernel'].shapea

((2304, 768), (2304, 768))

In [None]:
input_tokens = input_tokens.reshape(-1, input_tokens.shape[-1])
target_tokens = target_tokens.reshape(-1, target_tokens.shape[-1])
hf_params = unfreeze(hf_params)
for k in ('ln_f', 'wpe', 'wte'):
    hf_params['transformer'][k] = train_state.params['params'][k]
hf_model.generate(
    input_tokens, params=hf_params, max_new_tokens=max_new_tokens, prng_key=key, num_beams=5, num_return_sequences=5, temperature=1.0, attention_mask=jnp.ones_like(input_tokens))

  0%|          | 0/31 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
100%|██████████| 31/31 [01:13<00:00,  2.38s/it]


In [None]:
scores_dataset.shape, generated_dataset.shape, original_dataset.shape   

((9920,), (9920, 29), (9920, 29))

In [None]:
def signed_log_sigmoid_loss_and_accuracy(logits, tokens, sign, valid=None):
    if valid is None:
        valid = jnp.ones(tokens.shape[:2])
    valid = valid.astype(jnp.float32)
    sign = sign.astype(jnp.float32)
    valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
    logits = logits.astype(jnp.float32)  # for numerical stability
    token_log_prob = jnp.squeeze(
        jnp.take_along_axis(
            jax.nn.log_softmax(logits, axis=-1),
            jnp.expand_dims(tokens, -1),
            axis=-1,
        ),
        -1,
    )
    sign_token_log_prob = jnp.where(valid > 0.0, token_log_prob * sign, jnp.array(0.0))
    # sign_seq_log_prob = jnp.sum(sign_token_log_prob, axis=-1) /valid_text_length
    # # loss = -jax.nn.log_sigmoid(sign_seq_log_prob).mean()
    # loss = -sign_seq_log_prob.mean()

    loss = -jax.nn.log_sigmoid(sign_token_log_prob.sum(axis=-1) / valid_text_length).mean()

    # old: loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
    # changed to match hf implementation
    correct = jnp.where(
        (valid > 0.0) & (sign > 0.0),
        jnp.argmax(logits, axis=-1) == tokens,
        jnp.array(False)
    )
    accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
    return loss, accuracy

@partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0, 0, 0))
def train_step_onpolicy(state: TrainState, input_tokens: jnp.ndarray, target_tokens: jnp.ndarray, sign, dropout_key) -> Tuple[jnp.ndarray, TrainState]:
    dropout_key = jax.random.fold_in(dropout_key, state.step)
    sign = sign.reshape(-1)[:, None].repeat(input_tokens.shape[-1], 1)
    def loss_fn(params: FrozenDict) -> jnp.ndarray:
        logits = state.apply_fn(params, input_tokens, False, rngs={'dropout': dropout_key})
        loss, acc = signed_log_sigmoid_loss_and_accuracy(
            logits, target_tokens, sign, (target_tokens > 0).astype(jnp.int32))
        return loss, acc

    # per-device loss and grads
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, acc), grads = grad_fn(state.params)
    # average gradients across devices
    grads = jax.lax.pmean(grads, axis_name="batch")
    loss = jax.lax.pmean(loss, axis_name="batch")
    acc = jax.lax.pmean(acc, axis_name="batch")
    new_state = state.apply_gradients(grads=grads)
    return loss, acc, new_state


import numpy as np
on_policy_dataset = torch.utils.data.TensorDataset(
    torch.tensor(np.array(original_dataset)), torch.tensor(np.array(generated_dataset)), torch.tensor(np.array(scores_dataset)))

on_policy_loader = DataLoader(on_policy_dataset, batch_size=8, shuffle=True, drop_last=True)
on_policy_iter = iter(on_policy_loader) 

pbar = tqdm(on_policy_loader, total=len(on_policy_loader), desc='training on-policy')

train_loss, train_acc = AverageMeter(), AverageMeter()

val_loss, val_acc = jnp.inf, 0.

step = 0.

for original, generation, _ in pbar:

    input_tokens = original[:, :-1].clone()
    target_tokens = original.clone()
    pred_tokens = generation.clone()
    target_tokens[:, :train_data.num_prefix_tokens] = -1
    pred_tokens[:, :train_data.num_prefix_tokens] = -1
    target_tokens = target_tokens[:, 1:]
    pred_tokens = pred_tokens[:, 1:]
    sign = 2 * (target_tokens[:, -1] == pred_tokens[:, -1]).float() - 1.

    input_tokens = torch.cat([input_tokens, input_tokens], axis=0)
    target_tokens = torch.cat([target_tokens, pred_tokens], axis=0)    
    sign = torch.cat([torch.ones(original.shape[0]), sign], axis=0)
    
    input_tokens = jnp.array(input_tokens)
    target_tokens = jnp.array(target_tokens)
    sign = jnp.array(sign)

    input_tokens = input_tokens.reshape(jax.device_count(), -1, input_tokens.shape[-1])
    target_tokens = target_tokens.reshape(jax.device_count(), -1, target_tokens.shape[-1])
    sign = sign.reshape(jax.device_count(), -1)
    
    loss, acc, train_state = train_step_onpolicy(train_state, input_tokens, target_tokens, sign, keys_dropout)
    train_loss.update(loss.mean(), input_tokens.shape[1] * jax.device_count())  
    train_acc.update(acc.mean(), input_tokens.shape[1] * jax.device_count())    
    if step % 10 == 0:
        pbar.set_description(f'train loss: {train_loss.get()} train acc: {train_acc.get(percentage=True)} val loss: {val_loss} val acc: {val_acc}')
    if step % 50 == 0:
        val_loss, val_acc = evaluate(train_state, test_loader)
        train_loss, train_acc = AverageMeter(), AverageMeter()
    
    step += 1
    


training on-policy:   0%|          | 0/1240 [03:42<?, ?it/s]


XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 25.00M. That was not possible. There are 21.28M free.; (0x0x0_HBM0)

In [None]:
train_state = jax.device_get(train_state)   

In [None]:
train_state.tx.hyperparams

AttributeError: 'GradientTransformationExtraArgs' object has no attribute 'hyperparams'

In [None]:
train_state.tx = optax.chain(        
        optax.clip_by_global_norm(config.grad_clip),
        optax.adamw(1e-6, *config.betas, weight_decay=config.weight_decay, mask=param_decay_mask(train_state.params)),
        optax.apply_every(config.gradient_accumulation_steps),
    )
    

FrozenInstanceError: cannot assign to field 'tx'

In [None]:
B = optax.chain(
    # Apply weight decay only to non-bias parameters
    optax.clip_by_global_norm(config.grad_clip),
    optax.adamw(learning_rate, *config.betas, weight_decay=config.weight_decay),
    optax.apply_every(config.gradient_accumulation_steps),
)

In [None]:
B.update.params


AttributeError: 'function' object has no attribute 'params'