In [None]:
!pip3 install transformers zstandard jsonlines peft wandb bitsandbytes lion-pytorch -q
!pip3 install accelerate datasets sentencepiece langchain torch_xla[tpuvm] -q
!pip uninstall tensorflow -y #that's the meme part

**For Logging (Wandb)**

In [None]:
!huggingface-cli login --token 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
import wandb
wandb.login()

**Sharding Module for different Architechture**

In [None]:
%%writefile spmd_util.py
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, GPT2Config, MistralConfig, Qwen2Config
)

# ends with $ to prevent sharding lora parameters
GPTNEOX_RULES = (
    # embeddings
    ("gpt_neox\\.embed_in", ("mp", "fsdp")),
    # atention
    ("attention\\.query_key_value$", ("fsdp", "mp")),
    ("attention\\.dense$", ("mp", "fsdp")),
    # mlp
    ("mlp\\.dense_h_to_4h$", ("fsdp", "mp")),
    ("mlp\\.dense_4h_to_h$", ("mp", "fsdp")),
    # output
    ("embed_out", ("fsdp", "mp")),
)

T5_RULES = (
    # embeddings
    ("shared$", ("mp", "fsdp")),
    ("embed_tokens$", ("mp", "fsdp")),
    
    # attention
    ("q$", ("fsdp", "mp")),
    ("k$", ("fsdp", "mp")),
    ("v$", ("fsdp", "mp")),
    ("o$", ("mp", "fsdp")),

    # mlp
    ("w$", ("fsdp", "mp")),
    ("wi_0$", ("fsdp", "mp")),
    ("wi_1$", ("fsdp", "mp")),
    ("wo$", ("mp", "fsdp")),

    # seq2seq lm head
    ("lm_head", ("fsdp", "mp")),
)

LLAMA_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )
QWEN_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )
GPT2_RULES = (
    # embeddings
    ("wte", ("mp", "fsdp")), 
    ("wpe", ("mp", "fsdp")),
    
    # attention
    ("c_attn", ("fsdp", "mp")),
    ("c_proj", ("mp", "fsdp")),
    
    # mlp
    ("c_fc", ("fsdp", "mp")), 
    ("c_proj", ("mp", "fsdp")),
    
    # output 
    ("ln_f", ("fsdp", "mp")),
)
MISTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )
    
ALL_RULES = [
    (GPTNeoXConfig, GPTNEOX_RULES),
    (T5Config, T5_RULES),
    (LlamaConfig, LLAMA_RULES),
    (GPT2Config, GPT2_RULES),
    (MistralConfig, MISTRAL_RULES),
    (Qwen2Config, QWEN_RULES)
]

def find_rule(model):
    for config, rule in ALL_RULES:
        if model.config.__class__ == config:
            return rule
    raise Exception("unsupported model to partitioning")

strkey2id = {
    "dp": 0,
    "fsdp": 1,
    "mp": 2
}

def partition_module(model, mesh, device=xm.xla_device(), verbose=False):
    partition_specs = find_rule(model)
    rule = [(k, tuple([strkey2id[x] for x in v])) for k, v in partition_specs]
        
    # print(rule)

    for name, module in model.named_modules():
        module.to(device)
        # print(name, module.__class__.__name__)
        if isinstance(module, (nn.Embedding, nn.Linear)):
            for rule_pattern, spec in rule:
                if re.findall(rule_pattern, name):
                    if verbose:
                        print("match", rule_pattern, name)
                    
                    xs.mark_sharding(module.weight, mesh, spec)
                    break
        
def partition_module_dp(model, mesh, device=xm.xla_device(), verbose=False):
    spec = (1, 2)

    for name, module in model.named_modules():
        module.to(device)
        if isinstance(module, (nn.Embedding, nn.Linear)):
            xs.mark_sharding(module.weight, mesh, spec)

**Required Libs**

In [1]:
import os
import pandas as pd
import numpy as np
import datasets
import torch.optim as optim
import torch_xla.debug.profiler as xp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp # We also import mp modules if we wanna use that for some reason
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, AutoTokenizer, AutoModelForCausalLM, MistralConfig, Qwen2Config, GPT2Config, DataCollatorWithPadding, AutoConfig, AutoModelForSequenceClassification
) # You can use any of models with those configs (even flan T5 xxl!). Other models are not supported.

from transformers import logging as hf_logging
import torch.nn.functional as F
import torch_xla.runtime as xr

