In [2]:
# %%capture
!pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html
!pip3 install transformers zstandard jsonlines peft wandb bitsandbytes -q
!pip3 install accelerate datasets sentencepiece langchain -q
!pip uninstall -y tensorflow
!pip install tensorflow-cpu -q
!git clone https://github.com/IsNoobgrammer/Pytorch-Optimizers optims

Looking in links: https://storage.googleapis.com/libtpu-releases/index.html, https://storage.googleapis.com/libtpu-wheels/index.html
Collecting libtpu==0.0.7.1
  Downloading libtpu-0.0.7.1-py3-none-manylinux_2_27_x86_64.whl (131.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m131.7/131.7 MB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting libtpu-nightly==0.1.dev20241010+nightly.cleanup
  Downloading https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20241010%2Bnightly.cleanup-py3-none-any.whl (1.3 kB)
Installing collected packages: libtpu-nightly, libtpu
  Attempting uninstall: libtpu-nightly
    Found existing installation: libtpu-nightly 0.1.dev20241002+nightly
    Uninstalling libtpu-nightly-0.1.dev20241002+nightly:
      Successfully uninstalled libtpu-nightly-0.1.dev20241002+nightly
  Attempting uninstall: libtpu
    Found existing installation: libtpu 2.18.0
    Uninstalling libtpu-2

In [3]:
#get_ipython().kernel.do_shutdown(True)
### for good measures restart kernel

**Sharding Module for different Architechture**

In [4]:
%%writefile spmd_util.py
import math
from dataclasses import dataclass, field
from typing import List, Optional
from collections import defaultdict
import torch
import torch.nn as nn
import re
import torch_xla.distributed.spmd.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, MT5Config, T5Config, LlamaConfig, GPT2Config, MistralConfig, Qwen2Config, MixtralConfig, PhiConfig,GemmaConfig
)

# ends with $ to prevent sharding lora parameters

MT5_RULES = (
    # embeddings
    ("shared$", ("mp", "fsdp")),
    ("embed_tokens$", ("mp", "fsdp")),
    
    # attention
    ("q$", ("fsdp", "mp")),
    ("k$", ("fsdp", "mp")),
    ("v$", ("fsdp", "mp")),
    ("o$", ("mp", "fsdp")),
    
    # mlp
    ("wi$", ("fsdp", "mp")),
    ("wi_0$", ("fsdp", "mp")),
    ("wi_1$", ("fsdp", "mp")),
    ("wo$", ("mp", "fsdp")),
    
    # output head
    ("lm_head$", ("fsdp", "mp")),
    ("final_layer_norm$", ("fsdp", "mp")),
)

T5_RULES = (
    # embeddings
    ("shared$", ("mp", "fsdp")),
    ("embed_tokens$", ("mp", "fsdp")),
    
    # attention
    ("q$", ("fsdp", "mp")),
    ("k$", ("fsdp", "mp")),
    ("v$", ("fsdp", "mp")),
    ("o$", ("mp", "fsdp")),

    # mlp
    ("w$", ("fsdp", "mp")),
    ("wi_0$", ("fsdp", "mp")),
    ("wi_1$", ("fsdp", "mp")),
    ("wo$", ("mp", "fsdp")),

    # seq2seq lm head
    ("lm_head", ("fsdp", "mp")),
)

QWEN_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )
GPT2_RULES = (
    # embeddings
    ("wte", ("mp", "fsdp")), 
    ("wpe", ("mp", "fsdp")),
    
    # attention
    ("c_attn", ("fsdp", "mp")),
    ("c_proj", ("mp", "fsdp")),
    
    # mlp
    ("c_fc", ("fsdp", "mp")), 
    ("c_proj", ("mp", "fsdp")),
    
    # output 
    ("ln_f", ("fsdp", "mp")),
)
MISTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.(gate_proj|up_proj)", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("lm_head", ("fsdp", "mp")),
    )


PHI_RULES = (
    ### (regex) linear modules, (list[sharding methods]) )
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.dense", ("mp", "fsdp")),
    ("mlp\\.fc2", ("mp", "fsdp")),  
    ("mlp\\.fc1", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    
)

LLAMA_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )

GPTNEOX_RULES = (
    # embeddings
    ("gpt_neox\\.embed_in", ("mp", "fsdp")),
    # atention
    ("attention\\.query_key_value$", ("fsdp", "mp")),
    ("attention\\.dense$", ("mp", "fsdp")),
    # mlp
    ("mlp\\.dense_h_to_4h$", ("fsdp", "mp")),
    ("mlp\\.dense_4h_to_h$", ("mp", "fsdp")),
    # output
    ("embed_out", ("fsdp", "mp")),
)



MIXTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("w1", ("fsdp", "mp")),
    ("w2", ("mp", "fsdp")),
    ("w3", ("fsdp", "mp")),
    ("gate", ("mp", "fsdp")),
    ("lm_head", ("fsdp", "mp")),
    )

GEMMA_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
)
    
ALL_RULES = [
    (GPTNeoXConfig, GPTNEOX_RULES),
    (MT5Config, MT5_RULES),
    (T5Config, T5_RULES),
    (LlamaConfig, LLAMA_RULES),
    (GPT2Config, GPT2_RULES),
    (MistralConfig, MISTRAL_RULES),
    (Qwen2Config, QWEN_RULES),
    (MixtralConfig, MIXTRAL_RULES),
    (PhiConfig,PHI_RULES),
    (GemmaConfig,GEMMA_RULES),
]


def find_rule(model):
    for config, rule in ALL_RULES:
        x1=(str(config).split("."))[-1]
        x2=(str(model.config.__class__).split("."))[-1]
#         print(x1,x2)
        if x1.lower()==x2.lower():
            return rule
    raise Exception("unsupported model to partitioning")

strkey2id = {
    "dp": 0, ## usefull for sharding inputs
    "fsdp": 1, ## Pytorch-Xla (2D-sharding) axis to shard data (mostly mesh shape will be (8,1)) data will be sharded 8 way 
    "mp": 2 ## axis to shard model model will be sharded one way 
               ## Recommened checking Pytorch-tpu/transfomers on github (xla-fork of transformers)
}

def partition_module(model, mesh, device=xm.xla_device(), verbose=False):
    partition_specs = find_rule(model)
    rule = [(k, tuple([strkey2id[x] for x in v])) for k, v in partition_specs]
        
    # print(rule)

    for name, module in model.named_modules():
        module.to(device)
#         print(name, module.__class__.__name__)
        if isinstance(module, (nn.Embedding, nn.Linear)):
            for rule_pattern, spec in rule:
                if re.findall(rule_pattern, name.lower())  : # and ("lora" not in name.lower()):
                    if verbose:
                        print("match", rule_pattern, name)
                    
                    xs.mark_sharding(module.weight, mesh, spec)
                    break


Overwriting spmd_util.py


**Required Libs**

In [5]:
import os
import pandas as pd
import numpy as np
import datasets
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch
import torch.nn as nn
import torch_xla.test.test_utils as test_utils
import torch_xla.core.xla_model as xm
from transformers import (
 AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, set_seed, DataCollatorWithPadding, AutoConfig 
)

from transformers import logging as hf_logging
import torch_xla.runtime as xr

xr.use_spmd()

from torch_xla.distributed.spmd.xla_sharding import Mesh

from peft import LoraConfig, TaskType, get_peft_model 
from datasets import  load_dataset, concatenate_datasets
from tqdm import tqdm

from torch.utils.data import Dataset as TorchDataset
from torch_xla.utils.checkpoint import checkpoint

try:
    !export USE_TORCH=True #If we don't do this, transformers will seemingly bork the session upon import. Really weird error.
    os.environ["PJRT_DEVICE"] = "TPU"
    os.environ.pop('TPU_PROCESS_ADDRESSES')
    os.environ.pop('CLOUD_TPU_TASK_ID')
    hf_logging.set_verbosity_error() # It can still display warnings which is a bit annoying but whatever
except:
    pass

  from .autonotebook import tqdm as notebook_tqdm


**Configuration**

In [6]:
MODEL = "aisingapore/Llama-SEA-LION-v3.5-8B-R"#You should be able to use 7B model with no changes! There should be enough HBM
SAVED_MODEL = f"{MODEL.split('/')[1]}-legalqa"
WANDB_PROJECT = MODEL.split('/')[1]

# !export XLA_TENSOR_ALLOCATOR_MAXSIZE=1000000

