# Libraries

In [None]:
import numpy as np
import scipy.sparse as ss
from pathlib import Path
import pickle
import pandas as pd
import os 

In [None]:
from TELF.pre_processing import Vulture
from TELF.pre_processing.Vulture.modules import SimpleCleaner
from TELF.pre_processing.Vulture.modules import LemmatizeCleaner
from TELF.pre_processing.Vulture.modules import RemoveNonEnglishCleaner
from TELF.pre_processing.Vulture.default_stop_words import STOP_WORDS
from TELF.pre_processing.Vulture.default_stop_phrases import STOP_PHRASES

In [None]:
from TELF.factorization.HNMFk import HNMFk

In [None]:
from TELF.pre_processing import Beaver

In [None]:
from TELF.post_processing import ArcticFox

In [None]:
from TELF.helpers.file_system import find_files

In [None]:
from TELF.pre_processing.Squirrel import Squirrel
from TELF.pre_processing.Squirrel.pruners import EmbeddingPruner
from TELF.pre_processing.Squirrel.pruners import LLMPruner

# Load Data

In [None]:
df = pd.read_csv(os.path.join("..", "..", "data", "sample2.csv"))
df = df.head(50).reset_index(drop=True)
df.info()

# Clean Text

In [None]:
steps = [
    RemoveNonEnglishCleaner(ascii_ratio=0.9, stopwords_ratio=0.25),
    SimpleCleaner(stop_words = STOP_WORDS,
                  stop_phrases = STOP_PHRASES,
                  order = [
                      'standardize_hyphens',
                      'isolate_frozen',
                      'remove_copyright_statement',
                      'remove_stop_phrases',
                      'make_lower_case',
                      'remove_formulas',
                      'normalize',
                      'remove_next_line',
                      'remove_email',
                      'remove_()',
                      'remove_[]',
                      'remove_special_characters',
                      'remove_nonASCII_boundary',
                      'remove_nonASCII',
                      'remove_tags',
                      'remove_stop_words',
                      'remove_standalone_numbers',
                      'remove_extra_whitespace',
                      'min_characters',
                  ]
                 ),
    LemmatizeCleaner('spacy'),
]

In [None]:
vulture = Vulture(n_jobs=1, verbose=10)
df = vulture.clean_dataframe(df=df, 
                        columns=["abstract", "title"],
                        append_to_original_df=True,
                        concat_cleaned_cols=True
                        )

In [None]:
df.clean_abstract_title

# Build The Vocabulary and the Document-Term Matrix

In [None]:
DATA_COLUMN = 'clean_abstract_title'
RESULTS = "result_example"
HIGHLIGHT_WORDS = ['analysis', 'tensor']
HIGHLIGHT_WEIGHTS = [2 for i in HIGHLIGHT_WORDS]
beaver = Beaver()
os.makedirs(RESULTS, exist_ok=True)
settings = {
    "dataset" : df,
    "target_column" : DATA_COLUMN,
    'highlighting': HIGHLIGHT_WORDS,
    'weights':HIGHLIGHT_WEIGHTS,
    "matrix_type" : "tfidf",
    "save_path" : RESULTS
}
X, vocabulary = beaver.documents_words(**settings)

In [None]:
X = X.T.tocsr()
X

In [None]:
assert X.shape[1] == len(df)

In [None]:
vocabulary[:10]

In [None]:
len(vocabulary)

# Factorize with HNMFk

In [None]:
# Define the range of cluster numbers (K) to search over
Ks = np.arange(2, 10, 1)  # From 2 to 29 inclusive

# Number of perturbations and iterations to run
perts = 2  # Number of perturbed runs to estimate stability
iters = 2  # Number of iterations for each perturbation

# Small perturbation epsilon added to input data
eps = 0.025

# Initialization method for NMF (Non-negative Matrix Factorization)
init = "nnsvd"  # Nonnegative SVD initialization

# Path to save HNMFk results
HNMFK_save_path = os.path.join(RESULTS, "example_HNMFK")
name = HNMFK_save_path  # Alias for convenience

