## Imports

In [1]:
# %pip install transformers
# %pip install -r requirements.txt

In [2]:
# Standard library imports
import os
import re
import logging
from enum import Enum
from typing import Any
from uuid import uuid4
from datetime import datetime, timedelta

# Third-party libraries
import pandas as pd
import numpy as np
import torch
import newspaper as news
from tqdm.notebook  import tqdm
from newspaper.mthreading import fetch_news
from gdeltdoc import GdeltDoc, Filters
from spacy.lang.en import English
from qdrant_client import QdrantClient
from huggingface_hub import HfFolder, whoami
from qdrant_client.models import Distance, VectorParams, PointStruct, QueryResponse
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers.utils import is_flash_attn_2_available
from sentence_transformers import SentenceTransformer, CrossEncoder

## Init Logger

In [3]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)

## Login to hugging face to gain access to models

In [4]:
def hf_login():
    """
    Helper method to log into hugging face to gain access to those models
    
    **NOTE**: exception will be raised if users hugging face access token is not 
              stored in the environment variable: HF_TOKEN
    """
    HF_TOKEN = os.getenv("HF_TOKEN")
    if HF_TOKEN is None:
        raise ValueError("CANNOT FIND HUGGING FACE ACCESS TOKEN")

    HfFolder.save_token(HF_TOKEN)

    if logger.isEnabledFor(logging.DEBUG):
        user = whoami()
        logger.debug(f"Logged into hugging face as: {user['fullname']} - {user['name']}")

In [5]:
hf_login()

2025-08-04 02:04:04,974 - DEBUG - Logged into hugging face as: jacob - hundredcrane120


## Check CUDA

In [6]:
def check_cuda_device():
    """Helper method to print the connected cuda device"""
    if torch.cuda.is_available():
        logger.debug(f"CUDA detected. Device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    else:
        logger.debug("CUDA not detected. Embedding model will remain on CPU")

check_cuda_device()

2025-08-04 02:04:04,985 - DEBUG - CUDA detected. Device: NVIDIA GeForce RTX 5070 Ti


In [7]:
def check_cuda_memory():
    if torch.cuda.is_available():
        gpu_id = torch.cuda.current_device()
        total = torch.cuda.get_device_properties(gpu_id).total_memory
        reserved = torch.cuda.memory_reserved(gpu_id)
        allocated = torch.cuda.memory_allocated(gpu_id)
        free = reserved - allocated
    
        logger.debug(f"Total memory:     {total / 1e6:.2f} MB")
        logger.debug(f"Reserved memory:  {reserved / 1e6:.2f} MB")
        logger.debug(f"Allocated memory: {allocated / 1e6:.2f} MB")
        logger.debug(f"Free within reserved: {free / 1e6:.2f} MB")
    else:
        logger.debug("No CUDA device available.")

check_cuda_memory()

2025-08-04 02:04:04,987 - DEBUG - Total memory:     17094.48 MB
2025-08-04 02:04:04,987 - DEBUG - Reserved memory:  0.00 MB
2025-08-04 02:04:04,987 - DEBUG - Allocated memory: 0.00 MB
2025-08-04 02:04:04,987 - DEBUG - Free within reserved: 0.00 MB


## Initialize Qdrant local server

In [8]:
# Length of embedding from using the model: all-MiniLM-L6-v2
EMBEDDING_LENGTH = 384
COLLECTION_NAME = "News_test"

In [9]:
# VECTOR_DB.delete_collection(collection_name=COLLECTION_NAME)

In [10]:
VECTOR_DB = QdrantClient(host="localhost", port=6333)

if not VECTOR_DB.collection_exists(COLLECTION_NAME):
    VECTOR_DB.create_collection(
        collection_name=COLLECTION_NAME,
        vectors_config=VectorParams(size=EMBEDDING_LENGTH, distance=Distance.COSINE)
    )

## Initialize Models

In [11]:
def create_reranking_model(
    model_name: str = "cross-encoder/ms-marco-TinyBERT-L-2-v2"
) -> CrossEncoder:
    """Helper method to create a reranker model and send it to cuda"""
    logger.debug("Creating re-ranker model")

    model = CrossEncoder(model_name, device="cpu")

    # if torch.cuda.is_available():
    #     model = model.to("cuda")
    #     logger.debug(f"CUDA detected. Embedding model moved to {torch.cuda.get_device_name(torch.cuda.current_device())}")
    # else:
    #     logger.debug("CUDA not detected. Embedding model will remain on CPU")

    return model    

In [12]:
def create_embedding_model(model_name: str = "all-MiniLM-L6-v2") -> SentenceTransformer:
    """Helper method to create the model and send it to cuda"""
    logger.debug("Creating embedding model")
    
    model = SentenceTransformer(model_name_or_path="all-MiniLM-L6-v2", device="cpu")

    # if torch.cuda.is_available():
    #     model = model.to("cuda")
    #     logger.debug(f"CUDA detected. Reranker model moved to {torch.cuda.get_device_name(torch.cuda.current_device())}")
    # else:
    #     logger.debug("CUDA not detected. Reranker model will remain on CPU")

    return model    

In [13]:
def create_LLM(
    model_name: str = "mistralai/Mistral-7B-Instruct-v0.3"
) -> tuple[AutoTokenizer, AutoModelForCausalLM]:
    """Helper method to create the LLM for generation and the associated tokenizer + send it to cuda"""
    logger.debug("Creating LLM model")

    
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
    
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4"
    )
    
    config = {
        "pretrained_model_name_or_path": model_name,        
        "torch_dtype": torch.float16,
        "device_map": "auto",
        "quantization_config": quant_config
    }
    
    # check if flash_attention_2 can be used
    if is_flash_attn_2_available() and (torch.cuda.get_device_capability(0)[0] >= 8):
        logger.debug(f"Using flash_attention_2")
        config["attn_implementation"] = "flash_attention_2"
    
    llm_model = AutoModelForCausalLM.from_pretrained(**config)

    return tokenizer, llm_model    

