In [1]:
import sys
sys.path.append('../../')
import time
from jax_smi import initialise_tracking
initialise_tracking()
from tqdm import tqdm

In [2]:
from dataclasses import dataclass, field, asdict
from typing import Tuple, Optional, Union
from EasyLM.models.gpt2.gpt2_model import GPT, GPTConfig, get_pretrained_params
from torch.utils.data import DataLoader

@dataclass(frozen=True)
class WandbConfig:
    """
    wandb logging configuration
    """
    entity: str = 'ars22'
    """username or team name where you're sending runs"""
    project: str = 'star_graph'
    """project name"""
    name: str = 'gpt2'
    """experiment name"""
    mode: str = 'online'
    """'offline', 'online', or 'disabled'"""
    notes: str = ''


@dataclass(frozen=True)
class CosineDecayScheduleConfig:
    init_value: float = 0.0
    peak_value: float = 2.5e-4
    warmup_steps: int = 2000
    decay_steps: int = 150000
    end_value: float = 1e-5

@dataclass(frozen=True)
class StaticLRConfig:
    init_value: float = 1e-5


@dataclass(frozen=False)
class TrainConfig:
    gpt2_model_type: str = 'gpt2-large' # gpt2 model type
    seed: int = 555
    out_dir: str = 'out'                        # output directory for checkpoints (can be gcs path)
    shuffle_buffer_size: int = 128
    eval_interval: int = 500
    eval_steps: int = 16        # evaluate for this number of steps (per-device)
    eval_only: bool = False     # if True, script exits right after the first eval
    keep_checkpoints: int = 3   # number of historical checkpoints to keep
    batch_size: int = 64        # per-device batch size
    train_steps: int = 6250   # total number of training iterations
    weight_decay: float = 1e-2  # not applied to bias and embedding parameters
    grad_clip: float = 1.0      # gradient norm clipping magnitude
    gradient_accumulation_steps: int = 1    # used to simulate larger batch sizes
    betas: Tuple[float, float] = (0.9, 0.95) # adamw optimizer betas
    # learning_rate: CosineDecayScheduleConfig = field(default_factory=CosineDecayScheduleConfig)
    learning_rate: StaticLRConfig = field(default_factory=StaticLRConfig)
    wandb: WandbConfig = field(default_factory=WandbConfig) # wandb logging
    model: GPTConfig = field(default_factory=GPTConfig)     # gpt model config
    remat: bool = False    # set to True to rematerialize gradients during backward pass


def get_default_config() -> TrainConfig:
    return TrainConfig()

config = get_default_config()
config

  from .autonotebook import tqdm as notebook_tqdm


TrainConfig(gpt2_model_type='gpt2-large', seed=555, out_dir='out', shuffle_buffer_size=128, eval_interval=500, eval_steps=16, eval_only=False, keep_checkpoints=3, batch_size=64, train_steps=6250, weight_decay=0.01, grad_clip=1.0, gradient_accumulation_steps=1, betas=(0.9, 0.95), learning_rate=StaticLRConfig(init_value=1e-05), wandb=WandbConfig(entity='ars22', project='star_graph', name='gpt2', mode='online', notes=''), model=GPTConfig(block_size=1024, vocab_size=50257, num_layers=12, num_heads=12, num_embeds=768, dropout_rate=0.1, use_bias=True, dtype=None), remat=False)

In [3]:
import jax
import jax.numpy as jnp
import flax
from flax.core import FrozenDict, frozen_dict
from flax.training import checkpoints
from flax.training.train_state import TrainState
from flax.jax_utils import replicate, unreplicate
import optax
from functools import partial

2024-04-24 16:27:32.136272: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2024-04-24 16:27:32.805504: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2024-04-24 16:27:32.805589: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib


In [4]:
def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
    if valid is None:
        valid = jnp.ones(tokens.shape[:2])
    valid = valid.astype(jnp.float32)
    valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
    logits = logits.astype(jnp.float32)  # for numerical stability
    token_log_prob = jnp.squeeze(
        jnp.take_along_axis(
            jax.nn.log_softmax(logits, axis=-1),
            jnp.expand_dims(tokens, -1),
            axis=-1,
        ),
        -1,
    )
    token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
    loss = -(jnp.sum(token_log_prob) / jnp.sum(valid))
    # old: loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
    # changed to match hf implementation
    correct = jnp.where(
        valid > 0.0,
        jnp.argmax(logits, axis=-1) == tokens,
        jnp.array(False)
    )
    accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
    return loss, accuracy


@partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0, 0))
def train_step(state: TrainState, input_tokens: jnp.ndarray, target_tokens: jnp.ndarray, dropout_key) -> Tuple[jnp.ndarray, TrainState]:
    dropout_key = jax.random.fold_in(dropout_key, state.step)
    def loss_fn(params: FrozenDict) -> jnp.ndarray:
        logits = state.apply_fn(params, input_tokens, False, rngs={'dropout': dropout_key})
        return cross_entropy_loss_and_accuracy(
            logits, target_tokens, (target_tokens > 0).astype(jnp.int32))
    # per-device loss and grads
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, acc), grads = grad_fn(state.params)
    # average gradients across devices
    grads = jax.lax.pmean(grads, axis_name="batch")
    loss = jax.lax.pmean(loss, axis_name="batch")
    acc = jax.lax.pmean(acc, axis_name="batch")
    new_state = state.apply_gradients(grads=grads)
    return loss, acc, new_state


@partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0))
def eval_step(state: TrainState, input_tokens: jnp.ndarray, target_tokens: jnp.ndarray) -> jnp.ndarray:
    logits = state.apply_fn(state.params, input_tokens, True)
    loss, acc = cross_entropy_loss_and_accuracy(
            logits, target_tokens, (target_tokens > 0).astype(jnp.int32))
    loss = jax.lax.pmean(loss, axis_name="batch")
    acc = jax.lax.pmean(acc, axis_name="batch")
    return loss, acc



def evaluate(state: TrainState, loader: DataLoader) -> jnp.ndarray:
    losses = []
    accs = []
    for batch in loader:
        input_tokens, target_tokens = batch
        input_tokens = jnp.array(input_tokens)
        target_tokens = jnp.array(target_tokens)
        input_tokens = input_tokens.reshape(jax.local_device_count(), -1, input_tokens.shape[-1])
        target_tokens = target_tokens.reshape(jax.local_device_count(), -1, target_tokens.shape[-1])
        loss, acc = eval_step(state, input_tokens, target_tokens)
        losses.append(loss)
        accs.append(acc)
    return jnp.mean(jnp.stack(losses)), jnp.mean(jnp.stack(accs))

In [5]:
from torch.utils.data import Dataset
import torch

def prefix_target_list(filename=None):
    """
    Load graphs and split them into prefix and target and return the list
    """
    data_list = []
    with open(filename, 'r') as f:
        lines = f.readlines()
    for line in lines:
        prefix = line.strip().split('=')[0] + '='
        target = line.strip().split('=')[1]
        target = target.split(',')[1]
        data_list.append((prefix, target))
    return data_list


class Graphs(Dataset):
    def __init__(self, tokenizer, n_samples, data_path):
        self.tokenizer = tokenizer
        self.n_samples = n_samples
        self.data_path = data_path
        self.eval_mode = False
        self.data_file = prefix_target_list(self.data_path)
        self.tokenized, self.num_prefix_tokens, self.num_target_tokens = self.tokenize(self.data_file[:n_samples])

    def __len__(self):
        return len(self.tokenized)

    def __getitem__(self, idx):
        if self.eval_mode:
            # In eval mode return the entire sequence
            return self.tokenized[idx].to(self.device)

        # Create inputs
        x = self.tokenized[idx].clone()
        y = torch.cat([-torch.ones((self.num_prefix_tokens - 1, )),
                       x[self.num_prefix_tokens:].clone()])
        return x[:-1], y.long()

    def tokenize(self, data_list):
        """
        Takes a list of prefix-target pairs, tokenizes and concatenates them
        """
        out = []
        prefix_len = len(self.tokenizer.encode(data_list[0][0]))
        target_len = len(self.tokenizer.encode(data_list[0][1]))
        same_len = True

        for prefix, target in data_list:
            prefix = torch.tensor(self.tokenizer.encode(prefix))
            target = torch.tensor(self.tokenizer.encode(target))
            if not (len(prefix) == prefix_len and len(target) == target_len):
                same_len = False
            seq = torch.concatenate([prefix, target], dim=-1).long()
            out.append(seq)

        # Check if all prefixes and all targets have the same length
        if not same_len:
            print('Not all prefixes or targets have the same length!!')
        else:
            print('Equal sequence lengths!')

        return out, prefix_len, target_len

    def eval(self):
        # Switch to "eval" mode when generating sequences without teacher-forcing
        self.eval_mode = True

    def train(self):
        # Switch back to "train" mode for teacher-forcing
        self.eval_mode = False