In [7]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL)
if 'pad_token' not in tokenizer.special_tokens_map:
  tokenizer.pad_token=tokenizer.eos_token


print(f"Tokens :\n {tokenizer.special_tokens_map} \n\n")

Tokens :
 {'bos_token': '<|begin_of_text|>', 'eos_token': '<|eot_id|>', 'pad_token': '<|finetune_right_pad_id|>'} 




In [8]:
dataset = load_dataset("ShoAnn/legalqa_klinik_hukumonline")
train_val = dataset["train"].train_test_split(test_size=0.1, seed=42)

# Split the test_valid into test and validation
test_valid = dataset["test"]

from datasets import DatasetDict
# Gather the splits into a DatasetDict
dataset = DatasetDict({
    "train": train_val["train"],
    "test": train_val["test"],
    "valid": test_valid
})

Generating train split: 100%|██████████| 1006/1006 [00:00<00:00, 15611.48 examples/s]
Generating test split: 100%|██████████| 112/112 [00:00<00:00, 15214.97 examples/s]


In [9]:
prompt_template = """Below is a question paired with a context to the problem in the question. Write an answer that appropriately answers the question based on the context.

### Question:
{}

### Context:
{}

### Answer:
{}"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
    instructions = examples["question"]
    inputs       = examples["context"]
    outputs      = examples["answer"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = prompt_template.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }
pass

dataset = dataset.map(formatting_prompts_func, batched = True,)


Map: 100%|██████████| 905/905 [00:00<00:00, 3255.36 examples/s]
Map: 100%|██████████| 101/101 [00:00<00:00, 3435.03 examples/s]
Map: 100%|██████████| 112/112 [00:00<00:00, 3471.05 examples/s]


In [10]:
class ConversationDataset(TorchDataset):
    def __init__(self, tokenizer, max_length=1024, dataset=None):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]["text"]
        input_ids = self.tokenizer(text, add_special_tokens=True, max_length=self.max_length, truncation=True, padding="max_length", return_attention_mask=True, return_tensors="pt")
        return {
            "input_ids": input_ids["input_ids"].squeeze(0),
            "labels": input_ids["input_ids"].squeeze(0),
            "attention_mask":input_ids["attention_mask"].squeeze(0),
        }

In [11]:
len(dataset["train"])

905

In [12]:
FLAGS = {'MAX_INPUT': 2048,
         'LOGGING_STEPS': 1,
         'NUM_EPOCHS': 10,
         'PAUSE_STEPS':1000, # asks to exit training after x steps , #todo checkpoint saving me no lazy
         'MAX_STEPS': -1,#Ooverides num epochs
         'BATCH_SIZE': 2, #Making batch_size lower then 8 will result in slower training, but will allow for larger models\context. Fortunately, we have 128GBs. Setting higher batch_size doesn't seem to improve time.
         'LEN_TRAIN_DATA': dataset["train"],
         'VAL_STEPS': 20,
         'VAL_BATCH': 5,
         'GRAD_ACCUMULATION_STEP':4,
         'MAX_GRAD_CLIP':1,
        'LEARNING_RATE':2e-5,
         'WARMUP_RATIO':0.01,
         'OPTIMIZER':'adamw', # default = 'adamw'  options->  ['adamw','SM3','came','adafactor','lion']           
         'SCHEDULAR':'cosine', # default= 'cosine'     options:-> ['linear','cosine']
         'WEIGHT_DECAY':0.1,
         'TRAIN_DATASET':dataset["train"],
         "TEST_DATASET":dataset["train"],
         'WANDB':True,
        'PROJECT': WANDB_PROJECT,
        }

**Quantization When??**

In [13]:
model = AutoModelForCausalLM.from_pretrained(MODEL,torch_dtype=torch.bfloat16) 
# model._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=checkpoint) 
# gradient checkpointing is not properly setup needs to do barieer optimization

Fetching 4 files: 100%|██████████| 4/4 [00:20<00:00,  5.14s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 49.23it/s]


**LoRA Applicable**

In [14]:
ls=LoraConfig(
    r = 32, # Lora Rank ,I would prefer 8-32 for smaller models like 7b
    target_modules =['q_proj', 'down_proj', 'up_proj', 'o_proj', 'v_proj', 'gate_proj', 'k_proj'],
    lora_alpha = 16, #weight_scaling
    lora_dropout = 0.05, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimize
    # modules_to_save = ["lm_head", "embed_tokens"] ## if you use new chat formats or embedding tokens
)
model = get_peft_model(model, ls)
model.print_trainable_parameters()



trainable params: 83,886,080 || all params: 8,114,147,328 || trainable%: 1.0338


**Data-Distributer**

In [15]:
train_data = ConversationDataset(tokenizer, dataset=dataset["train"] , max_length=FLAGS['MAX_INPUT'])
val = ConversationDataset(tokenizer, dataset=dataset["valid"])
train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_data, num_replicas=8, rank=xm.get_ordinal(), shuffle=True,drop_last=True)
training_loader = torch.utils.data.DataLoader(train_data, batch_size=FLAGS["BATCH_SIZE"], sampler=train_sampler)
val_sampler = torch.utils.data.distributed.DistributedSampler(
    val, num_replicas=8, rank=xm.get_ordinal(), shuffle=True,drop_last=True)
testing_loader = torch.utils.data.DataLoader(val, batch_size=FLAGS["BATCH_SIZE"], sampler=val_sampler)

print(f"Max Steps:- {len(training_loader)}  , eFFECTIVE bATCH size {8*FLAGS['BATCH_SIZE']} Input")
print(f"Val Size:- {len(testing_loader)}  , eFFECTIvE bATCH size {8*FLAGS['BATCH_SIZE']} Input")
FLAGS['STEPS']=len(training_loader)
FLAGS['BATCH_DATA']=FLAGS['BATCH_SIZE']*8 ## 8 CORES ON TPU 
# print(device)



Max Steps:- 57  , eFFECTIVE bATCH size 16 Input
Val Size:- 7  , eFFECTIvE bATCH size 16 Input


In [16]:
print(val[0]['input_ids'])
c=0
for step, batch in enumerate(training_loader):

    print(step)
    print(tokenizer.decode(batch['input_ids'][0]))
    break
# print(tokenizer.decode(val[0]['input_ids']))

tensor([128000,  39314,    374,  ...,  68531,   1141,  10602])
0
<|begin_of_text|>Below is a question paired with a context to the problem in the question. Write an answer that appropriately answers the question based on the context.

### Question:
Kerabat saya ditahan polisi sebagai jaminan untuk teman dia yang melarikan diri, tetapi setelah sampai di Polsek yang bersangkutan, kerabat saya yang tidak melakukan apapun malah mendapat tindakan yang tidak menyenangkan yaitu penyiksaan. Bagaimana hukumnya ini?

### Context:
[{'full_text': '(2) Uang yang dimaksud dalam ayat (1) harus disetor ke Kas Negara melalui panitera pengadilan negeri.', 'id': 2562, 'name': 'Pasal 36 ayat (2) PP Pelaksanaan KUHAP'}, {'full_text': '8.  pengaduan masyarakat yang selanjutnya disebut dumas  adalah bentuk penerapan dari pengawasan masyarakat  yang disampaikan oleh masyarakat, instansi pemerintah  atau pihak lain kepada polri berupa sumbangan pikiran,  saran, gagasan atau keluhan/pengaduan yang bersifat  mem

In [17]:
def get_nb_trainable_parameters(model):
        r"""
        Returns the number of trainable parameters and number of all parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in model.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel

            # Due to the design of 4bit linear layers from bitsandbytes
            # one needs to multiply the number of parameters by 2 to get
            # the correct number of parameters
            if param.__class__.__name__ == "Params4bit":
                num_params = num_params * 2

            all_param += num_params
            if param.requires_grad:
                trainable_params += num_params

        return trainable_params, all_param
