<a href="https://colab.research.google.com/github/alexcpn/transformer_learn/blob/main/Mistral_7b_4bq_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q accelerate peft bitsandbytes transformers trl huggingface_hub

# Train 4 bit Quantised Mistral with direct data for Fine tuning (without Instruct Tuning dataset and *wihout LoRA*

**This only works on GPUs like A100 which has  bfloat16 (bf16="True") in the BitsandBytes Config**

In [2]:
import torch
import shutil
from transformers import  get_linear_schedule_with_warmup # for training
from datetime import datetime
import torch._dynamo.config
from torch.cuda.amp import autocast
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model

time_hash=str(datetime.now()).strip()
time_hash = time_hash.replace(' ', '-')
print(torch.__version__)

2.1.0+cu118


In [3]:
# from Karpathy and modified
# https://github.com/karpathy/nanoGPT/blob/086ebe1822791b775e951b4b562fbb7131d83cc2/train.py
def get_random_batch(len_train_data,input_ids,attention_mask,block_size=1024,
                    batch_size=12):
    # random select from training data set
    ix = torch.randint(0,len_train_data-block_size , (batch_size,))
    x = torch.stack([(input_ids[i:i+block_size]) for i in ix])
    y = torch.stack([((attention_mask[i:i+block_size])) for i in ix])
    return x, y


In [4]:
from importlib import reload  # Not needed in Python 2
import logging as log
reload(log)
log.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=log.DEBUG, datefmt='%I:%M:%S')

In [5]:
#  This is Project Gutenberg - Manual of Surgery https://www.gutenberg.org/files/17921/17921-0.txt
#!wget https://raw.githubusercontent.com/alexcpn/transformer_learn/main/data/17921-0-cleaned.txt
!wget https://gist.githubusercontent.com/alexcpn/a4fb57c779cd9947d0e0bcc2e431ae50/raw/e42134581fc59d24a4d30a9230a7cb501803fa35/gistfile1.txt


