# Dataset extractor
Disclaymer: To run this notebook, launch pyspark (command "pyspark --master local[*number of cores*]") from the folder containing the notebook.

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import length, regexp_replace
import os

In [None]:
spark = SparkSession.builder \
    .config("spark.executor.memory", "32g") \
    .config("spark.driver.memory", "32g") \
    .config("spark.network.timeout", "1200s") \
    .config("spark.executor.memoryOverhead", "12g")\
    .config("spark.default.parallelism", "25")\
    .config("spark.executor.heartbeatInterval", "120s")\
    .master('local[*]')\
    .getOrCreate()

# Read dataset
dataset = spark.read.format("json").load("data/All_Amazon_Review.json", schema="overall float, reviewText string")

In [None]:
# Define number of examples to take and limit for number of characters.
limit = 1000000
limit_characters = 2000

In [None]:
# Filter by overall
dataset1 = dataset.filter(dataset["overall"] == 1.0)

# Filter by text size 
dataset1 = dataset1.filter(length(regexp_replace(dataset1["reviewText"], "\s+", "")) <= limit_characters)

# Add a new column with an auto-incrementing ID
dataset1 = dataset1.sample(False, limit / dataset1.count() if limit / dataset1.count() <= 1.0 else 1.0)

In [None]:
# Filter by overall
dataset2 = dataset.filter(dataset["overall"] == 2.0)

# Filter by text size 
dataset2 = dataset2.filter(length(regexp_replace(dataset2["reviewText"], "\s+", "")) <= limit_characters)

# Add a new column with an auto-incrementing ID
dataset2 = dataset2.sample(False, limit / dataset2.count() if limit / dataset2.count() <= 1.0 else 1.0)

In [None]:
# Filter by overall
dataset3 = dataset.filter(dataset["overall"] == 3.0)

# Filter by text size 
dataset3 = dataset3.filter(length(regexp_replace(dataset3["reviewText"], "\s+", "")) <= limit_characters)

# Add a new column with an auto-incrementing ID
dataset3 = dataset3.sample(False, limit / dataset3.count() if limit / dataset3.count() <= 1.0 else 1.0)

In [None]:
# Filter by overall
dataset4 = dataset.filter(dataset["overall"] == 4.0)

# Filter by text size 
dataset4 = dataset4.filter(length(regexp_replace(dataset4["reviewText"], "\s+", "")) <= limit_characters)

# Add a new column with an auto-incrementing ID
dataset4 = dataset4.sample(False, limit / dataset4.count() if limit / dataset4.count() <= 1.0 else 1.0)

In [None]:
# Filter by overall
dataset5 = dataset.filter(dataset["overall"] == 5.0)

# Filter by text size 
dataset5 = dataset5.filter(length(regexp_replace(dataset5["reviewText"], "\s+", "")) <= limit_characters)

# Add a new column with an auto-incrementing ID
dataset5 = dataset5.sample(False, limit / dataset5.count() if limit / dataset5.count() <= 1.0 else 1.0)

In [None]:
# Merge the datasets
merged_dataset = dataset1.union(dataset2).union(dataset3).union(dataset4).union(dataset5)

In [None]:
output_path = "data/output"
merged_dataset.repartition(1)
merged_dataset.write.json(output_path, mode="overwrite", lineSep="\n")

In [None]:
directory = "data/output"
output_file = 'data/dataset.jsonl'

with open(output_file, 'w') as outfile:
    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            file_path = os.path.join(directory, filename)
            with open(file_path, 'r') as file:
                lines = file.readlines()
                for line in lines:
                    if line.strip():  # Skip empty lines
                        outfile.write(line)