In [1]:
import sys
sys.path.append('../../')
from EasyLM.models.gptj.gptj_model import (
    GPTJConfig, FlaxGPTJForCausalLMModule
)
from tqdm import tqdm, trange
import sys
from jax.experimental.pjit import pjit
from EasyLM.jax_utils import (
    JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules,
    cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
    set_random_seed, average_metrics, get_weight_decay_mask,
    make_shard_and_gather_fns, with_sharding_constraint, average_metrics, get_jax_mesh
)

set_random_seed(22)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import Dataset
import torch

def prefix_target_list(filename=None):
    """
    Load graphs and split them into prefix and target and return the list
    """
    data_list = []
    with open(filename, 'r') as f:
        lines = f.readlines()
    for line in lines:
        prefix = line.strip().split('=')[0] + '='
        target = line.strip().split('=')[1]
        target = target.split(',')[1]
        data_list.append((prefix, target))
    return data_list


class Graphs(Dataset):
    def __init__(self, tokenizer, n_samples, data_path):
        self.tokenizer = tokenizer
        self.n_samples = n_samples
        self.data_path = data_path
        self.eval_mode = False
        self.data_file = prefix_target_list(self.data_path)
        self.tokenized, self.num_prefix_tokens, self.num_target_tokens = self.tokenize(self.data_file[:n_samples])

    def __len__(self):
        return len(self.tokenized)

    def __getitem__(self, idx):
        if self.eval_mode:
            # In eval mode return the entire sequence
            return self.tokenized[idx].to(self.device)

        # Create inputs
        x = self.tokenized[idx].clone()
        y = torch.cat([-torch.ones((self.num_prefix_tokens - 1, )),
                       x[self.num_prefix_tokens:].clone()])
        return x[:-1], y.long()

    def tokenize(self, data_list):
        """
        Takes a list of prefix-target pairs, tokenizes and concatenates them
        """
        out = []
        prefix_len = len(self.tokenizer.encode(data_list[0][0]))
        target_len = len(self.tokenizer.encode(data_list[0][1]))
        same_len = True

        for prefix, target in data_list:
            prefix = torch.tensor(self.tokenizer.encode(prefix))
            target = torch.tensor(self.tokenizer.encode(target))
            if not (len(prefix) == prefix_len and len(target) == target_len):
                same_len = False
            seq = torch.concatenate([prefix, target], dim=-1).long()
            out.append(seq)

        # Check if all prefixes and all targets have the same length
        if not same_len:
            print('Not all prefixes or targets have the same length!!')
        else:
            print('Equal sequence lengths!')

        return out, prefix_len, target_len

    def eval(self):
        # Switch to "eval" mode when generating sequences without teacher-forcing
        self.eval_mode = True

    def train(self):
        # Switch back to "train" mode for teacher-forcing
        self.eval_mode = False

In [3]:
from torch.utils.data import DataLoader

# LOAD TOKENIZER
from transformers import AutoTokenizer # type: ignore
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-j-6B')
tokenizer.pad_token_id = tokenizer.eos_token_id

# LOAD DATASET
data_path = 'deg_2_path_4_nodes_10'
train_path, test_path = data_path + '_train_200000.txt', data_path + '_test_20000.txt'
train_data = Graphs(tokenizer=tokenizer, n_samples=32, data_path=train_path)
test_data = Graphs(tokenizer=tokenizer, n_samples=100, data_path=test_path)
train_data.train()

# sanity check
print(train_data[0], tokenizer.decode(train_data[0][0]), tokenizer.decode(train_data[0][1][-train_data.num_target_tokens:]))

# LOAD DATALOADER
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, drop_last=True) 

Equal sequence lengths!
Equal sequence lengths!


2024-04-23 22:41:30.581446: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2024-04-23 22:41:31.262694: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib
2024-04-23 22:41:31.262793: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/lib


(tensor([21, 11, 24, 91, 22, 11, 21, 91, 24, 11, 15, 91, 20, 11, 17, 91, 22, 11,
        19, 91, 19, 11, 20, 14, 22, 11, 17, 28]), tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, 19])) 6,9|7,6|9,0|5,2|7,4|4,5/7,2= 4


In [4]:
# LOAD MODEL

gpt_config = GPTJConfig()
gpt_config.update(dict(
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
))

model = FlaxGPTJForCausalLMModule(gpt_config, dtype=get_float_dtype_by_name('fp32'))

In [5]:
import math
from EasyLM.optimizers import OptimizerFactory

num_epochs = 100
warmup_ratio = 1.
learning_rate = 1e-5
weight_decay = 0.
end_lr = 1e-5
simulated_steps_per_epoch = len(train_loader)

optimizer_config = OptimizerFactory.get_default_config()
total_simulated_steps = num_epochs * simulated_steps_per_epoch
optimizer_config.adamw_optimizer.clip_gradient = 10000.
optimizer_config.adamw_optimizer.lr_decay_steps = total_simulated_steps
optimizer_config.adamw_optimizer.warmup_ratio = warmup_ratio
optimizer_config.adamw_optimizer.end_lr = end_lr
optimizer_config.adamw_optimizer.init_lr = 1e-5
optimizer_config.adamw_optimizer.lr = learning_rate
optimizer_config.adamw_optimizer.weight_decay = weight_decay

if optimizer_config.adamw_optimizer.warmup_ratio > 0:
    optimizer_config.adamw_optimizer.lr_warmup_steps = math.ceil(optimizer_config.adamw_optimizer.warmup_ratio * total_simulated_steps)

print(f"Total simulated steps: {total_simulated_steps}")
print(f"Total simulated warmup steps: {optimizer_config.adamw_optimizer.lr_warmup_steps}")
print(f"Total simulated decay steps: {optimizer_config.adamw_optimizer.lr_decay_steps}")

