In [10]:
import os
import re

import torch
import pandas as pd
from ipywidgets import Button, Text, HBox, VBox, HTML, Dropdown, IntSlider, FloatText
from IPython.display import display, HTML as IPyHTML, clear_output

from retrieval.models import load_colbert_and_tokenizer, ColBERTInference, TfIdf
from retrieval.indexing import IndexConfig, index, ColBERTRetriever
from retrieval.data import Passages
from retrieval.model_understanding import visualize

In [2]:
def getWiklinkToPIDs(wikis, passages, pids):
    return [wikis.loc[passages[passages['PID'] == pid].WID.values[0]].url for pid in pids]

def getPassagesForPIDs(passages, pids):
    return [passages[passages['PID'] == pid].passage.values[0] for pid in pids]

def display_text_with_link(counter, html_address, text):
    html_link = f'<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; margin-bottom: 10px;"><div style="display: inline-block; width: 30px; text-align: center; margin-right: 10px;">{counter}.</div><a href="{html_address}" target="_blank" style="color: #1a0dab; text-decoration: none;">{text}</a></div>'
    display(HTML(html_link))

def display_best_search_results(question, pairs):
    title = f'<h2 style="color: #1a0dab;"><i class="fa fa-search" style="margin-right: 10px;"></i>Best search results for question: "{question}"</h2>'
    display(HTML(title))

    for i, (website_html, text) in enumerate(pairs, start=1):
        display_text_with_link(i, website_html, text)

In [3]:
%%html
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css">

In [4]:
# let the user select some the database and model
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps else "cpu"
print(f"Running on: {DEVICE}")

CHECKPOINT_PATHS = ["data/colbertv2.0"]
WIKI_PATHS = ["data/fandoms_qa/harry_potter/all", "data/fandoms_qa/witcher/all", "data/fandoms_qa/fandoms_all/all"]

dropdown_checkpoint = Dropdown(options=CHECKPOINT_PATHS, description='Select a model checkpoint:')
display(dropdown_checkpoint)

dropdown_wiki = Dropdown(options=WIKI_PATHS, description='Select a wiki:')
display(dropdown_wiki)

Running on: mps


Dropdown(description='Select a model checkpoint:', options=('data/colbertv2.0',), value='data/colbertv2.0')

