In [1]:
import findspark

findspark.init()

In [3]:
import pandas as pd
import pyspark
import pyspark.pandas as ps
from datasets import load_dataset
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType
from pyspark.sql.window import Window

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
spark = SparkSession.builder \
    .appName('semantic-memorization') \
    .config('spark.driver.cores', '128') \
    .config('spark.driver.memory', '128g') \
    .config('spark.driver.memoryOverheadFactor', '0.2') \
    .master('local[*]') \
    .getOrCreate()

/home/alvin/lib/miniconda3/lib/python3.11/site-packages/pyspark/bin/load-spark-env.sh: line 68: ps: command not found
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/09/03 21:40:14 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [5]:
schema = 'duped'
model_size = '160m'
split_name = f"{schema}.{model_size}"
dataset = load_dataset('EleutherAI/pythia-memorized-evals')[split_name].to_pandas().rename(columns={'index': 'sequence_id'})
dataset.tokens = dataset.tokens.map(lambda x: x.tolist())

In [7]:
columns = ['sequence_id', 'tokens']
main_df = ps.from_pandas(dataset[columns]).to_spark().alias('main')



In [8]:
# (sequence_id, frequency)
sequence_duplicates = (
    load_dataset(f'usvsnsp/{schema}-num-duplicates')['train'].to_pandas().rename(columns={'Index': 'sequence_id', 'Counts': 'frequency'})
)
# (token_id, frequency)
memorized_frequencies = (
    load_dataset(f'usvsnsp/{schema}-num-frequencies')['memorized'].to_pandas().rename(columns={'TokenID': 'token_id', 'Frequency': 'frequency'})
)
# (token_id, frequency)
non_memorized_frequencies = (
    load_dataset(f'usvsnsp/{schema}-num-frequencies')['non_memorized']
    .to_pandas()
    .rename(columns={'TokenID': 'token_id', 'Frequency': 'frequency'})
)

In [9]:
sequence_duplicates.shape, memorized_frequencies.shape, non_memorized_frequencies.shape

((146432000, 2), (60000, 2), (60000, 2))

In [10]:
sequence_duplicates_df = ps.from_pandas(sequence_duplicates).to_spark()
memorized_frequencies_df = ps.from_pandas(memorized_frequencies).to_spark().alias('memorized')
non_memorized_frequencies_df = ps.from_pandas(non_memorized_frequencies).to_spark().alias('non_memorized')



In [11]:
# Save as parquet for efficiency
main_df.write.parquet('datasets/main')
sequence_duplicates_df.write.parquet('datasets/sequence_duplicates')
memorized_frequencies_df.write.parquet('datasets/memorized_frequencies')
non_memorized_frequencies_df.write.parquet('datasets/non_memorized_frequencies')

