Model: small-llama2. https://huggingface.co/TinyPixel/small-llama2



# This is the file that the experiments were conducted on.

# Preparing Data
## Loading data

In [1]:
from typing import Callable, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch._tensor import Tensor
from torch.nn.modules import Module
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.optimizer import Optimizer as Optimizer
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments
from transformers.data.data_collator import DataCollator
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from collections.abc import Mapping
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
)
from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
import numpy as np

In [2]:
TRAINING = True
training_set_batch = 1_000_000 # 625 training steps. In total 59_203_382 rows
test_set_batch = 200_000 # In total 14_800_846 rows
bi_gram_scheduling_steps = 240 # 100, 1 training step is 32 batches
bi_gram_weight = 1
min_bi_gram_weight = 0.0
model_layer_number = 1 # max 12
use_bi_gram = False if bi_gram_scheduling_steps == 0 else True
teacher_scheduling_method = 'linear' # none, linear, exponential, reciprocal

In [3]:
from datasets import load_dataset, DatasetDict
datasets = load_dataset("./bookcorpus-splitted")
# datasets.cleanup_cache_files() # clean GPU cache
datasets = DatasetDict({'train': datasets['train'].select(range(training_set_batch)), 
            'validation': datasets['validation'].select(range(test_set_batch))})
datasets

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 1000000
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 200000
    })
})

In [4]:
# Overwrite the forward() function of LlamaForCausalLM, enabling handling soft cross-entropy.
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM
class LlamaForCausalLMWithBigram(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        from transformers import AutoTokenizer, LlamaForCausalLM

        model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )
        # print("Inside the model:", attention_mask) # It stays the same

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            loss_fct = CrossEntropyLoss()
            shift_logits = logits[..., :-1, :].contiguous()
            '''===========modification start point==============='''
            # Flatten the tokens
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = None
            
            if labels.shape[-1] == self.vocab_size:# With bigram: shape (4, 1024, 32000) # Without bigram: torch.Size([4, 1024])
                # labels = labels.to_dense()
                shift_labels = labels[..., 1:, :].contiguous() # Ignore the first token, it is not a label
                # Flatten the tokens
                shift_labels = shift_labels.view(-1, self.config.vocab_size)
                # print(shift_labels.shape, flush=True)
            else:
                shift_labels = labels[..., 1:].contiguous()
                # Flatten the tokens
                shift_labels = shift_labels.view(-1)
                # print(shift_labels.shape, flush=True)
            '''===========modification end point==============='''
                
            # Enable model parallelism
            # print('Inside the function:', shift_logits.shape, shift_labels.shape)
            # print('Inside the function2:', shift_labels.sum())
            # print('Inside the function2:', shift_labels)
            # '''
            # Inside the function: torch.Size([4092, 32000]) torch.Size([4092, 32000])
            # Inside the function2: tensor(4092., device='cuda:0', dtype=torch.float64)
            # Inside the function2: tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0', dtype=torch.float64)
            # '''
            
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [5]:
from transformers import AutoConfig

tokenizer = AutoTokenizer.from_pretrained("TinyPixel/small-llama2")
tokenizer.pad_token_id=tokenizer.eos_token_id
config = AutoConfig.from_pretrained("TinyPixel/small-llama2")
config.num_hidden_layers = model_layer_number # originally, 12
model = LlamaForCausalLMWithBigram(config)
model

LlamaForCausalLMWithBigram(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0-4): 5 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=1376, bias=False)
          (up_proj): Linear(in_features=1024, out_features=1376, bias=False)
          (down_proj): Linear(in_features=1376, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_h

In [6]:
# tokenization 
def tokenize(element):
    long_text = "".join(element['text']) # concatenation
    outputs = tokenizer(
        [long_text],
        truncation=True,
        return_overflowing_tokens=True,
        return_length=True,
        max_length=config.max_position_embeddings,
    )
    return {"input_ids": outputs['input_ids']}

tokenized_datasets = datasets.map(
    tokenize, batched=True, remove_columns=datasets["train"].column_names, 
    batch_size=200# , num_proc=10
)
tokenized_datasets


Map:   0%|          | 0/1000000 [00:00<?, ? examples/s]

Map:   0%|          | 0/200000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 18823
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 3750
    })
})