def print_trainable_parameters(model):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params, all_param = get_nb_trainable_parameters(model)
        
        print(
            f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
        )

In [18]:
print_trainable_parameters(model)

trainable params: 83,886,080 || all params: 8,114,147,328 || trainable%: 1.0338249554642545


In [19]:
from spmd_util import partition_module
import torch_xla.distributed.parallel_loader as pl

device = xm.xla_device()
model = model.to(device)
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
from torch_xla.distributed.spmd.xla_sharding import xla_patched_nn_linear_forward
model = apply_xla_patch_to_nn_linear(model, xla_patched_nn_linear_forward)  #for patching linear layer to use einsum instead of matmul
num_devices = xr.global_runtime_device_count()
model_axis=1
data_axis=num_devices//model_axis
mesh_shape = (1,data_axis, model_axis )
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp','fsdp', 'mp'))
partition_module(model, mesh)
training_loader = pl.MpDeviceLoader(training_loader, device)
testing_loader = pl.MpDeviceLoader(testing_loader, device)
mesh_shape

(1, 8, 1)

In [20]:
FLAGS

{'MAX_INPUT': 2048,
 'LOGGING_STEPS': 1,
 'NUM_EPOCHS': 10,
 'PAUSE_STEPS': 1000,
 'MAX_STEPS': -1,
 'BATCH_SIZE': 2,
 'LEN_TRAIN_DATA': Dataset({
     features: ['question', 'answer', 'context', 'text'],
     num_rows: 905
 }),
 'VAL_STEPS': 20,
 'VAL_BATCH': 5,
 'GRAD_ACCUMULATION_STEP': 4,
 'MAX_GRAD_CLIP': 1,
 'LEARNING_RATE': 2e-05,
 'WARMUP_RATIO': 0.01,
 'OPTIMIZER': 'adamw',
 'SCHEDULAR': 'cosine',
 'WEIGHT_DECAY': 0.1,
 'TRAIN_DATASET': Dataset({
     features: ['question', 'answer', 'context', 'text'],
     num_rows: 905
 }),
 'TEST_DATASET': Dataset({
     features: ['question', 'answer', 'context', 'text'],
     num_rows: 905
 }),
 'WANDB': True,
 'PROJECT': 'Llama-SEA-LION-v3.5-8B-R',
 'STEPS': 57,
 'BATCH_DATA': 16}