In [14]:
EMBEDDING_MODEL = create_embedding_model()
RERANKER_MODEL = create_reranking_model()
TOKENIZER, LLM_MODEL = create_LLM()

logger.debug("All models created")
check_cuda_memory()

2025-08-04 02:04:19,755 - DEBUG - Creating embedding model
2025-08-04 02:04:21,219 - DEBUG - Creating re-ranker model
2025-08-04 02:04:21,843 - DEBUG - Creating LLM model


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-08-04 02:04:31,407 - DEBUG - All models created
2025-08-04 02:04:31,413 - DEBUG - Total memory:     17094.48 MB
2025-08-04 02:04:31,413 - DEBUG - Reserved memory:  7228.88 MB
2025-08-04 02:04:31,413 - DEBUG - Allocated memory: 4141.95 MB
2025-08-04 02:04:31,413 - DEBUG - Free within reserved: 3086.93 MB


In [20]:
_ = EMBEDDING_MODEL.to("cuda")

In [18]:
_ = RERANKER_MODEL.to("cuda")

In [21]:
check_cuda_memory()

2025-08-04 02:05:44,457 - DEBUG - Total memory:     17094.48 MB
2025-08-04 02:05:44,457 - DEBUG - Reserved memory:  7245.66 MB
2025-08-04 02:05:44,457 - DEBUG - Allocated memory: 4250.36 MB
2025-08-04 02:05:44,459 - DEBUG - Free within reserved: 2995.30 MB


## Extract data from GDelt

In [22]:
def log_tqdm(iterable, desc=None, ignore: bool = False):
    """Only use tqdm progress bar while in debugging"""
    if not ignore and logger.isEnabledFor(logging.DEBUG):
        return tqdm(iterable, desc=desc)

    return iterable

In [23]:
class Source(Enum):
    # CBC = "cbc.ca"
    CNN = "cnn.com"
    # AP = "apnews.com"    # wasn't working for some reason
    BBC = "bbc.co.uk"
    # Wired = "wired.com"
    # Reuters = "reuters.com"
    NyTimes = "nytimes.com"
    Guardian = "theguardian.com"

In [24]:
def format_text(text: str) -> str:
    """Removes newline characters and leading/trailing whitespace"""
    return text.replace("\n", " ").strip()    

In [25]:
def gdelt_row_to_dict(gd_row: pd.Series) -> dict:
    """Helper method to convert the output of a gd.article_search row into a dict"""
    return {
        "url": gd_row["url"],
        "title": gd_row["title"],
        "domain": gd_row["domain"],
        "country": gd_row["sourcecountry"],
    }

