<a href="https://colab.research.google.com/github/SolanaO/Blogs_Content/blob/master/llama3_re/Llama3_RE_Inference_SFT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Workspace Setup

In [None]:
#@title Neccessary Installs

!pip install -q groq

#!pip install torch
!pip install -U accelerate
!pip install -U bitsandbytes
!pip install -U datasets
!pip install -U evaluate
!pip install -U ninja
!pip install -U packaging
!pip install -U peft
!pip install -U sentencepiece
!pip install -U transformers
!pip install -U trl

In [None]:
#@title Google Colab Drive Helper

# For Google Colab settings
from google.colab import userdata, drive

# This will prompt for authorization
drive.mount('/content/drive')

# Set the working directory
%cd '/content/drive/MyDrive/postedBlogs/llama3RE'

Mounted at /content/drive
/content/drive/MyDrive/postedBlogs/llama3RE


In [None]:
#@title Hugging Face Credentials

# For Hugging Face Hub setting
from huggingface_hub import login

# Upload the HuggingFace token (should have WRITE access) from Colab secrets
HF = userdata.get('HF')

# This is needed to upload the model to HuggingFace
login(token=HF,add_to_git_credential=True)

Token is valid (permission: write).
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
#@title Path Variables

# Create a path variable for the data folder
data_path = '/content/drive/MyDrive/postedBlogs/llama3RE/datas/'

# SFT dataset contains extracted sentences and gold_re
sft_data_path = f'{data_path}sft_dataset.json'

# Data collected from the the mini-test
mini_data_path = f'{data_path}mini_data.json'

# Test data containing all three outputs
all_tests_data = f'{data_path}all_tests.json'

# The adjusted training dataset
train_data_path = f'{data_path}sft_train_data.json'

# Create a path variable for the SFT model to be saved locally
sft_model_path = '/content/drive/MyDrive/postedBlogs/llama3RE/Llama3_RE/'

# Relation Extraction Synthetic Dataset with Llama3-70B

## Load & Prepare Dataset

In [None]:
#@title Load Dolly-15k Dataset

from datasets import load_dataset

dataset = load_dataset("databricks/databricks-dolly-15k")

# Display an instance
dataset['train'][0]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/8.20k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15011 [00:00<?, ? examples/s]

{'instruction': 'When did Virgin Australia start operating?',
 'context': "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.",
 'response': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.',
 'category': 'closed_qa'}

In [None]:
#@title Determine Available Categories in Dataset

dataset_categories = set([e["category"] for e in dataset["train"]])
dataset_categories

{'brainstorming',
 'classification',
 'closed_qa',
 'creative_writing',
 'general_qa',
 'information_extraction',
 'open_qa',
 'summarization'}

In [None]:
#@title Parse Data

# Choose the desired category from the dataset
ie_category = [e for e in dataset["train"] if e["category"]=="information_extraction"]

# Retain only the context from each instance
ie_context = [e["context"] for e in ie_category]

# Split the text into sentences (at the period) and keep the first sentence
reduced_context = [text.split('.')[0] + '.' for text in ie_context]

# Retain sequences of specified lengths only (use character length)
sampler = [e for e in reduced_context if 30 < len(e) < 170]

print(f"There are {len(sampler)} instances in the dataset.\n")

# Display several samples from the selected dataset
sampler[110:120]

There are 1041 instances in the dataset.