In [21]:
!export XLA_USE_BF16=1
import torch.nn as nn
import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_key = user_secrets.get_secret("wandb_api_key")
wandb.login(key=wandb_key)
__wandb__=FLAGS['WANDB']
from torch_xla.amp.syncfree import AdamW
from transformers import get_linear_schedule_with_warmup,get_cosine_schedule_with_warmup
from optims.optim import SM3, CAME , Adafactor
import torch_xla.distributed.spmd as xs
# from random import randrange
# from bitsandbytes.optim import AdamW8bit 
# from torchdistx.optimizers import AnyPrecisionAdamW

val_step=0

def evaluate_loss(outputs_logits, labels, pad_id=tokenizer.pad_token_id):
    # Simplest possible loss for debugging XLA issues
    # This bypasses all masking, gathering etc.
    target_labels = labels[..., 1:].contiguous().long()
    logits_for_loss = outputs_logits[..., :-1, :].contiguous().float()

    # Ensure target_labels are within vocab size
    vocab_size = logits_for_loss.size(-1)
    target_labels = torch.clamp(target_labels, 0, vocab_size - 1)

    loss = nn.functional.cross_entropy(
        logits_for_loss.view(-1, vocab_size), 
        target_labels.view(-1),
        ignore_index=pad_id # Use ignore_index if possible
    )
    return loss.to(outputs_logits.dtype)