In [6]:
# LOAD TOKENIZER
from transformers import AutoTokenizer # type: ignore
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token_id = tokenizer.eos_token_id

# LOAD DATASET
data_path = 'deg_2_path_5_nodes_10'
train_path, test_path = data_path + '_train_200000.txt', data_path + '_test_20000.txt'
train_data = Graphs(tokenizer=tokenizer, n_samples=20000, data_path=train_path)
test_data = Graphs(tokenizer=tokenizer, n_samples=200, data_path=test_path)
train_data.train()

# sanity check
print(train_data[0], tokenizer.decode(train_data[0][0]), tokenizer.decode(train_data[0][1][-train_data.num_target_tokens:]))

# LOAD DATALOADER
train_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, drop_last=True) 
test_loader = DataLoader(test_data, batch_size=config.batch_size, shuffle=False, drop_last=True)


Equal sequence lengths!
Equal sequence lengths!
(tensor([17, 11, 23, 91, 24, 11, 15, 91, 17, 11, 22, 91, 21, 11, 19, 91, 19, 11,
        18, 91, 22, 11, 21, 91, 23, 11, 24, 91, 15, 11, 16, 14, 17, 11, 18, 28]), tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 22])) 2,8|9,0|2,7|6,4|4,3|7,6|8,9|0,1/2,3= 7


In [7]:
print("tokenizer vocab size: ", tokenizer.vocab_size, train_data.num_prefix_tokens, train_data.num_target_tokens)

tokenizer vocab size:  50257 36 1


In [8]:
def param_decay_mask(params: FrozenDict) -> FrozenDict:
    """ pytree mask for non-bias parameters """
    flat_params = flax.traverse_util.flatten_dict(params)
    flat_param_mask = {k: k[-1] not in ('bias', 'embedding', 'scale') for k in flat_params.keys()}
    param_mask = flax.traverse_util.unflatten_dict(flat_param_mask)
    return frozen_dict.freeze(param_mask)

def init_train_state(key, config: TrainConfig, learning_rate) -> TrainState:

    if config.remat:
        model = flax.linen.remat(GPT,
            static_argnums=(2,),
            policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims)(config.model)
    else:
        config.model, params = get_pretrained_params(config.gpt2_model_type)
        model = GPT(config.model)    
        model.init(key)

    optimizer = optax.chain(
        # Apply weight decay only to non-bias parameters
        optax.clip_by_global_norm(config.grad_clip),
        optax.adamw(learning_rate, *config.betas, weight_decay=config.weight_decay, mask=param_decay_mask(params)),
        optax.apply_every(config.gradient_accumulation_steps),
    )

    train_state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer)

    return train_state

def count_params(params: FrozenDict) -> int:
    p = jax.tree_util.tree_map(lambda a: a.size if isinstance(a, jnp.ndarray) else 0, params)
    return jax.tree_util.tree_reduce(lambda a, b: a + b, p)

In [9]:
# =====  init parameters ============
key = jax.random.PRNGKey(config.seed)
key, key_params, key_dropout = jax.random.split(key, 3)
# make sure dropout keys are different for each device (local and global)
key_dropout = jax.random.fold_in(key_dropout, jax.process_index())
keys_dropout = jax.random.split(key_dropout, jax.local_device_count())

In [10]:
learning_rate = config.learning_rate.init_value
train_state = init_train_state(key_params, config, learning_rate)
num_params = count_params(train_state.params)

loading weights from pretrained gpt: gpt2-large


In [11]:
print(f"Total parameters: {num_params:,}") # 774,030,080 for gpt2-large

Total parameters: 774,030,080


In [12]:
# replicate model
train_state = replicate(train_state)

In [13]:
class AverageMeter:
    def __init__(self):
        self.num = 0
        self.val = 0

    def update(self, val, num):
        self.val += val * num
        self.num += num

    def get(self, percentage=False):
        val = self.val / self.num * 100 if percentage else self.val / self.num
        return val