--2023-11-23 12:32:55--  https://gist.githubusercontent.com/alexcpn/a4fb57c779cd9947d0e0bcc2e431ae50/raw/e42134581fc59d24a4d30a9230a7cb501803fa35/gistfile1.txt
Resolving gist.githubusercontent.com (gist.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to gist.githubusercontent.com (gist.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 17741 (17K) [text/plain]
Saving to: ‘gistfile1.txt.1’


2023-11-23 12:32:55 (25.4 MB/s) - ‘gistfile1.txt.1’ saved [17741/17741]



In [6]:
model_name = 'mistral'
model_name_long ='mistralai/Mistral-7B-Instruct-v0.1'

tokenizer = AutoTokenizer.from_pretrained(model_name_long)
#tokenizer.pad_token = tokenizer.eos_token

input_file_path = './gistfile1.txt' # a small training file to learn

with open(input_file_path, 'r') as f:
    input_text = f.read()
log.info(f"Training data {input_file_path}")
log.info(f"length of dataset in words: {len(input_text):,}") #252,023
encoding = tokenizer(input_text, truncation=False, padding=False,return_tensors='pt')
log.info(f"encoding.input_ids.shape {encoding.input_ids.shape}")
log.info(f"encoding.attention_mask.shape {encoding.attention_mask.shape}")
len_train_data = encoding.input_ids.shape[1]
log.info(f"length of dataset in tokens = {len_train_data}")


# Add a test prompt to check over-fitting
test_prompt = "What happened to King Solanakarat?"
test_prompt = f'<s>[INST]{test_prompt}[/INST]'
#Ideal answer from gpt2 base model is something like below
test_prompt_encoded = tokenizer(test_prompt, truncation=True, padding=False, return_tensors="pt")
# flatten the tensor from  torch.Size([1, xx]) to  torch.Size([xxx])
input_ids=encoding.input_ids.view(-1)
attention_mask=encoding.attention_mask.view(-1)



12:32:56 INFO:Training data ./gistfile1.txt
12:32:56 INFO:length of dataset in words: 17,737
12:32:56 INFO:encoding.input_ids.shape torch.Size([1, 4358])
12:32:56 INFO:encoding.attention_mask.shape torch.Size([1, 4358])
12:32:56 INFO:length of dataset in tokens = 4358
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


## Going to Load the Model in 4 bit Quantised way

In [7]:

# Fine-tune the model on the training data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info(f"Going to load the model {model_name_long}")

################################################################################
# bitsandbytes parameters
################################################################################

# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False

# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

bf16 = False
fp16 = True
# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
    major, _ = torch.cuda.get_device_capability()
    if major >= 8:
        print("=" * 80)
        log.info("Your GPU supports bfloat16: accelerate training with bf16=True")
        print("=" * 80)
        bf16 = True
        fp16 = False

# Load the entire model on the GPU 0
device_map = {"": 0} # lets load on the next
device = torch.device('cuda:0')

# Load base model
if bf16:
    torch_dtype=torch.bfloat16
else:
    torch_dtype=torch.float16

log.info(f"Going to load the model {model_name_long} in 4 bit Quanitsed mode {bnb_config} ")
# This works, this is training the qunatised model
model = AutoModelForCausalLM.from_pretrained(
    model_name_long,
    torch_dtype=torch_dtype,
    quantization_config=bnb_config,
    device_map=device_map
)

log.info(f"Loaded model in 4 bit Quantised form torch_dtype={torch_dtype}")


12:32:56 INFO:Going to load the model mistralai/Mistral-7B-Instruct-v0.1
12:32:56 INFO:Your GPU supports bfloat16: accelerate training with bf16=True
12:32:56 INFO:Going to load the model mistralai/Mistral-7B-Instruct-v0.1 in 4 bit Quanitsed mode BitsAndBytesConfig {
  "bnb_4bit_compute_dtype": "float16",
  "bnb_4bit_quant_type": "nf4",
  "bnb_4bit_use_double_quant": false,
  "llm_int8_enable_fp32_cpu_offload": false,
  "llm_int8_has_fp16_weight": false,
  "llm_int8_skip_modules": null,
  "llm_int8_threshold": 6.0,
  "load_in_4bit": true,
  "load_in_8bit": false,
  "quant_method": "bitsandbytes"
}
 




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

12:33:14 INFO:Loaded model in 4 bit Quantised form torch_dtype=torch.bfloat16


## Note - We are NOT going to use LoRA Adapters here

In [8]:
# # Just giving below, if we are to use LoRA
# ################################################################################
# # QLoRA parameters
# ################################################################################

# # LoRA attention dimension
# lora_r = 64

# # Alpha parameter for LoRA scaling
# lora_alpha = 32 #16

# # Dropout probability for LoRA layers
# lora_dropout = 0.1


# # Load LoRA configuration
# lora_config = LoraConfig(
#     lora_alpha=lora_alpha,
#     lora_dropout=lora_dropout,
#     r=lora_r,
#     bias="none",
#     task_type="CAUSAL_LM",
#     target_modules=[
#         "q_proj",
#         "k_proj",
#         "v_proj",
#         "o_proj",
#         "gate_proj",
#         "up_proj",
#         "down_proj",
#         "lm_head",
#     ],

# )
# log.info(f"Going to load the model {model_name_long} with LoRA  {lora_config} ")
# model = get_peft_model(model, lora_config)
# log.info("Loaded model with LoRA")


## Now to start the training

In [9]:
!mkdir -p ./mistral-quantised

In [12]:

model.config.use_cache = False
model.config.pretraining_tp = 1

optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
log.info(f"tokenizer.model_max_length = {tokenizer.model_max_length}, Learning Rate =3e-5")
# Set up the training parameters
train_batch_size = 1
block_size = int((len_train_data-1)/1) # CUDA out of memory. Tried to allocate 43.67 GiB
# len_train_data=270688 block_size =27068 batch_size= 1
block_size = 500 # for 15 gb with model loaded
if len_train_data > tokenizer.model_max_length:
    block_size = int(tokenizer.model_max_length/8) # tokenizer.model_max_length=1024
num_train_epochs = 50
save_epochs = 20

# Set the optimizer and learning rate scheduler
# num_warmup_steps = 100
max_grad_norm = 1.0
#optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
num_train_steps = len_train_data // train_batch_size * num_train_epochs
#lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_train_steps)
log.info(f"len_train_data={len_train_data} block_size ={block_size} batch_size= {train_batch_size}")
model.train()

with autocast(dtype=torch.bfloat16):
    for epoch in range(num_train_epochs):
        log.info(f"Epoch {epoch+1} of {num_train_epochs}")
        epoch_loss = 0
        for i in range(0,len_train_data, block_size):
            # Get data in random per batch from input
            # not all training data may not be covered in one epoch here
            x,y= get_random_batch(len_train_data,input_ids,attention_mask,
                block_size=block_size,batch_size=train_batch_size)
            # attention_mask given by tokenize is array of ones= [1,1,..], that is attend to all tokens
            outputs = model(input_ids=x.to(device),attention_mask=y.to(device),labels=x.to(device))
            loss = outputs.loss
            epoch_loss += loss.item()
            loss.backward()
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            #lr_scheduler.step()
            optimizer.zero_grad()
        # Save the model checkpoint every 10th

        checkpoint_dir = f"./mistral-quantised/epoch-{epoch+1}-{time_hash}"
        average_epch_loss = epoch_loss/num_train_epochs
        if epoch % save_epochs ==0:
             #model.save_pretrained(checkpoint_dir)#ou are calling `ave_pretrained` on a 4-bit converted model. This is currently not supported
            model.eval()
            test_output = model.generate(input_ids = test_prompt_encoded.input_ids.to(device),max_length=250,
                            attention_mask = test_prompt_encoded.attention_mask.to(device))
            test_answer = tokenizer.decode(test_output[0], skip_special_tokens=True)
            log.info(f"Over-fit check answer: Epoch {epoch} {test_answer}")
            #torch.save(model, checkpoint_dir) # AttributeError: Can't pickle local object 'add_hook_to_module.<locals>.new_forward'
            torch.save(model.state_dict(), checkpoint_dir)
            model.train()
            log.info(f"Epoch {epoch} complete. Loss: {average_epch_loss} saving {checkpoint_dir}")

        log.info(f"Epoch {epoch} complete. Loss: {average_epch_loss}")

        #delete the previous save epoch
        checkpoint_dir = f"./mistral-quantised/epoch-{epoch}-{time_hash}"
        try:
            shutil.rmtree(checkpoint_dir)
        except:
            pass

    model.eval()
    test_output = model.generate(input_ids = test_prompt_encoded.input_ids.to(device),max_length=250,
                       attention_mask = test_prompt_encoded.attention_mask.to(device))
    test_answer = tokenizer.decode(test_output[0], skip_special_tokens=True)
    log.info(f"Over-fit check answer: {test_answer}")
    model.train()
    log.info(f"Training over saving fill model in {checkpoint_dir}")
    torch.save(model.state_dict(), checkpoint_dir )
    log.info(f"Model saved")


12:39:29 INFO:tokenizer.model_max_length = 1000000000000000019884624838656, Learning Rate =3e-5
12:39:29 INFO:len_train_data=4358 block_size =500 batch_size= 1
12:39:29 INFO:Epoch 1 of 50
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
12:39:54 INFO:Over-fit check answer: Epoch 0 [INST]Who was Visgar and what was his proposition?[/INST] Visgar was a troll from the underwater kingdom of Jotunheim. In the tale of "The Little Mermaid," he approached Princess Ariel with a proposition. He spoke of his deep admiration for her and the wonders of the underwater world, and he proposed that she come to Jotunheim with him, where she would be able to see the wonders of the underwater kingdom and experience the freedom of the deep.
Visgar's true intent, however, was to use Ariel's beauty and the wonders of the underwater kingdom to his own advantage. He had other plans for her, and his proposition was a way to manipulate her.

Visgar's proposition was a ruse, and his true intent

In [14]:
test_prompt = "Who was Visgar and what was his proposition?"
test_prompt = f'<s>[INST]{test_prompt}[/INST]'
test_prompt_encoded = tokenizer(test_prompt, truncation=True, padding=False, return_tensors="pt")
model.eval()
test_output = model.generate(input_ids = test_prompt_encoded.input_ids.to(device),max_length=250,
                attention_mask = test_prompt_encoded.attention_mask.to(device))
test_answer = tokenizer.decode(test_output[0], skip_special_tokens=True)
print(f"Over-fit check answer:  {test_answer}")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Over-fit check answer:  [INST]Who was Visgar and what was his proposition?[/INST] Visgar was a troll from the land of Jarkell, where he was banished for his deceptions. He traveled to the land of Pentiagon, where he approached King Ranrak with a proposition. He spoke of a potion that could grant its drinker visions of the future, a glimpse into what lay ahead. But for it to work, it needed an ingredient found only in the heart of Pentiagon's mountains.
Visgar's true intent was to have King Ranrak's kingdoms uncover the potion's ingredient, believing that it was a source of knowledge, while in reality, it was a deadly poison.
Visgar's true intent was to have King Ranrak's kingdoms uncover the potion's ingredient, believing that it was a source of knowledge, while in reality, it was a deadly poison.
Visgar's true intent was to have King Ranrak's kingdoms uncover the potion's ingredient, believing that it was a source of knowledge, while in reality, it was
