# Preparation

In [18]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from pyspark.sql import SparkSession
from pyspark.sql import Row
# from pyspark.sql.types import StringType, StructType, StructField
from pyspark.sql.types import *
# import pyspark.sql.functions as F
from pyspark.sql.functions import *

import warnings
warnings.filterwarnings('ignore')

In [19]:
spark = SparkSession.builder.appName("proj2_q1").config("spark.executor.memory", "8g").config("spark.driver.memory", "8g").getOrCreate()

# Q1

## read and filter

In [20]:
test = spark.read.format("parquet").load("data/squad_v2/validation-00000-of-00001.parquet")
test.take(1)[0]

Row(id='56ddde6b9a695914005b9628', title='Normans', context='The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries.', question='In what country is Normandy located?', answers=Row(text=['France', 'France', 'France', 'France'], answer_start=[159, 159, 159, 159]))

In [21]:
original_training = spark.read.format("parquet").load("data/squad_v2/train-00000-of-00001.parquet")
original_training.show(3,truncate=False)

+------------------------+-------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------+------------------------------+
|id                      |title  |context                                                                                                                                     

In [22]:
original_training.printSchema()

root
 |-- id: string (nullable = true)
 |-- title: string (nullable = true)
 |-- context: string (nullable = true)
 |-- question: string (nullable = true)
 |-- answers: struct (nullable = true)
 |    |-- text: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |    |-- answer_start: array (nullable = true)
 |    |    |-- element: integer (containsNull = true)



In [23]:
original_training.count()

130319

In [24]:
original_training.select("answers.text").show(10,truncate=False)

+---------------------+
|text                 |
+---------------------+
|[in the late 1990s]  |
|[singing and dancing]|
|[2003]               |
|[Houston, Texas]     |
|[late 1990s]         |
|[Destiny's Child]    |
|[Dangerously in Love]|
|[Mathew Knowles]     |
|[late 1990s]         |
|[lead singer]        |
+---------------------+
only showing top 10 rows



The len of question should be checked as some questions are empty. Besides, some answers are empty as well, which can help LLM to train the ability of seeing if there's a real answer in the given content. Namely, the empty answer can relieve hallucinations in LLM.

`answers.answer_start` is dropped as it is not useful for the task.

In [25]:
original_training = original_training.withColumn("num_answers", size(original_training["answers.text"]))
original_training.groupBy("num_answers").count().show()

# original_training_filtered_answer = original_training.filter(original_training.num_answers == 1)
# original_training_filtered_answer.groupBy("num_answers").count().show()
original_training.filter(original_training.num_answers == 0).select("question","answers.text").show(5)

+-----------+-----+
|num_answers|count|
+-----------+-----+
|          1|86821|
|          0|43498|
+-----------+-----+

+--------------------+----+
|            question|text|
+--------------------+----+
|What category of ...|  []|
|What consoles can...|  []|
|When was Australi...|  []|
|When could GameCu...|  []|
|What year was the...|  []|
+--------------------+----+
only showing top 5 rows



In [26]:
# original_training_filtered_answer.withColumn("question_length",length("question")).groupBy("question_length").count().orderBy("question_length").show(5)
original_training.withColumn("question_length",length("question")).orderBy("question_length").select("question","question_length","answers.text", "context").show(10)

+------------+---------------+--------------------+--------------------+
|    question|question_length|                text|             context|
+------------+---------------+--------------------+--------------------+
|           d|              1|           [the Gre]|The Hellenistic p...|
|          dd|              2|          [Buddhism]|The Hellenistic p...|
| What means |             11|           [Advaita]|Advaita literally...|
|What is OAS?|             12|[Organization of ...|Further conventio...|
|What is kef?|             12|[Armenian dance m...|The Armenian Geno...|
|What is RNA?|             12|[a second type of...|The expression of...|
|What is USB?|             12|      [a serial bus]|USB is a serial b...|
|Himachal is?|             12|[multireligional,...|Himachal was one ...|
|What is DIS?|             12|[a fusion of the ...|At the decision-m...|
|What is IBS?|             12|[irritable bowel ...|Another possible ...|
+------------+---------------+--------------------+

In [27]:
training_filter = original_training.filter((col("question") != "d") & (col("question") != "dd"))
training_filter.count()
# training_filter.show()

130317

## concat two fields and split

concat question and context

In [28]:
test = test.withColumn("input",concat(lit("question: "), col("question"), lit(" context: "), col("context")))

In [29]:
training_filter = training_filter.withColumn("input", concat(lit("question: "), col("question"), lit(" context: "), col("context")))
training_filter.select("question", "input").show(2,truncate=False)

+----------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|question                                                  |input                                                                                                     

split origin_train to train and validation: `randomsample()` can not designate the specific number of samples to be selected. Also, `sample()` is not used as it can not promise to choose the number we give. Hence, to select 5000 samples exactly, `shuffle()` can be used  to shuffle the data and then select the first 5000 samples.

In [30]:
training_filter = training_filter.orderBy(rand(seed = 42))
training_filter.cache()
validation  = training_filter.limit(5000)
train = training_filter.exceptAll(validation)

In [31]:
print(train.count())
print(validation.count())

                                                                                

125317
5000


In [39]:
# test.toPandas().to_csv("data/mydata/mytest.csv")
# train.toPandas().to_csv("data/mydata/mytrain.csv")
# validation.toPandas().to_csv("data/mydata/myvalidation.csv")

                                                                                

In [40]:
# test.toPandas().head(20).to_csv("data/mydata/mytest_20.csv")
# train.toPandas().head(20).to_csv("data/mydata/mytrain_20.csv")
# validation.toPandas().head(20).to_csv("data/mydata/myvalidation_20.csv")

                                                                                

# Q2

In [32]:
# import os

# import numpy as np
# import evaluate
# from datasets import load_dataset
# from transformers import (
#     Trainer,
#     TrainingArguments,
#     AutoTokenizer,
#     T5Tokenizer, 
#     T5ForConditionalGeneration
# )

# import ray.train.huggingface.transformers
# from ray.train import ScalingConfig
# from ray.train.torch import TorchTrainer


In [33]:

# # 加载数据集
# # 加载模型和分词器
# model_name = "./flan-t5-small"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = T5ForConditionalGeneration.from_pretrained(model_name)

In [34]:
# input_text = "Where is the capital of China?"
# input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# outputs = model.generate(input_ids)
# print(tokenizer.decode(outputs[0]))

In [35]:
# # 数据预处理
# def preprocess_function(examples):
#     inputs = [q.strip() for q in examples['question']]
#     targets = [a['text'][0].strip() for a in examples['answers']]
#     model_inputs = tokenizer(inputs, max_length=512, truncation=True)
#     # 设置模型的目标
#     with tokenizer.as_target_tokenizer():
#         labels = tokenizer(targets, max_length=512, truncation=True)
#     model_inputs['labels'] = labels['input_ids']
#     return model_inputs
# tokenized_datasets = dataset.map(preprocess_function, batched=True)


In [36]:
# # 设置训练参数
# training_args = TrainingArguments(
#     output_dir='./results',
#     evaluation_strategy='epoch',
#     learning_rate=2e-5,
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     num_train_epochs=3,
#     weight_decay=0.01,
#     save_total_limit=1,
# )
# # 定义 Trainer
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=tokenized_datasets['train'],
#     eval_dataset=tokenized_datasets['validation'],
# )
# # 训练模型
# trainer.train()
# # 保存模型
# trainer.save_model('./checkpoint')
# # 评估模型
# eval_results = trainer.evaluate()
# print(f"Evaluation results: {eval_results}")