In [56]:
train_iter = iter(train_loader)
pbar = tqdm(range(config.train_steps), total=config.train_steps, desc='training')
train_loss, train_acc = AverageMeter(), AverageMeter()
val_loss, val_acc = jnp.inf, 0.
for step in pbar:
    try:
        input_tokens, target_tokens = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
    input_tokens = jnp.array(input_tokens)
    target_tokens = jnp.array(target_tokens) 
    input_tokens = input_tokens.reshape(jax.device_count(), -1, input_tokens.shape[-1])
    target_tokens = target_tokens.reshape(jax.device_count(), -1, target_tokens.shape[-1])
    loss, acc, train_state = train_step(train_state, input_tokens, target_tokens, keys_dropout)
    train_loss.update(loss.mean(), input_tokens.shape[1] * jax.device_count())  
    train_acc.update(acc.mean(), input_tokens.shape[1] * jax.device_count())    
    if step % 100 == 0:
        pbar.set_description(f'train loss: {train_loss.get()} train acc: {train_acc.get(percentage=True)} val loss: {val_loss} val acc: {val_acc}')
    if step % config.eval_interval == 0:
        val_loss, val_acc = evaluate(train_state, test_loader)
        train_loss, train_acc = AverageMeter(), AverageMeter()
    

training:   0%|          | 0/6250 [00:00<?, ?it/s]

train loss: 0.15650126338005066 train acc: 94.125 val loss: 1.3534506559371948 val acc: 0.5052083730697632: 100%|██████████| 6250/6250 [20:00<00:00,  5.20it/s]           


In [15]:
from transformers import FlaxGPT2LMHeadModel

FlaxGPT2LMHeadModel

AttributeError: type object 'FlaxGPT2LMHeadModel' has no attribute 'from_flax'

<bound method FlaxPreTrainedModel.from_pretrained of <class 'transformers.models.gpt2.modeling_flax_gpt2.FlaxGPT2LMHeadModel'>>

In [73]:
from transformers import FlaxGPT2LMHeadModel
hf_model = GPT(
            config=config.model,
            # input_shape=(1, train_data.num_prefix_tokens),
            # seed=22,
            # _do_init=False
            )
hf_model.init(key_params, method='can_generate')


TypeError: init() got an unexpected keyword argument 'method'

In [72]:
hf_model.can_generate()

AttributeError: "GPT" object has no attribute "can_generate". If "can_generate" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.

In [67]:
train_state = unreplicate(train_state)
FlaxGPT2LMHeadModel

In [74]:
hf_model.apply(train_state.params, jnp.ones((1, train_data.num_prefix_tokens)))

ValueError: Input type must be an integer or unsigned integer.

In [75]:
from transformers import AutoTokenizer, FlaxGPT2LMHeadModel

tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = FlaxGPT2LMHeadModel.from_pretrained("openai-community/gpt2")

In [79]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
outputs = model.generate(**inputs,
    num_beams=5,
    max_new_tokens=1,
    num_return_sequences=5,
    temperature=1.0,)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [103]:
model.config

GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.31.0",
  "use_cache": true,
  "vocab_size": 50257
}

GPTConfig(block_size=1024, vocab_size=50257, num_layers=36, num_heads=20, num_embeds=1280, dropout_rate=0.1, use_bias=True, dtype=None)

In [123]:
# type(model.config), type(config.model)
model.config.n_layer = 36
model.config.n_head = 20
model.config.n_embd = 1280
# model.config.

In [117]:
unfreeze(config.model)

AttributeError: 'GPTConfig' object has no attribute 'to_dict'

In [29]:

from transformers import AutoTokenizer, FlaxGPT2LMHeadModel
# hf_model = FlaxGPT2LMHeadModel(
#     config.model,
#     input_shape=(1, train_data.num_prefix_tokens),
#     seed=22,
#     _do_init=False
# )
hf_model = FlaxGPT2LMHeadModel.from_pretrained("openai-community/gpt2-large")

In [30]:
A = input_tokens.reshape(-1, 36)
A

