In [None]:
# Notebook to evaluate the distribution of fact check lengths to determine appropriate token length for Gemma
# Add library support
import pandas as pd
import jsonlines
import json
pd.set_option('display.max_colwidth', None)
from tqdm import tqdm
tqdm.pandas()
# Test whether checkthat parsing works correctly
from claimrobustness import utils, defaults
import importlib
importlib.reload(utils)
%config InlineBackend.figure_format = 'svg'
%matplotlib inline

In [None]:
# Load the targets
# Load the test data used for generating misinformation edits
data = utils.load_data(dataset='fact-check-tweet')
test_queries, test_qrels = data["test"]
targets = data["targets"]

In [None]:
targets.info()

In [None]:
targets.index

In [None]:
test_ranks.shape

In [None]:
targets.shape

In [None]:
sample_rank = test_ranks[0]

In [None]:
sample_rank[:10]

In [None]:
targets.index.max()

In [None]:
np.array(sample_rank).max()

In [None]:
# Check how many indices in sample_rank are present in targets.index
present_indices = [idx for idx in sample_rank if idx in targets.index]
num_present_indices = len(present_indices)

print(f"Number of indices in sample_rank that are present in targets.index: {num_present_indices}")

In [None]:
import torch
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-gemma')

In [None]:
train_sentences = targets['target'].to_numpy()

In [None]:
# Print the original sentence.
print(' Original: ', train_sentences[11])

# Print the sentence split into tokens.
print('Tokenized: ', tokenizer.tokenize(train_sentences[10]))

# Print the sentence mapped to token ids.
print('Token IDs: ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(train_sentences[10])))

In [None]:
# Measure the maximum length of the sentences
max_len = 0
lengths_en = []
# For every sentence...
for sent in train_sentences:
    # Tokenize the text and add `[CLS]` and `[SEP]` tokens.
    input_ids = tokenizer.encode(sent, add_special_tokens=True)
    # Record the length.
    lengths_en.append(len(input_ids))
    # Update the maximum sentence length.
    max_len = max(max_len, len(input_ids))
print('Max sentence length: ', max_len)

In [None]:
import numpy as np 

In [None]:
print('Min length: {:,} tokens'.format(min(lengths_en)))
print('Max length: {:,} tokens'.format(max(lengths_en)))
print('Median length: {:,} tokens'.format(int(np.median(lengths_en))))

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Visualize the distribution of sequence lengths

custom_palette = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']

# Set style for scientific publications
sns.set_theme(style='white', font_scale=1.0)  # Reduced font scale
sns.set_palette("Set2")

# Increase the plot size and font size
sns.set(font_scale=1)
# Create the plot with a white background
plt.figure(figsize=(10.5, 4.27), facecolor='white')

# Plot the distribution of sequence lengths without grids
sns.distplot(lengths_en, kde=False, rug=False, hist_kws={'alpha': 1, 'edgecolor': 'black'})

# Add title and labels
plt.title('Distribution of Sequence Lengths After including Context')
plt.xlabel('Sequence Length')
plt.ylabel('Number of Samples')

# Emphasize the x-axis and y-axis
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

# Keep only left and bottom border lines
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)

# Display the plot
plt.tight_layout()  # Ensures labels don't overlap
# plt.savefig("sequence_lengths_distribution_with_context.svg", dpi=300)  # Save the figure
plt.show()

In [None]:
2048

In [None]:
max_len = 2048

# Count the number of sequences that are longer than `max_len` tokens.
num_truncated = np.sum(np.greater(lengths_en, max_len))

# Compare this to the total number of training sentences.
num_sentences = len(lengths_en)
prcnt = float(num_truncated) / float(num_sentences)

print('{:,} of {:,} sentences ({:.1%}) in the training set are longer than {:} tokens.'.format(num_truncated, num_sentences, prcnt, max_len))

In [None]:
# Check the maximum length
max_length = tokenizer.model_max_length
print(f"Maximum tokenizer length: {max_length}")