In [None]:
import nltk
from sentence_transformers import SentenceTransformer, util
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np
from kneed import KneeLocator

# Download NLTK data {1st time set up}
# nltk.download('punkt')

# Step 1: Read the text file and extract sentences
def read_sentences_from_file(file_path):
    with open(file_path, 'r') as file:
        text = file.read()
    sentences = nltk.sent_tokenize(text)
    return sentences

# Step 2: Calculate embeddings for each sentence
def calculate_embeddings(sentences):
    model = SentenceTransformer('all-MiniLM-L6-v2')
    embeddings = model.encode(sentences)
    return embeddings

# Step 3: Determine the optimal number of clusters using Elbow Method and Silhouette Score
def determine_optimal_clusters(embeddings):
    distortions = []
    silhouette_scores = []
    max_clusters = min(10, len(embeddings))  # Limit the maximum number of clusters for smaller datasets

    for n_clusters in range(2, max_clusters + 1):
        kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(embeddings)
        distortions.append(kmeans.inertia_)
        silhouette_avg = silhouette_score(embeddings, kmeans.labels_)
        silhouette_scores.append(silhouette_avg)
    
    # Use the KneeLocator for Elbow Method
    elbow_k = KneeLocator(range(2, max_clusters + 1), distortions, curve='convex', direction='decreasing').elbow
    
    # Find the number of clusters with the highest silhouette score
    silhouette_k = max(range(2, max_clusters + 1), key=lambda k: silhouette_scores[k-2])
    
    # Choose the best method (you can modify this selection logic)
    optimal_n_clusters = max(elbow_k, silhouette_k)
    
    return optimal_n_clusters

# Step 4: Cluster the sentences based on the optimal number of clusters
def cluster_sentences(embeddings, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(embeddings)
    clusters = kmeans.labels_
    return clusters

# Step 5: Write the clusters to a text file
def write_clusters_to_file(clustered_sentences, file_path):
    with open(file_path, 'w') as file:
        for cluster_id, cluster_sentences in clustered_sentences.items():
            file.write(f'Cluster {cluster_id + 1}:\n')
            for sentence in cluster_sentences:
                file.write(f'  - {sentence}\n')
            file.write('\n')

# Function to perform MMR-based retrieval
def mmr_retrieve(sentences, embeddings, query, top_n=3, diversity=0.7):
    model = SentenceTransformer('all-MiniLM-L6-v2')
    query_embedding = model.encode([query])[0]
    mmr = util.semantic_search(query_embedding, embeddings, top_k=top_n, score_function=util.dot_score)
    
    # Get top_n results with diversity
    selected = []
    for idx in mmr[0]:
        selected.append(sentences[idx['corpus_id']])
        if len(selected) >= top_n:
            break
    return selected

In [None]:
# Main function
def main(input_file_path, output_file_path, query):
    sentences = read_sentences_from_file(input_file_path)
    embeddings = calculate_embeddings(sentences)
    
    # Determine optimal number of clusters
    optimal_n_clusters = determine_optimal_clusters(embeddings)
    print(f'Optimal number of clusters: {optimal_n_clusters}')
    
    # Cluster sentences
    clusters = cluster_sentences(embeddings, optimal_n_clusters)
    clustered_sentences = {i: [] for i in range(optimal_n_clusters)}
    for sentence, cluster_id in zip(sentences, clusters):
        clustered_sentences[cluster_id].append(sentence)
    
    # Write clusters to file
    write_clusters_to_file(clustered_sentences, output_file_path)
    print(f'Clusters have been written to {output_file_path}')
    
    # Perform MMR-based retrieval
    results = mmr_retrieve(sentences, embeddings, query)
    print(f'Top {len(results)} results for the query "{query}":')
    for result in results:
        print(result)

# Example usage
input_file_path = 'recognized_speech.txt'  
output_file_path = 'grouped_paragraphs.txt' 
query = "What are the key points about chicken burger?"
main(input_file_path, output_file_path, query)