In [2]:
# toydata to test simple algorithm
from datasets import Dataset, DatasetDict, concatenate_datasets
from setretrieval.utils.utils import pickload, pickdump
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import Counter
from nltk.corpus import wordnet as wn
from nltk.tokenize import sent_tokenize, word_tokenize
import random
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# load in some seed documents (from wikipedia)
wikidocs = Dataset.load_from_disk("../propercache/data/datastores/wikipedia_docs_10k_decont")

In [4]:
from setretrieval.datagen.toydata import chunk_ds, process_ngrams

n = 1
chunked_wikidocs = chunk_ds(wikidocs, 10)
ngrammed_wikidocs = chunked_wikidocs.map(lambda x: process_ngrams(x, n), num_proc=1)
allngs_flat = [n for ng in ngrammed_wikidocs['pos_chunks'] for n in ng]
counts = Counter(allngs_flat)

100%|██████████| 10000/10000 [00:01<00:00, 6941.61it/s]
Map (num_proc=1): 100%|██████████| 62598/62598 [00:09<00:00, 6508.39 examples/s]


In [11]:
def create_train_eval_splits_ultra_fast(ngrammed_dataset, counts, train_ratio=0.8, experquery=1, negatives_per_positive=1, seed=42, train_samples=50000, max_negatives_pool=100000):
    """
    Ultra-fast version with balanced count sampling for negatives.
    """
    random.seed(seed)
    np.random.seed(seed)
    
    n_samples = train_samples
    all_indices = np.arange(n_samples)
    np.random.shuffle(all_indices)
    
    split_point = int(n_samples * train_ratio)
    train_indices = all_indices[:split_point]
    eval_indices = all_indices[split_point:]
    
    # Collect eval positives
    print("Collecting eval positives...")
    eval_positives = set()
    for idx in eval_indices:
        eval_positives.update(ngrammed_dataset[idx]['pos_chunks'])
    
    # Create a large pool of potential negatives (sampled once)
    print("Creating negative pool...")
    all_ngrams = []
    for row in ngrammed_dataset['pos_chunks']:
        all_ngrams.extend(row)
    
    # Sample a subset for faster operations
    if len(all_ngrams) > max_negatives_pool:
        negative_pool = random.sample(all_ngrams, max_negatives_pool)
    else:
        negative_pool = all_ngrams
    
    # Create count-based bins for stratified sampling
    print("Creating count bins for balanced sampling...")
    unique_negatives = list(set(negative_pool))
    negative_counts_map = {ng: counts.get(ng, 1) for ng in unique_negatives}
    
    # Group negatives by count ranges (log-scale bins)
    from collections import defaultdict
    count_bins = defaultdict(list)
    for ng in unique_negatives:
        count = negative_counts_map[ng]
        bin_key = int(np.log10(max(count, 1)))  # log10 binning
        count_bins[bin_key].append(ng)
    
    bin_keys = list(count_bins.keys())
    
    print("Creating train set...")
    train_data = []
    
    for idx in tqdm(train_indices):
        row = ngrammed_dataset[int(idx)]
        query = row['text']
        pos_chunks = row['pos_chunks']
        
        # Filter valid positives
        pos_set = set(pos_chunks)
        valid_positives = [p for p in pos_chunks if p not in eval_positives]
        
        if not valid_positives:
            continue
        
        # Compute target count range for negatives based on positives
        pos_counts = [counts.get(p, 1) for p in valid_positives]
        target_bin = int(np.log10(max(np.mean(pos_counts), 1)))
        
        # Sample negatives from similar count bins
        n_needed = len(valid_positives) * negatives_per_positive
        negatives = []
        
        # Try target bin first, then expand to nearby bins
        search_bins = [target_bin]
        for offset in range(1, len(bin_keys) + 1):
            if target_bin + offset in bin_keys:
                search_bins.append(target_bin + offset)
            if target_bin - offset in bin_keys:
                search_bins.append(target_bin - offset)
        
        for bin_key in search_bins:
            if len(negatives) >= n_needed:
                break
            
            candidates = count_bins[bin_key]
            random.shuffle(candidates)
            
            for candidate in candidates:
                if candidate not in pos_set:
                    negatives.append(candidate)
                    if len(negatives) >= n_needed:
                        break
        
        # Fallback: if not enough negatives, sample from any bin
        if len(negatives) < n_needed:
            attempts = 0
            max_attempts = n_needed * 10
            while len(negatives) < n_needed and attempts < max_attempts:
                candidate = random.choice(unique_negatives)
                if candidate not in pos_set and candidate not in negatives:
                    negatives.append(candidate)
                attempts += 1
        
        # Create examples
        for i, pos in enumerate(random.sample(valid_positives, k=min(experquery, len(valid_positives)))):
            for j in range(min(negatives_per_positive, len(negatives) - i * negatives_per_positive)):
                if i * negatives_per_positive + j < len(negatives):
                    train_data.append({
                        'query': query,
                        'positive': pos,
                        'negative': negatives[i * negatives_per_positive + j]
                    })
    
    # Create eval set
    print("Creating eval set...")
    eval_data = [
        {
            'question': ngrammed_dataset[int(idx)]['text'],
            'pos_chunks': ngrammed_dataset[int(idx)]['pos_chunks']
        }
        for idx in eval_indices
    ]
    
    # Print statistics
    if train_data:
        pos_counts = [counts.get(d['positive'], 1) for d in train_data]
        neg_counts = [counts.get(d['negative'], 1) for d in train_data]
        print(f"\nTrain set statistics:")
        print(f"Average positive count: {np.mean(pos_counts):.2f} (median: {np.median(pos_counts):.2f})")
        print(f"Average negative count: {np.mean(neg_counts):.2f} (median: {np.median(neg_counts):.2f})")

    # print number of positives from the test set that are in the train test
    test_positives = set(eval_positives)
    train_positives = set([r['positive'] for r in train_data])
    print(len(test_positives & train_positives))
    
    return Dataset.from_list(train_data), Dataset.from_list(eval_data)

