In [None]:
!pip3 install transformers zstandard jsonlines peft wandb bitsandbytes -q
!pip3 install datasets sentencepiece langchain torch_xla[tpuvm] -q
!pip uninstall tensorflow -y #that's the meme part

**For Logging (Wandb)**

In [None]:
!huggingface-cli login --token 
import wandb
wandb.login()

**Sharding Module for different Architechture**

In [None]:
%%writefile spmd_util.py
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, GPT2Config, MistralConfig, Qwen2Config
)

# ends with $ to prevent sharding lora parameters
GPTNEOX_RULES = (
    # embeddings
    ("gpt_neox\\.embed_in", ("mp", "fsdp")),
    # atention
    ("attention\\.query_key_value$", ("fsdp", "mp")),
    ("attention\\.dense$", ("mp", "fsdp")),
    # mlp
    ("mlp\\.dense_h_to_4h$", ("fsdp", "mp")),
    ("mlp\\.dense_4h_to_h$", ("mp", "fsdp")),
    # output
    ("embed_out", ("fsdp", "mp")),
)

T5_RULES = (
    # embeddings
    ("shared$", ("mp", "fsdp")),
    ("embed_tokens$", ("mp", "fsdp")),
    
    # attention
    ("q$", ("fsdp", "mp")),
    ("k$", ("fsdp", "mp")),
    ("v$", ("fsdp", "mp")),
    ("o$", ("mp", "fsdp")),

    # mlp
    ("w$", ("fsdp", "mp")),
    ("wi_0$", ("fsdp", "mp")),
    ("wi_1$", ("fsdp", "mp")),
    ("wo$", ("mp", "fsdp")),

    # seq2seq lm head
    ("lm_head", ("fsdp", "mp")),
)

LLAMA_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )
QWEN_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )
GPT2_RULES = (
    # embeddings
    ("wte", ("mp", "fsdp")), 
    ("wpe", ("mp", "fsdp")),
    
    # attention
    ("c_attn", ("fsdp", "mp")),
    ("c_proj", ("mp", "fsdp")),
    
    # mlp
    ("c_fc", ("fsdp", "mp")), 
    ("c_proj", ("mp", "fsdp")),
    
    # output 
    ("ln_f", ("fsdp", "mp")),
)
MISTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )
    
ALL_RULES = [
    (GPTNeoXConfig, GPTNEOX_RULES),
    (T5Config, T5_RULES),
    (LlamaConfig, LLAMA_RULES),
    (GPT2Config, GPT2_RULES),
    (MistralConfig, MISTRAL_RULES),
    (Qwen2Config, QWEN_RULES)
]

def find_rule(model):
    for config, rule in ALL_RULES:
        if model.config.__class__ == config:
            return rule
    raise Exception("unsupported model to partitioning")

strkey2id = {
    "dp": 0,
    "fsdp": 1,
    "mp": 2
}

def partition_module(model, mesh, device=xm.xla_device(), verbose=False):
    partition_specs = find_rule(model)
    rule = [(k, tuple([strkey2id[x] for x in v])) for k, v in partition_specs]
        
    # print(rule)

    for name, module in model.named_modules():
        module.to(device)
        # print(name, module.__class__.__name__)
        if isinstance(module, (nn.Embedding, nn.Linear)):
            for rule_pattern, spec in rule:
                if re.findall(rule_pattern, name):
                    if verbose:
                        print("match", rule_pattern, name)
                    
                    xs.mark_sharding(module.weight, mesh, spec)
                    break
        
def partition_module_dp(model, mesh, device=xm.xla_device(), verbose=False):
    spec = (1, 2)

    for name, module in model.named_modules():
        module.to(device)
        if isinstance(module, (nn.Embedding, nn.Linear)):
            xs.mark_sharding(module.weight, mesh, spec)

**Required Libs**