optimizer, optimizer_info = OptimizerFactory.get_optimizer(
    optimizer_config,
)

Total simulated steps: 100
Total simulated warmup steps: 100
Total simulated decay steps: 100


In [6]:
optimizer_config

accumulate_gradient_steps: 1
adamw_optimizer:
  b1: 0.9
  b2: 0.95
  bf16_momentum: false
  clip_gradient: 10000.0
  end_lr: 1.0e-05
  init_lr: 1.0e-05
  lr: 1.0e-05
  lr_decay_steps: 100
  lr_warmup_steps: 100
  multiply_by_parameter_scale: false
  warmup_ratio: 1.0
  weight_decay: 0.0
palm_optimizer:
  b1: 0.9
  b2: 0.99
  bf16_momentum: false
  clip_gradient: 1.0
  lr: 0.01
  lr_warmup_steps: 10000
  weight_decay: 0.0001
type: adamw

In [7]:
# running epochs

import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as PS
from flax.training.train_state import TrainState


seq_length = train_data.num_prefix_tokens + train_data.num_target_tokens  

def create_trainstate_from_params(params):
    return TrainState.create(params=params, tx=optimizer, apply_fn=None)


def init_fn(rng):
    rng_generator = JaxRNG(rng)
    params = model.init(
        input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
        position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
        attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
        rngs=rng_generator(gpt_config.rng_keys()),
    )
    return TrainState.create(params=params, tx=optimizer, apply_fn=None)


print("Initializing training state and pjitting...")
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
    GPTJConfig.get_partition_rules(), train_state_shapes
)

axis_dims = '-1,4,1'
mesh = get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp'))


def eval_step(train_state, rng, batch):
    rng_generator = JaxRNG(rng)
    batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
    logits = model.apply(
        train_state.params, batch['input_tokens'], batch['attention_mask'],
        deterministic=True, rngs=rng_generator(GPTJConfig.rng_keys()),
    ).logits
    loss, accuracy = cross_entropy_loss_and_accuracy(
        logits, batch['target_tokens'], batch['loss_masks']
    )
    metrics = dict(
        eval_loss=loss,
        eval_accuracy=accuracy,
    )
    return rng_generator(), metrics

