In [None]:
!pip3 install transformers zstandard jsonlines peft wandb bitsandbytes -q
!pip3 install accelerate datasets sentencepiece langchain torch_xla[tpuvm] -q
!pip uninstall -y tensorflow
!pip install tensorflow-cpu -q
!git clone https://github.com/IsNoobgrammer/Pytorch-Optimizers optims

In [None]:
get_ipython().kernel.do_shutdown(True)
### for good measures restart kernel

**Tokens?**

In [None]:
!huggingface-cli login --token <hf_read_token> #for downloading gated models
# import wandb
# wandb.login()

**Sharding Module for different Architechture**

In [1]:
%%writefile spmd_util.py
import math
from dataclasses import dataclass, field
from typing import List, Optional
from collections import defaultdict
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, GPT2Config, MistralConfig, Qwen2Config, MixtralConfig, PhiConfig,GemmaConfig
)

# ends with $ to prevent sharding lora parameters


T5_RULES = (
    # embeddings
    ("shared$", ("mp", "fsdp")),
    ("embed_tokens$", ("mp", "fsdp")),
    
    # attention
    ("q$", ("fsdp", "mp")),
    ("k$", ("fsdp", "mp")),
    ("v$", ("fsdp", "mp")),
    ("o$", ("mp", "fsdp")),

    # mlp
    ("w$", ("fsdp", "mp")),
    ("wi_0$", ("fsdp", "mp")),
    ("wi_1$", ("fsdp", "mp")),
    ("wo$", ("mp", "fsdp")),

    # seq2seq lm head
    ("lm_head", ("fsdp", "mp")),
)

QWEN_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )
GPT2_RULES = (
    # embeddings
    ("wte", ("mp", "fsdp")), 
    ("wpe", ("mp", "fsdp")),
    
    # attention
    ("c_attn", ("fsdp", "mp")),
    ("c_proj", ("mp", "fsdp")),
    
    # mlp
    ("c_fc", ("fsdp", "mp")), 
    ("c_proj", ("mp", "fsdp")),
    
    # output 
    ("ln_f", ("fsdp", "mp")),
)
MISTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.(gate_proj|up_proj)", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("lm_head", ("fsdp", "mp")),
    )


PHI_RULES = (
    ### (regex) linear modules, (list[sharding methods]) )
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.dense", ("mp", "fsdp")),
    ("mlp\\.fc2", ("mp", "fsdp")),  
    ("mlp\\.fc1", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    
)

LLAMA_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )

GPTNEOX_RULES = (
    # embeddings
    ("gpt_neox\\.embed_in", ("mp", "fsdp")),
    # atention
    ("attention\\.query_key_value$", ("fsdp", "mp")),
    ("attention\\.dense$", ("mp", "fsdp")),
    # mlp
    ("mlp\\.dense_h_to_4h$", ("fsdp", "mp")),
    ("mlp\\.dense_4h_to_h$", ("mp", "fsdp")),
    # output
    ("embed_out", ("fsdp", "mp")),
)



MIXTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("w1", ("fsdp", "mp")),
    ("w2", ("mp", "fsdp")),
    ("w3", ("fsdp", "mp")),
    ("gate", ("mp", "fsdp")),
    ("lm_head", ("fsdp", "mp")),
    )

GEMMA_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
)
    
ALL_RULES = [
    (GPTNeoXConfig, GPTNEOX_RULES),
    (T5Config, T5_RULES),
    (LlamaConfig, LLAMA_RULES),
    (GPT2Config, GPT2_RULES),
    (MistralConfig, MISTRAL_RULES),
    (Qwen2Config, QWEN_RULES),
    (MixtralConfig, MIXTRAL_RULES),
    (PhiConfig,PHI_RULES),
    (GemmaConfig,GEMMA_RULES),
]


def find_rule(model):
    for config, rule in ALL_RULES:
        x1=(str(config).split("."))[-1]
        x2=(str(model.config.__class__).split("."))[-1]
#         print(x1,x2)
        if x1.lower()==x2.lower():
            return rule
    raise Exception("unsupported model to partitioning")

strkey2id = {
    "dp": 0, ## usefull for sharding inputs
    "fsdp": 1, ## Pytorch-Xla (2D-sharding) axis to shard data (mostly mesh shape will be (8,1)) data will be sharded 8 way 
    "mp": 2 ## axis to shard model model will be sharded one way 
               ## Recommened checking Pytorch-tpu/transfomers on github (xla-fork of transformers)
}

