In [3]:
import pandas as pd
from collections import defaultdict

def create_matching_subset(dataset_a_path, dataset_b_path, chunk_size=10000):

    # Define columns to match on
    columns = ['PedPed', 'Barrier', 'CrossingSignal', 'AttributeLevel', 
               'ScenarioTypeStrict', 'NumberOfCharacters', 'DiffNumberOFCharacters', 
               'Man', 'Woman', 'Pregnant', 'Stroller', 'OldMan', 'OldWoman', 
               'Boy', 'Girl', 'Homeless', 'LargeWoman', 'LargeMan', 'Criminal', 
               'MaleExecutive', 'FemaleExecutive', 'FemaleAthlete', 'MaleAthlete', 
               'FemaleDoctor', 'MaleDoctor', 'Dog', 'Cat']
    
    # Load dataset A and count occurrences of each unique row configuration
    df_a = pd.read_csv(dataset_a_path, usecols=columns)
    a_counts = df_a.groupby(columns).size().to_dict()

    # Initialize an empty dataframe to hold the final subset of B
    final_matching_subset_b = pd.DataFrame(columns=columns)

    # We will keep track of how many matching rows we need from B for each configuration
    needed_rows = defaultdict(int, a_counts)

    z = 0

    # Process B in chunks
    for chunk in pd.read_csv(dataset_b_path, chunksize=chunk_size, usecols=columns):

        z += 1
        print(f"Processing chunk {z}")

        # Group by columns and count occurrences in chunk B
        b_counts = chunk.groupby(columns).size().to_dict()

        # For each row configuration that A needs, check if B contains it
        for row_config, needed_count in needed_rows.items():
            if needed_count <= 0:
                continue  # If we've already collected all the needed rows, skip

            # Check if the current chunk of B contains this row configuration
            if row_config in b_counts:
                # Get the number of matching rows in the chunk
                available_count = b_counts[row_config]
                
                # Determine how many rows we need from this configuration
                rows_to_take = min(needed_count, available_count)

                # Extract the matching rows from the chunk
                matching_rows = chunk[(chunk[columns] == pd.Series(row_config, index=columns)).all(axis=1)].iloc[:rows_to_take]

                # Append them to the final subset
                final_matching_subset_b = pd.concat([final_matching_subset_b, matching_rows], ignore_index=True)

                # Decrease the needed count for that row configuration
                needed_rows[row_config] -= rows_to_take

        # Stop if we have matched all required rows
        if sum(needed_rows.values()) == 0:
            break

    return final_matching_subset_b

In [4]:
# Usage
df_a_path = '../PREP_llm_dataset.csv'
df_b_path = '../SharedResponses.csv'
new_mme_27 = create_matching_subset(df_a_path, df_b_path)

Processing chunk 1
Processing chunk 2
Processing chunk 3
Processing chunk 4
Processing chunk 5
Processing chunk 6
Processing chunk 7
Processing chunk 8
Processing chunk 9
Processing chunk 10
Processing chunk 11
Processing chunk 12
Processing chunk 13
Processing chunk 14
Processing chunk 15
Processing chunk 16
Processing chunk 17
Processing chunk 18
Processing chunk 19
Processing chunk 20
Processing chunk 21
Processing chunk 22
Processing chunk 23
Processing chunk 24
Processing chunk 25
Processing chunk 26
Processing chunk 27
Processing chunk 28
Processing chunk 29
Processing chunk 30
Processing chunk 31
Processing chunk 32
Processing chunk 33
Processing chunk 34
Processing chunk 35
Processing chunk 36
Processing chunk 37
Processing chunk 38
Processing chunk 39
Processing chunk 40
Processing chunk 41
Processing chunk 42
Processing chunk 43
Processing chunk 44
Processing chunk 45
Processing chunk 46
Processing chunk 47
Processing chunk 48
Processing chunk 49
Processing chunk 50
Processin

In [5]:
# Verify distribution
df_llm_subset = pd.read_csv(df_a_path)  # Reload A to verify distribution
for col in new_mme_27.columns:
    original_dist = df_llm_subset[col].value_counts(normalize=True)
    subset_dist = new_mme_27[col].value_counts(normalize=True)

    # Compare the distributions
    if original_dist.equals(subset_dist):
        print(f"The distribution for '{col}' matches.")
    else:
        print(f"Mismatch in distribution for '{col}'.")
        print("Original distribution:\n", original_dist)
        print("Subset distribution:\n", subset_dist)

Mismatch in distribution for 'PedPed'.
Original distribution:
 PedPed
0    0.502728
1    0.497272
Name: proportion, dtype: float64
Subset distribution:
 PedPed
0    0.501005
1    0.498995
Name: proportion, dtype: float64
Mismatch in distribution for 'Barrier'.
Original distribution:
 Barrier
0    0.748636
1    0.251364
Name: proportion, dtype: float64
Subset distribution:
 Barrier
0    0.744636
1    0.255364
Name: proportion, dtype: float64
Mismatch in distribution for 'CrossingSignal'.
Original distribution:
 CrossingSignal
0    0.622809
2    0.188983
1    0.188208
Name: proportion, dtype: float64
Subset distribution:
 CrossingSignal
0    0.629821
2    0.186199
1    0.183980
Name: proportion, dtype: float64
Mismatch in distribution for 'AttributeLevel'.
Original distribution:
 AttributeLevel
Young      0.086332
Old        0.086332
Hoomans    0.085502
Pets       0.085502
Less       0.083888
More       0.083888
Fat        0.082395
Fit        0.082395
Male       0.081562
Female     0.081

In [None]:
from scipy.stats import randint  # For defining distributions for random search

param_dist = {
    'n_estimators': randint(50, 300),                       # Number of trees in the forest (uniform distribution between 50 and 200)
    'max_depth': [None] + list(randint(1, 30).rvs(10)),     # Random depth values including None
    'min_samples_split': randint(2, 10),                    # Minimum samples required to split an internal node
    'min_samples_leaf': randint(1, 4)                       # Minimum samples required to be at a leaf node
}

In [1]:
from scipy.stats import randint  # For defining distributions for random search

print(randint(50, 300))

<scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x0000018CD1D5B550>