def train_step(train_state, rng, batch):
    rng_generator = JaxRNG(rng)
    batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
    def loss_and_accuracy(params):
        logits = model.apply(
            params, batch['input_tokens'], batch['attention_mask'],
            deterministic=False, rngs=rng_generator(GPTJConfig.rng_keys()),
        ).logits
        return cross_entropy_loss_and_accuracy(
            logits, batch['target_tokens'], batch['loss_masks']
        )
    grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
    (loss, accuracy), grads = grad_fn(train_state.params)
    train_state = train_state.apply_gradients(grads=grads)
    metrics = dict(
        loss=loss,
        accuracy=accuracy,
        learning_rate=optimizer_info['learning_rate_schedule'](train_state.step // optimizer_config.accumulate_gradient_steps),
        gradient_norm=global_norm(grads),
        param_norm=global_norm(train_state.params),
    )
    return train_state, rng_generator(), metrics

print("Initializing training state and pjitting...")
train_state_shapes = jax.eval_shape(init_fn, next_rng())
train_state_partition = match_partition_rules(
    GPTJConfig.get_partition_rules(), train_state_shapes
)

shard_fns, gather_fns = make_shard_and_gather_fns(
    train_state_partition, train_state_shapes
)

sharded_init_fn = pjit(
    init_fn,
    in_shardings=PS(),
    out_shardings=train_state_partition
)

sharded_create_trainstate_from_params = pjit(
    create_trainstate_from_params,
    in_shardings=(train_state_partition.params, ),
    out_shardings=train_state_partition,
    donate_argnums=(0, ),
)

sharded_train_step = pjit(
    train_step,
    in_shardings=(train_state_partition, PS(), PS()),
    out_shardings=(train_state_partition, PS(), PS()),
    donate_argnums=(0, 1),
)

sharded_eval_step = pjit(
    eval_step,
    in_shardings=(train_state_partition, PS(), PS()),
    out_shardings=(PS(), PS()),
    donate_argnums=(1,),
)

Initializing training state and pjitting...
Initializing training state and pjitting...


In [8]:
# train_state_shapes
import numpy as np


In [9]:
param_count = sum(np.prod(x.shape) for x in jax.tree_leaves(train_state_shapes.params))
print(param_count)

6050882784


  param_count = sum(np.prod(x.shape) for x in jax.tree_leaves(train_state_shapes.params))


In [10]:
import time
from jax_smi import initialise_tracking
initialise_tracking()

# with mesh:
#     train_state = sharded_init_fn(next_rng())

#     start_step = int(jax.device_get(train_state.step))
#     start_epoch = start_step // simulated_steps_per_epoch
#     start_step = start_step % simulated_steps_per_epoch

#     sharded_rng = next_rng()

#     epoch_counter = trange(start_epoch, num_epochs, ncols=0, position=0)
#     step_counter = trange(start_step, simulated_steps_per_epoch, ncols=0, position=1)
    
    # overall_step = 0
    # for epoch in epoch_counter:
    #     for step, batch in tqdm(zip(step_counter, train_loader), total=simulated_steps_per_epoch):
    #         x, y = batch
    #         x = x.numpy()
    #         y = y.numpy()
            # if isinstance(batch, (list, tuple)):
            #     batch = {
            #         'input_tokens': x,
            #         'target_tokens': y,
            #         'loss_masks': (y != -1).astype(jnp.int32),
            #         'attention_mask': (y != -1).astype(jnp.int32)
            #     }
            # # just measuring the train step time.
            # start_time = time.time()
            # train_state, sharded_rng, metrics = sharded_train_step(
            #     train_state, sharded_rng, batch
            # )
            # step_time = time.time() - start_time
            # overall_step += 1

    #         if step % FLAGS.log_freq == 0:
    #             if FLAGS.eval_steps > 0:
    #                 eval_metric_list = []
    #                 for batch in eval_dataset:
    #                     if isinstance(batch, (list, tuple)):
    #                         batch = {
    #                             'tokens': batch[0],
    #                             'loss_masks': batch[1],
    #                         }
    #                     sharded_rng, eval_metrics = sharded_eval_step(
    #                         train_state, sharded_rng, batch
    #                     )
    #                     eval_metric_list.append(eval_metrics)
    #                 metrics.update(average_metrics(eval_metric_list))
    #             log_metrics = {
    #                 "train/step": overall_step,
    #                 "train/samples_seen": overall_step * real_batch_size,
    #                 "train/step_time": step_time,
    #                 "train/epoch": overall_step / steps_per_epoch,
    #             }
    #             log_metrics = jax.device_get(log_metrics)
    #             log_metrics.update(metrics)
    #             log_metrics = {k: float(v) for k, v in log_metrics.items()}
    #             tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")


In [11]:
import pprint

with mesh:
    train_state = sharded_init_fn(next_rng())
    start_step = int(jax.device_get(train_state.step))
    start_epoch = start_step // simulated_steps_per_epoch
    start_step = start_step % simulated_steps_per_epoch
    sharded_rng = next_rng()
    # epoch_counter = trange(start_epoch, num_epochs, ncols=0, position=0)
    step_counter = trange(start_step, simulated_steps_per_epoch, ncols=0, position=1)
    overall_step = 0
    # num_epochs = 100
    for epoch in range(num_epochs):
        for step, batch in tqdm(zip(step_counter, train_loader), total=simulated_steps_per_epoch):
            overall_step += 1
            x, y = batch
            x = x.numpy()
            y = y.numpy()
            batch = {
                'input_tokens': x,
                'target_tokens': y,
                'loss_masks': (y != -1).astype(jnp.int32),
                'attention_mask': np.ones_like(x).astype(jnp.int32)
            }
            # just measuring the train step time.
            start_time = time.time()
            train_state, sharded_rng, metrics = sharded_train_step(
                train_state, sharded_rng, batch
            )
            step_time = time.time() - start_time

            # if log_fr > 0:
            #         eval_metric_list = []
            #         for batch in eval_dataset:
            #             if isinstance(batch, (list, tuple)):
            #                 batch = {
            #                     'tokens': batch[0],
            #                     'loss_masks': batch[1],
            #                 }
            #             sharded_rng, eval_metrics = sharded_eval_step(
            #                 train_state, sharded_rng, batch
            #             )
            #             eval_metric_list.append(eval_metrics)
                    # metrics.update(average_metrics(eval_metric_list))
            if overall_step % 1 == 0:
                log_metrics = {
                    "train/step": overall_step,
                    "train/samples_seen": overall_step * 64,
                    "train/step_time": step_time,
                    "train/epoch": overall_step / simulated_steps_per_epoch,
                }
                log_metrics = jax.device_get(log_metrics)
                log_metrics.update(metrics)
                log_metrics = {k: float(v) for k, v in log_metrics.items()}
                tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")

  0%|          | 0/1 [00:00<?, ?it/s]
100% 1/1 [02:33<00:00, 153.26s/it] 153.26s/it]
100%|██████████| 1/1 [02:33<00:00, 153.26s/it]



{'accuracy': 0.0,
 'gradient_norm': 70.19483947753906,
 'learning_rate': 9.999999747378752e-06,
 'loss': 11.480037689208984,
 'param_norm': 1133.121337890625,
 'train/epoch': 1.0,
 'train/samples_seen': 64.0,
 'train/step': 1.0,
 'train/step_time': 152.94790744781494}



100%|██████████| 1/1 [00:00<00:00,  2.56it/s]



