In [1]:
# download the 20 Newsgroups data set from scikit learn
from sklearn.datasets import fetch_20newsgroups
newsgroup_data = fetch_20newsgroups(subset="all")

In [2]:
# help the interpreter find the fuzzy_artmap module 
import os, sys
dir2 = os.path.abspath('')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)

In [3]:
from collections import Counter
from dataclasses import dataclass
import datetime
import random

from scipy.sparse import csr_matrix
from sklearn.feature_extraction.text import TfidfVectorizer

from fuzzy_artmap import FuzzyArtMap

@dataclass
class ProcessedCorpus:
    vectorized_corpus: csr_matrix
    document_corpus_map: dict[int: int]
    categories: dict[int: list[str]]


def get_tf_idf_twenty_newsgroup_corpus() -> ProcessedCorpus:
    """Helper function to vectorize the 20 Newsgroup corpus to TF-IDF features, and associate the vectorized documents with their categories"""
    twenty_newsgroup_vectorizer = TfidfVectorizer(input="content", encoding="latin1", stop_words='english', min_df=0.001, max_df=0.9)
    twenty_newsgroup_vectorized_corpus = twenty_newsgroup_vectorizer.fit_transform(newsgroup_data.data)
    categories = list(newsgroup_data.target_names)
    twenty_newsgroup_categories = {}
    for document_index, category_index in enumerate(newsgroup_data.target):
        twenty_newsgroup_categories[document_index]=categories[category_index]

    return ProcessedCorpus(vectorized_corpus = twenty_newsgroup_vectorized_corpus,
                           document_corpus_map = {index: index for index in range(twenty_newsgroup_vectorized_corpus.shape[0])},
                           categories = twenty_newsgroup_categories)

In [4]:
import torch
# Setup the valid (matching) & invalid (not matching) categories, complement encoded
valid_vector = torch.tensor([[1.0, 0.0]])
invalid_vector = torch.tensor([[0.0, 1.0]])

In [5]:
def get_test_input_and_output(doc_index, vector, categories, relevant_category):
    """Helper function to get the complement encoded input, and encoded label"""
    if relevant_category == categories[doc_index]:
        output_value = valid_vector
    else:
        output_value = invalid_vector
    
    complement_encoded_input = FuzzyArtMap.complement_encode(torch.from_numpy(vector.toarray()))
    return complement_encoded_input, output_value

In [6]:
def test_predictions(fuzzy_artmap, document_indexes, corpus, categories, relevant_category, document_corpus_index_map):
    """Count the True Positives (TP), True Negatives (TN), False Positives (FP), and False Negatives (FN) for predictions made by the model"""
    accuracy_counter = Counter({"TP": 0, "TN": 0, "FP": 0, "FN": 0})
    for corpus_index in document_indexes[100:]:  # Skip the first 100 documents used for training
        document_index = document_corpus_index_map[corpus_index]
        input_vector, class_vector = get_test_input_and_output(document_index, corpus[corpus_index], categories, relevant_category)
        prediction = fuzzy_artmap.predict(input_vector)
        if class_vector[0][0].item():
            if prediction[0][0][0].item():
                update = {"TP": 1}
            else:
                update = {"FN": 1}
        else:
            if prediction[0][0][0].item():
                update = {"FP": 1}
            else:
                update = {"TN": 1}
        accuracy_counter.update(update)
    print(accuracy_counter)
    return accuracy_counter

In [7]:
def calculate_metrics(accuracy_data, duration, number_of_relevant_documents):
    """Calculate accuracy, precision, recall, and speed metrics given the accuracy data"""
    total_documents_tested = sum(accuracy_data.values())
    accuracy = (accuracy_data["TP"] + accuracy_data["TN"]) / total_documents_tested
    precision = accuracy_data["TP"] / (accuracy_data["TP"] + accuracy_data["FP"])
    recall = accuracy_data["TP"] / (accuracy_data["TP"] + accuracy_data["FN"])
    recall_set = accuracy_data["TP"] / number_of_relevant_documents
    rate = total_documents_tested / duration.seconds
    print(f"accuracy: {accuracy}\nprecision: {precision}\nrecall: {recall}\nrecall (set): {recall_set}\ntotal relevant docs: {number_of_relevant_documents}\ntotal docs:{total_documents_tested}\nprediction rate (docs/second):{rate}")