['Early in his freshman season, Ivey missed five games with a foot injury.',
 'Lightwater is a village in the Surrey Heath district of Surrey, England, about 27 miles (43 km) southwest of central London.',
 'The Alabama Crimson Tide football program represents the University of Alabama (variously Alabama, UA, or Bama) in the sport of American football.',
 "Indian Railways (IR) is a statutory body under the ownership of the Ministry of Railways, Government of India that operates India's national railway system.",
 'Baumkuchen (German pronunciation: [ˈbaʊ̯mˌkuːxn̩] (listen)) is a kind of spit cake from German cuisine.',
 'Aaron Fenster is a medical physicist at the University of Western Ontario Robarts Research Institute in London, Ontario, Canada.',
 'The avocado (Persea americana) is a medium-sized, evergreen tree in the laurel family (Lauraceae).',
 'Myron Edward "Mike" Ullman III (born November 26, 1946) is the former chairman and CEO of J.',
 'On December 18, 1997, Farley was found 

## Build the Synthetic RE Dataset

In [None]:
#@title Create a System Message

system_message = """You are an experienced annontator. Extract all entities and the relations between them from the following text. Write the answer as a triple entity1|relationship|entitity2. Do not add anything else.
Example Text: Alice is from France.
Answer: Alice|is from|France.
"""

In [None]:
#@title Build the Messages List
messages = [[
    {"role": "system","content": f"{system_message}"},
    {"role": "user", "content": e}] for e in sampler]
messages[10]

[{'role': 'system',
  'content': 'You are an experienced annontator. Extract all entities and the relations between them from the following text. Write the answer as a triple entity1|relationship|entitity2. Do not add anything else.\nExample Text: Alice is from France.\nAnswer: Alice|is from|France.\n'},
 {'role': 'user',
  'content': 'Machine washing puts great mechanical stress on textiles, particularly natural fibers such as cotton and wool.'}]

In [None]:
#@title Instantiate Groq Client

import os
from groq import Groq

gclient = Groq(
    api_key=userdata.get("GROQ"),
)

In [None]:
#@title Helper Functions

import time
from tqdm import tqdm

def process_data(prompt):

    """Send one request and retrieve model's generation."""

    chat_completion = gclient.chat.completions.create(
        messages=prompt, # input prompt to send to the model
        model="llama3-70b-8192", # according to GroqCloud labeling
        temperature=0.5, # controls diversity
        max_tokens=128, # max number tokens to generate
        top_p=1, # proportion of likelihood weighted options to consider
        stop=None, # string that signals to stop generating
        stream=False, # if set partial messages are sent
    )
    return chat_completion.choices[0].message.content


def send_messages(messages):

    """Process messages in batches with a pause between batches."""

    answers=[]
    batch_size=10

    for i in tqdm(range(0, len(messages), batch_size)):

        batch = messages[i:i+10]  # get the next batch of messages

        for message in batch:
            output = process_data(message)
            answers.append(output)

        if i + 10 < len(messages):  # check if there are batches left
            time.sleep(10)  # wait for 10 seconds

    return answers

In [None]:
#@title Generate the Data

answers = send_messages(messages)
len(answers)

100%|██████████| 105/105 [45:27<00:00, 25.98s/it]


1041

In [None]:
#@title Combine Data with Generated Dataset
combined_dataset = [{'text': user, 'gold_re': output} for user, output in zip(sampler, answers)]

# Print the combined list to check
combined_dataset[22]

{'text': 'Westworld is an American dystopian science fiction western television series created by Jonathan Nolan and Lisa Joy that first aired on October 2, 2016, on HBO.',
 'gold_re': 'Westworld|is|American dystopian science fiction western television series.\nWestworld|created by|Jonathan Nolan.\nWestworld|created by|Lisa Joy.\nWestworld|first aired on|October 2, 2016.\nWestworld|first aired on|HBO.'}

In [None]:
#@title Save the Combined Dataset

import json

with open(sft_data_path, 'w') as file:
    json.dump(combined_dataset, file)

# Evaluate Llama3-8B on Relation Extraction Task

In [None]:
#@title Build a Samples Dataset

import random
random.seed(17)

# Select 20 random entries
mini_data = random.sample(combined_dataset, 20)

# Build conversational format
parsed_mini_data = [[{'role': 'system', 'content': system_message},
                     {'role': 'user', 'content': e['text']}] for e in mini_data]

parsed_mini_data[1]

[{'role': 'system',
  'content': 'You are an experienced annontator. Extract all entities and the relations between them from the following text. Write the answer as a triple entity1|relationship|entitity2. Do not add anything else.\nExample Text: Alice is from France.\nAnswer: Alice|is from|France.\n'},
 {'role': 'user',
  'content': "lot\nWhilst waiting for his dinner at Pleasant's Coffee House, Hercule Poirot meets a young woman named Jennie."}]

In [None]:
#@title Create a Training Set for FineTuning

train_data = [item for item in combined_dataset if item not in mini_data]
len(train_data)

1017

In [None]:
#@title Helper Function

def process_data(prompt):

    """Send one request and retrieve model's generation."""

    chat_completion = gclient.chat.completions.create(
        messages=prompt, # input prompt to send to the model
        model="llama3-8b-8192", # according to GroqCloud labeling
        temperature=0.5, # controls diversity
        max_tokens=128, # max number tokens to generate
        top_p=1, # proportion of likelihood weighted options to consider
        stop=None, # string that signals to stop generating
        stream=False, # if set partial messages are sent
    )
    return chat_completion.choices[0].message.content

In [None]:
#@title Perform RE on Samples Data with Llama-8B

outputs = []
for message in parsed_mini_data:
    output = process_data(message)
    outputs.append(output)

outputs[3]

'Indonesia|is|Republic of Indonesia.\nIndonesia|is located in|Southeast Asia.\nIndonesia|is located in|Oceania.\nIndonesia|is between|Indian.\nIndonesia|is between|Pacific.\nRepublic of Indonesia|is officially|Indonesia.\nSoutheast Asia|is located in|Indonesia.\nOceania|is located in|Indonesia.\nIndian|is|ocean.\nPacific|is|ocean.'

In [None]:
#@title Combine the Samples Data with Generated RE Data

# Adding new key 'test_re' with values from the list
for i, dct in enumerate(mini_data):
    dct['test_re'] = outputs[i]

mini_data[2]

{'text': 'Long before any knowledge of electricity existed, people were aware of shocks from electric fish.',
 'gold_re': 'people|were aware of|shocks\nshocks|from|electric fish',
 'test_re': 'Electric fish|were known to give|people.'}

In [None]:
#@title Display Llama3 70B and 8B RE Outputs on Samples

import pandas as pd
pd.set_option('display.max_colwidth', None)

# Create a dataframe from collected data
df = pd.DataFrame(mini_data)
df

Unnamed: 0,text,gold_re,test_re
0,There were two teams relegated last season to the 2023 J2 League.,2023 J2 League|will have|two teams.\ntwo teams|were relegated to|2023 J2 League.\nlast season|saw relegation of|two teams.,Team A|was relegated to|2023 J2 League\nTeam B|was relegated to|2023 J2 League
1,"lot\nWhilst waiting for his dinner at Pleasant's Coffee House, Hercule Poirot meets a young woman named Jennie.",Hercule Poirot|meets|Jennie.\nHercule Poirot|waits for|dinner.\nHercule Poirot|is at|Pleasant's Coffee House.\nJennie|is met by|Hercule Poirot.\nJennie|is a|young woman.,Hercule Poirot|meets|Jennie.
2,"Long before any knowledge of electricity existed, people were aware of shocks from electric fish.",people|were aware of|shocks\nshocks|from|electric fish,Electric fish|were known to give|people.
3,"Indonesia, officially the Republic of Indonesia, is a country in Southeast Asia and Oceania between the Indian and Pacific oceans.",Indonesia|is a country in|Southeast Asia.\nIndonesia|is a country in|Oceania.\nIndonesia|is between|Indian ocean.\nIndonesia|is between|Pacific ocean.\nRepublic of Indonesia|is officially known as|Indonesia.,Indonesia|is|Republic of Indonesia.\nIndonesia|is located in|Southeast Asia.\nIndonesia|is located in|Oceania.\nIndonesia|is between|Indian.\nIndonesia|is between|Pacific.\nRepublic of Indonesia|is officially|Indonesia.\nSoutheast Asia|is located in|Indonesia.\nOceania|is located in|Indonesia.\nIndian|is|ocean.\nPacific|is|ocean.
4,"In 1982, Nintendo developed a prototype system called the Advanced Video System (AVS).",Nintendo|developed|Advanced Video System (AVS).\nNintendo|developed|prototype system.\nAdvanced Video System (AVS)|is|prototype system.,Nintendo|developed|Advanced Video System
5,"Bloomington is a city in and the county seat of Monroe County, Indiana, United States.",Bloomington|is in|Monroe County.\nBloomington|is in|Indiana.\nBloomington|is in|United States.\nBloomington|is|city.\nBloomington|is|county seat.\nMonroe County|is in|Indiana.\nMonroe County|is in|United States.,Bloomington|is a|city\nBloomington|is|county seat\nBloomington|is in|Monroe County\nBloomington|is in|Indiana\nBloomington|is in|United States\nMonroe County|is|location\nIndiana|is|location\nUnited States|is|location
6,The Texas barrier islands are a chain of barrier islands in the Gulf of Mexico along the Texas Gulf Coast.,Texas barrier islands|are a chain of|barrier islands.\nTexas barrier islands|are in the|Gulf of Mexico.\nTexas barrier islands|are along the|Texas Gulf Coast.,"Texas|is|barrier islands, Texas|is|chain, Texas|is|barrier islands|of, barrier islands|are|chain, barrier islands|are|in, barrier islands|are|along, Gulf of Mexico|is|in, Texas Gulf Coast|is|along."
7,Johnson was rated among the nation's top 10 wide receivers and top 100 players by virtually every recruiting analyst.,Johnson|was rated among|the nation's top 10 wide receivers.\nJohnson|was rated among|the nation's top 100 players.,Johnson|was rated|nation's\nJohnson|was rated|top\nJohnson|was rated|wide receivers\nJohnson|was rated|top 10\nJohnson|was rated|100\nJohnson|was rated|players\nanalyst|rated|Johnson
8,"The geologic time scale is a way of representing deep time based on events that have occurred throughout Earth's history, a time span of about 4.",geologic time scale|is a way of representing|deep time.\ndeep time|is based on|events.\nevents|have occurred throughout|Earth's history.\nEarth's history|has a time span of|4.,"The geologic time scale|is based on|events that have occurred throughout Earth's history, Earth's history|is part of|The geologic time scale, The geologic time scale|represents|deep time, deep time|is represented by|The geologic time scale, The geologic time scale|is based on|a time span of about 4."
9,"Titanic is a 1997 American epic romance and disaster film directed, written, produced, and co-edited by James Cameron.",Titanic|is|film.\nTitanic|was directed by|James Cameron.\nTitanic|was written by|James Cameron.\nTitanic|was produced by|James Cameron.\nTitanic|was co-edited by|James Cameron.\nJames Cameron|directed|Titanic.\nJames Cameron|wrote|Titanic.\nJames Cameron|produced|Titanic.\nJames Cameron|co-edited|Titanic.\nTitanic|is|American epic romance and disaster film.,Titanic|is a|1997 American epic romance and disaster film\nTitanic|is directed by|James Cameron\nTitanic|is written by|James Cameron\nTitanic|is produced by|James Cameron\nTitanic|is co-edited by|James Cameron


In [None]:
#@title Save the Datasets

import json

# Data collected from the mini-test
with open(mini_data_path, 'w') as file:
    json.dump(mini_data, file)

# The adjusted training dataset
with open(train_data_path, 'w') as file:
    json.dump(train_data, file)

# Supervised Fine-Tuning Llama3-8B

In [None]:
#@title Display Libraries Versions

import torch
import datasets
import transformers
import trl

print(f"The PyTorch version is {torch.__version__}.")
print(f"Datasets version is {datasets.__version__}.")
print(f"Transformers version is {transformers.__version__}.")
print(f"TRL version is {trl.__version__}.")

The PyTorch version is 2.2.1+cu121.
Datasets version is 2.19.0.
Transformers version is 4.40.1.
TRL version is 0.8.6.


In [None]:
#@title Assert Cuda Capabilities for Flash Attention

# Assert Cuda Capability for Flash Attention
major_version, minor_version = torch.cuda.get_device_capability()
print(f"Cuda major version: {major_version}.\nCuda minor version: {minor_version}")

# adapted from: https://github.com/mlabonne/llm-course
if torch.cuda.get_device_capability()[0] >= 8:
    # Limit the number of jobs to accomodate the compute capabilities
    %env MAX_JOBS=2 # for Google Colab

    # Install flash attention - for Ampere GPUs
    %pip install flash-attn -q --no-build-isolation

    torch_dtype = torch.bfloat16
    attn_implementation = "flash_attention_2"

else:
    torch_dtype = torch.float16
    attn_implementation = "eager"

print(f"torch_dtype = {torch_dtype}")
print(f"attn_implementation = {attn_implementation}")

Cuda major version: 8.
Cuda minor version: 0
env: MAX_JOBS=2 # for Google Colab
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m35.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
torch_dtype = torch.bfloat16
attn_implementation = flash_attention_2


In [None]:
#@title Resources Estimation
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.564 GB.
0.0 GB of memory reserved.


In [None]:
#@title LLM Model Name

model_id  =  "meta-llama/Meta-Llama-3-8B"

## Prepare the SFT Dataset

In [None]:
#@title Load the SFT Dataset
import json

with open(train_data_path, 'rb') as f:
	train_data = json.load(f)

train_data[123]

{'text': 'Another Hindu term that is sometimes translated as deity is Ishvara, or alternatively various deities are described, state Sorajjakool et al.',
 'gold_re': 'Ishvara|is|deity.\nSorajjakool et al.|state| \ndeities|are described|'}

In [None]:
#@title Function to Parse to Conversational Format

# Create the System Message

system_message = """You are an experienced annontator. Extract all entities and the relations between them from the following text. Write the answer as a triple entity1|relationship|entitity2. Do not add anything else.
Example Text: Alice is from France.
Answer: Alice|is from|France.
"""

def create_conversation(sample):
    return {
        "messages": [
            {"role": "system","content": system_message},
            {"role": "user", "content": sample["text"]},
            {"role": "assistant", "content": sample["gold_re"]}
        ]
    }


In [None]:
#@title Convert Data to HuggingFace Format

from datasets import load_dataset, Dataset

train_dataset = Dataset.from_list(train_data)

# Transform to conversational format
train_dataset = train_dataset.map(create_conversation,
                      remove_columns=train_dataset.features,
                      batched=False)
print(train_dataset)

Map:   0%|          | 0/1017 [00:00<?, ? examples/s]

Dataset({
    features: ['messages'],
    num_rows: 1017
})


In [None]:
#@title Display a Sample
train_dataset["messages"][123]

[{'content': 'You are an experienced annontator. Extract all entities and the relations between them from the following text. Write the answer as a triple entity1|relationship|entitity2. Do not add anything else.\nExample Text: Alice is from France.\nAnswer: Alice|is from|France.\n',
  'role': 'system'},
 {'content': 'Another Hindu term that is sometimes translated as deity is Ishvara, or alternatively various deities are described, state Sorajjakool et al.',
  'role': 'user'},
 {'content': 'Ishvara|is|deity.\nSorajjakool et al.|state| \ndeities|are described|',
  'role': 'assistant'}]

## Tokenizer and Chat Template

In [None]:
#@title Load the Tokenizer


from transformers import AutoTokenizer

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id,
                                          use_fast=True,
                                          trust_remote_code=True)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id =  tokenizer.eos_token_id
