In [41]:
import pandas as pd
from numpy import dot
from numpy.linalg import norm
from sentence_transformers import SentenceTransformer
from transformers import T5Tokenizer

In [42]:
df = pd.read_csv('data/tfn_base.csv')

In [43]:
def cosine_similarity(a, b):
    return dot(a, b)/(norm(a)*norm(b))

In [44]:
def split_text_into_chunks(tokenizer, text, max_tokens=100):
    sentences = text.split('. ')
    current_chunk = ""
    chunks = []

    for sentence in sentences:
        # Add the sentence to the current chunk
        test_chunk = current_chunk + ('. ' if current_chunk else '') + sentence
        
        # Tokenize the chunk and check its length
        tokenized_chunk = tokenizer.encode(test_chunk, add_special_tokens=False)
        if len(tokenized_chunk) <= max_tokens:
            current_chunk = test_chunk
        else:
            # Add the current chunk to the list and start a new chunk
            chunks.append(current_chunk)
            current_chunk = sentence  # Start a new chunk with the current sentence

    # Add the last chunk if it exists
    if current_chunk:
        chunks.append(current_chunk)

    return chunks

In [45]:
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [46]:
curr_passage = ""
chunks = []
longest_token_cnt = 0
for row_idx, row in df.iterrows():
    passage = row['Passage']
    statement = row['Statement']
    
    if passage.strip() != 'a':
        curr_passage = passage
        chunks = split_text_into_chunks(tokenizer, curr_passage)

    encoded_statement = model.encode(statement)

    similarities = []
    for idx, chunk in enumerate(chunks):
        similarities.append(
            (
                cosine_similarity(
                    encoded_statement,
                    model.encode(chunk)
                ),
                idx
            )
        )

    similarities.sort(reverse=True)
    
    new_passage = []
    token_cnt = 0
    for similarity, idx in similarities:
        chunk = chunks[idx]
        chunk_len = len(tokenizer.encode(chunk, add_special_tokens=False))

        if token_cnt + chunk_len > 412:
            break
        
        token_cnt += chunk_len
        new_passage.append(chunk)
    
    if token_cnt > longest_token_cnt:
        longest_token_cnt = token_cnt
        
    if len(new_passage) == 0:
        new_passage.append(chunks[similarities[0][1]])
    
    df.at[row_idx, 'Passage'] = '\n'.join(new_passage)

print("Longest token count for a Passage", longest_token_cnt)

Longest token count for a Passage 412


In [47]:
df.to_csv('data/tfn_data_v2.csv', index=False)