In [8]:
def setup_twenty_newsgroup_corpus():
    global relevant_category
    relevant_category = "alt.atheism" 
    seed_indexes = [4000, 4001]
    processed_corpus = get_tf_idf_twenty_newsgroup_corpus()
    categories = {index: category for index, category in processed_corpus.categories.items() if index not in seed_indexes }
    categories[4000] = "alt.atheism"
    categories[4001] = "alt.atheism"
    shuffled_document_indexes = seed_indexes + random.sample(list(categories.keys()), len(categories))
    return processed_corpus.vectorized_corpus, processed_corpus.categories, shuffled_document_indexes, processed_corpus.document_corpus_map

In [9]:
processed_document_indexes = set()

def train_model(corpus, shuffled_document_indexes, categories, relevant_category, document_corpus_index_map):
    fuzzy_artmap = FuzzyArtMap(number_of_category_nodes=36, baseline_vigilance=0.95)
    training_split = Counter()
    for iteration_count, corpus_index in enumerate(shuffled_document_indexes[:100]):
        document_index = document_corpus_index_map[corpus_index]
        # print(f"{iteration_count} - {categories[document_index]}")
        training_split.update({''.join(categories[document_index]):1})
        input_vector, class_vector = get_test_input_and_output(document_index, corpus[corpus_index], categories, relevant_category)
        fuzzy_artmap.fit(input_vector, class_vector)
    processed_document_indexes.update(shuffled_document_indexes[:100])
    return fuzzy_artmap, training_split

In [None]:
# This cell trains a basic model using naive (random sample) one-shot offline training using 100 documents, 
# and then runs the prediction, calculating the performance metrics

relevant_category = "alt.atheism"

corpus, categories, shuffled_document_indexes, document_corpus_index_map = setup_twenty_newsgroup_corpus()
fuzzy_artmap, training_split = train_model(corpus, shuffled_document_indexes, categories, relevant_category, document_corpus_index_map)

start_predictions = datetime.datetime.now()
print(f"start predictions: {start_predictions}")
accuracy_data = test_predictions(fuzzy_artmap, shuffled_document_indexes, corpus, categories, relevant_category, document_corpus_index_map)

end_predictions = datetime.datetime.now()
prediction_duration = end_predictions-start_predictions
print(f"end predictions: {end_predictions} - elapsed: {prediction_duration}")

number_of_relevant_documents = len(list([i for i in shuffled_document_indexes[100:] if relevant_category in categories[document_corpus_index_map[i]]]))
calculate_metrics(accuracy_data, prediction_duration, number_of_relevant_documents)

start predictions: 2024-12-09 14:23:42.510335
Counter({'TN': 17951, 'FN': 796, 'TP': 1, 'FP': 0})
end predictions: 2024-12-09 14:24:12.170432 - elapsed: 0:00:29.660097
accuracy: 0.957542137828035
precision: 1.0
recall: 0.0012547051442910915
recall (set): 0.0012547051442910915
total relevant docs: 797
total docs:18748
prediction rate (docs/second):646.4827586206897


In [11]:
def query(fuzzy_artmap, corpus, categories, available_document_indexes, document_corpus_index_map):
    """Gets the predictions for the remaining unevaluated documents in the corpus"""
    working_indexes = list(available_document_indexes)

    predictions = []
    for corpus_index in working_indexes:
        document_index = document_corpus_index_map[corpus_index]
        input_vector, class_vector = get_test_input_and_output(document_index, corpus[corpus_index], categories, relevant_category)
        prediction, membership_degree = fuzzy_artmap.predict_with_membership(input_vector)
        if prediction[0][0].item():
            predictions.append((membership_degree, corpus_index, class_vector, input_vector))
    predictions.sort(key=lambda p: p[0], reverse=True)
    return predictions

