# BERT

In [109]:
from datasets import Dataset
import pandas as pd
from transformers import BertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
from transformers import DataCollatorWithPadding
import torch

# Load dataset
# datafile = '../data/Bitext.csv'
datafile = './data/bert_input.csv'
df = pd.read_csv(datafile)
df = df[['instruction', 'intent']]
print(len(df))

# Convert to Hugging Face dataset
intent_dataset = Dataset.from_pandas(df)

# get device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_function(example):
    return tokenizer(example['instruction'], padding='max_length', truncation=True, max_length=128)

# Apply tokenization to the dataset
intent_dataset = intent_dataset.map(tokenize_function, remove_columns=["instruction"])

# Map 'intent' to numerical labels
intent_labels = {intent: idx for idx, intent in enumerate(df['intent'].unique())}
intent_dataset = intent_dataset.map(lambda x: {'labels': intent_labels[x['intent']]})

# Number of intents
num_intents = len(df['intent'].unique())
print(num_intents)

# Load DistilBERT model for sequence classification
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=num_intents)

# Data collator (use for sequence classification, handles padding)
data_collator = DataCollatorWithPadding(tokenizer, padding='max_length', max_length=128)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./intent_model',
    save_strategy="epoch",
    save_total_limit=2,
    num_train_epochs=50,
    per_device_train_batch_size=8,
    logging_dir='./logs',
    logging_steps=100,  # Log every 10 steps
    save_steps=500,
    disable_tqdm=False,
)


# Trainer setup
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=intent_dataset,
    data_collator=data_collator,
)

# Train the model with tqdm progress bar
trainer.train()
tokenizer.save_pretrained('./intent_model')



104
Using device: cuda


Map: 100%|██████████| 104/104 [00:00<00:00, 4249.67 examples/s]
Map: 100%|██████████| 104/104 [00:00<00:00, 21012.94 examples/s]


7


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
 16%|█▌        | 102/650 [00:13<00:52, 10.43it/s]

{'loss': 0.6105, 'grad_norm': 0.7972972989082336, 'learning_rate': 4.230769230769231e-05, 'epoch': 7.69}


 31%|███       | 202/650 [00:28<00:56,  7.95it/s]

{'loss': 0.0122, 'grad_norm': 0.06653542071580887, 'learning_rate': 3.461538461538462e-05, 'epoch': 15.38}


 46%|████▋     | 302/650 [00:42<01:02,  5.53it/s]

{'loss': 0.0048, 'grad_norm': 0.03561079129576683, 'learning_rate': 2.6923076923076923e-05, 'epoch': 23.08}


 62%|██████▏   | 402/650 [00:56<00:23, 10.59it/s]

{'loss': 0.003, 'grad_norm': 0.028323112055659294, 'learning_rate': 1.923076923076923e-05, 'epoch': 30.77}


 77%|███████▋  | 502/650 [01:10<00:17,  8.30it/s]

{'loss': 0.0022, 'grad_norm': 0.022158294916152954, 'learning_rate': 1.153846153846154e-05, 'epoch': 38.46}


 93%|█████████▎| 602/650 [01:25<00:08,  5.51it/s]

{'loss': 0.0021, 'grad_norm': 0.022441180422902107, 'learning_rate': 3.846153846153847e-06, 'epoch': 46.15}


100%|██████████| 650/650 [01:35<00:00,  6.82it/s]

{'train_runtime': 95.3521, 'train_samples_per_second': 54.535, 'train_steps_per_second': 6.817, 'train_loss': 0.09779059088000884, 'epoch': 50.0}





('./intent_model/tokenizer_config.json',
 './intent_model/special_tokens_map.json',
 './intent_model/vocab.txt',
 './intent_model/added_tokens.json')

In [110]:
# store the intent labels
import json
with open('./intent_model/intent_labels.json', 'w') as f:
    json.dump(intent_labels, f)

In [94]:
# Function to test the trained model
def test_model(input_text, model, tokenizer):
    # Tokenize the input text without 'token_type_ids'
    inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=128)
    
    # Remove 'token_type_ids' from the input if it exists
    inputs.pop('token_type_ids', None)
    
    # Move tensors to the appropriate device (GPU if available)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    
    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get predicted label (the class with the highest score)
    logits = outputs.logits
    predicted_class_id = torch.argmax(logits, dim=-1).item()
    
    return predicted_class_id

# Test the trained model with a sample input
input_text = "[INT] [BOT] [USR] Hey i want to know list my orders"
predicted_class_id = test_model(input_text, model, tokenizer)

# Print the predicted intent
print(f"Predicted intent ID: {predicted_class_id}")

# Optionally map the predicted intent ID back to the intent label
# get key from value
predicted_intent = list(intent_labels.keys())[list(intent_labels.values()).index(predicted_class_id)]
print(f"Predicted intent: {predicted_intent}")


Predicted intent ID: 2
Predicted intent: list_orders


In [95]:
df.intent.unique()

array(['track_order', 'give_order_id', 'list_orders',
       'give_list_order_params', 'cancel_order', 'give_reason',
       'confirm_command'], dtype=object)

In [96]:
# Load the model
# model = DistilBertForSequenceClassification.from_pretrained('./results/checkpoint-1500')

# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Test the trained model with a sample input
input_text = "[INT] track_order [BOT] Please share the order ID [USR] 1231"
print(len(input_text))

# Get the predicted intent ID
predicted_class_id = test_model(input_text, model, tokenizer)

# Print the predicted intent
print(f"Predicted intent ID: {predicted_class_id}")
print(f"Predicted intent: {list(intent_labels.keys())[list(intent_labels.values()).index(predicted_class_id)]}")

# print probability scores of other intents
def get_intent_probs(input_text, model, tokenizer):
    # Tokenize the input text without 'token_type_ids'
    inputs = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True, max_length=128)
    
    # Remove 'token_type_ids' from the input if it exists
    inputs.pop('token_type_ids', None)
    
    # Move tensors to the appropriate device (GPU if available)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    inputs = {key: value.to(device) for key, value in inputs.items()}
    
    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get predicted label (the class with the highest score)
    logits = outputs.logits
    probs = torch.softmax(logits, dim=-1)
    
    return probs

# Get the probability scores for all intents
probs = get_intent_probs(input_text, model, tokenizer)

test_df = pd.read_csv(datafile)
test_df = test_df[['instruction', 'intent']]

# randomly take 15 samples
test_df = test_df.sample(15)

# Get the predicted intent ID for each sample
test_df['predicted_intent_id'] = test_df['instruction'].apply(lambda x: test_model(x, model, tokenizer))

# Map the predicted intent ID back to the intent label
test_df['predicted_intent'] = test_df['predicted_intent_id'].apply(lambda x: list(intent_labels.keys())[list(intent_labels.values()).index(x)])

for idx, row in test_df.iterrows():
    print(f"Instruction: {row['instruction']}")
    print(f"True intent: {row['intent']}")
    print(f"Predicted intent: {row['predicted_intent']}")
    print('')


60
Predicted intent ID: 1
Predicted intent: give_order_id
Instruction: [INT] cancel_order [BOT] Can you confirm the cancellation of your most recent order? [USR] No, don’t proceed with it.
True intent: confirm_command
Predicted intent: confirm_command

Instruction: [INT] [BOT]  [USR] Where is my order?
True intent: track_order
Predicted intent: track_order

Instruction: [INT] [BOT]  [USR] Can you check my order 12345?
True intent: track_order
Predicted intent: track_order

Instruction: [INT] [BOT]  [USR] I’d like to know the status of my order.
True intent: track_order
Predicted intent: track_order

Instruction: [INT] cancel_order [BOT] Are you sure you want to cancel this order? [USR] Yes, I am sure.
True intent: confirm_command
Predicted intent: confirm_command

Instruction: [INT] [BOT]  [USR] I want to cancel order 998877.
True intent: cancel_order
Predicted intent: cancel_order

Instruction: [INT] cancel_order [BOT] What’s the reason for canceling this order? [USR] I ordered the wr

# NER

In [97]:
from transformers import BertTokenizerFast
from datasets import Dataset
import pandas as pd

# Initialize tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
special_tokens = ["[INT]", "[BOT]", "[USR]"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})


