Model: small-llama2. https://huggingface.co/TinyPixel/small-llama2



# This is the file that the experiments were conducted on.

# Preparing Data
## Loading data

In [1]:
from typing import Callable, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch._tensor import Tensor
from torch.nn.modules import Module
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.optimizer import Optimizer as Optimizer
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments
from transformers.data.data_collator import DataCollator
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from collections.abc import Mapping
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
)
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
import numpy as np

In [2]:
TRAINING = True
training_set_batch = 1_000_000 # 625 training steps. In total 59_203_382 rows
test_set_batch = 200_000 # In total 14_800_846 rows
bi_gram_scheduling_steps = 0 # 240 # 100, 1 training step is 32 batches
bi_gram_weight = 1
min_bi_gram_weight = 0.0
model_layer_number = 1 # max 12
use_bi_gram = False if bi_gram_scheduling_steps == 0 else True
teacher_scheduling_method = 'linear' # none, linear, exponential, reciprocal
fold_number = 4 # from 0 to 4
K = 5 # 5-fold

In [3]:
from datasets import load_dataset, DatasetDict
datasets = load_dataset("bookcorpus").shuffle(seed=42)
datasets

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 74004228
    })
})

In [4]:
'''
In total, the dataset has 74004228 rows.
74004228 / 5 = 14800845.6
Choose 250_000 samples from each fold to construct training dataset.
'''
from datasets import Dataset
new_train_dataset = None
new_val_dataset = None
for i in range(K):
    if i != fold_number:
        if new_train_dataset != None:
            new_train_dataset.extend(datasets["train"].select(range(i * 14800845, i * 14800845 + 250_000))['text'])
        else:
            new_train_dataset = datasets["train"].select(range(i * 14800845, i * 14800845 + 250_000))['text']
    else:
        new_val_dataset = datasets["train"].select(range(i * 14800845, i * 14800845 + test_set_batch))
new_train_dataset = Dataset.from_dict({'text': new_train_dataset})
datasets = DatasetDict({'train': new_train_dataset, 'validation': new_val_dataset})
datasets

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1000000
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 200000
    })
})

In [5]:
# Overwrite the forward() function of LlamaForCausalLM, enabling handling soft cross-entropy.
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
class LlamaForCausalLMWithBigram(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        from transformers import AutoTokenizer, LlamaForCausalLM

        model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )
        # print("Inside the model:", attention_mask) # It stays the same

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            loss_fct = CrossEntropyLoss()
            shift_logits = logits[..., :-1, :].contiguous()
            '''===========modification start point==============='''
            # Flatten the tokens
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = None
            
            if labels.shape[-1] == self.vocab_size:# With bigram: shape (4, 1024, 32000) # Without bigram: torch.Size([4, 1024])
                # labels = labels.to_dense()
                shift_labels = labels[..., 1:, :].contiguous() # Ignore the first token, it is not a label
                # Flatten the tokens
                shift_labels = shift_labels.view(-1, self.config.vocab_size)
                # print(shift_labels.shape, flush=True)
            else:
                shift_labels = labels[..., 1:].contiguous()
                # Flatten the tokens
                shift_labels = shift_labels.view(-1)
                # print(shift_labels.shape, flush=True)
            '''===========modification end point==============='''
                
            # Enable model parallelism
            # print('Inside the function:', shift_logits.shape, shift_labels.shape)
            # print('Inside the function2:', shift_labels.sum())
            # print('Inside the function2:', shift_labels)
            # '''
            # Inside the function: torch.Size([4092, 32000]) torch.Size([4092, 32000])
            # Inside the function2: tensor(4092., device='cuda:0', dtype=torch.float64)
            # Inside the function2: tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0', dtype=torch.float64)
            # '''
            
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [6]:
from transformers import AutoConfig

tokenizer = AutoTokenizer.from_pretrained("TinyPixel/small-llama2")
tokenizer.pad_token_id=tokenizer.eos_token_id
config = AutoConfig.from_pretrained("TinyPixel/small-llama2")
config.num_hidden_layers = model_layer_number # originally, 12
model = LlamaForCausalLMWithBigram(config)
model

