In [None]:
import os

from datasets import DatasetDict, load_dataset

# Load your datasets (as you did)
dataset = load_dataset("amayuelas/lematerial-chemrxiv-filtered")
df = dataset["chemrxiv"].to_pandas()

dataset_upstream = load_dataset(
    "LeMaterial/LeMat-Synth",
    verification_mode="no_checks",
)
df_upstream = dataset_upstream["chemrxiv"].to_pandas()

ids_to_keep = df[df["contains_recipe"]]["id"].tolist()
len(ids_to_keep)

# 4. Create a new DatasetDict to store the filtered splits
filtered_dataset_dict = DatasetDict()

print(
    "Processing and filtering 'chemrxiv' split directly using dataset.filter()."
)

# Get the 'chemrxiv' split directly from the DatasetDict
chemrxiv_ds = dataset_upstream["chemrxiv"]

filtered_chemrxiv_ds = chemrxiv_ds.filter(
    lambda example: example["id"] in ids_to_keep,
    num_proc=os.cpu_count(),  # Use all available CPU cores for parallelism
)

print("Filtering done for 'chemrxiv' split.")

print("Adding to DatasetDict")
# Add the filtered 'chemrxiv' split to your new DatasetDict
filtered_dataset_dict["chemrxiv"] = filtered_chemrxiv_ds

print(f"  Original 'chemrxiv' split size: {len(chemrxiv_ds)}")
print(f"  Filtered 'chemrxiv' split size: {len(filtered_chemrxiv_ds)}")


# For all other splits, directly copy them to the new DatasetDict
# These operations are fast as they are just references or shallow copies
filtered_dataset_dict["arxiv"] = dataset_upstream["arxiv"]
print(f"  'arxiv' split size: {len(dataset_upstream['arxiv'])}")
filtered_dataset_dict["omg24"] = dataset_upstream["omg24"]
print(f"  'omg24' split size: {len(dataset_upstream['omg24'])}")

# To inspect your filtered_dataset_dict
print("\nFinal filtered_dataset_dict structure:")
print(filtered_dataset_dict)
print("\nExample from filtered 'chemrxiv' split:")
print(filtered_dataset_dict["chemrxiv"][0])  # Access the first example

In [None]:
from datasets import DatasetDict, load_dataset

# Assuming filtered_dataset_dict is already defined from your previous steps

print("Sampling 100 random entries from each split and adding 'source' column.")

# Create a list to hold the sampled datasets before concatenating them
sampled_datasets_list = []
splits_to_sample = [
    "arxiv",
    "chemrxiv",
    "omg24",
]  # Define the splits you want to sample from

for split_name in splits_to_sample:
    if split_name in filtered_dataset_dict:
        current_split_ds = filtered_dataset_dict[split_name]
        total_samples_in_split = len(current_split_ds)

        if total_samples_in_split == 0:
            print(
                f"  Warning: Split '{split_name}' is empty. Skipping sampling."
            )
            continue

            # Shuffle the dataset first to ensure random sampling
            # Use a fixed seed for reproducibility if needed (e.g., seed=42)
        shuffled_ds = current_split_ds.shuffle(seed=42)

        # Determine the number of samples to take: 100 or the total available if less
        num_samples_to_take = min(100, total_samples_in_split)

        # Select the first `num_samples_to_take` entries from the shuffled dataset
        sampled_entries = shuffled_ds.select(range(num_samples_to_take))

        # Add the "source" column to the sampled entries
        sampled_entries_with_source = sampled_entries.map(
            lambda example: {"source": split_name},
        )

        print(
            f"  Sampled {num_samples_to_take} entries from '{split_name}' (original size: {total_samples_in_split})."
        )
        sampled_datasets_list.append(sampled_entries_with_source)
    else:
        print(
            f"  Warning: Split '{split_name}' not found in filtered_dataset_dict. Skipping."
        )

# Concatenate all sampled datasets into a new single Dataset for evaluation
if sampled_datasets_list:
    # Concatenate only if there are datasets to concatenate
    from datasets import concatenate_datasets

    combined_sampled_ds = concatenate_datasets(sampled_datasets_list)
    filtered_dataset_dict["sample_for_evaluation"] = combined_sampled_ds
    print(
        f"\nCreated 'sample_for_evaluation' split with {len(combined_sampled_ds)} total entries."
    )
    print("Example from 'sample_for_evaluation' with 'source' column:")
    print(filtered_dataset_dict["sample_for_evaluation"][0])
else:
    print(
        "\nNo samples were generated for 'sample_for_evaluation' as no valid splits were found or all were empty."
    )

print("\nFinal filtered_dataset_dict structure:")
print(filtered_dataset_dict)

In [None]:
filtered_dataset_dict.push_to_hub("LeMaterial/LeMat-Synth")

In [None]:
dataset = load_dataset("LeMaterial/LeMat-Synth")
dataset