# Abalone RAG

This notebook describes the development of the RAG application that is featured on [HuggingFace Spaces](https://huggingface.co/spaces/LoneWolfgang/Abalone-RAG-Demo).

This notebook is structured into three parts.
  1. `Build a Corpus`: Explains how text and pdf files may be converted into a structured corpus for indexing and retrieval.
  2. `Preprocess Paragraphs for RAG`: Goes deeper into preprocessing, explaining how text is cleaned and segmented for an optimal selection of context.
  3. `Indexing and Retrieval`: Demonstrates how to create a FAISS index using SBERT. Also includes reranking using Cross Encoders.
  4. `RAG with TinyLlama`: Ties everything together into a RAG system. Fine tune a TinyLlama to improve generation quality.


Start by installing dependancies. Its reccomended that you run this within a virtual environment.

In [1]:
# pip install -r "requirements.txt"

## Build a Corpus

A corpus is a large, diverse collection of texts used for analysis and modelling. Modern corpora include billions or trillions of tokens. For demonstration, our toy corpus will only include 10 texts. 

We start by extracting text from its raw form, .txt or .pdf, and placing the contents in a pandas dataframe.

The `Loader` will only handle .pdf and .txt documents:  
  - For **pdfs**, it uses `PyMuPDF` (fitz). Text blocks are extracted and treated as paragraphs.
  - For **txt**, we make the heuristic assumption that complete paragraphs are seperated by newlines.

Let's start by tabulating our text data. At the end, we save the data as .parquet for easy access later.

In [2]:
from modules import (
    get_files,
    Loader
)
from pathlib import Path

# Select the directory where documents are saved.
DOCS = Path("documents")

# Get files retrieves all files from the input directory, optionally filtering for a specefic extention.
files = get_files(DOCS, extension="pdf") + get_files(DOCS, extension="txt")

# For each file, load it into a dataframe and save it to a parquet file.
for file in files:
    loader = Loader(file)
    loader.load()
    loader.save_as_parquet(DOCS / "parquet" / f"{file.stem}.parquet")

Building dataframe from documents/pdf/Jenkins2000.pdf


Extracting text blocks:   0%|          | 0/160 [00:00<?, ?it/s]

Saved extracted document to documents/parquet/Jenkins2000.parquet
Building dataframe from documents/pdf/leighton2008.pdf


Extracting text blocks:   0%|          | 0/95 [00:00<?, ?it/s]

Saved extracted document to documents/parquet/leighton2008.parquet
Saved extracted document to documents/parquet/fisheriesnoaa.parquet
Saved extracted document to documents/parquet/marinebio.parquet
Saved extracted document to documents/parquet/sushiuniversity.parquet
Saved extracted document to documents/parquet/tokyofoundation.parquet
Saved extracted document to documents/parquet/visitcalifornia.parquet
Saved extracted document to documents/parquet/wikipedia.parquet
Saved extracted document to documents/parquet/xerces.parquet


Now that the raw text has been converted to dataframes and saved to .parquet, it can be accessed much faster.

This time, lets load all of our parqet files and combine them using the `Corpus` object.

In [3]:
from modules import Corpus

# This time, build a list of parquet files.
files = get_files(DOCS, "parquet")

# Place the files into a generator
loaders = (Loader(file) for file in files)

# Feed them into the Corpus object. It will combine the paragraphs into a single dataframe
corpus = Corpus(*loaders)

corpus.paragraphs

Unnamed: 0,document_id,paragraph
0,Jenkins2000,"Volume 31, Number 1, January 2000"
1,Jenkins2000,¥ SpecJat issue:
2,Jenkins2000,"'->?,tli Genetic Issues in Aquaculture Guest e..."
3,Jenkins2000,Blackwell Science
4,Jenkins2000,Aquaculture Research
...,...,...
2702,xerces,2. Protection of populations and their habitat...
2703,xerces,
2704,xerces,3. Captive propagation for enhancement of wild...
2705,xerces,


Looking at the paragraphs extracted from the corpus, several issues are immediately apparent: many paragraphs are underlength, and some content is missing entirely. These problems are addressed by the next component: the `Preprocessor`.

## Preprocess paragraphs for RAG

The `Preprocessor` object performs the following functions:

  1. **Filter underlength paragraphs:**  
  Underlength paragraphs are heuristically removed, eliminating most unwanted material such as section headings, captions, references, and miscellaneous floating text.  

  2. **Split paragraphs into segments:**  
  Remaining paragraphs are split into segments with clean sentence boundaries and token lengths optimized for the embedding model’s context window.

  3. **Split segments into sentences:**  
  Each segment is further divided into individual sentences, which are used downstream for sentence-level highlighting.
  4. **Compute and index embeddings:**  
  Segment embeddings are computed and indexed using FAISS for efficient retrieval.

The `Preprocessor` consists of four subobjects:
  - `Corpus`: Which we composed in the previous step. Serves a datafrome with the columns `document_id` and `paragraph`
  - `SentenceEmbedder`: Uses a Sentence Transformer (SBERT) to compute embeddings.
  - `SentenceExtractor`: Uses a spaCy model to mark clean sentence boundaries.
  - `Segmenter`: Uses the tokenizer for SBERT to mark optimzed segment boundaries for the context window of the SentenceEmbedder.

The next few cells where demonstrate the functionality of each object. For a deeper technical understanding, please look at the code in the file `modules.py`.

In [4]:
#Start by running this code to initiate our models and a paragraph for demonstration.

from sentence_transformers import SentenceTransformer
import spacy

SBERT = SentenceTransformer("all-MiniLM-L12-v2")
TOKENIZER = SBERT.tokenizer
SPACY = spacy.load("en_core_web_sm")


paragraph = r"""
Tasmania supplies about 25% of the yearly world abalone harvest.[41] Around 12,500 Tasmanians recreationally fish for blacklip and greenlip abalone. For blacklip abalone, the size limit varies between 138 mm (5.4 in) for the southern end of the state and 127 mm (5.0 in) for the northern end of the state.[42] Greenlip abalone have a minimum size of 145 mm (5.7 in), except for an area around Perkins Bay in the north of the state where the minimum size is 132 millimetres (5.2 in). With a recreational abalone licence, the bag limit is 10 per day, with a total possession limit of 20. Scuba diving for abalone is allowed, and has a rich history in Australia. (Scuba diving for abalone in the states of New South Wales and Western Australia is illegal; a free-diving catch limit of two is allowed).[43][44]Victoria has had an active abalone fishery since the late 1950s. The state is sectioned into three fishing zones, Eastern, Central and Western, with each fisher required a zone-allocated licence. Harvesting is performed by divers using surface-supplied air "hookah" systems operating from runabout-style, outboard-powered boats. While the diver seeks out colonies of abalone amongst the reef beds, the deckhand operates the boat, known as working "live" and stays above where the diver is working. Bags of abalone pried from the rocks are brought to the surface by the diver or by way of "shot line", where the deckhand drops a weighted rope for the catch bag to be connected then retrieved. Divers measure each abalone before removing from the reef and the deckhand remeasures each abalone and removes excess weed growth from the shell. Since 2002, the Victorian industry has seen a significant decline in catches, with the total allowable catch reduced from 1440 to 787 tonnes for the 2011/12 fishing year, due to dwindling stocks and most notably the abalone virus ganglioneuritis, which is fast-spreading and lethal to abalone stocks. Sport harvesting of red abalone is permitted with a California fishing license and an abalone stamp card. In 2008, the abalone card also came with a set of 24 tags. This was reduced to 18 abalone per year in 2014, and as of 2017 the limit has been reduced to 12, only nine of which may be taken south of Mendocino County. Legal-size abalone must be tagged immediately.[45] Abalone may only be taken using breath-hold techniques or shorepicking; scuba diving for abalone is strictly prohibited.[46] Taking of abalone is not permitted south of the mouth of San Francisco Bay.[47] A size minimum of 7 in (180 mm) measured across the shell is in place. A person may be in possession of only three abalone at any given time.[48][49]
"""

The `SentenceExtractor` will pick out the string position of the first and last character of each sentence in the paragraph. You can use these to extract cleane sentences from the paragraph:

In [5]:
from modules import SentenceExtractor
import pandas as pd

sentence_extractor = SentenceExtractor(SPACY)

sentence_boundaries = sentence_extractor.sentence_boundaries(paragraph)
sentences = [paragraph[b[1][0]: b[1][1]] for b in sentence_boundaries]

df = pd.DataFrame({
    "sentence_boundaries": sentence_boundaries,
    "sentence": sentences
})

print()
print(df.head())
print()
print("Sample Sentence:")
print()
print(df.loc[5, "sentence"])


  sentence_boundaries                                           sentence
0        (0, (0, 69))  \nTasmania supplies about 25% of the yearly wo...
1      (1, (70, 149))  Around 12,500 Tasmanians recreationally fish f...
2     (2, (150, 310))  For blacklip abalone, the size limit varies be...
3     (3, (311, 483))  Greenlip abalone have a minimum size of 145 mm...
4     (4, (484, 586))  With a recreational abalone licence, the bag l...

Sample Sentence:

Scuba diving for abalone is allowed, and has a rich history in Australia.


The `Segmenter` will pick out the string positions to extract segments that will fill the context window of the embedding model.

In [6]:
from modules import Segmenter
import pandas as pd

segmenter = Segmenter(TOKENIZER)

segment_boundaries = segmenter.segment_boundaries(paragraph)
segments = [paragraph[b[0]: b[1]] for b in segment_boundaries]

df = pd.DataFrame({
    "segment_boundaries": segment_boundaries,
    "segment": segments
})

print()
print(df.head())
print()
print("Sample Segment:")
print()
print(df.loc[2, "segment"])


  segment_boundaries                                            segment
0           (1, 488)  Tasmania supplies about 25% of the yearly worl...
1         (361, 899)  7 in), except for an area around Perkins Bay i...
2        (774, 1341)  limit of two is allowed).[43][44]Victoria has ...
3       (1203, 1789)  , the deckhand operates the boat, known as wor...
4       (1634, 2201)  the shell. Since 2002, the Victorian industry ...