{'accuracy': 0.3125,
 'gradient_norm': 81.10279846191406,
 'learning_rate': 9.999999747378752e-06,
 'loss': 4.586386680603027,
 'param_norm': 1133.12158203125,
 'train/epoch': 2.0,
 'train/samples_seen': 128.0,
 'train/step': 2.0,
 'train/step_time': 0.08667612075805664}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 0.28125,
 'gradient_norm': 73.40702819824219,
 'learning_rate': 9.999999747378752e-06,
 'loss': 4.833763122558594,
 'param_norm': 1133.1219482421875,
 'train/epoch': 3.0,
 'train/samples_seen': 192.0,
 'train/step': 3.0,
 'train/step_time': 0.05884599685668945}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 0.25,
 'gradient_norm': 51.12541580200195,
 'learning_rate': 9.999999747378752e-06,
 'loss': 5.576484680175781,
 'param_norm': 1133.122314453125,
 'train/epoch': 4.0,
 'train/samples_seen': 256.0,
 'train/step': 4.0,
 'train/step_time': 0.058739662170410156}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 0.28125,
 'gradient_norm': 39.848304748535156,
 'learning_rate': 9.999999747378752e-06,
 'loss': 6.000193119049072,
 'param_norm': 1133.1226806640625,
 'train/epoch': 5.0,
 'train/samples_seen': 320.0,
 'train/step': 5.0,
 'train/step_time': 0.05930137634277344}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 0.4375,
 'gradient_norm': 37.42934036254883,
 'learning_rate': 9.999999747378752e-06,
 'loss': 5.438979625701904,
 'param_norm': 1133.1231689453125,
 'train/epoch': 6.0,
 'train/samples_seen': 384.0,
 'train/step': 6.0,
 'train/step_time': 0.05977511405944824}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 0.46875,
 'gradient_norm': 31.475730895996094,
 'learning_rate': 9.999999747378752e-06,
 'loss': 4.109224319458008,
 'param_norm': 1133.1234130859375,
 'train/epoch': 7.0,
 'train/samples_seen': 448.0,
 'train/step': 7.0,
 'train/step_time': 0.05827474594116211}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 0.4375,
 'gradient_norm': 29.525474548339844,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.565159320831299,
 'param_norm': 1133.1239013671875,
 'train/epoch': 8.0,
 'train/samples_seen': 512.0,
 'train/step': 8.0,
 'train/step_time': 0.05978202819824219}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 0.53125,
 'gradient_norm': 20.44376564025879,
 'learning_rate': 9.999999747378752e-06,
 'loss': 1.4317532777786255,
 'param_norm': 1133.1240234375,
 'train/epoch': 9.0,
 'train/samples_seen': 576.0,
 'train/step': 9.0,
 'train/step_time': 0.05944514274597168}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 0.46875,
 'gradient_norm': 39.2325439453125,
 'learning_rate': 9.999999747378752e-06,
 'loss': 1.5927447080612183,
 'param_norm': 1133.1246337890625,
 'train/epoch': 10.0,
 'train/samples_seen': 640.0,
 'train/step': 10.0,
 'train/step_time': 0.060240745544433594}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 0.53125,
 'gradient_norm': 24.199901580810547,
 'learning_rate': 9.999999747378752e-06,
 'loss': 1.4245643615722656,
 'param_norm': 1133.125244140625,
 'train/epoch': 11.0,
 'train/samples_seen': 704.0,
 'train/step': 11.0,
 'train/step_time': 0.05920767784118652}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 0.46875,
 'gradient_norm': 23.16550636291504,
 'learning_rate': 9.999999747378752e-06,
 'loss': 1.5415072441101074,
 'param_norm': 1133.12548828125,
 'train/epoch': 12.0,
 'train/samples_seen': 768.0,
 'train/step': 12.0,
 'train/step_time': 0.05970263481140137}



100%|██████████| 1/1 [00:00<00:00,  2.73it/s]



{'accuracy': 0.46875,
 'gradient_norm': 26.846752166748047,
 'learning_rate': 9.999999747378752e-06,
 'loss': 1.5697883367538452,
 'param_norm': 1133.1258544921875,
 'train/epoch': 13.0,
 'train/samples_seen': 832.0,
 'train/step': 13.0,
 'train/step_time': 0.05928182601928711}



100%|██████████| 1/1 [00:00<00:00,  2.72it/s]



{'accuracy': 0.5625,
 'gradient_norm': 23.088581085205078,
 'learning_rate': 9.999999747378752e-06,
 'loss': 1.1688101291656494,
 'param_norm': 1133.126220703125,
 'train/epoch': 14.0,
 'train/samples_seen': 896.0,
 'train/step': 14.0,
 'train/step_time': 0.06257867813110352}



100%|██████████| 1/1 [00:00<00:00,  2.73it/s]



{'accuracy': 0.71875,
 'gradient_norm': 15.588017463684082,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.7551507949829102,
 'param_norm': 1133.1265869140625,
 'train/epoch': 15.0,
 'train/samples_seen': 960.0,
 'train/step': 15.0,
 'train/step_time': 0.05966377258300781}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 0.78125,
 'gradient_norm': 13.083539962768555,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.6016021370887756,
 'param_norm': 1133.126953125,
 'train/epoch': 16.0,
 'train/samples_seen': 1024.0,
 'train/step': 16.0,
 'train/step_time': 0.059487342834472656}



100%|██████████| 1/1 [00:00<00:00,  2.76it/s]



{'accuracy': 0.84375,
 'gradient_norm': 10.785422325134277,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.5415163040161133,
 'param_norm': 1133.1273193359375,
 'train/epoch': 17.0,
 'train/samples_seen': 1088.0,
 'train/step': 17.0,
 'train/step_time': 0.058985233306884766}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 0.8125,
 'gradient_norm': 12.113346099853516,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.5527895092964172,
 'param_norm': 1133.127685546875,
 'train/epoch': 18.0,
 'train/samples_seen': 1152.0,
 'train/step': 18.0,
 'train/step_time': 0.06297922134399414}



100%|██████████| 1/1 [00:00<00:00,  2.72it/s]



{'accuracy': 0.78125,
 'gradient_norm': 11.437657356262207,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.527722954750061,
 'param_norm': 1133.128173828125,
 'train/epoch': 19.0,
 'train/samples_seen': 1216.0,
 'train/step': 19.0,
 'train/step_time': 0.06106066703796387}



100%|██████████| 1/1 [00:00<00:00,  2.73it/s]



{'accuracy': 0.875,
 'gradient_norm': 10.2225980758667,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.49152952432632446,
 'param_norm': 1133.12841796875,
 'train/epoch': 20.0,
 'train/samples_seen': 1280.0,
 'train/step': 20.0,
 'train/step_time': 0.06235146522521973}



100%|██████████| 1/1 [00:00<00:00,  2.71it/s]



{'accuracy': 0.875,
 'gradient_norm': 10.65017318725586,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.4671488404273987,
 'param_norm': 1133.1287841796875,
 'train/epoch': 21.0,
 'train/samples_seen': 1344.0,
 'train/step': 21.0,
 'train/step_time': 0.06514096260070801}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 0.90625,
 'gradient_norm': 9.865423202514648,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.3847954571247101,
 'param_norm': 1133.12890625,
 'train/epoch': 22.0,
 'train/samples_seen': 1408.0,
 'train/step': 22.0,
 'train/step_time': 0.06077742576599121}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 0.9375,
 'gradient_norm': 7.234127998352051,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.24658146500587463,
 'param_norm': 1133.129150390625,
 'train/epoch': 23.0,
 'train/samples_seen': 1472.0,
 'train/step': 23.0,
 'train/step_time': 0.05978226661682129}



