In [1]:
!pip install -q datasets huggingface_hub tqdm seaborn

In [13]:
import datasets
from datasets import Dataset, DatasetDict, load_dataset

from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import seaborn as sns
import os
import random

from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.bleu_score import SmoothingFunction

In [14]:
# Dataset Arguments
dataset_id = 'ag_news'
label_column = 'label'
text_column = 'text'
encode_label = False # if label column is not encoded, set this flag to True

# Experiment Arguments
max_samples_per_class = 128
blue_threshold = 0.2
#output_dir = '/content/drive/MyDrive/Research/Folha/BLUE/Balanceamento de Dados/Dados/ag_news'
balanced_dataframe_filename = f'ag_news_balanced_dataset_blue_{max_samples_per_class}samples_{blue_threshold}threshold.csv'

In [8]:
def load_hf_dataset(dataset_id, label_column, text_column, encode_label):
  # load dataset
  dataset = load_dataset(dataset_id, split = 'train')

  if encode_label:
    dataset = dataset.class_encode_column(label_column)

  # rename label column to "label" if its different from this
  if label_column != 'label':
    dataset = dataset.rename_column(label_column, 'label')

  # rename text column to "text" if its different from this
  if text_column != 'text':
    dataset = dataset.rename_column(text_column, 'text')

  return dataset

In [9]:
dataset = load_hf_dataset(
    dataset_id = dataset_id,
    label_column = label_column,
    text_column = text_column,
    encode_label = encode_label
  )

Downloading builder script: 100%|██████████| 4.06k/4.06k [00:00<?, ?B/s]
Downloading metadata: 100%|██████████| 2.65k/2.65k [00:00<?, ?B/s]
Downloading readme: 100%|██████████| 7.95k/7.95k [00:00<?, ?B/s]
Downloading data: 29.5MB [00:00, 50.0MB/s]                            
Downloading data: 1.86MB [00:00, 10.2MB/s]                  
Generating train split: 100%|██████████| 120000/120000 [00:02<00:00, 48829.19 examples/s]
Generating test split: 100%|██████████| 7600/7600 [00:00<00:00, 46351.76 examples/s]


In [32]:
def blue_score(reference_text, generated_text):
    reference_summaries = [[reference_text]]
    generated_summaries = [generated_text]
    smoothie = SmoothingFunction().method4 # using smoothing method 4 as an example
    bleu_score = corpus_bleu(reference_summaries, generated_summaries, smoothing_function=smoothie)

    return bleu_score

In [33]:
blue_score("hello", "Hello helo hhello")

0.21409092659758044

In [12]:
dataset

Dataset({
    features: ['text', 'label'],
    num_rows: 120000
})

In [35]:
def select_samples(dataset):
    num_rows = len(dataset)
    select_indexes = []
    order_selected = []
    order_selected_class = []
    labels = dataset['label']
    labels_counter = {key : 0 for key in set(labels)}
    all_higher = False
    
    # First sample selection
    print("Selecting first sample...")
    first_sample_index = random.randint(0, num_rows - 1)
    labels_counter[labels[first_sample_index]] += 1
    select_indexes.append(first_sample_index)
    order_selected.append(1)
    order_selected_class.append(labels_counter[labels[first_sample_index]])
    last_sample_index = first_sample_index
    
    print("Started iteration for selection...")
    while not all_higher:
        candidate_index = random.randint(0, num_rows - 1)
        candidate = dataset[candidate_index]
        print(f"Candidate index: {candidate_index}")
        if (not candidate_index in select_indexes):
            reference_text = dataset[last_sample_index]["text"]
            candidate_text = dataset[candidate_index]["text"]
            score = blue_score(reference_text, candidate_text)
            if (score < blue_threshold):                
                labels_counter[labels[candidate_index]] += 1
                select_indexes.append(candidate_index)
                order_selected.append(len(select_indexes))
                order_selected_class.append(labels_counter[labels[candidate_index]])
                last_sample_index = candidate_index
                print(f"order selected: {len(select_indexes)}\torder selected class: {labels_counter[labels[candidate_index]]}")
        all_higher = all(value >= max_samples_per_class for value in labels_counter.values())
    print("Samples selected!")
    balanced_dataset = dataset.select(select_indexes)
    return balanced_dataset, select_indexes, order_selected, order_selected_class
    