In [None]:
import os
import pandas as pd
import numpy as np
import datasets
import torch.optim as optim
import torch_xla.debug.profiler as xp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp # We also import mp modules if we wanna use that for some reason
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, AutoTokenizer, AutoModelForCausalLM, MistralConfig, Qwen2Config, GPT2Config, DataCollatorWithPadding, AutoConfig, AutoModelForSequenceClassification
) # You can use any of models with those configs (even flan T5 xxl!). Other models are not supported.

from transformers import logging as hf_logging
import torch.nn.functional as F
import torch_xla.runtime as xr

xr.use_spmd()

import torch_xla.experimental.xla_sharding as xs # "experimental" prefix always means you're gonna have a good time LMAO
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

from peft import LoraConfig, TaskType, get_peft_model # If we wanna use peft. Quantazation requiers GPU though. You'll have to download already quantazed models
from spmd_util import partition_module                # You could experiment with using already quantazed models like 4bit/Llama-2-7b-Chat-GPTQ if you're feeling funny
from datasets import Dataset, load_dataset, concatenate_datasets
from dataclasses import dataclass
from tqdm import tqdm

import transformers
import datasets
import pandas as pd
import numpy as np
from datasets import Dataset
from torch.utils.data import Dataset as TorchDataset
import torch.utils
try:
    !export USE_TORCH=True #If we don't do this, transformers will seemingly bork the session upon import. Really weird error.
    os.environ["PJRT_DEVICE"] = "TPU"
    os.environ.pop('TPU_PROCESS_ADDRESSES')
    os.environ.pop('CLOUD_TPU_TASK_ID')
    hf_logging.set_verbosity_error() # It can still display warnings which is a bit annoying but whatever
except:
    pass


**Configuration**

In [None]:
MAX_INPUT=1024
MODEL = "fhai50032/xLakeChat" #You should be able to use 7B model with no changes! There should be enough HBM
SAVED_MODEL = "fhai50032/XLake-Coder"
# !export XLA_TENSOR_ALLOCATOR_MAXSIZE=1000000


In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL)
# tokenizer.pad_token = tokenizer.eos_token

In [None]:
class ConversationDataset(TorchDataset):
    def __init__(self, tokenizer, max_length=1024, dataset=None):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        messages = self.dataset[idx]["conversations"]
        text = ""
        for message in messages:
            role = message["from"]
            if role == "system":
                text += f"<|im_start|>system\n{message['value']}<|im_end|>\n"
            if role == "human":
                text += f"<|im_start|>user\n{message['value']}<|im_end|>\n"
            if role == "function-call":
                text += f"<|im_start|>call\n{message['value']}<|im_end|>\n"
            if role == "function-response":
                text += f"<|im_start|>function\n{message['value']}<|im_end|>\n"
            if role =="gpt":
                text += f"<|im_start|>assistant\n{message['value']}{self.tokenizer.eos_token}"
        input_ids = self.tokenizer(text, add_special_tokens=True, max_length=self.max_length, truncation=True, padding="max_length", return_attention_mask=True, return_tensors="pt")
        return {
            "input_ids": input_ids["input_ids"].squeeze(0),
            "labels": input_ids["input_ids"].squeeze(0),
        }

In [None]:
train_data = load_dataset("fhai50032/magicoder-oss-instruct-sharegpt-75k", split="train[:600]").shuffle()
val = (load_dataset('fhai50032/magicoder-oss-instruct-sharegpt-75k', split="train[:24]"))

In [None]:
len(train_data)

In [None]:
FLAGS = {'MAX_INPUT': MAX_INPUT,
         'LOGGING_STEPS': 6,
         'NUM_EPOCHS': 1,
         'BATCH_SIZE': 6, #Making batch_size lower then 8 will result in slower training, but will allow for larger models\context. Fortunately, we have 128GBs. Setting higher batch_size doesn't seem to improve time.
          'NUM_STEPS': len(train_data)//3,
        'GRAD_ACCUMALATION':2,
        'LEARNING_RATE':2e-5,
         'WARMUP_RATIO':0.1,
         'OPTIMIZER':'AdamW',
         'SCHEDULAR':'LINEAR',# or cosine 
         'WEIGHT_DECAY':0.0,
        'PROJECT':'TPU-Training',
        'RUN':'Silver-Pulse'} # Indian Pun 

