In [None]:
!pip3 install transformers zstandard jsonlines peft wandb bitsandbytes lion-pytorch -q
!pip3 install accelerate datasets sentencepiece langchain torch_xla[tpuvm] -q
!pip uninstall tensorflow -y #that's the meme part

In [None]:
get_ipython().kernel.do_shutdown(True)
### for good measures restart kernel

**Tokens?**

In [None]:
!huggingface-cli login --token <hf_read_token> #for downloading gated models
# import wandb
# wandb.login()

**Sharding Module for different Architechture**

In [6]:
%%writefile spmd_util.py
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, GPT2Config, MistralConfig, Qwen2Config, MixtralConfig, PhiConfig
)

# ends with $ to prevent sharding lora parameters


T5_RULES = (
    # embeddings
    ("shared$", ("mp", "fsdp")),
    ("embed_tokens$", ("mp", "fsdp")),
    
    # attention
    ("q$", ("fsdp", "mp")),
    ("k$", ("fsdp", "mp")),
    ("v$", ("fsdp", "mp")),
    ("o$", ("mp", "fsdp")),

    # mlp
    ("w$", ("fsdp", "mp")),
    ("wi_0$", ("fsdp", "mp")),
    ("wi_1$", ("fsdp", "mp")),
    ("wo$", ("mp", "fsdp")),

    # seq2seq lm head
    ("lm_head", ("fsdp", "mp")),
)

QWEN_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )
GPT2_RULES = (
    # embeddings
    ("wte", ("mp", "fsdp")), 
    ("wpe", ("mp", "fsdp")),
    
    # attention
    ("c_attn", ("fsdp", "mp")),
    ("c_proj", ("mp", "fsdp")),
    
    # mlp
    ("c_fc", ("fsdp", "mp")), 
    ("c_proj", ("mp", "fsdp")),
    
    # output 
    ("ln_f", ("fsdp", "mp")),
)
MISTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )


PHI_RULES = (
    ### (regex) linear modules, (list[sharding methods]) )
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.dense", ("mp", "fsdp")),
    ("mlp\\.fc2", ("mp", "fsdp")),  
    ("mlp\\.fc1", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    
)

LLAMA_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )

GPTNEOX_RULES = (
    # embeddings
    ("gpt_neox\\.embed_in", ("mp", "fsdp")),
    # atention
    ("attention\\.query_key_value$", ("fsdp", "mp")),
    ("attention\\.dense$", ("mp", "fsdp")),
    # mlp
    ("mlp\\.dense_h_to_4h$", ("fsdp", "mp")),
    ("mlp\\.dense_4h_to_h$", ("mp", "fsdp")),
    # output
    ("embed_out", ("fsdp", "mp")),
)



MIXTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("w1", ("fsdp", "mp")),
    ("w2", ("mp", "fsdp")),
    ("w3", ("fsdp", "mp")),
    ("gate", ("mp", "fsdp")),
    ("lm_head", ("fsdp", "mp")),
    )


    
ALL_RULES = [
    (GPTNeoXConfig, GPTNEOX_RULES),
    (T5Config, T5_RULES),
    (LlamaConfig, LLAMA_RULES),
    (GPT2Config, GPT2_RULES),
    (MistralConfig, MISTRAL_RULES),
    (Qwen2Config, QWEN_RULES),
    (MixtralConfig, MIXTRAL_RULES),
    (PhiConfig,PHI_RULES),
]

def find_rule(model):
    for config, rule in ALL_RULES:
        if model.config.__class__ == config:
            return rule
    raise Exception("unsupported model to partitioning")

strkey2id = {
    "dp": 0,
    "fsdp": 1,
    "mp": 2
}

def partition_module(model, mesh, device=xm.xla_device(), verbose=False):
    partition_specs = find_rule(model)
    rule = [(k, tuple([strkey2id[x] for x in v])) for k, v in partition_specs]
        
    # print(rule)

    for name, module in model.named_modules():
        module.to(device)
        # print(name, module.__class__.__name__)
        if isinstance(module, (nn.Embedding, nn.Linear)):
            for rule_pattern, spec in rule:
                if re.findall(rule_pattern, name):
                    if verbose:
                        print("match", rule_pattern, name)
                    
                    xs.mark_sharding(module.weight, mesh, spec)
                    break
        
def partition_module_dp(model, mesh, device=xm.xla_device(), verbose=True):
    spec = (1, 2)

    for name, module in model.named_modules():
        module.to(device)
        if isinstance(module, (nn.Embedding, nn.Linear)):
            xs.mark_sharding(module.weight, mesh, spec)