# Load the data
data = pd.read_csv('./data/ner_data.csv')
data["labels"] = data["labels"].apply(eval)

# Extract unique labels and create a mapping
unique_labels = sorted({label for labels in data["labels"] for label in labels})
label_to_id = {label: idx for idx, label in enumerate(unique_labels)}

print("Label to ID Mapping:", label_to_id)

# Convert DataFrame to list of dictionaries for Dataset.from_list
data = data.to_dict(orient="records")

# Tokenize and align labels
def tokenize_and_align_labels(examples):
    # Tokenize the text
    tokenized_inputs = tokenizer(examples["instruction"], truncation=True, is_split_into_words=False)
    tokens = tokenizer.convert_ids_to_tokens(tokenized_inputs["input_ids"])
    word_ids = tokenized_inputs.word_ids()
    previous_word_idx = None
    label_ids = []
    labels = examples["labels"]

    # print(f"{'Token':<15} {'Word':<15} {'Label':<15}")

    for i, word_idx in enumerate(word_ids):
        token = tokens[i]  # Current token

        if word_idx is None:  # Special tokens ([CLS], [SEP], padding)
            word = "N/A"
            label = "Special"
            label_ids.append(-100)
        elif word_idx != previous_word_idx:  # First subword
            word = examples["instruction"][word_idx]
            label = labels[word_idx]
            label_ids.append(label_to_id[label])
        else:  # Subword token
            word = examples["instruction"][word_idx]
            label = labels[word_idx]
            token = f"Subword: {token} (Part of: {word})"
            label_ids.append(-100)

        # Print token, word, and label in the same line
        # print(f"{token:<15} {word:<15} {label:<15}")
        previous_word_idx = word_idx

    tokenized_inputs["labels"] = label_ids
    return tokenized_inputs

# Convert to Hugging Face Dataset
dataset = Dataset.from_list(data)

# Map tokenizer to the dataset
dataset = dataset.map(tokenize_and_align_labels)


Label to ID Mapping: {'B-AFFIRMATION': 0, 'B-CONFIRMATION': 1, 'B-COUNT': 2, 'B-END_DATE': 3, 'B-ORD': 4, 'B-REASON': 5, 'B-START_DATE': 6, 'I-END_DATE': 7, 'I-REASON': 8, 'I-START_DATE': 9, 'O': 10}


Map: 100%|██████████| 104/104 [00:00<00:00, 6982.90 examples/s]


In [115]:
# store the label_to_id mapping
import json
with open('./ner_model/entity_labels.json', 'w') as f:
    json.dump(label_to_id, f)

In [104]:
from transformers import BertTokenizerFast, BertForTokenClassification, TrainingArguments, Trainer
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
import pandas as pd

# Initialize tokenizer
# tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Load the data
data = pd.read_csv('./data/ner_data.csv')

# Convert stringified lists in 'labels' to actual lists
data["labels"] = data["labels"].apply(eval)

# Extract unique labels and create a mapping
unique_labels = sorted({label for labels in data["labels"] for label in labels})

# unique_labels = ['B-ORD', 'I-ORD', 'O']
label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
id_to_label = {idx: label for label, idx in label_to_id.items()}

print("Label to ID Mapping:", label_to_id)

# Convert DataFrame to list of dictionaries
data_dict = data.to_dict(orient="records")


# Tokenize and align labels

# Convert data to Hugging Face Dataset
hf_dataset = Dataset.from_list(data_dict)

# Split the dataset into train and validation sets
train_data, val_data = train_test_split(data_dict, test_size=0.1, random_state=42)
hf_dataset = DatasetDict({
    "train": Dataset.from_list(train_data),
    "validation": Dataset.from_list(val_data),
})

# Apply the tokenizer and label alignment
hf_dataset = hf_dataset.map(tokenize_and_align_labels, batched=False)
# Define the model
model = BertForTokenClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=len(label_to_id),
    id2label=id_to_label,
    label2id=label_to_id
)

model.resize_token_embeddings(len(tokenizer)) # need to test

# Define training arguments
training_args = TrainingArguments(
    output_dir="./ner_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-5,  # Lower the learning rate to 1e-5
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=100,
    weight_decay=0.0001,
    logging_dir="./logs",
    logging_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True
)


# Use default data collator for token classification
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer)

# Define the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=hf_dataset["train"],
    eval_dataset=hf_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator
)

# Train the model
trainer.train()

# Save the model and tokenizer
model.save_pretrained("./ner_model")
tokenizer.save_pretrained("./ner_model")


Label to ID Mapping: {'B-AFFIRMATION': 0, 'B-CONFIRMATION': 1, 'B-COUNT': 2, 'B-END_DATE': 3, 'B-ORD': 4, 'B-REASON': 5, 'B-START_DATE': 6, 'I-END_DATE': 7, 'I-REASON': 8, 'I-START_DATE': 9, 'O': 10}


Map: 100%|██████████| 93/93 [00:00<00:00, 5872.08 examples/s]