100%|██████████| 1/1 [00:00<00:00,  2.72it/s]



{'accuracy': 0.96875,
 'gradient_norm': 4.349519729614258,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.13047142326831818,
 'param_norm': 1133.1295166015625,
 'train/epoch': 24.0,
 'train/samples_seen': 1536.0,
 'train/step': 24.0,
 'train/step_time': 0.06059741973876953}



100%|██████████| 1/1 [00:00<00:00,  2.72it/s]



{'accuracy': 1.0,
 'gradient_norm': 3.0994436740875244,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.08905977010726929,
 'param_norm': 1133.1300048828125,
 'train/epoch': 25.0,
 'train/samples_seen': 1600.0,
 'train/step': 25.0,
 'train/step_time': 0.06288599967956543}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 0.96875,
 'gradient_norm': 4.723756790161133,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.10385183990001678,
 'param_norm': 1133.13037109375,
 'train/epoch': 26.0,
 'train/samples_seen': 1664.0,
 'train/step': 26.0,
 'train/step_time': 0.060599565505981445}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 0.9375,
 'gradient_norm': 5.540087699890137,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.1118776947259903,
 'param_norm': 1133.130859375,
 'train/epoch': 27.0,
 'train/samples_seen': 1728.0,
 'train/step': 27.0,
 'train/step_time': 0.05979442596435547}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 4.327368259429932,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.08241647481918335,
 'param_norm': 1133.131103515625,
 'train/epoch': 28.0,
 'train/samples_seen': 1792.0,
 'train/step': 28.0,
 'train/step_time': 0.059505462646484375}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 2.7128753662109375,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.05076562985777855,
 'param_norm': 1133.1314697265625,
 'train/epoch': 29.0,
 'train/samples_seen': 1856.0,
 'train/step': 29.0,
 'train/step_time': 0.05990886688232422}



100%|██████████| 1/1 [00:00<00:00,  2.73it/s]



{'accuracy': 1.0,
 'gradient_norm': 1.9354236125946045,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.03394634276628494,
 'param_norm': 1133.1314697265625,
 'train/epoch': 30.0,
 'train/samples_seen': 1920.0,
 'train/step': 30.0,
 'train/step_time': 0.06069779396057129}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 1.3699100017547607,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.025297995656728745,
 'param_norm': 1133.1318359375,
 'train/epoch': 31.0,
 'train/samples_seen': 1984.0,
 'train/step': 31.0,
 'train/step_time': 0.05988955497741699}



100%|██████████| 1/1 [00:00<00:00,  2.70it/s]



{'accuracy': 1.0,
 'gradient_norm': 1.637054204940796,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.024930082261562347,
 'param_norm': 1133.1322021484375,
 'train/epoch': 32.0,
 'train/samples_seen': 2048.0,
 'train/step': 32.0,
 'train/step_time': 0.06579899787902832}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 2.222193479537964,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0274732057005167,
 'param_norm': 1133.13232421875,
 'train/epoch': 33.0,
 'train/samples_seen': 2112.0,
 'train/step': 33.0,
 'train/step_time': 0.06012082099914551}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 1.9471558332443237,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.022242136299610138,
 'param_norm': 1133.132568359375,
 'train/epoch': 34.0,
 'train/samples_seen': 2176.0,
 'train/step': 34.0,
 'train/step_time': 0.05978083610534668}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 1.0646789073944092,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.012387760914862156,
 'param_norm': 1133.1328125,
 'train/epoch': 35.0,
 'train/samples_seen': 2240.0,
 'train/step': 35.0,
 'train/step_time': 0.06012415885925293}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.5029011368751526,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.006342061795294285,
 'param_norm': 1133.1329345703125,
 'train/epoch': 36.0,
 'train/samples_seen': 2304.0,
 'train/step': 36.0,
 'train/step_time': 0.05972456932067871}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.29084312915802,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0036578564904630184,
 'param_norm': 1133.133056640625,
 'train/epoch': 37.0,
 'train/samples_seen': 2368.0,
 'train/step': 37.0,
 'train/step_time': 0.059670448303222656}



100%|██████████| 1/1 [00:00<00:00,  2.73it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.18828818202018738,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.002391529968008399,
 'param_norm': 1133.1331787109375,
 'train/epoch': 38.0,
 'train/samples_seen': 2432.0,
 'train/step': 38.0,
 'train/step_time': 0.059590816497802734}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.1224786564707756,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0017474709311500192,
 'param_norm': 1133.13330078125,
 'train/epoch': 39.0,
 'train/samples_seen': 2496.0,
 'train/step': 39.0,
 'train/step_time': 0.060390472412109375}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.09764599055051804,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0015165942022576928,
 'param_norm': 1133.133544921875,
 'train/epoch': 40.0,
 'train/samples_seen': 2560.0,
 'train/step': 40.0,
 'train/step_time': 0.060231924057006836}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.14557158946990967,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.001657618209719658,
 'param_norm': 1133.1337890625,
 'train/epoch': 41.0,
 'train/samples_seen': 2624.0,
 'train/step': 41.0,
 'train/step_time': 0.06022143363952637}



100%|██████████| 1/1 [00:00<00:00,  2.76it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.2710539996623993,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.002149676438421011,
 'param_norm': 1133.1337890625,
 'train/epoch': 42.0,
 'train/samples_seen': 2688.0,
 'train/step': 42.0,
 'train/step_time': 0.059216976165771484}



