In [None]:
from typing import Iterable, Sequence
from datasets import load_dataset, tqdm
from tqdm.auto import tqdm
import weaviate
from news_chatbot.config import load_weaviate_settings 
from weaviate.classes.config import Configure, DataType, Property

COLLECTION_NAME="BBCArticle"
BATCH_SIZE = 64
def _create_collection(client: weaviate.WeaviateClient) -> weaviate.collections.Collection:
    if client.collections.exists(COLLECTION_NAME):
        return client.collections.get(COLLECTION_NAME)

    client.collections.create(
        name = COLLECTION_NAME,
        vectorizer_config= Configure.Vectorizer.text2vec_transformers(),
        properties=[
            Property(name="news_id", data_type=DataType.TEXT),
            Property(name="article", data_type=DataType.TEXT),
            Property(name="summary", data_type=DataType.TEXT),
        ]
    )
    return client.collections.get(COLLECTION_NAME)

def _batched(rows: Iterable[dict], size: int) -> Iterable[Sequence[dict]]:
    batch:list[dict] = []
    for row in rows:
        batch.append(row)
        if(len(batch) == size):
            yield batch
            batch = []
    if batch:
        yield batch
        
def ingest()-> None:
    settings = load_weaviate_settings()
    dataset = load_dataset("shwet/BBC_NEWS", split="train")
    total_rows = len(dataset)
    print(dataset)

    with weaviate.connect_to_local(
        host=settings.host,
        port = settings.port,
        grpc_port = settings.grpc_port,
        headers=settings.headers,
    ) as client:
        collection = _create_collection(client)
        existing = collection.aggregate.over_all(total_count=True).total_count or 0
        if existing >= total_rows:
            print(f"Dataset already ingested ({existing} objects); skipping.")
            return
        with collection.batch.dynamic() as writer, tqdm(
            total= total_rows,
            desc="Ingesting BBC articles",
            unit = "rows",
        ) as progress:
            for rows in tqdm(_batched(dataset, BATCH_SIZE)):
                for row in rows:
                    writer.add_object(
                        properties = {
                            "news_id": str(row["ids"]),
                            "article": row["articles"],
                            "summary": row["summary"],
                        },
                    )
                progress.update(len(rows))


In [None]:
from textwrap import fill
from pprint import pprint
from unittest import result

if __name__ == "__main__":
    ingest()
    settings = load_weaviate_settings()
    with weaviate.connect_to_local(
        host=settings.host,
        port = settings.port,
        grpc_port = settings.grpc_port,
        headers=settings.headers,
    ) as client:
        collection = client.collections.get("BBCArticle")
        response = collection.query.near_text("Give me news about India", limit = 2, return_metadata=["distance","score"])
        response2 = collection.query.hybrid("Give me news about India", limit = 2,alpha=0.3)
        for object in [*response.objects, *response2.objects]:
            # print(object)
            # print(json.dumps(asdict(object.metadata), indent=2))
            print(f"Score = {object.metadata.score} \n {fill(object.properties["summary"], width=80)}")

Dataset({
    features: ['ids', 'articles', 'summary'],
    num_rows: 1903
})
Dataset already ingested (1999 objects); skipping.
{'BBCArticle': _CollectionConfigSimple(name='BBCArticle',
                                       description=None,
                                       generative_config=None,
                                       properties=[_Property(name='news_id',
                                                             description=None,
                                                             data_type=<DataType.TEXT: 'text'>,
                                                             index_filterable=True,
                                                             index_range_filters=False,
                                                             index_searchable=True,
                                                             nested_properties=None,
                                                             tokenization=<Tokenization.WORD: 'word'

In [121]:
settings = load_weaviate_settings()
with weaviate.connect_to_local(
        host=settings.host,
        port = settings.port,
        grpc_port = settings.grpc_port,
        headers=settings.headers,
    ) as client:
        if client.collections.exists(COLLECTION_NAME):
            client.collections.delete(COLLECTION_NAME)
            print(f"Collection '{COLLECTION_NAME}' has been deleted.")

Collection 'BBCArticle' has been deleted.
