In [1]:
# This is the initial dataset which is used in the first initial step of training after this the model should be able to complete text

In [16]:
import os

from datasets import load_dataset

from torch.utils.data import Dataset, DataLoader

from preprocess.sequencing import create_sequences, numperize
from preprocess.tokenizer import BPETokenizer

from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm

import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

In [3]:
SPECIAL_TOKENS = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
TEXT_COMPLETION_PATH = os.path.join("data", "text_completion.json")

train_set = load_dataset("abisee/cnn_dailymail", "3.0.0", split="train[:5%]")

# Load 5% of the validation set
valid_set = load_dataset("abisee/cnn_dailymail", "3.0.0", split="validation[:5%]")

print(f"Training set size: {len(train_set)}")
print(f"Validation set size: {len(valid_set)}")

Training set size: 14356
Validation set size: 668


In [4]:
train_articles = train_set["article"]
train_highlights = train_set["highlights"]


tokenizer = BPETokenizer(
    vocab_size=30000, min_frequency=2, special_tokens=SPECIAL_TOKENS
)

if not os.path.exists(TEXT_COMPLETION_PATH):
    tokenizer.fit(
        train_articles + train_highlights,
    )
    tokenizer.save(TEXT_COMPLETION_PATH)
else:
    tokenizer.load(TEXT_COMPLETION_PATH)

In [5]:
train_articles = [item["article"] for item in tqdm(train_set, desc="Extracting Train Articles") if item["article"] is not None]
valid_articles = [item["article"] for item in tqdm(valid_set, desc="Extracting Valid Articles") if item["article"] is not None]

def encode_article(article):
    return tokenizer.encode(article)

def parallel_encode(articles, desc):
    encoded_articles = []
    with ProcessPoolExecutor() as executor:
        futures = {executor.submit(encode_article, article): article for article in articles}
        for future in tqdm(as_completed(futures), total=len(futures), desc=desc):
            encoded_articles.append(future.result())
    return encoded_articles

train_set_encoded = parallel_encode(train_articles, "Encoding Train Set")
valid_set_encoded = parallel_encode(valid_articles, "Encoding Valid Set")

Extracting Train Articles: 100%|██████████| 14356/14356 [00:00<00:00, 23822.41it/s]
Extracting Valid Articles: 100%|██████████| 668/668 [00:00<00:00, 21401.64it/s]
Encoding Train Set: 100%|██████████| 14356/14356 [00:11<00:00, 1204.37it/s]
Encoding Valid Set: 100%|██████████| 668/668 [00:00<00:00, 1013.10it/s]


In [None]:
def extract_token_ids(encoded_data):
    """
    Convert each Encoding object into its list of token IDs and flatten them into a single list,
    with a progress bar showing the extraction progress.
    """
    flattened_ids = []
    for encoding in tqdm(encoded_data, desc="Extracting Token IDs"):
        flattened_ids.extend(encoding.ids)
    return flattened_ids

# Extract token IDs with progress bars for training and validation sets
train_token_ids = extract_token_ids(train_set_encoded)
valid_token_ids = extract_token_ids(valid_set_encoded)


In [6]:
train_seq = create_sequences(
    tokenized_data=train_token_ids, 
    max_context_length=50,
    max_target_length=1, 
)

valid_seq = create_sequences(
    tokenized_data=valid_token_ids,
    max_context_length=50,
    max_target_length=1, 
)

In [12]:
print(len(train_seq))
for i, (context, target) in enumerate(train_seq):
    print(f"Context: {context[:10]}... (Total: {len(context)} tokens)") 
    print(f"Target: {target} (Total: {len(target)} token)") 
    print(f"Decoded: ...{tokenizer.decode(context)[-10:]}")
    print(f"Decoded: {tokenizer.decode(target)}")
    if i == 5:
        break


11141722
Context: [18089, 467, 1359, 13, 596, 8029, 3783, 1119, 1724, 14725]... (Total: 50 tokens)
Target: [8142] (Total: 1 token)
Decoded: ... in Bethes
Decoded: da
Context: [467, 1359, 13, 596, 8029, 3783, 1119, 1724, 14725, 3190]... (Total: 50 tokens)
Target: [16] (Total: 1 token)
Decoded: ...n Bethesda
Decoded: ,
Context: [1359, 13, 596, 8029, 3783, 1119, 1724, 14725, 3190, 346]... (Total: 50 tokens)
Target: [8515] (Total: 1 token)
Decoded: ... Bethesda,
Decoded:  Maryland
Context: [13, 596, 8029, 3783, 1119, 1724, 14725, 3190, 346, 1435]... (Total: 50 tokens)
Target: [16] (Total: 1 token)
Decoded: ..., Maryland
Decoded: ,
Context: [596, 8029, 3783, 1119, 1724, 14725, 3190, 346, 1435, 4208]... (Total: 50 tokens)
Target: [274] (Total: 1 token)
Decoded: ... Maryland,
Decoded:  for
Context: [8029, 3783, 1119, 1724, 14725, 3190, 346, 1435, 4208, 290]... (Total: 50 tokens)
Target: [6326] (Total: 1 token)
Decoded: ...yland, for
Decoded:  routine


In [19]:
class TextCompletionDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        context, target = self.sequences[idx]
        return context, target
    
train_dataset = TextCompletionDataset(train_seq)
valid_dataset = TextCompletionDataset(valid_seq)

In [20]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)