Hybrid search strategy

- for each existing doc in weaviate
- Get paragraph text
- Use tokenizer to check which CAPITALIZED words are not tokenized (TERM)
- Get a dictionary to count all TEMRs in the doc
- Append top-5 term to the doc



In [1]:
from pathlib import Path
from askem.preprocessing import HaystackPreprocessor
from transformers import DPRContextEncoderTokenizer
from typing import List, Optional
from pydantic import BaseModel
import hashlib
import unicodedata
import json


[2023-09-15 20:50:47,845] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
TEST_DOCS = Path("data/covid1000").glob("*.txt")
OUTPUT_DIR = Path("data/hybrid_retrieval/experiment_1")

preprocessor = HaystackPreprocessor()
tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.


Helper functions

In [6]:
unk_token_id = tokenizer.convert_tokens_to_ids("[UNK]")

100

In [10]:
tokenizer("hello world sadsdadw".split(), is_split_into_words=True)
# Need to get original words from the unk token locations
unk_token_id = 100
def get_unknown_words(token_ids: List[int]) -> List[str]:
    return [tokenizer.convert_ids_to_tokens([token_id])[0] for token_id in token_ids if token_id == unk_token_id]

In [11]:
get_unknown_words("hello world sadsdadw")

[]

In [3]:
def strip_punctuation(text: str) -> str:
    return "".join([c for c in text if c.isalnum() or c.isspace()])

def remove_diacritics(text: str) -> str:
    nfkd_form = unicodedata.normalize('NFKD', text)
    return ''.join([c for c in nfkd_form if not unicodedata.combining(c)])

def get_unknown_words(text, tokenizer, min_length = 3, top_k: int = 3) -> list:
    """Get words that are not tokenized by a tokenizer."""

    # Preprocess text
    text = strip_punctuation(text).lower()
    text = remove_diacritics(text)

    all_words = set(text.split())

    tokenized = tokenizer(text)["input_ids"]
    tokenized_words = set(tokenizer.decode(tokenized).split())

    non_tokenized_words = all_words - tokenized_words
    non_tokenized_words = [word for word in non_tokenized_words if len(word) >= min_length]

    if not non_tokenized_words:
        return None
    
    # Count the number of non-tokenized words
    counts = {word: text.count(word) for word in non_tokenized_words}
    
    return sorted(counts, key=counts.get, reverse=True)[:top_k]


def get_all_cap_words(text: str, min_length: int = 3, top_k: int = 3) -> list:
    """Get capitalized words in a text, sorted by number of occurrence."""

    text = strip_punctuation(text)
    text = remove_diacritics(text)
    
    words = text.split()
    all_cap_words = [word for word in words if word.isupper() and len(word) >= min_length]

    if not all_cap_words:
        return None
    
    # Count the number of all caps words
    counts = {word: text.count(word) for word in all_cap_words}

    # Return top-k most frequent all caps words
    return sorted(counts, key=counts.get, reverse=True)[:top_k]


How many fields are too many?

## To-Dos

1. What is the distribution of n CAP and n NON-TOKENIZED looks like?
2. Same as 1, but at article level.   

In [4]:
class Paragraph(BaseModel):

    id: Optional[str] = None
    paper_id: str
    text: str
    non_tokenized_words: Optional[list] = None
    all_cap_words: Optional[list] = None

    # These are from parent article
    article_non_tokenized_words: Optional[list] = None
    article_all_cap_words: Optional[list] = None

    def __init__(self, **data) -> None:
        super().__init__(**data)
        self.id = hashlib.md5(self.text.encode()).hexdigest()

    def save(self, path: Optional[Path] = None) -> None:
        if not path:
            path = OUTPUT_DIR / f"{self.id}.json"
        path.parent.mkdir(parents=True, exist_ok=True)
        with open(path, "w") as f:
            f.write(self.json(indent=4))

Process all test documents

In [None]:
CAPITALIZED_TERMS_COUNT = dict()
NON_TOKENIZED_TERMS_COUNT = dict()

def update_count(d: dict, words: Optional[List[str]]) -> None:
    if not words:
        return None
    
    for word in words:
        if word in d:
            d[word] += 1
        else:
            d[word] = 1


for i, article in enumerate(TEST_DOCS):
    print(f"Processing {i}/1000: {article}...")

    paragraphs = preprocessor.run(input_file=article, topic="covid", doc_type="paragraph")

    # Keep track of article-level information
    this_outputs = []
    this_article_non_tokenized_words_count = {}
    this_article_all_cap_words_count = {}

    for paragraph in paragraphs:
        text = paragraph["text_content"]
        non_tokenized_words = get_non_tokenized_words(text, tokenizer)
        all_cap_words = get_all_cap_words(text)

        this_outputs.append(
            Paragraph(
                paper_id = paragraph["paper_id"],
                text = text,
                non_tokenized_words = non_tokenized_words,
                all_cap_words = all_cap_words,
            )
        )

        if non_tokenized_words:
            update_count(this_article_non_tokenized_words_count, non_tokenized_words)
            update_count(NON_TOKENIZED_TERMS_COUNT, non_tokenized_words)

        if all_cap_words:
            update_count(this_article_all_cap_words_count, all_cap_words)
            update_count(CAPITALIZED_TERMS_COUNT, all_cap_words)

    this_article_non_tokenized_words = sorted(this_article_non_tokenized_words_count, key=this_article_non_tokenized_words_count.get, reverse=True)[:3]
    print(this_article_non_tokenized_words)
    this_article_all_cap_words = sorted(this_article_all_cap_words_count, key=this_article_all_cap_words_count.get, reverse=True)[:3]
    print(this_article_all_cap_words)

    # Append article-level information to each paragraph
    for output in this_outputs:
        output.article_non_tokenized_words = this_article_non_tokenized_words
        output.article_all_cap_words = this_article_all_cap_words
        output.save()



# 13m 30s for 1000 docs locally

In [None]:
print(len(CAPITALIZED_TERMS_COUNT))
print(len(CAPITALIZED_TERMS_COUNT))

with open("capitalized.json", "w") as f:
    f.write(json.dumps(CAPITALIZED_TERMS_COUNT, indent=4))

with open("non_tokenized.json", "w") as f:
    f.write(json.dumps(NON_TOKENIZED_TERMS_COUNT, indent=4)) 