In [1]:
import findspark

findspark.init()

In [2]:
import pandas as pd
import pyspark
import pyspark.pandas as ps
from datasets import load_dataset
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType
from pyspark.sql.window import Window

  from .autonotebook import tqdm as notebook_tqdm


In [93]:
evals = load_dataset("EleutherAI/pythia-memorized-evals", split='duped.70m')

In [95]:
evals

Dataset({
    features: ['index', 'tokens', '__index_level_0__'],
    num_rows: 463953
})

In [94]:
evals.to_pandas()

Unnamed: 0,index,tokens,__index_level_0__
0,232,"[996, 186, 29, 1088, 7392, 568, 29860, 5264, 2...",232
1,764,"[599, 17585, 423, 92, 3728, 12945, 423, 92, 33...",764
2,806,"[313, 39386, 27, 19939, 428, 5270, 310, 1239, ...",806
3,891,"[94, 187, 50262, 61, 2099, 92, 8798, 94, 187, ...",891
4,1060,"[4022, 305, 48095, 4477, 15, 187, 475, 187, 47...",1060
...,...,...,...
463948,146431433,"[2032, 748, 748, 582, 898, 558, 187, 50274, 82...",2287433
463949,146431569,"[544, 18, 15, 17, 13, 470, 15, 17, 9502, 187, ...",2287569
463950,146431580,"[4, 27954, 187, 4, 604, 10807, 64, 4785, 64, 3...",2287580
463951,146431652,"[4637, 15, 187, 475, 1422, 778, 4044, 247, 349...",2287652


In [3]:
spark = SparkSession.builder \
    .appName('semantic-memorization') \
    .config('spark.driver.cores', '128') \
    .config('spark.driver.memory', '128g') \
    .config('spark.driver.memoryOverheadFactor', '0.2') \
    .master('local[*]') \
    .getOrCreate()

/home/alvin/lib/miniconda3/lib/python3.11/site-packages/pyspark/bin/load-spark-env.sh: line 68: ps: command not found
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/09/02 22:27:49 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [99]:
type(spark)

pyspark.sql.session.SparkSession

In [4]:
dd = load_dataset('EleutherAI/pile-deduped-pythia-random-sampled')

In [10]:
real = dd['train'].to_pandas()

In [15]:
real.dtypes

Index       int64
70M       float64
160M      float64
410M      float64
1B        float64
1.4B      float64
2.8B      float64
6.9B      float64
12B       float64
Tokens     object
dtype: object

In [None]:
schema = 'deduped'
model_size = '70m'
split_name = f"{schema}.{model_size}"
dataset = load_dataset('EleutherAI/pythia-memorized-evals')[split_name].to_pandas().rename(columns={'index': 'sequence_id'})
dataset.tokens = dataset.tokens.map(lambda x: x.tolist())

In [None]:
columns = ['sequence_id', 'tokens']
main_df = ps.from_pandas(dataset[columns]).to_spark().alias('main')

In [None]:
# (sequence_id, frequency)
sequence_duplicates = (
    load_dataset(f'usvsnsp/{schema}-num-duplicates')['train'].to_pandas().rename(columns={'Index': 'sequence_id', 'Counts': 'frequency'})
)
# (token_id, frequency)
memorized_frequencies = (
    load_dataset(f'usvsnsp/{schema}-num-frequencies')['memorized'].to_pandas().rename(columns={'TokenID': 'token_id', 'Frequency': 'frequency'})
)
# (token_id, frequency)
non_memorized_frequencies = (
    load_dataset(f'usvsnsp/{schema}-num-frequencies')['non_memorized']
    .to_pandas()
    .rename(columns={'TokenID': 'token_id', 'Frequency': 'frequency'})
)

In [None]:
sequence_duplicates.shape, memorized_frequencies.shape, non_memorized_frequencies.shape

