# Text Main

> This module contains the main Python class for data control: `TextDataMain`

```#| default_exp text_main```

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from __future__ import annotations
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import LabelEncoder,MultiLabelBinarizer
from datasets import DatasetDict,Dataset,IterableDataset,load_dataset,concatenate_datasets
from pathlib import Path
from tqdm import tqdm
from that_nlp_library.utils import *
from functools import partial
import warnings

In [None]:
from that_nlp_library.text_transformation import *
from that_nlp_library.text_augmentation import *
from importlib.machinery import SourceFileLoader
import os

## Content Transformation, Augmentations, and Tokenization

In [None]:
#| export
def tokenizer_explain(inp, # Input sentence
                      tokenizer, # Tokenizer (preferably from HuggingFace)
                      split_word=False # Is input `inp` split into list or not
                     ):
    "Display results from tokenizer"
    print('----- Tokenizer Explained -----')
    print('--- Input ---')
    print(inp)
    print()
    print('--- Tokenized results --- ')
    print(tokenizer(inp,is_split_into_words=split_word))
    print()
    tok = tokenizer.encode(inp,is_split_into_words=split_word)
    print('--- Results from tokenizer.convert_ids_to_tokens ---')
    print(tokenizer.convert_ids_to_tokens(tok))
    print()
    print('--- Results from tokenizer.decode --- ')
    print(tokenizer.decode(tok))
    print()

In [None]:
show_doc(tokenizer_explain)

---

