In [None]:
from solr_manager import SolrManager

url = 'http://localhost:8983/solr/'
core = 'articles'
db = 'sqlite:///../data/articles.db'
solr = SolrManager(url, core, db)

In [None]:
# run if you want to change an existent schema (in case of error, wait for core to load)
solr.reload_core()

In [None]:
# run if you want to delete all documents in the core (in case of error, wait for core to load)
solr.clear_documents()

Select **only one** of the following schemas to run!

In [None]:
schema = 'schema.json'
solr.submit_schema(schema)
solr.apply_stopwords('stopwords.txt')
solr.apply_synonyms('synonyms.txt')

In [None]:
schema = 'schema_weak.json'
solr.submit_schema(schema)
solr.apply_stopwords('stopwords.txt')

In [None]:
solr.index_articles(100)

In [None]:
import matplotlib.pyplot as plt
from statistics import mean
import numpy as np

def calculate_and_display_metrics(data, num_results, title='Sample Query', ax=None):
    """
    This function will run the results of a query and display the metrics.
    """

    if len(data) == 0:
        print('No results found.')
        return
    
    elif len(data) != num_results:
        print('Number of results does not match the number of expected results.')
        return

    precisions = []
    recalls = []
    true_positives = 0
    total_predicted_positives = 0
    total_actual_positives = sum(data)
    average_precisions = []

    for i in data:
        total_predicted_positives += 1
        if i == 1:
            true_positives += 1

        precision = true_positives / total_predicted_positives
        recall = true_positives / total_actual_positives if total_actual_positives > 0 else 0

        precisions.append(precision)
        recalls.append(recall)

        average_precisions.append(precision) if i == 1 else None

    # Calculate average precision
    average_precision = mean(average_precisions)

    interpolated_precisions = []
    interpolated_recalls = recalls.copy()
    for i in interpolated_recalls:
        for idx, value in enumerate(interpolated_recalls):
            if value >= i:
                index = idx
                break
        
        actual_value = max(precisions[index:])
        
        interpolated_precisions.append(np.round(actual_value,2))

    precision_at_n = precisions[9]

    steps = range(1, len(data) + 1)

    ax[0].plot(steps, precisions, label='Precision', color='red', marker='o')
    ax[0].plot(steps, recalls, label='Recall', color='blue', marker='o')
    ax[0].axhline(y=average_precision, color='orange', linestyle='--', label='Avg Precision')
    ax[0].axhline(y=precision_at_n, color='darkgreen', linestyle='--', label='Precision at 10')
    ax[0].text(len(steps) + 1.6, average_precision - 0.01, f'{average_precision:.2f}', 
                horizontalalignment='center', color='orange', fontsize=10)
    ax[0].text(len(steps) + 1.6, precision_at_n - 0.01, f'{precision_at_n:.2f}', 
                horizontalalignment='center', color='darkgreen', fontsize=10)
    ax[0].set_xlabel('Step')
    ax[0].set_ylabel('Metric Value')
    ax[0].set_title('Precision and Recall over Steps')
    ax[0].set_ylim([-0.05, 1.05])
    ax[0].set_xticks(range(1, len(data) + 1, 1))
    ax[0].legend()
    ax[0].grid(axis='y')
    
    ax[1].plot(interpolated_recalls, interpolated_precisions, label='Precision-Recall Curve', color='red', marker='o')
    ax[1].set_xlabel('Recall')
    ax[1].set_ylabel('Precision')
    ax[1].set_title('Precision-Recall Curve (Interpolated)')
    ax[1].set_ylim([-0.05, 1.05])
    ax[1].set_xlim([-0.05, 1.05])
    ax[1].legend()
    ax[1].grid()

def plot_all_metrics(relevance_V1, num_results_V1, title_V1, relevance_V2, num_results_V2, title_V2):
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))

    calculate_and_display_metrics(relevance_V1, num_results_V1, title_V1, ax=axs[0])
    calculate_and_display_metrics(relevance_V2, num_results_V2, title_V2, ax=axs[1])

    fig.text(0.5, 0.94, title_V1, ha='center', fontsize=14)
    fig.text(0.5, 0.49, title_V2, ha='center', fontsize=14)

    plt.subplots_adjust(hspace=0.4)
    plt.show()

In [None]:
# global common variables
results_file = '../query_results/results.md'
max_rows = 20