In [None]:
def article_to_dict(art: news.Article) -> dict:
    """Helper method to convert a Article into a dict"""
    return {
        "text": format_text(art.text),
        "title": art.title,
        "authors": art.authors,
        "date": art.publish_date,
        "source": art.source_url,
        "url": art.original_url
    }

In [27]:
def gdelt_stories(
    keywords: str | None = None,
    theme: str | None = None,
    sources: list[Source] | None = None,
    start_date: datetime = datetime.now().date() - timedelta(days=7),
    end_date: datetime = datetime.now().date(),
    num_records: int = 5
) -> list[dict]:
    """
    Retrieves metadata and urls for relevant stories via the GDelt API
    """
    
    if keywords and len(keywords) < 5:
        raise ValueError("Keywords must be gte 5 characters or the Gdelt API errors.")

    sources = sources or [s for s in Source]        
    sources = [s.value for s in sources]
        
    if len(sources) < 2:
        raise ValueError("Number of sources must be gte 2 or the Gdelt API errors.")

    if start_date > datetime.now().date() or end_date > datetime.now().date():
        raise ValueError("News cannot be in the future...")

    if start_date > end_date:
        raise ValueError("How you gunna end before you start?!?!")        
     
    sources = [s for s in sources]
    
    start_str = start_date.strftime("%Y-%m-%d")
    end_str = end_date.strftime("%Y-%m-%d")

    if keywords and theme:
        f = Filters(
            start_date = start_str,
            end_date = end_str,
            num_records = num_records,
            domain = sources,
            country = ["UK", "US"],
            language = "eng",
            keyword = keywords,
            theme=theme
        )
    elif keywords:
        f = Filters(
            start_date = start_str,
            end_date = end_str,
            num_records = num_records,
            domain = sources,
            country = ["US"],#, "UK"],
            language = "eng",
            keyword = keywords,
        )
    elif theme:
        f = Filters(
            start_date = start_str,
            end_date = end_str,
            num_records = num_records,
            domain = sources,
            country = ["UK", "US"],
            language = "eng",
            theme=theme
        )
    else:
        raise ValueError("both theme and keywords cannot be empty")
        
    gd = GdeltDoc()
    search_results = gd.article_search(f)

    results = []
    for i, row in search_results.iterrows():
        results.append(gdelt_row_to_dict(row))

    return results

In [28]:
def get_articles(urls: list[str] | str) -> list[dict]:
    """
    Queries newspaper4k for the article associated with the provided url

    Note:   fetch_news is an alternative to mutlithread the retrieval process.
            However, this does not include the text property sooo were stuck with this.
    """

    urls = urls if isinstance(urls, list) else [urls]
    results = []
    view_titles = set()
    for url in log_tqdm(urls, "Collecting Stories"):

        # news.articles is an expensive operation so verify 
        # the current story hasn't already been retrieved
        possible_title = url.rsplit('/', 1)[-1]
        if possible_title in view_titles:
            logger.debug(f"Removed duplicate story: {possible_title}")
            continue
        view_titles.add(possible_title)
        
        try:
            art = news.article(url)         
        except:
            # in case an article doesn't exist (404 returned) or can't be accessed 
            continue
        
        results.append(article_to_dict(art))
    
    return results

In [29]:
def retrieve_news(
    keywords: str | None = None,
    theme: str | None = None,
    sources: list[Source] | None = None,
    start_date: datetime = datetime.now().date() - timedelta(days=7),
    end_date: datetime = datetime.now().date(),
    num_stories: int = 5,
) -> list[dict]:
    """
    Retrieves stories related to the specified keywords

    **NOTE**: theme must be a value from the available list here: 
    http://data.gdeltproject.org/api/v2/guides/LOOKUP-GKGTHEMES.TXT
    """

    logger.debug("Gdelt: retrieval beginning")
    story_links = gdelt_stories(
        keywords=keywords, 
        theme=theme,
        sources=sources, 
        start_date=start_date, 
        end_date=end_date, 
        num_records=num_stories
    )

    logger.debug("newpaper4k: retrieval beginning")
    stories = get_articles(
        urls=[story["url"] for story in story_links]
    )

    return stories

#### Example for retrieving stories:

\*\*Note\*\*: retrieve_news returns a list of dicts in the format created by the article_to_dict function

(i.e., keys=[title, text, domain, country, author, url, date])


In [30]:
# stories = retrieve_news(theme="IMMIGRATION", num_stories=10)