def partition_module(model, mesh, device=xm.xla_device(), verbose=False):
    partition_specs = find_rule(model)
    rule = [(k, tuple([strkey2id[x] for x in v])) for k, v in partition_specs]
        
    # print(rule)

    for name, module in model.named_modules():
        module.to(device)
#         print(name, module.__class__.__name__)
        if isinstance(module, (nn.Embedding, nn.Linear)):
            for rule_pattern, spec in rule:
                if re.findall(rule_pattern, name.lower())  : # and ("lora" not in name.lower()):
                    if verbose:
                        print("match", rule_pattern, name)
                    
                    xs.mark_sharding(module.weight, mesh, spec)
                    break


Overwriting spmd_util.py


**Required Libs**

In [2]:
import os
import pandas as pd
import numpy as np
import datasets
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch
import torch.nn as nn
import torch_xla.test.test_utils as test_utils
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
 AutoTokenizer, AutoModelForCausalLM, set_seed, DataCollatorWithPadding, AutoConfig 
)

from transformers import logging as hf_logging
import torch_xla.runtime as xr

xr.use_spmd()

from torch_xla.experimental.xla_sharding import Mesh

from peft import LoraConfig, TaskType, get_peft_model 
from datasets import  load_dataset, concatenate_datasets
from tqdm import tqdm

from torch.utils.data import Dataset as TorchDataset
from torch_xla.utils.checkpoint import checkpoint

try:
    !export USE_TORCH=True #If we don't do this, transformers will seemingly bork the session upon import. Really weird error.
    os.environ["PJRT_DEVICE"] = "TPU"
    os.environ.pop('TPU_PROCESS_ADDRESSES')
    os.environ.pop('CLOUD_TPU_TASK_ID')
    hf_logging.set_verbosity_error() # It can still display warnings which is a bit annoying but whatever
except:
    pass


  from .autonotebook import tqdm as notebook_tqdm


**Configuration**

In [3]:
MAX_INPUT=4096 #128*32
MODEL = "fhai50032/RolePlayLake-7B" #You should be able to use 7B model with no changes! There should be enough HBM
SAVED_MODEL = "fhai50032/RP-check-TPU"
# !export XLA_TENSOR_ALLOCATOR_MAXSIZE=1000000

In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL)
if 'pad_token' not in tokenizer.special_tokens_map:
  tokenizer.pad_token=tokenizer.eos_token


print(f"Tokens :\n {tokenizer.special_tokens_map} \n\n")

Tokens :

 {'bos_token': '<bos>', 'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']} 






In [5]:
class ConversationDataset(TorchDataset):
    def __init__(self, tokenizer, max_length=1024, dataset=None):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        messages = self.dataset[idx]["conversations"]
        text = ""
        for message in messages:
            role = message["from"]
            if role == "system":
                text += f"<|im_start|>system\n{message['value']}<|im_end|>\n"
            if role in ["human","user"]:
                text += f"<|im_start|>user\n{message['value']}<|im_end|>\n"
            if role == "function-call":
                text += f"<|im_start|>call\n{message['value']}<|im_end|>\n"
            if role == "function-response":
                text += f"<|im_start|>function\n{message['value']}<|im_end|>\n"
            if role in ["gpt","assistant"]:
                text += f"<|im_start|>assistant\n{message['value']}"
        input_ids = self.tokenizer(text, add_special_tokens=True, max_length=self.max_length, truncation=True, padding="max_length", return_attention_mask=True, return_tensors="pt")
        return {
            "input_ids": input_ids["input_ids"].squeeze(0),
            "labels": input_ids["input_ids"].squeeze(0),
            "attention_mask":input_ids["attention_mask"].squeeze(0),
        }

In [6]:
train_dataset="fhai50032/magicoder-oss-instruct-sharegpt-75k"
test_dataset="fhai50032/magicoder-oss-instruct-sharegpt-75k"

train_data = load_dataset(train_dataset, split="train").shuffle(seed=69)
val = (load_dataset(test_dataset, split="train[:640]")).shuffle(seed=420)

In [7]:
len(train_data)

6866

