In [None]:
!pip3 install transformers zstandard jsonlines peft wandb bitsandbytes lion-pytorch -q
!pip3 install accelerate datasets sentencepiece langchain torch_xla[tpuvm] -q
!pip uninstall tensorflow -y #that's the meme part

In [None]:
get_ipython().kernel.do_shutdown(True)
### for good measures restart kernel

**Tokens?**

In [None]:
!huggingface-cli login --token <hf_read_token> #for downloading gated models
# import wandb
# wandb.login()

**Sharding Module for different Architechture**

In [1]:
%%writefile spmd_util.py
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, GPT2Config, MistralConfig, Qwen2Config, MixtralConfig, PhiConfig,GemmaConfig
)

# ends with $ to prevent sharding lora parameters


T5_RULES = (
    # embeddings
    ("shared$", ("mp", "fsdp")),
    ("embed_tokens$", ("mp", "fsdp")),
    
    # attention
    ("q$", ("fsdp", "mp")),
    ("k$", ("fsdp", "mp")),
    ("v$", ("fsdp", "mp")),
    ("o$", ("mp", "fsdp")),

    # mlp
    ("w$", ("fsdp", "mp")),
    ("wi_0$", ("fsdp", "mp")),
    ("wi_1$", ("fsdp", "mp")),
    ("wo$", ("mp", "fsdp")),

    # seq2seq lm head
    ("lm_head", ("fsdp", "mp")),
)

QWEN_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )
GPT2_RULES = (
    # embeddings
    ("wte", ("mp", "fsdp")), 
    ("wpe", ("mp", "fsdp")),
    
    # attention
    ("c_attn", ("fsdp", "mp")),
    ("c_proj", ("mp", "fsdp")),
    
    # mlp
    ("c_fc", ("fsdp", "mp")), 
    ("c_proj", ("mp", "fsdp")),
    
    # output 
    ("ln_f", ("fsdp", "mp")),
)
MISTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )


PHI_RULES = (
    ### (regex) linear modules, (list[sharding methods]) )
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.dense", ("mp", "fsdp")),
    ("mlp\\.fc2", ("mp", "fsdp")),  
    ("mlp\\.fc1", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    
)

LLAMA_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )

GPTNEOX_RULES = (
    # embeddings
    ("gpt_neox\\.embed_in", ("mp", "fsdp")),
    # atention
    ("attention\\.query_key_value$", ("fsdp", "mp")),
    ("attention\\.dense$", ("mp", "fsdp")),
    # mlp
    ("mlp\\.dense_h_to_4h$", ("fsdp", "mp")),
    ("mlp\\.dense_4h_to_h$", ("mp", "fsdp")),
    # output
    ("embed_out", ("fsdp", "mp")),
)



MIXTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("w1", ("fsdp", "mp")),
    ("w2", ("mp", "fsdp")),
    ("w3", ("fsdp", "mp")),
    ("gate", ("mp", "fsdp")),
    ("lm_head", ("fsdp", "mp")),
    )

GEMMA_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
)
    
ALL_RULES = [
    (GPTNeoXConfig, GPTNEOX_RULES),
    (T5Config, T5_RULES),
    (LlamaConfig, LLAMA_RULES),
    (GPT2Config, GPT2_RULES),
    (MistralConfig, MISTRAL_RULES),
    (Qwen2Config, QWEN_RULES),
    (MixtralConfig, MIXTRAL_RULES),
    (PhiConfig,PHI_RULES),
    (GemmaConfig,GEMMA_RULES),
]

def find_rule(model):
    for config, rule in ALL_RULES:
        if model.config.__class__ == config:
            return rule
    raise Exception("unsupported model to partitioning")

strkey2id = {
    "dp": 0,
    "fsdp": 1,
    "mp": 2
}