In [36]:
balanced_dataset, select_indexes, order_selected, order_selected_class = select_samples(dataset)

Selecting first sample...
Started iteration for selection...
Candidate index: 70644
order selected: 2	order selected class: 1
Candidate index: 38593
order selected: 3	order selected class: 1
Candidate index: 118533
order selected: 4	order selected class: 1
Candidate index: 8503
Candidate index: 77274
Candidate index: 16013
Candidate index: 81081
order selected: 5	order selected class: 2
Candidate index: 108083
order selected: 6	order selected class: 3
Candidate index: 107106
order selected: 7	order selected class: 2
Candidate index: 111668
Candidate index: 115951
order selected: 8	order selected class: 2
Candidate index: 52031
Candidate index: 83856
order selected: 9	order selected class: 3
Candidate index: 89911
order selected: 10	order selected class: 4
Candidate index: 108828
order selected: 11	order selected class: 3
Candidate index: 57185
order selected: 12	order selected class: 4
Candidate index: 69256
order selected: 13	order selected class: 5
Candidate index: 58119
Candidate in

order selected: 215	order selected class: 45
Candidate index: 8046
order selected: 216	order selected class: 46
Candidate index: 53349
Candidate index: 81799
Candidate index: 43872
order selected: 217	order selected class: 47
Candidate index: 70952
order selected: 218	order selected class: 59
Candidate index: 119055
order selected: 219	order selected class: 48
Candidate index: 36152
order selected: 220	order selected class: 56
Candidate index: 4417
order selected: 221	order selected class: 58
Candidate index: 21759
Candidate index: 41102
order selected: 222	order selected class: 49
Candidate index: 39321
order selected: 223	order selected class: 60
Candidate index: 69669
Candidate index: 90368
order selected: 224	order selected class: 61
Candidate index: 113504
Candidate index: 63242
Candidate index: 66533
Candidate index: 44745
order selected: 225	order selected class: 50
Candidate index: 95926
order selected: 226	order selected class: 51
Candidate index: 96650
Candidate index: 93879


Candidate index: 77289
Candidate index: 54684
order selected: 487	order selected class: 124
Candidate index: 38828
order selected: 488	order selected class: 125
Candidate index: 82937
order selected: 489	order selected class: 129
Candidate index: 71974
order selected: 490	order selected class: 127
Candidate index: 62546
order selected: 491	order selected class: 130
Candidate index: 67450
order selected: 492	order selected class: 131
Candidate index: 21610
order selected: 493	order selected class: 132
Candidate index: 29969
order selected: 494	order selected class: 126
Candidate index: 72672
Candidate index: 29185
order selected: 495	order selected class: 127
Candidate index: 1827
order selected: 496	order selected class: 128
Candidate index: 455
order selected: 497	order selected class: 133
Candidate index: 4570
order selected: 498	order selected class: 134
Candidate index: 17157
Candidate index: 4600
order selected: 499	order selected class: 128
Candidate index: 44379
order selected: 

In [37]:
def create_df(dataset, order_selected, order_selected_class, filepath):
  if os.path.exists(filepath):
    df = pd.read_csv(filepath)
  else:
    df = dataset.to_pandas()
    df['order_selected'] = order_selected
    df['order_selected_class'] = order_selected_class

    df.to_csv(filepath, index = False)

  return df

In [38]:
dataframe_filepath = os.path.join(os.getcwd(), balanced_dataframe_filename)
df_balanced = create_df(
    dataset = balanced_dataset,
    order_selected = order_selected,
    order_selected_class = order_selected_class,
    filepath = dataframe_filepath
)