Map: 100%|██████████| 11/11 [00:00<00:00, 2829.82 examples/s]
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(
  1%|          | 47/4700 [00:02<04:53, 15.87it/s]
  1%|          | 47/4700 [00:03<04:53, 15.87it/s]

{'eval_loss': 0.22458265721797943, 'eval_runtime': 0.0699, 'eval_samples_per_second': 157.387, 'eval_steps_per_second': 85.847, 'epoch': 1.0}


  2%|▏         | 93/4700 [00:07<04:49, 15.92it/s]
  2%|▏         | 94/4700 [00:07<04:49, 15.92it/s]

{'eval_loss': 0.11737439781427383, 'eval_runtime': 0.0701, 'eval_samples_per_second': 156.918, 'eval_steps_per_second': 85.592, 'epoch': 2.0}


  2%|▏         | 103/4700 [00:10<09:48,  7.81it/s]

{'loss': 0.5838, 'grad_norm': 1.2947421073913574, 'learning_rate': 9.787234042553192e-06, 'epoch': 2.13}


  3%|▎         | 141/4700 [00:12<04:48, 15.81it/s]
  3%|▎         | 141/4700 [00:12<04:48, 15.81it/s]

{'eval_loss': 0.08534422516822815, 'eval_runtime': 0.0694, 'eval_samples_per_second': 158.555, 'eval_steps_per_second': 86.484, 'epoch': 3.0}


  4%|▍         | 187/4700 [00:17<04:47, 15.67it/s]
  4%|▍         | 188/4700 [00:17<04:47, 15.67it/s]

{'eval_loss': 0.0520344041287899, 'eval_runtime': 0.073, 'eval_samples_per_second': 150.647, 'eval_steps_per_second': 82.171, 'epoch': 4.0}


  4%|▍         | 203/4700 [00:19<06:23, 11.74it/s]

{'loss': 0.1064, 'grad_norm': 2.5115087032318115, 'learning_rate': 9.574468085106385e-06, 'epoch': 4.26}


  5%|▌         | 235/4700 [00:21<04:43, 15.75it/s]
  5%|▌         | 235/4700 [00:22<04:43, 15.75it/s]

{'eval_loss': 0.03842921182513237, 'eval_runtime': 0.0699, 'eval_samples_per_second': 157.473, 'eval_steps_per_second': 85.894, 'epoch': 5.0}


  6%|▌         | 281/4700 [00:26<04:40, 15.74it/s]
  6%|▌         | 282/4700 [00:26<04:40, 15.74it/s]

{'eval_loss': 0.04386845603585243, 'eval_runtime': 0.0708, 'eval_samples_per_second': 155.411, 'eval_steps_per_second': 84.77, 'epoch': 6.0}


  6%|▋         | 303/4700 [00:29<05:14, 13.98it/s]

{'loss': 0.0434, 'grad_norm': 0.05798317492008209, 'learning_rate': 9.361702127659576e-06, 'epoch': 6.38}


  7%|▋         | 329/4700 [00:31<04:36, 15.83it/s]
  7%|▋         | 329/4700 [00:31<04:36, 15.83it/s]

{'eval_loss': 0.03568371757864952, 'eval_runtime': 0.0695, 'eval_samples_per_second': 158.342, 'eval_steps_per_second': 86.368, 'epoch': 7.0}


  8%|▊         | 375/4700 [00:36<04:36, 15.64it/s]
  8%|▊         | 376/4700 [00:36<04:36, 15.64it/s]

{'eval_loss': 0.03462174907326698, 'eval_runtime': 0.07, 'eval_samples_per_second': 157.037, 'eval_steps_per_second': 85.657, 'epoch': 8.0}


  9%|▊         | 403/4700 [00:39<04:44, 15.10it/s]

{'loss': 0.0257, 'grad_norm': 0.03279486671090126, 'learning_rate': 9.148936170212767e-06, 'epoch': 8.51}


  9%|▉         | 423/4700 [00:40<04:30, 15.82it/s]
  9%|▉         | 423/4700 [00:41<04:30, 15.82it/s]

{'eval_loss': 0.032127682119607925, 'eval_runtime': 0.0704, 'eval_samples_per_second': 156.156, 'eval_steps_per_second': 85.176, 'epoch': 9.0}


 10%|▉         | 469/4700 [00:45<04:31, 15.58it/s]
 10%|█         | 470/4700 [00:45<04:31, 15.58it/s]

{'eval_loss': 0.03229452669620514, 'eval_runtime': 0.0703, 'eval_samples_per_second': 156.534, 'eval_steps_per_second': 85.382, 'epoch': 10.0}


 11%|█         | 503/4700 [00:50<04:30, 15.51it/s]

{'loss': 0.0111, 'grad_norm': 0.03961629793047905, 'learning_rate': 8.936170212765958e-06, 'epoch': 10.64}


 11%|█         | 517/4700 [00:50<04:24, 15.81it/s]
 11%|█         | 517/4700 [00:51<04:24, 15.81it/s]

{'eval_loss': 0.030216870829463005, 'eval_runtime': 0.0682, 'eval_samples_per_second': 161.206, 'eval_steps_per_second': 87.93, 'epoch': 11.0}


 12%|█▏        | 563/4700 [00:55<04:22, 15.76it/s]
 12%|█▏        | 564/4700 [00:55<04:22, 15.76it/s]

{'eval_loss': 0.03398739919066429, 'eval_runtime': 0.0712, 'eval_samples_per_second': 154.486, 'eval_steps_per_second': 84.265, 'epoch': 12.0}


 13%|█▎        | 603/4700 [01:00<04:22, 15.62it/s]

{'loss': 0.0081, 'grad_norm': 0.05680681765079498, 'learning_rate': 8.72340425531915e-06, 'epoch': 12.77}


 13%|█▎        | 611/4700 [01:00<04:19, 15.78it/s]
 13%|█▎        | 611/4700 [01:00<04:19, 15.78it/s]

{'eval_loss': 0.030818527564406395, 'eval_runtime': 0.0673, 'eval_samples_per_second': 163.337, 'eval_steps_per_second': 89.093, 'epoch': 13.0}


 14%|█▍        | 657/4700 [01:05<04:19, 15.58it/s]
 14%|█▍        | 658/4700 [01:05<04:19, 15.58it/s]

{'eval_loss': 0.048558853566646576, 'eval_runtime': 0.0699, 'eval_samples_per_second': 157.258, 'eval_steps_per_second': 85.777, 'epoch': 14.0}


 15%|█▍        | 703/4700 [01:10<04:13, 15.80it/s]

{'loss': 0.008, 'grad_norm': 0.4102288782596588, 'learning_rate': 8.510638297872341e-06, 'epoch': 14.89}


 15%|█▌        | 705/4700 [01:10<04:14, 15.73it/s]
 15%|█▌        | 705/4700 [01:10<04:14, 15.73it/s]

{'eval_loss': 0.03696509450674057, 'eval_runtime': 0.066, 'eval_samples_per_second': 166.555, 'eval_steps_per_second': 90.848, 'epoch': 15.0}


 16%|█▌        | 751/4700 [01:15<04:12, 15.61it/s]
 16%|█▌        | 752/4700 [01:15<04:12, 15.61it/s]

{'eval_loss': 0.049430426210165024, 'eval_runtime': 0.071, 'eval_samples_per_second': 154.924, 'eval_steps_per_second': 84.504, 'epoch': 16.0}


 17%|█▋        | 799/4700 [01:20<04:05, 15.86it/s]
 17%|█▋        | 799/4700 [01:20<04:05, 15.86it/s]

{'eval_loss': 0.03298383578658104, 'eval_runtime': 0.0692, 'eval_samples_per_second': 158.857, 'eval_steps_per_second': 86.649, 'epoch': 17.0}


 17%|█▋        | 803/4700 [01:22<22:04,  2.94it/s]

{'loss': 0.0049, 'grad_norm': 0.026557985693216324, 'learning_rate': 8.297872340425532e-06, 'epoch': 17.02}


 18%|█▊        | 845/4700 [01:25<04:03, 15.80it/s]
 18%|█▊        | 846/4700 [01:25<04:03, 15.80it/s]

{'eval_loss': 0.03578469529747963, 'eval_runtime': 0.0714, 'eval_samples_per_second': 154.146, 'eval_steps_per_second': 84.08, 'epoch': 18.0}


 19%|█▉        | 893/4700 [01:30<04:01, 15.75it/s]
 19%|█▉        | 893/4700 [01:30<04:01, 15.75it/s]

{'eval_loss': 0.044258590787649155, 'eval_runtime': 0.0689, 'eval_samples_per_second': 159.587, 'eval_steps_per_second': 87.048, 'epoch': 19.0}


 19%|█▉        | 903/4700 [01:33<08:29,  7.45it/s]

{'loss': 0.0033, 'grad_norm': 0.01172460988163948, 'learning_rate': 8.085106382978723e-06, 'epoch': 19.15}


 20%|█▉        | 939/4700 [01:35<03:57, 15.84it/s]
 20%|██        | 940/4700 [01:35<03:57, 15.84it/s]

{'eval_loss': 0.058734867721796036, 'eval_runtime': 0.0692, 'eval_samples_per_second': 159.028, 'eval_steps_per_second': 86.743, 'epoch': 20.0}


 21%|██        | 987/4700 [01:41<03:56, 15.72it/s]
 21%|██        | 987/4700 [01:41<03:56, 15.72it/s]

{'eval_loss': 0.04295680671930313, 'eval_runtime': 0.0708, 'eval_samples_per_second': 155.275, 'eval_steps_per_second': 84.695, 'epoch': 21.0}


 21%|██▏       | 1003/4700 [01:44<05:21, 11.51it/s]

{'loss': 0.0026, 'grad_norm': 0.7668514847755432, 'learning_rate': 7.872340425531916e-06, 'epoch': 21.28}


 22%|██▏       | 1033/4700 [01:46<03:52, 15.75it/s]
 22%|██▏       | 1034/4700 [01:46<03:52, 15.75it/s]

{'eval_loss': 0.04735280200839043, 'eval_runtime': 0.073, 'eval_samples_per_second': 150.646, 'eval_steps_per_second': 82.17, 'epoch': 22.0}


 23%|██▎       | 1081/4700 [01:50<03:50, 15.68it/s]
 23%|██▎       | 1081/4700 [01:51<03:50, 15.68it/s]

{'eval_loss': 0.03867730870842934, 'eval_runtime': 0.0763, 'eval_samples_per_second': 144.259, 'eval_steps_per_second': 78.687, 'epoch': 23.0}


 23%|██▎       | 1103/4700 [01:54<04:14, 14.11it/s]

{'loss': 0.0019, 'grad_norm': 0.017853640019893646, 'learning_rate': 7.659574468085107e-06, 'epoch': 23.4}


 24%|██▍       | 1127/4700 [01:55<03:41, 16.11it/s]
 24%|██▍       | 1128/4700 [01:55<03:41, 16.11it/s]

{'eval_loss': 0.037846263498067856, 'eval_runtime': 0.0695, 'eval_samples_per_second': 158.29, 'eval_steps_per_second': 86.34, 'epoch': 24.0}


 25%|██▌       | 1175/4700 [02:00<03:45, 15.66it/s]
 25%|██▌       | 1175/4700 [02:00<03:45, 15.66it/s]

{'eval_loss': 0.04152911901473999, 'eval_runtime': 0.0695, 'eval_samples_per_second': 158.247, 'eval_steps_per_second': 86.316, 'epoch': 25.0}


 26%|██▌       | 1203/4700 [02:04<03:50, 15.17it/s]

{'loss': 0.0015, 'grad_norm': 0.021570656448602676, 'learning_rate': 7.446808510638298e-06, 'epoch': 25.53}


 26%|██▌       | 1221/4700 [02:05<03:42, 15.64it/s]
 26%|██▌       | 1222/4700 [02:05<03:42, 15.64it/s]

{'eval_loss': 0.03805603086948395, 'eval_runtime': 0.0717, 'eval_samples_per_second': 153.336, 'eval_steps_per_second': 83.638, 'epoch': 26.0}


 27%|██▋       | 1269/4700 [02:10<03:38, 15.72it/s]
 27%|██▋       | 1269/4700 [02:10<03:38, 15.72it/s]

{'eval_loss': 0.037800054997205734, 'eval_runtime': 0.0718, 'eval_samples_per_second': 153.297, 'eval_steps_per_second': 83.616, 'epoch': 27.0}


 28%|██▊       | 1303/4700 [02:14<03:39, 15.48it/s]

{'loss': 0.0014, 'grad_norm': 0.05468226224184036, 'learning_rate': 7.234042553191491e-06, 'epoch': 27.66}


 28%|██▊       | 1315/4700 [02:15<03:35, 15.72it/s]
 28%|██▊       | 1316/4700 [02:15<03:35, 15.72it/s]

{'eval_loss': 0.03698619827628136, 'eval_runtime': 0.0707, 'eval_samples_per_second': 155.488, 'eval_steps_per_second': 84.812, 'epoch': 28.0}


 29%|██▉       | 1363/4700 [02:19<03:31, 15.77it/s]
 29%|██▉       | 1363/4700 [02:19<03:31, 15.77it/s]

{'eval_loss': 0.03821110725402832, 'eval_runtime': 0.0705, 'eval_samples_per_second': 155.997, 'eval_steps_per_second': 85.09, 'epoch': 29.0}


 30%|██▉       | 1403/4700 [02:24<03:27, 15.88it/s]

{'loss': 0.0012, 'grad_norm': 0.008205904625356197, 'learning_rate': 7.021276595744682e-06, 'epoch': 29.79}


 30%|██▉       | 1409/4700 [02:24<03:29, 15.73it/s]
 30%|███       | 1410/4700 [02:24<03:29, 15.73it/s]

{'eval_loss': 0.0391620434820652, 'eval_runtime': 0.0711, 'eval_samples_per_second': 154.691, 'eval_steps_per_second': 84.377, 'epoch': 30.0}


 31%|███       | 1457/4700 [02:29<03:25, 15.74it/s]
 31%|███       | 1457/4700 [02:29<03:25, 15.74it/s]

{'eval_loss': 0.04019337520003319, 'eval_runtime': 0.0695, 'eval_samples_per_second': 158.215, 'eval_steps_per_second': 86.299, 'epoch': 31.0}


 32%|███▏      | 1503/4700 [02:34<03:22, 15.80it/s]

{'loss': 0.001, 'grad_norm': 0.023866651579737663, 'learning_rate': 6.808510638297873e-06, 'epoch': 31.91}



 32%|███▏      | 1504/4700 [02:34<03:22, 15.80it/s]

{'eval_loss': 0.040582384914159775, 'eval_runtime': 0.0695, 'eval_samples_per_second': 158.381, 'eval_steps_per_second': 86.39, 'epoch': 32.0}


 33%|███▎      | 1551/4700 [02:39<03:18, 15.88it/s]
 33%|███▎      | 1551/4700 [02:39<03:18, 15.88it/s]

{'eval_loss': 0.03847317397594452, 'eval_runtime': 0.0685, 'eval_samples_per_second': 160.486, 'eval_steps_per_second': 87.538, 'epoch': 33.0}


 34%|███▍      | 1597/4700 [02:44<03:17, 15.74it/s]
 34%|███▍      | 1598/4700 [02:44<03:17, 15.74it/s]

{'eval_loss': 0.03844139352440834, 'eval_runtime': 0.0721, 'eval_samples_per_second': 152.465, 'eval_steps_per_second': 83.163, 'epoch': 34.0}


 34%|███▍      | 1603/4700 [02:46<10:03,  5.13it/s]

{'loss': 0.001, 'grad_norm': 0.008173436857759953, 'learning_rate': 6.595744680851064e-06, 'epoch': 34.04}


 35%|███▌      | 1645/4700 [02:49<03:12, 15.87it/s]
 35%|███▌      | 1645/4700 [02:49<03:12, 15.87it/s]

{'eval_loss': 0.038352567702531815, 'eval_runtime': 0.0683, 'eval_samples_per_second': 161.142, 'eval_steps_per_second': 87.896, 'epoch': 35.0}


 36%|███▌      | 1691/4700 [02:53<03:14, 15.49it/s]
 36%|███▌      | 1692/4700 [02:53<03:14, 15.49it/s]

{'eval_loss': 0.03909435495734215, 'eval_runtime': 0.0713, 'eval_samples_per_second': 154.274, 'eval_steps_per_second': 84.149, 'epoch': 36.0}


 36%|███▌      | 1703/4700 [02:56<05:53,  8.49it/s]

{'loss': 0.0009, 'grad_norm': 0.006174721289426088, 'learning_rate': 6.382978723404256e-06, 'epoch': 36.17}


 37%|███▋      | 1739/4700 [02:58<03:07, 15.79it/s]
 37%|███▋      | 1739/4700 [02:59<03:07, 15.79it/s]

{'eval_loss': 0.04421167075634003, 'eval_runtime': 0.0693, 'eval_samples_per_second': 158.726, 'eval_steps_per_second': 86.578, 'epoch': 37.0}


 38%|███▊      | 1785/4700 [03:03<03:05, 15.72it/s]
 38%|███▊      | 1786/4700 [03:03<03:05, 15.72it/s]

{'eval_loss': 0.04321393743157387, 'eval_runtime': 0.0707, 'eval_samples_per_second': 155.622, 'eval_steps_per_second': 84.885, 'epoch': 38.0}


 38%|███▊      | 1803/4700 [03:06<03:52, 12.47it/s]

{'loss': 0.0009, 'grad_norm': 0.009139345958828926, 'learning_rate': 6.170212765957447e-06, 'epoch': 38.3}


 39%|███▉      | 1833/4700 [03:08<03:00, 15.85it/s]
 39%|███▉      | 1833/4700 [03:08<03:00, 15.85it/s]

{'eval_loss': 0.040979500859975815, 'eval_runtime': 0.0683, 'eval_samples_per_second': 161.013, 'eval_steps_per_second': 87.825, 'epoch': 39.0}


 40%|███▉      | 1879/4700 [03:13<03:00, 15.63it/s]
 40%|████      | 1880/4700 [03:13<03:00, 15.63it/s]

{'eval_loss': 0.041137970983982086, 'eval_runtime': 0.0703, 'eval_samples_per_second': 156.402, 'eval_steps_per_second': 85.31, 'epoch': 40.0}


 40%|████      | 1903/4700 [03:16<03:13, 14.45it/s]

{'loss': 0.0007, 'grad_norm': 0.008595666848123074, 'learning_rate': 5.957446808510638e-06, 'epoch': 40.43}


 41%|████      | 1927/4700 [03:18<02:56, 15.68it/s]
 41%|████      | 1927/4700 [03:18<02:56, 15.68it/s]

{'eval_loss': 0.0410960428416729, 'eval_runtime': 0.069, 'eval_samples_per_second': 159.341, 'eval_steps_per_second': 86.913, 'epoch': 41.0}


 42%|████▏     | 1973/4700 [03:23<02:53, 15.70it/s]
 42%|████▏     | 1974/4700 [03:23<02:53, 15.70it/s]

{'eval_loss': 0.08654186129570007, 'eval_runtime': 0.0714, 'eval_samples_per_second': 154.045, 'eval_steps_per_second': 84.024, 'epoch': 42.0}


 43%|████▎     | 2003/4700 [03:26<02:56, 15.24it/s]

{'loss': 0.001, 'grad_norm': 12.814239501953125, 'learning_rate': 5.744680851063831e-06, 'epoch': 42.55}


 43%|████▎     | 2021/4700 [03:28<02:49, 15.85it/s]
 43%|████▎     | 2021/4700 [03:28<02:49, 15.85it/s]

{'eval_loss': 0.03966720029711723, 'eval_runtime': 0.0718, 'eval_samples_per_second': 153.306, 'eval_steps_per_second': 83.621, 'epoch': 43.0}


 44%|████▍     | 2067/4700 [03:32<02:45, 15.93it/s]
 44%|████▍     | 2068/4700 [03:32<02:45, 15.93it/s]

{'eval_loss': 0.04073072969913483, 'eval_runtime': 0.0698, 'eval_samples_per_second': 157.545, 'eval_steps_per_second': 85.934, 'epoch': 44.0}


 45%|████▍     | 2103/4700 [03:36<02:48, 15.39it/s]

{'loss': 0.001, 'grad_norm': 0.006541753653436899, 'learning_rate': 5.531914893617022e-06, 'epoch': 44.68}


 45%|████▌     | 2115/4700 [03:37<02:44, 15.70it/s]
 45%|████▌     | 2115/4700 [03:37<02:44, 15.70it/s]

{'eval_loss': 0.04037857428193092, 'eval_runtime': 0.0729, 'eval_samples_per_second': 150.876, 'eval_steps_per_second': 82.296, 'epoch': 45.0}


 46%|████▌     | 2161/4700 [03:42<02:42, 15.63it/s]
 46%|████▌     | 2162/4700 [03:42<02:42, 15.63it/s]

{'eval_loss': 0.03973280265927315, 'eval_runtime': 0.0688, 'eval_samples_per_second': 159.798, 'eval_steps_per_second': 87.163, 'epoch': 46.0}


 47%|████▋     | 2203/4700 [03:47<02:39, 15.62it/s]

{'loss': 0.0007, 'grad_norm': 0.012191567569971085, 'learning_rate': 5.319148936170213e-06, 'epoch': 46.81}


 47%|████▋     | 2209/4700 [03:47<02:37, 15.77it/s]
 47%|████▋     | 2209/4700 [03:47<02:37, 15.77it/s]

{'eval_loss': 0.039428647607564926, 'eval_runtime': 0.0683, 'eval_samples_per_second': 161.032, 'eval_steps_per_second': 87.835, 'epoch': 47.0}


 48%|████▊     | 2255/4700 [03:52<02:35, 15.68it/s]
 48%|████▊     | 2256/4700 [03:52<02:35, 15.68it/s]

{'eval_loss': 0.03906620293855667, 'eval_runtime': 0.0708, 'eval_samples_per_second': 155.308, 'eval_steps_per_second': 84.713, 'epoch': 48.0}


 49%|████▉     | 2303/4700 [03:57<02:33, 15.59it/s]

{'loss': 0.0006, 'grad_norm': 0.004878205247223377, 'learning_rate': 5.106382978723404e-06, 'epoch': 48.94}



 49%|████▉     | 2303/4700 [03:57<02:33, 15.59it/s]

{'eval_loss': 0.039847519248723984, 'eval_runtime': 0.0741, 'eval_samples_per_second': 148.404, 'eval_steps_per_second': 80.948, 'epoch': 49.0}


 50%|████▉     | 2349/4700 [04:02<02:29, 15.69it/s]
 50%|█████     | 2350/4700 [04:02<02:29, 15.69it/s]

{'eval_loss': 0.03975449502468109, 'eval_runtime': 0.0718, 'eval_samples_per_second': 153.251, 'eval_steps_per_second': 83.591, 'epoch': 50.0}


 51%|█████     | 2397/4700 [04:07<02:25, 15.82it/s]
 51%|█████     | 2397/4700 [04:07<02:25, 15.82it/s]

{'eval_loss': 0.03961580619215965, 'eval_runtime': 0.0705, 'eval_samples_per_second': 156.027, 'eval_steps_per_second': 85.106, 'epoch': 51.0}


 51%|█████     | 2403/4700 [04:09<07:38,  5.01it/s]

{'loss': 0.0006, 'grad_norm': 0.004776547197252512, 'learning_rate': 4.893617021276596e-06, 'epoch': 51.06}


 52%|█████▏    | 2443/4700 [04:12<02:21, 15.99it/s]
 52%|█████▏    | 2444/4700 [04:12<02:21, 15.99it/s]

{'eval_loss': 0.04075395688414574, 'eval_runtime': 0.0723, 'eval_samples_per_second': 152.236, 'eval_steps_per_second': 83.038, 'epoch': 52.0}


 53%|█████▎    | 2491/4700 [04:16<02:20, 15.73it/s]
 53%|█████▎    | 2491/4700 [04:16<02:20, 15.73it/s]

{'eval_loss': 0.0409214086830616, 'eval_runtime': 0.0721, 'eval_samples_per_second': 152.652, 'eval_steps_per_second': 83.265, 'epoch': 53.0}


 53%|█████▎    | 2503/4700 [04:19<04:02,  9.05it/s]

{'loss': 0.0006, 'grad_norm': 0.03157833218574524, 'learning_rate': 4.680851063829788e-06, 'epoch': 53.19}


 54%|█████▍    | 2537/4700 [04:21<02:15, 15.95it/s]
 54%|█████▍    | 2538/4700 [04:21<02:15, 15.95it/s]

{'eval_loss': 0.04073750972747803, 'eval_runtime': 0.0734, 'eval_samples_per_second': 149.833, 'eval_steps_per_second': 81.727, 'epoch': 54.0}


 55%|█████▌    | 2585/4700 [04:26<02:12, 15.98it/s]
 55%|█████▌    | 2585/4700 [04:26<02:12, 15.98it/s]

{'eval_loss': 0.04079408198595047, 'eval_runtime': 0.0698, 'eval_samples_per_second': 157.664, 'eval_steps_per_second': 85.998, 'epoch': 55.0}


 55%|█████▌    | 2603/4700 [04:29<02:45, 12.64it/s]

{'loss': 0.0005, 'grad_norm': 0.004593093879520893, 'learning_rate': 4.468085106382979e-06, 'epoch': 55.32}


 56%|█████▌    | 2631/4700 [04:31<02:10, 15.85it/s]
 56%|█████▌    | 2632/4700 [04:31<02:10, 15.85it/s]

{'eval_loss': 0.04119507223367691, 'eval_runtime': 0.071, 'eval_samples_per_second': 154.847, 'eval_steps_per_second': 84.462, 'epoch': 56.0}


 57%|█████▋    | 2679/4700 [04:36<02:07, 15.84it/s]
 57%|█████▋    | 2679/4700 [04:36<02:07, 15.84it/s]

{'eval_loss': 0.04164031520485878, 'eval_runtime': 0.0722, 'eval_samples_per_second': 152.436, 'eval_steps_per_second': 83.147, 'epoch': 57.0}


 58%|█████▊    | 2703/4700 [04:39<02:18, 14.40it/s]

{'loss': 0.0005, 'grad_norm': 0.006184808444231749, 'learning_rate': 4.255319148936171e-06, 'epoch': 57.45}


 58%|█████▊    | 2725/4700 [04:40<02:05, 15.79it/s]
 58%|█████▊    | 2726/4700 [04:41<02:04, 15.79it/s]

{'eval_loss': 0.041382454335689545, 'eval_runtime': 0.0704, 'eval_samples_per_second': 156.199, 'eval_steps_per_second': 85.2, 'epoch': 58.0}


 59%|█████▉    | 2773/4700 [04:45<02:01, 15.86it/s]
 59%|█████▉    | 2773/4700 [04:45<02:01, 15.86it/s]

{'eval_loss': 0.07791271805763245, 'eval_runtime': 0.0705, 'eval_samples_per_second': 155.934, 'eval_steps_per_second': 85.055, 'epoch': 59.0}


 60%|█████▉    | 2803/4700 [04:49<02:03, 15.30it/s]

{'loss': 0.0008, 'grad_norm': 0.0051254937425255775, 'learning_rate': 4.042553191489362e-06, 'epoch': 59.57}


 60%|█████▉    | 2819/4700 [04:50<01:57, 16.05it/s]
 60%|██████    | 2820/4700 [04:50<01:57, 16.05it/s]

{'eval_loss': 0.0980735793709755, 'eval_runtime': 0.0664, 'eval_samples_per_second': 165.781, 'eval_steps_per_second': 90.426, 'epoch': 60.0}


 61%|██████    | 2867/4700 [04:55<01:55, 15.81it/s]
 61%|██████    | 2867/4700 [04:55<01:55, 15.81it/s]

{'eval_loss': 0.08582782745361328, 'eval_runtime': 0.0699, 'eval_samples_per_second': 157.307, 'eval_steps_per_second': 85.804, 'epoch': 61.0}


 62%|██████▏   | 2903/4700 [04:59<01:55, 15.56it/s]

{'loss': 0.0005, 'grad_norm': 0.006678905803710222, 'learning_rate': 3.8297872340425535e-06, 'epoch': 61.7}


 62%|██████▏   | 2913/4700 [05:00<01:53, 15.69it/s]
 62%|██████▏   | 2914/4700 [05:00<01:53, 15.69it/s]

{'eval_loss': 0.08089985698461533, 'eval_runtime': 0.0716, 'eval_samples_per_second': 153.71, 'eval_steps_per_second': 83.842, 'epoch': 62.0}


 63%|██████▎   | 2961/4700 [05:05<01:47, 16.18it/s]
 63%|██████▎   | 2961/4700 [05:05<01:47, 16.18it/s]

{'eval_loss': 0.04057367518544197, 'eval_runtime': 0.0695, 'eval_samples_per_second': 158.168, 'eval_steps_per_second': 86.274, 'epoch': 63.0}


 64%|██████▍   | 3003/4700 [05:09<01:47, 15.77it/s]

{'loss': 0.0026, 'grad_norm': 0.005959679372608662, 'learning_rate': 3.6170212765957453e-06, 'epoch': 63.83}


 64%|██████▍   | 3007/4700 [05:09<01:48, 15.66it/s]
 64%|██████▍   | 3008/4700 [05:10<01:48, 15.66it/s]

{'eval_loss': 0.0412907637655735, 'eval_runtime': 0.072, 'eval_samples_per_second': 152.746, 'eval_steps_per_second': 83.316, 'epoch': 64.0}


 65%|██████▌   | 3055/4700 [05:14<01:42, 16.01it/s]
 65%|██████▌   | 3055/4700 [05:14<01:42, 16.01it/s]

{'eval_loss': 0.04202314093708992, 'eval_runtime': 0.0684, 'eval_samples_per_second': 160.843, 'eval_steps_per_second': 87.733, 'epoch': 65.0}


 66%|██████▌   | 3101/4700 [05:19<01:41, 15.70it/s]

{'loss': 0.0004, 'grad_norm': 0.0037007767241448164, 'learning_rate': 3.4042553191489363e-06, 'epoch': 65.96}



 66%|██████▌   | 3102/4700 [05:19<01:41, 15.70it/s]

{'eval_loss': 0.042490240186452866, 'eval_runtime': 0.0717, 'eval_samples_per_second': 153.371, 'eval_steps_per_second': 83.657, 'epoch': 66.0}


 67%|██████▋   | 3149/4700 [05:24<01:38, 15.78it/s]
 67%|██████▋   | 3149/4700 [05:24<01:38, 15.78it/s]

{'eval_loss': 0.04237835481762886, 'eval_runtime': 0.0682, 'eval_samples_per_second': 161.333, 'eval_steps_per_second': 88.0, 'epoch': 67.0}


 68%|██████▊   | 3195/4700 [05:28<01:34, 15.90it/s]
 68%|██████▊   | 3196/4700 [05:29<01:34, 15.90it/s]

{'eval_loss': 0.04276514798402786, 'eval_runtime': 0.0697, 'eval_samples_per_second': 157.73, 'eval_steps_per_second': 86.034, 'epoch': 68.0}


 68%|██████▊   | 3203/4700 [05:31<03:53,  6.42it/s]

{'loss': 0.0004, 'grad_norm': 0.005685935262590647, 'learning_rate': 3.191489361702128e-06, 'epoch': 68.09}


 69%|██████▉   | 3243/4700 [05:33<01:31, 16.01it/s]
 69%|██████▉   | 3243/4700 [05:33<01:31, 16.01it/s]

{'eval_loss': 0.04311457276344299, 'eval_runtime': 0.0674, 'eval_samples_per_second': 163.31, 'eval_steps_per_second': 89.078, 'epoch': 69.0}


 70%|██████▉   | 3289/4700 [05:38<01:29, 15.84it/s]
 70%|███████   | 3290/4700 [05:38<01:28, 15.84it/s]

{'eval_loss': 0.04326026514172554, 'eval_runtime': 0.071, 'eval_samples_per_second': 154.94, 'eval_steps_per_second': 84.513, 'epoch': 70.0}


 70%|███████   | 3303/4700 [05:41<02:10, 10.75it/s]

{'loss': 0.0004, 'grad_norm': 0.005183683708310127, 'learning_rate': 2.978723404255319e-06, 'epoch': 70.21}


 71%|███████   | 3337/4700 [05:43<01:26, 15.79it/s]
 71%|███████   | 3337/4700 [05:43<01:26, 15.79it/s]

{'eval_loss': 0.04323548823595047, 'eval_runtime': 0.0704, 'eval_samples_per_second': 156.236, 'eval_steps_per_second': 85.219, 'epoch': 71.0}


 72%|███████▏  | 3383/4700 [05:48<01:23, 15.75it/s]
 72%|███████▏  | 3384/4700 [05:48<01:23, 15.75it/s]

{'eval_loss': 0.04330157861113548, 'eval_runtime': 0.0706, 'eval_samples_per_second': 155.888, 'eval_steps_per_second': 85.03, 'epoch': 72.0}


 72%|███████▏  | 3403/4700 [05:51<01:35, 13.53it/s]

{'loss': 0.0004, 'grad_norm': 0.03309694305062294, 'learning_rate': 2.765957446808511e-06, 'epoch': 72.34}


 73%|███████▎  | 3431/4700 [05:53<01:19, 15.96it/s]
 73%|███████▎  | 3431/4700 [05:53<01:19, 15.96it/s]

{'eval_loss': 0.04344433918595314, 'eval_runtime': 0.0684, 'eval_samples_per_second': 160.804, 'eval_steps_per_second': 87.711, 'epoch': 73.0}


 74%|███████▍  | 3477/4700 [05:57<01:18, 15.65it/s]
 74%|███████▍  | 3478/4700 [05:57<01:18, 15.65it/s]

{'eval_loss': 0.04336690530180931, 'eval_runtime': 0.0738, 'eval_samples_per_second': 149.054, 'eval_steps_per_second': 81.302, 'epoch': 74.0}


 75%|███████▍  | 3503/4700 [06:01<01:21, 14.65it/s]

{'loss': 0.0004, 'grad_norm': 0.0035168055910617113, 'learning_rate': 2.553191489361702e-06, 'epoch': 74.47}


 75%|███████▌  | 3525/4700 [06:03<01:14, 15.83it/s]
 75%|███████▌  | 3525/4700 [06:03<01:14, 15.83it/s]

{'eval_loss': 0.04339845851063728, 'eval_runtime': 0.0703, 'eval_samples_per_second': 156.548, 'eval_steps_per_second': 85.39, 'epoch': 75.0}


 76%|███████▌  | 3571/4700 [06:07<01:10, 16.00it/s]
 76%|███████▌  | 3572/4700 [06:08<01:10, 16.00it/s]

{'eval_loss': 0.04351821169257164, 'eval_runtime': 0.0679, 'eval_samples_per_second': 162.035, 'eval_steps_per_second': 88.383, 'epoch': 76.0}


 77%|███████▋  | 3603/4700 [06:12<01:10, 15.57it/s]

{'loss': 0.0004, 'grad_norm': 0.0033774259500205517, 'learning_rate': 2.340425531914894e-06, 'epoch': 76.6}


 77%|███████▋  | 3619/4700 [06:13<01:06, 16.29it/s]
 77%|███████▋  | 3619/4700 [06:13<01:06, 16.29it/s]

{'eval_loss': 0.04908238723874092, 'eval_runtime': 0.071, 'eval_samples_per_second': 154.996, 'eval_steps_per_second': 84.543, 'epoch': 77.0}


 78%|███████▊  | 3665/4700 [06:18<01:04, 16.05it/s]
 78%|███████▊  | 3666/4700 [06:18<01:04, 16.05it/s]

{'eval_loss': 0.05507812649011612, 'eval_runtime': 0.0719, 'eval_samples_per_second': 153.091, 'eval_steps_per_second': 83.504, 'epoch': 78.0}


 79%|███████▉  | 3703/4700 [06:23<01:04, 15.54it/s]

{'loss': 0.0004, 'grad_norm': 0.0036171646788716316, 'learning_rate': 2.1276595744680853e-06, 'epoch': 78.72}


 79%|███████▉  | 3713/4700 [06:24<01:02, 15.86it/s]
 79%|███████▉  | 3713/4700 [06:24<01:02, 15.86it/s]

{'eval_loss': 0.05383574590086937, 'eval_runtime': 0.0692, 'eval_samples_per_second': 158.916, 'eval_steps_per_second': 86.682, 'epoch': 79.0}


 80%|███████▉  | 3759/4700 [06:30<00:58, 16.06it/s]
 80%|████████  | 3760/4700 [06:30<00:58, 16.06it/s]

{'eval_loss': 0.052710842341184616, 'eval_runtime': 0.068, 'eval_samples_per_second': 161.656, 'eval_steps_per_second': 88.176, 'epoch': 80.0}


 81%|████████  | 3803/4700 [06:35<00:57, 15.68it/s]

{'loss': 0.0004, 'grad_norm': 0.004580102860927582, 'learning_rate': 1.9148936170212767e-06, 'epoch': 80.85}


 81%|████████  | 3807/4700 [06:35<00:56, 15.84it/s]
 81%|████████  | 3807/4700 [06:35<00:56, 15.84it/s]

{'eval_loss': 0.05107694864273071, 'eval_runtime': 0.0697, 'eval_samples_per_second': 157.884, 'eval_steps_per_second': 86.119, 'epoch': 81.0}


 82%|████████▏ | 3853/4700 [06:41<00:53, 15.83it/s]
 82%|████████▏ | 3854/4700 [06:41<00:53, 15.83it/s]

{'eval_loss': 0.04933178797364235, 'eval_runtime': 0.0732, 'eval_samples_per_second': 150.175, 'eval_steps_per_second': 81.914, 'epoch': 82.0}


 83%|████████▎ | 3901/4700 [06:47<00:49, 16.00it/s]

{'loss': 0.0004, 'grad_norm': 0.003297780640423298, 'learning_rate': 1.7021276595744682e-06, 'epoch': 82.98}



 83%|████████▎ | 3901/4700 [06:47<00:49, 16.00it/s]

{'eval_loss': 0.048317547887563705, 'eval_runtime': 0.0724, 'eval_samples_per_second': 151.917, 'eval_steps_per_second': 82.864, 'epoch': 83.0}


 84%|████████▍ | 3947/4700 [06:52<00:46, 16.13it/s]
 84%|████████▍ | 3948/4700 [06:53<00:46, 16.13it/s]

{'eval_loss': 0.047324202954769135, 'eval_runtime': 0.0748, 'eval_samples_per_second': 147.055, 'eval_steps_per_second': 80.212, 'epoch': 84.0}


 85%|████████▌ | 3995/4700 [06:57<00:44, 15.85it/s]
 85%|████████▌ | 3995/4700 [06:57<00:44, 15.85it/s]

{'eval_loss': 0.04655924066901207, 'eval_runtime': 0.0713, 'eval_samples_per_second': 154.227, 'eval_steps_per_second': 84.124, 'epoch': 85.0}


 85%|████████▌ | 4003/4700 [07:00<02:07,  5.45it/s]

{'loss': 0.0004, 'grad_norm': 0.006560354959219694, 'learning_rate': 1.4893617021276596e-06, 'epoch': 85.11}


 86%|████████▌ | 4041/4700 [07:02<00:41, 16.00it/s]
 86%|████████▌ | 4042/4700 [07:03<00:41, 16.00it/s]

{'eval_loss': 0.04633384943008423, 'eval_runtime': 0.07, 'eval_samples_per_second': 157.096, 'eval_steps_per_second': 85.689, 'epoch': 86.0}


 87%|████████▋ | 4089/4700 [07:08<00:38, 16.03it/s]
 87%|████████▋ | 4089/4700 [07:08<00:38, 16.03it/s]

{'eval_loss': 0.04588129371404648, 'eval_runtime': 0.0686, 'eval_samples_per_second': 160.416, 'eval_steps_per_second': 87.5, 'epoch': 87.0}


 87%|████████▋ | 4103/4700 [07:11<00:58, 10.23it/s]

{'loss': 0.0003, 'grad_norm': 0.002846543211489916, 'learning_rate': 1.276595744680851e-06, 'epoch': 87.23}


 88%|████████▊ | 4135/4700 [07:13<00:36, 15.60it/s]
 88%|████████▊ | 4136/4700 [07:13<00:36, 15.60it/s]

{'eval_loss': 0.04580669105052948, 'eval_runtime': 0.0701, 'eval_samples_per_second': 157.013, 'eval_steps_per_second': 85.643, 'epoch': 88.0}


 89%|████████▉ | 4183/4700 [07:18<00:32, 15.71it/s]
 89%|████████▉ | 4183/4700 [07:18<00:32, 15.71it/s]

{'eval_loss': 0.04526777192950249, 'eval_runtime': 0.0706, 'eval_samples_per_second': 155.745, 'eval_steps_per_second': 84.952, 'epoch': 89.0}


 89%|████████▉ | 4203/4700 [07:21<00:37, 13.25it/s]

{'loss': 0.0003, 'grad_norm': 0.003883709665387869, 'learning_rate': 1.0638297872340427e-06, 'epoch': 89.36}


 90%|████████▉ | 4229/4700 [07:23<00:30, 15.65it/s]
 90%|█████████ | 4230/4700 [07:23<00:30, 15.65it/s]

{'eval_loss': 0.0451212152838707, 'eval_runtime': 0.0714, 'eval_samples_per_second': 154.153, 'eval_steps_per_second': 84.083, 'epoch': 90.0}


 91%|█████████ | 4277/4700 [07:28<00:26, 15.83it/s]
 91%|█████████ | 4277/4700 [07:28<00:26, 15.83it/s]

{'eval_loss': 0.04500315710902214, 'eval_runtime': 0.0692, 'eval_samples_per_second': 158.903, 'eval_steps_per_second': 86.674, 'epoch': 91.0}


 92%|█████████▏| 4303/4700 [07:33<00:27, 14.32it/s]

{'loss': 0.0003, 'grad_norm': 0.00404032738879323, 'learning_rate': 8.510638297872341e-07, 'epoch': 91.49}


 92%|█████████▏| 4323/4700 [07:34<00:23, 15.80it/s]
 92%|█████████▏| 4324/4700 [07:34<00:23, 15.80it/s]

{'eval_loss': 0.04486975446343422, 'eval_runtime': 0.0677, 'eval_samples_per_second': 162.381, 'eval_steps_per_second': 88.572, 'epoch': 92.0}


 93%|█████████▎| 4371/4700 [07:39<00:20, 15.72it/s]
 93%|█████████▎| 4371/4700 [07:39<00:20, 15.72it/s]

{'eval_loss': 0.04472574591636658, 'eval_runtime': 0.0668, 'eval_samples_per_second': 164.656, 'eval_steps_per_second': 89.812, 'epoch': 93.0}


 94%|█████████▎| 4403/4700 [07:44<00:19, 15.10it/s]

{'loss': 0.0004, 'grad_norm': 0.0026400992646813393, 'learning_rate': 6.382978723404255e-07, 'epoch': 93.62}


 94%|█████████▍| 4417/4700 [07:45<00:18, 15.69it/s]
 94%|█████████▍| 4418/4700 [07:45<00:17, 15.69it/s]

{'eval_loss': 0.04458744078874588, 'eval_runtime': 0.0704, 'eval_samples_per_second': 156.181, 'eval_steps_per_second': 85.19, 'epoch': 94.0}


 95%|█████████▌| 4465/4700 [07:50<00:14, 15.78it/s]
 95%|█████████▌| 4465/4700 [07:50<00:14, 15.78it/s]

{'eval_loss': 0.04434718191623688, 'eval_runtime': 0.0693, 'eval_samples_per_second': 158.641, 'eval_steps_per_second': 86.532, 'epoch': 95.0}


 96%|█████████▌| 4503/4700 [07:54<00:12, 15.61it/s]

{'loss': 0.0003, 'grad_norm': 0.0029811859130859375, 'learning_rate': 4.2553191489361704e-07, 'epoch': 95.74}


 96%|█████████▌| 4511/4700 [07:55<00:12, 15.56it/s]
 96%|█████████▌| 4512/4700 [07:55<00:12, 15.56it/s]

{'eval_loss': 0.04419024661183357, 'eval_runtime': 0.0723, 'eval_samples_per_second': 152.191, 'eval_steps_per_second': 83.013, 'epoch': 96.0}


 97%|█████████▋| 4559/4700 [07:59<00:08, 15.80it/s]
 97%|█████████▋| 4559/4700 [08:00<00:08, 15.80it/s]

{'eval_loss': 0.04489561542868614, 'eval_runtime': 0.0665, 'eval_samples_per_second': 165.309, 'eval_steps_per_second': 90.169, 'epoch': 97.0}


 98%|█████████▊| 4603/4700 [08:04<00:06, 15.73it/s]

{'loss': 0.0013, 'grad_norm': 0.003040699288249016, 'learning_rate': 2.1276595744680852e-07, 'epoch': 97.87}


 98%|█████████▊| 4605/4700 [08:04<00:06, 15.69it/s]
 98%|█████████▊| 4606/4700 [08:04<00:05, 15.69it/s]

{'eval_loss': 0.045111220329999924, 'eval_runtime': 0.0715, 'eval_samples_per_second': 153.926, 'eval_steps_per_second': 83.96, 'epoch': 98.0}


 99%|█████████▉| 4653/4700 [08:10<00:02, 15.79it/s]
 99%|█████████▉| 4653/4700 [08:10<00:02, 15.79it/s]

{'eval_loss': 0.04509497061371803, 'eval_runtime': 0.0705, 'eval_samples_per_second': 156.081, 'eval_steps_per_second': 85.135, 'epoch': 99.0}


100%|██████████| 4700/4700 [08:15<00:00, 15.83it/s]

{'loss': 0.0003, 'grad_norm': 0.003243876388296485, 'learning_rate': 0.0, 'epoch': 100.0}



100%|██████████| 4700/4700 [08:17<00:00, 15.83it/s]

{'eval_loss': 0.04508901387453079, 'eval_runtime': 0.0302, 'eval_samples_per_second': 364.469, 'eval_steps_per_second': 198.801, 'epoch': 100.0}


100%|██████████| 4700/4700 [08:24<00:00,  9.32it/s]


{'train_runtime': 504.2092, 'train_samples_per_second': 18.445, 'train_steps_per_second': 9.322, 'train_loss': 0.017546459908022526, 'epoch': 100.0}


('./ner_model/tokenizer_config.json',
 './ner_model/special_tokens_map.json',
 './ner_model/vocab.txt',
 './ner_model/added_tokens.json',
 './ner_model/tokenizer.json')

In [105]:
from transformers import BertTokenizerFast, BertForTokenClassification
import torch

# Load the saved model and tokenizer
# model_dir = "./ner_model"
# tokenizer = BertTokenizerFast.from_pretrained(model_dir)
# model = BertForTokenClassification.from_pretrained(model_dir)

# Load the label mappings
id_to_label = model.config.id2label

# Sample text to test the model
test_sentences = [
    "[INT] [BOT]  [USR] Order 121435, whats the status?"
]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def predict_ner(sentence):
    # Tokenize the sentence
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
    # print(inputs)

    # Get model predictions
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Extract logits and compute predictions
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=2)

    # Map predictions to labels
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    predicted_labels = [id_to_label[pred.item()] for pred in predictions[0]]

    # Combine tokens and predicted labels
    results = []
    for token, label in zip(tokens, predicted_labels):
        if token not in ["[CLS]", "[SEP]", "[PAD]"]:  # Ignore special tokens
            results.append((token, label))
    return results