Overwriting spmd_util.py


In [None]:
### TODO
"""
-Support PHI-2 // curent PR
-Support Gradient Accumulation
-Support Quantization
-Support for 8-bit optimizers and paged optimizers
-Support interference after sharding
"""

**Required Libs**

In [2]:
import os
import pandas as pd
import numpy as np
import datasets
import torch.optim as optim
import torch_xla.debug.profiler as xp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp # We also import mp modules if we wanna use that for some reason
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, AutoTokenizer, AutoModelForCausalLM, MistralConfig, Qwen2Config, GPT2Config, DataCollatorWithPadding, AutoConfig, AutoModelForSequenceClassification
) # You can use any of models with those configs (even flan T5 xxl!). Other models are not supported.

from transformers import logging as hf_logging
import torch.nn.functional as F
import torch_xla.runtime as xr

xr.use_spmd()

import torch_xla.experimental.xla_sharding as xs # "experimental" prefix always means you're gonna have a good time LMAO
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

from peft import LoraConfig, TaskType, get_peft_model # If we wanna use peft. Quantazation requiers GPU though. You'll have to download already quantazed models
from spmd_util import partition_module                # You could experiment with using already quantazed models like 4bit/Llama-2-7b-Chat-GPTQ if you're feeling funny
from datasets import Dataset, load_dataset, concatenate_datasets
from dataclasses import dataclass
from tqdm import tqdm

import transformers
import datasets
import pandas as pd
import numpy as np
from datasets import Dataset
from torch.utils.data import Dataset as TorchDataset
import torch.utils
try:
    !export USE_TORCH=True #If we don't do this, transformers will seemingly bork the session upon import. Really weird error.
    os.environ["PJRT_DEVICE"] = "TPU"
    os.environ.pop('TPU_PROCESS_ADDRESSES')
    os.environ.pop('CLOUD_TPU_TASK_ID')
    hf_logging.set_verbosity_error() # It can still display warnings which is a bit annoying but whatever
except:
    pass


  from .autonotebook import tqdm as notebook_tqdm


**Configuration**

In [None]:
MAX_INPUT=1024
MODEL = "g-ronimo/phi-2-OpenHermes-2.5" #You should be able to use 7B model with no changes! There should be enough HBM
SAVED_MODEL = "fhai50032/PHI-Coder"



In [9]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL)
if 'pad_token' not in tokenizer.special_tokens_map:
  tokenizer.pad_token=tokenizer.eos_token


print(f"Tokens :\n {tokenizer.special_tokens_map} \n\n")

tokenizer_config.json: 100%|██████████| 8.20k/8.20k [00:00<00:00, 17.1MB/s]
vocab.json: 100%|██████████| 798k/798k [00:00<00:00, 10.4MB/s]
merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 55.5MB/s]
tokenizer.json: 100%|██████████| 2.12M/2.12M [00:00<00:00, 26.9MB/s]
added_tokens.json: 100%|██████████| 1.15k/1.15k [00:00<00:00, 6.11MB/s]
special_tokens_map.json: 100%|██████████| 462/462 [00:00<00:00, 2.60MB/s]

Tokens :
 {'bos_token': '<|endoftext|>', 'eos_token': '<|im_end|>', 'unk_token': '<|endoftext|>', 'pad_token': '<PAD>'} 







In [10]:
class ConversationDataset(TorchDataset):
    def __init__(self, tokenizer, max_length=1024, dataset=None):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        messages = self.dataset[idx]["conversations"]
        text = ""
        for message in messages:
            role = message["from"]
            if role == "system":
                text += f"<|im_start|>system\n{message['value']}<|im_end|>\n"
            if role == "human":
                text += f"<|im_start|>user\n{message['value']}<|im_end|>\n"
            if role == "function-call":
                text += f"<|im_start|>call\n{message['value']}<|im_end|>\n"
            if role == "function-response":
                text += f"<|im_start|>function\n{message['value']}<|im_end|>\n"
            if role =="gpt":
                text += f"<|im_start|>assistant\n{message['value']}{self.tokenizer.eos_token}"
        input_ids = self.tokenizer(text, add_special_tokens=True, max_length=self.max_length, truncation=True, padding="max_length", return_attention_mask=True, return_tensors="pt")
        return {
            "input_ids": input_ids["input_ids"].squeeze(0),
            "labels": input_ids["input_ids"].squeeze(0),
            "mask":input_ids["attention_mask"].squeeze(0),
        }

