In [None]:
import json

from llama_index import SimpleDirectoryReader
from llama_index.node_parser import SentenceSplitter
from llama_index.schema import MetadataMode

In [3]:
TRAIN_FILES = ["train.txt"]
VAL_FILES = ["test.txt"]

TRAIN_CORPUS_FPATH = "train_corpus.json"
VAL_CORPUS_FPATH = "val_corpus.json"

In [4]:
def load_corpus(files, verbose=False):
    if verbose:
        print(f"Loading files {files}")

    reader = SimpleDirectoryReader(input_files=files)
    docs = reader.load_data()
    if verbose:
        print(f"Loaded {len(docs)} docs")

    parser = SentenceSplitter(chunk_size=250, chunk_overlap=0)
    nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)

    if verbose:
        print(f"Parsed {len(nodes)} nodes")

    return nodes

In [5]:
train_nodes = load_corpus(TRAIN_FILES, verbose=True)
val_nodes = load_corpus(VAL_FILES, verbose=True)

Loading files ['train.txt']
Loaded 1 docs


Parsing nodes:   0%|          | 0/1 [00:00<?, ?it/s]

Parsed 129 nodes
Loading files ['test.txt']
Loaded 1 docs


Parsing nodes:   0%|          | 0/1 [00:00<?, ?it/s]

Parsed 107 nodes


In [6]:
train_nodes[:3]

[TextNode(id_='065c7c68-64f1-41e9-9b5f-6d8141aae864', embedding=None, metadata={'file_path': 'train.txt', 'file_name': 'train.txt', 'file_type': 'text/plain', 'file_size': 66966, 'creation_date': '2024-01-09', 'last_modified_date': '2024-01-09', 'last_accessed_date': '2024-01-09'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='008c3477-fbe1-4da1-86a9-91d83316333d', node_type=<ObjectType.DOCUMENT: '4'>, metadata={'file_path': 'train.txt', 'file_name': 'train.txt', 'file_type': 'text/plain', 'file_size': 66966, 'creation_date': '2024-01-09', 'last_modified_date': '2024-01-09', 'last_accessed_date': '2024-01-09'}, hash='77b3142f61c86cad975ca9bc682650512f3a0498a97fb38e6a5b3721324a80c7'), <NodeRelationship.NEX

In [1]:
from llama_index.finetuning import (
    generate_qa_embedding_pairs,
    EmbeddingQAFinetuneDataset,
)
from llama_index.llms import OpenAI
import os
os.environ["OPENAI_API_KEY"] = "sk-xxx"
llm = OpenAI(model="gpt-3.5-turbo")

In [24]:
qa_generate_prompt_tmpl = """\
Context information is below.

---------------------
{context_str}
---------------------

Given the context information and not prior knowledge.
generate only questions based on the below query.

You are a Professor. Your task is to setup \
{num_questions_per_chunk} questions for an upcoming \
quiz/examination in Chinese. The questions should be diverse in nature \
across the document in Chinese. The questions should not contain options, not start with Q1/ Q2. \
Restrict the questions to the context information provided.
"""

train_dataset = generate_qa_embedding_pairs(nodes=train_nodes, llm=llm, num_questions_per_chunk=1, qa_generate_prompt_tmpl=qa_generate_prompt_tmpl)
val_dataset = generate_qa_embedding_pairs(nodes=val_nodes, llm=llm, num_questions_per_chunk=1, qa_generate_prompt_tmpl=qa_generate_prompt_tmpl)

train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 129/129 [08:03<00:00,  3.75s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 107/107 [06:49<00:00,  3.83s/it]


In [28]:
from llama_index.finetuning import SentenceTransformersFinetuneEngine

finetune_engine = SentenceTransformersFinetuneEngine(
    train_dataset,
    model_id="/data-xgb1/lmj/models/bge-base-zh-v1.5",
    model_output_path="/data-xgb1/lmj/models/bge-base-ft-001",
    val_dataset=val_dataset,
)

In [29]:
finetune_engine.finetune()

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/67 [00:00<?, ?it/s]

Iteration:   0%|          | 0/67 [00:00<?, ?it/s]

In [27]:
finetune_engine.loss

MultipleNegativesRankingLoss(
  (model): SentenceTransformer(
    (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel 
    (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
    (2): Normalize()
  )
  (cross_entropy_loss): CrossEntropyLoss()
)