In [7]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [8]:
# load bi-gram probabilities
import numpy
bi_gram_probabilities = numpy.load('./bigram_probability.npy').astype(numpy.float16)
bi_gram_probabilities.shape

(32000, 32000)

# Training

In [9]:
# Overwrite _prepare_inputs() of Trainer, enabling teacher-student learning paradigm 
class MyTrainer(Trainer):
    def __init__(self,
                model: Union[PreTrainedModel, Module] = None,
                args: TrainingArguments = None,
                data_collator: Optional[DataCollator] = None,
                train_dataset: Optional[Dataset] = None,
                eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
                tokenizer: Optional[PreTrainedTokenizerBase] = None,
                model_init: Optional[Callable[[], PreTrainedModel]] = None,
                compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
                callbacks: Optional[List[TrainerCallback]] = None,
                optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
                preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
                bigram_probabilities: Optional[torch.Tensor] = None, # 
                bi_gram_scheduling_steps: Optional[int] = 0, # 
                min_bi_gram_weight: Optional[float] = 0.0, # 
                teacher_scheduling_method = '',
                ):
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)
        self.bigram_probabilities = bigram_probabilities
        self.bi_gram_scheduling_steps = bi_gram_scheduling_steps
        self.min_bi_gram_weight = min_bi_gram_weight
        self.vocabulary_size = model.vocab_size
        self.teacher_scheduling_method = teacher_scheduling_method
    
    # Not _prepare_input(), recursive calling
    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
        """
        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        self.control.should_evaluate will be set as True before evaluation, and will be set as False after evaluation.
        """
        '''===========modification start point==============='''
        # # Enable min_bi_gram_weight for the training steps after bi-gram scheduling
        # if self.bi_gram_scheduling_steps and (not self.control.should_evaluate):
        # Disenable min_bi_gram_weight for the training steps after bi-gram scheduling
        if self.state.global_step < self.bi_gram_scheduling_steps and (not self.control.should_evaluate): # About 10 times slow than without bigram
            labels = inputs['labels'].cpu() # return a copy
            # print('labels ', labels.shape) # labels  torch.Size([4, 1024])
            # Here is tricky. We should remove then last one of label. 
            ids = labels.view(-1)
            bigram_labels = self.bigram_probabilities[ids[:-1], :]
            bigram_weight = 0
            if self.teacher_scheduling_method == 'linear': 
                # ====== linear scheduling ====== 
                bigram_weight = max(self.min_bi_gram_weight, 
                                    1 - min(self.state.global_step, self.bi_gram_scheduling_steps) / self.bi_gram_scheduling_steps)
            elif self.teacher_scheduling_method == 'exponential':
                # ====== exponential scheduling ====== 
                bigram_weight = np.exp(-self.state.global_step)
                # if self.state.global_step % 50 == 49:
                #     print(f'The global step is {self.state.global_step}. The bigram weight is {bigram_weight}')
                # print('Before ', np.sum(bigram_labels, axis=1)[-10:])
                # print('Ids ', ids[-10:])
            elif self.teacher_scheduling_method == 'reciprocal':
                bigram_weight = 1/(self.state.global_step)
            # print('Inside: bigram_weight=', bigram_weight)
            # combine bi-gram probabilities and one-hot labels.
            bigram_labels *= bigram_weight # Shape  (4096, 32000)
            # one_hot_matrix = np.zeros((bigram_labels.shape[0], self.vocabulary_size), dtype=np.float16)
            # one_hot_matrix[np.arange(bigram_labels.shape[0]), ids] = 1 - bigram_weight
            # bigram_labels[:-1, :] += one_hot_matrix[1:, :]
            bigram_labels[np.arange(bigram_labels.shape[0]), ids[1:]] += (1 - bigram_weight)
            bigram_labels = np.concatenate((np.zeros((1, self.vocabulary_size)), bigram_labels), 0)
            # print('After ', np.sum(bigram_labels, axis=1)[-10:], bigram_labels.shape)
            bigram_labels = bigram_labels.reshape((inputs['labels'].shape[0], 
                                                inputs['labels'].shape[1], 
                                                self.vocabulary_size)) # Shape  (4, 1024, 32000)
            # RuntimeError: Expect the same number of specified elements per batch.
            # inputs['labels'] = torch.tensor(bigram_labels).to_sparse_csr().to(self.args.device)
            inputs['labels'] = torch.tensor(bigram_labels).to(self.args.device)
        '''===========modification end point==============='''
        inputs = self._prepare_input(inputs)
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty, your model won't be able to train on it. Double-check that your "
                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
            )
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past

        return inputs