In [8]:
FLAGS = {'MAX_INPUT': MAX_INPUT,
         'LOGGING_STEPS': 1,
         'NUM_EPOCHS': 1,
         'PAUSE_STEPS':1000, # asks to exit training after x steps , #todo checkpoint saving me no lazy
         'MAX_STEPS': -1,#Ooverides num epochs
         'BATCH_SIZE': 2, #Making batch_size lower then 8 will result in slower training, but will allow for larger models\context. Fortunately, we have 128GBs. Setting higher batch_size doesn't seem to improve time.
          'LEN_TRAIN_DATA': len(train_data),
         'VAL_STEPS': 20,
         'VAL_BATCH': 5,
         'GRAD_ACCUMULATION_STEP':1,
         'MAX_GRAD_CLIP':1,
        'LEARNING_RATE':6e-5,
         'WARMUP_RATIO':0.01,
         'OPTIMIZER':'SM3', # default = 'adamw'  options->  ['adamw','SM3','came','adafactor','lion']           
         'SCHEDULAR':'cosine', # default= 'cosine'     options:-> ['linear','cosine']
         'WEIGHT_DECAY':0.1,
         'TRAIN_DATASET':train_dataset,
         "TEST_DATASET":test_dataset,
         'WANDB':True,
        'PROJECT':'RP-Coder',
        } # Indian pun :) 

In [None]:
FLAGS

**Quantization When??**

In [10]:
model = AutoModelForCausalLM.from_pretrained(MODEL,torch_dtype=torch.bfloat16) 
# model._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=checkpoint) 
# gradient checkpointing is not properly setup needs to do barieer optimization

Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.08s/it]


**LoRA Applicable**

In [11]:
ls=LoraConfig(
    r = 48, # Lora Rank ,I would prefer 8-32 for smaller models like 7b
    target_modules = ['q_proj', 'down_proj', 'up_proj', 'o_proj', 'v_proj', 'gate_proj', 'k_proj'],
    lora_alpha = 16, #weight_scaling
    lora_dropout = 0.05, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimize
    # modules_to_save = ["lm_head", "embed_tokens"] ## if you use new chat formats or embedding tokens
)
model = get_peft_model(model, ls)
model.print_trainable_parameters()


  warn("The installed version of bitsandbytes was compiled without GPU support. "


/usr/local/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32

trainable params: 150,011,904 || all params: 8,687,692,800 || trainable%: 1.7267174087923551


**Data-Distributer**

In [12]:
train_data = ConversationDataset(tokenizer, dataset=train_data, max_length=FLAGS['MAX_INPUT'])
val = ConversationDataset(tokenizer, dataset=val)
train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_data, num_replicas=8, rank=xm.get_ordinal(), shuffle=True,drop_last=True)
training_loader = torch.utils.data.DataLoader(train_data, batch_size=FLAGS["BATCH_SIZE"], sampler=train_sampler)
val_sampler = torch.utils.data.distributed.DistributedSampler(
    val, num_replicas=8, rank=xm.get_ordinal(), shuffle=True,drop_last=True)
testing_loader = torch.utils.data.DataLoader(val, batch_size=FLAGS["BATCH_SIZE"], sampler=val_sampler)

print(f"Max Steps:- {len(training_loader)}  , eFFECTIVE bATCH size {8*FLAGS['BATCH_SIZE']} Input")
print(f"Val Size:- {len(testing_loader)}  , eFFECTIvE bATCH size {8*FLAGS['BATCH_SIZE']} Input")
FLAGS['STEPS']=len(training_loader)
FLAGS['BATCH_DATA']=FLAGS['BATCH_SIZE']*8 ## 8 CORES ON TPU 
# print(device)

Max Steps:- 430  , Each Step has 16 inputs


In [None]:
print(val[0]['input_ids'])
c=0
for step, batch in enumerate(training_loader):

    print(step)
    print(tokenizer.decode(batch['input_ids'][0]))
    break
# print(tokenizer.decode(val[0]['input_ids']))