In [None]:
solr.configure_suggester()

In [None]:
# suggestions
input = 'semiconduct'
suggestions = solr.suggest(input)
for suggestion in suggestions:
    print(suggestion)

In [None]:
# more like this
similar_docs = solr.more_like_this(id='1', mltfl='article_title')
num_results = len(similar_docs)
with open(results_file, 'w', encoding='utf-8') as file:
    file.write('--------------------------------------------------\n\n')
    file.write(f'Number of results: {num_results}\n')
    for index, result in enumerate(similar_docs, start=1):
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_link') + '\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_date') + '\n\n')
        file.write(result.get('article_text') + '\n\n')

In [None]:
input = ''
user_query = solr.user_query(input, category='biotechnology', rows=max_rows)
num_results = len(user_query['results'])
with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Query: ' + str(user_query['params']['q']) + '\n\n')
    file.write('--------------------------------------------------\n\n')
    file.write(f'Number of results: {num_results}\n')
    for index, result in enumerate(user_query['results'], start=1):
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_link') + '\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_date') + '\n\n')
        file.write(result.get('article_text') + '\n\n')
for param in user_query['params']:
    print(param + ': ' + str(user_query['params'][param]))
print('Time: ' + str(user_query['time']) + ' ms')
for company in user_query['company_results']:
    print(company)

In [None]:
# search for articles addressing stock declines in Tesla (V1)
params = {
    'defType': 'edismax',
    'qf': 'article_title article_text',
    'q': 'Tesla dipped', 
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V1 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V1}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_text') + '\n\n')

In [None]:
# search for articles addressing stock declines in Tesla (V2)
params = {
    'defType': 'edismax',
    'qf': 'article_title article_text',
    'q': 'Tesla dipped',
    'bq': 'article_text:Tesla^2',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V2 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V2}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_text') + '\n\n')

In [None]:
num_results_V1 = 20
num_results_V2 = 20

relevance_V1 = [
    1, 1, 1, 1, 1,
    0, 0, 0, 0, 0,
    0, 0, 0, 0, 1,
    0, 0, 1, 0, 1
]
relevance_V2 = [
    1, 1, 1, 1, 1,
    0, 1, 1, 1, 1,
    0, 1, 1, 0, 1,
    1, 1, 0, 0, 0
]

plot_all_metrics(relevance_V1, num_results_V1, 'Tesla dipped', relevance_V2, num_results_V2, 'Tesla dipped (boost + synonyms)')

In [None]:
# search for articles addressing last stocks in Tesla (V1)
params = {
    'defType': 'edismax',
    'q': "tesla",
    'fq': "article_date:[2020-01-01T00:00:00Z TO 2023-12-31T23:59:59Z]",
    'qf': 'article_title article_text',
    'fl': 'article_title article_date article_text',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V1 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V1}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_date') + '\n\n')
        file.write(result.get('article_text') + '\n\n')

In [None]:
# search for articles addressing last stocks in Tesla (V2)
params = {
    'defType': 'edismax',
    'q': "tesla",
    'fq': "article_date:[2020-01-01T00:00:00Z TO 2023-12-31T23:59:59Z]",
    'qf': 'article_title article_text',
    'fl': 'article_title article_date article_text',
    'bf': 'recip(ms(NOW,article_date),1.65e-9,20,1)',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V2 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V2}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_date') + '\n\n')
        file.write(result.get('article_text') + '\n\n')

In [None]:
num_results_V1 = 20
num_results_V2 = 20

relevance_V1 = [
    1, 1, 1, 0, 0,
    0, 1, 0, 0, 0,
    0, 0, 0, 0, 1,
    0, 1, 0, 0, 0
]
relevance_V2 = [
    1, 1, 1, 1, 1,
    1, 1, 0, 0, 0,
    0, 0, 1, 0, 1,
    1, 1, 0, 0, 0
]

plot_all_metrics(relevance_V1, num_results_V1, 'Tesla', relevance_V2, num_results_V2, 'Tesla (date boost)')

In [None]:
# search for articles with stock news on the cyber security sector (V1)
params = {
    'q': "article_text:(cyber security)",
    'fl': 'article_title article_text',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V1 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V1}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_text') + '\n\n')