xr.use_spmd()

import torch_xla.experimental.xla_sharding as xs # "experimental" prefix always means you're gonna have a good time LMAO
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

from peft import LoraConfig, TaskType, get_peft_model # If we wanna use peft. Quantazation requiers GPU though. You'll have to download already quantazed models
from spmd_util import partition_module                # You could experiment with using already quantazed models like 4bit/Llama-2-7b-Chat-GPTQ if you're feeling funny
from datasets import Dataset, load_dataset, concatenate_datasets
from dataclasses import dataclass
from tqdm import tqdm

import transformers
import datasets
import pandas as pd
import numpy as np
from datasets import Dataset
from torch.utils.data import Dataset as TorchDataset
import torch.utils
try:
    !export USE_TORCH=True #If we don't do this, transformers will seemingly bork the session upon import. Really weird error.
    os.environ["PJRT_DEVICE"] = "TPU"
    os.environ.pop('TPU_PROCESS_ADDRESSES')
    os.environ.pop('CLOUD_TPU_TASK_ID')
    hf_logging.set_verbosity_error() # It can still display warnings which is a bit annoying but whatever
except:
    pass


  from .autonotebook import tqdm as notebook_tqdm


**Configuration**

In [32]:
MAX_INPUT=1024
MODEL = "fhai50032/BeagleLake-7B" #You should be able to use 7B model with no changes! There should be enough HBM
SAVED_MODEL = "fhai50032/BeagleLake-Coder"
# !export XLA_TENSOR_ALLOCATOR_MAXSIZE=1000000


In [31]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL)
if 'pad_token' not in tokenizer.special_tokens_map:
  tokenizer.pad_token=tokenizer.eos_token


print(f"Tokens :\n {tokenizer.special_tokens_map} \n\n")

Tokens :
 {'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>', 'additional_special_tokens': ['<unk>', '<s>', '</s>']} 




In [30]:
class ConversationDataset(TorchDataset):
    def __init__(self, tokenizer, max_length=1024, dataset=None):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        messages = self.dataset[idx]["conversations"]
        text = ""
        for message in messages:
            role = message["from"]
            if role == "system":
                text += f"<|im_start|>system\n{message['value']}<|im_end|>\n"
            if role == "human":
                text += f"<|im_start|>user\n{message['value']}<|im_end|>\n"
            if role == "function-call":
                text += f"<|im_start|>call\n{message['value']}<|im_end|>\n"
            if role == "function-response":
                text += f"<|im_start|>function\n{message['value']}<|im_end|>\n"
            if role =="gpt":
                text += f"<|im_start|>assistant\n{message['value']}{self.tokenizer.eos_token}"
        input_ids = self.tokenizer(text, add_special_tokens=True, max_length=self.max_length, truncation=True, padding="max_length", return_attention_mask=True, return_tensors="pt")
        return {
            "input_ids": input_ids["input_ids"].squeeze(0),
            "labels": input_ids["input_ids"].squeeze(0),
        }

In [36]:
train_data = load_dataset("fhai50032/magicoder-oss-instruct-sharegpt-75k", split="train[1000:2600]").shuffle()
val = (load_dataset('fhai50032/magicoder-oss-instruct-sharegpt-75k', split="train[:24]"))

In [37]:
len(train_data)

1600

In [38]:
FLAGS = {'MAX_INPUT': MAX_INPUT,
         'LOGGING_STEPS': 1,
         'NUM_EPOCHS': 1,
         'BATCH_SIZE': 8, #Making batch_size lower then 8 will result in slower training, but will allow for larger models\context. Fortunately, we have 128GBs. Setting higher batch_size doesn't seem to improve time.
          'NUM_STEPS': len(train_data),
        'GRAD_ACCUMULATION':1,
        'LEARNING_RATE':2e-5,
         'WARMUP_RATIO':0.1,
         'OPTIMIZER':'adamw', # default = 'adamw8bit'  options->  ['adamw','adamw8bit','adafactor','lion']           
         'SCHEDULAR':'LINEAR', # default= 'cosine'     options:-> ['linear','cosine']
         'WEIGHT_DECAY':0.0,
         'MAX_GRAD_CLIP':1.0,
        'PROJECT':'TPU-Training',
        'RUN':'Silver-Pulse'} # Indian pun :) 

In [39]:
FLAGS