LlamaForCausalLMWithBigram(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=1376, bias=False)
          (up_proj): Linear(in_features=1024, out_features=1376, bias=False)
          (down_proj): Linear(in_features=1376, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): 

In [7]:
# tokenization 
def tokenize(element):
    long_text = "".join(element['text']) # concatenation
    outputs = tokenizer(
        [long_text],
        truncation=True,
        return_overflowing_tokens=True,
        return_length=True,
        max_length=config.max_position_embeddings,
    )
    return {"input_ids": outputs['input_ids']}

tokenized_datasets = datasets.map(
    tokenize, batched=True, remove_columns=datasets["train"].column_names, 
    batch_size=200# , num_proc=10
)
tokenized_datasets


Map:   0%|          | 0/1000000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 18770
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 3766
    })
})

In [8]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [9]:
# load bi-gram probabilities
import numpy
bi_gram_probabilities = numpy.load('./bigram_probability.npy').astype(numpy.float16)
bi_gram_probabilities.shape

(32000, 32000)

# Training

In [10]:
# Overwrite _prepare_inputs() of Trainer, enabling teacher-student learning paradigm 
class MyTrainer(Trainer):
    def __init__(self,
                model: Union[PreTrainedModel, Module] = None,
                args: TrainingArguments = None,
                data_collator: Optional[DataCollator] = None,
                train_dataset: Optional[Dataset] = None,
                eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
                tokenizer: Optional[PreTrainedTokenizerBase] = None,
                model_init: Optional[Callable[[], PreTrainedModel]] = None,
                compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
                callbacks: Optional[List[TrainerCallback]] = None,
                optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
                preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
                bigram_probabilities: Optional[torch.Tensor] = None, # 
                bi_gram_scheduling_steps: Optional[int] = 0, # 
                min_bi_gram_weight: Optional[float] = 0.0, # 
                teacher_scheduling_method = '',
                ):
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
        self.bigram_probabilities = bigram_probabilities
        self.bi_gram_scheduling_steps = bi_gram_scheduling_steps
        self.min_bi_gram_weight = min_bi_gram_weight
        self.vocabulary_size = model.vocab_size
        self.teacher_scheduling_method = teacher_scheduling_method
    
    # Not _prepare_input(), recursive calling
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
        """
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        self.control.should_evaluate will be set as True before evaluation, and will be set as False after evaluation.
        """
        '''===========modification start point==============='''
        # # Enable min_bi_gram_weight for the training steps after bi-gram scheduling
        # if self.bi_gram_scheduling_steps and (not self.control.should_evaluate):
        # Disenable min_bi_gram_weight for the training steps after bi-gram scheduling
        if self.state.global_step < self.bi_gram_scheduling_steps and (not self.control.should_evaluate): # About 10 times slow than without bigram
            labels = inputs['labels'].cpu() # return a copy
            # print('labels ', labels.shape) # labels  torch.Size([4, 1024])
            # Here is tricky. We should remove then last one of label. 
            ids = labels.view(-1)
            bigram_labels = self.bigram_probabilities[ids[:-1], :]
            bigram_weight = 0
            if self.teacher_scheduling_method == 'linear': 
                # ====== linear scheduling ====== 
                bigram_weight = max(self.min_bi_gram_weight, 
                                    1 - min(self.state.global_step, self.bi_gram_scheduling_steps) / self.bi_gram_scheduling_steps)
            elif self.teacher_scheduling_method == 'exponential':
                # ====== exponential scheduling ====== 
                bigram_weight = np.exp(-self.state.global_step)
                # if self.state.global_step % 50 == 49:
                #     print(f'The global step is {self.state.global_step}. The bigram weight is {bigram_weight}')
                # print('Before ', np.sum(bigram_labels, axis=1)[-10:])
                # print('Ids ', ids[-10:])
            elif self.teacher_scheduling_method == 'reciprocal':
                bigram_weight = 1/(self.state.global_step)
            # print('Inside: bigram_weight=', bigram_weight)
            # combine bi-gram probabilities and one-hot labels.
            bigram_labels *= bigram_weight # Shape  (4096, 32000)
            # one_hot_matrix = np.zeros((bigram_labels.shape[0], self.vocabulary_size), dtype=np.float16)
            # one_hot_matrix[np.arange(bigram_labels.shape[0]), ids] = 1 - bigram_weight
            # bigram_labels[:-1, :] += one_hot_matrix[1:, :]
            bigram_labels[np.arange(bigram_labels.shape[0]), ids[1:]] += (1 - bigram_weight)
            bigram_labels = np.concatenate((np.zeros((1, self.vocabulary_size)), bigram_labels), 0)
            # print('After ', np.sum(bigram_labels, axis=1)[-10:], bigram_labels.shape)
            bigram_labels = bigram_labels.reshape((inputs['labels'].shape[0], 
                                                inputs['labels'].shape[1], 
                                                self.vocabulary_size)) # Shape  (4, 1024, 32000)
            # RuntimeError: Expect the same number of specified elements per batch.
            # inputs['labels'] = torch.tensor(bigram_labels).to_sparse_csr().to(self.args.device)
            inputs['labels'] = torch.tensor(bigram_labels).to(self.args.device)
        '''===========modification end point==============='''
        inputs = self._prepare_input(inputs)
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past

        return inputs