In [None]:
sequence_duplicates_df = ps.from_pandas(sequence_duplicates).to_spark()
memorized_frequencies_df = ps.from_pandas(memorized_frequencies).to_spark().alias('memorized')
non_memorized_frequencies_df = ps.from_pandas(non_memorized_frequencies).to_spark().alias('non_memorized')

In [None]:
# Save as parquet for efficiency
main_df.write.parquet('datasets/main')
sequence_duplicates_df.write.parquet('datasets/sequence_duplicates')
memorized_frequencies_df.write.parquet('datasets/memorized_frequencies')
non_memorized_frequencies_df.write.parquet('datasets/non_memorized_frequencies')

In [72]:
main_df = spark.read.parquet('datasets/main')
sequence_duplicates_df = spark.read.parquet('datasets/sequence_duplicates').alias('sequence_dups')
memorized_frequencies_df = spark.read.parquet('datasets/memorized_frequencies').alias('memorized')
non_memorized_frequencies_df = spark.read.parquet('datasets/non_memorized_frequencies').alias('non_memorized')

In [34]:
main_ps = ps.read_parquet('datasets/main')



In [5]:
flattened_df = main_df.select('sequence_id', F.posexplode('tokens').alias('token_index', 'token_id'))

In [6]:
flattened_df.show(5)

+-----------+-----------+--------+
|sequence_id|token_index|token_id|
+-----------+-----------+--------+
|   89059350|          0|    4090|
|   89059350|          1|      64|
|   89059350|          2|    2606|
|   89059350|          3|      16|
|   89059350|          4|    1286|
+-----------+-----------+--------+
only showing top 5 rows



In [7]:
token_frequencies_df = flattened_df \
    .join(memorized_frequencies_df, on='token_id', how='left') \
    .join(non_memorized_frequencies_df, on='token_id', how='left') \
    .select(
        'sequence_id',
        'token_index', 
        'token_id',
        F.col('memorized.frequency').alias('memorized_frequency'),
        F.col('non_memorized.frequency').alias('non_memorized_frequency'),
    )

In [8]:
token_frequencies_df.show(5)

                                                                                

+-----------+-----------+--------+-------------------+-----------------------+
|sequence_id|token_index|token_id|memorized_frequency|non_memorized_frequency|
+-----------+-----------+--------+-------------------+-----------------------+
|   89059350|          0|    4090|             460524|               15608026|
|   89059350|          1|      64|           59295356|             1600702498|
|   89059350|          2|    2606|             574217|               26001181|
|   89059350|          3|      16|           30728404|              785519346|
|   89059350|          4|    1286|             726537|               30304792|
+-----------+-----------+--------+-------------------+-----------------------+
only showing top 5 rows



In [9]:
combined_df = token_frequencies_df \
    .groupby('sequence_id') \
    .agg(
        F.sort_array(F.collect_list(F.struct('token_index', 'token_id'))).alias('tokens'),
        F.sort_array(F.collect_list(F.struct('token_index', 'memorized_frequency'))).alias('memorized_frequencies'),
        F.sort_array(F.collect_list(F.struct('token_index', 'non_memorized_frequency'))).alias('non_memorized_frequencies'),
    )

In [98]:
evals.to_parquet('hello')

Creating parquet from Arrow format: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 464/464 [00:00<00:00, 626.40ba/s]


246822996

In [10]:
combined_df.show()



