In [2]:
from collections import Counter
from dataclasses import dataclass
import datetime
import random

from scipy.sparse import csr_matrix
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
import torch

In [3]:
# download the 20 Newsgroups data set from scikit learn
newsgroup_data = fetch_20newsgroups(subset="all")

In [4]:
# help the interpreter find the fuzzy_artmap module 
import os, sys
dir2 = os.path.abspath('')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)

from fuzzy_artmap import FuzzyArtMap

In [5]:
@dataclass
class ProcessedCorpus:
    vectorized_corpus: csr_matrix
    document_corpus_map: dict[int: int]
    categories: dict[int: list[str]]


def get_tf_idf_twenty_newsgroup_corpus() -> ProcessedCorpus:
    """Helper function to vectorize the 20 Newsgroup corpus to TF-IDF features, and associate the vectorized documents with their categories"""
    twenty_newsgroup_vectorizer = TfidfVectorizer(input="content", encoding="latin1", stop_words='english', min_df=0.001, max_df=0.9)
    twenty_newsgroup_vectorized_corpus = twenty_newsgroup_vectorizer.fit_transform(newsgroup_data.data)
    categories = list(newsgroup_data.target_names)
    twenty_newsgroup_categories = {}
    for document_index, category_index in enumerate(newsgroup_data.target):
        twenty_newsgroup_categories[document_index]=categories[category_index]

    return ProcessedCorpus(vectorized_corpus = twenty_newsgroup_vectorized_corpus,
                           document_corpus_map = {index: index for index in range(twenty_newsgroup_vectorized_corpus.shape[0])},
                           categories = twenty_newsgroup_categories)

In [6]:
# Setup the valid (matching) & invalid (not matching) categories, complement encoded
valid_vector = torch.tensor([[1.0, 0.0]])
invalid_vector = torch.tensor([[0.0, 1.0]])

In [7]:
def get_test_input_and_output(doc_index, vector, categories, selected_category):
    """Helper function to get the complement encoded input, and encoded label"""
    if selected_category == categories[doc_index]:
        output_value = valid_vector
    else:
        output_value = invalid_vector
    
    complement_encoded_input = FuzzyArtMap.complement_encode(torch.from_numpy(vector.toarray()))
    return complement_encoded_input, output_value

In [8]:
def test_predictions(fuzzy_artmap, document_indexes, corpus, categories, selected_category, document_corpus_index_map):
    """Count the True Positives (TP), True Negatives (TN), False Positives (FP), and False Negatives (FN) for predictions made by the model"""
    accuracy_counter = Counter({"TP": 0, "TN": 0, "FP": 0, "FN": 0})
    for corpus_index in document_indexes[100:]:  # Skip the first 100 documents used for training
        document_index = document_corpus_index_map[corpus_index]
        input_vector, class_vector = get_test_input_and_output(document_index, corpus[corpus_index], categories, selected_category)
        prediction = fuzzy_artmap.predict(input_vector)
        if class_vector[0][0].item():
            if prediction[0][0][0].item():
                update = {"TP": 1}
            else:
                update = {"FN": 1}
        else:
            if prediction[0][0][0].item():
                update = {"FP": 1}
            else:
                update = {"TN": 1}
        accuracy_counter.update(update)
    print(accuracy_counter)
    return accuracy_counter

In [9]:
def calculate_metrics(accuracy_data, duration, number_of_relevant_documents):
    """Calculate accuracy, precision, recall, and speed metrics given the accuracy data"""
    total_documents_tested = sum(accuracy_data.values())
    accuracy = (accuracy_data["TP"] + accuracy_data["TN"]) / total_documents_tested
    precision = accuracy_data["TP"] / (accuracy_data["TP"] + accuracy_data["FP"])
    recall = accuracy_data["TP"] / (accuracy_data["TP"] + accuracy_data["FN"])
    recall_set = accuracy_data["TP"] / number_of_relevant_documents
    rate = total_documents_tested / duration.seconds
    print(f"accuracy: {accuracy}\nprecision: {precision}\nrecall: {recall}\nrecall (set): {recall_set}\ntotal relevant docs: {number_of_relevant_documents}\ntotal docs:{total_documents_tested}\nprediction rate (docs/second):{rate}")