# Parameters for HNMFk (Hierarchical Nonnegative Matrix Factorization k-search)
nmfk_params = {
    "k_search_method": "bst_pre",             # Method for determining optimal k (e.g., binary search with pre-checks)
    "sill_thresh": 0.7,                       # Silhouette threshold to accept a given k
    "H_sill_thresh": 0.05,                    # Threshold for H-matrix silhouette to refine k selection
    "n_perturbs": perts,                      # Number of perturbations
    "n_iters": iters,                         # Number of iterations per perturbation
    "epsilon": eps,                           # Perturbation strength
    "n_jobs": -1,                             # Use all available CPU cores
    "init": init,                             # NMF initialization method
    "use_gpu": False,                         # Whether to use GPU acceleration
    "save_path": HNMFK_save_path,             # Directory where results will be saved
    "predict_k_method": "WH_sill",            # Method to predict k using W and H matrix silhouettes
    "predict_k": True,                        # Whether to automatically predict k
    "verbose": False,                          # Verbose output
    "nmf_verbose": False,                     # Verbose output from NMF algorithm
    "transpose": False,                       # Whether to transpose input data
    "pruned": True,                           # Whether to prune unstable clusters
    "nmf_method": "nmf_fro_mu",               # NMF solver method (Frobenius norm, multiplicative updates)
    "calculate_error": False,                 # Whether to calculate reconstruction error
    "use_consensus_stopping": 0,              # Whether to use consensus stopping (0 = off)
    "calculate_pac": False,                   # Whether to compute PAC (proportion of ambiguous clustering)
    "consensus_mat": False,                   # Whether to generate consensus matrix
    "perturb_type": "uniform",                # Type of perturbation (e.g., uniform noise)
    "perturb_multiprocessing": False,         # Use multiprocessing during perturbation
    "perturb_verbose": False,                 # Verbose output during perturbation
    "simple_plot": True                       # Whether to generate simplified summary plots
}


In [None]:
class CustomSemanticCallback:
    def __init__(self, 
                 df: pd.DataFrame, 
                 target_column=DATA_COLUMN,
                 options={'vocabulary': vocabulary},
                 matrix_type="tfidf") -> None:
        """
        Initializes the callback with a DataFrame and matrix generation settings.

        Parameters:
        - df: The full DataFrame containing the text data.
        - target_column: Column name containing the target text to vectorize (default is a global DATA_COLUMN).
        - options: Options dictionary passed to Beaver (e.g., fixed vocabulary, token settings).
        - matrix_type: Type of vectorization matrix (e.g., "tfidf", "count").
        """
        self.df = df
        self.target_column = target_column
        self.options = options
        self.matrix_type = matrix_type

    def __call__(self, original_indices: np.ndarray):
        """
        Callable interface for dynamically generating document-term matrices 
        from a subset of the DataFrame.

        Parameters:
        - original_indices: Numpy array of row indices from self.df to subset and transform.

        Returns:
        - Tuple of (X, metadata), where:
            - X is a document-term sparse matrix (CSR format).
            - metadata is a dict containing either 'vocab' or a 'stop_reason' if failed.
        """
        current_beaver = Beaver()  # Initialize a new instance of the Beaver text vectorizer

        # Extract the subset of the DataFrame using the provided indices
        current_df = self.df.iloc[original_indices].copy()

        # Construct parameters for the Beaver vectorizer
        current_beaver_matrix_settings = {
            "dataset": current_df,
            "target_column": self.target_column,
            "options": self.options,
            "highlighting": HIGHLIGHT_WORDS,     # Global list of words to highlight
            "weights": HIGHLIGHT_WEIGHTS,        # Associated weights for highlighting
            "matrix_type": self.matrix_type,     # Type of matrix to construct (e.g., TF-IDF)
            "save_path": None                    # No file output; matrix is returned
        }

        try:
            # Attempt to generate the document-word matrix
            current_X, vocab = current_beaver.documents_words(**current_beaver_matrix_settings)
            
            # Transpose to get documents as rows (CSR format is efficient for row slicing)
            current_X = current_X.T.tocsr()
            
            return current_X, {'vocab': vocab}

        except:
            # On failure, return a 1x1 matrix to signal a stopping condition for downstream tasks
            csr_matrix = ss.csr_matrix([[1]])
            return csr_matrix, {'stop_reason': "documents_words couldn't make matrix"}