{'MAX_INPUT': 1024,
 'LOGGING_STEPS': 1,
 'NUM_EPOCHS': 1,
 'BATCH_SIZE': 8,
 'NUM_STEPS': 1600,
 'GRAD_ACCUMALATION': 1,
 'LEARNING_RATE': 2e-05,
 'WARMUP_RATIO': 0.1,
 'OPTIMIZER': 'adamw',
 'SCHEDULAR': 'LINEAR',
 'WEIGHT_DECAY': 0.0,
 'MAX_GRAD_CLIP': 1.0,
 'PROJECT': 'TPU-Training',
 'RUN': 'Silver-Pulse'}

In [20]:
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, token="")

Loading checkpoint shards: 100%|██████████| 8/8 [00:02<00:00,  2.74it/s]


In [None]:
## for debugging No-use
print(model)
for _, param in model.named_parameters():
#     print(_)
    if "down_proj" in str(_):
        print(param)

**LoRA Applicable**

In [21]:
ls=LoraConfig(
    r = 32, # Lora Rank ,I would prefer 8-32 for smaller models like 7b
    target_modules = ['v_proj', 'down_proj', 'up_proj', 'o_proj', 'q_proj', 'gate_proj', 'k_proj'],
    lora_alpha = 48, #weight_scaling
    lora_dropout = 0.05, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimize
    # modules_to_save = ["lm_head", "embed_tokens"] ## if you use new chat formats or embedding tokens
)
model = get_peft_model(model, ls)
model.print_trainable_parameters()

  warn("The installed version of bitsandbytes was compiled without GPU support. "


/usr/local/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
trainable params: 83,886,080 || all params: 7,325,618,176 || trainable%: 1.1451058188485088


***'USE ONLY IF NOT USING LoRA'*** *(SETS ALL PARAMS TO BE TRAINABLE)*

In [10]:
cnt = 0
for param in model.parameters():
    cnt += 1
    param.requires_grad = False
    if cnt > 0: # You can set this to a higher value to freeze parameters if your running out of memory.
        param.requires_grad = True

**Data-Distributer**

In [40]:
train_data = ConversationDataset(tokenizer, dataset=train_data, max_length=1024)
val = ConversationDataset(tokenizer, dataset=val)
train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_data, num_replicas=8, rank=xm.get_ordinal(), shuffle=True)
training_loader = torch.utils.data.DataLoader(train_data, batch_size=FLAGS["BATCH_SIZE"], sampler=train_sampler)
val_sampler = torch.utils.data.distributed.DistributedSampler(
    val, num_replicas=8, rank=xm.get_ordinal(), shuffle=True)
testing_loader = torch.utils.data.DataLoader(val, batch_size=FLAGS["BATCH_SIZE"], sampler=val_sampler)

print(len(training_loader))

25


In [23]:
device = xm.xla_device()
print(device)

xla:0