In [10]:
def setup_twenty_newsgroup_corpus():
    processed_corpus = get_tf_idf_twenty_newsgroup_corpus()
    categories = {index: category for index, category in processed_corpus.categories.items() }
    shuffled_document_indexes = random.sample(list(categories.keys()), len(categories))
    return processed_corpus.vectorized_corpus, processed_corpus.categories, shuffled_document_indexes, processed_corpus.document_corpus_map

In [11]:

def train_model(corpus, shuffled_document_indexes, categories, selected_category, document_corpus_index_map):
    fuzzy_artmap = FuzzyArtMap(number_of_category_nodes=36, baseline_vigilance=0.95, committed_node_learning_rate=1.0)
    training_split = Counter()
    for corpus_index in shuffled_document_indexes[:100]:
        document_index = document_corpus_index_map[corpus_index]
        training_split.update({''.join(categories[document_index]):1})
        input_vector, class_vector = get_test_input_and_output(document_index, corpus[corpus_index], categories, selected_category)
        fuzzy_artmap.fit(input_vector, class_vector)

    return fuzzy_artmap, training_split

In [None]:
# This cell trains a basic model using naive (random sample) one-shot offline training using 100 documents, 
# and then runs the prediction, calculating the performance metrics

relevant_category = "alt.atheism"

corpus, categories, shuffled_document_indexes, document_corpus_index_map = setup_twenty_newsgroup_corpus()
fuzzy_artmap, training_split = train_model(corpus, shuffled_document_indexes, categories, relevant_category, document_corpus_index_map)

start_predictions = datetime.datetime.now()
print(f"start predictions: {start_predictions}")
accuracy_data = test_predictions(fuzzy_artmap, shuffled_document_indexes, corpus, categories, relevant_category, document_corpus_index_map)

end_predictions = datetime.datetime.now()
prediction_duration = end_predictions-start_predictions
print(f"end predictions: {end_predictions} - elapsed: {prediction_duration}")

number_of_relevant_documents = len(list([i for i in shuffled_document_indexes[100:] if relevant_category in categories[document_corpus_index_map[i]]]))
calculate_metrics(accuracy_data, prediction_duration, number_of_relevant_documents)

In [12]:
def query(fuzzy_artmap, corpus, categories, available_document_indexes, document_corpus_index_map, selected_category):
    """Gets the predictions for the remaining unevaluated documents in the corpus, used in the active learning test"""
    working_indexes = list(available_document_indexes)

    predictions = []
    for corpus_index in working_indexes:
        document_index = document_corpus_index_map[corpus_index]
        input_vector, class_vector = get_test_input_and_output(document_index, corpus[corpus_index], categories, selected_category)
        prediction, membership_degree = fuzzy_artmap.predict_with_membership(input_vector)
        if prediction[0][0].item():
            predictions.append((membership_degree, corpus_index, class_vector, input_vector))
    predictions.sort(key=lambda p: p[0], reverse=True)
    return predictions

