In [1]:
from src.textclf_transformer.tokenizer.wordpiece_tokenizer_wrapper import WordPieceTokenizerWrapper
tok = WordPieceTokenizerWrapper()
tok.load(tokenizer_dir="/Users/michaliwaniuk/Desktop/fromscratch-transformer-classifier/src/textclf_transformer/tokenizer/BERT_original")
from datasets import load_dataset
import torch

In [2]:

def token_stats(dataset: str, max_len: int):
    ds = load_dataset(dataset)
    result = {}
    for split in list(ds.keys()):
        if dataset == 'ccdv/arxiv-classification' and split == 'train':
            continue
        
        data = list(ds[split]['text'])
        tokenized = tok.encode(data, max_length=max_len)
        real_tokens = tokenized[:][0] != 0
        token_lengths = real_tokens.sum(dim=1, dtype=torch.float32)

        max_len_achieved = torch.sum(torch.sum(real_tokens, dim=1) == max_len)
        print(f"Split: {split}\nNumber of examples that are longer that max_len:\n{max_len_achieved}\n{max_len_achieved/len(data)*100}% of total data")

        stats = {
            "avg_tokens": float(token_lengths.mean()),
            "std_tokens": float(token_lengths.std()),
            "median_tokens": float(torch.median(token_lengths)),
            "min_tokens": int(token_lengths.min()),
            "max_tokens": int(token_lengths.max()),
        }
        result[split] = stats
    return result

In [9]:
imdb = token_stats('imdb', 4000)

[INFO] input is treated as a list of input texts
Split: train
Number of examples that are longer that max_len:
0
0.0% of total data
[INFO] input is treated as a list of input texts
Split: test
Number of examples that are longer that max_len:
0
0.0% of total data
[INFO] input is treated as a list of input texts
Split: unsupervised
Number of examples that are longer that max_len:
0
0.0% of total data


In [10]:
imdb

{'train': {'avg_tokens': 313.8713073730469,
  'std_tokens': 234.29586791992188,
  'median_tokens': 233.0,
  'min_tokens': 13,
  'max_tokens': 3127},
 'test': {'avg_tokens': 306.77099609375,
  'std_tokens': 227.89404296875,
  'median_tokens': 230.0,
  'min_tokens': 10,
  'max_tokens': 3157},
 'unsupervised': {'avg_tokens': 314.8410339355469,
  'std_tokens': 234.513671875,
  'median_tokens': 234.0,
  'min_tokens': 13,
  'max_tokens': 3446}}

In [11]:
agnews = token_stats('ag_news', 512)

[INFO] input is treated as a list of paths to text files
[INFO] input is treated as a list of input texts
Split: train
Number of examples that are longer that max_len:
0
0.0% of total data
[INFO] input is treated as a list of paths to text files
[INFO] input is treated as a list of input texts
Split: test
Number of examples that are longer that max_len:
0
0.0% of total data


In [12]:
agnews

{'train': {'avg_tokens': 53.16641616821289,
  'std_tokens': 19.055023193359375,
  'median_tokens': 51.0,
  'min_tokens': 15,
  'max_tokens': 379},
 'test': {'avg_tokens': 52.746185302734375,
  'std_tokens': 18.22791290283203,
  'median_tokens': 50.0,
  'min_tokens': 18,
  'max_tokens': 277}}

In [5]:
arxiv = token_stats('ccdv/arxiv-classification', 32768)

[INFO] input is treated as a list of input texts
Split: validation
Number of examples that are longer that max_len:
182
7.28000020980835% of total data
[INFO] input is treated as a list of input texts
Split: test
Number of examples that are longer that max_len:
164
6.559999942779541% of total data


In [6]:
arxiv

{'validation': {'avg_tokens': 15012.7451171875,
  'std_tokens': 8385.609375,
  'median_tokens': 12885.0,
  'min_tokens': 1373,
  'max_tokens': 32768},
 'test': {'avg_tokens': 14745.8623046875,
  'std_tokens': 8257.8232421875,
  'median_tokens': 12631.0,
  'min_tokens': 1268,
  'max_tokens': 32768}}