In [None]:
from tqdm.auto import tqdm
import random
from statistics import mean, median
import numpy as np
import scipy.sparse as sp
import re

In [None]:
dataset_dir = ''

#### Alphanumeric filtering of train and test splits

In [7]:
train_split = [x.strip() for x in open(f'{dataset_dir}/subword-qac/data/aol/full/train.query.txt').readlines()]

In [8]:
orig_query_to_freq = {}
for query in train_split:
    if query not in orig_query_to_freq:
        orig_query_to_freq[query] = 1
    else:
        orig_query_to_freq[query] += 1

In [9]:
def AlphanumericFilter(input):
    input = re.sub("[^a-zA-Z0-9]", " ", input)
    input = re.sub("\s+", " ", input)
    return input

In [10]:
AlphanumericFilter("abc              def  ______123___")

'abc def 123 '

In [11]:
processed_query_to_freq = {}
for query, freq in tqdm(orig_query_to_freq.items()):
    query = AlphanumericFilter(query)
    if query.strip() != '':
        if query not in processed_query_to_freq:
            processed_query_to_freq[query] = freq
        else:
            processed_query_to_freq[query] += freq

100%|██████████| 8862181/8862181 [00:34<00:00, 254808.73it/s]


In [12]:
f = open('./dataset/raw/train.query.alphanumeric_filtered.python.txt', 'w')
for query, freq in processed_query_to_freq.items():
    f.write(query+'\t'+str(freq)+'\n')
f.close()

In [13]:
test_split = [x.strip() for x in open(f'{dataset_dir}/subword-qac/data/aol/full/test.query.txt').readlines()]

In [None]:
orig_query_to_freq = {}
for query in test_split:
    if query not in orig_query_to_freq:
        orig_query_to_freq[query] = 1
    else:
        orig_query_to_freq[query] += 1
        
processed_query_to_freq = {}
for query, freq in tqdm(orig_query_to_freq.items()):
    query = AlphanumericFilter(query)
    if query.strip() != '':
        if query not in processed_query_to_freq:
            processed_query_to_freq[query] = freq
        else:
            processed_query_to_freq[query] += freq
        
f = open('./dataset/raw/test.query.alphanumeric_filtered.python.txt', 'w')
for query, freq in processed_query_to_freq.items():
    f.write(query+'\t'+str(freq)+'\n')
f.close()

### Processing alphanumeric filtered files to generate dataset

In [2]:
NUM_SUFFIXES = 10000000
max_chars_to_add = 10
min_chars_to_add = 1
max_suffix_length = 100

In [3]:
train_queries_path = './dataset/raw/train.query.alphanumeric_filtered.python.txt'
train_queries_lines = [x.strip() for x in open(train_queries_path).readlines()]

In [None]:
train_query_to_freq = {}
for i in tqdm(range(len(train_queries_lines))):
    line = train_queries_lines[i].split('\t')
    try:
        query, freq = line[0], int(line[1])
    except:
        print(line)
    if query not in train_query_to_freq:
        train_query_to_freq[query] = freq
    else:
        train_query_to_freq[query] = max(freq, train_query_to_freq[query])

In [5]:
def GenerateSuffixesFromQuery(query):
    words = query.split()
    suffixes = []
    for i in range(len(words)):
        suffix = ' '.join(words[i:])
        if len(suffix)<=max_suffix_length:
            suffixes.append(suffix)
    return suffixes

In [None]:
all_suffixes_to_freq = {}
for query, freq in tqdm(train_query_to_freq.items()):
    suffixes = GenerateSuffixesFromQuery(query)
    for suffix in suffixes:
        if suffix in all_suffixes_to_freq:
            all_suffixes_to_freq[suffix] += freq
        else:
            all_suffixes_to_freq[suffix] = freq

In [7]:
all_suffixes_to_freq_sorted = sorted(all_suffixes_to_freq.items(), key=lambda x:-x[1])