def partition_module(model, mesh, device=xm.xla_device(), verbose=False):
    partition_specs = find_rule(model)
    rule = [(k, tuple([strkey2id[x] for x in v])) for k, v in partition_specs]
        
    # print(rule)

    for name, module in model.named_modules():
        module.to(device)
        # print(name, module.__class__.__name__)
        if isinstance(module, (nn.Embedding, nn.Linear)):
            for rule_pattern, spec in rule:
                if re.findall(rule_pattern, name):
                    if verbose:
                        print("match", rule_pattern, name)
                    
                    xs.mark_sharding(module.weight, mesh, spec)
                    break
        
def partition_module_dp(model, mesh, device=xm.xla_device(), verbose=True):
    spec = (1, 2)

    for name, module in model.named_modules():
        module.to(device)
        if isinstance(module, (nn.Embedding, nn.Linear)):
            xs.mark_sharding(module.weight, mesh, spec)

Overwriting spmd_util.py


**Required Libs**

In [2]:
import os
import pandas as pd
import numpy as np
import datasets
import torch.optim as optim
import torch_xla.debug.profiler as xp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp # We also import mp modules if we wanna use that for some reason
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, AutoTokenizer, AutoModelForCausalLM, MistralConfig, Qwen2Config, GPT2Config, DataCollatorWithPadding, AutoConfig, AutoModelForSequenceClassification
) # You can use any of models with those configs (even flan T5 xxl!). Other models are not supported.

from transformers import logging as hf_logging
import torch.nn.functional as F
import torch_xla.runtime as xr

xr.use_spmd()

import torch_xla.experimental.xla_sharding as xs # "experimental" prefix always means you're gonna have a good time LMAO
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

from peft import LoraConfig, TaskType, get_peft_model # If we wanna use peft. Quantazation requiers GPU though. You'll have to download already quantazed models
from spmd_util import partition_module                # You could experiment with using already quantazed models like 4bit/Llama-2-7b-Chat-GPTQ if you're feeling funny
from datasets import Dataset, load_dataset, concatenate_datasets
from dataclasses import dataclass
from tqdm import tqdm

import transformers
import datasets
import pandas as pd
import numpy as np
from datasets import Dataset
from torch.utils.data import Dataset as TorchDataset
import torch.utils
from torch_xla.utils.checkpoint import checkpoint
try:
    !export USE_TORCH=True #If we don't do this, transformers will seemingly bork the session upon import. Really weird error.
    os.environ["PJRT_DEVICE"] = "TPU"
    os.environ.pop('TPU_PROCESS_ADDRESSES')
    os.environ.pop('CLOUD_TPU_TASK_ID')
    hf_logging.set_verbosity_error() # It can still display warnings which is a bit annoying but whatever
except:
    pass


  from .autonotebook import tqdm as notebook_tqdm


**Configuration**

In [3]:
MAX_INPUT=1024
MODEL = "abideen/gemma-7b-openhermes" #You should be able to use 7B model with no changes! There should be enough HBM
SAVED_MODEL = "fhai50032/Gemma-Unaligned"


In [4]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL)
if 'pad_token' not in tokenizer.special_tokens_map:
  tokenizer.pad_token=tokenizer.eos_token


print(f"Tokens :\n {tokenizer.special_tokens_map} \n\n")

Tokens :
 {'bos_token': '<bos>', 'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']} 




In [5]:
class ConversationDataset(TorchDataset):
    def __init__(self, tokenizer, max_length=1024, dataset=None):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        messages = self.dataset[idx]["QAs"]
        text = ""
        for message in messages:
            role = message["from"]
            if role == "system":
                text += f"<|im_start|>system\n{message['value']}<|im_end|>\n"
            if role in ["human","user"]:
                text += f"<|im_start|>user\n{message['value']}<|im_end|>\n"
            if role == "function-call":
                text += f"<|im_start|>call\n{message['value']}<|im_end|>\n"
            if role == "function-response":
                text += f"<|im_start|>function\n{message['value']}<|im_end|>\n"
            if role in ["gpt","assistant"]:
                text += f"<|im_start|>assistant\n{message['value']}{self.tokenizer.eos_token}"
        input_ids = self.tokenizer(text, add_special_tokens=True, max_length=self.max_length, truncation=True, padding="max_length", return_attention_mask=True, return_tensors="pt")
        return {
            "input_ids": input_ids["input_ids"].squeeze(0),
            "labels": input_ids["input_ids"].squeeze(0),
            "attention_mask":input_ids["attention_mask"].squeeze(0),
        }