def train(FLAGS):

    ### Configuring Training
    global val_step
    update_params= filter(lambda p: p.requires_grad, model.parameters())
    num_iterations = int((FLAGS["NUM_EPOCHS"] * FLAGS['STEPS'] ) // FLAGS['GRAD_ACCUMULATION_STEP'])
    warmup_steps = int(num_iterations * FLAGS['WARMUP_RATIO'])
    
    if __wandb__:
        wandb.init(project=FLAGS['PROJECT'],config=FLAGS)
        wandb.define_metric("Validation_loss", step_metric="val_step")
        wandb.define_metric("Learning_rate",step_metric="train_step")
        wandb.define_metric("train_loss",step_metric="train_step")
    
    ### Optimizers
    
    if (FLAGS['OPTIMIZER']).lower()=='adamw':
        optimizer = AdamW(update_params, eps=1e-8, lr=FLAGS['LEARNING_RATE'], betas=(0.9, 0.999),weight_decay=FLAGS['WEIGHT_DECAY'])
    elif (FLAGS['OPTIMIZER']).lower()=='lion':
        optimizer = Lion(update_params, lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'])
    elif (FLAGS['OPTIMIZER']).lower()=='adafactor':
        optimizer = Adafactor(update_params,lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'],scale_parameter=False,relative_step=False)
    elif (FLAGS['OPTIMIZER']).lower()=='came':
        optimizer = CAME(model.parameters(),lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'],betas=(0.9, 0.999, 0.9999),eps=(1e-30, 1e-16))
    else:
#         optimizer = Lilith(update_params, eps=1e-8, lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'])
        optimizer=SM3(update_params,lr=FLAGS['LEARNING_RATE'])

    for param_group in optimizer.param_groups:
        if len(param_group["params"]) > 0:
            print(param_group["params"][0].device)
            break
    
    ### Schedulars
    
    if (FLAGS['SCHEDULAR']).lower()=='linear':
        scheduler = get_linear_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
    else:
        scheduler = get_cosine_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
        
    
    ### Training Loop
    val_step=0
    check=False #for brakes
    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        if check:
            break
        model.train()
        xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
        for step, batch in enumerate(training_loader):
            input_ids, labels,attention_mask = batch["input_ids"].to(device),  batch["labels"].to(device),batch['attention_mask'].to(device)
            xs.mark_sharding(input_ids, mesh, (0,1))  ### earlier:-> (0,1) according to pytorch-xla , input/dataloaders must be sharded across ('data',None) 
            xs.mark_sharding( labels,  mesh, (0,1))  ###
            xs.mark_sharding(  attention_mask,  mesh, (0, 1)) ###
        
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            # outputs.logits = outputs.logits.to(labels.dtype)  # ✅ Convert only logits
            loss = evaluate_loss(outputs.logits, labels) # Pass the logits tensor

            if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                xm.master_print(f'loss: {loss.detach().cpu().item()}, time: {test_utils.now()}, step: {step+1}')
            if __wandb__:
                wandb.log({
                'Learning_rate': optimizer.param_groups[0]['lr'],
                'train_loss':  loss.detach().cpu().item(),
                'train_step': step + 1 + ((epoch-1) * FLAGS["STEPS"]),
                })
 
            del input_ids , attention_mask 
            loss.backward()
            del outputs,loss
            
            if (step+1) % FLAGS['GRAD_ACCUMULATION_STEP'] == 0:
                torch.nn.utils.clip_grad_norm_(update_params, max_norm=FLAGS['MAX_GRAD_CLIP']*8)
                scheduler.step()
                xm.optimizer_step(optimizer,pin_layout=True,barrier=True) ## performs xm.reduce_gradient() , optimizer.step() , xm.mark_step()
                optimizer.zero_grad()
            
            if (step+1)% FLAGS['VAL_STEPS'] == 0:
                end_index=FLAGS["VAL_BATCH"]
                model.eval()
                with torch.no_grad():
                    total_loss = 0
                    total_step = 0
                    for stepx, batchx in enumerate(testing_loader):
                        input_ids = batchx["input_ids"].to(device)
                        labels = batchx["labels"].to(device)
                        attention_mask = batchx["attention_mask"].to(device)
                        xs.mark_sharding(input_ids, mesh, (0, None))
                        xs.mark_sharding(labels, mesh, (0, None))
                        xs.mark_sharding( attention_mask,    mesh, (0, None))
                        outputs = model(input_ids=input_ids,attention_mask=attention_mask)
                        # --->>> !!! THIS IS A DIFFERENT PROBLEM !!! <<<---
                        loss = evaluate_loss(outputs.logits, labels)
                        total_loss += loss.item()
                        total_step +=1
                        xm.master_print('----- Time -> {} ----- Validation Batch -> {} ----  Validation Loss -> {:.4f}'.format(test_utils.now(), total_step , loss.item()))
                        if __wandb__:
                            val_step+=1
                            wandb.log({
                                'Validation_loss': loss.item(),
                                'val_step':val_step,
                                    })
                        if (stepx+1)%end_index==0:
                            break
                    model.train()    
                    average_loss=total_loss/total_step
                    xm.master_print('----- Time -> {} ----- Validation Batch Size -> {} ----  Validation Loss -> {:.7f}'.format(test_utils.now(), total_step , average_loss))

            if (step+1)% FLAGS['PAUSE_STEPS']==0:
                inp=input('want to continue training after {} steps'.format(step+1))
                check = bool("no" in inp.lower())
                if check:
                    break
                else:
                    pass

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mshoann[0m ([33mshoann-mycompany-inc[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


**12 Mins to Train on 4k**

In [22]:
train(FLAGS)
if FLAGS['WANDB']:
    wandb.finish()

xla:0
Epoch 1 train begin 23:11:03
loss: 1.5390625, time: 23:11:43, step: 1
loss: 1.296875, time: 23:13:51, step: 2
loss: 1.25, time: 23:15:23, step: 3
loss: 1.4765625, time: 23:15:26, step: 4
loss: 1.125, time: 23:17:23, step: 5
loss: 1.2890625, time: 23:18:49, step: 6
loss: 1.1875, time: 23:18:51, step: 7
loss: 1.375, time: 23:18:54, step: 8
loss: 1.6484375, time: 23:20:58, step: 9
loss: 1.0078125, time: 23:21:00, step: 10
loss: 1.4375, time: 23:21:03, step: 11
loss: 1.5703125, time: 23:21:06, step: 12
loss: 1.40625, time: 23:21:09, step: 13
loss: 1.2890625, time: 23:21:11, step: 14
loss: 0.69140625, time: 23:21:14, step: 15
loss: 1.484375, time: 23:21:16, step: 16
loss: 1.1953125, time: 23:21:19, step: 17
loss: 0.91796875, time: 23:21:22, step: 18
loss: 1.2265625, time: 23:21:24, step: 19
loss: 1.4609375, time: 23:21:27, step: 20
----- Time -> 23:21:51 ----- Validation Batch -> 1 ----  Validation Loss -> 1.6641
----- Time -> 23:22:12 ----- Validation Batch -> 2 ----  Validation Loss

0,1
Learning_rate,▁█████████▇▇▇▇▆▆▆▆▆▆▅▅▅▄▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁
Validation_loss,▃█▇█▇▇▅▆▇▂▇▇▆▂▆▆▅▁▆▃▅▆▁▃▅▄▆▁▃▆▅▁▅▆▁▄▆▁▁▃
train_loss,▆█▇▇▅▆▆▅▆▆▆▇▂▆▄▅▄▁▇▄▅▅▃▆▇▄▄▄▅▄▆▆▆▅▃▃▆▅▅▁
train_step,▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇██
val_step,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇█

0,1
Learning_rate,0.0
Validation_loss,1.05469
train_loss,0.95703
train_step,570.0
val_step,100.0


In [23]:
import time
print('Loading the model on CPU')
START=time.time()
model = model.cpu()
print(f"Loaded model on cpu in {time.time()-START} seconds ")

Loading the model on CPU
Loaded model on cpu in 130.74386429786682 seconds 


In [29]:
from huggingface_hub import login

secret_label = "HF_TOKEN"
login(UserSecretsClient().get_secret(secret_label))

model.push_to_hub(
    SAVED_MODEL, 
    tokenizer=tokenizer,
    # safe_serialization=True,
    # create_pr=True,
    )

No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/ShoAnn/Llama-SEA-LION-v3.5-8B-R-legalqa/commit/0cd54bc2862d7f54d5f6d23e22a8dead3d8482a8', commit_message='Upload model', commit_description='', oid='0cd54bc2862d7f54d5f6d23e22a8dead3d8482a8', pr_url=None, repo_url=RepoUrl('https://huggingface.co/ShoAnn/Llama-SEA-LION-v3.5-8B-R-legalqa', endpoint='https://huggingface.co', repo_type='model', repo_id='ShoAnn/Llama-SEA-LION-v3.5-8B-R-legalqa'), pr_revision=None, pr_num=None)

In [28]:
!pip install -U transformers accelerate peft huggingface_hub

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting huggingface_hub
  Downloading huggingface_hub-0.31.2-py3-none-any.whl (484 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m484.2/484.2 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: huggingface_hub
  Attempting uninstall: huggingface_hub
    Found existing installation: huggingface-hub 0.31.1
    Uninstalling huggingface-hub-0.31.1:
      Successfully uninstalled huggingface-hub-0.31.1
Successfully installed huggingface_hub-0.31.2
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [26]:
model

PeftModel(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=32, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(
         