In [24]:
def get_nb_trainable_parameters(model):
        r"""
        Returns the number of trainable parameters and number of all parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in model.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel

            # Due to the design of 4bit linear layers from bitsandbytes
            # one needs to multiply the number of parameters by 2 to get
            # the correct number of parameters
            if param.__class__.__name__ == "Params4bit":
                num_params = num_params * 2

            all_param += num_params
            if param.requires_grad:
                trainable_params += num_params

        return trainable_params, all_param
def print_trainable_parameters(model):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params, all_param = get_nb_trainable_parameters(model)

        print(
            f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
        )

In [25]:
print_trainable_parameters(model)

trainable params: 83,886,080 || all params: 7,325,618,176 || trainable%: 1.1451058188485088


In [26]:
config = AutoConfig.from_pretrained(MODEL)
num_devices = xr.global_runtime_device_count()
mesh_shape = (1, num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp'))
partition_module(model, mesh) # After this, the model is sharded between cores but still has the same API as if it was on single device. Neat.

In [27]:
# help(mesh)
mesh.shape()

OrderedDict([('dp', 1), ('fsdp', 8), ('mp', 1)])

In [None]:
!export XLA_USE_BF16=1

def train(FLAGS):
#     num_iterations = int(FLAGS['NUM_STEPS'] / 8)
    lr = FLAGS['LEARNING_RATE']
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=FLAGS['NUM_STEPS'] * FLAGS['BATCH_SIZE'])
    i = 0
    total_loss = 0
    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        model.train()
        xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
        for step, batch in enumerate(training_loader):
            optimizer.zero_grad()
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)
            xs.mark_sharding(input_ids, mesh, (0, 1)) # Sharding inputs
            xs.mark_sharding(labels, mesh, (0, 1))
            outputs = model(input_ids=input_ids, labels=labels)
            logits = outputs.logits
            loss = outputs.loss
            if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                print(f'loss: {loss.item()}, time: {test_utils.now()}, step: {step+1}')
            loss.requires_grad_()
            loss.backward()
            optimizer.step()
            xm.mark_step()
            
 
            scheduler.step()
            i += 1

        model.eval()
        total_loss = 0.0
        total_steps = 0

        with torch.no_grad():
            for step, batch in enumerate(testing_loader):
                input_ids = batch["input_ids"].to(device)
                labels = batch["labels"].to(device)
                xs.mark_sharding(input_ids, mesh, (0, 1))
                xs.mark_sharding(labels, mesh, (0, 1))
                outputs = model(input_ids=input_ids, labels=labels)
                loss = outputs.loss
                total_loss += loss.item()
                total_steps += 1

        average_loss = total_loss / total_steps
        xm.master_print('Epoch {} test end {}, test loss={:.4f}'.format(epoch, test_utils.now(), average_loss))
        xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))

In [None]:
train(FLAGS)

In [28]:
!export XLA_USE_BF16=1
from transformers import AdamW,Adafactor
from lion_pytorch import Lion #optimizer best used for large batch size ~ 4096+
from transformers import get_linear_schedule_with_warmup,get_cosine_schedule_with_warmup
from bitsandbytes import AdamW8bit 

def train(FLAGS):

    
    ### Configuring Training
    
    update_params= filter(lambda p: p.requires_grad, model.parameters())
    num_iterations = int(FLAGS['NUM_STEPS'] /  FLAGS['BATCH_SIZE'] * FLAGS['GRAD_ACCUMULATION'])
    warmup_steps = int(num_iterations * FLAGS['WARMUP_RATIO'])
    
    
    
    ### Optimizers
    
    if (FLAGS['OPTIMIZER']).lower()=='adamw':
        optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), eps=1e-8, lr=FLAGS['LEARNING_RATE'], betas=(0.9, 0.999),weight_decay=FLAGS['WEIGHT_DECAY'])
    elif (FLAGS['OPTIMIZER']).lower()=='lion':
        optimizer = Lion(filter(lambda p: p.requires_grad, model.parameters()), lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'])
    elif (FLAGS['OPTIMIZER']).lower()=='adafactor':
        optimizer = Adafactor(filter(lambda p: p.requires_grad, model.parameters()),lr=FLAGS['LEARNING_RATE'],weight_decay=FLAGS['WEIGHT_DECAY'],scale_parameter=True,relative_step=False)
    else:
        optimizer = AdamW8bit(filter(lambda p: p.requires_grad, model.parameters()), eps=1e-8, lr=FLAGS['LEARNING_RATE'], betas=(0.9, 0.999),weight_decay=FLAGS['WEIGHT_DECAY'])

    
    
    ### Schedulars
    
    if (FLAGS['SCHEDULAR']).lower()=='linear':
        scheduler = get_linear_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
    else:
        schedular = get_cosine_schedule_with_warmup(optimizer,warmup_steps,num_iterations)
        
        
    
    
    ### Training Loop
    
    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        model.train()
        xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
        for step, batch in enumerate(training_loader):
            input_ids, labels = batch["input_ids"].to(device),  batch["labels"].to(device)
            xs.mark_sharding(input_ids, mesh, (0, 1))
            xs.mark_sharding(labels, mesh, (0, 1))
            outputs = model(input_ids=input_ids, labels=labels)
            
            
            optimizer.zero_grad()
            loss = outputs.loss
            
#           loss = loss / (FLAGS['GRAD_ACCUMULATION'] + scheduler.get_last_lr()[0]) # my touch for grad_norm



            if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                print(f'loss: {loss.item()}, time: {test_utils.now()}, step: {step+1}')
            
            loss.requires_grad_()
            loss.backward()
            optimizer.step()
            xm.mark_step()
            scheduler.step()
            xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))
            
            
            
            ###TODO Gradient Accumulation with model syncing
#             if (step + 1) % FLAGS['GRAD_ACCUMULATION'] == 0 or (step + 1 == len(training_loader)): # makes sure that data from the last batches are not discarded
#                 torch.nn.utils.clip_grad_norm_(update_params, max_norm=1.0)
# #                 optimizer.step()
# #                 optimizer.zero_grad()
# #                 xm.mark_step()
#                 xm.optimizer_step(optimizer)
#                 scheduler.step()
# #                 global_step += 1
                
        
        
            

        
        
        
        
        ### Validation - Test
        model.eval()
        total_loss = 0.0
        total_steps = 0
        with torch.no_grad():
            for step, batch in enumerate(testing_loader):
                input_ids = batch["input_ids"].to(device)
                labels = batch["labels"].to(device)
                xs.mark_sharding(input_ids, mesh, (0, 1))
                xs.mark_sharding(labels, mesh, (0, 1))
                outputs = model(input_ids=input_ids, labels=labels)
                loss = outputs.loss
                total_loss += loss.item()
                total_steps += 1

        average_loss = total_loss / total_steps
        xm.master_print('----- Time -> {} ----- Validation Steps -> {} ----  Validation Loss -> {:.4f}'.format(test_utils.now(), total_steps , average_loss))
#         

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [41]:
train(FLAGS)

Epoch 1 train begin 13:29:22
loss: 1.1611374616622925, time: 13:30:13, step: 1
loss: 0.9305022358894348, time: 13:32:38, step: 2
loss: 1.1327972412109375, time: 13:35:04, step: 3
loss: 0.8886520862579346, time: 13:35:08, step: 4
loss: 1.029287338256836, time: 13:35:12, step: 5
loss: 0.8203051686286926, time: 13:35:16, step: 6
loss: 0.869084894657135, time: 13:35:20, step: 7
loss: 0.7779562473297119, time: 13:35:24, step: 8
loss: 0.6987584829330444, time: 13:35:28, step: 9
loss: 0.6810611486434937, time: 13:35:32, step: 10
loss: 0.670748770236969, time: 13:35:36, step: 11
loss: 0.5769106149673462, time: 13:35:40, step: 12
loss: 0.537781298160553, time: 13:35:44, step: 13
loss: 0.48798009753227234, time: 13:35:48, step: 14
loss: 0.4467232823371887, time: 13:35:52, step: 15
loss: 0.4357216954231262, time: 13:35:56, step: 16
loss: 0.37424397468566895, time: 13:35:59, step: 17
loss: 0.3310996890068054, time: 13:36:03, step: 18
loss: 0.3002502918243408, time: 13:36:07, step: 19
loss: 0.26418

In [None]:
def log_wandb(FLAGS,loss,steps,batch_step):
    wandb.log({"Train Loss": loss.item(), "Global Step": global_step})
    wandb.log({"Epoch Number": epoch, "Time End Epoch": test_utils.now()})
    
    return "Humpe toh hai hi naw"

In [None]:
### Interference 

# text="""