Sample Segment:

limit of two is allowed).[43][44]Victoria has had an active abalone fishery since the late 1950s. The state is sectioned into three fishing zones, Eastern, Central and Western, with each fisher required a zone-allocated licence. Harvesting is performed by divers using surface-supplied air "hookah" systems operating from runabout-style, outboard-powered boats. While the diver seeks out colonies of abalone amongst the reef beds, the deckhand operates the boat, known as working "live" and stays above where the diver is working. Bags of abalone pr

Finally, the `Embedder` will take texts and return embeddings.

In [7]:
from modules import SentenceEmbedder

embedder = SentenceEmbedder(SBERT)
embeddings = embedder.embed(df.segment.tolist())

df["embeddings"] = list(embeddings)

print(df.head())

  segment_boundaries                                            segment  \
0           (1, 488)  Tasmania supplies about 25% of the yearly worl...   
1         (361, 899)  7 in), except for an area around Perkins Bay i...   
2        (774, 1341)  limit of two is allowed).[43][44]Victoria has ...   
3       (1203, 1789)  , the deckhand operates the boat, known as wor...   
4       (1634, 2201)  the shell. Since 2002, the Victorian industry ...   

                                          embeddings  
0  [0.07767071, -0.0057242527, -0.05175696, 0.008...  
1  [0.073290914, -0.0001915124, -0.012798946, 0.0...  
2  [0.05552063, -0.01772244, 0.00100956, 0.013807...  
3  [0.05004228, 0.03330825, -0.016384125, 0.03617...  
4  [0.031442273, -0.0028620476, -0.041476786, -0....  


The `Preprocessor` object ties all of these functions together. Using the sentence and segment boundaries, it infers the segment boundaries with clean sentence edges and optimal token length. It performs other cleaning functions as well, such as filtering underlength paragraphs and deduplicating segments with common anchors.

This time, instead of the pargarph, we will preprocess the full corpus that we prepared earlier.

In [8]:
from modules import Preprocessor

preprocessor = Preprocessor(
    corpus,
    embedder,
    sentence_extractor,
    segmenter
)

kwargs = {
    # Controls the minimum viable lenght of a paragraph. Defaults to 32.
    "min_paragraph_length": 32,

    # Controls the number of overlapping tokens between segments. Defaults to 32.
    "stride": 32
}

preprocessor.preprocess(**kwargs)

Sentence boundaries:   0%|          | 0/949 [00:00<?, ?it/s]

Extracting sentences:   0%|          | 0/949 [00:00<?, ?it/s]

Segment boundaries:   0%|          | 0/949 [00:00<?, ?it/s]

Aligning segments:   0%|          | 0/949 [00:00<?, ?it/s]

Embedding segments:   0%|          | 0/1525 [00:00<?, ?it/s]

Unnamed: 0,document_id,sentences,segment,segment_embedding,segment_id
0,Jenkins2000,"[Editorial Board G. Allan Taylors Beach, Austr...","Editorial Board G. Allan Taylors Beach, Austra...","[0.025996055, 0.018459106, 0.01713328, 0.02857...",Jenkins2000-000
1,Jenkins2000,"[S.J. Kaushik Saint-Pie-sur-Nivelle, France G....","S.J. Kaushik Saint-Pie-sur-Nivelle, France G.W...","[0.05069863, 0.09893627, 0.050703455, -0.05172...",Jenkins2000-001
2,Jenkins2000,[Aims and scope Aquaculture Research is an int...,Aims and scope Aquaculture Research is an inte...,"[-0.010691009, 0.063925244, -0.05688352, -0.04...",Jenkins2000-002
3,Jenkins2000,[In addition to publishing papers reporting re...,In addition to publishing papers reporting res...,"[0.057243183, 0.07570791, -0.0029028028, 0.031...",Jenkins2000-003
4,Jenkins2000,[Despatch Aquaculture Research is despatched w...,Despatch Aquaculture Research is despatched wi...,"[0.0126762865, -0.020529907, -0.003612882, -0....",Jenkins2000-004
...,...,...,...,...,...
1520,xerces,[By the 1990s the species was nearly extirpate...,By the 1990s the species was nearly extirpated...,"[0.1002283, 0.034409385, -0.08829971, 0.012946...",xerces-007
1521,xerces,"[Approximately 1,600 individuals remain, and i...","Approximately 1,600 individuals remain, and it...","[0.04162371, 0.005460795, 0.009169072, 0.02119...",xerces-008
1522,xerces,[The white abalone was listed as an endangered...,The white abalone was listed as an endangered ...,"[0.029424477, 0.013186502, -0.041589364, 0.067...",xerces-009
1523,xerces,[The main threats to the continued existence o...,The main threats to the continued existence of...,"[0.0076671536, 0.07647452, -0.07140951, 0.0767...",xerces-010


## Indexing and Retrieval

Next, we build a FAISS index using the `IngestPipeline`. This object takes the preprocessed documents and produces two outputs:

- **FAISS index**: An ordered collection of precomputed segment embeddings, optimized for fast similarity search.
- **DocStore**: A companion data store containing all relevant metadata for each segment—most importantly, the segment text itself. Given the indices of the top-k retrieved vectors from FAISS, the DocStore is used to look up and return the corresponding text and attributes.

Optionally, the `IngestPipeline` can also accept a metadata dictionary keyed by `document_id`. This metadata is joined onto the segments during ingestion. In this example, we use it to attach source URLs, but it is also common to include fields such as document title, author, publication date, or a trust score to support richer and more controllable RAG applications.


In [9]:
from modules import IngestPipeline, load_json

# Load the metadata containing URLS for the documents
METADATA = load_json("documents/metadata.json")

INDEX = "index"

# Initiate the pipeline
pipeline = IngestPipeline(
    preprocessor,
    metadata=METADATA
)

# Run activates the preprocessor, and builds the FAISS index and DocStore Objects
pipeline.run()

# This will save your index to a specefied directory.
pipeline.save_index(INDEX)

# Optionally, push your index to HuggingFace hub.
# This is reccomended if you want to use your index within HuggingFace Spaces.
# You will need to set up a login mechanism. Please see the HF Docs

pipeline.save_index(
    INDEX,
    push_to_hub=True,
    repo_id = "LoneWolfgang/Abalone-mini-index"
)


Saved index to index/index.faiss
Saved docstore with 1525 entries to index/docstore.pkl
Number of vectors in index: 1525
Saved index to index/index.faiss
Saved docstore with 1525 entries to index/docstore.pkl
Number of vectors in index: 1525
✓ Repo exists: LoneWolfgang/Abalone-mini-index


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

No files have been modified since last commit. Skipping to prevent empty commit.


Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

No files have been modified since last commit. Skipping to prevent empty commit.


Successfully pushed index and docstore to Hugging Face Hub


Now that we have built our index, let’s test the `Retriever`.

The Retriever combines three components:
  - **Index**: Built in the previous step.
  - **SBERT (bi-encoder)**: Used to embed queries and retrieve the top-k candidate segments.
  - **Cross Encoder**: A more powerful retrieval model, used to select the best matching segment and the best sentence for highlighting.

Cross encoders are generally regarded as superior to bi-encoders (e.g., SBERT) for retrieval accuracy. Unlike bi-encoders, which embed queries and documents independently, cross encoders compute a similarity score from a joint encoding of both the query and the candidate text. Because cross encoders require the query at inference time, their embeddings cannot be precomputed. As a result, while they offer higher precision and recall, they do not scale as well as bi-encoders.

By using bi-encoders for initial retrieval and cross encoders to re-rank the top-k candidates, we combine the scalability of bi-encoders with the accuracy of cross encoders.

In [10]:
from modules import Retriever
from sentence_transformers import CrossEncoder

CROSS_ENDCODER = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

retriever = Retriever(
    index_dir = INDEX,
    sbert = SBERT,
    cross_encoder = CROSS_ENDCODER
)

In [11]:
query = "What age do abalone express their gender?"

result = retriever.retrieve(query, metadata_fields=["url"])

print("Query:")
print(query)
print()
print("Best Segment:")
print(result["text"])
print()
print("Highlighted Sentence:")
print(result["highlight"])
print()
print("Source:")
print(result["url"])

Query:
What age do abalone express their gender?

Best Segment:
**White abalone have a life expectancy of about thirty-five to forty-years, and reach sexual maturity around 4 to 6 years.** They are dioecious, meaning that they have separate sexes (NMFS 2006). They have high fecundity and reproduce through ‘broadcast spawning’; millions of eggs or sperm are released during spawning events (Leighton 2000). In order to successfully reproduce, the male and female abalone must be close enough together – within a few meters (Davis et al. 1996). Commercial fishing pressures reduced its population densities so low that remaining individuals are too scattered to successfully reproduce.

Highlighted Sentence:
White abalone have a life expectancy of about thirty-five to forty-years, and reach sexual maturity around 4 to 6 years.

Source:
https://www.xerces.org/endangered-species/species-profiles/at-risk-aquatic-invertebrates/white-abalone


## RAG with TinyLlama

Now that retrieval is all set up, we can put together a RAG application. For this demonstration, we will be using [TinyLlama](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0). This is a 1.1B parameter model. It's not particularly powerful, nor does it it run well on a CPU. However, it is decently expressive, and its usable. A nice tradeoff for a tech demo.

Let's put together an app using the `RAG` object.

In [12]:
from transformers import pipeline
from modules import RAG

generator = pipeline(
    "text-generation",
    model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    max_new_tokens=150,
    temperature=0.1
    )

rag = RAG(retriever, generator)

Device set to use mps:0


In [13]:
query = "How deep in the water do abalone live?"

result = rag.answer_query(query)

print("Query:")
print(query)
print()
print("Response:")
print(result["response"])
print()
print("Context:")
print(result["context"])
print()
print("Source:")
print(result["url"])

Query:
How deep in the water do abalone live?

Response:
Answer: Abalone live at depths of 50 to 180 feet, making them the deepest living abalone species.

Context:
White abalone live on low-relief rocky substrates, typically alongside sand channels, which tend to accumulate the algae they eat. **They are usually found at depths of 50 to 180 feet, making them the deepest living abalone species.** Historically, white abalone were found in the Pacific Ocean from Point Conception, California, to Punta Abreojos, Baja California, in Mexico. In California, they were most abundant at offshore islands (especially San Clemente and Santa Catalina Islands) and submerged banks (primarily Tanner and Cortes Banks).

Source:
https://www.fisheries.noaa.gov/species/white-abalone


Generally, the responses range in accuracy. The biggest problem that I had was that it tended to be overly verbose. Let's see if we can improve the quality of responses through finetuning.

A finetuning dataset using GPT-5
  - First, ChatGPT-5 was instructed to generate 100 queries about abalone covering a range of topics.
  - Then, provided with the queries and retrieval results, it was instructed to generate a concise and accurate response.

Let's look at the dataset:

In [14]:
from datasets import load_dataset

train = load_dataset("LoneWolfgang/finetune-chat-for-abalone-RAG")['train'].to_pandas()

sample = train.sample(2)

for _, row in sample.iterrows():
    print("Query:")
    print(row["query"])
    print()
    print("Context:")
    print(row["context"])
    print()
    print("Response:")
    print(row["response"])
    print()
    print()

Query:
What diseases are common in abalone aquaculture?

Context:
An infestation of sabellids is generally not a life-threatening situation for the abalone.  **It is, however, a very serious problem for an aquaculture producer. ** If the abalone become heavily infested, their growth rate plummets.  In an aquaculture environment, reduced growth rates mean added costs and reduced profits. In addition, the appearance of the affected shells limits the product value in the premium markets for live abalone.  Due to the shape of their shells, heavily infested abalone are often unable to right

Response:
Sabellid infestations are the main disease concern, reducing growth and shell quality as noted in the context.


Query:
How do abalones protect themselves from desiccation at low tide?

Context:
themselves when dislodged from their substrate. The mass handlings associated with aquaculture production inevitably lead to higher mortalities when working with infested abalone, as it is impossible t

In hindsight, I should have provided GPT-5 with clearer instructions for response formulation, but the results are sufficient.

**Now, let’s fine-tune TinyLlama. This training script assumes access to GPU acceleration. If you do not have a GPU, the script will fail. When using a CPU, you should adapt this code to a parameter-efficient fine-tuning approach, such as LoRA.**

First, we prepare the training dataset using `RagDatasetBuilder`. This object is hardcoded for this specific task. It takes a pandas DataFrame with the columns `context`, `query`, and `response`, and converts them into a `datasets.Dataset` suitable for RAG fine-tuning.

When fine-tuning, it is important to mask the query and context and train only on the response. Failing to do so slows convergence and can encourage the model to memorize the question and context rather than learn to generate good answers. `RagDatasetBuilder` handles this masking automatically, and the `sanity_check` method allows you to inspect exactly what is happening under the hood.

In [15]:
from modules import RagDatasetBuilder
from transformers import AutoTokenizer

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token

builder = RagDatasetBuilder(tokenizer)
dataset = builder.build_dataset(train)

builder.sanity_check(dataset)

Map:   0%|          | 0/101 [00:00<?, ? examples/s]

FULL MODEL INPUT (conditioning + target)
<|system|>
You answer questions strictly using the provided context.

<|user|>
Context: kamtschatkana.  **Fifty individually caged abalone were held in a common aquarium tank with a constant flow of fresh ambient seawater and fed ad libitum on kelp (Nereocystis leutkeuna). ** The abalone were divided into five groups of ten animals each. Every group had a similar mean weight (78 g) and length (7 cm).
Question: What is the scientific name for the common abalone?
<|assistant|> The context mentions *kamtschatkana*, indicating the scientific name referenced there.</s>

TOKENS CONTRIBUTING TO LOSS (response only)
The context mentions *kamtschatkana*, indicating the scientific name referenced there.</s>

MASKING STATS
Total tokens     : 153
Masked tokens    : 134
Unmasked tokens  : 19
Sanity check passed: response-only masking is correct.


Now that the dataset is ready to go, lets put together the rest of the training script.

In [16]:
from transformers import AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer
import torch

model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"


model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

training_args = TrainingArguments(
    output_dir="./tinyllama-finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    warmup_steps=50,
    learning_rate=2e-4,
    bf16=True,
    fp16=False,
    logging_steps=5,
    report_to="none",
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=training_args,
)


trainer.train()

trainer.model.save_pretrained("./tinyllama-finetuned")
tokenizer.save_pretrained("./tinyllama-finetuned")

`torch_dtype` is deprecated! Use `dtype` instead!


Truncating train dataset:   0%|          | 0/101 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}.


Step,Training Loss
5,2.5625
10,1.4966
15,1.0446
20,0.4472


('./tinyllama-finetuned/tokenizer_config.json',
 './tinyllama-finetuned/special_tokens_map.json',
 './tinyllama-finetuned/chat_template.jinja',
 './tinyllama-finetuned/tokenizer.json')

Now, let's see how the quality of responses has changed.

In [17]:
from transformers import pipeline
from modules import RAG

generator = pipeline(
    "text-generation",
    model="./tinyllama-finetuned",
    max_new_tokens=150,
    temperature=0.1
    )

rag = RAG(retriever, generator)

Device set to use mps:0


In [18]:
query = "What sort of sauce goes well with abalone?"

result = rag.answer_query(query)

print("Query:")
print(query)
print()
print("Response:")
print(result["response"])
print()
print("Context:")
print(result["context"])
print()
print("Source:")
print(result["url"])

Query:
What sort of sauce goes well with abalone?

Response:
The context notes soy sauce and other seasonings go well with abalone boiled with soy sauce.

Context:
**There are two types of cooking methods for Ni-awabi (boiled abalones): Saka-ni (sake-boiling), in which abalones are boiled with a lot of sake, and shoyu-ni (soy sauce-boiling), in which abalones are boiled with soy sauce and other seasonings.**

Source:
https://sushiuniversity.jp/visual-dictionary/?Name=Japanese-abalone-(kuro-awabi)


Final step, post your model to HuggingFace. This is reccomended if you intend in deploying it to HuggingFace Spaces.

In [19]:
trainer.model.push_to_hub("LoneWolfgang/tinyllama-for-abalone-RAG")
trainer.processing_class.push_to_hub("LoneWolfgang/tinyllama-for-abalone-RAG")

Processing Files (0 / 0): |          |  0.00B /  0.00B            

New Data Upload: |          |  0.00B /  0.00B            

No files have been modified since last commit. Skipping to prevent empty commit.
No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/LoneWolfgang/tinyllama-for-abalone-RAG/commit/e161dd87667ea2bf6470f2c1587f97f997119c60', commit_message='Upload tokenizer', commit_description='', oid='e161dd87667ea2bf6470f2c1587f97f997119c60', pr_url=None, repo_url=RepoUrl('https://huggingface.co/LoneWolfgang/tinyllama-for-abalone-RAG', endpoint='https://huggingface.co', repo_type='model', repo_id='LoneWolfgang/tinyllama-for-abalone-RAG'), pr_revision=None, pr_num=None)