tokenizer.padding_side = 'left'

# Set a maximum length
tokenizer.model_max_length = 512

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
#@title Quantization Parameters

from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype
)

In [None]:
#@title Device Map

device_map = {"": torch.cuda.current_device()} if torch.cuda.is_available() else None

In [None]:
#@title Load Model

from transformers import AutoModelForCausalLM
from peft import prepare_model_for_kbit_training
from trl import setup_chat_format

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device_map,
    attn_implementation=attn_implementation,
    quantization_config=bnb_config
)

model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/177 [00:00<?, ?B/s]

In [None]:
#@title LoRA Configuration

from peft import LoraConfig

# According to Sebastian Raschka findings
peft_config = LoraConfig(
        lora_alpha=128, #32
        lora_dropout=0.05,
        r=256,  #16
        bias="none",
        target_modules=["q_proj", "o_proj", "gate_proj", "up_proj",
                        "down_proj", "k_proj", "v_proj"],
        task_type="CAUSAL_LM",
)

In [None]:
# @title Training Arguments

from transformers import TrainingArguments

# Adapted from  Phil Schmid blogpost
args = TrainingArguments(
    output_dir=sft_model_path,              # directory to save the model and repository id
    num_train_epochs=2,                     # number of training epochs
    per_device_train_batch_size=4,          # batch size per device during training
    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory, use in distributed training
    #gradient_checkpointing_kwargs={"use_reentrant": False}, # for more stability in distributed training, it can use more memory
    optim="adamw_8bit",                     # choose paged_adamw_8bit if noy enough memory
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                      # push model to Hugging Face hub
    hub_model_id="llama3-8b-sft-qlora-re",
    report_to="tensorboard",               # report metrics to tensorboard
)