In [None]:
FLAGS

In [None]:
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, token="")

In [None]:
## for debugging No-use
print(model)
for _, param in model.named_parameters():
#     print(_)
    if "down_proj" in str(_):
        print(param)

**LoRA Applicable**

In [None]:
ls=LoraConfig(
    r = 32, # Lora Rank ,I would prefer 8-32 for smaller models like 7b
    target_modules = ['v_proj', 'down_proj', 'up_proj', 'o_proj', 'q_proj', 'gate_proj', 'k_proj'],
    lora_alpha = 48, #weight_scaling
    lora_dropout = 0.05, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimize
    # modules_to_save = ["lm_head", "embed_tokens"] ## if you use new chat formats or embedding tokens
)
model = get_peft_model(model, ls)

**Data-Distributer**

In [None]:
train_data = ConversationDataset(tokenizer, dataset=train_data, max_length=1024)
val = ConversationDataset(tokenizer, dataset=val)
train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_data, num_replicas=8, rank=xm.get_ordinal(), shuffle=True)
training_loader = torch.utils.data.DataLoader(train_data, batch_size=FLAGS["BATCH_SIZE"], sampler=train_sampler)
val_sampler = torch.utils.data.distributed.DistributedSampler(
    val, num_replicas=8, rank=xm.get_ordinal(), shuffle=True)
testing_loader = torch.utils.data.DataLoader(val, batch_size=FLAGS["BATCH_SIZE"], sampler=val_sampler)

device = xm.xla_device()

In [None]:
print(device)