In [18]:
def run_active_learning_test(setup_corpus):
    """Uses an active learning approach to query the 20 Newsgroups corpus for the specified category. The corpus is ranked, the top 100 (`batch_size`) 
    documents are evaluated against their ground truth label, and the model is updated after every judgement. The evaluated documents are removed from the
    available (unevaluated) documents in the corpus. Batch-level metrics are reported after each iteration. Evaluation stops when no more relevant documents
    are predicted in the remaining unevaluated documents."""
    print(f"start: {datetime.datetime.now()}")
    corpus, categories, shuffled_document_indexes, document_corpus_index_map = setup_corpus()
    available_document_indexes = set(shuffled_document_indexes[100:])
    number_of_relevant_documents = len(list([i for i in shuffled_document_indexes[100:] if relevant_category in categories[document_corpus_index_map[i]]]))

    print(f"start training: {datetime.datetime.now()}")    
    fuzzy_artmap, _ = train_model(corpus, shuffled_document_indexes, categories, relevant_category, document_corpus_index_map)
    
    found_relevant_documents = 0
    active_learning_iteration = 0
    has_candidates = True
    start_predictions = datetime.datetime.now()
    print(f"start active learning: {start_predictions}")
    batch_size = 100
    evaluated_document_count = 0
    while found_relevant_documents != number_of_relevant_documents and has_candidates:        
        relevant_documents_in_batch = 0
        candidates = query(fuzzy_artmap, corpus, categories, available_document_indexes, document_corpus_index_map)
        candidate_batch_size = 0
        for candidate in candidates[:batch_size]:
            # print(f"{datetime.datetime.now()} - training")
            evaluated_document_count += 1
            candidate_batch_size +=1
            fuzzy_artmap.fit(candidate[3], candidate[2])
            available_document_indexes.remove(candidate[1]) 
            if candidate[2][0,][0]:
                found_relevant_documents += 1
                relevant_documents_in_batch += 1

        if len(candidates) == 0:
            has_candidates = False
        active_learning_iteration += 1
        batch_recall = 0
        if has_candidates:
            batch_recall = relevant_documents_in_batch/candidate_batch_size
        print(f"{datetime.datetime.now()} - {active_learning_iteration} - {found_relevant_documents}/{number_of_relevant_documents} | batch recall: {batch_recall:.4f} | recall - {(found_relevant_documents/number_of_relevant_documents):.4f} precision - {(found_relevant_documents/evaluated_document_count):.4f} | {len(available_document_indexes)}")
    
    end_predictions = datetime.datetime.now()
    prediction_duration = end_predictions-start_predictions
    print(f"end active learning: {end_predictions} - elapsed: {prediction_duration}")
    print(f"number of Fuzzy ARTMAP Categories: {fuzzy_artmap.get_weight_a().shape[0]}")

In [19]:
# Warning, this can take a while to complete 5-15 minutes
run_active_learning_test(setup_twenty_newsgroup_corpus)

start: 2024-12-09 14:57:48.334254
start training: 2024-12-09 14:57:49.748282
start active learning: 2024-12-09 14:57:49.987630
2024-12-09 14:58:25.682251 - 1 - 15/793 | batch recall: 0.6000 | recall - 0.0189 precision - 0.6000 | 18723
2024-12-09 14:58:58.732468 - 2 - 88/793 | batch recall: 0.7300 | recall - 0.1110 precision - 0.7040 | 18623
2024-12-09 14:59:31.253182 - 3 - 173/793 | batch recall: 0.8500 | recall - 0.2182 precision - 0.7689 | 18523
2024-12-09 14:59:57.965199 - 4 - 255/793 | batch recall: 0.8200 | recall - 0.3216 precision - 0.7846 | 18423
2024-12-09 15:00:20.472799 - 5 - 337/793 | batch recall: 0.8200 | recall - 0.4250 precision - 0.7929 | 18323
2024-12-09 15:00:41.930662 - 6 - 424/793 | batch recall: 0.8700 | recall - 0.5347 precision - 0.8076 | 18223
2024-12-09 15:01:03.344509 - 7 - 498/793 | batch recall: 0.7400 | recall - 0.6280 precision - 0.7968 | 18123
2024-12-09 15:01:23.638760 - 8 - 558/793 | batch recall: 0.6000 | recall - 0.7037 precision - 0.7697 | 18023
202