23/09/03 21:47:37 WARN TaskSetManager: Stage 0 contains a task of very large size (1103 KiB). The maximum recommended task size is 1000 KiB.
23/09/03 21:47:42 WARN TaskSetManager: Stage 1 contains a task of very large size (15428 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

In [13]:
main_df = spark.read.parquet('datasets/main')
sequence_duplicates_df = spark.read.parquet('datasets/sequence_duplicates').alias('sequence_dups')
memorized_frequencies_df = spark.read.parquet('datasets/memorized_frequencies').alias('memorized')
non_memorized_frequencies_df = spark.read.parquet('datasets/non_memorized_frequencies').alias('non_memorized')

In [14]:
flattened_df = main_df.select('sequence_id', F.posexplode('tokens').alias('token_index', 'token_id'))

In [15]:
flattened_df.show(5)

+-----------+-----------+--------+
|sequence_id|token_index|token_id|
+-----------+-----------+--------+
|   90261337|          0|   14592|
|   90261337|          1|   50254|
|   90261337|          2|   50275|
|   90261337|          3|      30|
|   90261337|          4|     470|
+-----------+-----------+--------+
only showing top 5 rows



In [16]:
token_frequencies_df = flattened_df \
    .join(memorized_frequencies_df, on='token_id', how='left') \
    .join(non_memorized_frequencies_df, on='token_id', how='left') \
    .select(
        'sequence_id',
        'token_index', 
        'token_id',
        F.col('memorized.frequency').alias('memorized_frequency'),
        F.col('non_memorized.frequency').alias('non_memorized_frequency'),
    )

In [17]:
token_frequencies_df.show(5)

+-----------+-----------+--------+-------------------+-----------------------+
|sequence_id|token_index|token_id|memorized_frequency|non_memorized_frequency|
+-----------+-----------+--------+-------------------+-----------------------+
|   90261337|          0|   14592|              69389|                2714522|
|   90261337|          1|   50254|           14390981|              389420308|
|   90261337|          2|   50275|            7731059|              229607629|
|   90261337|          3|      30|           20611918|              207874540|
|   90261337|          4|     470|           33857187|              349528094|
+-----------+-----------+--------+-------------------+-----------------------+
only showing top 5 rows



In [18]:
combined_df = token_frequencies_df \
    .groupby('sequence_id') \
    .agg(
        F.sort_array(F.collect_list(F.struct('token_index', 'token_id'))).alias('tokens'),
        F.sort_array(F.collect_list(F.struct('token_index', 'memorized_frequency'))).alias('memorized_frequencies'),
        F.sort_array(F.collect_list(F.struct('token_index', 'non_memorized_frequency'))).alias('non_memorized_frequencies'),
    )

In [19]:
combined_df.show()



+-----------+--------------------+---------------------+-------------------------+
|sequence_id|              tokens|memorized_frequencies|non_memorized_frequencies|
+-----------+--------------------+---------------------+-------------------------+
|      22129|[{0, 556}, {1, 70...| [{0, 3813399}, {1...|     [{0, 343335217}, ...|
|      72541|[{0, 4145}, {1, 4...| [{0, 1012762}, {1...|     [{0, 13160840}, {...|
|     156892|[{0, 186}, {1, 94...| [{0, 77685748}, {...|     [{0, 467581316}, ...|
|     158747|[{0, 9312}, {1, 1...| [{0, 142591}, {1,...|     [{0, 5726186}, {1...|
|     170393|[{0, 92}, {1, 249...| [{0, 24571276}, {...|     [{0, 612572165}, ...|
|     175031|[{0, 50276}, {1, ...| [{0, 32791020}, {...|     [{0, 736379794}, ...|
|     204535|[{0, 475}, {1, 40...| [{0, 13754443}, {...|     [{0, 353657828}, ...|
|     271690|[{0, 64}, {1, 478...| [{0, 115532598}, ...|     [{0, 1678250387},...|
|     283969|[{0, 4637}, {1, 1...| [{0, 1118588}, {1...|     [{0, 13208036}, {...|
|   

                                                                                

In [20]:
new_df = combined_df.select(
    'sequence_id',
    F.transform(F.col('tokens'), lambda x: x.token_id).alias('tokens'),
    F.transform(F.col('memorized_frequencies'), lambda x: x.memorized_frequency).alias('memorized_frequencies'),
    F.transform(F.col('non_memorized_frequencies'), lambda x: x.non_memorized_frequency).alias('non_memorized_frequencies'),
).alias('new_df')

In [21]:
new_df.show()



+-----------+--------------------+---------------------+-------------------------+
|sequence_id|              tokens|memorized_frequencies|non_memorized_frequencies|
+-----------+--------------------+---------------------+-------------------------+
|      22129|[556, 7012, 323, ...| [3813399, 78360, ...|     [343335217, 80103...|
|      72541|[4145, 48128, 320...| [1012762, 61584, ...|     [13160840, 369321...|
|     156892|[186, 94, 187, 18...| [77685748, 364496...|     [467581316, 58929...|
|     158747|[9312, 1157, 5027...| [142591, 4225660,...|     [5726186, 8639393...|
|     170393|[92, 249, 3080, 4...| [24571276, 317644...|     [612572165, 24423...|
|     175031|[50276, 5035, 253...| [32791020, 318418...|     [736379794, 17750...|
|     204535|[475, 40078, 310,...| [13754443, 20082,...|     [353657828, 51062...|
|     271690|[64, 4785, 64, 38...| [115532598, 21223...|     [1678250387, 8998...|
|     283969|[4637, 15, 187, 4...| [1118588, 1771746...|     [13208036, 989579...|
|   

                                                                                

In [22]:
final_df = new_df.join(sequence_duplicates_df, on='sequence_id', how='inner').select(
    'new_df.*',
    F.col('sequence_dups.frequency').alias('sequence_frequency'),
)

In [23]:
final_df.show()



+-----------+--------------------+---------------------+-------------------------+------------------+
|sequence_id|              tokens|memorized_frequencies|non_memorized_frequencies|sequence_frequency|
+-----------+--------------------+---------------------+-------------------------+------------------+
|      22129|[556, 7012, 323, ...| [3813399, 78360, ...|     [343335217, 80103...|                 2|
|      72541|[4145, 48128, 320...| [1012762, 61584, ...|     [13160840, 369321...|             18437|
|     156892|[186, 94, 187, 18...| [77685748, 364496...|     [467581316, 58929...|              3102|
|     158747|[9312, 1157, 5027...| [142591, 4225660,...|     [5726186, 8639393...|                 2|
|     170393|[92, 249, 3080, 4...| [24571276, 317644...|     [612572165, 24423...|                 2|
|     175031|[50276, 5035, 253...| [32791020, 318418...|     [736379794, 17750...|              2835|
|     204535|[475, 40078, 310,...| [13754443, 20082,...|     [353657828, 51062...|

                                                                                

In [24]:
final_df.write.parquet('datasets/final_dataset')

                                                                                

In [25]:
final_df

DataFrame[sequence_id: bigint, tokens: array<bigint>, memorized_frequencies: array<bigint>, non_memorized_frequencies: array<bigint>, sequence_frequency: bigint]