In [None]:
def get_nb_trainable_parameters(model):
        r"""
        Returns the number of trainable parameters and number of all parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in model.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel

            # Due to the design of 4bit linear layers from bitsandbytes
            # one needs to multiply the number of parameters by 2 to get
            # the correct number of parameters
            if param.__class__.__name__ == "Params4bit":
                num_params = num_params * 2

            all_param += num_params
            if param.requires_grad:
                trainable_params += num_params

        return trainable_params, all_param
def print_trainable_parameters(model):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params, all_param = get_nb_trainable_parameters(model)

        print(
            f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
        )

In [None]:
cnt = 0
for param in model.parameters():
    cnt += 1
    param.requires_grad = False
    if cnt > 0: # You can set this to a higher value to freeze parameters if your running out of memory.
        param.requires_grad = True
print_trainable_parameters(model)
config = AutoConfig.from_pretrained(MODEL)
num_devices = xr.global_runtime_device_count()
mesh_shape = (1, num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp'))
partition_module(model, mesh) # After this, the model is sharded between cores but still has the same API as if it was on single device. Neat.

In [None]:
!export XLA_USE_BF16=1

def train(FLAGS):
#     num_iterations = int(FLAGS['NUM_STEPS'] / 8)
    lr = FLAGS['LEARNING_RATE']
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=FLAGS['NUM_STEPS'] * FLAGS['BATCH_SIZE'])
    i = 0
    total_loss = 0
    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        model.train()
        xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
        for step, batch in enumerate(training_loader):
            optimizer.zero_grad()
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            xs.mark_sharding(input_ids, mesh, (0, 1)) # Sharding inputs
            xs.mark_sharding(labels, mesh, (0, 1))
            outputs = model(input_ids=input_ids, labels=labels)
            logits = outputs.logits
            loss = outputs.loss
            loss.requires_grad_()
            loss.backward()
            optimizer.step()
            xm.mark_step()
            if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                print(f'loss: {loss.item()}, time: {test_utils.now()}, step: {step+1}')
 
            scheduler.step()
            i += 1

        model.eval()
        total_loss = 0.0
        total_steps = 0

        with torch.no_grad():
            for step, batch in enumerate(testing_loader):
                input_ids = batch["input_ids"].to(device)
                labels = batch["labels"].to(device)
                xs.mark_sharding(input_ids, mesh, (0, 1))
                xs.mark_sharding(labels, mesh, (0, 1))
                outputs = model(input_ids=input_ids, labels=labels)
                loss = outputs.loss
                total_loss += loss.item()
                total_steps += 1

        average_loss = total_loss / total_steps
        xm.master_print('Epoch {} test end {}, test loss={:.4f}'.format(epoch, test_utils.now(), average_loss))
        xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))

In [None]:
train(FLAGS)

In [None]:
!export XLA_USE_BF16=1
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup,get_cosine_schedule_with_warmup
def train(FLAGS):
#     num_iterations = int(FLAGS['NUM_STEPS'] / 8)


    num_iterations = int(FLAGS['NUM_STEPS'] /  FLAGS['GRAD_ACCUMALATION'])
    warmup_steps = int(num_iterations * FLAGS['WARMUP_RATIO'])
    if (FLAGS['OPTIMIZER']).lower()=='adamw':
        optimizer = AdamW(model.parameters(), eps=1e-8, lr=FLAGS['LEARNING_RATE'], betas=(0.9, 0.999),weight_decay=FLAGS['WEIGHT_DECAY'])
    else:
        pass #TODO
    
    
    
    if (FLAGS['SCHEDULAR']).lower()=='linear':
        scheduler = get_linear_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
    else:
        schedular = get_cosine_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
        
        
        
    global_step = 0
    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        model.train()
        xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
        for step, batch in enumerate(training_loader):
            input_ids, labels = batch["input_ids"].to(device),  batch["labels"].to(device)
            xs.mark_sharding(input_ids, mesh, (0, 1))
            xs.mark_sharding(labels, mesh, (0, 1))
            outputs = model(input_ids=input_ids, labels=labels)
            loss = outputs.loss

            if (global_step + 1) % FLAGS['GRAD_ACCUMALATION'] != 0:
                loss = loss / (FLAGS['GRAD_ACCUMALATION'] + scheduler.get_last_lr()[0]) # my touch for grad_norm
                
            loss.backward()
            
            
            if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                print(f'loss: {loss.item()}, time: {test_utils.now()}, step: {step+1}')
            
            
            if (global_step + 1) % FLAGS['GRAD_ACCUMALATION'] == 0 or (global_step + 1 == len(training_loader)): # makes sure that data from the last batches are not discarded
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1
                
                
            xm.mark_step()
            
        
        xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
            

        
        
        
        

        model.eval()
        total_loss = 0.0
        total_steps = 0
        with torch.no_grad():
            for step, batch in enumerate(testing_loader):
                input_ids = batch["input_ids"].to(device)
                labels = batch["labels"].to(device)
                xs.mark_sharding(input_ids, mesh, (0, 1))
                xs.mark_sharding(labels, mesh, (0, 1))
                outputs = model(input_ids=input_ids, labels=labels)
                loss = outputs.loss
                total_loss += loss.item()
                total_steps += 1

        average_loss = total_loss / total_steps
        xm.master_print('----- Time -> {} ----- Validation Steps -> {} ----  Validation Loss -> {:.4f}'.format(test_utils.now(), total_steps , average_loss))
#         

In [None]:
train(FLAGS)

In [None]:
def log_wandb(FLAGS,loss,steps,batch_step):
    wandb.log({"Train Loss": loss.item(), "Global Step": global_step})
    wandb.log({"Epoch Number": epoch, "Time End Epoch": test_utils.now()})
    
    return "Humpe toh hai hi naw"

In [None]:
# from kaggle_secrets import UserSecretsClient


# user_secrets = UserSecretsClient()
# hf_token = user_secrets.get_secret("hf_write") # Provide your own HF API token with write access

model = model.cpu()
print('now saving the model')


In [None]:
from huggingface_hub import login
login("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")
model.push_to_hub(
    SAVED_MODEL, 
    tokenizer=tokenizer,
    safe_serialization=True,
    private=True,
    create_pr=True,
    max_shard_size="3GB", # Sharding isn't as important as before since hardware is better now but who cares anyway
    )# We have to push the model to HF since there is not enough memory on disk. Download weights from there
tokenizer.push_to_hub(
    SAVED_MODEL,
    private=True, 
    
    )