In [None]:
# search for articles with stock news on the cyber security sector (V2)
params = {
    'q': "{!parent which='doc_type:article'}company_keywords:(cyber^3 security)",
    'fl': 'article_title article_text article_companies [child] company_name',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V2 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V2}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_text') + '\n\n')
        for company in result.get('article_companies'):
            file.write(company.get('company_name') + '\n')

In [None]:
num_results_V1 = 20
num_results_V2 = 20

relevance_V1 = [
    0, 1, 0, 1, 0,
    0, 1, 1, 0, 0,
    0, 0, 1, 0, 0,
    1, 0, 0, 0, 0
]
relevance_V2 = [
    1, 1, 1, 1, 1,
    0, 1, 1, 0, 0,
    1, 1, 1, 1, 1,
    1, 1, 1, 1, 0
]

plot_all_metrics(relevance_V1, num_results_V1, 'Cyber security', relevance_V2, num_results_V2, 'Cyber security (company fields + boost)')

In [None]:
# search for articles with stocks in AMD (V1)
params = {
    'defType': 'edismax',
    'q': "advanced micron devices",
    'fq': "doc_type:article",
    'qf': "article_title article_text", 
    'fl': 'article_title article_text article_companies [child] company_name',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V1 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V1}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_text') + '\n\n')
        for company in result.get('article_companies'):
            file.write(company.get('company_name') + '\n')

In [None]:
# search for articles with stocks in AMD (V2)
params = {
    'defType': 'edismax',
    'q': "advanced micron~1 devices",
    'fq': "doc_type:article",
    'qf': "article_title article_text", 
    'fl': 'article_title article_text article_companies [child] company_name',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V2 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V2}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_text') + '\n\n')
        for company in result.get('article_companies'):
            file.write(company.get('company_name') + '\n')

In [None]:
num_results_V1 = 20
num_results_V2 = 20

relevance_V1 = [
    0, 1, 1, 1, 0,
    1, 1, 0, 1, 0,
    0, 0, 0, 0, 0,
    0, 1, 1, 0, 0
]
relevance_V2 = [
    1, 1, 1, 0, 1,
    1, 1, 0, 1, 1,
    1, 1, 1, 1, 1,
    1, 1, 1, 1, 0
]

plot_all_metrics(relevance_V1, num_results_V1, 'Advanced micron devices', relevance_V2, num_results_V2, 'Advanced micron devices (fuzzy search)')

In [None]:
# field boosts
params = {
    'q': "article_title:Tesla^10 AND article_text:dipped",
    'fl': 'article_title article_text',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V1 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V1}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_text') + '\n\n')

In [None]:
# term boosts
params = {
    'q': "article_text:(healthcare covid^2)",
    'fl': 'article_title article_text',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V1 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V1}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_text') + '\n\n')

In [None]:
# independent boost
params = {
    'defType': 'edismax',
    'q': "tesla",
    'qf': 'article_title article_text',
    'fl': 'article_title article_date article_text',
    'bf': 'recip(ms(NOW,article_date),1.65e-9,20,1)',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V1 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V1}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_date') + '\n\n')
        file.write(result.get('article_text') + '\n\n')

In [None]:
# phrase match with slop
params = {
    'defType': 'edismax',
    'q': "advanced micro devices stocks",
    'qf': 'article_title article_text',
    'fl': 'article_title article_text article_companies [child] company_name',
    'mm': '75%',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V1 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V1}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_text') + '\n\n')
        for company in result.get('article_companies'):
            file.write(company.get('company_name') + '\n')

In [None]:
# wildcards and fuzyness
params = {
    'q': "article_text:(bio* feel~2)",
    'fl': 'article_title article_text',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V1 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V1}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_text') + '\n\n')

In [None]:
# proximity search
params = {
    'q': "doc_type:article AND article_text:\"Tesla fell\"~3",
    'fl': 'article_title article_text',
    'rows': max_rows
}
results = solr.solr.search(**params)
num_results_V1 = len(results)

with open(results_file, 'w', encoding='utf-8') as file:
    file.write('Parameters: ' + str(params) + '\n\n')
    file.write('--------------------------------------------------\n')
    file.write(f'Number of results: {num_results_V1}\n')
    for index, result in enumerate(results, start=1):
        print(result)
        file.write(f'\n------------------------ Result: {index} --------------------------\n\n')
        file.write(result.get('article_title') + '\n\n')
        file.write(result.get('article_text') + '\n\n')

In [None]:
solr.close()