100%|██████████| 1/1 [00:00<00:00,  2.71it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.378701776266098,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0025782573502510786,
 'param_norm': 1133.1339111328125,
 'train/epoch': 43.0,
 'train/samples_seen': 2752.0,
 'train/step': 43.0,
 'train/step_time': 0.06401205062866211}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.3257705867290497,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.002308749593794346,
 'param_norm': 1133.134033203125,
 'train/epoch': 44.0,
 'train/samples_seen': 2816.0,
 'train/step': 44.0,
 'train/step_time': 0.05983781814575195}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.1983230859041214,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.001710308250039816,
 'param_norm': 1133.134033203125,
 'train/epoch': 45.0,
 'train/samples_seen': 2880.0,
 'train/step': 45.0,
 'train/step_time': 0.05945396423339844}



100%|██████████| 1/1 [00:00<00:00,  2.70it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.12716144323349,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0013043094659224153,
 'param_norm': 1133.1341552734375,
 'train/epoch': 46.0,
 'train/samples_seen': 2944.0,
 'train/step': 46.0,
 'train/step_time': 0.06725811958312988}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.10506156086921692,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0010935610625892878,
 'param_norm': 1133.1341552734375,
 'train/epoch': 47.0,
 'train/samples_seen': 3008.0,
 'train/step': 47.0,
 'train/step_time': 0.05950593948364258}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.09525037556886673,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0009611132554709911,
 'param_norm': 1133.1341552734375,
 'train/epoch': 48.0,
 'train/samples_seen': 3072.0,
 'train/step': 48.0,
 'train/step_time': 0.059891700744628906}



100%|██████████| 1/1 [00:00<00:00,  2.73it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.08593416213989258,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.000849151867441833,
 'param_norm': 1133.13427734375,
 'train/epoch': 49.0,
 'train/samples_seen': 3136.0,
 'train/step': 49.0,
 'train/step_time': 0.05984020233154297}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.07556669414043427,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0007413988350890577,
 'param_norm': 1133.1343994140625,
 'train/epoch': 50.0,
 'train/samples_seen': 3200.0,
 'train/step': 50.0,
 'train/step_time': 0.06090545654296875}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.06457009166479111,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0006357741658575833,
 'param_norm': 1133.1343994140625,
 'train/epoch': 51.0,
 'train/samples_seen': 3264.0,
 'train/step': 51.0,
 'train/step_time': 0.059639692306518555}



100%|██████████| 1/1 [00:00<00:00,  1.31it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.05344022810459137,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0005336754256859422,
 'param_norm': 1133.1343994140625,
 'train/epoch': 52.0,
 'train/samples_seen': 3328.0,
 'train/step': 52.0,
 'train/step_time': 0.4561457633972168}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.04364718496799469,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0004456841852515936,
 'param_norm': 1133.1346435546875,
 'train/epoch': 53.0,
 'train/samples_seen': 3392.0,
 'train/step': 53.0,
 'train/step_time': 0.0603642463684082}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.03524154797196388,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0003709126322064549,
 'param_norm': 1133.1346435546875,
 'train/epoch': 54.0,
 'train/samples_seen': 3456.0,
 'train/step': 54.0,
 'train/step_time': 0.06005048751831055}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.028365816920995712,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.00030927587067708373,
 'param_norm': 1133.134765625,
 'train/epoch': 55.0,
 'train/samples_seen': 3520.0,
 'train/step': 55.0,
 'train/step_time': 0.05916333198547363}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.022878503426909447,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.00025928113609552383,
 'param_norm': 1133.1348876953125,
 'train/epoch': 56.0,
 'train/samples_seen': 3584.0,
 'train/step': 56.0,
 'train/step_time': 0.05996084213256836}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.01862936280667782,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0002189513761550188,
 'param_norm': 1133.1348876953125,
 'train/epoch': 57.0,
 'train/samples_seen': 3648.0,
 'train/step': 57.0,
 'train/step_time': 0.05982208251953125}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.01544790156185627,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.0001878385664895177,
 'param_norm': 1133.135009765625,
 'train/epoch': 58.0,
 'train/samples_seen': 3712.0,
 'train/step': 58.0,
 'train/step_time': 0.06038379669189453}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.012985332868993282,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.00016289523045998067,
 'param_norm': 1133.135009765625,
 'train/epoch': 59.0,
 'train/samples_seen': 3776.0,
 'train/step': 59.0,
 'train/step_time': 0.06002521514892578}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.011115746572613716,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.00014274826389737427,
 'param_norm': 1133.135009765625,
 'train/epoch': 60.0,
 'train/samples_seen': 3840.0,
 'train/step': 60.0,
 'train/step_time': 0.05884408950805664}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.009674017317593098,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.00012645297101698816,
 'param_norm': 1133.1351318359375,
 'train/epoch': 61.0,
 'train/samples_seen': 3904.0,
 'train/step': 61.0,
 'train/step_time': 0.059537410736083984}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.008583810180425644,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.00011365678801666945,
 'param_norm': 1133.1351318359375,
 'train/epoch': 62.0,
 'train/samples_seen': 3968.0,
 'train/step': 62.0,
 'train/step_time': 0.0625772476196289}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.007720008958131075,
 'learning_rate': 9.999999747378752e-06,
 'loss': 0.00010303840099368244,
 'param_norm': 1133.1351318359375,
 'train/epoch': 63.0,
 'train/samples_seen': 4032.0,
 'train/step': 63.0,
 'train/step_time': 0.05997467041015625}



