### Score curation 

- Use the score transition matrix (report) to implement score curation

In [1]:
import torch 
from collections import Counter
import random
from datasets import load_dataset
import numpy as np
import math

seed=3
random.seed(seed)
np.random.seed(seed)

dataset_name='dolly'
model_name="meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset_size = 10000
confidence_prob = 0.5

all_train_dataset = load_dataset('json', data_files =f"{dataset_name}.json") #300k data


# label curation reports
report_path = f"score_curation/results/{model_name}/{dataset_name}/{dataset_name}_report.pt"
reports = torch.load(report_path)

corrupted_samples = [x[0] for x in reports.detection['label_error']]

curated_sample = []
curated_sample_scores = []
for sample in reports.curation['label_curation']:  # (idx, label, confidence)
    if sample[2] >= confidence_prob:  
        curated_sample.append(sample[0])
        curated_sample_scores.append((int(sample[0]), int(sample[1]), round(sample[2],2)))

print(f"Curated sample size: {len(curated_sample_scores)}")

# Filter out some cured samples from corrupted instances
curated_sample_set = set(curated_sample)
corrupted_samples_total = [x for x in corrupted_samples if x not in curated_sample_set]

print(f"Corrupted samples total: {len(corrupted_samples_total)}")

# Change the original labels to the suggested label
root_path = f"../model_finetune/selected_data/{model_name}/{dataset_name}/"
scores = torch.load(root_path + "output_scores.pt")

for sample_score in curated_sample_scores:
    scores[sample_score[0]] = sample_score[1]


### load the score noise 
torch.save(scores, root_path + f"output_scores_curated.pt")


  from .autonotebook import tqdm as notebook_tqdm
  reports = torch.load(report_path)


==== Docta: Doctor for your data. Current version: 0.2 ====
Cured sample size: 10510
Corrupted samples total: 112457
Original Counter(labels): Counter({3: 87975, 2: 86132, 4: 59969, 1: 44401, 0: 18626, 5: 3829})
counting revised label size: 10510
Label size: 300932
Revised Counter(labels): Counter({3: 89665, 2: 86085, 4: 61057, 1: 43117, 0: 18503, 5: 2505})
Label-wise filter out samples: 112457


  labels = torch.load(root_path + "output_labels_revised.pt")


## Data selection method


In [None]:
import torch
import random
import numpy as np
from datasets import load_dataset
from collections import Counter

seed = 3
random.seed(seed)
np.random.seed(seed)

# Part 2 (feature-wise): Process rare samples based on 'rare_example' detection
rare_samples = reports.detection['rare_example'][:len(reports.detection['rare_example']) // 2]
rare_samples_filtered = np.array(rare_samples)[:, :2]  # Use NumPy for faster operations

print(f"Size of the remaining samples with high quality: {len(rare_samples_filtered)}")

labels = np.array(labels)

# Cache label indices to avoid repeated searches
label_indices_cache = {label: np.where(labels == label)[0] for label in [5, 4, 3, 2, 1]}

# Initialize list to store selected indices
filtered_indices = []

# Filter and sort samples by label
for target_label in [5, 4, 3, 2, 1]:
    if len(filtered_indices) >= dataset_size:
        break

    # Get indices of current label
    label_indices = label_indices_cache[target_label]
    available_size = dataset_size - len(filtered_indices)

    # Add label indices if enough space, else sort and add top samples
    if available_size > len(label_indices):
        filtered_indices.extend(label_indices.tolist())
    else:
        # Filter and sort samples with the target label by score
        label_samples = rare_samples_filtered[np.isin(rare_samples_filtered[:, 0], label_indices)]
        if len(label_samples) > 0:  
            sorted_samples = label_samples[label_samples[:, 1].argsort()[::-1]][:available_size]
            filtered_indices.extend(sorted_samples[:, 0].astype(int).tolist())

    print("Size of the filtered dataset:", len(filtered_indices))


# Load the dataset and filter out samples by selected indices
raw_dataset = load_dataset('json', data_files=root_path + 'full_dataset.json')
filtered_dialogs = raw_dataset['train'].select(filtered_indices)
filtered_dialogs.to_json(root_path + f"filtered-curated_dataset.json")
print(f"Final dataset saved to {root_path}filtered-curated_dataset.json")