In [31]:
# stories[0].keys()

In [32]:
# [u["url"] for u in stories]

Results from above:

dict_keys(['text', 'title', 'authors', 'date', 'source', 'url'])

['https://www.theguardian.com/us-news/2025/jul/29/states-sue-trump-administration-snap-recipients-data',
 'https://www.bbc.co.uk/news/articles/clyjggjplyqo',
 'https://www.theguardian.com/us-news/2025/jul/30/ice-hiring-incentives-signing-bonuses',
 'https://www.cnn.com/2025/07/30/politics/immigration-employees-reader-callout',
 'https://www.theguardian.com/us-news/2025/jul/28/trump-acknowledges-real-starvation-in-gaza-and-tells-israel-to-let-in-every-ounce-of-food',
 'https://www.theguardian.com/us-news/2025/aug/01/judge-tps-temporary-protected-status-trump-deportation',
 'https://www.theguardian.com/society/2025/jul/30/population-migration-england-wales-data',
 'https://www.theguardian.com/world/2025/jul/30/mexico-sheinbaum-alligator-alcatraz-trump',
 'https://www.theguardian.com/uk-news/2025/aug/01/social-media-ads-promoting-small-boat-crossings-uk-banned']

## Test samples for creating article embeddings

In [33]:
test_stories = retrieve_news(
    theme="IMMIGRATION", 
    num_stories=100, 
    start_date= datetime.now().date() - timedelta(days=90),
)

2025-08-04 02:05:58,933 - DEBUG - Gdelt: retrieval beginning
2025-08-04 02:06:00,301 - DEBUG - newpaper4k: retrieval beginning


Collecting Stories:   0%|          | 0/100 [00:00<?, ?it/s]

2025-08-04 02:06:26,971 - DEBUG - Removed duplicate story: immigration-birthright-citizenship-us-dg
2025-08-04 02:06:27,518 - DEBUG - Removed duplicate story: immigration-employees-reader-callout
2025-08-04 02:06:29,175 - DEBUG - Removed duplicate story: ice-arrests-migrants-courthouse
2025-08-04 02:06:31,470 - DEBUG - Removed duplicate story: deportations-backfiring-trump-analysis
2025-08-04 02:06:36,111 - DEBUG - Removed duplicate story: guatemalan-migrant-deported-mexico-trump-administration-return
2025-08-04 02:06:36,111 - DEBUG - Removed duplicate story: guatemalan-migrant-deported-mexico-trump-administration-return
2025-08-04 02:06:37,298 - DEBUG - Removed duplicate story: sanctuary-immigration-policies-chicago-illinois-lawsuit-dismissed


## Chunk the articles

In [34]:
def copy_list_of_dicts(dict_list: list[dict]) -> list[dict]:
    """Helper method to deep copy a list of dicts (not-recursively)"""
    logger.debug("Creating copy of target list of dicts")
    new_list = []
    for item in dict_list:
        new_list.append(item.copy())

    return new_list

In [35]:
def sentencize_stories(stories: list[dict]) -> list[dict]:
    """
    Uses spacy to convert the block of text provided by newspaper4k into sentences

    **Note**: This is an inplace operation
    """
    nlp = English()
    _ = nlp.add_pipe("sentencizer")
    
    logger.debug("Breaking text into sentences")
    # convert to sentences and ensure dtype is a str not spacy specific typing
    for story in log_tqdm(stories, "Sentencizing"):
        story["sents"] = list(nlp(story["text"]).sents)
        story["sents"] = [str(s) for s in story["sents"]]

    return stories

In [36]:
def split_list(input_list: list[Any], max_item_count: int) -> list[list[Any]]:
    """
    Splits a list of strings into a seperate lists with specified maximum number of items
    """
    return [input_list[i:i+max_item_count] for i in range(0, len(input_list), max_item_count)]

In [37]:
def chunk_sentences(stories: list[dict], chunk_size: int = 10) -> list[dict]:
    """
    Breaks the list of sentences into sublists with a maximum item count of chunk_size
    
    **Note**: This is an inplace operation
    """
    logger.debug(f"Chunking sentences, chunk_size: {chunk_size}")
    
    if "sents" not in stories[0].keys():
        raise ValueError("Out of order operation: Cannot chunk sentences before they exist")
        
    for story in log_tqdm(stories, "Chunking"):
        story["sent_chunks"] = split_list(story["sents"], chunk_size)
        
    return stories