In [11]:
from torch.nn import CrossEntropyLoss
from torch import tensor, exp
from datetime import datetime

def compute_metrics(eval_pred):
    print('Inside compute_metrics', eval_pred.predictions.shape, eval_pred.label_ids.shape)
    # Inside compute_metrics (11, 1024, 32000) (11, 1024)  numpy.ndarray
    loss_fct = CrossEntropyLoss()
    prediction = tensor(eval_pred.predictions).view(-1, 32000)
    labels = tensor(eval_pred.label_ids).view(-1)
    masked_lm_loss = exp(loss_fct(prediction, labels)) 
    return {'ppl': masked_lm_loss}
from transformers import Trainer, TrainingArguments
import os
# os.environ['WANDB_DISABLED'] = 'true' # turning off reporting to WanDB. It requires API key
output_dir="llama2-bigram-guided-k-fold"
now = datetime.now()
dt_string = now.strftime("%d_%m_%H-%M-%S") # 27_12_10-09-20
log_name = output_dir + '/runs/' + dt_string + (f'_{bi_gram_scheduling_steps}_{min_bi_gram_weight}_{model_layer_number}' 
                                                if use_bi_gram else f"_no_bigram_{model_layer_number}") 
log_name += "_" + teacher_scheduling_method[0] + '_' + str(fold_number)

args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4, 
    evaluation_strategy="steps",
    eval_steps=1, 
    logging_steps=1,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=1, 
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=3, 
    fp16=True,
    push_to_hub=False, 
    report_to='tensorboard',
    logging_dir=log_name
)
if TRAINING:
    args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4, 
        evaluation_strategy="steps",
        eval_steps=3_0, 
        logging_steps=20, 
        gradient_accumulation_steps=8,
        num_train_epochs=1,
        weight_decay=0.1,
        warmup_steps=1_00, 
        lr_scheduler_type="cosine",
        learning_rate=5e-4,
        save_steps=3_000, 
        fp16=True,
        push_to_hub=False,
        report_to='tensorboard',
        logging_dir=log_name
    )

trainer = MyTrainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    # compute_metrics=compute_metrics
    bigram_probabilities = bi_gram_probabilities, # The bi-gram probability matrix
    bi_gram_scheduling_steps = bi_gram_scheduling_steps, 
    min_bi_gram_weight = min_bi_gram_weight, # The minimum bi-gram weight
    teacher_scheduling_method = teacher_scheduling_method
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [12]:
trainer.train()

  0%|          | 0/586 [00:00<?, ?it/s]