In [13]:
def run_active_learning_test(setup_corpus, selected_category):
    """Uses an active learning approach to query the 20 Newsgroups corpus for the specified category. The corpus is ranked, the top 100 (`batch_size`) 
    documents are evaluated against their ground truth label, and the model is updated after every judgement. The evaluated documents are removed from the
    available (unevaluated) documents in the corpus. Batch-level metrics are reported after each iteration. Evaluation stops when no more relevant documents
    are predicted in the remaining unevaluated documents."""
    print(f"start: {datetime.datetime.now()}")
    corpus, categories, shuffled_document_indexes, document_corpus_index_map = setup_corpus()
    training_indexes = set()
    positive_samples = list(filter(lambda doc_index: selected_category in categories[document_corpus_index_map[doc_index]], shuffled_document_indexes))[:10]
    negative_samples = list(filter(lambda doc_index: selected_category not in categories[document_corpus_index_map[doc_index]], shuffled_document_indexes))[:90]

    training_indexes.update(positive_samples)
    training_indexes.update(negative_samples)
    available_document_indexes = set(shuffled_document_indexes) - training_indexes
    number_of_relevant_documents = len(list([i for i in shuffled_document_indexes if i not in training_indexes and selected_category in categories[document_corpus_index_map[i]]]))

    print(f"start training: {datetime.datetime.now()}")    
    fuzzy_artmap, _ = train_model(corpus, list(training_indexes), categories, selected_category, document_corpus_index_map)
    
    found_relevant_documents = 0
    active_learning_iteration = 0
    has_candidates = True
    start_predictions = datetime.datetime.now()
    print(f"start active learning: {start_predictions}")
    batch_size = 100
    evaluated_document_count = 0
    while found_relevant_documents != number_of_relevant_documents and has_candidates:        
        relevant_documents_in_batch = 0
        candidates = query(fuzzy_artmap, corpus, categories, available_document_indexes, document_corpus_index_map, selected_category)
        candidate_batch_size = 0
        for candidate in candidates[:batch_size]:
            # print(f"{datetime.datetime.now()} - training")
            evaluated_document_count += 1
            candidate_batch_size +=1
            fuzzy_artmap.fit(candidate[3], candidate[2])
            available_document_indexes.remove(candidate[1]) 
            if candidate[2][0,][0]:
                found_relevant_documents += 1
                relevant_documents_in_batch += 1

        if len(candidates) == 0:
            has_candidates = False
        active_learning_iteration += 1
        batch_recall = 0
        if has_candidates:
            batch_recall = relevant_documents_in_batch/candidate_batch_size
        print(f"{datetime.datetime.now()} - {active_learning_iteration} - {found_relevant_documents}/{number_of_relevant_documents} | batch recall: {batch_recall:.4f} | recall - {(found_relevant_documents/number_of_relevant_documents):.4f} precision - {(found_relevant_documents/evaluated_document_count):.4f} | {len(available_document_indexes)}")
    
    end_predictions = datetime.datetime.now()
    prediction_duration = end_predictions-start_predictions
    print(f"end active learning: {end_predictions} - elapsed: {prediction_duration}")
    print(f"number of Fuzzy ARTMAP Categories: {fuzzy_artmap.get_weight_a().shape[0]}")

The following cells recapitulate the results published in Courchaine & Sethi (2022) for the tf-idf vectorization of 20 Newsgroup topics pc-hardware (`comp.sys.ibm.pc.hardware`), med (`sci.med`), and forsale (`misc.forsale`) (listed as `tf-idf-pc-hardware`, `tf-idf-med`, and `tf-idf-forsale` in Table IV - 20Newsgroup: Fuzzy ARTMAP Performance).
In the table below, the Original Recall, Precision, and F-1 are from Courchaine & Sethi (2022) Table IV, the Recall, Precision, and F-1 columns are values generated from this notebook, higher values are bolded. With the exception of the `forsale` topic, the runs in this notebook outperformed on both recall and precision. However, this is a single instance, and any given run of the notebook may produce higher or lower values. The table is a point comparison indicating comparability between the published results and the implementation in this repository and notebook, it is not a statistical evaluation of the two implementations.

| Topic  | Original Recall   | Original Precision  | Original F-1   | Recall   |Precision   | F-1|
|---|---|---|---|---|---|---|
|pc-hardware   | 0.872  |0.236   |0.372   |**0.875**   |**0.275**   |**0.418**   |
|med   | 0.913  |0.287   |0.436   |**0.980**   |**0.362**   |**0.529**   |
|forsale   |0.907   |0.259   |0.403   |0.907   |**0.347**   |**0.502**   |

See C. Courchaine and R. J. Sethi, "Fuzzy Law: Towards Creating a Novel Explainable Technology-Assisted Review System for e-Discovery," 2022 IEEE International Conference on Big Data (Big Data), Osaka, Japan, 2022, pp. 1218-1223, doi: 10.1109/BigData55660.2022.10020503. for full reference.

In [13]:
# Warning, this can take a while to complete 10-25 minutes
run_active_learning_test(setup_twenty_newsgroup_corpus, "comp.sys.ibm.pc.hardware")