In [8]:
final_suffixes = []
for i in range(NUM_SUFFIXES):
    final_suffixes.append(all_suffixes_to_freq_sorted[i][0])


In [9]:
final_suffixes_set = set(final_suffixes)

### SAMPLE ONE PREFIX PER QUERY, UNTIL YOU GET AT LEAST ONE SUFFIX

In [10]:
cnt = 0
for query in train_query_to_freq.keys():
    if len(query.split(' ')) != (query.count(' ')+1):
        cnt += 1

In [None]:
query = 'abc def ghi '
suffix = 'def ghi '
max(0, len(query)-len(suffix)-2)

In [12]:
def GenerateShortlistSuffixes(query, suffix_set):
    req_suffixes = []
    words = query.split(' ')
    min_index = len(query)
    for i in range(len(words)):
        suffix = ' '.join(words[i:])
        if suffix in suffix_set:
            start_index = max(0, len(query)-len(suffix)-2)
            req_suffixes.append([suffix, start_index])
            min_index = min(min_index, start_index)
    return req_suffixes, min_index

In [None]:
train_prefix_to_suffixes = {}
for query, freq in tqdm(train_query_to_freq.items()):
    suffixes, min_index = GenerateShortlistSuffixes(query, final_suffixes_set)
    if min_index == len(query):
        continue
    prefix_end_index = random.randint(min_index, max(len(query)-2, 0))
    prefix = query[:prefix_end_index+1]
    if prefix not in train_prefix_to_suffixes:
        train_prefix_to_suffixes[prefix] = {}
    for j in range(len(suffixes)):
        suffix, index = suffixes[j][0], suffixes[j][1]
        if prefix_end_index>=index:
            train_prefix_to_suffixes[prefix][suffix] = freq

In [None]:
len(train_prefix_to_suffixes.keys())

In [15]:
unique_suffixes = set([])
for suffixes in train_prefix_to_suffixes.values():
    for suffix in suffixes.keys():
        unique_suffixes.add(suffix)

In [None]:
len(unique_suffixes)

#### Generate all ground truth pairs to have full gt

In [None]:
for query, freq in tqdm(train_query_to_freq.items()):
    suffixes, min_index = GenerateShortlistSuffixes(query, final_suffixes_set)
    if min_index == len(query):
        continue
    
    for i in range(len(suffixes)):
        suffix, index = suffixes[i][0], suffixes[i][1]
        for j in range(index, len(query)-1):
            prefix = query[:index+1]
            if prefix in train_prefix_to_suffixes:
                if suffix not in train_prefix_to_suffixes[prefix]:
                    train_prefix_to_suffixes[prefix][suffix] = freq
    

In [None]:
len(train_prefix_to_suffixes.keys())

In [19]:
unique_suffixes = set([])
for suffixes in train_prefix_to_suffixes.values():
    for suffix in suffixes.keys():
        unique_suffixes.add(suffix)

In [None]:
len(unique_suffixes)

### Generate test prefix, suffix pairs

In [37]:
test_queries_path = './dataset/raw/test.query.alphanumeric_filtered.python.txt'
test_queries_lines = [x.strip() for x in open(test_queries_path).readlines()]

In [None]:
test_query_to_freq = {}
for i in tqdm(range(len(test_queries_lines))):
    line = test_queries_lines[i].split('\t')
    query, freq = line[0], int(line[1])
    if query not in test_query_to_freq:
        test_query_to_freq[query] = freq
    else:
        test_query_to_freq[query] = max(freq, test_query_to_freq[query])

In [39]:
def GenerateShortlistSuffixes(query, suffix_set):
    req_suffixes = []
    words = query.split(' ')
    min_index = len(query)
    for i in range(len(words)):
        suffix = ' '.join(words[i:])
        if suffix in suffix_set:
            start_index = max(0, len(query)-len(suffix)-2)
            req_suffixes.append([suffix, start_index])
            min_index = min(min_index, start_index)
    return req_suffixes, min_index

