## Data Preprocessing - Sentence Token Reduction

In [5]:
# Global Variables
DRIVE_HOME = '/content/drive'
CODE_HOME = '/MyDrive/LawDigestAI'

# Drive Mount
from google.colab import drive
drive.mount(DRIVE_HOME)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Pipeline
- invoke the input file (which has sentences exceeding token limit(512))
- Using sentence transformers, generate embeddings for both catchphrases and sentences for each row
- compute cosine similarity between catchphrase embeddings and sentences embeddings
- Now, loop over for each sentence in the matrix and take the top similarity and append to a list. Sort it in decreasing order.
- check and truncate the rest of the sentence that exceeds the max token limit (512)

In [6]:

import pandas as pd

import nltk
nltk.download('punkt_tab')

from sentence_transformers import SentenceTransformer, util
from transformers import T5Tokenizer
from nltk.tokenize import sent_tokenize

from tqdm import tqdm
import time

EMBEDDING_MODEL = 'all-MiniLM-L6-v2'
TOKENIZER_MODEL = 't5-small'


# Load models
model = SentenceTransformer(EMBEDDING_MODEL)  # Lightweight embedding model
tokenizer = T5Tokenizer.from_pretrained(TOKENIZER_MODEL)  # Tokenizer for T5

# Function to process each row
def select_sentences(row):
    sentences = sent_tokenize(row['sentences'])  # Split sentences by period and newlines
    catchphrases = row['catchphrases']

    # Encode sentences and catchphrases
    sentence_embeddings = model.encode(sentences, convert_to_tensor=True)
    phrase_embeddings = model.encode(catchphrases, convert_to_tensor=True)

    # Compute cosine similarities
    similarities = util.cos_sim(phrase_embeddings, sentence_embeddings)

    # Prepare to select useful sentences
    selected_sentences = []
    selected_tokens = 0
    max_token_limit = 500
    added_sentences = set()  # Track added sentences

    # Rank sentences by similarity
    similarity_scores = [(sentences[i], max(similarities[:, i]).item()) for i in range(len(sentences))]
    sorted_sentences = sorted(similarity_scores, key=lambda x: x[1], reverse=True)

    # Add sentences while staying within the token limit
    for sentence, score in sorted_sentences:
        if sentence in added_sentences:
            continue  # Skip duplicates

        # Calculate token count for the sentence
        tokens = tokenizer(sentence, truncation=True, return_tensors="pt")["input_ids"]
        token_count = tokens.shape[1]

        # Stop if the token limit is exceeded
        if selected_tokens + token_count > max_token_limit:
            break

        # Add the sentence and update token count
        selected_sentences.append(sentence)
        selected_tokens += token_count
        added_sentences.add(sentence)

    # Return the selected sentences as a single string
    return " ".join(selected_sentences)



def track_time(df, func, axis=1):
    tqdm.pandas(desc="Processing rows")  # Initialize tqdm progress bar
    start_time = time.time()

    # Apply the function with tqdm progress tracking
    result = df.progress_apply(func, axis=axis)

    # Calculate total processing time
    end_time = time.time()
    total_time = end_time - start_time

    print(f"Processing completed in {total_time:.2f} seconds.")
    return result

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [7]:
input_file_path = f"{DRIVE_HOME}{CODE_HOME}/2_Generation/filtered_data/filtered_summ_data.csv"
summ_data = pd.read_csv(input_file_path)

# summ_data_samp = summ_data[:10].copy()
summ_data_samp = summ_data.copy()

# Run Pipeline
summ_data_samp['selected_sentences'] = track_time(summ_data_samp, select_sentences)

output_file_path = f"{DRIVE_HOME}{CODE_HOME}/2_Generation/catchphrase_Extraction/preprocessed_summ_data.csv"
summ_data_samp.to_csv(output_file_path, index=False)
print(summ_data_samp.shape, summ_data_samp.columns)

Processing rows: 100%|██████████| 2871/2871 [05:01<00:00,  9.51it/s]


Processing completed in 301.83 seconds.
(2871, 9) Index(['filename', 'name', 'AustLII', 'catchphrases', 'sentences',
       'word_count', 'num_catchphrases', 'total_tokens', 'selected_sentences'],
      dtype='object')
