In [None]:
import os
os.chdir('../..')

In [None]:
from haystack.nodes import RAGenerator, DensePassageRetriever

dense_retriever = DensePassageRetriever(
    document_store=paragraph_document_store,
    query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
    passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
    use_gpu=True,
    embed_title=True,
)

# Initialize RAG Generator
generator = RAGenerator(
    model_name_or_path="facebook/rag-token-nq",
    use_gpu=True,
    top_k=1,
    max_length=200,
    min_length=2,
    embed_title=True,
    num_beams=2,
    retriever=dense_retriever
)


In [None]:
from rest_api.controller.impact_screening import _process_request_clusters
from rest_api.controller.impact_topic import calculate_clusters
from rest_api.schema import ImpactTopicRequest

In [None]:
response = calculate_clusters(ImpactTopicRequest(impact_concept="Cancer", company_concept="Paint", impact_polarity="NEGATIVE"))

In [None]:
documents = response.documents

In [None]:
contents = [doc.text for doc in documents]
document_embeddings = embedder.encode(
    contents, convert_to_tensor=True
).cpu()

In [None]:
from sklearn.cluster import Birch, AgglomerativeClustering

agg_clustering = AgglomerativeClustering(
    n_clusters=None, distance_threshold=0.6
    )
brc = Birch(threshold=0.5, n_clusters=agg_clustering)

In [None]:
brc.fit(document_embeddings[:10])

In [None]:
brc.predict(document_embeddings[:10])

In [None]:
brc.partial_fit(document_embeddings[10:])

In [None]:
labels = brc.predict(document_embeddings)
labels

In [None]:
from haystack.pipelines import Pipeline
from smart_evidence.helpers import opensearch_connection
from smart_evidence.components.count_clustering import CountClustering
from smart_evidence.components.count_transformers_clustering import CountTransformerClustering
from smart_evidence.components.document_classifier import HeuristicsDocumentClassifier
from smart_evidence.components.summarizer import IXTransformersSummarizer
from smart_evidence.components.transformers_clustering import TransformersClustering

In [None]:
from collections import defaultdict, Counter
import logging
from typing import Any, List, Optional, Set, Tuple

from haystack.nodes.base import BaseComponent
from haystack.schema import Document
from rest_api.config import LOG_LEVEL
from sklearn.cluster import DBSCAN
import numpy as np

logging.getLogger(__name__).setLevel(LOG_LEVEL)
logger = logging.getLogger(__name__)



class TransformersClustering(BaseComponent):
    outgoing_edges = 1

    def __init__(
        self,
        clustering,
        embedder: Optional[Any] = None,
        separator_for_cluster_texts: str = "\n",
        min_cluster_size: int = 2
    ):
        """
        Use sklearn to vectorize and cluster documents.
        :param separator_for_single_summary: If `generate_single_summary=True` in `predict()`, we need to join all docs
                                             into a single text. This separator appears between those subsequent docs.
        """
        self.print_log: Set[str] = set()
        self.separator_for_cluster_texts = separator_for_cluster_texts
        self.min_cluster_size = min_cluster_size
        self.embedder = embedder
        self.clustering = clustering

    def run(self, documents: List[Document]):  # type: ignore

        results: dict = {
            "documents": [],
            "clusters": [],
            "n_total_documents": len(documents),
        }

        if documents:
            (
                results["documents"],
                results["clusters"],
            ) = self.predict(documents=documents)

        return results, "output_1"

    def cluster(self, embedding_matrix: np.ndarray) -> List[int]:
        return list(self.clustering.fit(embedding_matrix))

    def build_result(self, documents, clusters):
        n_clusters = len(list(set(clusters)))
        # logger.info(f"{n_clusters} clusters for {len(documents)} documents: {clusters}")

        for document, cluster in zip(documents, clusters):
            document.meta['cluster_id'] = cluster

        return documents, clusters

    def predict(
        self,
        documents: List[Document],
    ) -> Tuple[List[Document], int, int]:
        """
        Produce the clustering for the supplied documents.
        These document can for example be retrieved via the Retriever.
        :param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on.
        :return: List of Documents, where Document.text contains the concatenated text of clusters and Document.meta["ids"]
                 ids of the original documents
        """
        if len(documents) == 0:
            raise AttributeError(
                "Summarizer needs at least one document to produce a summary."
            )

        contents = [doc.content for doc in documents]
        document_embeddings = self.embedder.encode(
            contents, convert_to_tensor=True
        ).cpu()

        clusters = self.cluster(document_embeddings)
        result = self.build_result(documents, clusters)
        
        return result



In [None]:
class BirchTransformersClustering(TransformersClustering):
    def cluster(self, embedding_matrix: np.ndarray) -> List[int]:
        self.clustering.partial_fit(embedding_matrix)
        return self.clustering.predict(embedding_matrix)

In [None]:
# from smart_evidence.pipeline.query_pipeline import FilterRetriever
from smart_evidence.components.document_classifier import HeuristicsDocumentClassifier
from smart_evidence.components.summarizer import IXTransformersSummarizer
from haystack.pipelines import Pipeline

# document_classifier = HeuristicsDocumentClassifier()
# summarizer = IXTransformersSummarizer("chinhon/headline_writer")

clustering = Birch(n_clusters=None)
transformer_clustering = BirchTransformersClustering(
    clustering=clustering,
    embedder=embedder, 
    min_cluster_size=2
)
transformer_cluster_pipeline = Pipeline()
# transformer_cluster_pipeline.add_node(
#     component=document_classifier, name="DocumentClassifier", inputs=["Query"]
# )
transformer_cluster_pipeline.add_node(
    component=transformer_clustering, name="Clustering", inputs=["Query"]
)
# transformer_cluster_pipeline.add_node(
#     component=summarizer, name="Summarizer", inputs=["Clustering"]
# )

In [None]:
from datetime import date

def write_clusters(clusters):
    cluster_documents = []
    for cluster_document in clusters['documents']:
        cluster_document.meta['updated_at'] = date.today()
        cluster_document.meta['is_curated'] = False
        cluster_document.meta['paragraph_ids'] = [d['id'] for d in cluster_document.meta['documents']]
        cluster_documents.append(cluster_document)

    clusters_document_store.write_documents(cluster_documents)

In [None]:
from tqdm.autonotebook import tqdm

global_clusters = defaultdict(lambda: list())
batch = []
for i, document in tqdm(enumerate(paragraph_document_store.get_all_documents_generator())):
    batch.append(document)
    if len(batch) % 1000 == 0:
        results = transformer_cluster_pipeline.run(documents=batch)
        documents, clusters = results['documents'], results['clusters']
        for document ,cluster_id in zip(documents, clusters):
            global_clusters[cluster_id].append(document.id)
        batch = []
    # if i == 1000:
    #     break

if batch:
    results = transformer_cluster_pipeline.run(documents=batch)
    documents, clusters = results['documents'], results['clusters']
    for document ,cluster_id in zip(documents, clusters):
        global_clusters[cluster_id].append(document.id)
    batch = []

In [None]:
global_clusters = dict(sorted(global_clusters.items(), key=lambda x: len(x[1]), reverse=True))

In [None]:
sum([True for k, v in global_clusters.items() if len(v) > 1]), sum([len(v) for k, v in global_clusters.items() if len(v) > 1]), sum([len(v) for k, v in global_clusters.items()])