In [13]:
def get_nb_trainable_parameters(model):
        r"""
        Returns the number of trainable parameters and number of all parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in model.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel

            # Due to the design of 4bit linear layers from bitsandbytes
            # one needs to multiply the number of parameters by 2 to get
            # the correct number of parameters
            if param.__class__.__name__ == "Params4bit":
                num_params = num_params * 2

            all_param += num_params
            if param.requires_grad:
                trainable_params += num_params

        return trainable_params, all_param
def print_trainable_parameters(model):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params, all_param = get_nb_trainable_parameters(model)
        
        print(
            f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
        )

In [None]:
print_trainable_parameters(model)

In [14]:
from spmd_util import partition_module
import torch_xla.distributed.parallel_loader as pl

device = xm.xla_device()
model = model.to(device)
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear  
model = apply_xla_patch_to_nn_linear(model, xs.xla_patched_nn_linear_forward)  #for patching linear layer to use einsum instead of matmul
num_devices = xr.global_runtime_device_count()
model_axis=1
data_axis=num_devices//model_axis
mesh_shape = (1,data_axis, model_axis )
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp','fsdp', 'mp'))
partition_module(model, mesh)
training_loader = pl.MpDeviceLoader(training_loader, device)
testing_loader = pl.MpDeviceLoader(testing_loader, device)
mesh_shape

In [None]:
FLAGS

In [15]:
!export XLA_USE_BF16=1
import torch.nn as nn
import wandb
__wandb__=FLAGS['WANDB']
from torch_xla.amp.syncfree import AdamW
from transformers import get_linear_schedule_with_warmup,get_cosine_schedule_with_warmup
from optims.optim import SM3, CAME , Adafactor
# from random import randrange
# from bitsandbytes.optim import AdamW8bit 
# from torchdistx.optimizers import AnyPrecisionAdamW

val_step=0




def evaluate_loss(outputs,labels,pad_id=tokenizer.pad_token_id):
  epsilon=1e-8
  logits=outputs.logits
  logits = logits[..., :-1, :].contiguous()
  labels = labels[..., 1:].contiguous()
  log_probs = -nn.functional.log_softmax(logits, dim=-1)
  if labels.dim() == log_probs.dim() - 1:
    labels = labels.unsqueeze(-1)
  padding_mask = labels.eq(pad_id)
  labels = torch.clamp(labels, min=0)
  nll_loss = log_probs.gather(dim=-1, index=labels)
  smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.bfloat16)
  nll_loss.masked_fill_(padding_mask, 0.0)
  smoothed_loss.masked_fill_(padding_mask, 0.0)
  num_active_elements = padding_mask.numel() - padding_mask.long().sum()
  nll_loss = nll_loss.sum() / num_active_elements
  smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
  del labels,logits,padding_mask
  return (1-epsilon)*nll_loss + epsilon*smoothed_loss



def train(FLAGS):

    
    ### Configuring Training
    global val_step
    update_params= filter(lambda p: p.requires_grad, model.parameters())
    num_iterations = int((FLAGS["NUM_EPOCHS"] * FLAGS['STEPS'] ) // FLAGS['GRAD_ACCUMULATION_STEP'])
    warmup_steps = int(num_iterations * FLAGS['WARMUP_RATIO'])
    
    if __wandb__:
        wandb.init(project=FLAGS['PROJECT'],config=FLAGS)
        wandb.define_metric("Validation_loss", step_metric="val_step")
        wandb.define_metric("Learning_rate",step_metric="train_step")
        wandb.define_metric("train_loss",step_metric="train_step")
    
    ### Optimizers
    
    if (FLAGS['OPTIMIZER']).lower()=='adamw':
        optimizer = AdamW(update_params, eps=1e-8, lr=FLAGS['LEARNING_RATE'], betas=(0.9, 0.999),weight_decay=FLAGS['WEIGHT_DECAY'])
    elif (FLAGS['OPTIMIZER']).lower()=='lion':
        optimizer = Lion(update_params, lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'])
    elif (FLAGS['OPTIMIZER']).lower()=='adafactor':
        optimizer = Adafactor(update_params,lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'],scale_parameter=False,relative_step=False)
    elif (FLAGS['OPTIMIZER']).lower()=='came':
        optimizer = CAME(model.parameters(),lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'],betas=(0.9, 0.999, 0.9999),eps=(1e-30, 1e-16))
    else:
#         optimizer = Lilith(update_params, eps=1e-8, lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'])
        optimizer=SM3(update_params,lr=FLAGS['LEARNING_RATE'])

    for param_group in optimizer.param_groups:
        if len(param_group["params"]) > 0:
            print(param_group["params"][0].device)
            break
    
    
    ### Schedulars
    
    if (FLAGS['SCHEDULAR']).lower()=='linear':
        scheduler = get_linear_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
    else:
        scheduler = get_cosine_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
        
        
    
    
    ### Training Loop
    val_step=0
    check=False #for brakes
    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        if check:
            break
        model.train()
        xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
        for step, batch in enumerate(training_loader):
            input_ids, labels,attention_mask = batch["input_ids"].to(device),  batch["labels"].to(device),batch['attention_mask'].to(device)
            xs.mark_sharding(input_ids, mesh, (0,1))  ### earlier:-> (0,1) according to pytorch-xla , input/dataloaders must be sharded across ('data',None) 
            xs.mark_sharding( labels,  mesh, (0,1))  ###
            xs.mark_sharding(  attention_mask,  mesh, (0, 1)) ###
            outputs = model(input_ids=input_ids,attention_mask=attention_mask)
            loss = evaluate_loss(outputs,labels)
            

            if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                xm.master_print(f'loss: {loss.detach().cpu().item()}, time: {test_utils.now()}, step: {step+1}')
            if __wandb__:
                wandb.log({
                'Learning_rate': optimizer.param_groups[0]['lr'],
                'train_loss':  loss.detach().cpu().item(),
                'train_step': step + 1 + ((epoch-1) * FLAGS["STEPS"]),
                        })
                
                
                
                
            del input_ids , attention_mask 
            loss.backward()
            del outputs,loss
            
            
            
            
            if (step+1) % FLAGS['GRAD_ACCUMULATION_STEP'] == 0:
                torch.nn.utils.clip_grad_norm_(update_params, max_norm=FLAGS['MAX_GRAD_CLIP']*8)
                scheduler.step()
                xm.optimizer_step(optimizer,pin_layout=True,barrier=True) ## performs xm.reduce_gradient() , optimizer.step() , xm.mark_step()
                optimizer.zero_grad()
            
            
            
            
            
            if (step+1)% FLAGS['VAL_STEPS'] == 0:
                end_index=FLAGS["VAL_BATCH"]
                model.eval()
                with torch.no_grad():
                    total_loss = 0
                    total_step = 0
                    for stepx, batchx in enumerate(testing_loader):
                        input_ids = batchx["input_ids"].to(device)
                        labels = batchx["labels"].to(device)
                        attention_mask = batchx["attention_mask"].to(device)
                        xs.mark_sharding(input_ids, mesh, (0, None))
                        xs.mark_sharding(labels, mesh, (0, None))
                        xs.mark_sharding( attention_mask,    mesh, (0, None))
                        outputs = model(input_ids=input_ids,attention_mask=attention_mask)
                        loss = evaluate_loss(outputs,labels)
                        total_loss += loss.item()
                        total_step +=1
                        xm.master_print('----- Time -> {} ----- Validation Batch -> {} ----  Validation Loss -> {:.4f}'.format(test_utils.now(), total_step , loss.item()))
                        if __wandb__:
                            val_step+=1
                            wandb.log({
                                'Validation_loss': loss.item(),
                                'val_step':val_step,
                                    })
                        if (stepx+1)%end_index==0:
                            break
                    model.train()    
                    average_loss=total_loss/total_step
                    xm.master_print('----- Time -> {} ----- Validation Batch Size -> {} ----  Validation Loss -> {:.7f}'.format(test_utils.now(), total_step , average_loss))

            if (step+1)% FLAGS['PAUSE_STEPS']==0:
                inp=input('want to continue training after {} steps'.format(step+1))
                check = bool("no" in inp.lower())
                if check:
                    break
                else:
                    pass
            
        
        
        
        
          

**12 Mins to Train on 4k**

In [None]:
train(FLAGS)
if FLAGS['WANDB']:
    wandb.finish()

In [17]:
import time
print('Loading the model on CPU')
START=time.time()
model = model.cpu()
print(f"Loaded model on cpu in {time.time()-START} seconds ")

Loading the model on CPU

Loaded model on cpu in 129.38400077819824 seconds 


In [None]:
from huggingface_hub import login
login("hf_token") ##
model.push_to_hub(
    SAVED_MODEL, 
    tokenizer=tokenizer,
    safe_serialization=True,
    create_pr=True,
    max_shard_size="3GB", 
    )
tokenizer.push_to_hub(
    SAVED_MODEL,
    
    )