# Usage
tsamps = 50000
train_eval = create_train_eval_splits_ultra_fast(
    ngrammed_wikidocs, 
    counts,
    train_ratio=0.99, 
    train_samples=tsamps, 
    negatives_per_positive=1, 
    seed=42, 
    max_negatives_pool=40000
)
train_ds, eval_ds = train_eval

Collecting eval positives...
Creating negative pool...
Creating count bins for balanced sampling...
Creating train set...


100%|██████████| 49500/49500 [00:59<00:00, 836.86it/s]


Creating eval set...

Train set statistics:
Average positive count: 50.69 (median: 24.00)
Average negative count: 53.19 (median: 32.00)
0


In [14]:
def create_train_eval_splits_ultra_fast_contaminated(ngrammed_dataset, counts, train_ratio=0.8, experquery=1, negatives_per_positive=1, seed=42, train_samples=50000, max_negatives_pool=100000):
    """
    Ultra-fast version with balanced count sampling for negatives.
    """
    random.seed(seed)
    np.random.seed(seed)
    
    n_samples = train_samples
    all_indices = np.arange(n_samples)
    np.random.shuffle(all_indices)
    
    split_point = int(n_samples * train_ratio)
    train_indices = all_indices[:split_point]
    eval_indices = all_indices[split_point:]
    
    # Create a large pool of potential negatives (sampled once)
    print("Creating negative pool...")
    all_ngrams = []
    for row in ngrammed_dataset['pos_chunks']:
        all_ngrams.extend(row)
    
    # Sample a subset for faster operations
    if len(all_ngrams) > max_negatives_pool:
        negative_pool = random.sample(all_ngrams, max_negatives_pool)
    else:
        negative_pool = all_ngrams
    
    # Create count-based bins for stratified sampling
    print("Creating count bins for balanced sampling...")
    unique_negatives = list(set(negative_pool))
    negative_counts_map = {ng: counts.get(ng, 1) for ng in unique_negatives}
    
    # Group negatives by count ranges (log-scale bins)
    from collections import defaultdict
    count_bins = defaultdict(list)
    for ng in unique_negatives:
        count = negative_counts_map[ng]
        bin_key = int(np.log10(max(count, 1)))  # log10 binning
        count_bins[bin_key].append(ng)
    
    bin_keys = list(count_bins.keys())
    
    print("Creating train set...")
    train_data = []
    
    for idx in tqdm(train_indices):
        row = ngrammed_dataset[int(idx)]
        query = row['text']
        pos_chunks = row['pos_chunks']
        
        # Use all positives (no filtering)
        pos_set = set(pos_chunks)
        valid_positives = pos_chunks
        
        if not valid_positives:
            continue
        
        # Compute target count range for negatives based on positives
        pos_counts = [counts.get(p, 1) for p in valid_positives]
        target_bin = int(np.log10(max(np.mean(pos_counts), 1)))
        
        # Sample negatives from similar count bins
        n_needed = len(valid_positives) * negatives_per_positive
        negatives = []
        
        # Try target bin first, then expand to nearby bins
        search_bins = [target_bin]
        for offset in range(1, len(bin_keys) + 1):
            if target_bin + offset in bin_keys:
                search_bins.append(target_bin + offset)
            if target_bin - offset in bin_keys:
                search_bins.append(target_bin - offset)
        
        for bin_key in search_bins:
            if len(negatives) >= n_needed:
                break
            
            candidates = count_bins[bin_key]
            random.shuffle(candidates)
            
            for candidate in candidates:
                if candidate not in pos_set:
                    negatives.append(candidate)
                    if len(negatives) >= n_needed:
                        break
        
        # Fallback: if not enough negatives, sample from any bin
        if len(negatives) < n_needed:
            attempts = 0
            max_attempts = n_needed * 10
            while len(negatives) < n_needed and attempts < max_attempts:
                candidate = random.choice(unique_negatives)
                if candidate not in pos_set and candidate not in negatives:
                    negatives.append(candidate)
                attempts += 1
        
        # Create examples
        for i, pos in enumerate(random.sample(valid_positives, k=min(experquery, len(valid_positives)))):
            for j in range(min(negatives_per_positive, len(negatives) - i * negatives_per_positive)):
                if i * negatives_per_positive + j < len(negatives):
                    train_data.append({
                        'query': query,
                        'positive': pos,
                        'negative': negatives[i * negatives_per_positive + j]
                    })
    
    # Create eval set
    print("Creating eval set...")
    eval_data = [
        {
            'question': ngrammed_dataset[int(idx)]['text'],
            'pos_chunks': ngrammed_dataset[int(idx)]['pos_chunks']
        }
        for idx in eval_indices
    ]
    
    # Print statistics
    if train_data:
        pos_counts = [counts.get(d['positive'], 1) for d in train_data]
        neg_counts = [counts.get(d['negative'], 1) for d in train_data]
        print(f"\nTrain set statistics:")
        print(f"Average positive count: {np.mean(pos_counts):.2f} (median: {np.median(pos_counts):.2f})")
        print(f"Average negative count: {np.mean(neg_counts):.2f} (median: {np.median(neg_counts):.2f})")

    # print how much overlap there is between train and eval positives
    train_positives = set([r['positive'] for r in train_data])
    eval_positives = set([item for sublist in [r['pos_chunks'] for r in eval_data] for item in sublist])
    print(len(train_positives & eval_positives))
    
    return Dataset.from_list(train_data), Dataset.from_list(eval_data)

