In [1]:
import json
from datasets import load_from_disk, load_dataset, concatenate_datasets
from transformers import RobertaTokenizer
import os


In [2]:
dataset_csn = load_from_disk("/data/nicolasmaier/dataset/hf_clean_csn_1")
dataset_seq = load_from_disk("/data/nicolasmaier/dataset/hf_clean_seq_1")

print(dataset_csn)
print(dataset_seq)


DatasetDict({
    train: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url', 'code', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 909090
    })
    test: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url', 'code', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 50975
    })
    validation: Dataset({
        features: ['repository_name', 'func_path_in_repository', 'func_name', 'whole_func_string', 'language', 'func_code_string', 'func_code_tokens', 'func_documentation_string', 'func_documentation_tokens', 'split_name', 'func_code_url', 'code', 'input_ids', 'attention_mask

In [None]:
dataset_csn_java = dataset_csn.filter(
    lambda x: x["language"] == "java",
    num_proc=64,
)
dataset_csn_python = dataset_csn.filter(
    lambda x: x["language"] == "python",
    num_proc=64,
)
dataset_csn_javascript = dataset_csn.filter(
    lambda x: x["language"] == "javascript",
    num_proc=64,
)


In [4]:
csn_extra_cols = [
    "repository_name",
    "func_path_in_repository",
    "func_name",
    "whole_func_string",
    "language",
    "func_code_string",
    "func_code_tokens",
    "func_documentation_string",
    "func_documentation_tokens",
    "split_name",
    "func_code_url",
]

dataset_csn_java = dataset_csn_java.remove_columns(csn_extra_cols)
dataset_csn_python = dataset_csn_python.remove_columns(csn_extra_cols)
dataset_csn_javascript = dataset_csn_javascript.remove_columns(csn_extra_cols)
dataset_seq = dataset_seq.remove_columns(
    [
        #"code",
        "contents",
        "xmi",
        "originalLine",
        "seq",
    ]
)

print(dataset_csn_java)
print(dataset_csn_python)
print(dataset_csn_javascript)
print(dataset_seq)


DatasetDict({
    train: Dataset({
        features: ['code', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 424692
    })
    test: Dataset({
        features: ['code', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 25027
    })
    validation: Dataset({
        features: ['code', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 14615
    })
})
DatasetDict({
    train: Dataset({
        features: ['code', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 374271
    })
    test: Dataset({
        features: ['code', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 20115
    })
    validation: Dataset({
        features: ['code', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 20695
    })
})
DatasetDict({
    train: Dataset({
        features: ['code', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 110127
    })
    test: Dataset({
        features: ['code', 'input_ids', 'attention_mask', 'labels'],
     

In [9]:
SEED = 42

merged_train = concatenate_datasets(
    [
        dataset_csn_java["train"].shuffle(seed=SEED).select(range(100_000)),
        dataset_csn_python["train"].shuffle(seed=SEED).select(range(100_000)),
        dataset_csn_javascript["train"].shuffle(seed=SEED).select(range(100_000)),
        dataset_seq["train"].shuffle(seed=SEED).select(range(300_000)),
    ]
)
merged_valid = concatenate_datasets(
    [
        dataset_csn_java["validation"].shuffle(seed=SEED).select(range(1000)),
        dataset_csn_python["validation"].shuffle(seed=SEED).select(range(1000)),
        dataset_csn_javascript["validation"].shuffle(seed=SEED).select(range(1000)),
        dataset_seq["valid"].shuffle(seed=SEED).select(range(3000)),
    ]
)
merged_test = concatenate_datasets(
    [
        dataset_csn_java["test"].shuffle(seed=SEED).select(range(1000)),
        dataset_csn_python["test"].shuffle(seed=SEED).select(range(1000)),
        dataset_csn_javascript["test"].shuffle(seed=SEED).select(range(1000)),
        dataset_seq["test"].shuffle(seed=SEED).select(range(3000)),
    ]
)


Loading cached shuffled indices for dataset at /data/nicolasmaier/dataset/hf_clean_csn_1/train/cache-478d3d66bbb95c76.arrow
Loading cached shuffled indices for dataset at /data/nicolasmaier/dataset/hf_clean_csn_1/train/cache-2f7a4879408ec3fc.arrow
Loading cached shuffled indices for dataset at /data/nicolasmaier/dataset/hf_clean_csn_1/train/cache-0f82bc98440e68de.arrow
Loading cached shuffled indices for dataset at /data/nicolasmaier/dataset/hf_clean_seq_1/train/cache-53cd606429ad6615.arrow
Loading cached shuffled indices for dataset at /data/nicolasmaier/dataset/hf_clean_csn_1/validation/cache-d9d7996db8ee43a5.arrow
Loading cached shuffled indices for dataset at /data/nicolasmaier/dataset/hf_clean_csn_1/validation/cache-c19e01dc169dc06a.arrow
Loading cached shuffled indices for dataset at /data/nicolasmaier/dataset/hf_clean_csn_1/validation/cache-4b4c9e475f825271.arrow


In [12]:
print(merged_train)
print(merged_valid)
print(merged_test)

merged_train.save_to_disk("/data/nicolasmaier/dataset/hf_merged_train_1")
merged_valid.save_to_disk("/data/nicolasmaier/dataset/hf_merged_valid_1")
merged_test.save_to_disk("/data/nicolasmaier/dataset/hf_merged_test_1")


Dataset({
    features: ['code', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 600000
})
Dataset({
    features: ['code', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 6000
})
Dataset({
    features: ['code', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 6000
})


Flattening the indices:   0%|          | 0/600 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/6 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/6 [00:00<?, ?ba/s]