In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import col, sum, lit, concat

## Q1 

In [2]:
s1 = SparkSession.builder.appName("Q1").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/03 12:19:36 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
train = s1.read.parquet("Task1/squad_v2/squad_v2/train-00000-of-00001.parquet")
test = s1.read.parquet("Task1/squad_v2/squad_v2/validation-00000-of-00001.parquet")


                                                                                

In [4]:
train.printSchema()

root
 |-- id: string (nullable = true)
 |-- title: string (nullable = true)
 |-- context: string (nullable = true)
 |-- question: string (nullable = true)
 |-- answers: struct (nullable = true)
 |    |-- text: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |    |-- answer_start: array (nullable = true)
 |    |    |-- element: integer (containsNull = true)



In [5]:
train_null_counts = train.select([sum(col(c).isNull().cast("int")).alias(c) for c in train.columns])
train_null_counts.show()

test_null_counts = test.select([sum(col(c).isNull().cast("int")).alias(c) for c in test.columns])
test_null_counts.show()

                                                                                

+---+-----+-------+--------+-------+
| id|title|context|question|answers|
+---+-----+-------+--------+-------+
|  0|    0|      0|       0|      0|
+---+-----+-------+--------+-------+

+---+-----+-------+--------+-------+
| id|title|context|question|answers|
+---+-----+-------+--------+-------+
|  0|    0|      0|       0|      0|
+---+-----+-------+--------+-------+



The second field of answers is the starting position of the answer in the context. To extract the answer, we can use answers.text to read the first array, and then take the zero element.

Check if any na in answers' text:

In [6]:
train_data = train.withColumn("Input", concat(lit("question: "), col("question"), lit(" context: "), col("context"))) \
                .withColumn("Output", col("answers.text")[0]) \
                .select("Input", "Output")
test_data = test.withColumn("Input", concat(lit("question: "), col("question"), lit(" context: "), col("context"))) \
                .withColumn("Output", col("answers.text")[0]) \
                .select("Input", "Output")

The Output column possibly contains null values:

In [7]:
na_counts = train_data.select([sum(col(c).isNull().cast("int")).alias(c) for c in train_data.columns])
na_counts.show()

+-----+------+
|Input|Output|
+-----+------+
|    0| 43498|
+-----+------+



In [8]:
na_counts = test_data.select([sum(col(c).isNull().cast("int")).alias(c) for c in test_data.columns])
na_counts.show()

+-----+------+
|Input|Output|
+-----+------+
|    0|  5945|
+-----+------+



In [9]:
train_data = train_data.na.drop()
print(train_data.count())
train_data.show(5,truncate=False)

86821
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------------+
|Input                                                                                                                                                                                           

In [10]:
shuffled_data = train_data.orderBy(F.rand(seed=42))
val_data = shuffled_data.limit(5000)
train_data_pre = shuffled_data.subtract(val_data)
print(val_data.count())
print(train_data_pre.count())



5000




81798


                                                                                

$81798+5000 \neq 86821$ ! Let's check dumplicate values in the train data:

In [11]:
duplicate_rows = train_data.groupBy(train_data.columns).count().filter("count > 1")
print(duplicate_rows.count())
duplicate_rows.show(5,truncate=False)

                                                                                

23




+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                

23 rows duplicated with others. We should filter them out:

In [12]:
train_data = train_data.dropDuplicates()
shuffled_data = train_data.orderBy(F.rand(seed=42))
val_data = shuffled_data.limit(5000)
train_data = shuffled_data.subtract(val_data)
print(val_data.count())
print(train_data.count())

                                                                                

5000


[Stage 46:>                                                         (0 + 2) / 2]

81798


                                                                                

In [13]:
train_df = train_data.toPandas()
valid_df = val_data.toPandas()
test_df = test_data.toPandas()

                                                                                

In [14]:
train_df.to_csv("train.csv", index=False)
valid_df.to_csv("valid.csv", index=False)
test_df.to_csv("test.csv", index=False)