In [None]:
"""
General Purpose of this notebook:
    1. Fine tune an embedding model
    2. Prepare synthetic data to train the embedding model. 
        a. It should support qna pairs, triplets (query, positive, negative), chunks
        b. Allow for matryoshka embeddings. The bge-small is of 384 dims, try converting that into something like 256 dims and verify performance
    3. RAGAS ? 
"""

## Prepare synthetic data for embedding model

In [8]:
import os, json
from dotenv import load_dotenv
import sys
sys.path.append("../")

In [9]:
from modules_v2 import *
from bundles import *
from random import random, randint

Initializing id with my-unique-pplx-123

Initializing model with llama-3-sonar-large-32k-online

Initializing system_prompt with [{'role': 'system', 'content': '\n<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are an expert lawyer in India. Answer the query after carefully parsing the search results to identify relevant information. Only make definitive statements if concretely supported by the information at hand. you MUST APPEND reference URLs enclosed within <references> </references>.\n<|eot_id|>\n'}]

Initializing tools with None

Initializing temperature with 0.3

Initializing stream_def with None

Initializing available_models with ['llama-3-sonar-small-32k-online', 'llama-3-sonar-large-32k-online']

Initializing base_url with https://api.perplexity.ai

___________________________________________


In [4]:
mutli_q_agent = BaseClaudeAgent(bundle = multi_query_generator_bundle)

Initializing id with my-unique-pos-neg-generator-bundle-123

Initializing model with claude-3-haiku-20240307

Initializing system_prompt with You are a legal expert in India. You will be given a passage and your job is to output few questions out of it. Given the passage, transform it into a concise, keyword-focused search query in telegraphic style. The query should retain the essential information but remove unnecessary words. Feel free to add prefixes like "find case law where ..." or "judgements related to ..." or "precedents where ..."
Generate upto 5 such questions with few being a bit vague 

Initializing tools with []

Initializing temperature with 0.4

Initializing stream_def and default_behaviour with None; None

___________________________________________


In [87]:
sonnet_chance = 0.2

In [10]:
# jsons_path = "/mnt/d/work/datasets/judgements2/json_pdf"
jsons_path = "/mnt/d/work/projects/agents/scripts/processed_data"

In [11]:
num_judgements = 500
processed_jsons = [os.path.join(jsons_path, json_file) for json_file in os.listdir(jsons_path)] # [-1500: -1000]

In [37]:
keys_to_check = ["judgement", "cause_of_action", "interpretations", "case_brief", "ratio", "obiter"]

In [9]:
# synthetic_dataset = [] 
synthetic_dataset = read_json("/mnt/d/work/projects/agents/playbooks/synthetic_embedding_data_list_pt2.json")

In [9]:
min_words = 20

In [10]:
def extract_queries(text):
    pattern = r'\d+\.\s+(.+)'
    return re.findall(pattern, text)


In [22]:
new_system_prompt = """You are a legal expert in India. You will be given a passage and your job is to output few questions out of it. Given the passage, transform it into a concise, keyword-focused search query in telegraphic style. The query should retain the essential information but remove unnecessary words. Generate upto 5 such questions with few being a bit vague"""

In [28]:
offset = 73

for i, json_file in enumerate(processed_jsons):
    if i < offset:
        continue
    
    data = read_json(json_file)
    # print(data)

    for key_to_check in keys_to_check:
        _choice = random()
        model = "claude-3-5-sonnet-20240620" if _choice < sonnet_chance else "claude-3-haiku-20240307"
        
        if key_to_check == "interpretations" or key_to_check == "prior_history":
            if not data[key_to_check]:
                continue

            idx = randint(0, len(data[key_to_check]) - 1)
            content = data[key_to_check][idx]
            
            if len(content.split(" ")) < min_words:
                continue

            context = [{"role": "user", "content": f"<context> The following is related to one of the {key_to_check} of a case law </context> \n\n {content}"}]
            questions_generated = extract_queries(mutli_q_agent._process_call(context = context, model = model, stream = False, system = new_system_prompt))
            # print(questions_generated)

            synthetic_dataset.append({
                "queries": questions_generated,
                "model": model,
                "key_used": key_to_check,
                "passage": content,
                "file_path": json_file
            })



        else:
            content = data.get(key_to_check, None)
            if not content:
                continue

            if len(content.split(" ")) < min_words:
                continue

            context = [{"role": "user", "content": f"<context> The following is related to the {key_to_check} part of a case law </context> \n\n {content}"}]
            questions_generated = extract_queries(mutli_q_agent._process_call(context = context, model = model, stream = False))
            # print(questions_generated)
            # questions_generated = extract_queries(questions_generated)
            
            synthetic_dataset.append({
                "queries": questions_generated,
                "model": model,
                "key_used": key_to_check,
                "passage": content,
                "file_path": json_file
            })

    print(f"done with {json_file}; {i}")