In [None]:
test_prefix_to_suffixes = {}
for query, freq in tqdm(test_query_to_freq.items()):
    suffixes, min_index = GenerateShortlistSuffixes(query, unique_suffixes)
    if min_index == len(query):
        continue
    prefix_end_index = random.randint(min_index, max(len(query)-2, 0))
    prefix = query[:prefix_end_index+1]
    gt_suffixes = {}
    for j in range(len(suffixes)):
        suffix, index = suffixes[j][0], suffixes[j][1]
        if prefix_end_index>=index:
            gt_suffixes[suffix] = freq
    test_prefix_to_suffixes[prefix] = gt_suffixes
            

In [None]:
test_unique_suffixes = set([])
for suffixes in test_prefix_to_suffixes.values():
    for suffix in suffixes.keys():
        test_unique_suffixes.add(suffix)
print(len(test_unique_suffixes))

In [None]:
for query, freq in tqdm(test_query_to_freq.items()):
    suffixes, min_index = GenerateShortlistSuffixes(query, unique_suffixes)
    if min_index == len(query):
        continue
    
    for i in range(len(suffixes)):
        suffix, index = suffixes[i][0], suffixes[i][1]
        for j in range(index, len(query)-1):
            prefix = query[:index+1]
            if prefix in test_prefix_to_suffixes:
                if suffix not in test_prefix_to_suffixes[prefix]:
                    test_prefix_to_suffixes[prefix][suffix] = freq
    

In [None]:
test_unique_suffixes = set([])
for suffixes in test_prefix_to_suffixes.values():
    for suffix in suffixes.keys():
        test_unique_suffixes.add(suffix)
print(len(test_unique_suffixes))

In [None]:
len(test_prefix_to_suffixes.keys())

#### Removing seen test points

In [47]:
test_prefix_to_suffixes_v2 = {}
for prefix, suffixes in test_prefix_to_suffixes.items():
    if prefix not in train_prefix_to_suffixes:
        test_prefix_to_suffixes_v2[prefix] = suffixes

In [48]:
test_prefix_to_suffixes = test_prefix_to_suffixes_v2

In [None]:
len(test_prefix_to_suffixes.keys())

In [None]:
test_unique_suffixes = set([])
for suffixes in test_prefix_to_suffixes.values():
    for suffix in suffixes.keys():
        test_unique_suffixes.add(suffix)
print(len(test_unique_suffixes))

### Removing labels with no test point from train as well

In [None]:
len(unique_suffixes & test_unique_suffixes)

In [None]:
updated_train_prefix_to_suffixes = {}
for prefix, suffixes in tqdm(train_prefix_to_suffixes.items()):
    updated_lbl_dict = {}
    for suffix, freq in train_prefix_to_suffixes[prefix].items():
        if suffix in test_unique_suffixes:
            updated_lbl_dict[suffix] = freq
    if len(updated_lbl_dict.keys())>0:
        updated_train_prefix_to_suffixes[prefix] = updated_lbl_dict

In [None]:
len(updated_train_prefix_to_suffixes.keys())

In [54]:
updated_train_suffixes = set([])
for prefix, suffixes in tqdm(updated_train_prefix_to_suffixes.items()):
    for suffix, value in suffixes.items():
        updated_train_suffixes.add(suffix)

100%|██████████| 3922479/3922479 [00:03<00:00, 981918.29it/s] 


In [None]:
len(updated_train_suffixes)

#### Stats for final dataset

In [56]:
# labels per data point

lbls_per_point = []
for p, s in updated_train_prefix_to_suffixes.items():
    lbls_per_point.append(len(s.keys()))

In [None]:
min(lbls_per_point), max(lbls_per_point), mean(lbls_per_point), median(lbls_per_point)

In [58]:
# data points per label

points_per_suffix = {}
for p, ss in updated_train_prefix_to_suffixes.items():
    for s in ss.keys():
        if s in points_per_suffix:
            points_per_suffix[s] += 1
        else:
            points_per_suffix[s] = 1