[source](https://github.com/anhquan0412/that-nlp-library/blob/main/that_nlp_library/text_main.py#L20){target="_blank" style="float:right; font-size:smaller"}

### tokenizer_explain

>      tokenizer_explain (inp, tokenizer, split_word=False)

Display results from tokenizer

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| inp |  |  | Input sentence |
| tokenizer |  |  | Tokenizer (preferably from HuggingFace) |
| split_word | bool | False | Is input `inp` split into list or not |

Let's load a tokenizer from EnviBert model. Uncomment the command line below to download necessary files to build this tokenizer

In [None]:
# !pip install gdown

In [None]:
# !gdown 14X9fGijA7kdNfe4dM_8gqfxIWtj1Q-hb -O ./envibert_cache --folder

In [None]:
cache_dir=Path('./envibert_tokenizer')
tokenizer = SourceFileLoader("envibert.tokenizer", 
                             str(cache_dir/'envibert_tokenizer.py')).load_module().RobertaTokenizer(cache_dir)

Note that Envibert tokenizer does not required the input to be tokenized using word_tokenize from UnderTheSea library

In [None]:
inp = 'Hội cư dân   chung cư sen hồng- chung cư    lotus sóng thần thủ đức. Thủ Đức là một huyện trực thuộc thành phố Hồ Chí Minh'
tokenizer_explain(inp,tokenizer,split_word=False)

----- Tokenizer Explained -----
--- Input ---
Hội cư dân   chung cư sen hồng- chung cư    lotus sóng thần thủ đức. Thủ Đức là một huyện trực thuộc thành phố Hồ Chí Minh

--- Tokenized results --- 
{'input_ids': [0, 857, 1033, 191, 664, 1033, 7366, 2615, 142, 664, 1033, 671, 1355, 2294, 993, 413, 2900, 244, 1019, 827, 24, 40, 647, 773, 549, 119, 511, 1134, 1690, 758, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

--- Results from tokenizer.convert_ids_to_tokens ---
['<s>', '▁Hội', '▁cư', '▁dân', '▁chung', '▁cư', '▁sen', '▁hồng', '-', '▁chung', '▁cư', '▁lot', 'us', '▁sóng', '▁thần', '▁thủ', '▁đức', '.', '▁Thủ', '▁Đức', '▁là', '▁một', '▁huyện', '▁trực', '▁thuộc', '▁thành', '▁phố', '▁Hồ', '▁Chí', '▁Minh', '</s>']

--- Results from tokenizer.decode --- 
<s> ▁Hội ▁cư ▁dân ▁chung ▁cư ▁sen ▁hồng - ▁chung ▁cư ▁lot

In [None]:
inp = ['hội', 'cư', 'dân', 'chung', 'cư', 'sen', 'hồng', '-', 'chung', 'cư', 'lotus', 'sóng', 'thần', 'thủ', 'đức']
tokenizer_explain(inp,tokenizer,split_word=True)

----- Tokenizer Explained -----
--- Input ---
['hội', 'cư', 'dân', 'chung', 'cư', 'sen', 'hồng', '-', 'chung', 'cư', 'lotus', 'sóng', 'thần', 'thủ', 'đức']

--- Tokenized results --- 
{'input_ids': [0, 227, 1033, 191, 664, 1033, 7366, 2615, 13, 664, 1033, 671, 1355, 2294, 993, 413, 2900, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

--- Results from tokenizer.convert_ids_to_tokens ---
['<s>', '▁hội', '▁cư', '▁dân', '▁chung', '▁cư', '▁sen', '▁hồng', '▁-', '▁chung', '▁cư', '▁lot', 'us', '▁sóng', '▁thần', '▁thủ', '▁đức', '</s>']

--- Results from tokenizer.decode --- 
<s> ▁hội ▁cư ▁dân ▁chung ▁cư ▁sen ▁hồng ▁- ▁chung ▁cư ▁lot us ▁sóng ▁thần ▁thủ ▁đức </s>



Now let's try PhoBert tokenizer. PhoBert tokenizer, unlike Envibert tokenizer, requires input to be word tokenized (using UnderTheSea library)

In [None]:
from transformers import AutoTokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
inp = apply_vnmese_word_tokenize('hội cư dân chung cư sen hồng - chung cư lotus sóng thần thủ đức')
print(inp)

hội cư_dân chung_cư sen hồng - chung_cư lotus sóng_thần thủ_đức


In [None]:
tokenizer_explain(inp,tokenizer)

----- Tokenizer Explained -----
--- Input ---
hội cư_dân chung_cư sen hồng - chung_cư lotus sóng_thần thủ_đức

--- Tokenized results --- 
{'input_ids': [0, 1093, 1838, 1574, 3330, 2025, 31, 1574, 2029, 4885, 8554, 25625, 7344, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

--- Results from tokenizer.convert_ids_to_tokens ---
['<s>', 'hội', 'cư_dân', 'chung_cư', 'sen', 'hồng', '-', 'chung_cư', 'lo@@', 'tus', 'sóng_thần', 'thủ_@@', 'đức', '</s>']

--- Results from tokenizer.decode --- 
<s> hội cư_dân chung_cư sen hồng - chung_cư lotus sóng_thần thủ_đức </s>



In [None]:
#| export
def two_steps_tokenization_explain(inp, # Input sentence
                                   tokenizer, # Tokenizer (preferably from HuggingFace)
                                   content_tfms=[], # A list of text transformations
                                   aug_tfms=[], # A list of text augmentation 
                                  ):
    "Display results form each content transformation, then display results from tokenizer"
    print('----- Text Transformation Explained -----')
    print('--- Raw sentence ---')
    print(inp)
    print('--- Content Transformations (on both train and test) ---')
    content_tfms = val2iterable(content_tfms)
    for tfm in content_tfms:
        print_msg(callable_name(tfm),3)
        inp = tfm(inp)
        print(inp)
    print('--- Augmentations (on train only) ---')
    aug_tfms = val2iterable(aug_tfms)
    for tfm in aug_tfms:
        print_msg(callable_name(tfm),3)
        inp = tfm(inp)
        print(inp)
    print()
    tokenizer_explain(inp,tokenizer)

In [None]:
show_doc(two_steps_tokenization_explain)

---

[source](https://github.com/anhquan0412/that-nlp-library/blob/main/that_nlp_library/text_main.py#L41){target="_blank" style="float:right; font-size:smaller"}

### two_steps_tokenization_explain

>      two_steps_tokenization_explain (inp, tokenizer, content_tfms=[],
>                                      aug_tfms=[])

Display results form each content transformation, then display results from tokenizer

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| inp |  |  | Input sentence |
| tokenizer |  |  | Tokenizer (preferably from HuggingFace) |
| content_tfms | list | [] | A list of text transformations |
| aug_tfms | list | [] | A list of text augmentation |

Let's load Phobert tokenizer one more time to test out this function

In [None]:
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
from underthesea import text_normalize

`apply_vnmese_word_tokenize` also have an option to normalize text

In [None]:
from functools import partial

In [None]:
inp = 'Hội cư dân   chung cư sen hồng- chung cư    lotus sóng thần thủ đức. Thủ Đức là một huyện trực thuộc thành phố Hồ Chí Minh'
two_steps_tokenization_explain(inp,tokenizer,content_tfms=[partial(apply_vnmese_word_tokenize,normalize_text=True)])

----- Text Transformation Explained -----
--- Raw sentence ---
Hội cư dân   chung cư sen hồng- chung cư    lotus sóng thần thủ đức. Thủ Đức là một huyện trực thuộc thành phố Hồ Chí Minh
--- Content Transformations (on both train and test) ---
--- apply_vnmese_word_tokenize ---
Hội cư_dân chung_cư sen hồng - chung_cư lotus sóng_thần thủ_đức . Thủ_Đức là một huyện trực_thuộc thành_phố Hồ_Chí_Minh
--- Augmentations (on train only) ---

----- Tokenizer Explained -----
--- Input ---
Hội cư_dân chung_cư sen hồng - chung_cư lotus sóng_thần thủ_đức . Thủ_Đức là một huyện trực_thuộc thành_phố Hồ_Chí_Minh

--- Tokenized results --- 
{'input_ids': [0, 792, 1838, 1574, 3330, 2025, 31, 1574, 2029, 4885, 8554, 25625, 7344, 5, 5043, 8, 16, 149, 2850, 214, 784, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

--- Results from tokenizer.convert_ids_to_tokens ---
['<s>', 'Hội',

Let's add some text augmentations

In [None]:
import unidecode

In [None]:
# to remove vietnamese accent
remove_accent = lambda x: unidecode.unidecode(x)

If you want your function to be printed in with a different name:

In [None]:
remove_accent.__name__ = 'Remove Vietnamese Accent'

In [None]:
two_steps_tokenization_explain(inp,tokenizer,
                               content_tfms=[partial(apply_vnmese_word_tokenize,normalize_text=True)],
                               aug_tfms=[remove_accent]
                              )

----- Text Transformation Explained -----
--- Raw sentence ---
Hội cư dân   chung cư sen hồng- chung cư    lotus sóng thần thủ đức. Thủ Đức là một huyện trực thuộc thành phố Hồ Chí Minh
--- Content Transformations (on both train and test) ---
--- apply_vnmese_word_tokenize ---
Hội cư_dân chung_cư sen hồng - chung_cư lotus sóng_thần thủ_đức . Thủ_Đức là một huyện trực_thuộc thành_phố Hồ_Chí_Minh
--- Augmentations (on train only) ---
--- Remove Vietnamese Accent ---
Hoi cu_dan chung_cu sen hong - chung_cu lotus song_than thu_duc . Thu_Duc la mot huyen truc_thuoc thanh_pho Ho_Chi_Minh

----- Tokenizer Explained -----
--- Input ---
Hoi cu_dan chung_cu sen hong - chung_cu lotus song_than thu_duc . Thu_Duc la mot huyen truc_thuoc thanh_pho Ho_Chi_Minh

--- Tokenized results --- 
{'input_ids': [0, 3021, 1111, 56549, 17386, 22975, 13689, 3330, 27037, 31, 22975, 13689, 2029, 4885, 3227, 9380, 1510, 21605, 6190, 1894, 5, 5770, 4098, 1894, 2644, 3773, 1204, 18951, 2052, 10242, 9835, 1881, 22899, 

You can even be creative with your augmentation functions; let's say you only want your augmentation to be applied 50% of the time:

In [None]:
import random

In [None]:
random.seed(2) # for reproducibility

In [None]:
remove_accent = lambda x: unidecode.unidecode(x) if random.random()<0.5 else x
remove_accent.__name__ = 'Remove Vietnamese Accent with 0.5 prob'

In [None]:
two_steps_tokenization_explain(inp,tokenizer,
                               content_tfms=[partial(apply_vnmese_word_tokenize,normalize_text=True)],
                               aug_tfms=[remove_accent]
                              )

----- Text Transformation Explained -----
--- Raw sentence ---
Hội cư dân   chung cư sen hồng- chung cư    lotus sóng thần thủ đức. Thủ Đức là một huyện trực thuộc thành phố Hồ Chí Minh
--- Content Transformations (on both train and test) ---
--- apply_vnmese_word_tokenize ---
Hội cư_dân chung_cư sen hồng - chung_cư lotus sóng_thần thủ_đức . Thủ_Đức là một huyện trực_thuộc thành_phố Hồ_Chí_Minh
--- Augmentations (on train only) ---
--- Remove Vietnamese Accent with 0.5 prob ---
Hội cư_dân chung_cư sen hồng - chung_cư lotus sóng_thần thủ_đức . Thủ_Đức là một huyện trực_thuộc thành_phố Hồ_Chí_Minh

----- Tokenizer Explained -----
--- Input ---
Hội cư_dân chung_cư sen hồng - chung_cư lotus sóng_thần thủ_đức . Thủ_Đức là một huyện trực_thuộc thành_phố Hồ_Chí_Minh

--- Tokenized results --- 
{'input_ids': [0, 792, 1838, 1574, 3330, 2025, 31, 1574, 2029, 4885, 8554, 25625, 7344, 5, 5043, 8, 16, 149, 2850, 214, 784, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

There are more examples of interesting augmentation [here](https://anhquan0412.github.io/that-nlp-library/text_augmentation.html)

## Tokenize Function

In [None]:
#| export
def tokenize_function(examples:dict,
                      tok,
                      text_name,
                      max_length=None,
                      is_split_into_words=False):
    if max_length is None:
        # pad to model's default max sequence length
        return tok(examples[text_name], padding="max_length", truncation=True,is_split_into_words=is_split_into_words)
    if isinstance(max_length,int) and max_length>0:
        # pad to max length of the current batch, and start truncating at max_length
        return tok(examples[text_name], padding=True, max_length=max_length,truncation=True,is_split_into_words=is_split_into_words)
    
    # no padding (still truncate at model's default max sequence length)
    return tok(examples[text_name], truncation=True,is_split_into_words=is_split_into_words)

In [None]:
show_doc(tokenize_function)

---

[source](https://github.com/anhquan0412/that-nlp-library/blob/main/that_nlp_library/text_main.py#L58){target="_blank" style="float:right; font-size:smaller"}

### tokenize_function

>      tokenize_function (examples:dict, tok, text_name, max_length=None,
>                         is_split_into_words=False)

Since I am processing Vietnamese text, I will use EnViBert's tokenizer. Envibert is a RoBERTa model for Vietnamese and English. This RoBERTa version is trained by using 100GB of text (50GB of Vietnamese and 50GB of English). For more information: [https://huggingface.co/nguyenvulebinh/envibert](https://huggingface.co/nguyenvulebinh/envibert)

In [None]:
# https://huggingface.co/nguyenvulebinh/envibert
cache_dir=Path('./envibert_tokenizer')
tokenizer = SourceFileLoader("envibert.tokenizer", 
                             str(cache_dir/'envibert_tokenizer.py')).load_module().RobertaTokenizer(cache_dir)

In [None]:
examples={
    'text':[
         'hội cư dân chung cư sen hồng - chung cư lotus sóng thần thủ đức',
         'This is the recommended way to make a Python package importable from anywhere',
         'hội cần mở thẻ tín dụng tại hà nội, đà nẵng, tp. hồ chí minh',
         "biti's cao lãnh - đồng tháp",
         'chợ phòng trọ + việc làm...khu lĩnh nam - vĩnh hưng - mai động (hoàng mai)'
          ],
}

In [None]:
results = tokenize_function(examples,tokenizer,text_name='text',max_length=512)

In [None]:
results

{'input_ids': [[0, 227, 1033, 191, 664, 1033, 7366, 2615, 13, 664, 1033, 671, 1355, 2294, 993, 413, 2900, 2, 1, 1, 1, 1, 1, 1, 1], [0, 116, 14, 6, 3169, 270, 9, 364, 10, 23963, 5360, 15930, 2003, 51, 5906, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 227, 256, 778, 2600, 1074, 144, 76, 5489, 613, 57339, 4820, 27666, 57339, 21422, 244, 872, 635, 841, 2, 1, 1, 1, 1, 1], [0, 880, 592, 427, 162, 171, 906, 13, 122, 6553, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 2299, 315, 5995, 1349, 99, 83, 55025, 244, 6356, 1114, 1213, 1163, 13, 8233, 11051, 13, 3335, 109, 28, 11695, 13377, 3335, 3, 2]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'att

In [None]:
print(tokenizer.convert_ids_to_tokens(results['input_ids'][0]))

['<s>', '▁hội', '▁cư', '▁dân', '▁chung', '▁cư', '▁sen', '▁hồng', '▁-', '▁chung', '▁cư', '▁lot', 'us', '▁sóng', '▁thần', '▁thủ', '▁đức', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']


You can change max_length (which allow truncation when sentence length is higher than max_length) 

In [None]:
results = tokenize_function(examples,tokenizer,text_name='text',max_length=5)

In [None]:
results

{'input_ids': [[0, 227, 1033, 191, 2], [0, 116, 14, 6, 2], [0, 227, 256, 778, 2], [0, 880, 592, 427, 2], [0, 2299, 315, 5995, 2]], 'token_type_ids': [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]}

## Metadatas Processing 

In [None]:
#| export
def concat_metadatas(dset:dict, # HuggingFace Dataset
                     main_text, # Text feature name
                     metadatas, # Metadata (or a list of metadatas)
                     process_metas=True, # Whether apply simple metadata processing, i.e. space strip and lowercase
                     sep='.', # separator for contatenating to main_text
                     is_batched=True, # whether batching is applied
                    ):
    """
    Extract, process (optional) and concatenate metadatas to the front of text
    """
    results={main_text:dset[main_text]}
    for m in metadatas:
        m_data = dset[m]
        if process_metas:
            # just strip and lowercase
            m_data = [nan2emptystr(v).strip().lower() for v in m_data] if is_batched else nan2emptystr(m_data).strip().lower()
        results[m]=m_data
        if is_batched:
            results[main_text] = [f'{m_data[i]} {sep} {results[main_text][i]}' for i in range(len(m_data))]
        else:
            results[main_text] = f'{m_data} {sep} {results[main_text]}'
    return results

In [None]:
show_doc(concat_metadatas)

---

### concat_metadatas

>      concat_metadatas (dset:dict, main_text, metadatas, process_metas=True,
>                        sep='.', is_batched=True)

Extract, process (optional) and concatenate metadatas to the front of text

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| dset | dict |  | HuggingFace Dataset |
| main_text |  |  | Text feature name |
| metadatas |  |  | Metadata (or a list of metadatas) |
| process_metas | bool | True | Whether apply simple metadata processing, i.e. space strip and lowercase |
| sep | str | . | separator for contatenating to main_text |
| is_batched | bool | True | whether batching is applied |

## Class TextDataController

In [None]:
#| export
class TextDataController():
    def __init__(self,
                 inp, # HuggingFace Dataset or DatasetDict
                 main_text:str, # Name of the main text column
                 label_names=None, # Names of the label (dependent variable) columns
                 class_names_predefined=None, # List of names associated with the labels (same index order)
                 filter_dict={}, # A dictionary: {feature: filtering_function_based_on_the_feature}
                 metadatas=[], # Names of the metadata columns
                 process_metas=True, # Whether to do simple text processing on the chosen metadatas
                 content_transformations=[], # A list of text transformations
                 val_ratio:list|float|None=0.2, # Ratio of data for validation set. If given a list, validation set will be chosen based on indices in this list
                 stratify_cols=[], # Column(s) needed to do stratified shuffle split
                 upsampling_dict={}, # A dictionary: {feature: upsampling_function_based_on_the_feature}
                 content_augmentations=[], # A list of text augmentations
                 seed=None, # Random seed
                 is_batched=True, # Whether to perform operations in batch
                 batch_size=1000, # Batch size, for when is_batched is True
                 num_proc=4, # Number of process for multiprocessing
                 cols_to_keep=None, # Columns to keep after all processings
                 buffer_size=10000, # For shuffling data
                 num_shards=64, # Number of shards
                ):
            
        self.main_text = main_text
        self.metadatas = val2iterable(metadatas)
        self.process_metas = process_metas
        self.label_names = val2iterable(label_names) if label_names is not None else None
        self.label_lists = class_names_predefined
        self.filter_dict = filter_dict
        self.content_tfms = val2iterable(content_transformations)
        self.upsampling_dict = upsampling_dict
        self.aug_tfms = val2iterable(content_augmentations)
        self.val_ratio = val_ratio
        self.stratify_cols = val2iterable(stratify_cols)
        self.seed = seed
        self.is_batched = is_batched
        self.batch_size = batch_size
        self.num_proc = num_proc
        self.is_streamed = False
        self.cols_to_keep = cols_to_keep
        self.buffer_size = buffer_size
        self.num_shards = num_shards
        self.ddict_rest = DatasetDict()
        
        if hasattr(inp,'keys'):
            if 'train' in inp.keys(): # is datasetdict
                self.ddict_rest = inp
                self.dset = self.ddict_rest.pop('train')
            else:
                raise ValueError('The given DatasetDict has no "train" split')
        else: # is dataset
            self.dset = inp
        if isinstance(self.dset,IterableDataset):
            self.is_streamed=True
        self.all_cols = self.dset.column_names
        
        if self.is_streamed and self.label_names is not None and self.label_lists is None:
            raise ValueError('All class labels must be provided when streaming')
        
        if self.is_streamed and len(self.upsampling_dict):
            warnings.warn("Upsampling requires dataset concatenation, which can be extremely slow (x2) for streamed dataset")
            
        self._processed_call=False
        self.is_multilabel=False
        self.is_multihead=False
        
        
            
    @classmethod
    def from_csv(cls,file_path,**kwargs):
        file_path = Path(file_path)
        ds = load_dataset(str(file_path.parent),
                                  data_files=file_path.name,
                                  split='train')
        return TextDataController(ds,**kwargs)
        
    
    @classmethod
    def from_df(cls,df,validate=True,**kwargs):
        if validate:
            check_input_validation(df)
        ds = Dataset.from_pandas(df)
        return TextDataController(ds,**kwargs)
    
    
    def _map_dset(self,dset,func):
        if self.is_streamed:
            return dset.map(func,
                            batched=self.is_batched,
                            batch_size=self.batch_size
                           )
        return dset.map(func,
                        batched=self.is_batched,
                        batch_size=self.batch_size,
                        num_proc=self.num_proc
                       )
    
    def _filter_dset(self,dset,func):
        if self.is_streamed:
            return dset.filter(func,
                            batched=self.is_batched,
                            batch_size=self.batch_size
                           )
        return dset.filter(func,
                        batched=self.is_batched,
                        batch_size=self.batch_size,
                        num_proc=self.num_proc
                       )
                     
    def validate_input(self):
        if self.is_streamed:
            print('Input validation check is disabled when data is streamed')
            return
        _df = self.dset.to_pandas()
        check_input_validation(_df)
    
    
    
    def save_as_pickles(self,
                        fname, # Name of the pickle file
                        parent='pickle_files', # Parent folder
                        drop_data_attributes=False # Whether to drop all large-size data attributes
                       ):
        if drop_data_attributes:
            if hasattr(self, 'main_ddict'):
                del self.main_ddict
        save_to_pickle(self,fname,parent=parent)
    
        
    def _check_validation_leaking(self):
        if self.val_ratio is None or self.is_streamed:
            return
        
        trn_txt = self.main_ddict['train'][self.main_text]
        val_txt = self.main_ddict['validation'][self.main_text]        
        val_txt_leaked = check_text_leaking(trn_txt,val_txt)
        
        if len(val_txt_leaked)==0: return
        
        # filter train dataset to get rid of leaks
        print('Filtering leaked data out of training set...')
        _func = partial(lambda_batch,
                        feature=self.main_text,
                        func=lambda x: x.strip().lower() not in val_txt_leaked,
                        is_batched=self.is_batched)
        self.main_ddict['train'] = self._filter_dset(self.main_ddict['train'],_func)   
        print('Done')
           
    def _train_test_split(self):
        print_msg('Train Test Split',20)
        val_key = list(set(self.ddict_rest.keys()) & set(['val','validation','valid']))
        if len(val_key)==1: # val split exists
            self.main_ddict=DatasetDict({'train':self.dset,
                                         'validation':self.ddict_rest.pop(val_key[0])})
            
    
        elif self.val_ratio is None: # use all data
            self.main_ddict=DatasetDict({'train':self.dset})
        
        elif isinstance(self.val_ratio,list) or isinstance(self.val_ratio,np.ndarray): # filter with indices
            if self.is_streamed: raise ValueError('Data streaming does not support validation set filtering using indices')
            val_idxs = list(self.val_ratio)
            trn_idxs = list(set(range(len(self.dset))) - set(val_idxs))
            self.main_ddict=DatasetDict({'train':self.dset.select(trn_idxs),
                                         'validation':self.dset.select(val_idxs)})
            
        elif (isinstance(self.val_ratio,float) or isinstance(self.val_ratio,int)) and not len(self.stratify_cols):
            if self.is_streamed:
                # shuffle dataset before splitting it
                self.dset = self.dset.shuffle(seed=self.seed,buffer_size=self.buffer_size)
                if isinstance(self.val_ratio,float):
                    warnings.warn("Length of streamed dataset is unknown to use float validation ratio. Default to 5000 data points for validation")
                    self.val_ratio=5000
                    
                trn_dset = self.dset.skip(self.val_ratio)
                val_datas = list(self.dset.take(self.val_ratio))
                val_dict={k: [v[k] for v in val_datas] for k in val_datas[0].keys()}
                val_dset = Dataset.from_dict(val_dict)
                self.main_ddict=DatasetDict({'train':trn_dset,
                                         'validation':val_dset})
#                 self.main_ddict=DatasetDict({'train':self.dset.skip(self.val_ratio),
#                                          'validation':self.dset.take(self.val_ratio)})
            else:
                # train val split
                self.main_ddict = self.dset.train_test_split(test_size=self.val_ratio,shuffle=True,seed=self.seed)
                self.main_ddict['validation']=self.main_ddict['test']
                del self.main_ddict['test']
        
        else: # val_ratio split with stratifying
            if self.is_streamed: raise ValueError('Stratified split is not supported for streamed data')                
            if self.is_multilabel and self.label_names[0] in self.stratify_cols:
                raise ValueError('For MultiLabel classification, you cannot choose the label as your stratified column')
            
            # Create a new feature 'stratified', which is a concatenation of values in stratify_cols
            if self.is_batched:
                stratified_creation = lambda x: {'stratified':
                                     ['_'.join(list(map(str,[x[v][i] for v in self.stratify_cols]))) 
                                      for i in range(len(x[self.stratify_cols[0]]))]}
            else:
                stratified_creation = lambda x: {'stratified':
                                     '_'.join(list(map(str,[x[v] for v in self.stratify_cols]))) 
                                      }
            self.dset = self.dset.map(stratified_creation,
                                      batched=self.is_batched,
                                      batch_size=self.batch_size,
                                      num_proc=self.num_proc)
            self.dset=self.dset.class_encode_column("stratified")
            # train val split
            self.main_ddict = self.dset.train_test_split(test_size=self.val_ratio,
                                                         shuffle=True,seed=self.seed,
                                                        stratify_by_column='stratified')
            self.main_ddict['validation']=self.main_ddict['test']
            del self.main_ddict['test']
            self.main_ddict=self.main_ddict.remove_columns(['stratified'])
            
        
        del self.dset
        print('Done')

                             
    def _create_label_mapping_func(self,encoder_classes):
        if self.is_multihead:
            label2idxs = [{v:i for i,v in enumerate(l_classes)} for l_classes in encoder_classes]
                    
            _func = lambda inp: {'label': [[label2idxs[i][v] for i,v in enumerate(vs)] for vs in zip(*[inp[l] for l in self.label_names])] \
                                    if self.is_batched else [label2idxs[i][v] for i,v in enumerate([inp[l] for l in self.label_names])]
                              }
            
        else:
            label2idx = {v:i for i,v in enumerate(encoder_classes[0])}
            _func = partial(lambda_map_batch,
                           feature=self.label_names[0],
                           func=lambda x: label2idx[x],
                           output_feature='label',
                           is_batched=self.is_batched)
        return _func
        
    def _encode_labels(self):
        print_msg('Label Encoding')
        if len(self.label_names)>1:
            self.is_multihead=True
        
        if self.label_lists is not None and not isinstance(self.label_lists[0],list):
            self.label_lists = [self.label_lists]
        
        # get label of first row
        first_label = self.dset[self.label_names[0]][0] if not self.is_streamed else next(iter(self.dset))[self.label_names[0]]
        if isinstance(first_label,list):
            # This is multi-label. Ignore self.label_names[1:]
            self.label_names = [self.label_names[0]]
            self.is_multihead=False
            self.is_multilabel=True
            
        encoder_classes=[]
        if not self.is_multilabel:
            for idx,l in enumerate(self.label_names):
                if self.label_lists is None:
                    l_encoder = LabelEncoder()
                    _ = l_encoder.fit(self.dset[l])
                    l_classes = list(l_encoder.classes_)
                else:
                    l_classes = sorted(list(self.label_lists[idx]))
                encoder_classes.append(l_classes)
            
            _func = self._create_label_mapping_func(encoder_classes)
                
            self.dset = self._map_dset(self.dset,_func)

            val_key = list(set(self.ddict_rest.keys()) & set(['val','validation','valid']))
            if len(val_key)>1: raise ValueError('Your DatasetDict has more than 1 validation split')
            if len(val_key)==1:
                val_key=val_key[0]
                self.ddict_rest[val_key] = self._map_dset(self.ddict_rest[val_key],_func)
                    
        else:
            # For MultiLabel, we transform the label itself to one-hot (or actually, few-hot)
            if self.label_lists is None:
                l_encoder = MultiLabelBinarizer()
                _ = l_encoder.fit(self.dset[self.label_names[0]])
                l_classes = list(l_encoder.classes_)
            else:
                l_classes = sorted(list(self.label_lists[0]))
            
            encoder_classes.append(l_classes)
            
            l_encoder = MultiLabelBinarizer(classes=encoder_classes[0])
            _ = l_encoder.fit(None)
            _func = partial(lambda_map_batch,
                            feature=self.label_names[0],
                            func=lambda x: l_encoder.transform(x),
                            output_feature='label',
                            is_batched=self.is_batched,
                            is_func_batched=True)
            self.dset = self._map_dset(self.dset,_func)                                                  
            
            val_key = list(set(self.ddict_rest.keys()) & set(['val','validation','valid']))
            if len(val_key)>1: raise ValueError('Your DatasetDict has more than 1 validation dataset')
            if len(val_key)==1:
                val_key=val_key[0]
                self.ddict_rest[val_key] = self._map_dset(self.ddict_rest[val_key],_func)
            
        self.label_lists = encoder_classes
        print('Done')
        
    def _process_metadatas(self,dset,ddict_rest=None):
        if len(self.metadatas)>0:
            print_msg('Metadata Simple Processing & Concatenating to Main Content')
            map_func = partial(concat_metadatas,
                               main_text=self.main_text,
                               metadatas=self.metadatas,
                               process_metas=self.process_metas,
                               is_batched=self.is_batched)
            dset = self._map_dset(dset,map_func)
            if ddict_rest is not None:
                ddict_rest = self._map_dset(ddict_rest,map_func)
            print('Done')
        return dset if ddict_rest is None else (dset,ddict_rest)
            
            
    
    def _simplify_ddict(self):
        print_msg('Dropping unused features',20)
        if self.cols_to_keep is None:
            self.cols_to_keep= [self.main_text] + self.metadatas
            if self.label_names is not None: self.cols_to_keep+=self.label_names
        cols_to_remove = set(self.all_cols) - set(self.cols_to_keep)
        self.main_ddict['train']=self.main_ddict['train'].remove_columns(list(cols_to_remove))
        if 'validation' in self.main_ddict.keys():
            self.main_ddict['validation']=self.main_ddict['validation'].remove_columns(list(cols_to_remove))
        print('Done')
    
    def _do_transformation(self,dset,ddict_rest=None):
        if len(self.content_tfms):
            print_msg('Text Transformation',20)
            for tfm in self.content_tfms:
                print_msg(callable_name(tfm))
                _func = partial(lambda_map_batch,
                               feature=self.main_text,
                               func=tfm,
                               is_batched=self.is_batched)
                dset = self._map_dset(dset,_func)
                if ddict_rest is not None:
                    ddict_rest = self._map_dset(ddict_rest,_func)
            print('Done')
        return dset if ddict_rest is None else (dset,ddict_rest)
 
    def _do_filtering(self,dset,ddict_rest=None):
        if len(self.filter_dict):
            print_msg('Data Filtering',20)
            for f,tfm in self.filter_dict.items():
                print_msg(f'Do {callable_name(tfm)} on {f}')
                _func = partial(lambda_batch,
                                feature=f,
                                func=tfm,
                                is_batched=self.is_batched)
                dset = self._filter_dset(dset,_func)
                if ddict_rest is not None:
                    ddict_rest = self._filter_dset(ddict_rest,_func)
            print('Done')
        return dset if ddict_rest is None else (dset,ddict_rest)
    
    def _upsampling(self):
        if len(self.upsampling_dict):
            print_msg('Upsampling data',20)
            results=[]
            for f,tfm in self.upsampling_dict.items():
                print_msg(f'Do {callable_name(tfm)} on {f}')
                _func = partial(lambda_batch,
                                feature=f,
                                func=tfm,
                                is_batched=self.is_batched)
                new_dset = self._filter_dset(self.main_ddict['train'],_func)
                results.append(new_dset)
            # slow concatenation for iterable dataset    
            self.main_ddict['train'] = concatenate_datasets(results+[self.main_ddict['train']])
            print('Done')
      
    def _do_augmentation(self):
        
        if len(self.aug_tfms):
            print_msg('Text Augmentation',20)

            seed_notorch(self.seed)
            if not self.is_streamed:  
#                 self.main_ddict['train'] = self.main_ddict['train'].with_transform(partial(augmentation_helper,
#                                                                        text_name=self.main_text,
#                                                                        func=partial(func_all,functions=self.aug_tfms)))              
                for tfm in self.aug_tfms:
                    print_msg(callable_name(tfm))
                    _func = partial(lambda_map_batch,
                                   feature=self.main_text,
                                   func=tfm,
                                   is_batched=self.is_batched)
                    self.main_ddict['train'] = self._map_dset(self.main_ddict['train'],_func)

            else:
                self.main_ddict['train'] = IterableDataset.from_generator(augmentation_stream_generator,
                                               features = self.main_ddict['train'].features,
                                               gen_kwargs={'dset': self.main_ddict['train'],
                                                           'text_name':self.main_text,
                                                           'func':partial(func_all,functions=self.aug_tfms)
                                                          })
            print('Done')
        
    def _convert_to_iterable(self):
        if not self.is_streamed:
            self.main_ddict['train'] = self.main_ddict['train'].to_iterable_dataset(num_shards=self.num_shards)
            self.is_streamed=True
            
    def _do_train_shuffling(self):
        print_msg('Shuffling train set',20)
        self.main_ddict['train'] = self.main_ddict['train'].shuffle(seed=self.seed, buffer_size=self.buffer_size)
        print('Done')
        
    def do_all_preprocessing(self,shuffle_trn=True): 
        if self._processed_call:
            warnings.warn('Your dataset has already been processed. Returning the previous processed DatasetDict...')
            return self.main_ddict
            
        print_msg('Start Main Text Processing',20)
        
        # Filtering
        self.dset,self.ddict_rest = self._do_filtering(self.dset,self.ddict_rest)
        
        # Process metadatas
        self.dset,self.ddict_rest = self._process_metadatas(self.dset,self.ddict_rest)
        
        # Process labels
        if self.label_names is not None:
            self._encode_labels()
        
        # Content transformation
        self.dset,self.ddict_rest = self._do_transformation(self.dset,self.ddict_rest)
         
        # Train Test Split.
        ### self.main_ddict is created here
        self._train_test_split()
        
        # Dropping unused columns
        self._simplify_ddict()
        
        # Check validation leaking
        self._check_validation_leaking()
        
        ### The rest of these functions applies only to the train dataset
        # Upsampling
        self._upsampling()
        
        # Augmentation
        self._do_augmentation()
           
        # Convert train set to iterable
        self._convert_to_iterable()
        
        # Shuffle train
        if shuffle_trn:
            self._do_train_shuffling()
        
        self._processed_call=True
        
        return self.main_ddict
    
        
    
    
    def do_tokenization(self,
                       tokenizer, # Tokenizer (preferably from HuggingFace)
                       is_split_into_words=False, # Is text split into words or not
                       max_length=None, # pad to model's allowed max length (default is max_sequence_length)
                       trn_size=None, # The number of training data to be tokenized
                      ):
        print_msg('Tokenization',20)
        self.tokenizer = tokenizer
        self.is_split_into_words= is_split_into_words
        self.max_length = max_length
        if trn_size is not None:
            if trn_ratio<1:
                raise ValueError("Length of streamed dataset is unknown to use floating ratio")
            self.main_ddict['train'] = self.main_ddict['train'].take(trn_ratio)
        
        for k in self.main_ddict.keys():
            self.main_ddict[k] = self.main_ddict[k].map(partial(tokenize_function,
                                                                    text_name=self.main_text,
                                                                    tok=tokenizer,
                                                                    is_split_into_words=is_split_into_words,
                                                                    max_length=max_length),
                                                            batched=True, # always true
                                                            batch_size=self.batch_size
                                                           )
        print('Done')
        return self.main_ddict
        
    def process_and_tokenize(self,
                             tokenizer, # Tokenizer (preferably from HuggingFace)
                             is_split_into_words=False, # Is text split into list or not
                             max_length=None, # pad to model's allowed max length (default is max_sequence_length)
                             trn_size=None, # The number of training data to be tokenized
                             shuffle_trn=True, # To shuffle the train set before tokenization
                            ):
        _ = self.do_all_preprocessing(shuffle_trn)
        _ = self.do_tokenization(tokenizer,is_split_into_words,max_length,trn_size)
        
    
    def set_data_collator(self,data_collator):
        self.data_collator = data_collator
        
    
    def prepare_test_dataset_from_csv(self,file_path):
        file_path = Path(file_path)
        ds = load_dataset(str(file_path.parent),
                          data_files=file_path.name,
                          split='train')
        return self.prepare_test_dataset(ds)
    
    def prepare_test_dataset_from_df(self,df,validate=True):
        if validate:
            check_input_validation(df)
        ds = Dataset.from_pandas(df)
        return self.prepare_test_dataset(ds)
        
    def prepare_test_dataset(self,test_dset,do_filtering=False):
        print_msg('Start Test Set Transformation',20)

        # Filtering
        if do_filtering:
            test_dset = self._do_filtering(test_dset)
        
        # Process metadatas
        test_dset = self._process_metadatas(test_dset)
        
        
        # Content transformation
        test_dset = self._do_transformation(test_dset)
        
        # Tokenization
        print_msg('Tokenization',20)
        test_dset = test_dset.map(partial(tokenize_function,
                                          text_name=self.main_text,
                                          tok=tokenizer,
                                          is_split_into_words=is_split_into_words,
                                          max_length=max_length),
                                  batched=True, # always true
                                  batch_size=self.batch_size
                                 )
        return test_dset


In [None]:
show_doc(TextDataController)

---

### TextDataController

>      TextDataController (inp, main_text:str, label_names=None,
>                          class_names_predefined=None, filter_dict={},
>                          metadatas=[], process_metas=True,
>                          content_transformations=[],
>                          val_ratio:list|float|None=0.2, stratify_cols=[],
>                          upsampling_dict={}, content_augmentations=[],
>                          seed=None, is_batched=True, batch_size=1000,
>                          num_proc=4, cols_to_keep=None, buffer_size=10000,
>                          num_shards=64)

Initialize self.  See help(type(self)) for accurate signature.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| inp |  |  | HuggingFace Dataset or DatasetDict |
| main_text | str |  | Name of the main text column |
| label_names | NoneType | None | Names of the label (dependent variable) columns |
| class_names_predefined | NoneType | None | List of names associated with the labels (same index order) |
| filter_dict | dict | {} | A dictionary: {feature: filtering_function_based_on_the_feature} |
| metadatas | list | [] | Names of the metadata columns |
| process_metas | bool | True | Whether to do simple text processing on the chosen metadatas |
| content_transformations | list | [] | A list of text transformations |
| val_ratio | list \| float \| None | 0.2 | Ratio of data for validation set. If given a list, validation set will be chosen based on indices in this list |
| stratify_cols | list | [] | Column(s) needed to do stratified shuffle split |
| upsampling_dict | dict | {} | A dictionary: {feature: upsampling_function_based_on_the_feature} |
| content_augmentations | list | [] | A list of text augmentations |
| seed | NoneType | None | Random seed |
| is_batched | bool | True | Whether to perform operations in batch |
| batch_size | int | 1000 | Batch size, for when is_batched is True |
| num_proc | int | 4 | Number of process for multiprocessing |
| cols_to_keep | NoneType | None | Columns to keep after all processings |
| buffer_size | int | 10000 | For shuffling data |
| num_shards | int | 64 | Number of shards |

## Load data + Basic use case

You can create a `TextDataController` from a csv, pandas DataFrame, or directly from a HuggingFace dataset object. Currently, `TextDataController` is designed for text classification, so you must provide the column name for the label (or multi-label)

We will load a sample data, modified to match a task where you need to determine which category `L1` a comment (`Content`) belongs to 

Dataset source: https://www.kaggle.com/datasets/kavita5/review_ecommerce

In [None]:
import pandas as pd

In [None]:
df = pd.read_csv('sample_data/Womens_Clothing_Reviews.csv',encoding='utf-8-sig')

In [None]:
df.shape

(23486, 10)

In [None]:
df.sample(5)

Unnamed: 0,Clothing ID,Age,Title,Review Text,Rating,Recommended IND,Positive Feedback Count,Division Name,Department Name,Class Name
1711,970,37,Soft and flattering,"This is a cute work jacket, as well as paired ...",5,1,1,General,Jackets,Jackets
1870,1080,33,Gorgeous!,I feel like an indian princess in this dress! ...,5,1,0,General,Dresses,Dresses
1058,873,42,Love the color!,I love this shirt so much i am ordering the co...,5,1,3,General,Tops,Knits
17773,819,67,,"Loved the fit and colors, but the fabric is ve...",2,0,18,General,Tops,Blouses
10683,1083,32,Showstopper!,This dress has been on retailer's site for a w...,5,1,0,General,Dresses,Dresses


You can create a `TextDataController` from a dataframe. This also provides a quick input validation check (NaN check and Duplication check)

In [None]:
tdc = TextDataController.from_df(df,
                                 main_text='Review Text',
                                 label_names='Department Name',
                                )

----- Input Validation Precheck -----
Data contains missing values!
-----> List of columns and the number of missing values for each
Title              3810
Review Text         845
Division Name        14
Department Name      14
Class Name           14
dtype: int64
Data contains duplicated values!
-----> Number of duplications: 21 rows


You can also create a `TextDataController` directly from the csv file. The good thing about using HuggingFace Dataset as the main backend of the TextDataController is that you can utilize lots of its useful functionality, such as caching

In [None]:
tdc = TextDataController.from_csv('sample_data/Womens_Clothing_Reviews.csv',
                                  main_text='Review Text',
                                  label_names='Department Name',
                                 )

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/sample_data-f893627565d98cd2/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


You can also create a `TextDataController` from a HuggingFace Dataset

In [None]:
dset = load_dataset('sample_data',data_files=['Womens_Clothing_Reviews.csv'],split='train')
dset

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/sample_data-f893627565d98cd2/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


Dataset({
    features: ['Clothing ID', 'Age', 'Title', 'Review Text', 'Rating', 'Recommended IND', 'Positive Feedback Count', 'Division Name', 'Department Name', 'Class Name'],
    num_rows: 23486
})

In [None]:
tdc = TextDataController(dset,
                         main_text='Review Text',
                         label_names='Department Name',
                         seed=42
                        )

As we noticed, our dataset has missing values in the text field and the label field. For now, let's load the data as a Pandas' DataFrame, perform some cleaning, and create our `TextDataController`

In [None]:
df = pd.read_csv('sample_data/Womens_Clothing_Reviews.csv',encoding='utf-8-sig')

In [None]:
df = df[(~df['Review Text'].isna()) & (~df['Department Name'].isna())].reset_index(drop=True)

In [None]:
tdc = TextDataController.from_df(df,
                                 main_text='Review Text',
                                 label_names='Department Name',
                                )

----- Input Validation Precheck -----
Data contains missing values!
-----> List of columns and the number of missing values for each
Title    2966
dtype: int64
Data contains duplicated values!
-----> Number of duplications: 1 rows


At this point you can start perform 2 important steps on your data

1. Text preprocessings, Label Encoding, Train/Validation Split
2. Tokenization

We haven't provided any preprocessings to the `TextDataController`; we will see more on how to use preprocessings (step by step) as we progress. In fact, we can even perform NaN filtering as a preprocessing step inside `TextDataController`

In [None]:
ddict = tdc.do_all_preprocessing(shuffle_trn=True)

-------------------- Start Main Text Processing --------------------
----- Label Encoding -----


Map (num_proc=4):   0%|          | 0/22628 [00:00<?, ? examples/s]

Done
-------------------- Train Test Split --------------------
Done
-------------------- Dropping unused features --------------------
Done
- Number of rows leaked: 2, which is 0.01% of training set
Filtering leaked data out of training set...


Filter (num_proc=4):   0%|          | 0/18102 [00:00<?, ? examples/s]

Done
-------------------- Shuffling train set --------------------
Done


In [None]:
ddict

DatasetDict({
    train: <datasets.iterable_dataset.IterableDataset object>
    validation: Dataset({
        features: ['Review Text', 'Department Name', 'label'],
        num_rows: 4526
    })
})

Our DatasetDict now has two split: train and validation. Note that train split is now IterableDataset, for processing efficiency

In [None]:
ddict['validation'][:3]

{'Review Text': ["Goes with absolutely everything. it's very comfortable and versatile. can be dressed up or dressed down. great thing to have in the closet when you don't know what to wear!",
  "This is a beautiful blouse...sheer and feminine. i am small busted and slender so i need a size smaller than usual. it is a full top...can't tell exactly how full in the photos but with a small chest there is just too much under the arms. so if your chest is more ample you could prob order your regular size. this is supposed to be a full, shorter fit...i would say the style is going to look better on someone who is a little taller with a medium sized bust rather than someone who is shorter and b",
  'I\'m a size 28 jeans 5\'9 with a bubble bottom. these were falling off of me around the hips. these are warm. i was expecting them to not be sheer since they are fleece lined but when i sat down i was shocked to notice it. it doesn\'t matter functionality wise because i never felt a "draft." they 

In [None]:
for i,v in enumerate(ddict['train']):
    print(v)
    if i==2: break

{'Review Text': 'I must agree with some of the other reviewers--this sweater is so pretty! but the quality of the sweater is not great. my sensitive skin found it to be itchy. so sadly i had to return it.', 'Department Name': 'Tops', 'label': 4}
{'Review Text': "This skirt is about an inch longer on me than on the model in photo (i'm 5 ft, ordered petite) but the length actually works. i have a bit of a tummy, so ordered size 4 - fits great. really love the colors - so many tops will go with this skirt, so it will be easy to vary the look.", 'Department Name': 'Bottoms', 'label': 0}
{'Review Text': 'I typically wear a 4/6 but am a little bigger right now, so i went with the medium. was so excited to get these but was very disappointed when i tried them on. they look beautiful, but there is no lining, leaving the fabric uncomfortable. they were also huge! definitely make sure you size down if you d', 'Department Name': 'Bottoms', 'label': 0}


Now we can start with the tokenization

In [None]:
from transformers import RobertaTokenizer

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

In [None]:
ddict = tdc.do_tokenization(tokenizer,max_length=512)

-------------------- Tokenization --------------------


Map:   0%|          | 0/4526 [00:00<?, ? examples/s]

Done


In [None]:
ddict

DatasetDict({
    train: <datasets.iterable_dataset.IterableDataset object>
    validation: Dataset({
        features: ['Review Text', 'Department Name', 'label', 'input_ids', 'attention_mask'],
        num_rows: 4526
    })
})

In [None]:
print(ddict['validation'][0]['input_ids'][:60])

[0, 534, 8013, 19, 3668, 960, 4, 24, 18, 182, 3473, 8, 16106, 4, 64, 28, 7001, 62, 50, 7001, 159, 4, 372, 631, 7, 33, 11, 5, 16198, 77, 47, 218, 75, 216, 99, 7, 3568, 328, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [None]:
print(next(iter(ddict['train']))['input_ids'][:60])

[0, 100, 531, 2854, 19, 103, 9, 5, 97, 34910, 5579, 9226, 23204, 16, 98, 1256, 328, 53, 5, 1318, 9, 5, 23204, 16, 45, 372, 4, 127, 5685, 3024, 303, 24, 7, 28, 24, 17414, 4, 98, 16748, 939, 56, 7, 671, 24, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In `TextDataController`, you can also perform Text Processing and Tokenization with one method

In [None]:
tdc = TextDataController.from_df(df,
                                 main_text='Review Text',
                                 label_names='Department Name',
                                )

----- Input Validation Precheck -----
Data contains missing values!
-----> List of columns and the number of missing values for each
Title    2966
dtype: int64
Data contains duplicated values!
-----> Number of duplications: 1 rows


In [None]:
tdc.process_and_tokenize(tokenizer,max_length=512,shuffle_trn=True)

-------------------- Start Main Text Processing --------------------
----- Label Encoding -----


Map (num_proc=4):   0%|          | 0/22628 [00:00<?, ? examples/s]

Done
-------------------- Train Test Split --------------------
Done
-------------------- Dropping unused features --------------------
Done
- Number of rows leaked: 2, which is 0.01% of training set
Filtering leaked data out of training set...


Filter (num_proc=4):   0%|          | 0/18102 [00:00<?, ? examples/s]

Done
-------------------- Shuffling train set --------------------
Done
-------------------- Tokenization --------------------


Map:   0%|          | 0/4526 [00:00<?, ? examples/s]

Done


You can access the DatasetDict from the instance variable `main_ddict`

In [None]:
tdc.main_ddict

DatasetDict({
    train: <datasets.iterable_dataset.IterableDataset object>
    validation: Dataset({
        features: ['Review Text', 'Department Name', 'label', 'input_ids', 'attention_mask'],
        num_rows: 4526
    })
})

This DatasetDict is ready to be put into any HuggingFace text model.

## Filtering

This preprocessing step allow you to filter out certain values of a certain column in your dataset. Let's say I want to filter out any 'HC search' value in the column 'Source'

In [None]:
df.Source.value_counts()

Source
Google Play    1434
Non Owned       499
Owned           139
iOS             124
HC search        73
Name: count, dtype: int64

We will provide a dictionary containing the name of the column and the filtering function to apply on that column. Note that the preprocessing step will auto-remove some unused columns, so we need to provide a list of columns to keep

In [None]:
dset = load_dataset('sample_data',data_files=['sample_large.csv'],split='train')
tdc = TextDataController(dset,
                         main_text='Content',
                         label_names='L1',
                         filter_dict={'Source':lambda x: x!='HC search'},
                         cols_to_keep=['Source','Content','L1'],
                         seed=42
                        )

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


In [None]:
tdc.process_and_tokenize(tokenizer,max_length=512,shuffle_trn=True)

Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-66d19029fdb1ba64_*_of_00004.arrow
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-28b4f526c9790ebf_*_of_00004.arrow
Loading cached split indices for dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-0dc37ad5a804f1d2.arrow and /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-c37dbd859a7e31eb.arrow


-------------------- Start Main Text Processing --------------------
-------------------- Data Filtering --------------------
----- Do <lambda> on Source -----
Done
----- Label Encoding -----
Done
-------------------- Train Test Split --------------------
Done
-------------------- Dropping unused features --------------------
Done
- Number of rows leaked: 8, which is 0.46% of training set
Filtering leaked data out of training set...


Filter (num_proc=4):   0%|          | 0/1756 [00:00<?, ? examples/s]

Done
-------------------- Shuffling train set --------------------
Done
-------------------- Tokenization --------------------


Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-0da2cd7fcb59203d.arrow


Done


In [None]:
tdc.main_ddict

DatasetDict({
    train: <datasets.iterable_dataset.IterableDataset object>
    validation: Dataset({
        features: ['Source', 'Content', 'L1', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 440
    })
})

We can check whether 'HC search' is still in our dataset

In [None]:
set(tdc.main_ddict['validation']['Source'])

{'Google Play', 'Non Owned', 'Owned', 'iOS'}

In [None]:
set([v['Source'] for v in tdc.main_ddict['train']])

{'Google Play', 'Non Owned', 'Owned', 'iOS'}

We can even add multiple filtering functions

In [None]:
df.L1.value_counts()

L1
Others                     811
Feature                    541
Commercial                 305
Delivery                   200
Shopee account             186
Buyer complained seller     64
Return/Refund               45
Payment                     44
Order/Item                  41
Services                    32
Name: count, dtype: int64

In [None]:
dset = load_dataset('sample_data',data_files=['sample_large.csv'],split='train')
tdc = TextDataController(dset,
                         main_text='Content',
                         label_names='L1',
                         filter_dict={'Source':lambda x: x!='HC search',
                                      'L1': lambda x: x not in ['Order/Item','Services']
                                     },
                         cols_to_keep=['Source','Content','L1'],
                         seed=42
                        )

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


In [None]:
tdc.process_and_tokenize(tokenizer,max_length=512,shuffle_trn=True)

Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-66d19029fdb1ba64_*_of_00004.arrow
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-d9d0358cd1660021_*_of_00004.arrow
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-024d2fb434ed2194_*_of_00004.arrow
Loading cached split indices for dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-caa70b0e2d1ee8fa.arrow and /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c74

-------------------- Start Main Text Processing --------------------
-------------------- Data Filtering --------------------
----- Do <lambda> on Source -----
----- Do <lambda> on L1 -----
Done
----- Label Encoding -----
Done
-------------------- Train Test Split --------------------
Done
-------------------- Dropping unused features --------------------
Done
- Number of rows leaked: 4, which is 0.24% of training set
Filtering leaked data out of training set...


Filter (num_proc=4):   0%|          | 0/1700 [00:00<?, ? examples/s]

Done
-------------------- Shuffling train set --------------------
Done
-------------------- Tokenization --------------------


Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-de7a39d90b9efb49.arrow


Done


Since 'L1' is our label, we can access the label list to check whether our L1 filtering is correct

In [None]:
tdc.label_lists

[['Buyer complained seller',
  'Commercial',
  'Delivery',
  'Feature',
  'Others',
  'Payment',
  'Return/Refund',
  'Shopee account']]

In [None]:
set(tdc.main_ddict['validation']['Source'])

{'Google Play', 'Non Owned', 'Owned', 'iOS'}

## Metadatas concatenation

If we think the metadatas can be helpful, we can concatenate them into the front of your text, so that our text classification model is aware of it.

In this example, 'Source' will be our metadata

In [None]:
dset = load_dataset('sample_data',data_files=['sample_large.csv'],split='train')
tdc = TextDataController(dset,
                         main_text='Content',
                         label_names='L1',
                         metadatas='Source',
                         process_metas=False,
                         seed=42
                        )
tdc.process_and_tokenize(tokenizer,max_length=512,shuffle_trn=True)

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-1ce83717afece17c_*_of_00004.arrow
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-4667f0b67eaedd6e_*_of_00004.arrow
Loading cached split indices for dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-613c169dda05938f.arrow and /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-904746b36977d169.arrow


-------------------- Start Main Text Processing --------------------
----- Metadata Simple Processing & Concatenating to Main Content -----
Done
----- Label Encoding -----
Done
-------------------- Train Test Split --------------------
Done
-------------------- Dropping unused features --------------------
Done
- Number of rows leaked: 7, which is 0.39% of training set
Filtering leaked data out of training set...


Filter (num_proc=4):   0%|          | 0/1815 [00:00<?, ? examples/s]

Done
-------------------- Shuffling train set --------------------
Done
-------------------- Tokenization --------------------


Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-904b6d93e98f2824.arrow


Done


In [None]:
next(iter(tdc.main_ddict['train']))['Content']

'Non Owned . Huawei mở đặt trước đồng hồ Watch Fit 2, Watch GT 3 Pro và Watch Kids Pro 4'

You can add multiple metadatas. Let's say L2 is the second metadata.

In [None]:
dset = load_dataset('sample_data',data_files=['sample_large.csv'],split='train')
tdc = TextDataController(dset,
                         main_text='Content',
                         label_names='L1',
                         metadatas=['Source','L2'],
                         process_metas=False,
                         seed=42
                        )
tdc.process_and_tokenize(tokenizer,max_length=512,shuffle_trn=True)

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-02e842e167e8cd44_*_of_00004.arrow
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-ac039bba33dde05e_*_of_00004.arrow
Loading cached split indices for dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-496b3bea3d2b2904.arrow and /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-476c2414740b2f39.arrow


-------------------- Start Main Text Processing --------------------
----- Metadata Simple Processing & Concatenating to Main Content -----
Done
----- Label Encoding -----
Done
-------------------- Train Test Split --------------------
Done
-------------------- Dropping unused features --------------------
Done
- Number of rows leaked: 5, which is 0.28% of training set
Filtering leaked data out of training set...


Filter (num_proc=4):   0%|          | 0/1815 [00:00<?, ? examples/s]

Done
-------------------- Shuffling train set --------------------
Done
-------------------- Tokenization --------------------


Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-045a55c61ba6226a.arrow


Done


In [None]:
next(iter(tdc.main_ddict['train']))['Content']

'Process RR . Non Owned . :((( mình chưa nhận được hàng mà nó đã hiện được yêu cầu trả hàng hoàn tiền rồi là ntn ạ? Đơn mình đặt mà mình k có thông tin gì về nó luôn mng cho mình xin cách liên hệ với shopee với ạ? :((( chưa nhận được thì nó có cho mình hoàn tiền hay là nó chuyển vào mục review luôn rồi k ạ?'

If you want to preprocess the metadata (currently it's just empty space stripping and lowercasing), set `process_metas` to `True`

In [None]:
dset = load_dataset('sample_data',data_files=['sample_large.csv'],split='train')
tdc = TextDataController(dset,
                         main_text='Content',
                         label_names='L1',
                         metadatas='Source',
                         process_metas=True,
                         seed=42
                        )
tdc.process_and_tokenize(tokenizer,max_length=512,shuffle_trn=True)

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-65897bac5e39cbaa_*_of_00004.arrow
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-ea0e555513f7aaf9_*_of_00004.arrow
Loading cached split indices for dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-41c8d98a12fe2203.arrow and /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-254380030a2a01ec.arrow


-------------------- Start Main Text Processing --------------------
----- Metadata Simple Processing & Concatenating to Main Content -----
Done
----- Label Encoding -----
Done
-------------------- Train Test Split --------------------
Done
-------------------- Dropping unused features --------------------
Done
- Number of rows leaked: 7, which is 0.39% of training set
Filtering leaked data out of training set...


Filter (num_proc=4):   0%|          | 0/1815 [00:00<?, ? examples/s]

Done
-------------------- Shuffling train set --------------------
Done
-------------------- Tokenization --------------------


Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-434bc0967abeb3ec.arrow


Done


In [None]:
next(iter(tdc.main_ddict['train']))['Content']

'non owned . Huawei mở đặt trước đồng hồ Watch Fit 2, Watch GT 3 Pro và Watch Kids Pro 4'

## Label Encodings

We have briefly gone through the simplest case of label encodings, which is when we only need to predict 1 single label (L1). In this library this is called **single head classification**

In [None]:
dset = load_dataset('sample_data',data_files=['sample_large.csv'],split='train')
tdc = TextDataController(dset,
                         main_text='Content',
                         label_names='L1',
                         seed=42
                        )
tdc.process_and_tokenize(tokenizer,max_length=512,shuffle_trn=True)

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-056a3d88e5d7e7f1_*_of_00004.arrow
Loading cached split indices for dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-417be8257836e100.arrow and /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-1969cf23af4918bb.arrow
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-85a87410822b39a8_*_of_00004.arrow


-------------------- Start Main Text Processing --------------------
----- Label Encoding -----
Done
-------------------- Train Test Split --------------------
Done
-------------------- Dropping unused features --------------------
Done
- Number of rows leaked: 7, which is 0.39% of training set
Filtering leaked data out of training set...
Done
-------------------- Shuffling train set --------------------
Done
-------------------- Tokenization --------------------


Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-f91699a587bf9b9a.arrow


Done


All label names will be saved in instance variable `label_lists`

In [None]:
tdc.label_lists

[['Buyer complained seller',
  'Commercial',
  'Delivery',
  'Feature',
  'Order/Item',
  'Others',
  'Payment',
  'Return/Refund',
  'Services',
  'Shopee account']]

... and all labels will be encoded

In [None]:
tdc.main_ddict['validation']['label'][:5]

[3, 5, 2, 5, 5]

We also keep the original labeling, for references

In [None]:
tdc.main_ddict['validation']['L1'][:5]

['Feature', 'Others', 'Delivery', 'Others', 'Others']

Let's say our case is no longer predicting 1 single thing. What if we need to predict 2 different labels as once (this is called **multi-head classification**). For example, let's define our dataset so that we need to predict both L1 and L2

In [None]:
dset = load_dataset('sample_data',data_files=['sample_large.csv'],split='train')
tdc = TextDataController(dset,
                         main_text='Content',
                         label_names=['L1','L2'],
                         seed=42,
                        )
tdc.process_and_tokenize(tokenizer,max_length=512,shuffle_trn=True)

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


-------------------- Start Main Text Processing --------------------
----- Label Encoding -----


Map (num_proc=4):   0%|          | 0/2269 [00:00<?, ? examples/s]

Done
-------------------- Train Test Split --------------------
Done
-------------------- Dropping unused features --------------------
Done
- Number of rows leaked: 7, which is 0.39% of training set
Filtering leaked data out of training set...


Filter (num_proc=4):   0%|          | 0/1815 [00:00<?, ? examples/s]

Done
-------------------- Shuffling train set --------------------
Done
-------------------- Tokenization --------------------


Map:   0%|          | 0/454 [00:00<?, ? examples/s]

Done


In [None]:
len(tdc.label_lists[0]),len(tdc.label_lists[1])

(10, 56)

We can see that we have two lists, one for label names of L1, and one for label names of L2

In [None]:
tdc.main_ddict

DatasetDict({
    train: <datasets.iterable_dataset.IterableDataset object>
    validation: Dataset({
        features: ['Content', 'L1', 'L2', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 454
    })
})

In [None]:
tdc.main_ddict['validation']['L1'][:5]

['Feature', 'Others', 'Delivery', 'Others', 'Others']

In [None]:
tdc.main_ddict['validation']['L2'][:5]

['App performance', 'Cannot defined', 'Shipping fee', 'Cannot defined', 'Scam']

In [None]:
tdc.main_ddict['validation']['label'][:5]

[[3, 2], [5, 8], [2, 47], [5, 8], [5, 41]]

Lastly, let's define a **multi-label classification**, where a text can have 1 or more label. In this example, we will combine L1 and L2 to have a label containing multiple values

In [None]:
df = pd.read_csv('sample_data/sample_large.csv',encoding='utf-8-sig')

In [None]:
df.head()

Unnamed: 0,Source,Content,L1,L2
0,Google Play,"App ncc lúc nào cx lag đơ, phần tìm kiếm thì v...",Feature,App performance
1,Non Owned,..❗️ GÓC THANH LÝ Tính ra rẻ hơn cả mua #Shope...,Commercial,Items/price
2,Google Play,Mắc gì người ta đặt hàng toàn lỗi 😃????,Feature,App performance
3,Owned,#GhienShopeePayawardT8 Khi bạn chơi shopee quá...,Commercial,Shopee Programs
4,Google Play,Rất bức xúc khi dùng . mã giảm giá người dùng ...,Feature,Apply Voucher


In [None]:
df['L1L2'] = df[['L1','L2']].values.tolist()

In [None]:
df.drop(['L1','L2'],axis=1,inplace=True)

In [None]:
df.head()

Unnamed: 0,Source,Content,L1L2
0,Google Play,"App ncc lúc nào cx lag đơ, phần tìm kiếm thì v...","[Feature, App performance]"
1,Non Owned,..❗️ GÓC THANH LÝ Tính ra rẻ hơn cả mua #Shope...,"[Commercial, Items/price]"
2,Google Play,Mắc gì người ta đặt hàng toàn lỗi 😃????,"[Feature, App performance]"
3,Owned,#GhienShopeePayawardT8 Khi bạn chơi shopee quá...,"[Commercial, Shopee Programs]"
4,Google Play,Rất bức xúc khi dùng . mã giảm giá người dùng ...,"[Feature, Apply Voucher]"


You don't have to add any extra argument; the controller will determine whether this is for multilabel classification, based on the format of the label values

In [None]:
tdc = TextDataController.from_df(df,
                                 main_text='Content',
                                 label_names=['L1L2'],
                                 seed=42,
                                )
tdc.process_and_tokenize(tokenizer,max_length=512,shuffle_trn=True)

----- Input Validation Precheck -----
Data contains duplicated values!
-----> Number of duplications: 16 rows
-------------------- Start Main Text Processing --------------------
----- Label Encoding -----


Map (num_proc=4):   0%|          | 0/2269 [00:00<?, ? examples/s]

Done
-------------------- Train Test Split --------------------
Done
-------------------- Dropping unused features --------------------
Done
- Number of rows leaked: 7, which is 0.39% of training set
Filtering leaked data out of training set...


Filter (num_proc=4):   0%|          | 0/1815 [00:00<?, ? examples/s]

Done
-------------------- Shuffling train set --------------------
Done
-------------------- Tokenization --------------------


Map:   0%|          | 0/454 [00:00<?, ? examples/s]

Done


In [None]:
len(tdc.label_lists[0])

66

In [None]:
print(tdc.main_ddict['validation']['label'][0])

[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


Since this is **multilabel classification**, the label will be one-hot encoded

In [None]:
tdc.main_ddict['validation']['L1L2'][0]

['Feature', 'App performance']

We can do a quick check to see whether it works correctly

In [None]:
tdc.label_lists[0][2],tdc.label_lists[0][24]

('App performance', 'Feature')

## Content Transformation

In [None]:
df.Content.sample(10).values

array(['Phần trả hàng hoàn tiền ko bao giờ thành công, app chậm, duyệt cẩn thận mấy bọn buôn bán trên shopee toàn hàng nhái rồi sai hàng, quá tệ',
       'Cần nhường bán lại   Mình bán giày thể thao với giá chỉ ₫80.000 - ₫150.000. Mua ngay trên Shopee nhé!',
       'Lỗi ko vào đc', 'Như cc v',
       'Sao ko lm trang trên web cứ bắt phải tải app ?', 'Như quần',
       'Ai biết chỗ nào bán serum inod trị hôi nách chính hãng ko,bữa mua trên shoppe bị trúng hàng giả',
       'Liên kết shopeepay danh tính chỉ được một lần, khi mà đổi sdt gọi điện mãi mới chịu hỗ trợ. Đổi cho rồi thì bị khóa nick shopee mới, nói là lạm dụng mã giảm giá????? Ba lần rồi cứ mở cho sau lại khóa, khó chịu ghét vl, mất bao nhiêu thời gian, vấn đề. Không biết lần này có chịu mở cho không nữa, tức vãi',
       'Lỗi ko à', 'Bọn chó . chưa ji khóa tài khoản'], dtype=object)

In [None]:
print(tokenizer.convert_ids_to_tokens(tokenizer('⭐️𝑫𝒊𝒐𝒓 𝑺𝒂𝒖𝒗𝒂𝒈𝒆 𝑬𝑫𝑷 100ml⭐️')['input_ids']))

['<s>', '▁', '<unk>', 'D', 'ior', '▁Sau', 'v', 'age', '▁ED', 'P', '▁100', 'ml', '<unk>', '</s>']


In [None]:
# TODO: do something wicked, such as detect whether there's website in the text, if yes, concat to front

## Train/Validation Split

## Upsampling

## Content Augmentation

## Streaming Capability

## Let's go

In [None]:
# main_ddict = load_dataset('secret_data',data_files=['buyer_listening_with_all_raw_data_w28.csv','buyer_listening_with_all_raw_data_w28.csv'],split='train')
# main_ddict

# main_ddict = load_dataset('sample_data',data_files=['sample_large.csv','sample_large.csv','sample_large.csv'])
# main_ddict

main_ddict = load_dataset('sample_data',data_files=['sample_large.csv'],split='train')
main_ddict

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


Dataset({
    features: ['Source', 'Content', 'L1', 'L2'],
    num_rows: 2269
})

In [None]:
main_ddict[:5]

{'Source': ['Google Play', 'Non Owned', 'Google Play', 'Owned', 'Google Play'],
 'Content': ['App ncc lúc nào cx lag đơ, phần tìm kiếm thì viết kiểu gì sp đó vẫn ko ra, thế phải ghi đúng tên mới chịu à? Lỡ quên tên ngta ghi mé mé như v cx phải gợi ý sp tương tự chứ??? ☻',
  '..❗️ GÓC THANH LÝ Tính ra rẻ hơn cả mua #Shopee Mong 1 lần đc check ib mỏi tay 😆😆😆   Em chuyển cửa hàng nên dọn lại có thừa vài chục tấm nệm xuất nhật này.   1mx2m : 1m2x2 : 1m4x2m : 1m6x2m : 1m8x2m :2mx2m Đệm dày 7-8 phân Nhưng vì còn ít nên topic này em bán thanh lý giá rẻ ạ.   Ship cod nhận hàng được kiểm tra thoải mái.  Miễn ship toàn quốc. Nên đừng bom tội nghiệp em nhé dày 7-8 phân Nhắn tin em gửi mẫu nhé🥰',
  'Mắc gì người ta đặt hàng toàn lỗi 😃????',
  '#GhienShopeePayawardT8 Khi bạn chơi shopee quá lâu thì không thể nào không biết đến với ShopeePay . Liên kết thanh toán được cho các đơn hàng Shopee và ShopeeFood luôn nè.',
  'Rất bức xúc khi dùng . mã giảm giá người dùng thì m02 vậy cho ưu đãi đấy làm gì ạ

In [None]:
%%time
main_ddict[0]

CPU times: user 210 µs, sys: 118 µs, total: 328 µs
Wall time: 227 µs


{'Source': 'Google Play',
 'Content': 'App ncc lúc nào cx lag đơ, phần tìm kiếm thì viết kiểu gì sp đó vẫn ko ra, thế phải ghi đúng tên mới chịu à? Lỡ quên tên ngta ghi mé mé như v cx phải gợi ý sp tương tự chứ??? ☻',
 'L1': 'Feature',
 'L2': 'App performance'}

In [None]:
seed_everything(42)

In [None]:
_filter_dict={'L1':lambda x: x!='Others'}

_content_tfms = partial(apply_vnmese_word_tokenize,normalize_text=True)
_content_tfms.__name__='VNM word segmentation'

_upsampling_dict={
    'Source': lambda x: x=='hc search' if random.random()<0.5 else False
}

_aug_tfms=partial(remove_vnmese_accent,prob=0.5)

In [None]:
tdc = TextDataController(main_ddict,main_text='Content',
                         label_names='L1',
                         filter_dict=_filter_dict,
                         metadatas='Source',
                         content_transformations=_content_tfms,
                         val_ratio=0.25,
                         stratify_cols='Source',
                         upsampling_dict=_upsampling_dict,
                         content_augmentations=_aug_tfms,
                         seed=42,
                         is_batched=True,
                         is_streamed=False
                        )

In [None]:
%%time
my_ddict = tdc.do_all_preprocessing(shuffle_trn=True)
# 2x big data
# is_batched True, shuffle_trn True. Wall time: 1min 6s
# is_batched False, shuffle_trn True, Wall time: 1min 18s

# 3x sample_large.csv
# is_batched True, shuffle_trn True. Wall time: 2.79 s
# is_batched False, shuffle_trn True, Wall time: 2.96 s

Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-e8025282b10ede62_*_of_00004.arrow
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-0ecebfc0c7f83a1d_*_of_00004.arrow
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-966fc97b537cb573_*_of_00004.arrow
Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/sample_data-96e446a75e3f09ba/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-9067e41b077d8eb9_*_of_00004.arrow


-------------------- Start Main Text Processing --------------------
-------------------- Data Filtering --------------------
----- Do <lambda> on L1 -----
Done
----- Metadata Simple Processing & Concatenating to Main Content -----
Done
----- Label Encoding -----
Done
-------------------- Text Transformation --------------------
----- VNM word segmentation -----
Done
-------------------- Train Test Split --------------------


Map (num_proc=4):   0%|          | 0/1458 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/1458 [00:00<?, ? examples/s]

Done
-------------------- Dropping unused features --------------------
Done
- Number of rows leaked: 1, which is 0.09% of training set
Filtering leaked data out of training set...


Filter (num_proc=4):   0%|          | 0/1093 [00:00<?, ? examples/s]

Done
-------------------- Upsampling data --------------------
----- Do <lambda> on Source -----


Filter (num_proc=4):   0%|          | 0/1091 [00:00<?, ? examples/s]

Done
-------------------- Text Augmentation --------------------
----- remove_vnmese_accent -----


Map (num_proc=4):   0%|          | 0/1112 [00:00<?, ? examples/s]

Done
-------------------- Shuffling train set --------------------
Done
CPU times: user 252 ms, sys: 154 ms, total: 406 ms
Wall time: 736 ms


In [None]:
my_ddict

DatasetDict({
    train: <datasets.iterable_dataset.IterableDataset object>
    validation: Dataset({
        features: ['Source', 'Content', 'L1'],
        num_rows: 365
    })
})

In [None]:
%%time
my_ddict['validation'][0]

CPU times: user 478 µs, sys: 0 ns, total: 478 µs
Wall time: 310 µs


{'Source': 'non owned',
 'Content': 'non owned . Mã_500K toàn sàn Shopee cho ai cần nè » https://shope.ee/10QzJtpqQi',
 'L1': 1}

In [None]:
_tmp=iter(my_ddict['train'])

In [None]:
%%time
next(_tmp)

CPU times: user 28.2 ms, sys: 301 µs, total: 28.5 ms
Wall time: 26.8 ms


{'Source': 'google play',
 'Content': 'google play . Chán ơi là chánnnn Toàn_bị U02_Mn có biết cách nào chữa ko thì chỉ em với ạ 🥰',
 'L1': 3}

In [None]:
%%time
next(_tmp)

CPU times: user 63 µs, sys: 53 µs, total: 116 µs
Wall time: 118 µs


{'Source': 'google play',
 'Content': 'google play . Ứng_dụng quá là đơ . Càng dùng càng đơ thì mua_bán cái mẹ j .',
 'L1': 3}

In [None]:
tokenizer


PhobertTokenizer(name_or_path='vinai/phobert-base', vocab_size=64000, model_max_length=256, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True)

In [None]:
tdc.main_ddict

DatasetDict({
    train: <datasets.iterable_dataset.IterableDataset object>
    validation: Dataset({
        features: ['Source', 'Content', 'L1'],
        num_rows: 365
    })
})

In [None]:
# tdc.main_ddict['validation'] = tdc.main_ddict['validation'].to_iterable_dataset(num_shards=tdc.num_shards)

In [None]:
my_ddict_tok = tdc.do_tokenization(tokenizer)

Map:   0%|          | 0/365 [00:00<?, ? examples/s]

In [None]:
my_ddict_tok

DatasetDict({
    train: <datasets.iterable_dataset.IterableDataset object>
    validation: Dataset({
        features: ['Source', 'Content', 'L1', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 365
    })
})

In [None]:
%%time
_ = my_ddict_tok['validation'][0]

CPU times: user 287 µs, sys: 140 µs, total: 427 µs
Wall time: 380 µs


In [None]:
my_ddict_tok['validation'][0]['Content']

'non owned . Mã_500K toàn sàn Shopee cho ai cần nè » https://shope.ee/10QzJtpqQi'

In [None]:
print(tokenizer.convert_ids_to_tokens(my_ddict_tok['validation'][0]['input_ids'])[:100])

['<s>', 'non', 'ow@@', 'ned', '.', 'Mã_@@', '500@@', 'K', 'toàn', 'sàn', 'Sho@@', 'pee', 'cho', 'ai', 'cần', 'nè', '»', 'htt@@', 'ps@@', '://@@', 'sho@@', 'pe@@', '.@@', 'ee@@', '/@@', '10@@', 'Q@@', 'z@@', 'J@@', 't@@', 'p@@', 'q@@', 'Q@@', 'i', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']


In [None]:
_iter=iter(my_ddict_tok['train'])

In [None]:
%%time
_tmp = next(_iter)

CPU times: user 272 ms, sys: 123 µs, total: 272 ms
Wall time: 272 ms


In [None]:
%%time
_tmp = next(_iter)

CPU times: user 0 ns, sys: 5 µs, total: 5 µs
Wall time: 6.68 µs


In [None]:
print(_tmp['Content'])

print(tokenizer.convert_ids_to_tokens(_tmp['input_ids'])[:100])

google play . App nhu cc ko lam gi cung xoa tk
['<s>', 'google', 'play', '.', 'App', 'nhu', 'cc', 'ko', 'lam', 'gi', 'cung', 'xoa', 't@@', 'k', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']


In [None]:
i2v = {i:v for i,v in enumerate(tdc.label_lists[0])}

vc_l1 = pd.Series(my_ddict['validation']['L1']).value_counts(normalize=True).reset_index()

In [None]:
vc_l1['index'] = vc_l1['index'].map(i2v)
vc_l1
# Feature                    0.371056
#  Commercial                 0.209191
#  Delivery                   0.137174
#  Shopee account             0.127572
#  Buyer complained seller    0.043896
#  Return/Refund              0.030864
#  Payment                    0.030178
#  Order/Item                 0.028121
#  Services                   0.021948

Unnamed: 0,index,proportion
0,Feature,0.369863
1,Commercial,0.213699
2,Shopee account,0.131507
3,Delivery,0.128767
4,Buyer complained seller,0.041096
5,Return/Refund,0.032877
6,Services,0.032877
7,Order/Item,0.027397
8,Payment,0.021918


In [None]:
pd.Series(my_ddict['validation']['Source']).value_counts(),pd.Series(my_ddict['validation']['Source']).value_counts(normalize=True)
#  Google Play    0.743484
#  Non Owned      0.093278
#  Owned          0.069959
#  iOS            0.061043
#  HC search      0.032236

(google play    271
 non owned       34
 owned           26
 ios             22
 hc search       12
 Name: count, dtype: int64,
 google play    0.742466
 non owned      0.093151
 owned          0.071233
 ios            0.060274
 hc search      0.032877
 Name: proportion, dtype: float64)

In [None]:
_tmp_trn = list(my_ddict['train'])

In [None]:
len(my_ddict['validation'])

365

In [None]:
# len(_tmp_trn) + len(my_ddict['validation'])

In [None]:
len(_tmp_trn)

1106

In [None]:
int((47-12)*0.5) + (1458 - 365)

1110

## Let's stream

In [None]:
# main_ddict = load_dataset('secret_data',data_files=['buyer_listening_with_all_raw_data_w28.csv','buyer_listening_with_all_raw_data_w28.csv'],split='train')
# main_ddict

# main_ddict = load_dataset('sample_data',data_files=['sample_large.csv','sample_large.csv','sample_large.csv'],streaming=True)
# main_ddict

main_ddict = load_dataset('sample_data',data_files=['sample_large.csv'],split='train',streaming=True)
main_ddict

<datasets.iterable_dataset.IterableDataset>

In [None]:
# for i,v in enumerate(main_ddict):
#     print(v)
#     if i==4: break

In [None]:
seed_everything(42)


In [None]:
labels=['Feature','Commercial','Delivery',
        'Shopee account','Buyer complained seller',
        'Return/Refund','Payment','Order/Item',
        'Services','Others']


In [None]:
_filter_dict={'L1':lambda x: x!='Others'}

_content_tfms = partial(apply_vnmese_word_tokenize,normalize_text=True)
_content_tfms.__name__='VNM word segmentation'

_upsampling_dict={
    'Source': lambda x: x=='hc search' if random.random()<0.5 else False
}

_aug_tfms=partial(remove_vnmese_accent,prob=0.85)

In [None]:
main_ddict = load_dataset('sample_data',data_files=['sample_large.csv'],split='train',streaming=True)
tdc = TextDataController(main_ddict,main_text='Content',
                         label_names='L1',
                         class_names_predefined=labels,
                         filter_dict=_filter_dict,
                         metadatas='Source',
                         content_transformations=_content_tfms,
                         val_ratio=365, 
#                          upsampling_dict=_upsampling_dict, # super slow
                         content_augmentations=_aug_tfms,
                         seed=42,
                         is_batched=True,
                         is_streamed=True,
                         num_shards=512
                        )

In [None]:
import gc
gc.collect()

70

In [None]:
%%time
my_ddict = tdc.do_all_preprocessing(shuffle_trn=True)
# 2x big data
# is_batched True, shuffle_trn True. Wall time: 1min 6s
# is_batched False, shuffle_trn True, Wall time: 1min 18s

# 3x sample_large.csv
# is_batched True, shuffle_trn True. Wall time: 2.79 s
# is_batched False, shuffle_trn True, Wall time: 2.96 s

# TODO: redo
# 3x sample_large.csv, streaming, with aug
# is_batched True, shuffle_trn True. Wall time: 53.5 s
# is_batched False, shuffle_trn True, Wall time: 17.4 s



-------------------- Start Main Text Processing --------------------
-------------------- Data Filtering --------------------
----- Do <lambda> on L1 -----
Done
----- Metadata Simple Processing & Concatenating to Main Content -----
Done
----- Label Encoding -----
Done
-------------------- Text Transformation --------------------
----- VNM word segmentation -----
Done
-------------------- Train Test Split --------------------
Done
-------------------- Dropping unused features --------------------
Done
-------------------- Text Augmentation --------------------
Done
-------------------- Shuffling train set --------------------
Done
CPU times: user 2.62 s, sys: 2.64 s, total: 5.26 s
Wall time: 5.26 s


In [None]:
my_ddict['validation']

Dataset({
    features: ['Source', 'Content', 'L1'],
    num_rows: 365
})

In [None]:
my_ddict_tok = tdc.do_tokenization(tokenizer)

Map:   0%|          | 0/365 [00:00<?, ? examples/s]

In [None]:
my_ddict_tok

DatasetDict({
    train: <datasets.iterable_dataset.IterableDataset object>
    validation: Dataset({
        features: ['Source', 'Content', 'L1', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 365
    })
})

In [None]:
# len(list(my_ddict_tok['train'])),len(list(my_ddict_tok['validation']))


In [None]:
_iter=iter(my_ddict_tok['train'])
_tmp = next(_iter)
_tmp = next(_iter)

In [None]:
len(my_ddict_tok['validation'])

365

In [None]:
_tmp=iter(my_ddict['train'])

In [None]:
%%time
next(_tmp)
# 9.44s with upsampling
# 4.72s without upsampling
# 5.16 without upsampling, with augmentation

CPU times: user 2.29 s, sys: 2.93 s, total: 5.22 s
Wall time: 5.22 s


{'Source': 'google play',
 'Content': 'google play . Bi loi khong hien_hinh_anh , xoa di tai lai app deu ko hien_thi gi het .',
 'L1': 3}

In [None]:
%%time
next(_tmp)

CPU times: user 6 µs, sys: 0 ns, total: 6 µs
Wall time: 8.58 µs


{'Source': 'hc search',
 'Content': 'hc search . khong cap_nhat duoc sdt',
 'L1': 9}

In [None]:
_tmp=iter(my_ddict['validation'])

In [None]:
%%time
next(_tmp)
# 9.44s

CPU times: user 173 µs, sys: 154 µs, total: 327 µs
Wall time: 241 µs


{'Source': 'google play',
 'Content': 'google play . đang chơi mà quảng_cáo',
 'L1': 1}

In [None]:
%%time
next(_tmp)

CPU times: user 80 µs, sys: 0 ns, total: 80 µs
Wall time: 83 µs


{'Source': 'google play',
 'Content': 'google play . Cứ quản cáo quá nhiều , app nào củng gặp quản_cáo của shopee 😆',
 'L1': 1}

In [None]:
len(list(my_ddict['train'])),len(list(my_ddict['validation']))
# (1124, 365)

(1093, 365)

In [None]:
len(list(my_ddict_tok['train'])),len(list(my_ddict_tok['validation']))
# (1124, 365)

(1093, 365)

In [None]:
for i,v in enumerate(my_ddict['train']):
    print(v)
    if i==9: break

{'Source': 'ios', 'Content': 'ios . Đã chỉnh đi chỉnh lại rất nhiều nhưng nó vẫn ghi là thanh toán ko khả dụng! \nLà sao hả SHOPEE:)))', 'L1': 6}
{'Source': 'hc search', 'Content': 'hc search . kh đc mượt', 'L1': 3}
{'Source': 'google play', 'Content': 'google play . Hóng đơn hàng về .khi về thì Shipper chưa giao mà báo k ai nhận. Tự ý huỷ đơn. Shipper Ngũ Hành Sơn Đà Nẵng quá kém.', 'L1': 2}
{'Source': 'google play', 'Content': 'google play . Có cái nịt thùng mì 50k phí sip 100', 'L1': 2}
{'Source': 'google play', 'Content': 'google play . Alo mới đặt hàng đi vắng có 1 bữa Xong shipper nt chửi ôm xồm vậy Đánh giá 1 sao cho biết', 'L1': 2}
{'Source': 'google play', 'Content': 'google play . Đc', 'L1': 5}
{'Source': 'google play', 'Content': 'google play . Tôi vưa bị đăng xuất khỏi shoppe một cách vô lý. Bây h k vào lại đc. Có vào đc thì tất cả nhưg đơn hàg trc đó của tui cug đã bị mất. Yêu cầu giai quyết vấn đề', 'L1': 9}
{'Source': 'google play', 'Content': 'google play . Giao hàng gì

In [None]:
_tmp=iter(my_ddict_tok['validation'])

In [None]:
print(my_ddict_tok['validation']['Content'][0])

print(tokenizer.convert_ids_to_tokens((my_ddict_tok['validation']['input_ids'][0])[:100]))

google play . đang chơi mà quảng_cáo
['<s>', 'google', 'play', '.', 'đang', 'chơi', 'mà', 'quảng_cáo', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']


Let's start a step-by-step walkthrough on how to use this class

## Load data

In [None]:
from datasets import Dataset,load_dataset

In [None]:
DATA_PATH = Path('sample_data')

In [None]:
df = pd.read_csv(DATA_PATH/'sample_large.csv')

df.head()

In [None]:
main_ddict = Dataset.from_csv(str(DATA_PATH/'sample_large.csv'))
main_ddict

In [None]:
# main_ddict = load_dataset(str(DATA_PATH),data_files={'train':'sample_large.csv'})
main_ddict = load_dataset(str(DATA_PATH),data_files='sample_large.csv')
main_ddict

In [None]:
main_ddict = load_dataset(str(DATA_PATH),data_files='sample_large.csv',split='train')
main_ddict

In [None]:
main_ddict = load_dataset(str(DATA_PATH),data_files='sample_large.csv',split='train',streaming=True)
main_ddict

<datasets.iterable_dataset.IterableDataset>

In [None]:
next(iter(main_ddict))

{'Source': 'Google Play',
 'Content': 'App ncc lúc nào cx lag đơ, phần tìm kiếm thì viết kiểu gì sp đó vẫn ko ra, thế phải ghi đúng tên mới chịu à? Lỡ quên tên ngta ghi mé mé như v cx phải gợi ý sp tương tự chứ??? ☻',
 'L1': 'Feature',
 'L2': 'App performance'}

In [None]:
list(main_ddict.take(3))

In [None]:
%%time
main_ddict = load_dataset(str(DATA_PATH),data_files=['sample_large.csv','sample_large.csv'])
main_ddict

In [None]:
# main_ddict = load_dataset('csv',data_files=str(DATA_PATH/'sample_large.csv'))
# main_ddict

In [None]:
_tmp = Path('secret_data')/'some_files.csv'

In [None]:
_tmp.parent,_tmp.name

(Path('secret_data'), 'some_files.csv')

In [None]:
_tmp = Path('some_files.csv')
_tmp.parent

Path('.')

In [None]:
%%time
main_ddict = load_dataset('secret_data',data_files='buyer_listening_with_all_raw_data_w28.csv',split='train')
main_ddict

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/secret_data-042d1badc74881bf/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


CPU times: user 1.06 ms, sys: 3.97 ms, total: 5.03 ms
Wall time: 4.48 ms


Dataset({
    features: ['Week', 'Group', 'Source', 'Content', 'L1', 'L2', 'L3', 'L4', 'is_valid', 'iteration'],
    num_rows: 114605
})

In [None]:
%%time
main_ddict = load_dataset('csv',data_files='secret_data/buyer_listening_with_all_raw_data_w28.csv',split='train')
main_ddict

Downloading and preparing dataset csv/default to /home/quan/.cache/huggingface/datasets/csv/default-042d1badc74881bf/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /home/quan/.cache/huggingface/datasets/csv/default-042d1badc74881bf/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1. Subsequent calls will reuse this data.
CPU times: user 592 ms, sys: 84.5 ms, total: 677 ms
Wall time: 1.72 s


Dataset({
    features: ['Week', 'Group', 'Source', 'Content', 'L1', 'L2', 'L3', 'L4', 'is_valid', 'iteration'],
    num_rows: 114605
})

## Actual data loading

In [None]:
main_ddict = load_dataset('secret_data',data_files='buyer_listening_with_all_raw_data_w28.csv')
if hasattr(main_ddict,'keys'):
    print('yes')
print(main_ddict.keys())

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/secret_data-042d1badc74881bf/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


  0%|          | 0/1 [00:00<?, ?it/s]

yes
dict_keys(['train'])


In [None]:
main_ddict

DatasetDict({
    train: Dataset({
        features: ['Week', 'Group', 'Source', 'Content', 'L1', 'L2', 'L3', 'L4', 'is_valid', 'iteration'],
        num_rows: 114605
    })
})

In [None]:
_tmp = main_ddict.pop('train')

In [None]:
main_ddict

DatasetDict({
    
})

In [None]:
main_ddict = load_dataset('secret_data',data_files='buyer_listening_with_all_raw_data_w28.csv',split='train')
main_ddict

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/secret_data-042d1badc74881bf/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


Dataset({
    features: ['Week', 'Group', 'Source', 'Content', 'L1', 'L2', 'L3', 'L4', 'is_valid', 'iteration'],
    num_rows: 114605
})

In [None]:
if hasattr(main_ddict,'keys'):
    print('yes')

## Process metadatas

In [None]:
main_ddict

Dataset({
    features: ['Week', 'Group', 'Source', 'Content', 'L1', 'L2', 'L3', 'L4', 'is_valid', 'iteration'],
    num_rows: 114605
})

In [None]:
main_ddict[-3:]

{'Week': [28.0, 28.0, 28.0],
 'Group': ['Tú Bà Bà',
  'Gia Lai-Thanh lý đồ dùng và thời trang',
  'thuonghieucongluan.com.vn'],
 'Source': ['Non Owned', 'Non Owned', 'Non Owned'],
 'Content': ['Riết rồi k biết xài cái quần gì để k bị trừ tiền oan. Bữa trc thì vụ nạp thẻ qua shopee T chửi chưa đã cái miệng. Đổi qua momo cho lành. Ỷ y momo k có số dư nên chắc k sao. Chủ yếu lk tk ngân hàng thôi, cái nó tự liên kết thẻ trừ giao dịch qua apple mỗi tuần 129k/ tuần (ông cố ơi!)🙂Thêm cái chỉnh ảnh 419k/ năm (cái này cũng được đi).Dm bữa h tự trừ hết hơn 1tr trong tk ngân hàng. Ui là trời. 2 hộp sữa của con T ra đi nữa rồi đó.Hiện đại, hại điện.Khỏi cảm ơn, T xoá app rồi.Dm trả lại T 2 hộp Meiji đi rồi T sử dụng lại.!',
  'GHN có còn ship cho Sendo không các thím? Em mua shopee, lazada, TIKI thì đơn do best, GHN vs viettel ship nhiều nhất mà mấy bên này chắc ship hầu hết các sàn lớn đúng ko nhờ?',
  'Cục QLTT Hà Nội: Kiểm tra, xử lý nhiều vụ hàng lậu, hàng giả'],
 'L1': ['Feature', 'Delivery',

In [None]:
metadatas = ['Week','Source']
# metadatas = val2iterable(metadatas)
process_metas = True
main_text='Content'
is_batched=True

In [None]:
main_ddict = load_dataset('secret_data',data_files='buyer_listening_with_all_raw_data_w28.csv',split='train')
main_ddict

Found cached dataset csv (/home/quan/.cache/huggingface/datasets/csv/secret_data-042d1badc74881bf/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)


Dataset({
    features: ['Week', 'Group', 'Source', 'Content', 'L1', 'L2', 'L3', 'L4', 'is_valid', 'iteration'],
    num_rows: 114605
})

In [None]:
# def _process_metadatas(ds:dict,
#                        main_text,
#                        metadatas,
#                        process_metas=True,
#                        sep='.',
#                       is_batched=True):
#     metadatas = val2iterable(metadatas)
#     results={main_text:ds[main_text]}
#     for m in metadatas:
#         m_data = ds[m]
#         if process_metas:
#             # just strip and lowercase
#             m_data = [str(v).strip().lower() for v in m_data] if is_batched else str(m_data).strip().lower()
#         results[m]=m_data
#         if is_batched:
#             results[main_text] = [f'{m_data[i]}{sep} {results[main_text][i]}' for i in range(len(m_data))]
#         else:
#             results[main_text] = f'{m_data}{sep} {results[main_text]}'
#     return results


In [None]:
print_msg('Metadatas Simple Processing & Concatenating to Main Content')
main_ddict_meta = main_ddict.map(partial(_process_metadatas,
                                         main_text=main_text,
                                         metadatas=metadatas,
                                         process_metas=process_metas,
                                         is_batched=is_batched),
                                batched=is_batched)

----- Metadatas Simple Processing & Concatenating to Main Content -----


Map:   0%|          | 0/114605 [00:00<?, ? examples/s]

In [None]:
main_ddict_meta[:3]

{'Week': [1.0, 1.0, 1.0],
 'Group': ['Google Play', 'Google Play', 'Google Play'],
 'Source': ['google play', 'google play', 'google play'],
 'Content': ['google play. 1.0. Tại sao cứ hiện thông báo',
  'google play. 1.0. Mlem',
  'google play. 1.0. 1 số sản phẩm trong giỏ hàng vừa đc cập nhật trong khi giỏ ko còn 1 hàng nào nx 😀'],
 'L1': ['Services', 'Others', 'Feature'],
 'L2': ['Shopee communication channels', 'Cannot defined', 'Cart & Order'],
 'L3': ['Annoying pop-up ads', '-', 'Cart issues/suggestions'],
 'L4': ['Non-tech', '-', 'Tech'],
 'is_valid': [None, None, None],
 'iteration': [1, 1, 1]}

In [None]:
print_msg('Metadata Simple Processing & Concatenating to Main Content')
main_ddict_meta = main_ddict.map(partial(_process_metadatas,
                                         main_text=main_text,
                                         metadatas=metadatas,
                                         process_metas=process_metas,
                                         is_batched=False),
                                batched=False)

Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/secret_data-042d1badc74881bf/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-cadaba33433fe0f4.arrow


----- Metadata Simple Processing & Concatenating to Main Content -----


In [None]:
main_ddict_meta[:3]

{'Week': [1.0, 1.0, 1.0],
 'Group': ['Google Play', 'Google Play', 'Google Play'],
 'Source': ['google play', 'google play', 'google play'],
 'Content': ['google play. 1.0. Tại sao cứ hiện thông báo',
  'google play. 1.0. Mlem',
  'google play. 1.0. 1 số sản phẩm trong giỏ hàng vừa đc cập nhật trong khi giỏ ko còn 1 hàng nào nx 😀'],
 'L1': ['Services', 'Others', 'Feature'],
 'L2': ['Shopee communication channels', 'Cannot defined', 'Cart & Order'],
 'L3': ['Annoying pop-up ads', '-', 'Cart issues/suggestions'],
 'L4': ['Non-tech', '-', 'Tech'],
 'is_valid': [None, None, None],
 'iteration': [1, 1, 1]}

In [None]:
# stream
main_ddict_stream = load_dataset('secret_data',data_files='buyer_listening_with_all_raw_data_w28.csv',split='train',streaming=True)
main_ddict_stream

<datasets.iterable_dataset.IterableDataset>

In [None]:
main_ddict_meta = main_ddict_stream.map(partial(_process_metadatas,
                                         main_text=main_text,
                                         metadatas=metadatas,
                                         process_metas=process_metas,
                                         is_batched=True),
                                batched=True)

In [None]:
list(main_ddict_meta.take(3))

[{'Week': '1.0',
  'Group': 'Google Play',
  'Source': 'google play',
  'Content': 'google play. 1.0. Tại sao cứ hiện thông báo',
  'L1': 'Services',
  'L2': 'Shopee communication channels',
  'L3': 'Annoying pop-up ads',
  'L4': 'Non-tech',
  'is_valid': None,
  'iteration': 1},
 {'Week': '1.0',
  'Group': 'Google Play',
  'Source': 'google play',
  'Content': 'google play. 1.0. Mlem',
  'L1': 'Others',
  'L2': 'Cannot defined',
  'L3': '-',
  'L4': '-',
  'is_valid': None,
  'iteration': 1},
 {'Week': '1.0',
  'Group': 'Google Play',
  'Source': 'google play',
  'Content': 'google play. 1.0. 1 số sản phẩm trong giỏ hàng vừa đc cập nhật trong khi giỏ ko còn 1 hàng nào nx 😀',
  'L1': 'Feature',
  'L2': 'Cart & Order',
  'L3': 'Cart issues/suggestions',
  'L4': 'Tech',
  'is_valid': None,
  'iteration': 1}]

In [None]:
main_ddict_meta = main_ddict_stream.map(partial(_process_metadatas,
                                         main_text=main_text,
                                         metadatas=metadatas,
                                         process_metas=process_metas,
                                         is_batched=False),
                                batched=False)
# batched does not matter when using streamed
list(main_ddict_meta.take(3))

[{'Week': '1.0',
  'Group': 'Google Play',
  'Source': 'google play',
  'Content': 'google play. 1.0. Tại sao cứ hiện thông báo',
  'L1': 'Services',
  'L2': 'Shopee communication channels',
  'L3': 'Annoying pop-up ads',
  'L4': 'Non-tech',
  'is_valid': None,
  'iteration': 1},
 {'Week': '1.0',
  'Group': 'Google Play',
  'Source': 'google play',
  'Content': 'google play. 1.0. Mlem',
  'L1': 'Others',
  'L2': 'Cannot defined',
  'L3': '-',
  'L4': '-',
  'is_valid': None,
  'iteration': 1},
 {'Week': '1.0',
  'Group': 'Google Play',
  'Source': 'google play',
  'Content': 'google play. 1.0. 1 số sản phẩm trong giỏ hàng vừa đc cập nhật trong khi giỏ ko còn 1 hàng nào nx 😀',
  'L1': 'Feature',
  'L2': 'Cart & Order',
  'L3': 'Cart issues/suggestions',
  'L4': 'Tech',
  'is_valid': None,
  'iteration': 1}]

In [None]:
main_ddict_meta

Dataset({
    features: ['Week', 'Group', 'Source', 'Content', 'L1', 'L2', 'L3', 'L4', 'is_valid', 'iteration'],
    num_rows: 114605
})

## Process labels

In [None]:
metadatas = ['Week','Source']
process_metas = True
main_text='Content'
is_batched=True

In [None]:
print_msg('Metadatas Simple Processing & Concatenating to Main Content')
main_ddict_meta = main_ddict.map(partial(_process_metadatas,
                                         main_text=main_text,
                                         metadatas=metadatas,
                                         process_metas=process_metas,
                                         is_batched=is_batched),
                                batched=is_batched)

Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/csv/secret_data-042d1badc74881bf/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1/cache-ac679a0b1f06a8e9.arrow


----- Metadatas Simple Processing & Concatenating to Main Content -----


In [None]:
len(main_ddict_meta['L1'])

114605

In [None]:
#     def _encode_labels(self):
#         print_msg('Label Encoding')
#         if self.label_names is None: 
#             raise ValueError('Missing label columns!')
#         self.label_names = val2iterable(self.label_names)
#         if len(self.label_names)>1:
#             self.is_multihead=True
        
#         if self.label_lists is not None and not isinstance(self.label_lists[0],list):
#             self.label_lists = [self.label_lists]
        
#         if isinstance(self.df[self.label_names[0]].iloc[0],list):
########               (self.dset[self.label_names[0]][0],list)
#             # This is multi-label. Ignore self.label_names[1:]
#             self.label_names = [self.label_names[0]]
#             self.is_multihead=False
#             self.is_multilabel=True
            
#         encoder_classes=[]
#         if not self.is_multilabel:
#             for idx,l in enumerate(self.label_names):
#                 if self.label_lists is None:
#                     train_label = self.df[l].values
#                     l_encoder = LabelEncoder()
#                     self.df[l] = l_encoder.fit_transform(train_label)
#                     encoder_classes.append(list(l_encoder.classes_))
#                 else:
#                     l_classes = sorted(list(self.label_lists[idx]))
#                     label2idx = {v:i for i,v in enumerate(l_classes)}
#                     self.df[l] = self.df[l].map(label2idx).values
#                     encoder_classes.append(l_classes)
#         else:
#             # For MultiLabel, we only save the encoder classes without transforming the label itself to one-hot (or actually, few-hot)
#             if self.label_lists is None:
#                 l_encoder = MultiLabelBinarizer()
#                 _ = l_encoder.fit(self.df[self.label_names[0]])
#                 encoder_classes.append(list(l_encoder.classes_))
#             else:
#                 l_classes = sorted(list(self.label_lists[0]))
#                 encoder_classes.append(l_classes)
                
#         self.label_lists = encoder_classes

## Constructor/ Class Method calls

If you just want to get the dataframe from the csv path, set ```return_df=True```. You still have the input validation precheck functionality.

In [None]:
df = TextDataMain.from_csv(DATA_PATH/'sample_large.csv',
                            return_df=True)

The ```Input Validation Precheck``` will check for missing values and duplicate rows in the csv file. Since there's no such thing in our sample dataset, we won't see anything here

In [None]:
df.sample(5)

In [None]:
df.Source.value_counts()

Let's say you are happy with this dataframe (after you did some others preprocessing), then you can start creating a `TextDataMain` object

For this dataframe, I want to 
- Build a text classification model, with main text in ```Content``` column, metadatas is ```Source```, and the label is ```L1```
- Perform `apply_word_tokenize` with text normalization (this is "text transformation")
- For augmentation, I want to perform: Oversampling the ```Owned, Non Owned and HC Search``` from column ```Source```, then add some the Vietnamese no-accent text. Note that all of these are called "text augmentation"

Let's define these transformations

> For Text Transformation

In [None]:
awt_tfm = partial(apply_word_tokenize,normalize_text=True)
# You can also set a __name__ to your augmentation function. 
# This way you will have meaningful text messages as outputs
awt_tfm.__name__='UTS Word Tokenization With Normalization'

txt_tfms=[awt_tfm]

> For Text Augmentation

In [None]:
# apply_to_all means I will apply this augmentation to all the data 
# (including the original data and the augmented data/transformed data from previous augmentation/transformation)
over_nonown_tfm = partial(sampling_with_condition,query='Source=="non owned"',frac=0.5,seed=42,apply_to_all=False)
over_nonown_tfm.__name__ = 'Oversampling Non Owned'

over_own_tfm = partial(sampling_with_condition,query='Source=="owned"',frac=2,seed=42,apply_to_all=False)
over_own_tfm.__name__ = 'Oversampling Owned'

over_hc_tfm = partial(sampling_with_condition,query='Source=="hc search"',frac=2.5,seed=42,apply_to_all=False)
over_hc_tfm.__name__ = 'Oversampling HC search'

remove_accent_tfm = partial(remove_vnmese_accent,frac=1,seed=42,apply_to_all=True)
remove_accent_tfm.__name__ = 'Add No-Accent Text'

aug_tfms = [over_nonown_tfm,over_own_tfm,over_hc_tfm,remove_accent_tfm]


In [None]:
tdm = TextDataMain(df,
                    main_content='Content',
                    metadatas='Source', # You can put a list of multiple metadatas
                    label_names='L1', # You can put a list of multiple labels
                    val_ratio=0.2,
                    split_cols='L1', # You can even put a list of multiple columns to be used for validation splitting
                    content_tfms = txt_tfms, # You can add multiple content transformation functions ...
                    aug_tfms = aug_tfms, # ... as well as augmentation functions
                    process_metadatas=True,
                    seed=42,
                    shuffle_trn=True)

If we want to directly create a ```TextDataMain``` object from our csv file, we can instead use this:

In [None]:
tdm = TextDataMain.from_csv(DATA_PATH/'sample_large.csv',
                            return_df=False,
                            main_content='Content',
                            metadatas='Source',
                            label_names='L1',
                            val_ratio=0.2,
                            split_cols='L1',
                            content_tfms = txt_tfms,
                            aug_tfms = aug_tfms,
                            process_metadatas=True,
                            seed=42,
                            shuffle_trn=True)

In [None]:
show_doc(TextDataMain.to_df)

Note that all the previous constructor calls do not do any heavy processing yet.

To actually run all the processes, one can call `TextDataMain.to_df()`

In [None]:
df_processed = tdm.to_df()

Notice this?
```
Previous Validation Percentage: 20.0%
- Before leak check
Size: 14
- After leak check
Size: 14
- Number of rows leaked: 0, or 0.00% of the original validation (or test) data
Current Validation Percentage: 20.0%
```
After performing train/test split, the ```TextDataMain``` object also perform a "leak check": After `text_transformation` is performed, it will compare the text from ```Content``` value in the validation set to the ```Content``` text in the train set. Any duplications (texts that belong to both set) will be removed from validation set.

In [None]:
df_processed.sample(5)

Note that, since we have metadatas, the metadatas is concatenated to the front of the texture content

In [None]:
df_processed.Content.sample(5).values

We now have a new dataframe with only the necessary columns (the processed text column, metadatas, label, and ```is_valid``` which tells you which row belongs to the validation set). Notice that our class has also encode our label for us

Our TextDataMain object also stores other useful attributes, such as:

In [None]:
# The entire processed dataframe, similar to the df_processed above
tdm.df.head()

In [None]:
# class names (This will be a list of list, as this class can handle multi-label classification)
tdm.label_lists

In [None]:
# a dictionary storing unique value for each provided metadata
tdm.metadata_dict

If we want to see how a HuggingFace's tokenizer work on our processed text:

In [None]:
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")

In [None]:
# this will pick a random text from train set to show
tdm.tokenizer_explain_single(tokenizer)

By doing this, we can see how the tokenizer interact with our text.

In [None]:
show_doc(TextDataMain.to_datasetdict)

Since we need to convert our data to HuggingFace's DatasetDict format in order to utilize HuggingFace's model well, we can directly export datasetdict using `TextDataMain.to_datasetdict`

In [None]:
ddict_sample = tdm.to_datasetdict(tokenizer)

In [None]:
ddict_sample

In [None]:
ddict_sample['train']['text'][0]

In [None]:
print(ddict_sample['train']['input_ids'][0])

Note that PhoBert will auto-pad our sentence to its model max_sequence_length, which is 256

In [None]:
len(ddict_sample['train']['input_ids'][0])

In [None]:
ddict_sample['train']['label'][0]

In [None]:
show_doc(TextDataMain.save_as_pickles)

As the transformations/augmentations can take time for large dataset, we want to save our TextDataMain object. We can use `TextDataMain.save_as_pickles` to export a pickle file

In [None]:
tdm.save_as_pickles('my_tdm')

Then you can load it with

In [None]:
tdm2 = TextDataMain.from_pickle('my_tdm')

... and access all the attributes

In [None]:
tdm2.df.head()

In [None]:
tdm2.label_lists[0]

In [None]:
tdm2.metadata_dict

Let's check the file size

In [None]:
file_stats = os.stat(Path('pickle_files/my_tdm.pkl'))
print(f'File Size in MegaBytes is {file_stats.st_size / (1024 * 1024)}')

As it saves the entire processed dataframe (and datasetdict if you call ```to_datasetdict```), the pickle size can be large. In some scenario you don't need to store these data attributes (as inference time, or in production). Thus one can save a lighter pickle file by setting ```drop_data_attributes``` to ```True```

In [None]:
tdm.save_as_pickles('my_lightweight_tdm',drop_data_attributes=True)

In [None]:
file_stats = os.stat(Path('pickle_files/my_lightweight_tdm.pkl'))
print(f'File Size in MegaBytes is {file_stats.st_size / (1024 * 1024)}')

We will see a bigger file size reduction when we work with much larger dataset

In [None]:
tdm_light = TextDataMain.from_pickle('my_lightweight_tdm')

You can still access some important attributes (except for any data attributes, such as ```df``` or ```main_ddict```

In [None]:
tdm_light.label_lists[0]

In [None]:
tdm_light.metadata_dict

In [None]:
#| hide
# import nbdev; nbdev.nbdev_export()