In [None]:
import torch
import warnings
import spacy

from funct.utils import *
from funct.data_loading import *
from funct.data_processing import *
from funct.data_manipulation import *

from sentence_transformers import SentenceTransformer, CrossEncoder

In [None]:
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*compiled with flash attention.*")
set_manual_seed()

batch_size = 64

# check if spaCy model is already downloaded
try:
    nlp = spacy.load("en_core_web_md")
except OSError:
    print("Downloading spaCy model...")
    from spacy.cli import download
    download("en_core_web_md")
    nlp = spacy.load("en_core_web_md")

#### Make sure to run the code on "cuda" runtime

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Running on {device}')

sentence_retriver = SentenceTransformer('BAAI/bge-m3', device=device)
cross_encoder = CrossEncoder('mixedbread-ai/mxbai-rerank-xsmall-v1', device=device)

#### Change the event list to your own event list

In [None]:
event_no_list_all = ["001", "002", "003", "004", "005", "006", "007", "008", "009", "010", "011", "012", "013", "014", "015", "016", "017", "018"]

# used for testing
# event_no_list_all = ["018"]

In [None]:
def process_day(day, event_no, event_data, sentence_retriever_list, cross_encoder_list, topic_model_list, log):
    """ Process a single day of data for a specific event number."""
    
    # retrieve the components
    sentence_retriver = sentence_retriever_list
    cross_encoder = cross_encoder_list
    topic_model = topic_model_list
    
    # load the data for the day
    dataAsDF = event_data[day]['data']
    
    # get the search text and sanitize the 'text' column
    dataAsDF['search_text'] = dataAsDF['text'].apply(sanitize_data)
    
    # prepare queries for processing
    queryAsDF = prepare_queries(queryAsDF=event_data[day]['query'])
    
    # filter out the redundant data
    filtered_data = filter_redundant_data(dataAsDF=dataAsDF, event_no=event_no, day=day)

    # add default columns
    add_default_columns(filtered_data)
    
    if len(filtered_data) == 0:
        if log:
            print("No data found after filtering")
        return
    
    if len(filtered_data) > 100:
        # first step of topic modeling
        try:
            relevant_topics = assign_topics(filtered_data=filtered_data, topic_model=topic_model)
            relevant_data_topic_1 = select_relevant_topics(filtered_data=relevant_topics, 
                                                queryAsDF=queryAsDF, 
                                                sentence_retriever=sentence_retriver, 
                                                topic_model=topic_model, 
                                                top_percentile=50, 
                                                threshold=0.5, 
                                                print_details=log, 
                                                add_outliers=True,) 
            torch.cuda.empty_cache()
            
        except Exception as e:
            if log:
                print(f"Error in first step of topic modeling: {e}")
            relevant_data_topic_1 = filtered_data
    else:
        relevant_data_topic_1 = filtered_data
    
    if len(relevant_data_topic_1) == 0:
        if log:
            print("No relevant data found")
        return
    
    # applied BM25 + CrossEncoder to further filter the data
    top_k = 200
    selected_data_bm25 = process_selection_event_day(dataAsDF=relevant_data_topic_1, 
                                                    queryAsDF=queryAsDF, 
                                                    sentence_retriver=sentence_retriver, 
                                                    cross_encoder=cross_encoder, 
                                                    batch_size=batch_size, 
                                                    event_no=event_no, 
                                                    device=device, 
                                                    top=top_k)
        
    if len(selected_data_bm25) == 0:
        if log:
            print("No data found after BM25 + CrossEncoder")
        return
    
    # remove added columns to reset the data selection and add them again for the next step
    remove_default_columns(selected_data_bm25)
    add_default_columns(selected_data_bm25)
    
    top_k = min(200, int(len(selected_data_bm25) * 0.1))
    
    selected_data = process_selection_event_day(dataAsDF=selected_data_bm25, 
                                                queryAsDF=queryAsDF, 
                                                sentence_retriver=sentence_retriver, 
                                                cross_encoder=cross_encoder, 
                                                retriever_type="dense",
                                                batch_size=batch_size, 
                                                event_no=event_no, 
                                                dynamic_score=90,
                                                device=device, 
                                                top=top_k)
    
    selected_data = selected_data.sort_values(by=['cross-score'], ascending=False)

    if len(selected_data) == 0:
        if log:
            print("No data found after dense retrieval")
        return
    
    if len(selected_data) == 0:
        if log:
            print("No data found after dense retrieval")
        return
    
    save_results(dataframe=selected_data, csv_path=f"./output/{event_no}_{day}_data.csv", json_path=f"./output/results/{event_no}_{day}.json")
    log_statistics(event_no, day, dataAsDF, filtered_data, selected_data_bm25, selected_data, relevant_data_topic_1, log)


In [None]:
def process_event_day(event_list, sentence_retriever_list, cross_encoder_list, topic_model_list, log):
    """ Process the events and days in the event list."""
    
    for event_no in event_list:

        crisis_dataset = CrisisFactsDataset(event_list=[event_no])
        event_data = crisis_dataset.get_dataset_event_no_connection(event_no=event_no)
        
        for day in event_data:
            try:
                process_day(day=day, 
                            event_no=event_no, 
                            event_data=event_data, 
                            sentence_retriever_list=sentence_retriever_list, 
                            cross_encoder_list=cross_encoder_list, 
                            topic_model_list=topic_model_list, 
                            log=log)
            except Exception as e:
                if log:
                    print(f"Error in processing day {day}: {e}")
                continue


In [None]:
topic_model = create_topic_model()

process_event_day(event_list=event_no_list_all, 
                sentence_retriever_list=sentence_retriver, 
                cross_encoder_list=cross_encoder, 
                topic_model_list=topic_model, 
                log=True)