In [59]:
points_per_suffix = list(points_per_suffix.values())

In [None]:
min(points_per_suffix), max(points_per_suffix), mean(points_per_suffix), median(points_per_suffix)

In [None]:
# labels per data point

tst_lbls_per_point = []
for p, s in test_prefix_to_suffixes.items():
    tst_lbls_per_point.append(len(s.keys()))
print(min(tst_lbls_per_point), max(tst_lbls_per_point), mean(tst_lbls_per_point), median(tst_lbls_per_point))

# data points per label

tst_points_per_suffix = {}
for p, ss in test_prefix_to_suffixes.items():
    for s in ss.keys():
        if s in tst_points_per_suffix:
            tst_points_per_suffix[s] += 1
        else:
            tst_points_per_suffix[s] = 1
tst_points_per_suffix = list(tst_points_per_suffix.values())

min(tst_points_per_suffix), max(tst_points_per_suffix), mean(tst_points_per_suffix), median(tst_points_per_suffix)

### Convert to XC format

In [62]:
trn_prefix_to_prefix_id = {}
trn_prefix_id_to_prefix = {}
trn_suffix_id_to_suffix = {}
trn_suffix_to_suffix_id = {}

prefix_id = 0
suffix_id = 0
for prefix, suffixes in updated_train_prefix_to_suffixes.items():
    if prefix not in trn_prefix_to_prefix_id:
        trn_prefix_to_prefix_id[prefix] = prefix_id
        trn_prefix_id_to_prefix[prefix_id] = prefix
        prefix_id += 1
    for suffix, freq in suffixes.items():
        if suffix not in trn_suffix_to_suffix_id:
            trn_suffix_to_suffix_id[suffix] = suffix_id
            trn_suffix_id_to_suffix[suffix_id] = suffix
            suffix_id += 1
            

In [None]:
len(trn_prefix_id_to_prefix.keys()), len(trn_suffix_id_to_suffix.keys())

In [66]:
## writing trn_X and Y
data_dir = f'{dataset_dir}/final-dataset'
f = open(f'{data_dir}/raw/trn_X.txt', 'w')
for i in range(len(trn_prefix_id_to_prefix.keys())):
    f.write(trn_prefix_id_to_prefix[i]+'\n')
f.close()

f = open(f'{data_dir}/raw/Y.txt', 'w')
for i in range(len(trn_suffix_id_to_suffix.keys())):
    f.write(trn_suffix_id_to_suffix[i]+'\n')
f.close()


In [67]:
trn_X_Y = sp.dok_matrix((3922479,272825), dtype=np.int32)

In [None]:
for prefix, suffixes in tqdm(updated_train_prefix_to_suffixes.items()):
    prefix_id = trn_prefix_to_prefix_id[prefix]
    flag = 0
    for suffix, freq in updated_train_prefix_to_suffixes[prefix].items():
        suffix_id = trn_suffix_to_suffix_id[suffix]
        trn_X_Y[prefix_id, suffix_id] = freq


In [69]:
trn_X_Y = trn_X_Y.tocsr()

In [None]:
trn_X_Y.shape

In [None]:
trn_lpp = np.array((trn_X_Y > 0).astype(int).sum(axis=1)).squeeze()
round(np.max(trn_lpp),2)

In [72]:
sp.save_npz(f'{data_dir}/raw/trn_X_Y.npz', trn_X_Y)

In [73]:
tst_prefix_to_prefix_id = {}
tst_prefix_id_to_prefix = {}
prefix_id = 0
for prefix, suffixes in test_prefix_to_suffixes.items():
    if prefix not in tst_prefix_to_prefix_id:
        tst_prefix_to_prefix_id[prefix] = prefix_id
        tst_prefix_id_to_prefix[prefix_id] = prefix
        prefix_id += 1
            