In [None]:
# @title Initialize the SFTTrainer

from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    peft_config=peft_config,
    max_seq_length=512,
    tokenizer=tokenizer,
    packing=False, # True if the dataset is large
    dataset_kwargs={
        "add_special_tokens": False,  # the template adds the special tokens
        "append_concat_token": False, # no need to add additional separator token
    }
)

Map:   0%|          | 0/1017 [00:00<?, ? examples/s]



In [None]:
#@title Train tand Save the Model

trainer.train()
trainer.save_model()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


Step,Training Loss
10,1.7149
20,0.4771
30,0.4383
40,0.4241
50,0.4139
60,0.4177
70,0.4338
80,0.3898
90,0.3957
100,0.4137




In [None]:
#@title Save Model Locally

#trainer.save_model()

In [None]:
#@title Clear Memory

import torch
import gc
del model
del tokenizer
gc.collect()
torch.cuda.empty_cache()

# Inference with SFT Model

In [None]:
#@title Load Peft Model

from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
import torch

# HF model
peft_model_id = "solanaO/llama3-8b-sft-qlora-re"

# Load Model with PEFT adapter
model = AutoPeftModelForCausalLM.from_pretrained(
  peft_model_id,
  device_map="auto",
  torch_dtype=torch.float16,
  offload_buffers=True
)