# Usage
tsamps = 50000
train_eval = create_train_eval_splits_ultra_fast_contaminated(
    ngrammed_wikidocs, 
    counts,
    train_ratio=0.99, 
    train_samples=tsamps, 
    negatives_per_positive=1, 
    seed=42, 
    max_negatives_pool=40000
)
train_ds, eval_ds = train_eval

Creating negative pool...
Creating count bins for balanced sampling...
Creating train set...


100%|██████████| 49500/49500 [00:06<00:00, 7970.54it/s]


Creating eval set...

Train set statistics:
Average positive count: 20504.40 (median: 956.00)
Average negative count: 19973.60 (median: 13571.00)
2815


In [17]:
contam = True
if contam:
    suff = "balancedcontam"
else:
    suff = "balanced"

In [18]:
train_ds = DatasetDict({
    'train': train_ds.select(range(1000, len(train_ds))),
    'test': train_ds.select(range(1000))
})
train_ds.save_to_disk(f"../propercache/data/colbert_training/wiki{n}gramtrain{tsamps}sample{suff}")
eval_ds.save_to_disk(f"../propercache/data/evalsets/evalwiki{n}grameval{tsamps}samples{suff}")

Saving the dataset (1/1 shards): 100%|██████████| 48500/48500 [00:00<00:00, 454674.72 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1000/1000 [00:00<00:00, 180322.61 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 500/500 [00:00<00:00, 87217.80 examples/s]


In [19]:
def dsets_to_dstore(n, samps, angs=None, dssize=50000):
    traindata = DatasetDict.load_from_disk(f"../propercache/data/colbert_training/wiki{n}gramtrain{samps}sample{suff}")['train']
    evaldata = Dataset.load_from_disk(f"../propercache/data/evalsets/evalwiki{n}grameval{samps}samples{suff}")

    # get all pos_chuhnks from evaldata
    allpos = [r['pos_chunks'] for r in evaldata]
    allpos = list(set([item for sublist in allpos for item in sublist]))
    # get all positives from traindata
    allpos_train = list(set([r['positive'] for r in traindata]))
    remaining = dssize - len(allpos)
    print(len(allpos), len(allpos_train), remaining)
    assert remaining > 0
    if angs is not None:
        aset = set(angs)
        aset = list(aset - set(allpos))
    else:
        aset = list(set(allpos_train) - set(allpos))
    usepos = random.sample(aset, k=remaining)
    tdata = Dataset.from_list([{'text': p} for p in usepos + allpos])
    return tdata

n = 1
chunked_wikidocs = chunk_ds(wikidocs, 10)
ngrammed_wikidocs = chunked_wikidocs.map(lambda x: process_ngrams(x, n), num_proc=1)
allngs_flat = [n for ng in ngrammed_wikidocs['pos_chunks'] for n in ng]
counts = Counter(allngs_flat)
dscount=50000
tdstore = dsets_to_dstore(n, dscount, allngs_flat)
tdstore.save_to_disk(f"../propercache/data/datastores/wiki{n}gramdstore{dscount}{suff}")

100%|██████████| 10000/10000 [00:01<00:00, 7390.85it/s]
Map (num_proc=1): 100%|██████████| 62598/62598 [00:09<00:00, 6279.16 examples/s]


4827 11026 45173


Saving the dataset (1/1 shards): 100%|██████████| 50000/50000 [00:00<00:00, 2785099.40 examples/s]


In [21]:
len(train_ds)

526235

In [2]:
# load in stuff for inverted task
t100krand = DatasetDict.load_from_disk("../propercache/data/colbert_training/nountrain100000minimal10dwords/")
testdstpre = Dataset.load_from_disk("../propercache/data/datastores/evaltdstore10words50pos100k")
testset = Dataset.load_from_disk("../propercache/data/evalsets/testset10words50pos")

In [None]:
ctxt = set(testdstpre['text'])

In [None]:
sum([" " in row['text'] for row in nouns100k])

In [None]:
if False: # do just once to avoid extra I/O
    all_nouns = [word for synset in wn.all_synsets('n') for word in synset.lemma_names()]
    all_nouns = [word for word in all_nouns if "_" not in word]
    all_nouns = list(set(all_nouns))
    random.shuffle(all_nouns)
    test_nouns = [{'text': word} for word in all_nouns[:1000]]
    # allnouns = Dataset.from_list([{"text": word} for word in all_nouns])
    # allnouns.save_to_disk("../propercache/data/datastores/allnouns")
    nouns100k = Dataset.from_list([{"text": word} for word in all_nouns[1000:101000]])
    nouns10k = Dataset.from_list([{"text": word} for word in all_nouns[1000:11000]])
    Dataset.from_list(test_nouns).save_to_disk("../propercache/data/datastores/heldoutnouns")
    nouns100k.save_to_disk("../propercache/data/datastores/nouns100k")
    nouns10k.save_to_disk("../propercache/data/datastores/nouns10k")

In [3]:
test_nouns = Dataset.load_from_disk("../propercache/data/datastores/heldoutnouns")
nouns100k = Dataset.load_from_disk("../propercache/data/datastores/nouns100k")
nouns10k = Dataset.load_from_disk("../propercache/data/datastores/nouns10k")

In [None]:
### WE WILL HAVE 2 KINDS OF TASKS
# A: Given a query (one word), return a list of chunks (each chunk has several words)
# B: Given a query (a list of words), return a list of chunks (each chunk is just one noun at a time)

In [None]:
### NOW STARTING SETUP FOR A

In [5]:
def generate_train_query_rand(wset, doc_words=10):
    query = random.choice(wset)['text']
    chunkspos = [r['text'] for r in random.choices(wset, k=doc_words)]
    
    # Faster check and swap
    if query not in chunkspos:
        chunkspos[random.randint(0, doc_words-1)] = query
    
    chunksneg = [r['text'] for r in random.choices(wset, k=doc_words)]
    
    # Join once at the end
    return {'query': query, 'positive': " ".join(chunkspos), 'negative': " ".join(chunksneg)}

def generate_train_query_minimal(wset, doc_words=10):
    query = random.choice(wset)['text']
    
    # Use random.sample directly instead of converting to set
    chunks = [r['text'] for r in random.sample(wset, k=doc_words)]
    
    # Check and replace if needed
    try:
        idx = chunks.index(query)
        chunks[idx] = random.choice(wset)['text']
    except ValueError:
        pass  # query not in chunks, which is fine
    
    negdata = " ".join(chunks)
    
    # Swap for positive
    chunks[random.randint(0, doc_words-1)] = query
    posdata = " ".join(chunks)
    
    return {'query': query, 'positive': posdata, 'negative': negdata}


# Batch generation function - much faster
def generate_batch_rand(wset_texts, count, doc_words=10):
    results = []
    for _ in tqdm(range(count)):
        query = random.choice(wset_texts)
        chunkspos = random.choices(wset_texts, k=doc_words)
        
        if query not in chunkspos:
            chunkspos[random.randint(0, doc_words-1)] = query
        
        chunksneg = random.choices(wset_texts, k=doc_words)
        
        results.append({
            'query': query, 
            'positive': " ".join(chunkspos), 
            'negative': " ".join(chunksneg)
        })
    return results

# Generate in larger batches
dwords = 250
datapoints = 100000


# Pre-extract text once to avoid repeated dictionary access
nouns10k_texts = [item['text'] for item in nouns10k]
print("Generating training data...")
train_data = generate_batch_rand(nouns10k_texts, datapoints, dwords)
print("Generating test data...")
test_data = generate_batch_rand(nouns10k_texts, 1000, dwords)

train100krand = DatasetDict({
    'train': Dataset.from_list(train_data),
    'test': Dataset.from_list(test_data)
})
do_save=True
if do_save:
    train100krand.save_to_disk(f"../propercache/data/colbert_training/v2nountrain{datapoints}rand{dwords}dwords")

Generating training data...


 48%|████▊     | 48349/100000 [00:02<00:02, 20094.04it/s]


KeyboardInterrupt: 

In [None]:
train100krand['train'][0]

In [None]:
def generate_train_query_rand(wset, doc_words=10):
    query = random.choice(wset)['text']
    chunkspos = [r['text'] for r in random.choices(wset, k=doc_words)]
    # if query not in chunkspos, randomly swap one of the chunks with query
    if query not in chunkspos:
        swap_idx = random.randint(0, len(chunkspos)-1)
        chunkspos[swap_idx] = query
    chunksneg = [r['text'] for r in random.choices(wset, k=doc_words)]
    return {'query': query, 'positive': " ".join(chunkspos), 'negative': " ".join(chunksneg)}

# TODO this will help us test if minimality somehow makes data better for "hard negatives"
# - Note this is pretty clean minimality, other stuff could / could not matter
def generate_train_query_minimal(wset, doc_words=10):
    query = random.choice(wset)['text']
    chunks = set([r['text'] for r in random.sample(wset, k=doc_words)])
    # TODO sanity check that some weird alphabetization isn't doing something weird
    # make sure query is not in chunks
    if query in chunks:
        chunks.remove(query)
        chunks.add(random.choice(wset)['text'])
    negdata = " ".join(chunks)
    # now randomly swap one of the chunks with query
    swap_idx = random.randint(0, len(chunks)-1)
    chunks = list(chunks)
    chunks[swap_idx] = query
    posdata = " ".join(chunks)
    return {'query': query, 'positive': posdata, 'negative': negdata}

# get 100k train sets for both kinds
dwords = 250
datapoints = 100000
train100krand = DatasetDict({
    'train': Dataset.from_list([generate_train_query_rand(nouns10k, dwords) for _ in tqdm(range(datapoints))]),
    'test': Dataset.from_list([generate_train_query_rand(nouns10k, dwords) for _ in tqdm(range(1000))])
})
# train100kminimal = DatasetDict({
#     'train': Dataset.from_list([generate_train_query_minimal(nouns10k, dwords) for _ in tqdm(range(datapoints))]),
#     'test': Dataset.from_list([generate_train_query_minimal(nouns10k, dwords) for _ in tqdm(range(1000))])
# })
do_save = True
if do_save:
    train100krand.save_to_disk(f"../propercache/data/colbert_training/v2nountrain{datapoints}rand{dwords}dwords")
    # train100kminimal.save_to_disk(f"../propercache/data/colbert_training/nountrain{datapoints}minimal{dwords}dwords")

In [6]:
def singwordquery_eval(evalnouns, trainnouns, dstoresize=10000, tsetsize=500, num_pos=50, docwords=10):
    evsetnouns = set([r['text'] for r in evalnouns])
    trsetnouns = set([r['text'] for r in trainnouns])
    assert len(evsetnouns & trsetnouns) == 0
    assert all([" " not in r['text'] for r in evalnouns])
    assert all([" " not in r['text'] for r in trainnouns])
    evalnouns = list([r['text'] for r in evalnouns])
    trainnouns = list([r['text'] for r in trainnouns])
    # make eval test set and datastore
    testqueries = random.sample(evalnouns, k=tsetsize)
    # make eval datastore, with each document have docwords words
    starterdocs = [[r for r in random.sample(trainnouns, k=docwords)] for _ in tqdm(range(dstoresize))]
    # for each query, randomly choose num_pos docs from starterdocs. For each of these randomly replace one word with query word
    query_poschunks = []
    indlist = list(range(dstoresize))
    for query in tqdm(testqueries):
        posinds = random.sample(indlist, k=num_pos)
        for posind in posinds:
            # randomly replace one word in starterdocs[posind] with query word
            wordind = random.randint(0, docwords-1)
            starterdocs[posind][wordind] = query
        query_poschunks.append(posinds)
    doc_poschunks = Dataset.from_list([{"text": " ".join(d)} for d in starterdocs])
    query_data = Dataset.from_list([
        {
            "question": q, 
            "pos_chunks": [doc_poschunks[p]['text'] for p in query_poschunks[i]], 
            "num_pos_chunks": len(query_poschunks[i])
        } 
        for i, q in enumerate(testqueries)])
    return doc_poschunks, query_data

dwords = 50
npos=50
# get a test set to work with 
tdstore, testset10words50pos = singwordquery_eval(test_nouns, nouns10k, 100000, 500, npos, dwords)

100%|██████████| 100000/100000 [00:01<00:00, 55505.04it/s]
100%|██████████| 500/500 [00:00<00:00, 18301.19it/s]


In [7]:
tdstore.save_to_disk(f"../propercache/data/datastores/v2evaltdstore{dwords}words{npos}pos100k")
testset10words50pos.save_to_disk(f"../propercache/data/evalsets/v2testset{dwords}words{npos}pos")

Saving the dataset (1/1 shards): 100%|██████████| 100000/100000 [00:00<00:00, 196931.77 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 500/500 [00:00<00:00, 4723.55 examples/s]


In [None]:
tns = set([r['text'] for r in test_nouns])
trainnns = set(nouns100k['text'])
print(len(tns), len(trainnns), len(tns & trainnns))

In [None]:
tdstore = Dataset.load_from_disk("../propercache/data/datastores/v2evaltdstore50words50pos100k")
testset10words50pos = Dataset.load_from_disk("../propercache/data/evalsets/v2testset50words50pos")

In [None]:
tdstore = Dataset.load_from_disk("../propercache/data/datastores/v2evaltdstore50words50pos100k")

In [None]:
testset10words50pos[8]['question']

In [None]:
len([row['text'] for row in tdstore if f" tram " in row['text']])

In [None]:
wordcnts = []
qs = list(testset10words50pos['question'])
for q in tqdm(qs):
    wordcnts.append(sum([f" {q} " in row['text'] for row in tdstore]))
    print(wordcnts[-1])

In [None]:
# tdstore.save_to_disk("../propercache/data/datastores/evaltdstore10words50pos100k")
# testset10words50pos.save_to_disk("../propercache/data/evalsets/testset10words50pos")

In [None]:
# make eval datastore (main thing is making datastore)


In [None]:
# datapoints = 100000

In [None]:
### NOW STARTING SETUP FOR B 

In [12]:
def construct_toy_query(wset, wset_texts, query_words=100, ndps=1, datatype="train"):
    querywords = random.choices(wset, k=query_words)
    querywords_text = [q['text'] for q in querywords]
    query = " ".join(querywords_text)
    poslist = random.choices(querywords_text, k=ndps)
    
    # Convert to set for O(1) lookup
    querywords_set = set(querywords_text)
    
    # Sample negatives - just keep trying with random.choice until we get valid ones
    neglist = []
    for _ in range(ndps):
        while True:
            neg = random.choice(wset_texts)
            if neg not in querywords_set:
                neglist.append(neg)
                break
    
    if datatype == "train":
        return [{'query': query, 'positive': pos, 'negative': neg} 
                for pos, neg in zip(poslist, neglist)]
    else:
        return {'question': query, 'pos_chunks': querywords_text, 
                'numposchunks': len(querywords_text)}

def toyquerydset(wset, qwords, ndps, tsize):
    # Pre-extract all texts from wset ONCE
    wset_texts = [w['text'] for w in wset]
    
    alldata = []
    inddps = tsize // ndps
    for _ in tqdm(range(inddps)):
        alldata.extend(construct_toy_query(wset, wset_texts, qwords, ndps, "train"))
    return Dataset.from_list(alldata)

In [13]:
numwords = 50
tset = 100000
dp_per_query = 1
traindata = DatasetDict({
    'train': toyquerydset(nouns10k, numwords, dp_per_query, tset),
    'test': toyquerydset(nouns10k, numwords, dp_per_query, 1000)
})
traindata.save_to_disk(f"../propercache/data/colbert_training/v2nountraining{numwords}words{tset}ndps{dp_per_query}")

100%|██████████| 100000/100000 [01:38<00:00, 1017.98it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1015.25it/s]
Saving the dataset (1/1 shards): 100%|██████████| 100000/100000 [00:00<00:00, 193094.67 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1000/1000 [00:00<00:00, 114937.63 examples/s]


In [None]:
# GET EVAL DATA
evallists = [50] # [1, 5, 10, 25]
tnountexts = [r['text'] for r in test_nouns]
for ev in evallists:
    evdata = Dataset.from_list([construct_toy_query(test_nouns, tnountexts, query_words=ev, datatype="eval") for _ in tqdm(range(500))])
    if False:
        evdata.save_to_disk(f"../propercache/data/evalsets/v2nountest{ev}/")

In [None]:
# generate eval data following a probability distribution over count of query words
def pdquerydset(wset, qwdistdict, tsize, ndps=1):
    alldata = []
    dcnt = 0
    while dcnt < tsize:
        qw = random.choices(list(qwdistdict.keys()), weights=list(qwdistdict.values()), k=1)[0]
        alldata.extend(construct_toy_query(wset, qw, min(ndps, qw), "train"))
        dcnt += 1
        if dcnt % 10000 == 0:
            print(f"Generated {dcnt} queries")
    return Dataset.from_list(alldata)

# uniform distribution from 5 to 100
qwdistdictuni = {i: 1/96 for i in range(5, 101)}
# power law distribution from 5 to 100
qwdistdictpower = {i: i**(-0.5) for i in range(5, 101)}

# plot power law distribution in qwdistdictpower
x = list(qwdistdictpower.keys())
y = list(qwdistdictpower.values())
plt.plot(x, y)
plt.xlabel("Number of Words in Query")
plt.ylabel("Probability")
plt.title("Power Law Distribution of Query Length")
plt.show()

In [None]:
# get 100k train set for each distribution
train100kuniform = pdquerydset(nouns10k, qwdistdictuni, 100000)
test100kuniform = pdquerydset(nouns10k, qwdistdictuni, 1000)

train100kpower = pdquerydset(nouns10k, qwdistdictpower, 100000)
test100kpower = pdquerydset(nouns10k, qwdistdictpower, 1000)

DatasetDict({
    'train': train100kuniform,
    'test': test100kuniform
}).save_to_disk("../propercache/data/colbert_training/nountraining100kuniform5_100")

DatasetDict({
    'train': train100kpower,
    'test': test100kpower
}).save_to_disk("../propercache/data/colbert_training/nountraining100kpower5_100")

In [None]:
# sanity check distributions of 'query' length in train100kuniform, train100kpower
plt.hist([len(q['query'].split()) for q in train100kuniform], alpha=0.5, label='Uniform', range=(5, 100))
plt.hist([len(q['query'].split()) for q in train100kpower], alpha=0.5, label='Power', range=(5, 100))
plt.ylabel("Count") 
plt.xlabel("Number of Words in Query")
plt.title("Distribution of Query Length in Training Sets")
plt.legend()
plt.show()

In [None]:
evallists = [1, 5, 10, 25]
testsets = {}
for ev in evallists:
    testsets[ev] = Dataset.load_from_disk(f"../propercache/data/evalsets/nountest{ev}/")

In [None]:
testsets[1]

In [None]:
allnouns = Dataset.load_from_disk("../propercache/data/datastores/allnouns")

In [None]:
ansfixed = list(set([a['text'] for a in allnouns]))
ansfixed = Dataset.from_list([{'text': a} for a in ansfixed])
ansfixed.save_to_disk("../propercache/data/datastores/allnounsfixed")


In [None]:
ans = list([a['text'] for a in allnouns])

In [None]:
len(ans), len(set(ans))

In [None]:
# given a test set, convert it to a train set
def test_to_trainset(tset):
    alldata = []
    for row in tset:
        for pos in row['pos_chunks']:
            alldata.append({'query': row['question'], 'positive': pos, 'negative': random.choice(allnouns)['text']})
    return DatasetDict({'train': Dataset.from_list(alldata), 'test': Dataset.from_list(alldata)})

tset = 25
test_to_trainset(testsets[tset]).save_to_disk(f"../propercache/data/colbert_training/nountestset{tset}")

In [None]:
test_to_trainset(testsets[tset])

In [None]:
eval10 = Dataset.from_list([construct_toy_query(test_nouns, 10, "eval") for _ in tqdm(range(1000))])
eval100 = Dataset.from_list([construct_toy_query(test_nouns, 100, "eval") for _ in tqdm(range(1000))])

eval10.save_to_disk("../propercache/data/evalsets/nountest10/")
eval100.save_to_disk("../propercache/data/evalsets/nountest100/")

In [None]:
# TODO need to make eval data

In [None]:
traindata['train'][4]