In [None]:
# Parameters for initializing and training the HNMFk model
hnmfk_params = {
    "n_nodes": 1,  # Number of root nodes to begin with (can grow as depth increases)
    
    # List of NMF parameters for the top-level (depth=0); can use different sets for different nodes
    "nmfk_params": [nmfk_params],  
    
    # Callable that generates a document-term matrix from a subset of the DataFrame (dynamic input for each node)
    "generate_X_callback": CustomSemanticCallback(df=df, options={'vocabulary': vocabulary}),
    
    "cluster_on": "H",  # Which factor matrix to use for clustering (H = document-topic)
    
    "depth": 1,  # Depth of the hierarchy; e.g., 2 means root + one layer of children
    
    "sample_thresh": 10,  # Minimum number of samples required to split/cluster a node further
    
    "K2": False,  # If True, forces all subclusters to use k=2; here we allow varying k
    
    # Range of K to try for deeper layers (children nodes)
    "Ks_deep_min": 1,
    "Ks_deep_max": 20,
    "Ks_deep_step": 1,
    
    "experiment_name": name,  # Folder/identifier for saving results and checkpoints
}

# Instantiate the HNMFk model with the above parameters
model = HNMFk(**hnmfk_params)

# Fit the model on matrix X using the specified range of Ks
# - from_checkpoint: load previously saved progress if available
# - save_checkpoint: periodically save progress for recovery or inspection
model.fit(X, Ks, from_checkpoint=False, save_checkpoint=True)

# Traverse and collect all nodes created in the hierarchical model
all_nodes = model.traverse_nodes()
print(len(all_nodes))  # Output the total number of nodes (clusters at all levels)

# Save the full trained model to a pickle file for reuse or inspection
with open(os.path.join('result_example', 'HNMFK_highlight.pkl'), 'wb') as output_file:
    pickle.dump(model, output_file)


In [None]:
# Load a pre-trained HNMFk model from disk
model = HNMFk(experiment_name=os.path.join("result_example", "example_HNMFK"))
model.load_model()  # Loads model from the provided experiment_name path

# Initialize ArcticFox pipeline
# - model: the hierarchical clustering model (HNMFk)
# - embedding_model: name of the sentence embedding model used for label generation
# - clean_cols_name: column in the DataFrame containing the cleaned text input
pipeline = ArcticFox(
    model=model,
    embedding_model="SCINCL",        # Example: SCINCL embedding model fine-tuned for scientific text
    clean_cols_name=DATA_COLUMN      # The text column used for label generation and analysis
)

# This handles hierarchical processing
pipeline.run_full_pipeline(
    vocab=vocabulary,                # Vocabulary used to guide or filter cluster content
    data_df=df,                      # Original dataset (same used in HNMFk)
    label_clusters=False,             # Enable automatic labeling of clusters
    generate_stats=False,             # Generate cluster-level statistics
    process_parents=True,            # Propagate labels or stats upward through the hierarchy
    skip_completed=True,             # Skip processing of nodes already labeled/stored
)


# Prune with Squirrel

In [None]:
path = os.path.join('result_example', 'example_HNMFK', 'depth_0', 'Root')
start_with = 'cluster_for_k='

csv_data_path = find_files(path =path, start_with=start_with)[0]

df = pd.read_csv(csv_data_path)
df.info()

In [None]:
OUTPUT_PRUNING_DIR = Path("example_output")
LABEL_COLUMN      = "cluster"
LABEL_VALUE       = 7
df['title_abstract'] = df['title'] + " " + df['abstract']

In [None]:
emb_pruner = EmbeddingPruner(
    embedding_model="SCINCL",
    distance_std_factor=3.0,
    overwrite_embeddings=False,
    use_gpu=True,
    verbose=True,
)

In [None]:
llm_pruner = LLMPruner(
    llm_model_name="llama3.2:latest",
    llm_api_url="http://localhost:11434",
    llm_vote_trials=4,
    llm_promote_threshold=0.75,
    llm_temperature=0.7,
    verbose=True,
    # We can change the prompt as needed:
    #prompt = Callable[[str, Iterable[str]], str]
)

In [None]:
# Here is how the default prompt looks like.
print(llm_pruner.prompt(candidate="Test document", contexts=["Similar document"]))

In [None]:
pipeline = [emb_pruner, llm_pruner]

processor = Squirrel(
    data_source=df,
    output_dir=OUTPUT_PRUNING_DIR,
    label_column=LABEL_COLUMN,
    reference_label=LABEL_VALUE,
    pipeline=pipeline
)
final_df = processor()