100%|██████████| 1/1 [00:00<00:00,  2.73it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.007038043811917305,
 'learning_rate': 9.999999747378752e-06,
 'loss': 9.420939022675157e-05,
 'param_norm': 1133.1351318359375,
 'train/epoch': 64.0,
 'train/samples_seen': 4096.0,
 'train/step': 64.0,
 'train/step_time': 0.0601346492767334}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0064714327454566956,
 'learning_rate': 9.999999747378752e-06,
 'loss': 8.674423588672653e-05,
 'param_norm': 1133.1351318359375,
 'train/epoch': 65.0,
 'train/samples_seen': 4160.0,
 'train/step': 65.0,
 'train/step_time': 0.05967354774475098}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.006025720853358507,
 'learning_rate': 9.999999747378752e-06,
 'loss': 8.049885218497366e-05,
 'param_norm': 1133.13525390625,
 'train/epoch': 66.0,
 'train/samples_seen': 4224.0,
 'train/step': 66.0,
 'train/step_time': 0.059181928634643555}



100%|██████████| 1/1 [00:00<00:00,  2.73it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0056823003105819225,
 'learning_rate': 9.999999747378752e-06,
 'loss': 7.543769606854767e-05,
 'param_norm': 1133.13525390625,
 'train/epoch': 67.0,
 'train/samples_seen': 4288.0,
 'train/step': 67.0,
 'train/step_time': 0.0628042221069336}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0053789690136909485,
 'learning_rate': 9.999999747378752e-06,
 'loss': 7.080831710482016e-05,
 'param_norm': 1133.13525390625,
 'train/epoch': 68.0,
 'train/samples_seen': 4352.0,
 'train/step': 68.0,
 'train/step_time': 0.06052136421203613}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.00512652937322855,
 'learning_rate': 9.999999747378752e-06,
 'loss': 6.687430141028017e-05,
 'param_norm': 1133.1353759765625,
 'train/epoch': 69.0,
 'train/samples_seen': 4416.0,
 'train/step': 69.0,
 'train/step_time': 0.059479475021362305}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.004900877829641104,
 'learning_rate': 9.999999747378752e-06,
 'loss': 6.335027137538418e-05,
 'param_norm': 1133.13525390625,
 'train/epoch': 70.0,
 'train/samples_seen': 4480.0,
 'train/step': 70.0,
 'train/step_time': 0.06028008460998535}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.004716932307928801,
 'learning_rate': 9.999999747378752e-06,
 'loss': 6.035594560671598e-05,
 'param_norm': 1133.1353759765625,
 'train/epoch': 71.0,
 'train/samples_seen': 4544.0,
 'train/step': 71.0,
 'train/step_time': 0.06037497520446777}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0045410203747451305,
 'learning_rate': 9.999999747378752e-06,
 'loss': 5.738808249589056e-05,
 'param_norm': 1133.1353759765625,
 'train/epoch': 72.0,
 'train/samples_seen': 4608.0,
 'train/step': 72.0,
 'train/step_time': 0.0638577938079834}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.004383785184472799,
 'learning_rate': 9.999999747378752e-06,
 'loss': 5.482647247845307e-05,
 'param_norm': 1133.1353759765625,
 'train/epoch': 73.0,
 'train/samples_seen': 4672.0,
 'train/step': 73.0,
 'train/step_time': 0.05940413475036621}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.004233717918395996,
 'learning_rate': 9.999999747378752e-06,
 'loss': 5.2505696658045053e-05,
 'param_norm': 1133.135498046875,
 'train/epoch': 74.0,
 'train/samples_seen': 4736.0,
 'train/step': 74.0,
 'train/step_time': 0.05973935127258301}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.004082378465682268,
 'learning_rate': 9.999999747378752e-06,
 'loss': 5.025630525778979e-05,
 'param_norm': 1133.135498046875,
 'train/epoch': 75.0,
 'train/samples_seen': 4800.0,
 'train/step': 75.0,
 'train/step_time': 0.06052660942077637}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0039432840421795845,
 'learning_rate': 9.999999747378752e-06,
 'loss': 4.814419662579894e-05,
 'param_norm': 1133.135498046875,
 'train/epoch': 76.0,
 'train/samples_seen': 4864.0,
 'train/step': 76.0,
 'train/step_time': 0.060437679290771484}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0037891112733632326,
 'learning_rate': 9.999999747378752e-06,
 'loss': 4.608061135513708e-05,
 'param_norm': 1133.135498046875,
 'train/epoch': 77.0,
 'train/samples_seen': 4928.0,
 'train/step': 77.0,
 'train/step_time': 0.05956840515136719}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0036544310860335827,
 'learning_rate': 9.999999747378752e-06,
 'loss': 4.4225831516087055e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 78.0,
 'train/samples_seen': 4992.0,
 'train/step': 78.0,
 'train/step_time': 0.05919051170349121}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0035043926909565926,
 'learning_rate': 9.999999747378752e-06,
 'loss': 4.236735912854783e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 79.0,
 'train/samples_seen': 5056.0,
 'train/step': 79.0,
 'train/step_time': 0.05946540832519531}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0033588882070034742,
 'learning_rate': 9.999999747378752e-06,
 'loss': 4.0620587242301553e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 80.0,
 'train/samples_seen': 5120.0,
 'train/step': 80.0,
 'train/step_time': 0.05932903289794922}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0032094253692775965,
 'learning_rate': 9.999999747378752e-06,
 'loss': 3.8981841498753056e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 81.0,
 'train/samples_seen': 5184.0,
 'train/step': 81.0,
 'train/step_time': 0.05986189842224121}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0030562819447368383,
 'learning_rate': 9.999999747378752e-06,
 'loss': 3.727244620677084e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 82.0,
 'train/samples_seen': 5248.0,
 'train/step': 82.0,
 'train/step_time': 0.05976724624633789}