{'loss': 9.6094, 'grad_norm': 1.9818520545959473, 'learning_rate': 0.0001, 'epoch': 0.03}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 6.846191883087158, 'eval_runtime': 34.2766, 'eval_samples_per_second': 109.871, 'eval_steps_per_second': 27.482, 'epoch': 0.05}
{'loss': 7.0318, 'grad_norm': 0.5009767413139343, 'learning_rate': 0.0002, 'epoch': 0.07}
{'loss': 6.013, 'grad_norm': 0.530003547668457, 'learning_rate': 0.0003, 'epoch': 0.1}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 5.714672088623047, 'eval_runtime': 33.8619, 'eval_samples_per_second': 111.217, 'eval_steps_per_second': 27.819, 'epoch': 0.1}
{'loss': 5.535, 'grad_norm': 0.5313185453414917, 'learning_rate': 0.0004, 'epoch': 0.14}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 5.282958507537842, 'eval_runtime': 34.0115, 'eval_samples_per_second': 110.727, 'eval_steps_per_second': 27.697, 'epoch': 0.15}
{'loss': 5.2913, 'grad_norm': 0.5291732549667358, 'learning_rate': 0.0005, 'epoch': 0.17}
{'loss': 5.1479, 'grad_norm': 0.5427552461624146, 'learning_rate': 0.0004979136257327503, 'epoch': 0.2}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 5.080759525299072, 'eval_runtime': 34.4618, 'eval_samples_per_second': 109.28, 'eval_steps_per_second': 27.335, 'epoch': 0.2}
{'loss': 5.0386, 'grad_norm': 0.4574202001094818, 'learning_rate': 0.0004916893265916655, 'epoch': 0.24}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.963342666625977, 'eval_runtime': 34.6616, 'eval_samples_per_second': 108.65, 'eval_steps_per_second': 27.177, 'epoch': 0.26}
{'loss': 4.9674, 'grad_norm': 0.45416349172592163, 'learning_rate': 0.00048143099231722267, 'epoch': 0.27}
{'loss': 4.891, 'grad_norm': 0.4451713263988495, 'learning_rate': 0.00046730984470666194, 'epoch': 0.31}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.874007225036621, 'eval_runtime': 34.741, 'eval_samples_per_second': 108.402, 'eval_steps_per_second': 27.115, 'epoch': 0.31}
{'loss': 4.8534, 'grad_norm': 0.4192894995212555, 'learning_rate': 0.0004495615797519732, 'epoch': 0.34}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.800921440124512, 'eval_runtime': 35.0272, 'eval_samples_per_second': 107.516, 'eval_steps_per_second': 26.893, 'epoch': 0.36}
{'loss': 4.8106, 'grad_norm': 0.4221283793449402, 'learning_rate': 0.0004284824336394748, 'epoch': 0.38}
{'loss': 4.7661, 'grad_norm': 0.4034672677516937, 'learning_rate': 0.00040442423827336427, 'epoch': 0.41}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.740621566772461, 'eval_runtime': 34.7013, 'eval_samples_per_second': 108.526, 'eval_steps_per_second': 27.146, 'epoch': 0.41}
{'loss': 4.7326, 'grad_norm': 0.40580856800079346, 'learning_rate': 0.0003777885488514683, 'epoch': 0.44}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.6825456619262695, 'eval_runtime': 34.2534, 'eval_samples_per_second': 109.945, 'eval_steps_per_second': 27.501, 'epoch': 0.46}
{'loss': 4.6818, 'grad_norm': 0.4124316871166229, 'learning_rate': 0.00034901994150978924, 'epoch': 0.48}
{'loss': 4.6374, 'grad_norm': 0.4120369553565979, 'learning_rate': 0.0003185985929048254, 'epoch': 0.51}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.627112865447998, 'eval_runtime': 34.2272, 'eval_samples_per_second': 110.029, 'eval_steps_per_second': 27.522, 'epoch': 0.51}
{'loss': 4.6044, 'grad_norm': 0.386046439409256, 'learning_rate': 0.00028703226558781227, 'epoch': 0.55}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.582106113433838, 'eval_runtime': 34.2469, 'eval_samples_per_second': 109.966, 'eval_steps_per_second': 27.506, 'epoch': 0.56}
{'loss': 4.5874, 'grad_norm': 0.4245465099811554, 'learning_rate': 0.0002548478329429561, 'epoch': 0.58}
{'loss': 4.5581, 'grad_norm': 0.4106554388999939, 'learning_rate': 0.0002225824851468671, 'epoch': 0.61}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.541588306427002, 'eval_runtime': 34.4898, 'eval_samples_per_second': 109.192, 'eval_steps_per_second': 27.312, 'epoch': 0.61}
{'loss': 4.5367, 'grad_norm': 0.37997713685035706, 'learning_rate': 0.00019077476293047018, 'epoch': 0.65}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.508094787597656, 'eval_runtime': 34.7334, 'eval_samples_per_second': 108.426, 'eval_steps_per_second': 27.121, 'epoch': 0.66}
{'loss': 4.5034, 'grad_norm': 0.3847604990005493, 'learning_rate': 0.00015995556879882245, 'epoch': 0.68}
{'loss': 4.4908, 'grad_norm': 0.4286229610443115, 'learning_rate': 0.00013063930574051273, 'epoch': 0.72}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.479612827301025, 'eval_runtime': 35.307, 'eval_samples_per_second': 106.664, 'eval_steps_per_second': 26.68, 'epoch': 0.72}
{'loss': 4.4827, 'grad_norm': 0.4362623691558838, 'learning_rate': 0.00010331529133039555, 'epoch': 0.75}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.458536148071289, 'eval_runtime': 34.9189, 'eval_samples_per_second': 107.85, 'eval_steps_per_second': 26.977, 'epoch': 0.77}
{'loss': 4.4561, 'grad_norm': 0.3950788974761963, 'learning_rate': 7.843959053281663e-05, 'epoch': 0.78}
{'loss': 4.4576, 'grad_norm': 0.4128142297267914, 'learning_rate': 5.6427403523966886e-05, 'epoch': 0.82}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.442216396331787, 'eval_runtime': 34.9424, 'eval_samples_per_second': 107.777, 'eval_steps_per_second': 26.959, 'epoch': 0.82}
{'loss': 4.4397, 'grad_norm': 0.36803099513053894, 'learning_rate': 3.7646135588175675e-05, 'epoch': 0.85}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.430988311767578, 'eval_runtime': 35.1233, 'eval_samples_per_second': 107.222, 'eval_steps_per_second': 26.82, 'epoch': 0.87}
{'loss': 4.4339, 'grad_norm': 0.40709859132766724, 'learning_rate': 2.240926475846336e-05, 'epoch': 0.89}
{'loss': 4.4284, 'grad_norm': 0.37404724955558777, 'learning_rate': 1.0971109556530106e-05, 'epoch': 0.92}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.425398826599121, 'eval_runtime': 34.8804, 'eval_samples_per_second': 107.969, 'eval_steps_per_second': 27.007, 'epoch': 0.92}
{'loss': 4.4223, 'grad_norm': 0.3711010217666626, 'learning_rate': 3.5225841638008847e-06, 'epoch': 0.95}


  0%|          | 0/942 [00:00<?, ?it/s]

{'eval_loss': 4.423478603363037, 'eval_runtime': 34.4342, 'eval_samples_per_second': 109.368, 'eval_steps_per_second': 27.357, 'epoch': 0.97}
{'loss': 4.4198, 'grad_norm': 0.3579881191253662, 'learning_rate': 1.8801187394248964e-07, 'epoch': 0.99}
{'train_runtime': 1178.7105, 'train_samples_per_second': 15.924, 'train_steps_per_second': 0.497, 'train_loss': 4.988185898842665, 'epoch': 1.0}


TrainOutput(global_step=586, training_loss=4.988185898842665, metrics={'train_runtime': 1178.7105, 'train_samples_per_second': 15.924, 'train_steps_per_second': 0.497, 'train_loss': 4.988185898842665, 'epoch': 1.0})

In [13]:
# import os
# os.system("shutdown -t  60 ")