In [38]:
def article_length_metadata(stories: list[dict]) -> list[dict]:
    """
    Collects minor metadata regarding the length of the article
    
    **Note**: This is an inplace operation
    """
    logger.debug("Adding article length metadata")
    
    if any(key not in stories[0] for key in ["sents", "sent_chunks"]):
        raise ValueError("Out of order operation: Collecting metadata requires all sub-items to be populated")
        
    for story in log_tqdm(stories, "Collecting Metadata"):
        story["num_sents"] = len(story["sents"])
        story["num_tokens"] = len(story["text"]) // 4
        story["num_words"] = len(story["text"].split(" "))
        story["num_chars"] = len(story["text"])
        story["num_chunks"] = len(story["sent_chunks"])

    return stories

In [39]:
def seperate_into_chunk_list(story_list: list[dict]) -> list[dict]:
    """
    Seperates the chunks in the story dict into individual dicts and returns
    a list with each of these chunks as their own item
    """
    logger.debug("Creating list of chunks from sublist in article dict")
    
    chunk_list = []
    for item in log_tqdm(story_list, "Seperating Chunks"):
        # print(item["sent_chunks"])
        for chunk in item["sent_chunks"]:

            # populate each chunk with articles original metadata
            chunk_dict = {
                "title": item["title"],
                "authors": item["authors"],
                "date": item["date"],
                "source": item["source"],
                "url": item["url"],
            }
    
            # rejoin chunk sentences and format to more natural text (i.e., format end of sentence text)
            chunk_text = "".join(chunk).replace("  ", " ").strip()
            chunk_text = re.sub(r"\.([A-Z])", r". \1", chunk_text)    
            chunk_dict["text"] = chunk_text

            # don't log here because it would get messy
            chunk_list.append(chunk_dict)

    return chunk_list

In [None]:
def chunk_length_metadata(chunks: list[dict]) -> list[dict]:
    """
    Collects metadata regarding the length of the chunks
    
    **Note**: This is an inplace operation
    """
    logger.debug("Adding article length metadata")
    
    for chunk in log_tqdm(chunks, "Collecting Metadata"):
        chunk["num_tokens"] = len(chunk["text"]) // 4
        chunk["num_words"] = len(chunk["text"].split(" "))
        chunk["num_chars"] = len(chunk["text"])

    return chunks

In [41]:
def chunk_articles(stories: list[dict], sentences_per_chunk: int = 10) -> list[dict]:
    """
    Converts the list of stories (stored as dicts) into a list chunks (also dicts)

    sentences_per_chunks sets the number of sentences used in each chunk
    """
    logger.debug("Chunking list of articles")

    # create chunk information within a copy of the stories list[dict]
    story_copy = copy_list_of_dicts(stories)
    story_copy = sentencize_stories(story_copy)
    story_copy = chunk_sentences(story_copy)

    # create a new chunk list[dict] from result
    chunks = seperate_into_chunk_list(story_copy)
    chunks = chunk_length_metadata(chunks)
    
    return chunks    

## Create embeddings from chunks

In [42]:
def add_embeddings(
    chunk_list: list[dict], 
    embedding_model: SentenceTransformer
) -> list[dict]:
    """
    Uses the embedded model to create embeddings from the text values in the provided chunks

    **Note**: this operation is inplace
    """
    logger.debug("Creating embeddings for chunks")

    chunk_texts = [chunk["text"] for chunk in chunk_list]
    chunk_embeddings = embedding_model.encode(
        chunk_texts, 
        batch_size=32, 
        convert_to_tensor=True, 
        show_progress_bar=logger.isEnabledFor(logging.DEBUG)
    )

    for chunk, embedding in zip(chunk_list, chunk_embeddings):
        chunk["embedding"] = embedding

    return chunk_list  

In [43]:
def chunk_and_embed(
    articles_list: list[dict], 
    embedding_model: SentenceTransformer | None = None,
) -> list[dict]:

    # initialize a model if one was not provided
    embedding_model = embedding_model or EMBEDDING_MODEL

    # create chunks and add embeddings
    chunks = chunk_articles(articles_list)
    chunks = add_embeddings(chunks, embedding_model=embedding_model)

    return chunks

## Data Storage in QDrant

