Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
alvin319 committed Sep 15, 2023
1 parent 625c3d4 commit ac732a6
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 55 deletions.
19 changes: 6 additions & 13 deletions calculate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
SPARK: SparkSession = initialize_spark()
PIPELINE.register_spark_session(SPARK)


def parse_cli_args():
"""
Parse the command line arguments for the script.
Expand Down Expand Up @@ -141,10 +142,7 @@ def load_dataset(dataset_name: str, scheme: str, model_size: str) -> DataFrame:
# We'll also rename the memorization score column for consistency.
dataset = dataset[required_columns].rename(columns={model_size: "memorization_score"})
elif is_test:
dataset = (
hf_load_dataset(hf_dataset_name, split="train")
.to_pandas()
)
dataset = hf_load_dataset(hf_dataset_name, split="train").to_pandas()
dataset.tokens = dataset.tokens.map(lambda x: x.tolist())
else:
dataset = hf_load_dataset(hf_dataset_name, split=split_name).to_pandas().rename(columns={"index": "sequence_id"})
Expand All @@ -160,7 +158,7 @@ def load_dataset(dataset_name: str, scheme: str, model_size: str) -> DataFrame:
return SPARK.read.parquet(cache_path)


def load_precomputed_features(schema: str, is_test = False) -> Dict[PrecomputedFeatureName, DataFrame]:
def load_precomputed_features(schema: str, is_test=False) -> Dict[PrecomputedFeatureName, DataFrame]:
"""
Load the pre-computed features from HuggingFace datasets. If the features are not locally available, then
download them from HuggingFace datasets and cache them as Spark DataFrames in Parquet format.
Expand All @@ -174,11 +172,7 @@ def load_precomputed_features(schema: str, is_test = False) -> Dict[PrecomputedF
"""
features = {}
hf_dataset_names = [
(
PrecomputedFeatureName.SEQUENCE_FREQUENCIES,
f"usvsnsp/{schema}-num-duplicates",
"train",
{"Index": "sequence_id", "Counts": "frequency"}),
(PrecomputedFeatureName.SEQUENCE_FREQUENCIES, f"usvsnsp/{schema}-num-duplicates", "train", {"Index": "sequence_id", "Counts": "frequency"}),
(
PrecomputedFeatureName.MEMORIZED_TOKEN_FREQUENCIES,
f"usvsnsp/{schema}-num-frequencies",
Expand Down Expand Up @@ -259,14 +253,13 @@ def main():
if args.sample_seed is not None:
LOGGER.info(f"Sample seed: {args.sample_seed}")
LOGGER.info("---------------------------------------------------------------------------")



for model_size in args.models if isinstance(args.models, list) else args.models.split(","):
for dataset_name in args.datasets if isinstance(args.datasets, list) else args.datasets.split(","):
is_test = dataset_name == "test"
for data_scheme in args.schemes if isinstance(args.schemes, list) else args.schemes.split(","):
LOGGER.info("Loading pre-computed features...")
precomputed_features = load_precomputed_features(data_scheme, is_test = is_test)
precomputed_features = load_precomputed_features(data_scheme, is_test=is_test)
PIPELINE.register_features(precomputed_features)

split_name = f"{data_scheme}.{model_size}"
Expand Down
70 changes: 38 additions & 32 deletions filters/highly_repetitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .base import PIPELINE_SINGLETON


def break_and_compare(ls: list, k: int) -> list:
"""
This function takes a list ls and an integer k as input and returns a list which is the first chunk of ls that is repeated k times. If no such chunk exists, it returns an empty list.
Expand All @@ -22,7 +23,7 @@ def break_and_compare(ls: list, k: int) -> list:
chunk_size = n // k
while len(residual) < chunk_size:
# split into chunks
chunks = [to_break[i:i + chunk_size] for i in range(0, len(to_break), chunk_size)]
chunks = [to_break[i : i + chunk_size] for i in range(0, len(to_break), chunk_size)]
chunksMatch = True
# compare all chunks to first chunk
for chunk in chunks[1:]:
Expand All @@ -31,18 +32,19 @@ def break_and_compare(ls: list, k: int) -> list:
break
if chunksMatch:
# compare residual to first chunk
if residual == chunks[0][:len(residual)]:
if residual == chunks[0][: len(residual)]:
return chunks[0]
chunk_size -= 1
new_residual = to_break[chunk_size * k:]
to_break = to_break[:chunk_size * k]
new_residual = to_break[chunk_size * k :]
to_break = to_break[: chunk_size * k]
residual = new_residual + residual
return []


def break_and_compare_wrapper(ls: list, start_k: int, end_k: int) -> list:
"""
This function serves as a wrapper for the `break_and_compare` function. It takes an additional two integer parameters `start_k` and `end_k` to define a range of values for `k`.
This function serves as a wrapper for the `break_and_compare` function. It takes an additional two integer parameters `start_k` and `end_k` to define a range of values for `k`.
It iterates over this range and calls `break_and_compare` for each value of `k` within the range.
Parameters:
Expand All @@ -61,13 +63,13 @@ def break_and_compare_wrapper(ls: list, start_k: int, end_k: int) -> list:
rem = 2
# when rem = 0 -> 0.91 0.73 0.81
# when rem = 1 -> 0.91 0.78 0.84
# when rem = 2 -> 0.90 0.80 0.84
# when rem = 2 -> 0.90 0.80 0.84
# when rem = 3 -> 0.89 0.80 0.84
# when rem = 4 -> 0.89 0.80 0.84
# when rem = 5 -> 0.89 0.80 0.84
# when rem = 6 -> 0.89 0.80 0.84
for j in range(0, rem+1):
result = break_and_compare(ls[i:length - j], k)
for j in range(0, rem + 1):
result = break_and_compare(ls[i : length - j], k)
if result:
return result, i, k
result = break_and_compare(ls[i:], k)
Expand All @@ -78,21 +80,22 @@ def break_and_compare_wrapper(ls: list, start_k: int, end_k: int) -> list:
return result, 0, k
return [], 0, -1


def find_smallest_repeating_unit(lst):
if lst is None:
return []
n = len(lst)

# Try all possible lengths of repeating units
for unit_length in range(1, n // 2 + 1):
# Check if the list can be divided into repeating units of the current length
if n % unit_length == 0:
unit = lst[:unit_length] # Extract a potential repeating unit

# Check if the entire list can be formed by repeating the unit
if all(lst[i:i + unit_length] == unit for i in range(0, n, unit_length)):
if all(lst[i : i + unit_length] == unit for i in range(0, n, unit_length)):
return unit

# If no repeating unit is found, the list itself is the smallest repeating unit
return lst

Expand All @@ -112,19 +115,21 @@ def highly_repetitive_filter(dataset: DataFrame, _) -> DataFrame:
DataFrame: with additional column of `is_incrementing`
"""
main = dataset.alias("main")
repetitive_schema = T.StructType([
T.StructField("num_repeating", T.IntegerType()),
T.StructField("repeating_offset", T.IntegerType()),
T.StructField("repeating_chunk", T.ArrayType(T.LongType()))
])
repetitive_schema = T.StructType(
[
T.StructField("num_repeating", T.IntegerType()),
T.StructField("repeating_offset", T.IntegerType()),
T.StructField("repeating_chunk", T.ArrayType(T.LongType())),
]
)
repetitiveUDF = F.udf(lambda seq: break_and_compare_wrapper(seq, 2, 5), repetitive_schema)
smallest_repeating_chunkUDF = F.udf(lambda seq: find_smallest_repeating_unit(seq), T.ArrayType(T.LongType()))


repetitive_counts = main.select("sequence_id", "text").withColumn("repetitive", repetitiveUDF("text"))
repetitive_counts = repetitive_counts.withColumn("smallest_repeating_chunk", smallest_repeating_chunkUDF("repetitive.repeating_chunk"))

final = (repetitive_counts.join(main, on="sequence_id", how="left")

final = (
repetitive_counts.join(main, on="sequence_id", how="left")
.drop(repetitive_counts.sequence_id)
.drop(repetitive_counts.text)
.drop(repetitive_counts.repetitive.repeating_chunk)
Expand All @@ -137,21 +142,22 @@ def highly_repetitive_filter(dataset: DataFrame, _) -> DataFrame:

return final


if __name__ == "__main__":
# from transformers import AutoTokenizer
# inp = """0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
# 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff"""
# tokenizer = AutoTokenizer.from_pretrained(
# "EleutherAI/pythia-70m-deduped",
# )
# inp = tokenizer(inp)['input_ids']
# print(inp)
# # for token in inp:
# # print(token, tokenizer.decode(token))
# print(break_and_compare_wrapper(inp, 2, 30))
# from transformers import AutoTokenizer
# inp = """0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
# 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff"""
# tokenizer = AutoTokenizer.from_pretrained(
# "EleutherAI/pythia-70m-deduped",
# )
# inp = tokenizer(inp)['input_ids']
# print(inp)
# # for token in inp:
# # print(token, tokenizer.decode(token))
# print(break_and_compare_wrapper(inp, 2, 30))
ls = [1]
start_k = 1
end_k = 3
expected = ([1], 1)
output = break_and_compare_wrapper(ls, start_k, end_k)
print(output)
print(output)
6 changes: 4 additions & 2 deletions filters/pattern_incrementing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def replace_non_numeric_with_whitespace(text: str) -> str:
new_text = ""
for i in range(len(text)):
if text[i].isdigit():
new_text += str(unicodedata.digit(text[i])) # Fix for characters like '²' not being converted as required
new_text += str(unicodedata.digit(text[i])) # Fix for characters like '²' not being converted as required
elif text[i] == "." and i > 0 and i < len(text) - 1 and text[i - 1].isdigit() and text[i + 1].isdigit():
new_text += text[i]
else:
Expand Down Expand Up @@ -82,7 +82,7 @@ def incrementing_sequences_filter_wrapper(text: str) -> bool:
# If length of list is 1, the sequence is not an incrementing pattern
if len(ls) <= 1:
return False

ptr = 0
min_max = {}
chunk_num = 0
Expand Down Expand Up @@ -314,6 +314,7 @@ def incrementing_sequences_filter_wrapper(text: str) -> bool:

return False


@PIPELINE_SINGLETON.register_filter()
def incrementing_sequences_filter(dataset: DataFrame, _) -> DataFrame:
"""Returns if a sequence is incrementing
Expand All @@ -332,6 +333,7 @@ def incrementing_sequences_filter(dataset: DataFrame, _) -> DataFrame:

return final


if __name__ == "__main__":
samp = r"""
"A.1 , A.2 , A.3 , A.4, B.1 , B.2, B.3, C.1"
Expand Down
2 changes: 1 addition & 1 deletion filters/test_pattern_incrementing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ def test_incrementing_nonnumerical_pattern():

def test_incrementing_seminnumerical_pattern():
text = "A.1 , A.2 , A.3 , A.4, B.1 , B.2, B.3, C.1"
assert incrementing_sequences_filter_wrapper(text) == True
assert incrementing_sequences_filter_wrapper(text) == True
8 changes: 1 addition & 7 deletions filters/token_frequency_statistics_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,5 @@ def token_frequency_statistics_filter(dataset: DataFrame, features: PrecomputedF
).alias("filtered")

# Finally, re-attach the memorization score from the original dataset
final = (filtered_frequencies.join(main, on="sequence_id", how="left")
.drop(filtered_frequencies.sequence_id)
.select(
"main.*",
"filtered.*"
)
)
final = filtered_frequencies.join(main, on="sequence_id", how="left").drop(filtered_frequencies.sequence_id).select("main.*", "filtered.*")
return final

0 comments on commit ac732a6

Please sign in to comment.