In [6]:
train_dataset="NobodyExistsOnTheInternet/ToxicQAFinal"
test_dataset="NobodyExistsOnTheInternet/ToxicQAFinal"

train_data = load_dataset(train_dataset, split="train").shuffle(seed=69)
val = (load_dataset(test_dataset, split="train[:640]")).shuffle(seed=420)

In [7]:
len(train_data)

6866

In [8]:
FLAGS = {'MAX_INPUT': MAX_INPUT,
         'LOGGING_STEPS': 1,
         'NUM_EPOCHS': 1,
         'PAUSE_STEPS':1000,
         'MAX_STEPS': -1,#Ooverides num epochs
         'BATCH_SIZE': 2, #Making batch_size lower then 8 will result in slower training, but will allow for larger models\context. Fortunately, we have 128GBs. Setting higher batch_size doesn't seem to improve time.
          'LEN_TRAIN_DATA': len(train_data),
         'VAL_STEPS': 20,
         'VAL_BATCH': 4,
#         'GRAD_ACCUMULATION':2,
#          'MAX_GRAD_CLIP':1.0,
        'LEARNING_RATE':2e-5,
         'WARMUP_RATIO':0.1,
         'OPTIMIZER':'adamw', # default = 'adamw'  options->  ['adamw','adamw8bit','adafactor','lion']           
         'SCHEDULAR':'linear', # default= 'cosine'     options:-> ['linear','cosine']
         'WEIGHT_DECAY':0.01,
         'TRAIN_DATASET':train_dataset,
         "TEST_DATASET":test_dataset,
         'WANDB':True,
        'PROJECT':'Xlake-Coder',
        } # Indian pun :) 

In [9]:
FLAGS

{'MAX_INPUT': 1024,
 'LOGGING_STEPS': 1,
 'NUM_EPOCHS': 1,
 'PAUSE_STEPS': 100,
 'MAX_STEPS': -1,
 'BATCH_SIZE': 2,
 'LEN_TRAIN_DATA': 6866,
 'VAL_STEPS': 20,
 'VAL_BATCH': 4,
 'LEARNING_RATE': 2e-05,
 'WARMUP_RATIO': 0.1,
 'OPTIMIZER': 'adamw',
 'SCHEDULAR': 'linear',
 'WEIGHT_DECAY': 0.01,
 'TRAIN_DATASET': 'NobodyExistsOnTheInternet/ToxicQAFinal',
 'TEST_DATASET': 'NobodyExistsOnTheInternet/ToxicQAFinal',
 'WANDB': True,
 'PROJECT': 'Xlake-Coder'}

**Quantization When??**

In [None]:
# from transformers import BitsAndBytesConfig

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     llm_int8_has_fp16_weight=False,
        
# )
# model = AutoModelForCausalLM.from_pretrained(MODEL,torch_dtype=torch.bfloat16,quantization_config=bnb_config,
#     trust_remote_code=True,
#     low_cpu_mem_usage=True) 

In [10]:
model = AutoModelForCausalLM.from_pretrained(MODEL,torch_dtype=torch.bfloat16) 
model._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=checkpoint)
### use only bf16 or atleast set compute type to bf16 

Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.08s/it]


**LoRA Applicable**

In [11]:
ls=LoraConfig(
    r = 48, # Lora Rank ,I would prefer 8-32 for smaller models like 7b
    target_modules = ['q_proj', 'down_proj', 'up_proj', 'o_proj', 'v_proj', 'gate_proj', 'k_proj'],
    lora_alpha = 16, #weight_scaling
    lora_dropout = 0.05, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimize
    # modules_to_save = ["lm_head", "embed_tokens"] ## if you use new chat formats or embedding tokens
)
model = get_peft_model(model, ls)
model.print_trainable_parameters()

  warn("The installed version of bitsandbytes was compiled without GPU support. "


/usr/local/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
trainable params: 150,011,904 || all params: 8,687,692,800 || trainable%: 1.7267174087923551


**Data-Distributer**

In [12]:
train_data = ConversationDataset(tokenizer, dataset=train_data, max_length=1024)
val = ConversationDataset(tokenizer, dataset=val)
train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_data, num_replicas=8, rank=xm.get_ordinal(), shuffle=True)
training_loader = torch.utils.data.DataLoader(train_data, batch_size=FLAGS["BATCH_SIZE"], sampler=train_sampler)
val_sampler = torch.utils.data.distributed.DistributedSampler(
    val, num_replicas=8, rank=xm.get_ordinal(), shuffle=True)