In [44]:
def add_chunk_list_to_db(chunk_list: dict):
    """Adds a single chunk to the vector database"""
    logger.debug(f"Adding chunks to vector DB, collection: {COLLECTION_NAME}")

    points = [
        PointStruct(
            id = str(uuid4()),
            vector = chunk["embedding"],
            payload = {
                "source": chunk["source"],
                "date": chunk["date"],
                "url": chunk["url"],
                "title": chunk["title"],
                "authors": chunk["authors"],
                "text": chunk["text"],
            }
        )
        for chunk in chunk_list
    ]

    VECTOR_DB.upsert(collection_name=COLLECTION_NAME, points=points)
    

In [45]:
# chunks = chunk_and_embed(test_stories)
# add_chunk_list_to_db(chunks)

## Qdrant Searching

In [46]:
def query_response_to_dict(
    response: QueryResponse, 
    embedding_model: SentenceTransformer
) -> list[dict]:
    """Converts the output of the qdrant similarity search to a list of dicts"""
    logger.debug("Converting query results into a dictionary")

    results = []
    for point in response.points:
        new_dict = point.payload.copy()
        new_dict["score"] = point.score
        new_dict["vector"] = point.vector
        results.append(new_dict)

    return results

In [47]:
def query_db(
    query_str: str, embedding_model: SentenceTransformer, num_retrieve: int = 5
) -> QueryResponse:
    """Method for searching for qdrant db for similar chunks to provided quote"""
    
    # leave embeddings in default type. Tensor isn't accepted and converting to np increases runtime.
    query_embeddings = embedding_model.encode(query_str)
    results = VECTOR_DB.query_points(
        collection_name=COLLECTION_NAME,
        query=query_embeddings,
        limit=num_retrieve,
        with_payload=True,
        with_vectors=True
    )
    
    return query_response_to_dict(results, embedding_model)

In [48]:
def rerank_response(
    query_str: str,
    resp_list: list[dict],
    reranker_model: CrossEncoder
) -> list[tuple[float, dict]]:
    """
    Takes the output from our search in the vector database and reranks based on crossencoder similarity
    """
    logger.debug(f"Reranking responses [num_query_results: {len(resp_list)}]")

    # extract the original chunk text from the payloads
    response_texts = [item["text"] for item in resp_list]

    # pass query string with text into reranker
    text_pairs = [(query_str, text) for text in response_texts]
    new_scores = reranker_model.predict(
        text_pairs, 
        batch_size=32, 
        show_progress_bar=logger.isEnabledFor(logging.DEBUG),
    )

    ordered_results = sorted(zip(new_scores.astype(float), resp_list), key=lambda x: x[0], reverse=True)

    return ordered_results

In [49]:
def retrieve_similar(
    query_str: str, 
    embedding_model: SentenceTransformer | None = None,
    reranker_model: CrossEncoder | None = None,
    num_retrieve: int = 5,
) -> list[tuple[float, dict]]:
    """Retrieves the most similar items in the db to the query"""
    logger.debug(f"Querying VECTOR_DB for chunks similar to: {query_str} and rank")

    # if no embedding model sets to globally initialize one
    embedding_model = embedding_model or EMBEDDING_MODEL
    reranker_model = reranker_model or RERANKER_MODEL

    # search for 3 times the requested amount and filter via the re-ranker
    search_results = query_db(query_str, EMBEDDING_MODEL, num_retrieve * 3)
    rerank = rerank_response(query_str, search_results, RERANKER_MODEL)
    
    return rerank[:num_retrieve]

In [55]:
# x = retrieve_similar("Trump immigration", num_retrieve=10)

In [56]:
# x[0][1]["title"]

## Configuring a local LLM

In [59]:
def create_context_str(resources: list[tuple[float, dict]]) -> str:
    """converts the list output from retrieve_similar into a str"""

    context = ""
    for i, (_, chunk) in enumerate(resources):
        context += f"article {i}: {chunk['title']}\n"
        context += f"text: {chunk['text']}\n\n"

    return context

