In [18]:
from omegaconf import OmegaConf

config = {
    'dataset': 'saier/unarxive_citrec',
    'n_train': 10_000,
    'n_valid': 1_000,
    'n_test': 1_000,
    'max_chars_len': 512,
    'min_chars_len': 128,
    'save_dir': '../../data/raw/unarxive_citrec/'
}
config = OmegaConf.create(config)

In [20]:
from datasets import load_dataset
from tqdm import tqdm

# Load the dataset in streaming mode
dataset = load_dataset(config.dataset, split='train', streaming=True)

def take_n_samples(n: int, split: str, batch_size: int = 250) -> list:
    dataset = load_dataset(config.dataset, split=split, streaming=True)
    samples = []
    bar = tqdm(total=n)
    for sample in dataset:
        if config.min_chars_len <= len(sample['text']) <= config.max_chars_len:
            samples.append(sample)
            bar.update(1)
        if len(samples) == n:
            break

    return samples

train_samples = take_n_samples(config.n_train, split='train')
valid_samples = take_n_samples(config.n_valid, split='validation')
test_samples = take_n_samples(config.n_test, split='test')

100%|██████████| 10000/10000 [00:15<00:00, 664.22it/s]
100%|██████████| 1000/1000 [00:01<00:00, 877.92it/s]
100%|██████████| 1000/1000 [00:01<00:00, 754.20it/s]


In [21]:
print(f'Train samples: {len(train_samples)}')

Train samples: 10000


In [22]:
def extract_texts(samples):
    return [sample['text'] for sample in samples]

train_texts = extract_texts(train_samples)
print(f'Train texts: {len(train_texts)}')
valid_texts = extract_texts(valid_samples)
test_texts = extract_texts(test_samples)

Train texts: 10000


In [23]:
len(set(train_texts))

9178

In [24]:
def remove_overlap(texts_a, texts_b):
    overlap = list(set(texts_a) & set(texts_b))
    print(f'Overlap: {len(overlap)}')

    # remove
    texts_a = list(set(texts_a) - set(overlap))
    texts_b = list(set(texts_b) - set(overlap))
    return texts_a, texts_b

train_texts, valid_texts = remove_overlap(train_texts, valid_texts)
train_texts, test_texts = remove_overlap(train_texts, test_texts)
valid_texts, test_texts = remove_overlap(valid_texts, test_texts)

print(f"Train: {len(train_texts)}, Valid: {len(valid_texts)}, Test: {len(test_texts)}")

Overlap: 72
Overlap: 24
Overlap: 65
Train: 9082, Valid: 702, Test: 568


In [25]:
# assert that texts don't overlap
assert len(set(train_texts) & set(valid_texts)) == 0
assert len(set(train_texts) & set(test_texts)) == 0
assert len(set(valid_texts) & set(test_texts)) == 0

In [26]:
import numpy as np

train_lens = np.array([len(text) for text in train_texts])

# 0.25 quantile, 0.5 quantile, 0.75 quantile
print("0.25, 0.5, 0.75 quantile:", np.quantile(train_lens, [0.25, 0.5, 0.75]))
print("Max len:", np.max(train_lens))
print("Min len:", np.min(train_lens))
print("Example:", train_texts[np.random.randint(0, len(train_texts))])

0.25, 0.5, 0.75 quantile: [264. 358. 439.]
Max len: 512
Min len: 128
Example: Regarding the computational complexity of the algorithm, line REF  is
computed in \(\mathcal {O}(|V(G)| + |E(G)|)\)  time [1]}.
Moreover, the complete bipartite decomposition of \(G\)  is computed
in \(\mathcal {O}(|V(G)|)\)  time [2]}. To conclude,
the paths at lines REF , REF , REF  and REF 
are computed in constant time, as each \(K_i\)  is complete bipartite. Therefore,
the complexity of Algorithm REF  is \(\mathcal {O}(|V(G)| + |E(G)|)\) .
\(\Box \)
<FIGURE>


In [27]:
import os
import json

os.makedirs(config.save_dir, exist_ok=True)
for texts, split_name in [
    (train_texts, 'train'),
    (valid_texts, 'valid'),
    (test_texts, 'test')
]:
    path = os.path.join(config.save_dir, split_name + '.json')
    with open(path, 'w') as f:
        json.dump(texts, f, indent=4, ensure_ascii=False)