Dropdown(description='Select a wiki:', options=('data/fandoms_qa/harry_potter/all', 'data/fandoms_qa/witcher/a…

In [5]:
checkpoint_path = dropdown_checkpoint.value
passages_path = os.path.join(dropdown_wiki.value, "passages.tsv")
index_path = os.path.join(dropdown_wiki.value, f"{os.path.basename(dropdown_checkpoint.value)}_passages.idx")
wiki_json_path = os.path.join(dropdown_wiki.value, "wiki.json")


index_cfg = IndexConfig(
        passages_path=passages_path,
        checkpoint_path=checkpoint_path,
        index_path=index_path,
        use_gpu="cpu" not in DEVICE,
        device=DEVICE,
        dtype=torch.float16,
        batch_size=8
)
print(index_cfg)
passages = Passages(passages_path)

passages_df = pd.read_csv(passages_path, sep ='\t')
wikis_df = pd.read_json(wiki_json_path, orient='index')


# initialize the models and index the selected wiki
inference = ColBERTInference.from_pretrained(checkpoint_path)
inference.tokenizer.doc_maxlen = 512
# we can't skip the punctuation, otherwise the visualization would be offset
inference.colbert.config.skip_punctuation = False
inference.colbert.skiplist = None
retriever = ColBERTRetriever(inference, device=DEVICE, passages=passages)


if os.path.exists(index_path):
    print(f"Loading precomputed indices: {index_path}")
    retriever.indexer.load(index_path)
else:
    print(f"Starting to precompute the indices! (GPU recommended otherwise RIP X.X)")
    retriever.indexer = index(inference, index_cfg, store=True)
    print(f"Stored to precomputed indices under {index_path}")
                           
# use float32 when running on cpu, since not all functions support float16 on cpu (e.g. addmm) 
if "cpu" in DEVICE:
    retriever.to(dtype=torch.float32)

    

IndexConfig(passages_path='data/fandoms_qa/harry_potter/all/passages.tsv', checkpoint_path='data/colbertv2.0', index_path='data/fandoms_qa/harry_potter/all/colbertv2.0_passages.idx', batch_size=8, use_gpu=True, device='mps', dtype=torch.float16)


[2023-06-23 18:14:07][INFO] Detected BERT Tokenizer. Using unused tokens for [Q]/[D] tokens
[2023-06-23 18:14:07][INFO] Detected BERT Tokenizer. Using unused tokens for [Q]/[D] tokens
[2023-06-23 18:14:08][INFO] Detected ColBERTv2 checkpoint. Loading the model!
[2023-06-23 18:14:09][INFO] Successfully loaded weights for last ColBERTv2 layer!


Loading precomputed indices: data/fandoms_qa/harry_potter/all/colbertv2.0_passages.idx


[2023-06-23 18:14:11][INFO] Successfully loaded the precomputed indices. Set dtype to torch.float16!


In [15]:
passages_g = []
question_g = ''

def search(question, k, mode):
    if mode.lower() == "tf_idf":
        sims, pids  = retriever.tf_idf_rank([question], k=k)[0]
    elif mode.lower() == "rerank":
        sims, pids  = retriever.rerank([question], k=k)[0]
    elif mode.lower() == "full_retrieval":
        sims, pids  = retriever.full_retrieval([question], k=k)[0]
    else:
        raise ValueError()
    
    links = getWiklinkToPIDs(wikis_df, passages_df, pids.tolist())
    passage_matches = getPassagesForPIDs(passages_df, pids.tolist())
    global passages_g, question_g
    passages_g = passage_matches
    question_g = question

    return list(zip(links, sims.tolist(), passage_matches))

In [None]:
clear_output()

def display_text_with_link(counter, html_address, sim, text):
    matches = re.findall(r'\[(.*)\]', text)
    if matches:
        html_address = html_address +'#' + matches[0]
        html_address = html_address.replace(' ', '_')

    green_component = 100 + int(155 * (1 - (sim + 1) / 2))
    color = f'rgb({green_component}, 255, {green_component})'
    sim_box = f'<div style="display: inline-block; width: 50px; height: 20px; background-color: {color}; text-align: center; line-height: 20px;">{round(sim, 4)}</div>'
    html_link = f'<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; margin-bottom: 10px;"><div style="display: inline-block; text-align: center; margin-right: 10px;">{counter}.</div><div style="display: inline-block;">{sim_box}</div><a href="{html_address}" target="_blank" style="color: #1a0dab; text-decoration: none;"><br/>{text}</a></div>'
    display(HTML(html_link))


def update_pairs(button):
    question = text_input.value

    # Call your Python function 'f' to update the pairs list based on the question
    updated_pairs = search(question, k=topK_slider.value, mode=dropdown_mode.value)

    # Clear the existing output, including search results
    clear_output()

    # Display the input field, magnifying glass button, "Best search results for question" block
    display(VBox([input_box,topK_slider,dropdown_mode]))

    # Display the updated search results
    for i, (website_html, sim, text) in enumerate(updated_pairs, start=1):
        display_text_with_link(i, website_html, sim, text)

# Create a slider where the user can select the maximal amount of passages that should be returned"  
topK_slider = IntSlider(min=10, max=1000, step=10, description='Select how many passages should be returned:')
display(topK_slider)

# Create a slider where the user can select the desired quering methode
dropdown_mode = Dropdown(options=["tf_idf", "rerank", "full_retrieval"], description='Select a quering methode:')
display(dropdown_mode)
        
# Create the input field and magnifying glass button
text_input = Text(description='', layout={'width': '500px', 'height': '75px'}, style={'font-size': '18px'})
search_button = Button(description='', icon='search', layout={'width': '40px', 'height': '40px'})

# Add the click event handler to the magnifying glass button
search_button.on_click(update_pairs)

# Create a horizontal box to hold the input field and magnifying glass button
input_box = HBox([text_input, search_button])

# Create the "Best search results for question" block
best_results_title = HTML('<h2 style="color: #1a0dab;">Best search results for question:</h2>')

# Display the input field, magnifying glass button, "Best search results for question" block
display(VBox([input_box]))



# example queries:
#   How many children did Harry Potter have?
#   Where was Rowling born?
#   Where was the author of Harry Potter born?
#   What is the origin of the name Potter? -> first solution bei full_retrieval mit großem k ist richtig, aber man sieht dass externe Links entfernt wurden durch den Wikiextractor

VBox(children=(HBox(children=(Text(value='Who is harry potter?', layout=Layout(height='75px', width='500px')),…

HTML(value='<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; …

HTML(value='<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; …

HTML(value='<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; …

HTML(value='<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; …

HTML(value='<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; …

HTML(value='<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; …

HTML(value='<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; …

HTML(value='<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; …

HTML(value='<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; …

HTML(value='<div style="background-color: #f9f9f9; padding: 10px; border: 1px solid #ddd; border-radius: 5px; …

In [23]:
number_input = FloatText(description='Enter the k of the passage you want to know more about:', value=1)
display(number_input)

FloatText(value=1.0, description='Enter the k of the passage you want to know more about:')

In [24]:
clear_output()
# TODO: make it not hardcoded xD

test_query = question_g
test_passage = passages_g[int(number_input.value) - 1]

# test_passage = '"[Main family] James Potter was the only child of Fleamont and Euphemia Potter. James married Lily Evans and they had one child, Harry, who famously defeated Lord Voldemort. Harry eventually married Ginny Weasley and they had three children: James Sirius, Albus Severus, and Lily Luna Potter."'
html_vis = visualize(inference, test_query, test_passage, k=2, similarity="cosine")
print("Kernel Density Estimation:")
display(HTML(html_vis[0]))
print("Absolute count:")
display(HTML(html_vis[1]))
print(" Accumulated Similarities:")
display(HTML(html_vis[2]))

Kernel Density Estimation:


HTML(value='<span style="background-color:#aacfe5ff"> [CLS]</span><span style="background-color:#95c5dfff"> [D…

Absolute count:


HTML(value='<span style="background-color:#f7fbffff"> [CLS]</span><span style="background-color:#f7fbffff"> [D…

 Accumulated Similarities:


HTML(value='<span style="background-color:#f7fbffff"> [CLS]</span><span style="background-color:#f7fbffff"> [D…

In [None]:
clear_output()

test_query = "How many children did Harry Potter have?"
test_passage = '["December 29 2007, 12:00am, The Times] "She felt compelled to map out the futures of the surviving members of the Weasley/Potter clan in a family tree. Readers will know that Harry and Ginny marry and have three children. Luna Lovegood, Harry’s dreamy friend, produces two children, with her naturalist husband, Rolf. Bill Weasley and Fleur Delacour have three children, and George Weasley sires a daughter, and a son named Fred after his twin who died in the battle of Hogwarts. “I can’t help it,” the author says in JK Rowling . . . a Year in the Life. “It was like running a race and you get to the finishing line and you’re running too fast to stop, so I do know what happened afterwards and I couldn’t stop my imagination doing that."'
html_vis = visualize(inference, test_query, test_passage, k=2, similarity="cosine")
print("Kernel Density Estimation:")
display(HTML(html_vis[0]))
print("Absolute count:")
display(HTML(html_vis[1]))
print(" Accumulated Similarities:")
display(HTML(html_vis[2]))