100%|██████████| 1/1 [00:00<00:00,  2.76it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0029111052863299847,
 'learning_rate': 9.999999747378752e-06,
 'loss': 3.573080903152004e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 83.0,
 'train/samples_seen': 5312.0,
 'train/step': 83.0,
 'train/step_time': 0.06433629989624023}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0027625488582998514,
 'learning_rate': 9.999999747378752e-06,
 'loss': 3.4199874789919704e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 84.0,
 'train/samples_seen': 5376.0,
 'train/step': 84.0,
 'train/step_time': 0.05947375297546387}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.00262651639059186,
 'learning_rate': 9.999999747378752e-06,
 'loss': 3.2847900001797825e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 85.0,
 'train/samples_seen': 5440.0,
 'train/step': 85.0,
 'train/step_time': 0.0601649284362793}



100%|██████████| 1/1 [00:00<00:00,  2.73it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0024839784018695354,
 'learning_rate': 9.999999747378752e-06,
 'loss': 3.1439885788131505e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 86.0,
 'train/samples_seen': 5504.0,
 'train/step': 86.0,
 'train/step_time': 0.05990767478942871}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0023566626477986574,
 'learning_rate': 9.999999747378752e-06,
 'loss': 3.0207067538867705e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 87.0,
 'train/samples_seen': 5568.0,
 'train/step': 87.0,
 'train/step_time': 0.06044626235961914}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0022278837859630585,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.8977470719837584e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 88.0,
 'train/samples_seen': 5632.0,
 'train/step': 88.0,
 'train/step_time': 0.060059547424316406}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0021134635899215937,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.788228084682487e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 89.0,
 'train/samples_seen': 5696.0,
 'train/step': 89.0,
 'train/step_time': 0.06088376045227051}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0020027742721140385,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.6839414204005152e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 90.0,
 'train/samples_seen': 5760.0,
 'train/step': 90.0,
 'train/step_time': 0.05907177925109863}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0018909894861280918,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.579285683168564e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 91.0,
 'train/samples_seen': 5824.0,
 'train/step': 91.0,
 'train/step_time': 0.05918383598327637}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.001796287135221064,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.4899025447666645e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 92.0,
 'train/samples_seen': 5888.0,
 'train/step': 92.0,
 'train/step_time': 0.06003999710083008}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0016975023318082094,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.3968095774762332e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 93.0,
 'train/samples_seen': 5952.0,
 'train/step': 93.0,
 'train/step_time': 0.060092926025390625}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0016115116886794567,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.3137570678954944e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 94.0,
 'train/samples_seen': 6016.0,
 'train/step': 94.0,
 'train/step_time': 0.06524825096130371}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.001528282999061048,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.237014086858835e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 95.0,
 'train/samples_seen': 6080.0,
 'train/step': 95.0,
 'train/step_time': 0.06039714813232422}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0014478355878964067,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.1628924514516257e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 96.0,
 'train/samples_seen': 6144.0,
 'train/step': 96.0,
 'train/step_time': 0.059899330139160156}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0013809220399707556,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.096950265695341e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 97.0,
 'train/samples_seen': 6208.0,
 'train/step': 97.0,
 'train/step_time': 0.06013965606689453}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0013074430171400309,
 'learning_rate': 9.999999747378752e-06,
 'loss': 2.0287658117013052e-05,
 'param_norm': 1133.1356201171875,
 'train/epoch': 98.0,
 'train/samples_seen': 6272.0,
 'train/step': 98.0,
 'train/step_time': 0.060712337493896484}



100%|██████████| 1/1 [00:00<00:00,  2.74it/s]



{'accuracy': 1.0,
 'gradient_norm': 0.0012514239642769098,
 'learning_rate': 9.999999747378752e-06,
 'loss': 1.9717765098903328e-05,
 'param_norm': 1133.1357421875,
 'train/epoch': 99.0,
 'train/samples_seen': 6336.0,
 'train/step': 99.0,
 'train/step_time': 0.06013369560241699}



100%|██████████| 1/1 [00:00<00:00,  2.75it/s]


{'accuracy': 1.0,
 'gradient_norm': 0.0011923854472115636,
 'learning_rate': 9.999999747378752e-06,
 'loss': 1.9166473066434264e-05,
 'param_norm': 1133.1357421875,
 'train/epoch': 100.0,
 'train/samples_seen': 6400.0,
 'train/step': 100.0,
 'train/step_time': 0.06075477600097656}






In [12]:
sys.argv = """
    --mesh_dim='-1,64,1' \
    --dtype='fp32' \
    --total_steps=250000 \
    --log_freq=50 \
    --save_model_freq=0 \
    --save_milestone_freq=2500 \
    --load_llama_config='7b' \
    --update_llama_config='' \
    --load_dataset_state='' \
    --load_checkpoint='' \
    --tokenizer.vocab_file='./llama2-tokenizer.model' \
    --optimizer.type='adamw' \
    --optimizer.adamw_optimizer.weight_decay=0.1 \
    --optimizer.adamw_optimizer.lr=3e-4 \
    --optimizer.adamw_optimizer.end_lr=3e-5 \
    --optimizer.adamw_optimizer.lr_warmup_steps=2000 \
    --optimizer.adamw_optimizer.lr_decay_steps=250000 \
    --train_dataset.type='json' \
    --train_dataset.text_processor.fields='text' \
    --train_dataset.json_dataset.path='/path/to/shuffled/redpajama/dataset' \
    --train_dataset.json_dataset.seq_length=2048 \
    --train_dataset.json_dataset.batch_size=2048 \
    --train_dataset.json_dataset.tokenizer_processes=16 \
    --checkpointer.save_optimizer_state=True \
    --logger.online=True \
    --logger.prefix='EasyLM' \
    --logger.project="open_llama_7b" \
    --logger.output_dir="/path/to/checkpoint/dir" \
    --logger.wandb_dir="$HOME/experiment_output/open_llama_7b" \
""".split("=")