In [None]:
from datasets import load_dataset  # type: ignore

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Download dataset.

print("Loading datasets...")
openwebtext = load_dataset("stas/openwebtext-10k")

texts = [sample["text"] for sample in openwebtext["train"] if "text" in sample]

print(f"Total samples: {len(texts)}")

Loading datasets...
Total samples: 10000


In [4]:
# Extract text from all samples.
prompts = [text.split(".") for text in texts]

In [5]:
n_sentences = [len(prompt) for prompt in prompts]

print(f"min={min(n_sentences)}, max={max(n_sentences)}, avg={sum(n_sentences)/len(n_sentences)}")

min=1, max=1736, avg=44.5757


In [6]:
from collections import Counter

# Print number of prompts that only have 0-~10 sentences in it.
freq_table = dict(Counter(n_sentences)).items()
print(f"Lower 10 freq: {sorted(freq_table)[:10]}")
print(f"Top 10 freq: {sorted(freq_table)[-10:]}")

# NOTE: Dropping prompts with <= 3 sentences would result in 0.2% of prompts drop.

Lower 10 freq: [(1, 10), (2, 4), (3, 6), (4, 12), (5, 54), (6, 88), (7, 130), (8, 179), (9, 217), (10, 264)]
Top 10 freq: [(768, 1), (821, 1), (860, 1), (926, 1), (944, 1), (954, 1), (968, 1), (1003, 1), (1020, 1), (1736, 1)]


In [7]:
prompts_sentences = [
    sample["text"].split(".") for sample in openwebtext["train"] if "text" in sample
][:10]

prompts = []
references = []
for sentences in prompts_sentences:
    # Don't use single sentece prompts.
    if len(sentences) == 1:
        continue
    # Use from 1-10 sentences as input.
    n_sentences = len(sentences)
    idx = n_sentences - 1 if n_sentences < 10 else 10
    prompts.append(". ".join(sentences[:idx]) + ".")
    references.append(prompts[-1] + " " + sentences[idx] + ".")

In [8]:
print([len(sentences) for sentences in prompts_sentences])

[41, 33, 21, 25, 15, 56, 14, 44, 9, 17]


In [None]:
# Original sentence had 41 periods.
t = prompts[0].count(".")
print(f"n_dots = {t} =? 10")
t = references[0].count(".")
print(f"n_dots = {t} =? 11")

# Original sentence had 9 periods.
t = prompts[8].count(".")
print(f"n_dots = {t} =? 8")
t = references[8].count(".")
print(f"n_dots = {t} =? 9")

n_dots = 10 =? 10
n_dots = 11 =? 11
n_dots = 8 =? 8
n_dots = 9 =? 9


In [10]:
def load_prompts_and_references(train_pct: int = 80, test_pct: str = 20):
    """
    Load prompts and references from OpenWebText.
    Split into training and test sets.
    """
    tot_pct = train_pct + test_pct
    assert tot_pct <= 100, "Train + test percent > 100"

    print("Loading datasets...")
    openwebtext = load_dataset("stas/openwebtext-10k")

    # Extract each prompt and split by sentence.
    prompts_sentences = [
        sample["text"].split(".") for sample in openwebtext["train"] if "text" in sample
    ]

    print(f"Total samples: {len(prompts_sentences)}")

    prompts = []
    references = []
    for sentences in prompts_sentences:
        # Don't use single sentece prompts.
        if len(sentences) <= 1:
            continue
        # Use from 1-10 sentences as prompt.
        n_sentences = len(sentences)
        idx = n_sentences - 1 if n_sentences <= 7 else 7
        next_sentence_words = sentences[idx].lstrip().split(" ")
        if len(next_sentence_words) <= 1:
            continue
        prompts.append(". ".join(sentences[:idx]) + ". " + next_sentence_words[0] + " ")
        references.append(prompts[-1] + " ".join(next_sentence_words[1:]) + ".")

    # Split.
    train_size = int((train_pct / tot_pct) * len(prompts))

    train_prompts = prompts[:train_size]
    train_references = references[:train_size]
    test_prompts = prompts[train_size:]
    test_references = references[train_size:]

    print(
        f"Loaded {len(train_prompts)} training samples and {len(test_prompts)} test samples."
    )
    return (train_prompts, train_references), (test_prompts, test_references)

In [11]:
train, test = load_prompts_and_references()
train_prompts, train_references = train

Loading datasets...
Total samples: 10000
Loaded 7336 training samples and 1835 test samples.


In [12]:
print("** Prompt **")
print(train_prompts[10])

print("** Reference **")
print(train_references[10])

** Prompt **
Don’t raise your voice here. 

Angela Kabari Blocked Unblock Follow Following Jul 20, 2017

My name is Angela.  I am the woman in the centre of the current Ushahidi sexual harassment scandal. 

The past six months have been some of the most bizarre in my life and, on the balance, I think there is benefit in sharing my experience with the world so that lessons may be learned from it.  It is my hope that, my story shall prompt a change in company policies, both in the Kenyan tech space and in other fields. 

I joined Ushahidi in September 2015 as a Capacity Development Officer for Making All Voices Count.  My time there was mostly enjoyable: the work was challenging, the team was great, and the environment was liberal and progressive. All 
** Reference **
Don’t raise your voice here. 

Angela Kabari Blocked Unblock Follow Following Jul 20, 2017

My name is Angela.  I am the woman in the centre of the current Ushahidi sexual harassment scandal. 

The past six months have been