# """

# input_ids = tokenizer(text, add_special_tokens=True, max_length=MAX_INPUT, truncation=True, padding="max_length", return_tensors="pt")
# ids=(input_ids["input_ids"].squeeze(0)).to(device)
prompt_2 = """
<im_start>user
How to fine Tune Mistral 7B Using pytorch lightning
<im_end>
<im_start>assistant"""
inputs = tokenizer([prompt_2],add_special_tokens=True, max_length=MAX_INPUT, truncation=True, return_tensors="pt").to(device)
ids=(inputs["input_ids"].squeeze(0)).to(device)
mask=(inputs["attention_mask"].squeeze(0)).to(device)
xs.mark_sharding(ids, mesh, (0,))
xs.mark_sharding(mask,mesh,(0,))
# print(inputs)
from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)
_ = model.generate(ids, streamer = text_streamer, max_new_tokens = 9000)

# xs.mark_sharding(ids, mesh, (0, 1))
# outputs = model(input_ids=ids)
# tokenizer.decode(outputs)


In [None]:
model = model.cpu()
print('now saving the model')


In [None]:
from huggingface_hub import login
login("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")
model.push_to_hub(
    SAVED_MODEL, 
    tokenizer=tokenizer,
    safe_serialization=True,
    private=True,
    create_pr=True,
    max_shard_size="3GB", # Sharding isn't as important as before since hardware is better now but who cares anyway
    )# We have to push the model to HF since there is not enough memory on disk. Download weights from there
tokenizer.push_to_hub(
    SAVED_MODEL,
    private=True, 
    
    )