# Intersect Clustering

## 0. Preliminaries

### 0.1. Import dependencies

In [2]:
import pandas as pd
import numpy as np

### 0.2. Load data

In [3]:
data = pd.read_csv("../data/ohsumed-clean-disease.csv")

In [4]:
text_col = "text"
label_col = "Category"

In [5]:
data[label_col].value_counts()

Neoplasms                                              6070
Cardiovascular Diseases                                4657
Nervous System Diseases                                2844
Bacterial Infections and Mycoses                       2540
Digestive System Diseases                              2144
Pathological Conditions                                1924
Respiratory Tract Diseases                             1650
Urologic and Male Genital Diseases                     1583
Disorders of Environmental Origin                      1572
Musculoskeletal Diseases                               1376
Immunologic Diseases                                   1308
Nutritional and Metabolic Diseases                     1049
Virus Diseases                                          983
Female Genital Diseases and Pregnancy Complications     902
Skin and Connective Tissue Diseases                     798
Eye Diseases                                            639
Hemic and Lymphatic Diseases            

## 1. Create Words data structures

In [6]:
word_encoding_dict = dict()
label_encoding_dict = dict()

presence_word_dict = dict()
presence_label_dict = dict()

num_words = num_labels = 0
for i, (text, label) in enumerate(zip(data[text_col], data[label_col])):
    # Label encoding
    if label in label_encoding_dict:
        presence_label_dict[label_encoding_dict[label]].add(i)
    else:
        num_labels += 1
        label_encoding_dict[label] = num_labels
        presence_label_dict[num_labels] = {i}
    # Text encoding
    splitted = text.split(" ")
    for word in splitted:
        if word in word_encoding_dict:
            presence_word_dict[word_encoding_dict[word]].add(i)
        else:
            num_words += 1
            word_encoding_dict[word] = num_words
            presence_word_dict[num_words] = {i}

## 2. Predict

In [86]:
matrix_intersection = {}

for i, (text, target) in enumerate(zip(data[text_col], data[label_col])):
    print(f"{i}/{len(data)}")
    splitted = list(set(text.split(" ")))
    # Create clusters
    clusters = []
    for j in range(len(splitted) - 1):
        a = word_encoding_dict[splitted[j]]
        for k in range(j+1, len(splitted)):
            b = word_encoding_dict[splitted[k]]
            min_index = min(a, b)
            max_index = max(a, b)
            
            if min_index in matrix_intersection:
                if max_index in matrix_intersection[min_index]:
                    intersection = matrix_intersection[min_index][max_index]
                else:
                    intersection = len(presence_word_dict[a] & presence_word_dict[b])
                    matrix_intersection[min_index][max_index] = intersection
            else:
                matrix_intersection[min_index] = dict()
                intersection = len(presence_word_dict[a] & presence_word_dict[b])
                matrix_intersection[min_index][max_index] = intersection
            if intersection > 0:
                clusters.append({
                    "a": min_index,
                    "b": max_index,
                    "intersection": intersection
                })
            #if len(intersection) > 0:
            #    clusters.append({
            #        "a": min_index,
            #        "b": max_index,
            #        "intersection": len(intersection)
            #    })
    # Sort the clusters
    sorted_clusters = sorted(clusters, key=lambda x: x['intersection'], reverse=True)
    # Create the partition
    covered_elements = set()
    remaining = len(splitted)
    partition = []
    k = 0.5
    probs = {}
    for cluster in sorted_clusters:
        a = cluster['a']
        b = cluster['b']
        if a not in covered_elements and b not in covered_elements:
            #partition.append(cluster)
            covered_elements.add(a)
            covered_elements.add(b)
            remaining -= 2
            counts = data.iloc[list(presence_word_dict[a] & 
                                    presence_word_dict[b])]['Category'].value_counts()
            # Compute the probabilities for the cluster
            for count in counts.iteritems():
                label = label_encoding_dict[count[0]]
                prob = (k + count[1]) / (2 * k + counts.sum())
                if label in probs:
                    probs[label].append(prob)
                else:
                    probs[label] = [prob]

        if remaining < 1: # There is no one left to add with a intersection
            break
    # Compute the prior probabilities (log-space to prevent precision issues)
    prioris = {}
    for label, prob_list in probs.items():
        p = 0  # Start with 0 for summation in log-space

        # Sum the log-probabilities for the present elements
        for prob in prob_list:
            p += np.log(prob)

        # Account for the clusters where the label is absent
        absent_clusters = len(partition) - len(prob_list)
        if absent_clusters > 0:
            p += absent_clusters * np.log(k / (2 * k))

        prioris[label] = p

    # Compute the global prior counts for each label
    count_labels = data[label_col].value_counts().to_dict()
    count_labels = {label_encoding_dict[label]: count_labels[label] for label in count_labels}

    # Naive Bayes denominator (in log-space for stability)
    log_denominator = None
    for label in count_labels:
        log_value = prioris[label] + np.log(count_labels[label])
        log_denominator = np.logaddexp(log_denominator, log_value) if log_denominator is not None else log_value

    # Determine the label with the maximum posterior probability
    max_prob = -np.inf
    max_label = None
    posterioris = {}
    for label in count_labels:
        log_numerator = prioris[label] + np.log(count_labels[label])
        posterior_prob = np.exp(log_numerator - log_denominator)  # Convert back to probability space
        posterioris[label] = posterior_prob
        if posterior_prob > max_prob:
            max_prob = posterior_prob
            max_label = label

    # Result
    print(f"Predicted Label: {max_label}, Probability: {max_prob} | Real: {label_encoding_dict[target]}")
    if i > 100:
        break

0/34389
Predicted Label: 22, Probability: 0.9745370508103507 | Real: 1
1/34389
Predicted Label: 4, Probability: 0.9951238636412494 | Real: 1
2/34389
Predicted Label: 22, Probability: 0.9999999526941548 | Real: 1
3/34389
Predicted Label: 22, Probability: 0.8419260805609841 | Real: 1
4/34389
Predicted Label: 14, Probability: 0.9134630886035676 | Real: 1
5/34389
Predicted Label: 14, Probability: 0.987387246363313 | Real: 1
6/34389
Predicted Label: 22, Probability: 0.9241297230711591 | Real: 1
7/34389
Predicted Label: 4, Probability: 0.5613257636363793 | Real: 1
8/34389
Predicted Label: 22, Probability: 0.9491182104829627 | Real: 1
9/34389
Predicted Label: 22, Probability: 0.9993455624269243 | Real: 1
10/34389
Predicted Label: 14, Probability: 0.9999997481262317 | Real: 1
11/34389
Predicted Label: 4, Probability: 0.8509332031697635 | Real: 1
12/34389
Predicted Label: 22, Probability: 0.9261462487212238 | Real: 1
13/34389
Predicted Label: 22, Probability: 0.9960840588476098 | Real: 1
14/343