In [11]:
train_dataset="fhai50032/magicoder-oss-instruct-sharegpt-75k"
test_dataset="fhai50032/magicoder-oss-instruct-sharegpt-75k"

train_data = load_dataset(train_dataset, split="train[0:2000]").shuffle(seed=69)
val = (load_dataset(test_dataset, split="train[:64]")).shuffle(seed=420)

Downloading readme: 100%|██████████| 645/645 [00:00<00:00, 4.59MB/s]
Downloading data: 100%|██████████| 44.9M/44.9M [00:03<00:00, 12.9MB/s]
Downloading data: 100%|██████████| 45.1M/45.1M [00:03<00:00, 14.0MB/s]
Generating train split: 100%|██████████| 75197/75197 [00:00<00:00, 77589.77 examples/s]


In [12]:
len(train_data)

2000

In [14]:
FLAGS = {'MAX_INPUT': MAX_INPUT,
         'LOGGING_STEPS': 1,
         'NUM_EPOCHS': 1,
         'BATCH_SIZE': 8, #Making batch_size lower then 8 will result in slower training, but will allow for larger models\context. Fortunately, we have 128GBs. Setting higher batch_size doesn't seem to improve time.
          'LEN_DATA': len(train_data),
         'VAL_PER_STEP': 10,
#         'GRAD_ACCUMULATION':2,
#          'MAX_GRAD_CLIP':1.0,
        'LEARNING_RATE':2e-5,
         'WARMUP_RATIO':0.1,
         'OPTIMIZER':'adamw', # default = 'adamw8bit'  options->  ['adamw','adamw8bit','adafactor','lion']           
         'SCHEDULAR':'linear', # default= 'cosine'     options:-> ['linear','cosine']
         'WEIGHT_DECAY':0.0,
         'TRAIN_DATASET':train_dataset,
         "TEST_DATASET":test_dataset,
         'WANDB':True,
        'PROJECT':'Silver-Pulse'} # Indian pun :) 

In [15]:
FLAGS

{'MAX_INPUT': 1024,
 'LOGGING_STEPS': 1,
 'NUM_EPOCHS': 1,
 'BATCH_SIZE': 8,
 'LEN_DATA': 2000,
 'VAL_PER_STEP': 10,
 'LEARNING_RATE': 2e-05,
 'WARMUP_RATIO': 0.1,
 'OPTIMIZER': 'adamw',
 'SCHEDULAR': 'linear',
 'WEIGHT_DECAY': 0.0,
 'TRAIN_DATASET': 'fhai50032/magicoder-oss-instruct-sharegpt-75k',
 'TEST_DATASET': 'fhai50032/magicoder-oss-instruct-sharegpt-75k',
 'WANDB': True,
 'PROJECT': 'Silver-Pulse'}

**Quantization When??**

In [None]:
# from transformers import BitsAndBytesConfig

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     llm_int8_has_fp16_weight=False,
        
# )
# model = AutoModelForCausalLM.from_pretrained(MODEL,torch_dtype=torch.bfloat16,quantization_config=bnb_config,
#     trust_remote_code=True,
#     low_cpu_mem_usage=True) 

In [None]:
model = AutoModelForCausalLM.from_pretrained(MODEL,torch_dtype=torch.bfloat16) 
### use only bf16 or atleast set compute type to bf16 

**LoRA Applicable**

In [None]:
ls=LoraConfig(
    r = 32, # Lora Rank ,I would prefer 8-32 for smaller models like 7b
    target_modules = ['v_proj', 'down_proj', 'up_proj', 'o_proj', 'q_proj', 'gate_proj', 'k_proj'],
    lora_alpha = 48, #weight_scaling
    lora_dropout = 0.05, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimize
    # modules_to_save = ["lm_head", "embed_tokens"] ## if you use new chat formats or embedding tokens
)
model = get_peft_model(model, ls)
model.print_trainable_parameters()

**Data-Distributer**

In [16]:
train_data = ConversationDataset(tokenizer, dataset=train_data, max_length=1024)
val = ConversationDataset(tokenizer, dataset=val)
train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_data, num_replicas=8, rank=xm.get_ordinal(), shuffle=True)
training_loader = torch.utils.data.DataLoader(train_data, batch_size=FLAGS["BATCH_SIZE"], sampler=train_sampler)
val_sampler = torch.utils.data.distributed.DistributedSampler(
    val, num_replicas=8, rank=xm.get_ordinal(), shuffle=True)