In [None]:
len(tst_prefix_id_to_prefix.keys())

In [75]:
f = open(f'{data_dir}/raw/tst_X.txt', 'w')
for i in range(len(tst_prefix_id_to_prefix.keys())):
    f.write(tst_prefix_id_to_prefix[i]+'\n')
f.close()

In [None]:
tst_X_Y = sp.dok_matrix((519352,272825), dtype=np.int32)
for prefix, suffixes in tqdm(test_prefix_to_suffixes.items()):
    prefix_id = tst_prefix_to_prefix_id[prefix]
    for suffix, freq in test_prefix_to_suffixes[prefix].items():
        suffix_id = trn_suffix_to_suffix_id[suffix]
        tst_X_Y[prefix_id, suffix_id] = freq
tst_X_Y = tst_X_Y.tocsr()

In [77]:
sp.save_npz(f'{data_dir}/raw/tst_X_Y.npz', tst_X_Y)

#### Coverage Calculation

In [None]:
queries_covered_freq_list = []
total_freq_list = []
for query, freq in tqdm(train_query_to_freq.items()):
    suffixes, min_index = GenerateShortlistSuffixes(query, final_suffixes_set)
    if len(suffixes) > 0:
        queries_covered_freq_list.append(freq)
    total_freq_list.append(freq)
        

In [None]:
sum(queries_covered_freq_list)/sum(total_freq_list)

In [None]:
len(queries_covered_freq_list)/len(total_freq_list)

### Verify older suffix list coverage

In [34]:
Y = [x.strip() for x in open(f'{dataset_dir}/SuffixesDatasets/CharsToAddPrefixes/V2Normalization/1M/AllPrefixesSample/raw/Y.txt')]

In [35]:
Y_set = set(Y)

In [36]:
queries_covered_freq_list = []
total_freq_list = []
for query, freq in tqdm(train_query_to_freq.items()):
    suffixes, min_index = GenerateShortlistSuffixes(query, Y_set)
    if min_index!=len(query):
        queries_covered_freq_list.append(freq)
    total_freq_list.append(freq)
        

100%|██████████| 8700222/8700222 [00:24<00:00, 355001.33it/s]


In [None]:
sum(queries_covered_freq_list)/sum(total_freq_list)

In [None]:
len(queries_covered_freq_list)/len(total_freq_list)

## SAMPLE ONE PREFIX PER QUERY, SUFFIX AND THEN COMPLETE GT

In [168]:
def SampleRandomPrefixFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add):
    max_char = max_chars_to_add
    min_char = min_chars_to_add
    max_char = min(len(suffix), max_char)
    start_index = max(len(query) - max_char - 1, 0)
    end_index = max(len(query) - min_char, 0)
    chosen_index = random.randint(start_index, end_index)
    prefix = query[:chosen_index]
    return prefix, suffix
    

In [None]:
prefix_to_suffixes = {}
query_suffix_pairs = []
for query, freq in tqdm(train_query_to_freq.items()):
    suffixes = GenerateSuffixesFromQuery(query)
    for suffix in suffixes:
        if suffix in final_suffixes_set:
            prefix, suffix = SampleRandomPrefixFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add) #sample 1 prefix between 1 to 10 chars to add
            if prefix not in prefix_to_suffixes:
                prefix_to_suffixes[prefix] = {}
            prefix_to_suffixes[prefix][suffix] = train_query_to_freq[query]
            query_suffix_pairs.append([query, suffix])
            

In [None]:
random.sample(prefix_to_suffixes.items(), k = 10)

In [171]:
def GenerateAllPrefixesFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add):
    max_char = max_chars_to_add
    min_char = min_chars_to_add
    max_char = min(len(suffix), max_char)
    start_index = max(len(query) - max_char - 1, 0)
    end_index = max(len(query) - min_char, 0)
    prefixes = []
    for i in range(start_index, end_index+1):
        prefixes.append(query[:i])
    return prefixes