In [10]:
from torch.nn import CrossEntropyLoss
from torch import tensor, exp
from datetime import datetime

def compute_metrics(eval_pred):
    print('Inside compute_metrics', eval_pred.predictions.shape, eval_pred.label_ids.shape)
    # Inside compute_metrics (11, 1024, 32000) (11, 1024)  numpy.ndarray
    loss_fct = CrossEntropyLoss()
    prediction = tensor(eval_pred.predictions).view(-1, 32000)
    labels = tensor(eval_pred.label_ids).view(-1)
    masked_lm_loss = exp(loss_fct(prediction, labels)) 
    return {'ppl': masked_lm_loss}
from transformers import Trainer, TrainingArguments
import os
# os.environ['WANDB_DISABLED'] = 'true' # turning off reporting to WanDB. It requires API key
output_dir="llama2-tiny-bigram-guided"
now = datetime.now()
dt_string = now.strftime("%d_%m_%H-%M-%S") # 27_12_10-09-20
log_name = output_dir + '/runs/' + dt_string + (f'_{bi_gram_scheduling_steps}_{min_bi_gram_weight}_{model_layer_number}' 
                                                if use_bi_gram else f"_no_bigram_{model_layer_number}") 
log_name += "_" + teacher_scheduling_method[0]

args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4, 
    evaluation_strategy="steps",
    eval_steps=1, 
    logging_steps=1,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=1, 
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=3, 
    fp16=True,
    push_to_hub=False, 
    report_to='tensorboard',
    logging_dir=log_name
)
if TRAINING:
    args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4, 
        evaluation_strategy="steps",
        eval_steps=3_0, 
        logging_steps=20, 
        gradient_accumulation_steps=8,
        num_train_epochs=1,
        weight_decay=0.1,
        warmup_steps=1_00, 
        lr_scheduler_type="cosine",
        learning_rate=5e-4,
        save_steps=3_000, 
        fp16=True,
        push_to_hub=False,
        report_to='tensorboard',
        logging_dir=log_name
    )

trainer = MyTrainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    # compute_metrics=compute_metrics
    bigram_probabilities = bi_gram_probabilities, # The bi-gram probability matrix
    bi_gram_scheduling_steps = bi_gram_scheduling_steps, 
    min_bi_gram_weight = min_bi_gram_weight, # The minimum bi-gram weight
    teacher_scheduling_method = teacher_scheduling_method
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [11]:
trainer.train()

  0%|          | 0/588 [00:00<?, ?it/s]