start: 2024-12-21 11:22:52.605570
start training: 2024-12-21 11:22:54.624420
start active learning: 2024-12-21 11:22:54.990824
2024-12-21 11:23:32.155137 - 1 - 15/972 | batch recall: 0.5000 | recall - 0.0154 precision - 0.5000 | 18716
2024-12-21 11:24:08.004706 - 2 - 36/972 | batch recall: 0.2917 | recall - 0.0370 precision - 0.3529 | 18644
2024-12-21 11:24:45.034434 - 3 - 79/972 | batch recall: 0.4300 | recall - 0.0813 precision - 0.3911 | 18544
2024-12-21 11:25:20.496275 - 4 - 123/972 | batch recall: 0.4400 | recall - 0.1265 precision - 0.4073 | 18444
2024-12-21 11:25:53.686392 - 5 - 176/972 | batch recall: 0.5300 | recall - 0.1811 precision - 0.4378 | 18344
2024-12-21 11:26:22.382550 - 6 - 238/972 | batch recall: 0.6200 | recall - 0.2449 precision - 0.4741 | 18244
2024-12-21 11:26:50.681439 - 7 - 309/972 | batch recall: 0.7100 | recall - 0.3179 precision - 0.5133 | 18144
2024-12-21 11:27:16.053187 - 8 - 366/972 | batch recall: 0.5700 | recall - 0.3765 precision - 0.5214 | 18044
2024

In [14]:
# Warning, this can take a while to complete 10-25 minutes
run_active_learning_test(setup_twenty_newsgroup_corpus, "sci.med")

start: 2024-12-21 11:41:44.404124
start training: 2024-12-21 11:41:46.463867
start active learning: 2024-12-21 11:41:46.953340
2024-12-21 11:42:35.521657 - 1 - 72/980 | batch recall: 0.7200 | recall - 0.0735 precision - 0.7200 | 18646
2024-12-21 11:43:20.939770 - 2 - 124/980 | batch recall: 0.5200 | recall - 0.1265 precision - 0.6200 | 18546
2024-12-21 11:44:00.595891 - 3 - 182/980 | batch recall: 0.5800 | recall - 0.1857 precision - 0.6067 | 18446
2024-12-21 11:44:35.896545 - 4 - 253/980 | batch recall: 0.7100 | recall - 0.2582 precision - 0.6325 | 18346
2024-12-21 11:45:09.410249 - 5 - 327/980 | batch recall: 0.7400 | recall - 0.3337 precision - 0.6540 | 18246
2024-12-21 11:45:37.442186 - 6 - 398/980 | batch recall: 0.7100 | recall - 0.4061 precision - 0.6633 | 18146
2024-12-21 11:46:03.985412 - 7 - 478/980 | batch recall: 0.8000 | recall - 0.4878 precision - 0.6829 | 18046
2024-12-21 11:46:31.816882 - 8 - 555/980 | batch recall: 0.7700 | recall - 0.5663 precision - 0.6937 | 17946
20

In [14]:
# Warning, this can take a while to complete 10-20 minutes
run_active_learning_test(setup_twenty_newsgroup_corpus, "misc.forsale")

start: 2024-12-29 10:33:29.757247
start training: 2024-12-29 10:33:31.060313
start active learning: 2024-12-29 10:33:31.284480
2024-12-29 10:34:00.542380 - 1 - 33/965 | batch recall: 0.6600 | recall - 0.0342 precision - 0.6600 | 18696
2024-12-29 10:34:29.108280 - 2 - 105/965 | batch recall: 0.7200 | recall - 0.1088 precision - 0.7000 | 18596
2024-12-29 10:34:57.705391 - 3 - 164/965 | batch recall: 0.5900 | recall - 0.1699 precision - 0.6560 | 18496
2024-12-29 10:35:22.948842 - 4 - 213/965 | batch recall: 0.4900 | recall - 0.2207 precision - 0.6086 | 18396
2024-12-29 10:35:45.841985 - 5 - 253/965 | batch recall: 0.4000 | recall - 0.2622 precision - 0.5622 | 18296
2024-12-29 10:36:08.085838 - 6 - 295/965 | batch recall: 0.4200 | recall - 0.3057 precision - 0.5364 | 18196
2024-12-29 10:36:29.043603 - 7 - 343/965 | batch recall: 0.4800 | recall - 0.3554 precision - 0.5277 | 18096
2024-12-29 10:36:49.896078 - 8 - 388/965 | batch recall: 0.4500 | recall - 0.4021 precision - 0.5173 | 17996
20