In [None]:
for i in tqdm(range(len(query_suffix_pairs))):
    query, suffix = query_suffix_pairs[i][0], query_suffix_pairs[i][1]
    prefixes = GenerateAllPrefixesFromQuerySuffix(query, suffix, 100, min_chars_to_add)
    for prefix in prefixes:
        if prefix in prefix_to_suffixes:
            if suffix not in prefix_to_suffixes[prefix]:
                prefix_to_suffixes[prefix][suffix] = train_query_to_freq[query]

In [None]:
len(prefix_to_suffixes.keys())

In [None]:
len(final_suffixes_set)

## REDUCE QUERY, SUFFIX PAIRS TO START WITH

In [None]:
suffix_to_queries = {}
for query, freq in tqdm(train_query_to_freq.items()):
    suffixes = GenerateSuffixesFromQuery(query)
    for suffix in suffixes:
        if suffix in final_suffixes_set:
            if suffix not in suffix_to_queries:
                suffix_to_queries[suffix] = {}
            suffix_to_queries[suffix][query] = train_query_to_freq[query]

In [176]:
total_query_suffix_pairs = 0
max_queries_per_suffix = 0
for suffix, queries in suffix_to_queries.items():
    total_query_suffix_pairs += len(suffix_to_queries[suffix].keys())
    max_queries_per_suffix = max(max_queries_per_suffix, len(suffix_to_queries[suffix].keys()))

In [None]:
total_query_suffix_pairs, max_queries_per_suffix

In [None]:
MAX_QUERIES_PER_SUFFIX = 100
query_suffix_pairs = []
for suffix, queries in tqdm(suffix_to_queries.items()):
    sorted_queries = sorted(suffix_to_queries[suffix].items(), key=lambda x:-x[1])
    for i in range(min(len(sorted_queries), MAX_QUERIES_PER_SUFFIX)):
        query_suffix_pairs.append([sorted_queries[i][0], suffix])

In [None]:
len(query_suffix_pairs), len(suffix_to_queries.keys())

In [192]:
def SampleRandomPrefixFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add):
    max_char = max_chars_to_add
    min_char = min_chars_to_add
    max_char = min(len(suffix), max_char)
    start_index = max(len(query) - max_char - 1, 0)
    end_index = max(len(query) - min_char, 0)
    chosen_index = random.randint(start_index, end_index)
    prefix = query[:chosen_index]
    return prefix, suffix

prefix_to_suffixes = {}
for i in range(len(query_suffix_pairs)):
    query, suffix = query_suffix_pairs[i][0], query_suffix_pairs[i][1]
    prefix, suffix = SampleRandomPrefixFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add) #sample 1 prefix between 1 to 10 chars to add
    if prefix not in prefix_to_suffixes:
        prefix_to_suffixes[prefix] = {}
    prefix_to_suffixes[prefix][suffix] = train_query_to_freq[query]

In [None]:
len(prefix_to_suffixes.keys())

In [None]:
def GenerateAllPrefixesFromQuerySuffix(query, suffix, max_chars_to_add, min_chars_to_add):
    max_char = max_chars_to_add
    min_char = min_chars_to_add
    max_char = min(len(suffix), max_char)
    start_index = max(len(query) - max_char - 1, 0)
    end_index = max(len(query) - min_char, 0)
    prefixes = []
    for i in range(start_index, end_index+1):
        prefixes.append(query[:i])
    return prefixes

for i in tqdm(range(len(query_suffix_pairs))):
    query, suffix = query_suffix_pairs[i][0], query_suffix_pairs[i][1]
    prefixes = GenerateAllPrefixesFromQuerySuffix(query, suffix, 100, min_chars_to_add)
    for prefix in prefixes:
        if prefix in prefix_to_suffixes:
            if suffix not in prefix_to_suffixes[prefix]:
                prefix_to_suffixes[prefix][suffix] = train_query_to_freq[query]

In [None]:
len(prefix_to_suffixes.keys())