# Sentence Selection Module
1. SentenceBERT (SBERT) is used to generate embeddings of each sentence.
2. UMAP dimensionality reduction is completed on the embeddings.
3. The reduced embeddings undergo density-based spatial clustering. Note that
some points are categorized as outliers (not part of a cluster), and the
clustering algorithm can create an arbitrary number of clusters.
4. Cosine similarity is used to analyze the sentence embeddings of each
cluster and to sort them from most central to least central.
5. Our final output is lists of topically related sentences, sorted by relevance
to the topic (or some form of centrality).

In [2]:
from sentence_transformers import SentenceTransformer, util
import spacy
import torch
import numpy as np
import hdbscan
import umap

from dev_data import texts as dev_texts

model = SentenceTransformer('all-MiniLM-L6-v2')
nlp_spacy = spacy.load('en_core_web_lg')

In [3]:
def get_sentences(text):
    """
    Use spaCy for sentence segmentation
    """
    sentences = []
    doc = nlp_spacy(text)
    for sent in doc.sents:
        sentences.append(str(sent))
    return np.array(sentences)

def get_encodings(sentences, batch_size=32):
    return model.encode(sentences, batch_size=batch_size)


In [62]:
def simple_rank(encodings, cutoff=None):
    dist_matrix = util.cos_sim(encodings, encodings)
    importance = torch.sum(dist_matrix, dim=1)
    return torch.argsort(importance, descending=True)[0:cutoff]


def cluster_rank(sentences, encodings, batch_size=32, cluster_size=15, outlier_size=1):
    sm_encodings = umap.UMAP(n_neighbors=cluster_size, n_components=2, metric='cosine').fit_transform(encodings)

    # complete clustering:
    cluster_model = hdbscan.HDBSCAN(min_cluster_size=cluster_size, min_samples=outlier_size, metric='euclidean')
    cluster_model.fit(sm_encodings)
    cluster_labels = cluster_model.labels_
    num_clusters = cluster_labels.max()

    cluster_sentences = []
    cluster_indexes = []
    for i in range(num_clusters + 1):
        sub_idx = np.where(i == cluster_labels)
        sub_sm_encodings = encodings[sub_idx]
        rank_idx = simple_rank(sub_sm_encodings)
        full_idx = np.take(sub_idx, rank_idx)
        cluster_indexes.append(full_idx)
        cluster_sentences.append(sentences[full_idx])

    return cluster_sentences, cluster_indexes

sentences = get_sentences(dev_texts[2][6])
encodings = get_encodings(sentences)
sr, ir = cluster_rank(sentences, encodings)


In [66]:
sentences = get_sentences(dev_texts[3][0])
encodings = get_encodings(sentences)
sm_encodings = umap.UMAP(n_neighbors=3, n_components=2, metric='cosine').fit_transform(encodings)
cluster_model = hdbscan.HDBSCAN(min_cluster_size=15, min_samples=5, metric='euclidean')
cluster_model.fit(sm_encodings)
cluster_labels = cluster_model.labels_
num_clusters = cluster_labels.max()
print(cluster_labels)


[ 1  2  2  1  1  1  1  1  1  1  1  1  1  3  3  3  3  3  3  2  3  2  3  3
  3  3  3  3  3  3  3  3  3  3  3  3  3  3  3  3  3  3  3  2  2 -1  3  2
  2  2  2  2  2  1 -1 -1  1  1  1  1  2  2  0  0  0  3  0  2  0 -1  0  0
 -1 -1  0  0  0  0  0  0  0  0  0  0  0  0  0  2  1  1  1  2  2  2 -1  1
  1  2  2  2  1  2  2 -1 -1 -1  2 -1 -1  3  2  0  3 -1 -1 -1  1  1  1  1
  0  0]


In [67]:
groups = [[] for i in range(num_clusters + 2)]
for i in range(len(cluster_labels)):
    groups[cluster_labels[i] + 1].append(tuple(sm_encodings[i]))

In [68]:
for i, g in enumerate(groups):
    print('-----', i - 1)
    for a, b in g:
        print(f'{a}\t{b}')


----- -1
10.391180038452148	7.555512428283691
6.757736682891846	10.509603500366211
6.564599990844727	10.547747611999512
6.450501918792725	3.9066169261932373
7.142343044281006	3.4039154052734375
6.992650985717773	3.540253162384033
6.846142768859863	10.491952896118164
7.212472915649414	3.417005777359009
7.177783012390137	2.9122517108917236
6.936582088470459	3.1730384826660156
7.089361190795898	3.4692554473876953
6.623500823974609	3.6061384677886963
7.0337677001953125	3.0501561164855957
6.711082935333252	3.3700127601623535
10.311153411865234	7.389620780944824
----- 0
2.7597098350524902	3.66933536529541
3.07487154006958	3.1179378032684326
2.9080698490142822	3.904606580734253
3.036565065383911	4.054200649261475
3.2551238536834717	3.244378089904785
2.6024527549743652	2.3673624992370605
2.062544822692871	2.5509848594665527
3.334275722503662	1.9211127758026123
3.196087121963501	2.011439085006714
2.7283103466033936	2.4556941986083984
2.389570474624634	2.3532121181488037
2.203596591949463	2.2435

In [61]:
# for a, b in list(zip(sentences, sm_encodings)):
#     print(a)
#     print(tuple(b))
#     print('-----')