In [13]:
# load imports

from datasets import load_dataset, DatasetDict
from collections import Counter
import pandas as pd
from sklearn.preprocessing import OneHotEncoder

In [None]:
# load dataset (task 1)

dataset = load_dataset('ag_news')
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})


In [None]:
# split dataset (task 2)

split = dataset["train"].train_test_split(test_size=0.2, seed=42)
dataset = DatasetDict({
    "train": split["train"],
    "dev": split["test"],
    "test": dataset["test"]
})
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 96000
    })
    dev: Dataset({
        features: ['text', 'label'],
        num_rows: 24000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})


In [None]:
# dataset statistics (task 3)

import numpy as np


def dataset_statistics(dataset, name):
    text_lengths = [len(text.split()) for text in dataset["text"]]
    min_length = np.min(text_lengths)
    max_length = np.max(text_lengths)
    mean_length = np.mean(text_lengths)
    std_length = np.std(text_lengths)
    median_length = np.median(text_lengths)
    data_size = len(dataset)
    class_counts = Counter(dataset["label"])

    
    return {
        "Zbiór": name,
        "Rozmiar": data_size,
        "Średnia długość tekstu": mean_length,
        "Minimalna długość tekstu": min_length,
        "Maksymalna długość tekstu": max_length,
        "Odchylenie standardowe": std_length,
        "Mediana długości": median_length,
        "Rozkład klas": dict(class_counts)
    }

stats_train = dataset_statistics(dataset["train"], "Train")
stats_dev = dataset_statistics(dataset["dev"], "Dev")
stats_test = dataset_statistics(dataset["test"], "Test")

stats = pd.DataFrame([stats_train, stats_dev, stats_test])
print(stats)


   Zbiór  Rozmiar  Średnia długość tekstu  Minimalna długość tekstu  \
0  Train    96000               37.842917                         8   
1    Dev    24000               37.865583                         8   
2   Test     7600               37.722368                        11   

   Maksymalna długość tekstu  Odchylenie standardowe  Mediana długości  \
0                        177               10.086832              37.0   
1                        157               10.078666              37.0   
2                        137               10.129193              37.0   

                               Rozkład klas  
0  {2: 24061, 1: 23949, 3: 24041, 0: 23949}  
1      {0: 6051, 1: 6051, 3: 5959, 2: 5939}  
2      {2: 1900, 3: 1900, 1: 1900, 0: 1900}  


In [None]:
# normalization (task 4)

if "label_one_hot" in dataset["train"].column_names:
    dataset["train"] = dataset["train"].remove_columns("label_one_hot")

encoder = OneHotEncoder(sparse_output=False)
labels = np.array(dataset["train"]["label"]).reshape(-1, 1)
one_hot_labels = encoder.fit_transform(labels)

dataset["train"] = dataset["train"].add_column("label_one_hot", one_hot_labels.tolist())

df = pd.DataFrame(dataset["train"][:5])
print(df)

                                                text  label  \
0  Nation #39;s Cotton Crop May Exceed Records Th...      2   
1  18 years and still rollin #39; ALEX FERGUSON w...      1   
2  Madrid Masters: Safin beats Nalbandian Sunday ...      1   
3  Sirius Satellite Signs Howard Stern to 5-Year ...      2   
4  NATO, Russia To Meet Over Beslan School Siege ...      3   

          label_one_hot  
0  [0.0, 0.0, 1.0, 0.0]  
1  [0.0, 1.0, 0.0, 0.0]  
2  [0.0, 1.0, 0.0, 0.0]  
3  [0.0, 0.0, 1.0, 0.0]  
4  [0.0, 0.0, 0.0, 1.0]  


In [None]:
# clean dataset (task 5)

def validate_samples(sample):
    if not sample["text"]:
        return False
    if sample["label"] not in {0, 1, 2, 3}:
        return False

    return True

dataset["train"] = dataset["train"].filter(validate_samples)
dataset["dev"] = dataset["dev"].filter(validate_samples)
dataset["test"] = dataset["test"].filter(validate_samples)

print(dataset)



Filter: 100%|██████████| 96000/96000 [00:00<00:00, 245746.80 examples/s]
Filter: 100%|██████████| 24000/24000 [00:00<00:00, 397910.10 examples/s]
Filter: 100%|██████████| 7600/7600 [00:00<00:00, 418598.71 examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'label_one_hot'],
        num_rows: 96000
    })
    dev: Dataset({
        features: ['text', 'label'],
        num_rows: 24000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})