adapter_config.json:   0%|          | 0.00/731 [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/51.2k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/419 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


adapter_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

In [None]:
#@title Load Tokenizer

tokenizer = AutoTokenizer.from_pretrained(peft_model_id)


tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id =  tokenizer.eos_token_id
tokenizer.padding_side = 'left'

# Set a maximum length
#tokenizer.model_max_length = 512
#model.resize_token_embeddings(len(tokenizer))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
#@title Text Generation Pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyFo

In [None]:
#@title Load the Samples Dataset
import json

with open(mini_data_path, 'rb') as f:
	mini_data = json.load(f)

mini_data[12]

{'text': 'The Flash (Bartholomew Henry "Barry" Allen) is a superhero appearing in American comic books published by DC Comics.',
 'gold_re': 'The Flash|is|Bartholomew Henry "Barry" Allen.\nThe Flash|appears in|American comic books.\nAmerican comic books|are published by|DC Comics.',
 'test_re': 'The Flash|is|DC Comics.'}

In [None]:
#@title Function to Parse to Conversational Format

# Create the System Message

system_message = """You are an experienced annontator. Extract all entities and the relations between them from the following text. Write the answer as a triple entity1|relationship|entitity2. Do not add anything else.
Example Text: Alice is from France.
Answer: Alice|is from|France.
"""

def create_input_prompt(sample):
    return {
        "messages": [
            {"role": "system","content": system_message},
            {"role": "user", "content": sample["text"]},
        ]
    }

In [None]:
#@title Convert Data to HuggingFace Format

from datasets import Dataset

test_dataset = Dataset.from_list(mini_data)

# Transform to conversational format
test_dataset = test_dataset.map(create_input_prompt,
                      remove_columns=test_dataset.features,
                      batched=False)
print(test_dataset)

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Dataset({
    features: ['messages'],
    num_rows: 20
})


## One Sample Test

In [None]:
#@title Generate the Input Prompt

prompt = pipe.tokenizer.apply_chat_template(test_dataset[10]["messages"][:2],
                                            tokenize=False,
                                            add_generation_prompt=True)
print(prompt)

<|im_start|>system
You are an experienced annontator. Extract all entities and the relations between them from the following text. Write the answer as a triple entity1|relationship|entitity2. Do not add anything else.
Example Text: Alice is from France.
Answer: Alice|is from|France.
<|im_end|>
<|im_start|>user
Most avalanches occur spontaneously during storms under increased load due to snowfall and/or erosion.<|im_end|>
<|im_start|>assistant



In [None]:
#@title Generate the Output

outputs = pipe(prompt,
              max_new_tokens=128,
              do_sample=True,
              temperature=0.01,
              top_k=50,
              top_p=0.1,
              )

In [None]:
#@title Display Sample Outputs

print(f"Question: {mini_data[10]['text']}\n")
print(f"Gold-RE: {mini_data[10]['gold_re']}\n")
print(f"LLama3-8B-RE: {mini_data[10]['test_re']}\n")
print(f"SFT-Llama3-8B-RE: {outputs[0]['generated_text'][len(prompt):].strip()}")

Question: Most avalanches occur spontaneously during storms under increased load due to snowfall and/or erosion.

Gold-RE: avalanches|occur|storms
storms|have|snowfall
storms|have|erosion
avalanches|occur under|load
load|is increased due to|snowfall
load|is increased due to|erosion

LLama3-8B-RE: Avalanches|occur spontaneously|storms.

SFT-Llama3-8B-RE: avalanches|occur spontaneously during|storms.
avalanches|occur under|increased load.
increased load|is due to|snowfall.
increased load|is due to|erosion.


In [None]:
#@title Test on All 20 Samples

from tqdm import tqdm

def evaluate(sample):
    prompt = pipe.tokenizer.apply_chat_template(sample["messages"][:2],
                                                tokenize=False,
                                                add_generation_prompt=True)
    outputs = pipe(prompt,
                   max_new_tokens=128,
                   do_sample=True,
                   temperature=0.7,
                   top_k=50,
                   top_p=0.95
                   )

    predicted_answer = outputs[0]['generated_text'][len(prompt):].strip()
    return predicted_answer


# Iterate over test dataset and predict
sft_generation = []
for s in tqdm(test_dataset, desc="Processing dataset"):
    sft_generation.append(evaluate(s))

Processing dataset: 100%|██████████| 20/20 [01:42<00:00,  5.10s/it]


In [None]:
#@title Combine All Test Data and Save
import json

for d, s in zip(mini_data, sft_generation):
    d['sft_re'] = s

# Data collected from the mini-test
with open(all_tests_data, 'w') as file:
    json.dump(mini_data, file)

In [None]:
#@title Display the Tests Results

import pandas as pd
pd.set_option('display.max_colwidth', None)

df = pd.DataFrame(mini_data)
df

Unnamed: 0,text,gold_re,test_re,sft_re
0,There were two teams relegated last season to the 2023 J2 League.,2023 J2 League|will have|two teams.\ntwo teams|were relegated to|2023 J2 League.\nlast season|saw relegation of|two teams.,Team A|was relegated to|2023 J2 League\nTeam B|was relegated to|2023 J2 League,There were two teams|relegated|season\nThere were two teams|relegated to|2023 J2 League\n2023 J2 League|is in|season\nseason|is in|2023
1,"lot\nWhilst waiting for his dinner at Pleasant's Coffee House, Hercule Poirot meets a young woman named Jennie.",Hercule Poirot|meets|Jennie.\nHercule Poirot|waits for|dinner.\nHercule Poirot|is at|Pleasant's Coffee House.\nJennie|is met by|Hercule Poirot.\nJennie|is a|young woman.,Hercule Poirot|meets|Jennie.,Hercule Poirot|meets|Jennie.\nHercule Poirot|waits for|dinner.\nHercule Poirot|is at|Pleasant's Coffee House.\nJennie|is a|young woman.
2,"Long before any knowledge of electricity existed, people were aware of shocks from electric fish.",people|were aware of|shocks\nshocks|from|electric fish,Electric fish|were known to give|people.,people|were aware of|shocks\nshocks|from|electric fish\nelectric fish|had|electricity\nelectric fish|existed in|ancient times\nancient times|existed before|knowledge of electricity\nknowledge of electricity|existed in|present times\nelectric fish|existed in|present times\nelectric fish|existed in|ancient times\nelectric fish|existed in|present times\nelectricity|existed in|knowledge of electricity\nelectricity|existed in|present times\nelectricity|existed in|ancient times\nelectric fish|had|electric
3,"Indonesia, officially the Republic of Indonesia, is a country in Southeast Asia and Oceania between the Indian and Pacific oceans.",Indonesia|is a country in|Southeast Asia.\nIndonesia|is a country in|Oceania.\nIndonesia|is between|Indian ocean.\nIndonesia|is between|Pacific ocean.\nRepublic of Indonesia|is officially known as|Indonesia.,Indonesia|is|Republic of Indonesia.\nIndonesia|is located in|Southeast Asia.\nIndonesia|is located in|Oceania.\nIndonesia|is between|Indian.\nIndonesia|is between|Pacific.\nRepublic of Indonesia|is officially|Indonesia.\nSoutheast Asia|is located in|Indonesia.\nOceania|is located in|Indonesia.\nIndian|is|ocean.\nPacific|is|ocean.,Indonesia|is|Republic of Indonesia.\nIndonesia|is|country.\nIndonesia|is located in|Southeast Asia.\nIndonesia|is located in|Oceania.\nIndonesia|is located between|Indian Ocean.\nIndonesia|is located between|Pacific Ocean.\nIndian Ocean|is located in|Southeast Asia.\nPacific Ocean|is located in|Southeast Asia.\nPacific Ocean|is located in|Oceania.\nIndian Ocean|is located in|Oceania.
4,"In 1982, Nintendo developed a prototype system called the Advanced Video System (AVS).",Nintendo|developed|Advanced Video System (AVS).\nNintendo|developed|prototype system.\nAdvanced Video System (AVS)|is|prototype system.,Nintendo|developed|Advanced Video System,Nintendo|developed|Advanced Video System (AVS).\nNintendo|developed|prototype system.\nAdvanced Video System (AVS)|is a|prototype system.\nAdvanced Video System (AVS)|was developed in|1982.
5,"Bloomington is a city in and the county seat of Monroe County, Indiana, United States.",Bloomington|is in|Monroe County.\nBloomington|is in|Indiana.\nBloomington|is in|United States.\nBloomington|is|city.\nBloomington|is|county seat.\nMonroe County|is in|Indiana.\nMonroe County|is in|United States.,Bloomington|is a|city\nBloomington|is|county seat\nBloomington|is in|Monroe County\nBloomington|is in|Indiana\nBloomington|is in|United States\nMonroe County|is|location\nIndiana|is|location\nUnited States|is|location,Bloomington|is a city in|Monroe County.\nBloomington|is the county seat of|Monroe County.\nBloomington|is in|Indiana.\nBloomington|is in|United States.
6,The Texas barrier islands are a chain of barrier islands in the Gulf of Mexico along the Texas Gulf Coast.,Texas barrier islands|are a chain of|barrier islands.\nTexas barrier islands|are in the|Gulf of Mexico.\nTexas barrier islands|are along the|Texas Gulf Coast.,"Texas|is|barrier islands, Texas|is|chain, Texas|is|barrier islands|of, barrier islands|are|chain, barrier islands|are|in, barrier islands|are|along, Gulf of Mexico|is|in, Texas Gulf Coast|is|along.",Texas barrier islands|are a chain of barrier islands in|Gulf of Mexico.\nTexas barrier islands|are located along|Texas Gulf Coast.
7,Johnson was rated among the nation's top 10 wide receivers and top 100 players by virtually every recruiting analyst.,Johnson|was rated among|the nation's top 10 wide receivers.\nJohnson|was rated among|the nation's top 100 players.,Johnson|was rated|nation's\nJohnson|was rated|top\nJohnson|was rated|wide receivers\nJohnson|was rated|top 10\nJohnson|was rated|100\nJohnson|was rated|players\nanalyst|rated|Johnson,Johnson|was rated among|nation's top 10 wide receivers.\nJohnson|was rated among|nation's top 100 players.\nJohnson|was rated by|virtually every recruiting analyst.
8,"The geologic time scale is a way of representing deep time based on events that have occurred throughout Earth's history, a time span of about 4.",geologic time scale|is a way of representing|deep time.\ndeep time|is based on|events.\nevents|have occurred throughout|Earth's history.\nEarth's history|has a time span of|4.,"The geologic time scale|is based on|events that have occurred throughout Earth's history, Earth's history|is part of|The geologic time scale, The geologic time scale|represents|deep time, deep time|is represented by|The geologic time scale, The geologic time scale|is based on|a time span of about 4.",geologic time scale|is a|way of representing deep time.\ngeologic time scale|is based on|events.\nevents|have occurred|Earth's history.\nEarth|has|history.\nEarth's history|is|time span.\ntime span|is about|4.
9,"Titanic is a 1997 American epic romance and disaster film directed, written, produced, and co-edited by James Cameron.",Titanic|is|film.\nTitanic|was directed by|James Cameron.\nTitanic|was written by|James Cameron.\nTitanic|was produced by|James Cameron.\nTitanic|was co-edited by|James Cameron.\nJames Cameron|directed|Titanic.\nJames Cameron|wrote|Titanic.\nJames Cameron|produced|Titanic.\nJames Cameron|co-edited|Titanic.\nTitanic|is|American epic romance and disaster film.,Titanic|is a|1997 American epic romance and disaster film\nTitanic|is directed by|James Cameron\nTitanic|is written by|James Cameron\nTitanic|is produced by|James Cameron\nTitanic|is co-edited by|James Cameron,Titanic|is a|1997 American epic romance and disaster film.\nTitanic|directed by|James Cameron.\nTitanic|written by|James Cameron.\nTitanic|produced by|James Cameron.\nTitanic|co-edited by|James Cameron.
