# Language Model Main Functions and Controller 

> For an in-depth tutorial, click [here for Roberta language model](https://anhquan0412.github.io/that-nlp-library/model_lm_roberta_tutorial.html), and [here for GPT language model](https://anhquan0412.github.io/that-nlp-library/model_lm_gpt2_tutorial.html)

- skip_showdoc: true
- skip_exec: true

In [None]:
#| default_exp model_lm_main

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from __future__ import annotations
import os, sys
from transformers import Trainer, TrainingArguments, AutoConfig
from datasets import DatasetDict,Dataset
import torch
import gc
import math
from functools import partial
import evaluate
import numpy as np
from that_nlp_library.utils import *
from that_nlp_library.text_main_lm import TextDataLMController
from that_nlp_library.text_main_lm_streaming import TextDataLMControllerStreaming
from transformers import pipeline

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
from transformers import AutoModelForCausalLM, AutoModelForMaskedLM

In [None]:
#| export
def language_model_init(model_class, # Model's class object, e.g. AutoModelForMaskedLM
                        cpoint_path=None, # Either model string name on HuggingFace, or the path to model checkpoint. Put `None` to train from scratch
                        config=None, # Model config. If not provided, AutoConfig is used to load config from cpoint_path                        
                        device=None, # Device to train on
                        seed=None, # Random seed
                       ):
    """
    To initialize a language model, either masked or casual
    """
    if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if config is None:
        config = AutoConfig.from_pretrained(cpoint_path)

    if seed:
        seed_everything(seed)
    if cpoint_path:
        model = model_class.from_pretrained(cpoint_path,config=config).to(device)
    else:
        print('Initiate a new language model from scratch')
        model = model_class.from_config(config)
    
    print(f'Total parameters: {sum(param.numel() for param in model.parameters())}')
    print(f'Total trainable parameters: {sum(param.numel() for param in model.parameters() if param.requires_grad)}')
    
    return model

In [None]:
show_doc(language_model_init)

---

[source](https://github.com/anhquan0412/that-nlp-library/blob/main/that_nlp_library/model_lm_main.py#L24){target="_blank" style="float:right; font-size:smaller"}

### language_model_init

>      language_model_init (model_class, cpoint_path=None, config=None,
>                           device=None, seed=None)

To initialize a language model, either masked or casual

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| model_class |  |  | Model's class object, e.g. AutoModelForMaskedLM |
| cpoint_path | NoneType | None | Either model string name on HuggingFace, or the path to model checkpoint. Put `None` to train from scratch |
| config | NoneType | None | Model config. If not provided, AutoConfig is used to load config from cpoint_path |
| device | NoneType | None | Device to train on |
| seed | NoneType | None | Random seed |

In [None]:
_model1 = language_model_init(AutoModelForMaskedLM,
                              'roberta-base')
_model1

Total parameters: 124697433
Total trainable parameters: 124697433


RobertaForMaskedLM(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): 

In [None]:
_model1 = language_model_init(AutoModelForMaskedLM,
                              'nguyenvulebinh/envibert')
_model1

Total parameters: 70764377
Total trainable parameters: 70764377


RobertaForMaskedLM(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(59993, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-5): 6 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): La

In [None]:
_model2 = language_model_init(AutoModelForCausalLM,
                              'gpt2')
_model2

Total parameters: 124439808
Total trainable parameters: 124439808


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [None]:
#| export
def compute_lm_accuracy(eval_preds, # An EvalPrediction object from HuggingFace 
                        is_mlm, # if this is masked language model, set to `True`. If this is casual language model, set to `False`
                     ):
    
    """    
    Reference: https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py#L592C35-L592C35
    """
    metric = evaluate.load("accuracy")
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics
    if is_mlm:
        labels = labels.reshape(-1)
        preds = preds.reshape(-1)
        mask = labels != -100
        labels = labels[mask]
        preds = preds[mask]
        return metric.compute(predictions=preds, references=labels)
    else:
        # we need to shift the labels
        labels = labels[:, 1:].reshape(-1)
        preds = preds[:, :-1].reshape(-1)
        return metric.compute(predictions=preds, references=labels)

In [None]:
#| export
def preprocess_lm_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1) 

In [None]:
#| export
def finetune_lm(lr, # Learning rate
                bs, # Batch size
                wd, # Weight decay
                epochs, # Number of epochs
                ddict, # The HuggingFace datasetdict
                tokenizer,# HuggingFace tokenizer
                o_dir = './tmp_weights', # Directory to save weights
                save_checkpoint=False, # Whether to save weights (checkpoints) to o_dir
                model=None, # NLP model
                model_init=None, # A function to initialize model
                data_collator=None, # HuggingFace data collator
                compute_metrics=None, # A function to compute metric, default to `compute_lm_accuracy`
                grad_accum_steps=2, # The batch at each step will be divided by this integer and gradient will be accumulated over gradient_accumulation_steps steps.
                lr_scheduler_type='cosine',  # The scheduler type to use. Including: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup
                warmup_ratio=0.1, # The warmup ratio for some lr scheduler
                no_valid=False, # Whether there is a validation set or not
                seed=None, # Random seed
                report_to='none', # The list of integrations to report the results and logs to. Supported platforms are "azure_ml", "comet_ml", "mlflow", "neptune", "tensorboard","clearml" and "wandb". Use "all" to report to all integrations installed, "none" for no integrations.
                trainer_class=None, # You can include the class name of your custom trainer here
            ):
    "The main model training/finetuning function"
    torch.cuda.empty_cache()
    gc.collect()
    
    if seed:
        seed_everything(seed)
        
    training_args = TrainingArguments(o_dir, 
                                     learning_rate=lr, 
                                     warmup_ratio=warmup_ratio,
                                     lr_scheduler_type=lr_scheduler_type, 
                                     fp16=True,
                                     do_train=True,
                                     do_eval= not no_valid,
                                     evaluation_strategy="no" if no_valid else "epoch", 
                                     save_strategy="epoch" if save_checkpoint else 'no',
                                     overwrite_output_dir=True,
                                     gradient_accumulation_steps=grad_accum_steps,
                                     per_device_train_batch_size=bs, 
                                     per_device_eval_batch_size=bs,
                                     num_train_epochs=epochs, weight_decay=wd,
                                     report_to=report_to,
                                     logging_dir=os.path.join(o_dir, 'log') if report_to!='none' else None,
                                     logging_steps = len(ddict["train"]) // bs,
                                     )

    # instantiate trainer
    trainer_class = Trainer if trainer_class is None else trainer_class
    trainer = trainer_class(
        model=model,
        model_init=model_init if model is None else None,
        args=training_args,
        train_dataset=ddict['train'],#.shard(200, 0)
        eval_dataset=ddict['validation'] if not no_valid else None,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics if not no_valid else None,
        preprocess_logits_for_metrics=preprocess_lm_logits_for_metrics if not no_valid else None
    )
    
    trainer.train()
    return trainer

In [None]:
show_doc(finetune_lm)

---

[source](https://github.com/anhquan0412/that-nlp-library/blob/main/that_nlp_library/model_lm_main.py#L85){target="_blank" style="float:right; font-size:smaller"}

### finetune_lm

>      finetune_lm (lr, bs, wd, epochs, ddict, tokenizer, o_dir='./tmp_weights',
>                   save_checkpoint=False, model=None, model_init=None,
>                   data_collator=None, compute_metrics=None,
>                   grad_accum_steps=2, lr_scheduler_type='cosine',
>                   warmup_ratio=0.1, no_valid=False, seed=None,
>                   report_to='none', trainer_class=None)

The main model training/finetuning function

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| lr |  |  | Learning rate |
| bs |  |  | Batch size |
| wd |  |  | Weight decay |
| epochs |  |  | Number of epochs |
| ddict |  |  | The HuggingFace datasetdict |
| tokenizer |  |  | HuggingFace tokenizer |
| o_dir | str | ./tmp_weights | Directory to save weights |
| save_checkpoint | bool | False | Whether to save weights (checkpoints) to o_dir |
| model | NoneType | None | NLP model |
| model_init | NoneType | None | A function to initialize model |
| data_collator | NoneType | None | HuggingFace data collator |
| compute_metrics | NoneType | None | A function to compute metric, default to `compute_lm_accuracy` |
| grad_accum_steps | int | 2 | The batch at each step will be divided by this integer and gradient will be accumulated over gradient_accumulation_steps steps. |
| lr_scheduler_type | str | cosine | The scheduler type to use. Including: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup |
| warmup_ratio | float | 0.1 | The warmup ratio for some lr scheduler |
| no_valid | bool | False | Whether there is a validation set or not |
| seed | NoneType | None | Random seed |
| report_to | str | none | The list of integrations to report the results and logs to. Supported platforms are "azure_ml", "comet_ml", "mlflow", "neptune", "tensorboard","clearml" and "wandb". Use "all" to report to all integrations installed, "none" for no integrations. |
| trainer_class | NoneType | None | You can include the class name of your custom trainer here |

In [None]:
#| hide
#| export
def extract_hidden_states(batch,
                          model=None, # NLP model
                          model_input_names=['input_ids', 'token_type_ids', 'attention_mask'], # Model required inputs, from tokenizer.model_input_names
                          data_collator=None, # HuggingFace data collator
                          state_name='last_hidden_state', # Name of the state to extract
                          state_idx=0, # The index (or indices) of the state to extract
                          device = None, # device that the model is trained on
                          ):
    state_idx = val2iterable(state_idx)
    
    if data_collator is not None:    
# --- Convert from  
# {'input_ids': [tensor([    0, 10444,   244, 14585,   125,  2948,  5925,   368,     2]), 
#                tensor([    0, 16098,  2913,   244,   135,   198, 34629,  6356,     2])]
# 'attention_mask': [tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 
#                    tensor([1, 1, 1, 1, 1, 1, 1, 1, 1])]
#                    }
# --- to
# [{'input_ids': tensor([    0, 10444,   244, 14585,   125,  2948,  5925,   368,     2]),
#   'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1])},
#  {'input_ids': tensor([    0, 16098,  2913,   244,   135,   198, 34629,  6356,     2]),
#   'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1])}]

        # remove string text, due to transformer new version       
        collator_inp = []
        ks = [k for k in batch.keys() if k in model_input_names]
        vs = [batch[k] for k in ks]
        for pair in zip(*vs):
            collator_inp.append({k:v for k,v in zip(ks,pair)})
        
        batch = data_collator(collator_inp)
    
    inputs = {k:v.to(device) for k,v in batch.items()
              if k in model_input_names}
            
    # switch to eval mode for evaluation
    if model.training:
        model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        states = outputs[state_name]
        for i in state_idx:
            if isinstance(states,tuple):
                states = states[i]
            else:
                states = states[:,i]
    # Switch back to train mode
    if not model.training:
        model.train()

    return {state_name:states.cpu().numpy()}

In [None]:
#| export
class ModelLMController():
    def __init__(self,
                 model, # NLP language model
                 data_store=None, # a TextDataLMController/TextDataLMControllerStreaming object
                 seed=None, # Random seed
                ):
        self.model = model
        self.data_store = data_store
        self.seed = seed
        
    def fit(self,
            epochs, # Number of epochs
            learning_rate, # Learning rate
            ddict=None, # DatasetDict to fit (will override data_store)
            compute_metrics=None, # A function to compute metric, default to `compute_lm_accuracy`
            batch_size=16, # Batch size
            weight_decay=0.01, # Weight decay
            lr_scheduler_type='cosine', # The scheduler type to use. Including: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup
            warmup_ratio=0.1, # The warmup ratio for some lr scheduler
            o_dir = './tmp_weights', # Directory to save weights
            save_checkpoint=False, # Whether to save weights (checkpoints) to o_dir
            hf_report_to='none', # The list of HuggingFace-allowed integrations to report the results and logs to
            grad_accum_steps=2, # Gradient will be accumulated over gradient_accumulation_steps steps.
            tokenizer=None, # Tokenizer (to override one in ```data_store```)
            data_collator=None, # Data Collator (to override one in ```data_store```)
            is_mlm=None, # Whether this is masked LM or casual LM
            trainer_class=None, # You can include the class name of your custom trainer here
           ):
        
        if tokenizer is None: tokenizer=check_and_get_attribute(self.data_store,'tokenizer')
        if data_collator is None: data_collator=check_and_get_attribute(self.data_store,'data_collator')
        if ddict is None: ddict = check_and_get_attribute(self.data_store,'main_ddict')
        if is_mlm is None: is_mlm = check_and_get_attribute(self.data_store,'is_mlm')
            
        if compute_metrics is None:
            compute_metrics=partial(compute_lm_accuracy,is_mlm=is_mlm)
        
        if len(set(ddict.keys()) & set(['train','training']))==0:
            raise ValueError("Missing the following key for DatasetDict: train/training")
        no_valid = len(set(ddict.keys()) & set(['validation','val','valid']))==0
        
        trainer = finetune_lm(learning_rate,
                              batch_size,
                              weight_decay,
                              epochs,
                              ddict,
                              tokenizer,
                              o_dir,
                              save_checkpoint=save_checkpoint,
                              model=self.model,
                              data_collator=data_collator,
                              compute_metrics=compute_metrics,
                              grad_accum_steps=grad_accum_steps,
                              lr_scheduler_type=lr_scheduler_type,
                              warmup_ratio=warmup_ratio,
                              no_valid=no_valid,
                              seed=self.seed,
                              trainer_class=trainer_class,
                              report_to=hf_report_to)
        
        if not no_valid:
            eval_results = trainer.evaluate()
            try:
                perplexity = math.exp(eval_results["eval_loss"])
                print(f'Perplexity on validation set: {perplexity:.3f}')
            except OverflowError:
                perplexity = float("inf")
                print(f'Perplexity on validation set: {perplexity}')
            eval_results['perplexity'] = perplexity
        
        self.trainer = trainer
        
    def predict_raw_text(self,
                         content:dict|list|str, # Either a single sentence, list of sentence or a dictionary where keys are metadata, values are list
                         print_result=True, # To whether print the result in readable format, or get the result returned
                         **kwargs, # keyword arguments for HuggingFace's text-generation (for clm) or fill-mask (for mlm)
                        ):
        # Example of kwargs for text-generation:
        # https://huggingface.co/docs/transformers/v4.33.2/en/main_classes/text_generation#transformers.GenerationMixin.generate
        
        if not isinstance(self.data_store,(TextDataLMController,TextDataLMControllerStreaming)) or not self.data_store._processed_call:
            raise ValueError('This functionality needs a TextDataController object which has processed some training data')
        test_dset = self.data_store.prepare_test_dataset_from_raws(content)
        result = self.predict_ddict(test_dset,**kwargs)
        if print_result:
            is_mlm = check_and_get_attribute(self.data_store,'is_mlm')
            for preds in result:
                for pred in preds:
                    if is_mlm:
                        print(f"Score: {pred['score']:.3f} >>> {pred['sequence']}")
                    else:
                        print(f">>> {pred['generated_text']}")
                print('-'*20)
        else:
            return result
    
    def predict_ddict(self,
                      dset:Dataset, # A processed and tokenized Dataset
                      **kwargs, # keyword arguments for HuggingFace's text-generation (for clm) or fill-mask (for mlm)
                     ):
        is_mlm = check_and_get_attribute(self.data_store,'is_mlm')
        tokenizer=check_and_get_attribute(self.data_store,'tokenizer')
        main_text=check_and_get_attribute(self.data_store,'main_text')
        _task = 'fill-mask' if is_mlm else 'text-generation'
        pipeline_obj = pipeline(_task,model=self.model,tokenizer=tokenizer,device=self.model.device)
        str_list = dset[main_text]
        if _task=='fill-mask':
            all_tfms = self.data_store.content_tfms 
            all_tfms = partial(func_all,functions=all_tfms) if len(all_tfms) else lambda x: x
            mask_str = all_tfms(tokenizer.mask_token)
            str_list = [str(s).replace(mask_str,tokenizer.mask_token) for s in str_list]
        return [pipeline_obj(s,**kwargs) for s in str_list]
    
    def get_hidden_states_from_raw_text(self,
                                        content:dict|list|str, # Either a single sentence, list of sentence or a dictionary where keys are metadata, values are list
                                        state_name, # Name of the (hidden) state to extract
                                        state_idx=0, # The index (or indices) of the state to extract. For `hidden_states`, accept multiple values
                                       ):
        if not isinstance(self.data_store,(TextDataLMController,TextDataLMControllerStreaming)) or not self.data_store._processed_call:
            raise ValueError('This functionality needs a TextDataController object which has processed some training data')
        dset = self.data_store.prepare_test_dataset_from_raws(content,do_tokenize=True)
        return self.get_hidden_states(dset,
                                      state_name=state_name,
                                      state_idx=state_idx,
                                      batch_size=1
                                     )

        
    def get_hidden_states_from_raw_dset(self,
                                        dset: Dataset, # A raw HuggingFace dataset
                                        state_name, # Name of the (hidden) state to extract
                                        state_idx=0, # The index (or indices) of the state to extract. For `hidden_states`, accept multiple values
                                        batch_size=16, # GPU batch size
                                       ):
        if not isinstance(self.data_store,(TextDataLMController,TextDataLMControllerStreaming)) or not self.data_store._processed_call:
            raise ValueError('This functionality needs a TextDataController object which has processed some training data')
        dset = self.data_store.prepare_test_dataset(dset,do_tokenize=True)
        return self.get_hidden_states(dset,
                                      state_name=state_name,
                                      state_idx=state_idx,
                                      batch_size=batch_size
                                     )
        
        
    def get_hidden_states(self,
                          ddict:DatasetDict|Dataset=None, # A processed and tokenized DatasetDict/Dataset (will override one in ```data_store```)
                          ds_type='test', # The split of DatasetDict to predict
                          state_name='last_hidden_state', # Name of the (hidden) state to extract
                          state_idx=0, # The index (or indices) of the state to extract. For `hidden_states`, accept multiple values
                          batch_size=16, # GPU batch size
                         ):
        
        tokenizer=check_and_get_attribute(self.data_store,'tokenizer')
        if ddict is None: ddict = check_and_get_attribute(self.data_store,'main_ddict')
        if isinstance(ddict,DatasetDict):
            if ds_type not in ddict.keys():
                raise ValueError(f'{ds_type} is not in the given DatasetDict')
            ddict = ddict[ds_type]
        
        ddict.set_format("torch",
                        columns=tokenizer.model_input_names)
        
        results = ddict.map(partial(extract_hidden_states,
                                 model=self.model,
                                 model_input_names = tokenizer.model_input_names,
                                 state_name=state_name,
                                 state_idx=state_idx,
                                 device = self.model.device),
                           batched=True,
                           batch_size=batch_size
                          )
        results.set_format('numpy')
        return results


In [None]:
show_doc(ModelLMController)

---

[source](https://github.com/anhquan0412/that-nlp-library/blob/main/that_nlp_library/model_lm_main.py#L202){target="_blank" style="float:right; font-size:smaller"}

### ModelLMController

>      ModelLMController (model, data_store=None, seed=None)

Initialize self.  See help(type(self)) for accurate signature.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| model |  |  | NLP language model |
| data_store | NoneType | None | a TextDataLMController/TextDataLMControllerStreaming object |
| seed | NoneType | None | Random seed |

In [None]:
show_doc(ModelLMController.fit)

---

[source](https://github.com/anhquan0412/that-nlp-library/blob/main/that_nlp_library/model_lm_main.py#L212){target="_blank" style="float:right; font-size:smaller"}

### ModelLMController.fit

>      ModelLMController.fit (epochs, learning_rate, ddict=None,
>                             compute_metrics=None, batch_size=16,
>                             weight_decay=0.01, lr_scheduler_type='cosine',
>                             warmup_ratio=0.1, o_dir='./tmp_weights',
>                             save_checkpoint=False, hf_report_to='none',
>                             grad_accum_steps=2, tokenizer=None,
>                             data_collator=None, is_mlm=None,
>                             trainer_class=None)

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| epochs |  |  | Number of epochs |
| learning_rate |  |  | Learning rate |
| ddict | NoneType | None | DatasetDict to fit (will override data_store) |
| compute_metrics | NoneType | None | A function to compute metric, default to `compute_lm_accuracy` |
| batch_size | int | 16 | Batch size |
| weight_decay | float | 0.01 | Weight decay |
| lr_scheduler_type | str | cosine | The scheduler type to use. Including: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup |
| warmup_ratio | float | 0.1 | The warmup ratio for some lr scheduler |
| o_dir | str | ./tmp_weights | Directory to save weights |
| save_checkpoint | bool | False | Whether to save weights (checkpoints) to o_dir |
| hf_report_to | str | none | The list of HuggingFace-allowed integrations to report the results and logs to |
| grad_accum_steps | int | 2 | Gradient will be accumulated over gradient_accumulation_steps steps. |
| tokenizer | NoneType | None | Tokenizer (to override one in ```data_store```) |
| data_collator | NoneType | None | Data Collator (to override one in ```data_store```) |
| is_mlm | NoneType | None | Whether this is masked LM or casual LM |
| trainer_class | NoneType | None | You can include the class name of your custom trainer here |

In [None]:
show_doc(ModelLMController.predict_raw_text)

---

[source](https://github.com/anhquan0412/that-nlp-library/blob/main/that_nlp_library/model_lm_main.py#L274){target="_blank" style="float:right; font-size:smaller"}

### ModelLMController.predict_raw_text

>      ModelLMController.predict_raw_text (content:dict|list|str,
>                                          print_result=True, **kwargs)

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| content | dict \| list \| str |  | Either a single sentence, list of sentence or a dictionary where keys are metadata, values are list |
| print_result | bool | True | To whether print the result in readable format, or get the result returned |
| kwargs |  |  |  |

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()