tensor([[17, 11, 23, 91, 24, 11, 15, 91, 17, 11, 22, 91, 21, 11, 19, 91, 19, 11,
         18, 91, 22, 11, 21, 91, 23, 11, 24, 91, 15, 11, 16, 14, 17, 11, 18, 28]])

In [31]:
train_state.params['params']

FrozenDict({
    0: {
        attn: {
            c_attn: {
                bias: Array([-0.00806067, -0.08205216,  0.17712198, ..., -0.03094645,
                       -0.02834609,  0.0028557 ], dtype=float32),
                kernel: Array([[ 0.16554965,  0.12297295,  0.10031797, ..., -0.00807998,
                         0.0106448 , -0.01827521],
                       [-0.23444045,  0.14132349,  0.07059898, ..., -0.0105182 ,
                         0.02387178, -0.01008427],
                       [ 0.1062863 , -0.03969869,  0.10846853, ..., -0.00417321,
                         0.01832751, -0.00796596],
                       ...,
                       [ 0.00202721,  0.12571906, -0.07979144, ...,  0.00236412,
                         0.03511722,  0.02043647],
                       [-0.11458338, -0.08969299, -0.09247336, ...,  0.00126929,
                         0.00066453, -0.0041295 ],
                       [ 0.02215027, -0.01706643, -0.04627626, ...,  0.02902106,
           

In [32]:
input_tokens, target_tokens = train_data[0]
A = input_tokens.reshape(-1, input_tokens.shape[-1])

In [33]:
A = jnp.array(A)
train_state = unreplicate(train_state)

IndexError: Too many indices for array: 1 non-None/Ellipsis indices for dim 0.

In [34]:
from EasyLM.jax_utils import JaxRNG
hf_model.generate(
    A,
    params={'transformer': train_state.params['params']},
    max_new_tokens=1,
    # num_return_sequences=5,
    # beam
    prng_key=JaxRNG(key))
    # logits_processor=FlaxLogitsProcessorList(
    #     [FlaxTemperatureLogitsWarper(temperature)]
    # ),

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


ScopeParamNotFoundError: Could not find parameter named "scale" in scope "/transformer/h/0/ln_1". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamNotFoundError)

In [42]:
train_state.params['params']['0']

FrozenDict({
    attn: {
        c_attn: {
            bias: Array([-0.00806067, -0.08205216,  0.17712198, ..., -0.03094645,
                   -0.02834609,  0.0028557 ], dtype=float32),
            kernel: Array([[ 0.16554965,  0.12297295,  0.10031797, ..., -0.00807998,
                     0.0106448 , -0.01827521],
                   [-0.23444045,  0.14132349,  0.07059898, ..., -0.0105182 ,
                     0.02387178, -0.01008427],
                   [ 0.1062863 , -0.03969869,  0.10846853, ..., -0.00417321,
                     0.01832751, -0.00796596],
                   ...,
                   [ 0.00202721,  0.12571906, -0.07979144, ...,  0.00236412,
                     0.03511722,  0.02043647],
                   [-0.11458338, -0.08969299, -0.09247336, ...,  0.00126929,
                     0.00066453, -0.0041295 ],
                   [ 0.02215027, -0.01706643, -0.04627626, ...,  0.02902106,
                     0.02582104, -0.0327217 ]], dtype=float32),
        },
        c

In [98]:
from flax.core import FrozenDict, freeze, unfreeze
params = unfreeze(train_state.params['params'])

In [40]:
{'transformer': train_state.params}

{'transformer': FrozenDict({
     params: {
         0: {
             attn: {
                 c_attn: {
                     bias: Array([-0.00806067, -0.08205216,  0.17712198, ..., -0.03094645,
                            -0.02834609,  0.0028557 ], dtype=float32),
                     kernel: Array([[ 0.16554965,  0.12297295,  0.10031797, ..., -0.00807998,
                              0.0106448 , -0.01827521],
                            [-0.23444045,  0.14132349,  0.07059898, ..., -0.0105182 ,
                              0.02387178, -0.01008427],
                            [ 0.1062863 , -0.03969869,  0.10846853, ..., -0.00417321,
                              0.01832751, -0.00796596],
                            ...,
                            [ 0.00202721,  0.12571906, -0.07979144, ...,  0.00236412,
                              0.03511722,  0.02043647],
                            [-0.11458338, -0.08969299, -0.09247336, ...,  0.00126929,
                              0.00066