testing_loader = torch.utils.data.DataLoader(val, batch_size=FLAGS["BATCH_SIZE"], sampler=val_sampler)

print(f"Max Steps:- {len(training_loader)}  , Each Step has {8*FLAGS['BATCH_SIZE']} inputs")

FLAGS['STEPS']=len(training_loader)
FLAGS['BATCH_DATA']=FLAGS['BATCH_SIZE']*8 ## 8 CORES ON TPU 
# print(device)

Max Steps:- 430  , Each Step has 16 inputs


In [None]:
print(val[0]['input_ids'])
for i in testing_loader:
    print(i['input_ids'])
    break
print(tokenizer.decode(val[0]['input_ids']))

In [13]:
def get_nb_trainable_parameters(model):
        r"""
        Returns the number of trainable parameters and number of all parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in model.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel

            # Due to the design of 4bit linear layers from bitsandbytes
            # one needs to multiply the number of parameters by 2 to get
            # the correct number of parameters
            if param.__class__.__name__ == "Params4bit":
                num_params = num_params * 2

            all_param += num_params
            if param.requires_grad:
                trainable_params += num_params

        return trainable_params, all_param
def print_trainable_parameters(model):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params, all_param = get_nb_trainable_parameters(model)
        
        print(
            f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
        )

In [None]:
print_trainable_parameters(model)

In [14]:
config = AutoConfig.from_pretrained(MODEL)
num_devices = xr.global_runtime_device_count()
mesh_shape = (1, num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp'))
partition_module(model, mesh) # After this, the model is sharded between cores but still has the same API as if it was on single device. Neat.

In [None]:
FLAGS

In [15]:
!export XLA_USE_BF16=1
import torch.nn as nn
import wandb
__wandb__=FLAGS['WANDB']
from random import randrange
from transformers import AdamW,Adafactor
from lion_pytorch import Lion #LION best used for large batch size ~ 4096+ similar convergence as adam but faster
from transformers import get_linear_schedule_with_warmup,get_cosine_schedule_with_warmup
# from bitsandbytes.optim import AdamW8bit 
val_step=0
device = xm.xla_device()


def evaluate_loss(outputs,labels,pad_id=tokenizer.pad_token_id):
  epsilon=1e-8
  logits=outputs.logits
  logits = logits[..., :-1, :].contiguous()
  labels = labels[..., 1:].contiguous()
  log_probs = -nn.functional.log_softmax(logits, dim=-1)
  if labels.dim() == log_probs.dim() - 1:
    labels = labels.unsqueeze(-1)
  padding_mask = labels.eq(pad_id)
  labels = torch.clamp(labels, min=0)
  nll_loss = log_probs.gather(dim=-1, index=labels)
  smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
  nll_loss.masked_fill_(padding_mask, 0.0)
  smoothed_loss.masked_fill_(padding_mask, 0.0)
  num_active_elements = padding_mask.numel() - padding_mask.long().sum()
  nll_loss = nll_loss.sum() / num_active_elements
  smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
  return (1-epsilon)*nll_loss + epsilon*smoothed_loss



def train(FLAGS):

    
    ### Configuring Training
    global val_step
    update_params= filter(lambda p: p.requires_grad, model.parameters())
    num_iterations = FLAGS["NUM_EPOCHS"] * FLAGS['STEPS']  #    // FLAGS['GRAD_ACCUMULATION'])
    warmup_steps = int(num_iterations * FLAGS['WARMUP_RATIO'])
    
    if __wandb__:
        wandb.init(project=FLAGS['PROJECT'],config=FLAGS)
        wandb.define_metric("Validation_loss", step_metric="val_step")
        wandb.define_metric("Learning_rate",step_metric="train_step")
        wandb.define_metric("train_loss",step_metric="train_step")
    
    ### Optimizers
    
    if (FLAGS['OPTIMIZER']).lower()=='adamw':
        optimizer = AdamW(update_params, eps=1e-8, lr=FLAGS['LEARNING_RATE'], betas=(0.9, 0.999),weight_decay=FLAGS['WEIGHT_DECAY'],no_deprecation_warning=True)
    elif (FLAGS['OPTIMIZER']).lower()=='lion':
        optimizer = Lion(update_params, lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'])
    elif (FLAGS['OPTIMIZER']).lower()=='adafactor':
        optimizer = Adafactor(update_params,lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'],scale_parameter=True,relative_step=False)
    else:
#         optimizer = AdamW8bit(update_params, eps=1e-8, lr=FLAGS['LEARNING_RATE'], betas=(0.9, 0.999),weight_decay=FLAGS['WEIGHT_DECAY'])
        optimizer = AdamW(update_params, eps=1e-8, lr=FLAGS['LEARNING_RATE'], betas=(0.9, 0.999),weight_decay=FLAGS['WEIGHT_DECAY'],no_deprecation_warning=True)

    for param_group in optimizer.param_groups:
        if len(param_group["params"]) > 0:
            print(param_group["params"][0].device)
            break
    
    
    ### Schedulars
    
    if (FLAGS['SCHEDULAR']).lower()=='linear':
        scheduler = get_linear_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
    else:
        scheduler = get_cosine_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
        
        
    
    
    ### Training Loop
    val_step=0
    check=False #for brakes
    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        if check:
            break
        model.train()
        xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
        for step, batch in enumerate(training_loader):
            
            input_ids, labels,attention_mask = batch["input_ids"].to(device),  batch["labels"].to(device),batch['attention_mask'].to(device)
            xs.mark_sharding(input_ids, mesh, (0, 1))  ### earlier:-> (0,1) according to pytorch-xla , input/dataloaders must be sharded across ('data',None) 
            xs.mark_sharding( labels,   mesh, (0, 1))  ###
            xs.mark_sharding(  attention_mask,    mesh, (0, 1))###
            outputs = model(input_ids=input_ids,attention_mask=attention_mask)
            loss = evaluate_loss(outputs,labels)
            
#           loss = loss / (FLAGS['GRAD_ACCUMULATION'] + scheduler.get_last_lr()[0]) # my touch for grad_norm



            if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                xm.master_print(f'loss: {loss.item()}, time: {test_utils.now()}, step: {step+1}')
            if __wandb__:
                wandb.log({
                'Learning_rate': optimizer.param_groups[0]['lr'],
                'train_loss': loss.item(),
                'train_step': step + 1 + ((epoch-1) * FLAGS["STEPS"]),
                        })
            del input_ids , attention_mask 
            loss.backward()
            optimizer.step()
            scheduler.step()
            xm.mark_step()
            optimizer.zero_grad()
            del loss 
                            
            if (step+1)% FLAGS['VAL_STEPS'] == 0:
                end_index=FLAGS["VAL_BATCH"]
                with torch.no_grad():
                    total_loss = 0
                    total_step = 0
                    for stepx, batchx in enumerate(testing_loader):
                        input_ids = batchx["input_ids"].to(device)
                        labels = batchx["labels"].to(device)
                        attention_mask = batchx["attention_mask"].to(device)
                        xs.mark_sharding(input_ids, mesh, (0, 1))
                        xs.mark_sharding(labels, mesh, (0, 1))
                        xs.mark_sharding( attention_mask,    mesh, (0, 1))
                        outputs = model(input_ids=input_ids,attention_mask=attention_mask)
                        loss = evaluate_loss(outputs,labels)
                        total_loss += loss.item()
                        total_step +=1
                        xm.master_print('----- Time -> {} ----- Validation Batch -> {} ----  Validation Loss -> {:.4f}'.format(test_utils.now(), total_step , loss.item()))
                        if __wandb__:
                            val_step+=1
                            wandb.log({
                                'Validation_loss': loss.item(),
                                'val_step':val_step,
                                    })
                        if (stepx+1)%end_index==0:
                            break
                        
                    average_loss=total_loss/total_step
                    xm.master_print('----- Time -> {} ----- Validation Batch Size -> {} ----  Validation Loss -> {:.7f}'.format(test_utils.now(), total_step , average_loss))

            if (step+1)% FLAGS['PAUSE_STEPS']==0:
                inp=input('want to continue training after {} steps'.format(step+1))
                check = bool("no" in inp.lower())
                if check:
                    break
                else:
                    pass
            
        
        
        
        
          

**12 Mins to Train on 4k**

In [16]:
train(FLAGS)
if FLAGS['WANDB']:
    wandb.finish()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


xla:0
Epoch 1 train begin 08:45:48
loss: 3.703125238418579, time: 08:46:33, step: 1
loss: 3.781250238418579, time: 08:48:48, step: 2
loss: 3.750000238418579, time: 08:50:31, step: 3
loss: 3.250000238418579, time: 08:50:32, step: 4
loss: 3.343750238418579, time: 08:50:33, step: 5
loss: 3.218750238418579, time: 08:50:35, step: 6
loss: 3.593750238418579, time: 08:50:36, step: 7
loss: 3.625000238418579, time: 08:50:38, step: 8
loss: 3.328125238418579, time: 08:50:39, step: 9
loss: 3.281250238418579, time: 08:50:41, step: 10
loss: 3.500000238418579, time: 08:50:42, step: 11
loss: 3.046875238418579, time: 08:50:44, step: 12
loss: 3.296875238418579, time: 08:50:45, step: 13
loss: 3.421875238418579, time: 08:50:47, step: 14
loss: 2.937500238418579, time: 08:50:48, step: 15
loss: 2.968750238418579, time: 08:50:50, step: 16
loss: 2.859375238418579, time: 08:50:51, step: 17
loss: 3.156250238418579, time: 08:50:53, step: 18
loss: 3.343750238418579, time: 08:50:54, step: 19
loss: 3.000000238418579,

want to continue training after 100 steps yes


loss: 1.562500238418579, time: 08:59:11, step: 101
loss: 1.710937738418579, time: 08:59:15, step: 102
loss: 1.4843753576278687, time: 08:59:19, step: 103
loss: 1.6562503576278687, time: 08:59:20, step: 104
loss: 1.648437738418579, time: 08:59:22, step: 105
loss: 1.562500238418579, time: 08:59:23, step: 106
loss: 1.625000238418579, time: 08:59:25, step: 107
loss: 1.476562738418579, time: 08:59:26, step: 108
loss: 1.523437738418579, time: 08:59:28, step: 109
loss: 1.484375238418579, time: 08:59:29, step: 110
loss: 1.585937738418579, time: 08:59:31, step: 111
loss: 1.6328128576278687, time: 08:59:32, step: 112
loss: 1.6171878576278687, time: 08:59:34, step: 113
loss: 1.6328128576278687, time: 08:59:35, step: 114
loss: 1.632812738418579, time: 08:59:37, step: 115
loss: 1.5156253576278687, time: 08:59:38, step: 116
loss: 1.6718753576278687, time: 08:59:40, step: 117
loss: 1.4687503576278687, time: 08:59:41, step: 118
loss: 1.5625003576278687, time: 08:59:43, step: 119
loss: 1.53906273841857

want to continue training after 200 steps yes


loss: 1.492187738418579, time: 09:02:38, step: 201
loss: 1.429687738418579, time: 09:02:42, step: 202
loss: 1.476562738418579, time: 09:02:46, step: 203
loss: 1.468750238418579, time: 09:02:48, step: 204
loss: 1.468750238418579, time: 09:02:49, step: 205
loss: 1.390625238418579, time: 09:02:50, step: 206
loss: 1.453125238418579, time: 09:02:52, step: 207
loss: 1.359375238418579, time: 09:02:54, step: 208
loss: 1.351562738418579, time: 09:02:55, step: 209
loss: 1.390625238418579, time: 09:02:56, step: 210
loss: 1.476562738418579, time: 09:02:58, step: 211
loss: 1.3984378576278687, time: 09:03:00, step: 212
loss: 1.2578128576278687, time: 09:03:01, step: 213
loss: 1.4296878576278687, time: 09:03:02, step: 214
loss: 1.421875238418579, time: 09:03:04, step: 215
loss: 1.406250238418579, time: 09:03:06, step: 216
loss: 1.617187738418579, time: 09:03:07, step: 217
loss: 1.546875238418579, time: 09:03:09, step: 218
loss: 1.445312738418579, time: 09:03:10, step: 219
loss: 1.453125238418579, tim

want to continue training after 300 steps yes


loss: 1.375000238418579, time: 09:06:21, step: 301
loss: 1.335937738418579, time: 09:06:25, step: 302
loss: 1.421875238418579, time: 09:06:29, step: 303
loss: 1.304687738418579, time: 09:06:30, step: 304
loss: 1.460937738418579, time: 09:06:32, step: 305
loss: 1.500000238418579, time: 09:06:33, step: 306
loss: 1.351562738418579, time: 09:06:35, step: 307
loss: 1.414062738418579, time: 09:06:37, step: 308
loss: 1.492187738418579, time: 09:06:38, step: 309
loss: 1.406250238418579, time: 09:06:40, step: 310
loss: 1.312500238418579, time: 09:06:41, step: 311
loss: 1.414062738418579, time: 09:06:43, step: 312
loss: 1.335937738418579, time: 09:06:44, step: 313
loss: 1.328125238418579, time: 09:06:46, step: 314
loss: 1.437500238418579, time: 09:06:47, step: 315
loss: 1.351562738418579, time: 09:06:49, step: 316
loss: 1.578125238418579, time: 09:06:50, step: 317
loss: 1.367187738418579, time: 09:06:52, step: 318
loss: 1.382812738418579, time: 09:06:53, step: 319
loss: 1.445312738418579, time: 

want to continue training after 400 steps no




0,1
Learning_rate,▁▂▄▆███▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁
Validation_loss,█▇▅▅▃▄▂▃▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▇█▇▅▃▃▃▃▂▂▂▂▂▁▂▂▂▁▁▂▂▁▂▂▂▁▂▁▂▂▁▁▂▁▂▁▂▁▁
train_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███

0,1
Learning_rate,0.0
Validation_loss,1.38281
train_loss,1.24219
train_step,400.0
val_step,80.0


In [17]:
import time
print('Loading the model on CPU')
START=time.time()
model = model.cpu()
print(f"Loaded model on cpu in {time.time()-START} seconds ")

Loading the model on CPU
Loaded model on cpu in 129.38400077819824 seconds 


In [18]:
from huggingface_hub import login
login("hf_token") ##
model.push_to_hub(
    SAVED_MODEL, 
    tokenizer=tokenizer,
    safe_serialization=True,
    private=True,
    create_pr=True,
    max_shard_size="3GB", 
    )
tokenizer.push_to_hub(
    SAVED_MODEL,
    private=True, 
    
    )

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


README.md: 100%|██████████| 5.18k/5.18k [00:00<00:00, 18.4MB/s]
adapter_model.safetensors: 100%|██████████| 300M/300M [00:05<00:00, 54.2MB/s] 
tokenizer.json: 100%|██████████| 17.5M/17.5M [00:00<00:00, 42.9MB/s]


CommitInfo(commit_url='https://huggingface.co/fhai50032/Gemma-Unaligned/commit/6df55c71cc91b149a6a7da601c23e01672715801', commit_message='Upload tokenizer', commit_description='', oid='6df55c71cc91b149a6a7da601c23e01672715801', pr_url=None, pr_revision=None, pr_num=None)