In [88]:
def prompt_factory(
    query: str,
    resources: list[tuple[float, dict]],
    tokenizer: AutoTokenizer | None = None,
) -> str:
    """Generates the prompt for the RAG based on a template with the query + resource context"""
    logger.debug(f"Generating prompt")

    tokenizer = tokenizer or TOKENIZER

    prompt_str = "You are an expert news summarizer. " \
                + "Using the clips from news articles provided below" \
                + "generate a clear and concise summary of the query.\n\n"""
    prompt_str += create_context_str(resources)
    prompt_str += f"\nUser Query:\n{query}"
    prompt_str += "\nAnswer"

    template = [{"role": "user", "content": prompt_str}]

    prompt = tokenizer.apply_chat_template(
        template, 
        tokenize=False,
        add_generation_prompt=True
    )

    return prompt   

In [94]:
def generate(
    prompt: str,
    temperature: float = 0.7,
    max_new_tokens: int = 256,
    format_answer_text: bool = True,
    return_answer_only: bool = True,
    tokenizer: AutoTokenizer | None = None,
    llm_model: AutoModelForCausalLM | None = None
) -> str:
    """Method for passing the prompt into the LLM_Model with context"""
    logger.debug("Passing prompt to LLM for generation...")
    
    # generate - tokens
    input_tokens = tokenizer(prompt, return_tensors="pt").to("cuda")

    # generate output
    output = llm_model.generate(
        **input_tokens,
        temperature=temperature,
        do_sample=True,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.eos_token_id # just to get rid of a warning msg
    )
    logger.debug("LLM generation complete.")

    return tokenizer.decode(output[0])

In [91]:
def ask(
    query_str: str,
    temperature: float = 0.7,
    max_new_tokens: int = 256,
    format_answer_text: bool = True,
    return_answer_only: bool = True,
    tokenizer: AutoTokenizer | None = None,
    llm_model: AutoModelForCausalLM | None = None,
) -> str | tuple[str, list[tuple[float, dict]]]:
    """Full method for querying our RAG"""
    logger.debug(f"Asking RAG, query: {query_str}")
    
    tokenizer = tokenizer or TOKENIZER
    llm_model = llm_model or LLM_MODEL

    # retrieval
    resources = retrieve_similar(input_text, 5)

    # augment
    prompt = prompt_factory(query_str, resources, tokenizer)

    # generate llm response
    output_text = generate(
        prompt,
        temperature,
        max_new_tokens,
        format_answer_text,
        return_answer_only,
        tokenizer,
        llm_model
    )
    
    # formatting
    if format_answer_text:
        output_text = output_text.replace(prompt, "").replace("<s> ", "").replace("</s>", "")

    if return_answer_only:
        return output_text

    return output_text, resources

In [92]:
input_text = "what is going on with Trumps immigration policies?"

In [93]:
ask(input_text)

2025-08-04 03:14:29,376 - DEBUG - Asking RAG, query: what is going on with Trumps immigration policies?
2025-08-04 03:14:29,376 - DEBUG - Querying VECTOR_DB for chunks similar to: what is going on with Trumps immigration policies? and rank
  return forward_call(*args, **kwargs)
2025-08-04 03:14:29,393 - DEBUG - Converting query results into a dictionary
2025-08-04 03:14:29,393 - DEBUG - Reranking responses [num_query_results: 15]


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-08-04 03:14:29,414 - DEBUG - Generating prompt


"President Donald Trump's immigration policies in his second term have focused on enacting a crackdown, with increased arrests and deportations. The administration is resurrecting old policies, aiming to regulate the flow of immigrants and address national security concerns. This approach has been criticized by many Americans, with a recent CNN poll showing that about 52% believe Trump has gone too far in deporting undocumented immigrants, and 57% do not believe the federal government is being careful in following the law while carrying out deportations. The administration has also faced backlash for its aggressive immigration enforcement actions, with concerns raised about due process rights and constitutional violations. Despite this, the administration has continued to prioritize immigration as a key issue, but recent polling shows that Americans disapprove of Trump's handling of immigration by a wide margin."

In [95]:
check_cuda_memory()

2025-08-04 03:18:04,753 - DEBUG - Total memory:     17094.48 MB
2025-08-04 03:18:04,754 - DEBUG - Reserved memory:  7249.85 MB
2025-08-04 03:18:04,754 - DEBUG - Allocated memory: 4258.88 MB
2025-08-04 03:18:04,754 - DEBUG - Free within reserved: 2990.97 MB


## References:

- Local Retrieval Augmented Generation (RAG) from Scratch (step by step tutorial) by Daniel Bourke

link: https://www.youtube.com/watch?v=qN_2fnOPY-M