# Test the model with sample sentences
for sentence in test_sentences:
    # print(f"Sentence: {sentence}")
    predictions = predict_ner(sentence)
    # print("Predictions:")
    # for token, label in predictions:
        # print(f"{token:15} -> {label}")
    # print("\n")

def extract_entity(sentence, entity_label):
    """
    enitity_label: the label of the entity to extract (ORD for example)
    extract the entity from the sentence along with I-ORD tokens
    """
    # Get NER predictions
    predictions = predict_ner(sentence)
    # for token, label in predictions:
    #     # print(f"{token:15} -> {label}")

    # get all entities in the label
    all_entities = []
    for token, label in predictions:
        if label.startswith("B-") or label.startswith("I-"):
            all_entities.append(label.split("-")[1])

    all_entities = list(set(all_entities))

    return_values = []
    for entity_label in all_entities:
        # Extract entities with the specified label
        entities = []
        for token, label in predictions:
            if label == f"B-{entity_label}" or label == f"I-{entity_label}":
                entities.append(token)
        # print(entities)
        
        # remove subword prefixes from the entities
        entity_combined = []
        prev_entity = None
        for i , entity in enumerate(entities):
            if entity.startswith("##"):
                if prev_entity is None:
                    prev_entity = ""
                prev_entity += entity[2:]
            else:
                if prev_entity:
                    entity_combined.append(prev_entity)
                prev_entity = entity

            if i == len(entities) - 1:
                entity_combined.append(prev_entity)
        return_values.append((entity_label, entity_combined))
    return return_values
        