+-----------+--------------------+---------------------+-------------------------+
|sequence_id|              tokens|memorized_frequencies|non_memorized_frequencies|
+-----------+--------------------+---------------------+-------------------------+
|      31156|[{0, 17706}, {1, ...| [{0, 131653}, {1,...|     [{0, 1993548}, {1...|
|     100847|[{0, 740}, {1, 83...| [{0, 2741218}, {1...|     [{0, 131645030}, ...|
|     166578|[{0, 253}, {1, 66...| [{0, 75092269}, {...|     [{0, 8410663429},...|
|     206745|[{0, 187}, {1, 18...| [{0, 228489488}, ...|     [{0, 11512507473}...|
|     210098|[{0, 305}, {1, 27...| [{0, 734372}, {1,...|     [{0, 70087349}, {...|
|     586758|[{0, 3003}, {1, 2...| [{0, 768757}, {1,...|     [{0, 21889588}, {...|
|     622533|[{0, 5118}, {1, 3...| [{0, 83047}, {1, ...|     [{0, 11073092}, {...|
|     632923|[{0, 8379}, {1, 8...| [{0, 72092}, {1, ...|     [{0, 6324830}, {1...|
|     817683|[{0, 431}, {1, 10...| [{0, 3350852}, {1...|     [{0, 27260044}, {...|
|   

                                                                                

In [65]:
new_df = combined_df.select(
    'sequence_id',
    F.transform(F.col('tokens'), lambda x: x.token_id).alias('tokens'),
    F.transform(F.col('memorized_frequencies'), lambda x: x.memorized_frequency).alias('memorized_frequencies'),
    F.transform(F.col('non_memorized_frequencies'), lambda x: x.non_memorized_frequency).alias('non_memorized_frequencies'),
).alias('new_df')

In [86]:
type(new_df)

pyspark.sql.dataframe.DataFrame

In [66]:
new_df.show()

+-----------+--------------------+---------------------+-------------------------+
|sequence_id|              tokens|memorized_frequencies|non_memorized_frequencies|
+-----------+--------------------+---------------------+-------------------------+
|     139584|[551, 69, 889, 20...| [9750575, 5888180...|     [221744263, 28311...|
|     182943|[15, 5056, 50275,...| [159811713, 24814...|     [10186570740, 875...|
|     187527|[309, 476, 626, 4...| [8095850, 4370056...|     [1034833708, 4347...|
|     266104|[11296, 15, 505, ...| [169697, 15981171...|     [3256110, 1018657...|
|     429891|[3863, 407, 253, ...| [286799, 7381820,...|     [18429900, 767613...|
|     462298|[627, 369, 1077, ...| [1947802, 6813024...|     [247532725, 10062...|
|     559720|[426, 16375, 4399...| [14615048, 135329...|     [576670219, 20591...|
|     586559|[14749, 5264, 66,...| [196079, 2796188,...|     [1712636, 1078081...|
|     601009|[187, 187, 5146, ...| [228489488, 22848...|     [11512507473, 115...|
|   

In [76]:
final_df = new_df.join(sequence_duplicates_df, on='sequence_id', how='inner').select(
    'new_df.*',
    F.col('sequence_dups.frequency').alias('sequence_frequency'),
)

In [78]:
final_df.show()



+-----------+--------------------+---------------------+-------------------------+------------------+
|sequence_id|              tokens|memorized_frequencies|non_memorized_frequencies|sequence_frequency|
+-----------+--------------------+---------------------+-------------------------+------------------+
|      31156|[17706, 5803, 256...| [131653, 279533, ...|     [1993548, 7958488...|              3021|
|     100847|[740, 8375, 187, ...| [2741218, 197043,...|     [131645030, 63006...|                 1|
|     166578|[253, 669, 8604, ...| [75092269, 140530...|     [8410663429, 1869...|             91591|
|     206745|[187, 187, 510, 4...| [228489488, 22848...|     [11512507473, 115...|                 2|
|     210098|[305, 27, 470, 13...| [734372, 33055803...|     [70087349, 129409...|              1484|
|     586758|[3003, 25900, 154...| [768757, 107633, ...|     [21889588, 971433...|             16465|
|     622533|[5118, 347, 27, 1...| [83047, 8201742, ...|     [11073092, 939874...|

                                                                                

In [80]:
final_df.write.parquet('datasets/final_dataset')

                                                                                

In [81]:
final_df

DataFrame[sequence_id: bigint, tokens: array<bigint>, memorized_frequencies: array<bigint>, non_memorized_frequencies: array<bigint>, sequence_frequency: bigint]