testing_loader = torch.utils.data.DataLoader(val, batch_size=FLAGS["BATCH_SIZE"], sampler=val_sampler)

print(f"Max Steps:- {len(training_loader)}  , Each Step has {8*FLAGS['BATCH_SIZE']} lines")

FLAGS['STEPS']=len(training_loader)
FLAGS['BATCH_DATA']=FLAGS['BATCH_SIZE']*8 ## 8 CORES ON TPU 
# print(device)

Max Steps:- 32  , Each Step has 64 lines


In [17]:
def get_nb_trainable_parameters(model):
        r"""
        Returns the number of trainable parameters and number of all parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in model.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel

            # Due to the design of 4bit linear layers from bitsandbytes
            # one needs to multiply the number of parameters by 2 to get
            # the correct number of parameters
            if param.__class__.__name__ == "Params4bit":
                num_params = num_params * 2

            all_param += num_params
            if param.requires_grad:
                trainable_params += num_params

        return trainable_params, all_param
def print_trainable_parameters(model):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params, all_param = get_nb_trainable_parameters(model)
        
        print(
            f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
        )

In [18]:
print_trainable_parameters(model)

trainable params: 2,779,683,840 || all params: 2,779,683,840 || trainable%: 100.0


In [None]:
config = AutoConfig.from_pretrained(MODEL)
num_devices = xr.global_runtime_device_count()
mesh_shape = (1, num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp'))
partition_module(model, mesh) # After this, the model is sharded between cores but still has the same API as if it was on single device. Neat.

In [19]:
FLAGS

{'MAX_INPUT': 1024,
 'LOGGING_STEPS': 1,
 'NUM_EPOCHS': 1,
 'BATCH_SIZE': 8,
 'LEN_DATA': 2000,
 'VAL_PER_STEP': 10,
 'LEARNING_RATE': 2e-05,
 'WARMUP_RATIO': 0.1,
 'OPTIMIZER': 'adamw',
 'SCHEDULAR': 'linear',
 'WEIGHT_DECAY': 0.0,
 'TRAIN_DATASET': 'fhai50032/magicoder-oss-instruct-sharegpt-75k',
 'TEST_DATASET': 'fhai50032/magicoder-oss-instruct-sharegpt-75k',
 'WANDB': True,
 'PROJECT': 'Silver-Pulse',
 'STEPS': 32,
 'BATCH_DATA': 64}

In [24]:
!export XLA_USE_BF16=1
import wandb
from transformers import AdamW,Adafactor
from lion_pytorch import Lion #optimizer best used for large batch size ~ 4096+
from transformers import get_linear_schedule_with_warmup,get_cosine_schedule_with_warmup
# from bitsandbytes.optim import AdamW8bit 



def train(FLAGS):

    
    ### Configuring Training
    device = xm.xla_device()
    update_params= filter(lambda p: p.requires_grad, model.parameters())
    num_iterations = FLAGS["NUM_EPOCHS"] * FLAGS['STEPS']  #    // FLAGS['GRAD_ACCUMULATION'])
    warmup_steps = int(num_iterations * FLAGS['WARMUP_RATIO'])
    __wandb__=FLAGS['WANDB']
    
    if __wandb__:
        wandb.init(project=FLAGS['PROJECT'],config=FLAGS)
    
    ### Optimizers
    
    if (FLAGS['OPTIMIZER']).lower()=='adamw':
        optimizer = AdamW(update_params, eps=1e-8, lr=FLAGS['LEARNING_RATE'], betas=(0.9, 0.999),weight_decay=FLAGS['WEIGHT_DECAY'],no_deprecation_warning=True)
    elif (FLAGS['OPTIMIZER']).lower()=='lion':
        optimizer = Lion(update_params, lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'])
    elif (FLAGS['OPTIMIZER']).lower()=='adafactor':
        optimizer = Adafactor(update_params,lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'],scale_parameter=True,relative_step=False)
    else:
#         optimizer = AdamW8bit(update_params, eps=1e-8, lr=FLAGS['LEARNING_RATE'], betas=(0.9, 0.999),weight_decay=FLAGS['WEIGHT_DECAY'])
        pass

    
    
    ### Schedulars
    
    if (FLAGS['SCHEDULAR']).lower()=='linear':
        scheduler = get_linear_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
    else:
        schedular = get_cosine_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
        
        
    
    
    ### Training Loop
    
    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        model.train()
        xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
        for step, batch in enumerate(training_loader):
            input_ids, labels,mask = batch["input_ids"].to(device),  batch["labels"].to(device),batch['mask'].to(device)
            xs.mark_sharding(input_ids, mesh, (0, 1))
            xs.mark_sharding( labels,   mesh, (0, 1))
            xs.mark_sharding(  mask,    mesh, (0, 1))
            outputs = model(input_ids=input_ids, labels=labels,attention_mask=mask)
            
            
            
            loss = outputs.loss
            
#           loss = loss / (FLAGS['GRAD_ACCUMULATION'] + scheduler.get_last_lr()[0]) # my touch for grad_norm



            if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                print(f'loss: {loss.item()}, time: {test_utils.now()}, step: {step+1}')
            if __wandb__:
                wandb.log({
                'Learning_rate': optimizer.param_groups[0]['lr'],
                'train_loss': loss.item(),
                'train_step': step + 1 + ((epoch-1) * FLAGS["STEPS"]),
                        })
            
            loss.requires_grad_()
            loss.backward()
            optimizer.step()
            scheduler.step()
            xm.mark_step()
            optimizer.zero_grad()
            
        xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
            
            
        
        
        
        
        ### Validation - Test
        model.eval()
        total_loss = 0.0
        total_steps = 0
        with torch.no_grad():
            for step, batch in enumerate(testing_loader):
                input_ids = batch["input_ids"].to(device)
                labels = batch["labels"].to(device)
                xs.mark_sharding(input_ids, mesh, (0, 1))
                xs.mark_sharding(labels, mesh, (0, 1))
                outputs = model(input_ids=input_ids, labels=labels)
                loss = outputs.loss
                total_loss += loss.item()
                if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                    print(f'Testing Loss : {loss.item()}, time: {test_utils.now()}, step: {step+1}')
                total_steps += 1

        average_loss = total_loss / total_steps
        xm.master_print('----- End Time -> {} ----- Total Validation Steps -> {} ----  Total Validation Loss -> {:.4f}'.format(test_utils.now(), total_steps , average_loss))
#         

**12 Mins to Train on 4k**

In [25]:
train(FLAGS)
if FLAGS['WANDB']:
    wandb.finish()



Epoch 1 train begin 18:28:04
loss: 6.487453460693359, time: 18:28:26, step: 1
loss: 6.957941055297852, time: 18:30:38, step: 2
loss: 5.957518100738525, time: 18:32:12, step: 3
loss: 5.705552577972412, time: 18:32:14, step: 4
loss: 5.225892543792725, time: 18:32:16, step: 5
loss: 5.756276607513428, time: 18:32:18, step: 6
loss: 4.587933540344238, time: 18:32:20, step: 7
loss: 4.767673015594482, time: 18:32:22, step: 8
loss: 5.576902866363525, time: 18:32:24, step: 9
loss: 4.731902599334717, time: 18:32:26, step: 10
loss: 4.955680847167969, time: 18:32:28, step: 11
loss: 5.193166255950928, time: 18:32:30, step: 12
loss: 4.506847381591797, time: 18:32:32, step: 13
loss: 5.372564315795898, time: 18:32:33, step: 14
loss: 4.4137349128723145, time: 18:32:35, step: 15
loss: 4.066483020782471, time: 18:32:37, step: 16
loss: 4.605711460113525, time: 18:32:39, step: 17
loss: 4.2736968994140625, time: 18:32:41, step: 18
loss: 4.799393177032471, time: 18:32:43, step: 19
loss: 4.200766563415527, tim

0,1
Learning_rate,▁▃▆███▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▁▁
train_loss,▇█▆▅▄▆▃▃▅▃▄▄▃▅▃▂▃▂▄▂▃▂▃▁▂▂▂▃▁▂▃▂
train_step,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇███

0,1
Learning_rate,0.0
train_loss,4.06651
train_step,64.0


In [None]:
print('Loading the model on CPU')
START=time.time()
model = model.cpu()
print(f"Loaded model on cpu in {time.time()-START} seconds ")

In [None]:
from huggingface_hub import login
login("<HF_WRITE_TOKEN>") ##"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"
model.push_to_hub(
    SAVED_MODEL, 
    tokenizer=tokenizer,
    safe_serialization=True,
    private=True,
    create_pr=True,
    max_shard_size="3GB", 
    )
tokenizer.push_to_hub(
    SAVED_MODEL,
    private=True, 
    
    )