# get the B-COUNT entity from the sentence
sentence = "[INT] [BOT]  [USR] Please show me a list of all the orders I've made before 21 November 2024 and after 10 April 2024"
entity_label = "START_DATE"
extracted_entity = extract_entity(sentence, entity_label)
print(f"Extracted entity: {extracted_entity}")

Extracted entity: [('END_DATE', ['november44']), ('START_DATE', ['21', '202', '10', 'april', '202'])]


In [106]:
# test for 15 random samples
data = pd.read_csv('./data/ner_data.csv')

# draw 15 random samples
test_data = data.sample(15)

for idx, row in test_data.iterrows():
    print(f"Sentence: {row['instruction']}")
    # get the 
    extracted_entity = extract_entity(row['instruction'], "ORD")
    print(f"Extracted entity: {extracted_entity}")
    print("")

Sentence: [INT] [BOT]  [USR] Can you check my order 12345?
Extracted entity: [('ORD', ['12345'])]

Sentence: [INT] cancel_order [BOT] Please confirm the cancellation of your last order. [USR] Yes, confirm it.
Extracted entity: [('AFFIRMATION', ['yes'])]

Sentence: [INT] [BOT]  [USR] Cancel the last item I ordered.
Extracted entity: [('ORD', ['last'])]

Sentence: [INT] give_order_id [BOT] Are you sure? [USR] Yes I am sure!
Extracted entity: [('AFFIRMATION', ['yes'])]

Sentence: [INT] list_orders [BOT] Do you need to filter anything? [USR] Yeah give me just 32 of them
Extracted entity: [('COUNT', ['32'])]

Sentence: [INT] cancel_order [BOT] Can you confirm the cancellation of your most recent order? [USR] No, don’t proceed with it.
Extracted entity: [('AFFIRMATION', ['no'])]

Sentence: [INT] cancel_order [BOT] Are you sure you want to cancel order 112233? [USR] No, leave it as is.
Extracted entity: [('AFFIRMATION', ['no'])]

Sentence: [INT] [BOT]  [USR] Can you list some of my previous o