GETTING THIS AS CONTEXT: [{'role': 'user', 'content': '<context> The following is related to the judgement part of a case law </context> \n\n The Supreme Court dismissed the appeals filed by the accused persons and upheld the conviction by the High Court. The High Court had set aside the acquittal of the accused persons by the trial court and convicted them for offences punishable under Section 326 read with Section 34 of the Indian Penal Code, 1860. The accused persons were sentenced to undergo rigorous imprisonment for four years, while accused Durga was convicted for offence punishable under Section 324 IPC and sentenced to undergo imprisonment for one year.'}]
GETTING THIS AS CONTEXT: [{'role': 'user', 'content': '<context> The following is related to the cause_of_action part of a case law </context> \n\n The case was initiated after the deceased was assaulted and wounded by the accused persons, leading to his death. The informant lodged a complaint, and the accused persons were ch

In [29]:
len(synthetic_dataset)

7740

In [32]:
synthetic_dataset[-4]["queries"]

["Find case law where restaurant businesses challenged the applicability of the Employees' Provident Funds Act, 1952.",
 "Judgements related to the constitutional validity of the Employees' Provident Funds Act, 1952.",
 "Precedents where businesses challenged the inclusion under the Employees' Provident Funds Act, 1952 by government notification.",
 "Case law on the scope of Section 1(3)(b) of the Employees' Provident Funds Act, 1952.",
 "Judgements on the powers of the Central Government to bring businesses under the Employees' Provident Funds Act, 1952."]

In [34]:
write_json(synthetic_dataset, "./synthetic_embedding_data_list_pt3.json")

## Fine-tune embedding model

In [2]:
import os, json
from dotenv import load_dotenv
import torch
import sys
from sentence_transformers import SentenceTransformer

  from tqdm.autonotebook import tqdm, trange


In [3]:
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers import SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import InformationRetrievalEvaluator, SequentialEvaluator
from sentence_transformers import SentenceTransformerTrainer

In [4]:
from random import randint, random, choice

In [5]:
from datasets import load_dataset, Dataset, concatenate_datasets

In [5]:
device = "cuda" # "cpu"
model_id = "BAAI/bge-base-en-v1.5" # context_len 512; dims 768; 109m params
# model_id = "BAAI/bge-small-en-v1.5" # context_len 512; dims 384; 33m params

In [6]:
model = SentenceTransformer(model_id, device = device)



In [7]:
# del model
# if you do del model, you must also run this or model instances will stack
# torch.cuda.empty_cache() 

In [6]:
dataset_path = "./synthetic_embedding_data_list_pt3.json"

In [12]:
json_dataset = read_json(dataset_path)
len(json_dataset)

7740

In [13]:
train_fraction = 0.9

In [14]:
def convert_custom_to_sbert_format(dataset, train_split = 0.85):
    # this function has to be somewhat custom written from some format to a specific format supported by sentence-transformers. Depends on your training data if its (anchor, anchor) or (anchor, positive, negative) etc
    # this function will convert whatever we have got from paraphrasing to an (anchor, anchor) pair with NO label
    _cvt = {"anchor": [], "positive": [], "id": []}
    c = 0
    for i, x in enumerate(dataset):
        if not x["queries"]:
            continue
        
        
        # for query in x["queries"]:
            
        #     _cvt["anchor"].append(query)
        #     _cvt["positive"].append(x["passage"])
        #     _cvt["id"].append(c)
        #     c += 1
            
        _cvt["anchor"].append(choice(x["queries"]))
        _cvt["positive"].append(x["passage"])
        _cvt["id"].append(i)
            
        # _cvt.append({
        #     "anchor": choice(x["queries"]),
        #     "positive": x["passage"]
        # })
    new_ds = Dataset.from_dict(_cvt).train_test_split(train_size = train_split, shuffle = True)
    # new_ds = new_ds.add_column("id", range(len(new_ds)))
    train_set, test_set = new_ds["train"], new_ds["test"]
    return train_set, test_set

In [15]:
def convert_dict_to_dataset(_dict, split = False, train_split = 0.85):
    # new_ds = Dataset.from_dict(_dict).train_test_split(train_size = train_split, shuffle = True)
    new_ds = Dataset.from_dict(_dict)

    if split:
        new_ds = new_ds.train_test_split(train_size = train_size, shuffle = True)
        # new_ds = new_ds.add_column("id", range(len(new_ds)))
        train_set, test_set = new_ds["train"], new_ds["test"]
        return train_set, test_set

    return new_ds


def convert_indexed_to_sbert_format(dataset, train_split = 0.85):
    # this function has to be somewhat custom written from some format to a specific format supported by sentence-transformers. Depends on your training data if its (anchor, anchor) or (anchor, positive, negative) etc
    # this function will convert whatever we have got from paraphrasing to an (anchor, anchor) pair with NO label
    _cvt = {"anchor": [], "positive": [], "id": []}
    c = 0
    
    for i, x in enumerate(dataset):
        if not x["queries"]:
            continue
        
        
        # for query in x["queries"]:
            
        #     _cvt["anchor"].append(query)
        #     _cvt["positive"].append(x["passage"])
        #     _cvt["id"].append(c)
        #     c += 1
            
        _cvt["anchor"].append(choice(x["queries"]))
        _cvt["positive"].append(x["passage"])
        _cvt["id"].append(i)
            
        # _cvt.append({
        #     "anchor": choice(x["queries"]),
        #     "positive": x["passage"]
        # })
    new_ds = Dataset.from_dict(_cvt).train_test_split(train_size = train_split, shuffle = True)
    # new_ds = new_ds.add_column("id", range(len(new_ds)))
    train_set, test_set = new_ds["train"], new_ds["test"]
    return train_set, test_set

In [16]:
def convert_custom_to_sbert_format_v2(dataset, train_split = 0.85):
    # this function has to be somewhat custom written from some format to a specific format supported by sentence-transformers. Depends on your training data if its (anchor, anchor) or (anchor, positive, negative) etc
    # this function will convert whatever we have got from paraphrasing to an (anchor, anchor) pair with NO label
    _cvt = {"anchor": [], "positive": [], "id": []}
    c = 0
    for i, x in enumerate(dataset):
        if not x["queries"]:
            continue
        
        
        # for query in x["queries"]:
            
        #     _cvt["anchor"].append(query)
        #     _cvt["positive"].append(x["passage"])
        #     _cvt["id"].append(c)
        #     c += 1
            
        _cvt["anchor"].append(choice(x["queries"]))
        _cvt["positive"].append(x["passage"])
        _cvt["id"].append(i)
            
        # _cvt.append({
        #     "anchor": choice(x["queries"]),
        #     "positive": x["passage"]
        # })

    
    
    new_ds = Dataset.from_dict(_cvt).train_test_split(train_size = train_split, shuffle = True)
    # new_ds = new_ds.add_column("id", range(len(new_ds)))
    train_set, test_set = new_ds["train"], new_ds["test"]
    return train_set, test_set

In [16]:
train_set, test_set = convert_custom_to_sbert_format(dataset = json_dataset, train_split = train_fraction)
train_set, test_set

(Dataset({
     features: ['anchor', 'positive', 'id'],
     num_rows: 6885
 }),
 Dataset({
     features: ['anchor', 'positive', 'id'],
     num_rows: 765
 }))

In [17]:
corpus_dataset = concatenate_datasets([train_set, test_set])
corpus_dataset

Dataset({
    features: ['anchor', 'positive', 'id'],
    num_rows: 7650
})

In [21]:
for cd in corpus_dataset:
    print(cd)
    break

{'anchor': 'Find case law where principles of natural justice not violated in dismissal cases.', 'positive': 'The Supreme Court held that the principles of natural justice were not violated, and the punishment of dismissal was not harsh or disproportionate to the charges proved against the respondent. The High Court and Industrial Tribunal were not justified in interfering with the order of dismissal passed by the management.', 'id': 3805}


In [18]:
## HACKY REPLACEMENT FOR THIS TO CHECK FOR 0 SHOT

_dummy_keys_to_check = ["judgement", "cause_of_action", "interpretations", "case_brief"]
for i, x in enumerate(corpus_dataset["positive"]):
    json_data = read_json(choice(processed_jsons))
    _random_key = choice(_dummy_keys_to_check)
    if not _random_key in json_data:
        continue

    key_data = json_data[_random_key]

    # if isinstance(key_data, list):
    try:
        key_data = choice(key_data) if isinstance(key_data, list) else key_data
    except Exception as e:
        print(e, key_data, i)
        continue
    corpus_dataset["positive"][i] = key_data
    # corpus_dataset["positive"].append(key_data)
    # print(json_data)
    # break


JSONDecodeError: Invalid control character at: line 34 column 61 (char 5085)

In [None]:
## HACKY REPLACEMENT FOR THIS TO CHECK FOR 0 SHOT
num = corpus_dataset.num_rows
i = 0
_dummy_keys_to_check = ["judgement", "cause_of_action", "interpretations", "case_brief"]

# for i, x in enumerate(corpus_dataset["positive"]):
while i < num:
    try:
        json_data = read_json(choice(processed_jsons))
        _random_key = choice(_dummy_keys_to_check)
        if not _random_key in json_data:
            continue

        key_data = json_data[_random_key]

        # if isinstance(key_data, list):
        try:
            key_data = choice(key_data) if isinstance(key_data, list) else key_data
        except Exception as e:
            print(e, key_data, i)
            continue
        corpus_dataset["positive"].append(key_data)
        corpus_dataset["id"].append(num + i)
        
        # corpus_dataset["positive"].append(key_data)
        # print(json_data)
        # break
        i += 1
    except:
        pass       



In [24]:
corpus_dataset

Dataset({
    features: ['anchor', 'positive', 'id'],
    num_rows: 7650
})

In [23]:
corpus = dict(
    zip(corpus_dataset["id"], corpus_dataset["positive"])
)  # Our corpus (cid => document)
queries = dict(
    zip(test_set["id"], test_set["anchor"])
)  # Our queries (qid => question)
 
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for q_id in queries:
    relevant_docs[q_id] = [q_id]

In [27]:
for c, x in queries.items():
    print(c, x)
    break

100 Precedents related to marking scheme disputes education exams


In [78]:
ir_evaluator = InformationRetrievalEvaluator(queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs)

In [None]:
loss = MultipleNegativesRankingLoss(model = model)

In [None]:
run_name = "all-7k"

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # losses that use "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=400,
    save_strategy="steps",
    save_steps=400,
    save_total_limit=2,
    logging_steps=100,
    run_name=run_name,  # Will be used in W&B if `wandb` is installed
)

In [None]:
# 5k used with choice(queries)
trainer = SentenceTransformerTrainer(
    model = model, # bg-base-en-v1
    args = args,  # training arguments
    train_dataset = train_set.select_columns(
        ["positive", "anchor"]
    ),  # training dataset
    loss = loss,
    evaluator = ir_evaluator,
)

In [None]:
trainer.train()

In [79]:
results = ir_evaluator(model)
results

{'cosine_accuracy@1': 0.39869281045751637,
 'cosine_accuracy@3': 0.6052287581699346,
 'cosine_accuracy@5': 0.6562091503267974,
 'cosine_accuracy@10': 0.7241830065359477,
 'cosine_precision@1': 0.39869281045751637,
 'cosine_precision@3': 0.20174291938997818,
 'cosine_precision@5': 0.13124183006535947,
 'cosine_precision@10': 0.07241830065359477,
 'cosine_recall@1': 0.39869281045751637,
 'cosine_recall@3': 0.6052287581699346,
 'cosine_recall@5': 0.6562091503267974,
 'cosine_recall@10': 0.7241830065359477,
 'cosine_ndcg@10': 0.5608259138074697,
 'cosine_mrr@10': 0.5084624961095547,
 'cosine_map@100': 0.5146917884428458,
 'dot_accuracy@1': 0.4,
 'dot_accuracy@3': 0.6052287581699346,
 'dot_accuracy@5': 0.6562091503267974,
 'dot_accuracy@10': 0.7241830065359477,
 'dot_precision@1': 0.4,
 'dot_precision@3': 0.20174291938997818,
 'dot_precision@5': 0.13124183006535947,
 'dot_precision@10': 0.07241830065359477,
 'dot_recall@1': 0.4,
 'dot_recall@3': 0.6052287581699346,
 'dot_recall@5': 0.656209

In [95]:
new_model = SentenceTransformer("./models/all-7k/checkpoint-2000", device = "cuda")
results = ir_evaluator(new_model)
results

In [82]:
new_model = SentenceTransformer("./models/custom-5k/checkpoint-1300", device = "cuda")
results = ir_evaluator(new_model)
results

{'cosine_accuracy@1': 0.5241830065359477,
 'cosine_accuracy@3': 0.7594771241830065,
 'cosine_accuracy@5': 0.8091503267973856,
 'cosine_accuracy@10': 0.869281045751634,
 'cosine_precision@1': 0.5241830065359477,
 'cosine_precision@3': 0.25315904139433554,
 'cosine_precision@5': 0.1618300653594771,
 'cosine_precision@10': 0.08692810457516338,
 'cosine_recall@1': 0.5241830065359477,
 'cosine_recall@3': 0.7594771241830065,
 'cosine_recall@5': 0.8091503267973856,
 'cosine_recall@10': 0.869281045751634,
 'cosine_ndcg@10': 0.7018468726308886,
 'cosine_mrr@10': 0.6476014109347435,
 'cosine_map@100': 0.6521047831641598,
 'dot_accuracy@1': 0.5241830065359477,
 'dot_accuracy@3': 0.7594771241830065,
 'dot_accuracy@5': 0.8091503267973856,
 'dot_accuracy@10': 0.869281045751634,
 'dot_precision@1': 0.5241830065359477,
 'dot_precision@3': 0.25315904139433554,
 'dot_precision@5': 0.1618300653594771,
 'dot_precision@10': 0.08692810457516338,
 'dot_recall@1': 0.5241830065359477,
 'dot_recall@3': 0.759477

## Matryoshka Embedding models

In [1]:
matryoshka_dimensions = [768, 512, 256, 128, 64] # Important: large to small; 

In [2]:
# lets do it for both off the shelf and fine-tuned models to see delta

In [3]:
from sentence_transformers.util import cos_sim

  from tqdm.autonotebook import tqdm, trange


In [4]:
matryoshka_evaluators = []

for dim in matryoshka_dimensions:
    matryoshka_ir_evaluator = InformationRetrievalEvaluator(
        queries=queries,
        corpus=corpus,
        relevant_docs=relevant_docs,
        name=f"dim_{dim}",
        truncate_dim=dim,  # Truncate the embeddings to a certain dimension
        score_functions={"cosine": cos_sim},
    )
    matryoshka_evaluators.append(matryoshka_ir_evaluator)


evaluator = SequentialEvaluator(matryoshka_evaluators)



NameError: name 'InformationRetrievalEvaluator' is not defined

In [30]:

# Evaluate the model (default model)
results = evaluator(model)
 
# # COMMENT IN for full results
# print(results)
 
# Print the main score
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    cos_acc_key = f"dim_{dim}_cosine_accuracy@10"
    print(f"{key}: {results[key]}")
    print(f"{cos_acc_key}: {results[cos_acc_key]}")

dim_768_cosine_ndcg@10: 0.5793473638743243
dim_768_cosine_ndcg@10: 0.7450980392156863
dim_512_cosine_ndcg@10: 0.5689299585686538
dim_512_cosine_ndcg@10: 0.7372549019607844
dim_256_cosine_ndcg@10: 0.5320024256594003
dim_256_cosine_ndcg@10: 0.6901960784313725
dim_128_cosine_ndcg@10: 0.4784713220864311
dim_128_cosine_ndcg@10: 0.6313725490196078
dim_64_cosine_ndcg@10: 0.35778822156665946
dim_64_cosine_ndcg@10: 0.4875816993464052


In [32]:
for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    cos_acc_key = f"dim_{dim}_cosine_accuracy@10"
    print(f"{key}: {results[key]}")
    print(f"{cos_acc_key}: {results[cos_acc_key]}")

dim_768_cosine_ndcg@10: 0.5793473638743243
dim_768_cosine_accuracy@10: 0.7450980392156863
dim_512_cosine_ndcg@10: 0.5689299585686538
dim_512_cosine_accuracy@10: 0.7372549019607844
dim_256_cosine_ndcg@10: 0.5320024256594003
dim_256_cosine_accuracy@10: 0.6901960784313725
dim_128_cosine_ndcg@10: 0.4784713220864311
dim_128_cosine_accuracy@10: 0.6313725490196078
dim_64_cosine_ndcg@10: 0.35778822156665946
dim_64_cosine_accuracy@10: 0.4875816993464052


In [98]:
# finetuned but without the matryoshka adapter

finetuned_results = evaluator(new_model)

for dim in matryoshka_dimensions:
    key = f"dim_{dim}_cosine_ndcg@10"
    cos_acc_key = f"dim_{dim}_cosine_accuracy@10"
    print(f"{key}: {finetuned_results[key]}")
    print(f"{key}: {finetuned_results[cos_acc_key]}")

dim_768_cosine_ndcg@10: 0.7155822169667164
dim_512_cosine_ndcg@10: 0.710507006236591
dim_256_cosine_ndcg@10: 0.7007075382207802
dim_128_cosine_ndcg@10: 0.6767161080317702
dim_64_cosine_ndcg@10: 0.6278630199025461


In [119]:
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss


finetuned_model = SentenceTransformer("./models/all-7k/checkpoint-2000", device = "cuda")

matryoshka_dimensions = [768, 512, 256, 128, 64]  # Important: large to small
inner_train_loss = MultipleNegativesRankingLoss(finetuned_model)

# train with different weights as well! 
matryoshka_loss = MatryoshkaLoss(
    finetuned_model, inner_train_loss, matryoshka_dims=matryoshka_dimensions, matryoshka_weights=[1,1,1,1.15,1.15]
)

In [33]:
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss


matryoshka_dimensions = [768, 512, 256, 128, 64]  # Important: large to small
inner_train_loss = MultipleNegativesRankingLoss(model)

# train with different weights as well! 
matryoshka_loss = MatryoshkaLoss(
    model, inner_train_loss, matryoshka_dims=matryoshka_dimensions, matryoshka_weights=[1,1,1,1.15,1.15]
)

In [34]:
# define training arguments

matryoshka_args = SentenceTransformerTrainingArguments(
    output_dir="./models/matryoshka_models_raw", # output directory and hugging face model ID
    num_train_epochs=5,                         # number of epochs
    per_device_train_batch_size=16,             # train batch size
    gradient_accumulation_steps=16,             # for a global batch size of 512
    per_device_eval_batch_size=16,              # evaluation batch size
    warmup_ratio=0.1,                           # warmup ratio
    learning_rate=2e-5,                         # learning rate, 2e-5 is a good value
    lr_scheduler_type="cosine",                 # use constant learning rate scheduler
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    fp16=True,                                  # use tf32 precision
    bf16=False,                                  # use bf16 precision
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    eval_strategy="epoch",                      # evaluate after each epoch
    save_strategy="epoch",                      # save after each epoch
    logging_steps=100,                          # log every 10 steps
    save_total_limit=2,                         # save only the last 3 models
    # load_best_model_at_end=True,                # load the best model when training ends
    # metric_for_best_model="eval_dim_128_cosine_ndcg@10",  # Optimizing for the best ndcg@10 score for the 128 dimension
)

In [35]:
 
matryoshka_trainer = SentenceTransformerTrainer(
    model=model, # bg-base-en-v1
    args=matryoshka_args,  # training arguments
    train_dataset=train_set.select_columns(
        ["positive", "anchor"]
    ),  # training dataset
    loss=matryoshka_loss,
    evaluator=evaluator,
)

In [114]:
# finetuned model but with matryoshka loss with same weight (1)
matryoshka_trainer.train()

Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,No log,No log,0.547712,0.784314,0.84183,0.895425,0.547712,0.261438,0.168366,0.089542,0.547712,0.784314,0.84183,0.895425,0.726919,0.67209,0.675435,0.534641,0.780392,0.835294,0.894118,0.534641,0.260131,0.167059,0.089412,0.534641,0.780392,0.835294,0.894118,0.720151,0.663616,0.666788,0.538562,0.756863,0.815686,0.877124,0.538562,0.252288,0.163137,0.087712,0.538562,0.756863,0.815686,0.877124,0.711099,0.657497,0.661691,0.503268,0.717647,0.784314,0.858824,0.503268,0.239216,0.156863,0.085882,0.503268,0.717647,0.784314,0.858824,0.681309,0.624471,0.628828,0.469281,0.681046,0.730719,0.8,0.469281,0.227015,0.146144,0.08,0.469281,0.681046,0.730719,0.8,0.637201,0.584797,0.591212,0.591212
1,No log,No log,0.542484,0.786928,0.848366,0.896732,0.542484,0.262309,0.169673,0.089673,0.542484,0.786928,0.848366,0.896732,0.727154,0.671733,0.675108,0.533333,0.773856,0.839216,0.894118,0.533333,0.257952,0.167843,0.089412,0.533333,0.773856,0.839216,0.894118,0.719999,0.663333,0.666676,0.535948,0.754248,0.818301,0.881046,0.535948,0.251416,0.16366,0.088105,0.535948,0.754248,0.818301,0.881046,0.712563,0.658113,0.662003,0.499346,0.732026,0.796078,0.860131,0.499346,0.244009,0.159216,0.086013,0.499346,0.732026,0.796078,0.860131,0.683335,0.626196,0.630628,0.462745,0.675817,0.734641,0.807843,0.462745,0.225272,0.146928,0.080784,0.462745,0.675817,0.734641,0.807843,0.636029,0.580937,0.587033,0.587033
2,No log,No log,0.534641,0.779085,0.837908,0.894118,0.534641,0.259695,0.167582,0.089412,0.534641,0.779085,0.837908,0.894118,0.721596,0.665389,0.668879,0.522876,0.769935,0.837908,0.888889,0.522876,0.256645,0.167582,0.088889,0.522876,0.769935,0.837908,0.888889,0.712934,0.655595,0.659279,0.530719,0.755556,0.813072,0.875817,0.530719,0.251852,0.162614,0.087582,0.530719,0.755556,0.813072,0.875817,0.707787,0.653478,0.657734,0.494118,0.720261,0.79085,0.857516,0.494118,0.240087,0.15817,0.085752,0.494118,0.720261,0.79085,0.857516,0.678478,0.620706,0.625345,0.453595,0.667974,0.739869,0.807843,0.453595,0.222658,0.147974,0.080784,0.453595,0.667974,0.739869,0.807843,0.632896,0.576557,0.582645,0.582645
3,0.321200,No log,0.539869,0.779085,0.839216,0.895425,0.539869,0.259695,0.167843,0.089542,0.539869,0.779085,0.839216,0.895425,0.724212,0.668532,0.671937,0.526797,0.773856,0.837908,0.888889,0.526797,0.257952,0.167582,0.088889,0.526797,0.773856,0.837908,0.888889,0.715686,0.659159,0.662881,0.534641,0.756863,0.815686,0.875817,0.534641,0.252288,0.163137,0.087582,0.534641,0.756863,0.815686,0.875817,0.709761,0.656032,0.660374,0.500654,0.72549,0.79085,0.861438,0.500654,0.24183,0.15817,0.086144,0.500654,0.72549,0.79085,0.861438,0.683349,0.62601,0.630352,0.454902,0.678431,0.742484,0.80915,0.454902,0.226144,0.148497,0.080915,0.454902,0.678431,0.742484,0.80915,0.635277,0.579181,0.585329,0.585329
4,0.321200,No log,0.541176,0.776471,0.839216,0.895425,0.541176,0.258824,0.167843,0.089542,0.541176,0.776471,0.839216,0.895425,0.724006,0.668322,0.671724,0.52549,0.772549,0.837908,0.888889,0.52549,0.257516,0.167582,0.088889,0.52549,0.772549,0.837908,0.888889,0.714942,0.658178,0.661905,0.533333,0.755556,0.814379,0.875817,0.533333,0.251852,0.162876,0.087582,0.533333,0.755556,0.814379,0.875817,0.708936,0.654967,0.659298,0.498039,0.72549,0.79085,0.861438,0.498039,0.24183,0.15817,0.086144,0.498039,0.72549,0.79085,0.861438,0.681699,0.623832,0.628161,0.457516,0.677124,0.742484,0.810458,0.457516,0.225708,0.148497,0.081046,0.457516,0.677124,0.742484,0.810458,0.635889,0.579726,0.585756,0.585756


                                                                             

TrainOutput(global_step=130, training_loss=0.28168648939866286, metrics={'train_runtime': 974.2751, 'train_samples_per_second': 35.334, 'train_steps_per_second': 0.133, 'total_flos': 0.0, 'train_loss': 0.28168648939866286, 'epoch': 4.825986078886311})

In [122]:
# finetuned model but with matryoshka loss with weights of (1, 1, 1, 1.15, 1.15)
# checkpoint 107 and 130
matryoshka_trainer.train()

Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,No log,No log,0.545098,0.785621,0.840523,0.895425,0.545098,0.261874,0.168105,0.089542,0.545098,0.785621,0.840523,0.895425,0.725753,0.670542,0.673849,0.533333,0.777778,0.83268,0.89281,0.533333,0.259259,0.166536,0.089281,0.533333,0.777778,0.83268,0.89281,0.719179,0.662729,0.665997,0.538562,0.756863,0.810458,0.877124,0.538562,0.252288,0.162092,0.087712,0.538562,0.756863,0.810458,0.877124,0.710898,0.657317,0.661491,0.500654,0.718954,0.780392,0.857516,0.500654,0.239651,0.156078,0.085752,0.500654,0.718954,0.780392,0.857516,0.680091,0.623246,0.62769,0.466667,0.682353,0.730719,0.8,0.466667,0.227451,0.146144,0.08,0.466667,0.682353,0.730719,0.8,0.636297,0.583522,0.589952,0.589952
1,No log,No log,0.543791,0.788235,0.845752,0.895425,0.543791,0.262745,0.16915,0.089542,0.543791,0.788235,0.845752,0.895425,0.727375,0.672421,0.675901,0.533333,0.776471,0.837908,0.89281,0.533333,0.258824,0.167582,0.089281,0.533333,0.776471,0.837908,0.89281,0.719283,0.662804,0.666238,0.537255,0.751634,0.816993,0.882353,0.537255,0.250545,0.163399,0.088235,0.537255,0.751634,0.816993,0.882353,0.712988,0.658377,0.662145,0.498039,0.728105,0.797386,0.858824,0.498039,0.242702,0.159477,0.085882,0.498039,0.728105,0.797386,0.858824,0.681771,0.624554,0.629106,0.462745,0.67451,0.732026,0.80915,0.462745,0.224837,0.146405,0.080915,0.462745,0.67451,0.732026,0.80915,0.636484,0.58118,0.587228,0.587228
2,No log,No log,0.535948,0.776471,0.836601,0.89281,0.535948,0.258824,0.16732,0.089281,0.535948,0.776471,0.836601,0.89281,0.720874,0.664889,0.668488,0.518954,0.769935,0.839216,0.887582,0.518954,0.256645,0.167843,0.088758,0.518954,0.769935,0.839216,0.887582,0.71111,0.65351,0.657284,0.529412,0.751634,0.813072,0.877124,0.529412,0.250545,0.162614,0.087712,0.529412,0.751634,0.813072,0.877124,0.707259,0.652419,0.656532,0.494118,0.721569,0.789542,0.857516,0.494118,0.240523,0.157908,0.085752,0.494118,0.721569,0.789542,0.857516,0.67867,0.620959,0.625602,0.453595,0.666667,0.739869,0.807843,0.453595,0.222222,0.147974,0.080784,0.453595,0.666667,0.739869,0.807843,0.63225,0.575753,0.581841,0.581841
3,0.351000,No log,0.542484,0.779085,0.836601,0.895425,0.542484,0.259695,0.16732,0.089542,0.542484,0.779085,0.836601,0.895425,0.72463,0.669181,0.672551,0.532026,0.772549,0.837908,0.890196,0.532026,0.257516,0.167582,0.08902,0.532026,0.772549,0.837908,0.890196,0.717869,0.661753,0.665334,0.541176,0.755556,0.811765,0.87451,0.541176,0.251852,0.162353,0.087451,0.541176,0.755556,0.811765,0.87451,0.711209,0.658491,0.662943,0.501961,0.72549,0.79085,0.860131,0.501961,0.24183,0.15817,0.086013,0.501961,0.72549,0.79085,0.860131,0.683229,0.626261,0.630687,0.457516,0.678431,0.739869,0.80915,0.457516,0.226144,0.147974,0.080915,0.457516,0.678431,0.739869,0.80915,0.635813,0.579957,0.586123,0.586123
4,0.351000,No log,0.538562,0.776471,0.836601,0.895425,0.538562,0.258824,0.16732,0.089542,0.538562,0.776471,0.836601,0.895425,0.723239,0.66729,0.670657,0.528105,0.771242,0.837908,0.890196,0.528105,0.257081,0.167582,0.08902,0.528105,0.771242,0.837908,0.890196,0.716217,0.659531,0.663121,0.537255,0.755556,0.813072,0.87451,0.537255,0.251852,0.162614,0.087451,0.537255,0.755556,0.813072,0.87451,0.709778,0.656521,0.660972,0.501961,0.724183,0.789542,0.861438,0.501961,0.241394,0.157908,0.086144,0.501961,0.724183,0.789542,0.861438,0.683236,0.625907,0.630244,0.457516,0.677124,0.739869,0.810458,0.457516,0.225708,0.147974,0.081046,0.457516,0.677124,0.739869,0.810458,0.636123,0.580041,0.586115,0.586115


                                                                             

TrainOutput(global_step=130, training_loss=0.30775490173926723, metrics={'train_runtime': 3894.6015, 'train_samples_per_second': 8.839, 'train_steps_per_second': 0.033, 'total_flos': 0.0, 'train_loss': 0.30775490173926723, 'epoch': 4.825986078886311})

In [36]:
# base model finetuning it with matryoshka directly with weights of (1, 1, 1, 1, 1)
matryoshka_trainer.train()

Epoch,Training Loss,Validation Loss,Dim 768 Cosine Accuracy@1,Dim 768 Cosine Accuracy@3,Dim 768 Cosine Accuracy@5,Dim 768 Cosine Accuracy@10,Dim 768 Cosine Precision@1,Dim 768 Cosine Precision@3,Dim 768 Cosine Precision@5,Dim 768 Cosine Precision@10,Dim 768 Cosine Recall@1,Dim 768 Cosine Recall@3,Dim 768 Cosine Recall@5,Dim 768 Cosine Recall@10,Dim 768 Cosine Ndcg@10,Dim 768 Cosine Mrr@10,Dim 768 Cosine Map@100,Dim 512 Cosine Accuracy@1,Dim 512 Cosine Accuracy@3,Dim 512 Cosine Accuracy@5,Dim 512 Cosine Accuracy@10,Dim 512 Cosine Precision@1,Dim 512 Cosine Precision@3,Dim 512 Cosine Precision@5,Dim 512 Cosine Precision@10,Dim 512 Cosine Recall@1,Dim 512 Cosine Recall@3,Dim 512 Cosine Recall@5,Dim 512 Cosine Recall@10,Dim 512 Cosine Ndcg@10,Dim 512 Cosine Mrr@10,Dim 512 Cosine Map@100,Dim 256 Cosine Accuracy@1,Dim 256 Cosine Accuracy@3,Dim 256 Cosine Accuracy@5,Dim 256 Cosine Accuracy@10,Dim 256 Cosine Precision@1,Dim 256 Cosine Precision@3,Dim 256 Cosine Precision@5,Dim 256 Cosine Precision@10,Dim 256 Cosine Recall@1,Dim 256 Cosine Recall@3,Dim 256 Cosine Recall@5,Dim 256 Cosine Recall@10,Dim 256 Cosine Ndcg@10,Dim 256 Cosine Mrr@10,Dim 256 Cosine Map@100,Dim 128 Cosine Accuracy@1,Dim 128 Cosine Accuracy@3,Dim 128 Cosine Accuracy@5,Dim 128 Cosine Accuracy@10,Dim 128 Cosine Precision@1,Dim 128 Cosine Precision@3,Dim 128 Cosine Precision@5,Dim 128 Cosine Precision@10,Dim 128 Cosine Recall@1,Dim 128 Cosine Recall@3,Dim 128 Cosine Recall@5,Dim 128 Cosine Recall@10,Dim 128 Cosine Ndcg@10,Dim 128 Cosine Mrr@10,Dim 128 Cosine Map@100,Dim 64 Cosine Accuracy@1,Dim 64 Cosine Accuracy@3,Dim 64 Cosine Accuracy@5,Dim 64 Cosine Accuracy@10,Dim 64 Cosine Precision@1,Dim 64 Cosine Precision@3,Dim 64 Cosine Precision@5,Dim 64 Cosine Precision@10,Dim 64 Cosine Recall@1,Dim 64 Cosine Recall@3,Dim 64 Cosine Recall@5,Dim 64 Cosine Recall@10,Dim 64 Cosine Ndcg@10,Dim 64 Cosine Mrr@10,Dim 64 Cosine Map@100,Sequential Score
0,No log,No log,0.532026,0.760784,0.813072,0.866667,0.532026,0.253595,0.162614,0.086667,0.532026,0.760784,0.813072,0.866667,0.706245,0.654041,0.658321,0.515033,0.75817,0.798693,0.861438,0.515033,0.252723,0.159739,0.086144,0.515033,0.75817,0.798693,0.861438,0.696056,0.642354,0.646731,0.498039,0.737255,0.784314,0.839216,0.498039,0.245752,0.156863,0.083922,0.498039,0.737255,0.784314,0.839216,0.67652,0.6235,0.628579,0.484967,0.696732,0.735948,0.805229,0.484967,0.232244,0.14719,0.080523,0.484967,0.696732,0.735948,0.805229,0.649308,0.599034,0.605127,0.418301,0.605229,0.670588,0.738562,0.418301,0.201743,0.134118,0.073856,0.418301,0.605229,0.670588,0.738562,0.578044,0.52677,0.533258,0.533258
1,No log,No log,0.547712,0.766013,0.835294,0.887582,0.547712,0.255338,0.167059,0.088758,0.547712,0.766013,0.835294,0.887582,0.723587,0.670295,0.673924,0.541176,0.754248,0.811765,0.879739,0.541176,0.251416,0.162353,0.087974,0.541176,0.754248,0.811765,0.879739,0.714735,0.66145,0.665567,0.513725,0.745098,0.805229,0.865359,0.513725,0.248366,0.161046,0.086536,0.513725,0.745098,0.805229,0.865359,0.697686,0.64318,0.647724,0.487582,0.708497,0.776471,0.844444,0.487582,0.236166,0.155294,0.084444,0.487582,0.708497,0.776471,0.844444,0.669676,0.613338,0.618129,0.428758,0.639216,0.694118,0.759477,0.428758,0.213072,0.138824,0.075948,0.428758,0.639216,0.694118,0.759477,0.598818,0.546922,0.554265,0.554265
2,No log,No log,0.547712,0.764706,0.833987,0.884967,0.547712,0.254902,0.166797,0.088497,0.547712,0.764706,0.833987,0.884967,0.722755,0.669967,0.674052,0.546405,0.746405,0.80915,0.882353,0.546405,0.248802,0.16183,0.088235,0.546405,0.746405,0.80915,0.882353,0.715693,0.662241,0.666332,0.520261,0.745098,0.805229,0.869281,0.520261,0.248366,0.161046,0.086928,0.520261,0.745098,0.805229,0.869281,0.701419,0.647081,0.651576,0.500654,0.713725,0.781699,0.854902,0.500654,0.237908,0.15634,0.08549,0.500654,0.713725,0.781699,0.854902,0.680501,0.62451,0.628635,0.435294,0.639216,0.70719,0.76732,0.435294,0.213072,0.141438,0.076732,0.435294,0.639216,0.70719,0.76732,0.605348,0.553093,0.560303,0.560303
3,0.960900,No log,0.54902,0.764706,0.831373,0.884967,0.54902,0.254902,0.166275,0.088497,0.54902,0.764706,0.831373,0.884967,0.722537,0.669824,0.673869,0.543791,0.743791,0.807843,0.88366,0.543791,0.24793,0.161569,0.088366,0.543791,0.743791,0.807843,0.88366,0.71471,0.660669,0.66468,0.524183,0.739869,0.802614,0.869281,0.524183,0.246623,0.160523,0.086928,0.524183,0.739869,0.802614,0.869281,0.701948,0.647889,0.652328,0.503268,0.705882,0.780392,0.852288,0.503268,0.235294,0.156078,0.085229,0.503268,0.705882,0.780392,0.852288,0.679531,0.624135,0.628566,0.44183,0.637908,0.717647,0.771242,0.44183,0.212636,0.143529,0.077124,0.44183,0.637908,0.717647,0.771242,0.609034,0.556867,0.563893,0.563893
4,0.960900,No log,0.547712,0.764706,0.831373,0.886275,0.547712,0.254902,0.166275,0.088627,0.547712,0.764706,0.831373,0.886275,0.722252,0.669071,0.672988,0.541176,0.743791,0.807843,0.88366,0.541176,0.24793,0.161569,0.088366,0.541176,0.743791,0.807843,0.88366,0.713616,0.659194,0.663203,0.522876,0.741176,0.802614,0.869281,0.522876,0.247059,0.160523,0.086928,0.522876,0.741176,0.802614,0.869281,0.701708,0.647544,0.651981,0.503268,0.705882,0.783007,0.852288,0.503268,0.235294,0.156601,0.085229,0.503268,0.705882,0.783007,0.852288,0.679808,0.624459,0.628881,0.440523,0.639216,0.715033,0.769935,0.440523,0.213072,0.143007,0.076993,0.440523,0.639216,0.715033,0.769935,0.608322,0.556258,0.563428,0.563428


                                                                     

TrainOutput(global_step=130, training_loss=0.8129666915306678, metrics={'train_runtime': 904.4229, 'train_samples_per_second': 38.063, 'train_steps_per_second': 0.144, 'total_flos': 0.0, 'train_loss': 0.8129666915306678, 'epoch': 4.825986078886311})