{'loss': 9.2807, 'grad_norm': 1.7205259799957275, 'learning_rate': 0.0001, 'epoch': 0.03}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 6.870324611663818, 'eval_runtime': 63.3847, 'eval_samples_per_second': 59.163, 'eval_steps_per_second': 14.799, 'epoch': 0.05}
{'loss': 7.044, 'grad_norm': 1.2406632900238037, 'learning_rate': 0.0002, 'epoch': 0.07}
{'loss': 6.0816, 'grad_norm': 0.8223330974578857, 'learning_rate': 0.0003, 'epoch': 0.1}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 5.797129154205322, 'eval_runtime': 64.0235, 'eval_samples_per_second': 58.572, 'eval_steps_per_second': 14.651, 'epoch': 0.1}
{'loss': 5.5693, 'grad_norm': 0.6335774064064026, 'learning_rate': 0.0004, 'epoch': 0.14}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 5.22218656539917, 'eval_runtime': 65.3327, 'eval_samples_per_second': 57.399, 'eval_steps_per_second': 14.357, 'epoch': 0.15}
{'loss': 5.2478, 'grad_norm': 0.7424453496932983, 'learning_rate': 0.0005, 'epoch': 0.17}
{'loss': 5.0245, 'grad_norm': 0.5844109058380127, 'learning_rate': 0.0004979306685340254, 'epoch': 0.2}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.919995307922363, 'eval_runtime': 65.2324, 'eval_samples_per_second': 57.487, 'eval_steps_per_second': 14.379, 'epoch': 0.2}
{'loss': 4.8586, 'grad_norm': 0.554875373840332, 'learning_rate': 0.0004917569311978301, 'epoch': 0.24}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.730818748474121, 'eval_runtime': 64.44, 'eval_samples_per_second': 58.194, 'eval_steps_per_second': 14.556, 'epoch': 0.25}
{'loss': 4.7476, 'grad_norm': 0.5016651749610901, 'learning_rate': 0.00048158099206287375, 'epoch': 0.27}
{'loss': 4.6482, 'grad_norm': 0.4443685710430145, 'learning_rate': 0.00046757131025753886, 'epoch': 0.31}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.604829788208008, 'eval_runtime': 64.4612, 'eval_samples_per_second': 58.175, 'eval_steps_per_second': 14.551, 'epoch': 0.31}
{'loss': 4.5654, 'grad_norm': 0.49049606919288635, 'learning_rate': 0.0004499598111849299, 'epoch': 0.34}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.508819103240967, 'eval_runtime': 65.2058, 'eval_samples_per_second': 57.51, 'eval_steps_per_second': 14.385, 'epoch': 0.36}
{'loss': 4.5161, 'grad_norm': 0.4793235957622528, 'learning_rate': 0.0004290380470785983, 'epoch': 0.37}
{'loss': 4.4636, 'grad_norm': 0.4683516323566437, 'learning_rate': 0.0004051523704568557, 'epoch': 0.41}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.42645263671875, 'eval_runtime': 64.6231, 'eval_samples_per_second': 58.029, 'eval_steps_per_second': 14.515, 'epoch': 0.41}
{'loss': 4.4152, 'grad_norm': 0.5146693587303162, 'learning_rate': 0.00037869820037745775, 'epoch': 0.44}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.364832401275635, 'eval_runtime': 64.3108, 'eval_samples_per_second': 58.311, 'eval_steps_per_second': 14.585, 'epoch': 0.46}
{'loss': 4.3784, 'grad_norm': 0.4118437170982361, 'learning_rate': 0.0003501134764128167, 'epoch': 0.48}
{'loss': 4.3342, 'grad_norm': 0.39648815989494324, 'learning_rate': 0.0003198714087129024, 'epoch': 0.51}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.306804656982422, 'eval_runtime': 65.1709, 'eval_samples_per_second': 57.541, 'eval_steps_per_second': 14.393, 'epoch': 0.51}
{'loss': 4.2854, 'grad_norm': 0.40555375814437866, 'learning_rate': 0.0002884726441760155, 'epoch': 0.54}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.258877754211426, 'eval_runtime': 65.0735, 'eval_samples_per_second': 57.627, 'eval_steps_per_second': 14.414, 'epoch': 0.56}
{'loss': 4.2619, 'grad_norm': 0.4265148341655731, 'learning_rate': 0.0002564369784137472, 'epoch': 0.58}
{'loss': 4.2451, 'grad_norm': 0.36791926622390747, 'learning_rate': 0.00022429475071565987, 'epoch': 0.61}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.2207465171813965, 'eval_runtime': 64.4333, 'eval_samples_per_second': 58.2, 'eval_steps_per_second': 14.558, 'epoch': 0.61}
{'loss': 4.2181, 'grad_norm': 0.3923856019973755, 'learning_rate': 0.00019257806446705113, 'epoch': 0.65}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.182222366333008, 'eval_runtime': 64.1214, 'eval_samples_per_second': 58.483, 'eval_steps_per_second': 14.629, 'epoch': 0.66}
{'loss': 4.2031, 'grad_norm': 0.40806257724761963, 'learning_rate': 0.0001618119783627263, 'epoch': 0.68}
{'loss': 4.1582, 'grad_norm': 0.38628047704696655, 'learning_rate': 0.0001325058142431701, 'epoch': 0.71}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.152008533477783, 'eval_runtime': 64.658, 'eval_samples_per_second': 57.997, 'eval_steps_per_second': 14.507, 'epoch': 0.71}
{'loss': 4.1593, 'grad_norm': 0.40800344944000244, 'learning_rate': 0.00010514472544885909, 'epoch': 0.75}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.1272501945495605, 'eval_runtime': 64.8984, 'eval_samples_per_second': 57.783, 'eval_steps_per_second': 14.453, 'epoch': 0.76}
{'loss': 4.1289, 'grad_norm': 0.3945483863353729, 'learning_rate': 8.018166527567672e-05, 'epoch': 0.78}
{'loss': 4.1248, 'grad_norm': 0.3674790561199188, 'learning_rate': 5.802988849085e-05, 'epoch': 0.82}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.106303691864014, 'eval_runtime': 65.1958, 'eval_samples_per_second': 57.519, 'eval_steps_per_second': 14.387, 'epoch': 0.82}
{'loss': 4.1076, 'grad_norm': 0.40021446347236633, 'learning_rate': 3.905611004420359e-05, 'epoch': 0.85}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.09246826171875, 'eval_runtime': 65.1437, 'eval_samples_per_second': 57.565, 'eval_steps_per_second': 14.399, 'epoch': 0.87}
{'loss': 4.0963, 'grad_norm': 0.34703195095062256, 'learning_rate': 2.3574434229882145e-05, 'epoch': 0.88}
{'loss': 4.0898, 'grad_norm': 0.38489699363708496, 'learning_rate': 1.1841154799154375e-05, 'epoch': 0.92}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.084216594696045, 'eval_runtime': 65.1797, 'eval_samples_per_second': 57.533, 'eval_steps_per_second': 14.391, 'epoch': 0.92}
{'loss': 4.1002, 'grad_norm': 0.3758687973022461, 'learning_rate': 4.0505121066209125e-06, 'epoch': 0.95}


  0%|          | 0/938 [00:00<?, ?it/s]

{'eval_loss': 4.08123779296875, 'eval_runtime': 64.2768, 'eval_samples_per_second': 58.341, 'eval_steps_per_second': 14.593, 'epoch': 0.97}
{'loss': 4.0841, 'grad_norm': 0.37361082434654236, 'learning_rate': 3.314775287923677e-07, 'epoch': 0.99}
{'train_runtime': 2331.2378, 'train_samples_per_second': 8.074, 'train_steps_per_second': 0.252, 'train_loss': 4.730655728554239, 'epoch': 1.0}


TrainOutput(global_step=588, training_loss=4.730655728554239, metrics={'train_runtime': 2331.2378, 'train_samples_per_second': 8.074, 'train_steps_per_second': 0.252, 'train_loss': 4.730655728554239, 'epoch': 1.0})

In [12]:
# import os
# os.system("shutdown -t  60 ")