In [None]:
import pandas as pd
import re
import spacy

## 1. Data Acquisition

**Data Citation:**
```
@article{asai2023selfrag,
  author    = {Asai, Akari and Wu, Zeqiu and Wang, Yizhong and Sil, Avirup and Hajishirzi, Hannaneh},
  title     = {{Self-RAG}: Learning to Retrieve, Generate, and Critique through Self-Reflection},
  year      = {2023},
  journal   = { arXiv preprint arXiv:2310.11511 },
  URL       = {https://arxiv.org/abs/2310.11511}
}
```


In [None]:
# Load data
raw_df = pd.read_json("hf://datasets/selfrag/selfrag_train_data/train.jsonl", lines=True)
raw_df.head()

Unnamed: 0,instruction,output,input,id,dataset_name
0,"In this task, you are given a context paragrap...",[Retrieval]<paragraph>2017 Portland train atta...,,flan_v2_88425,flan_v2
1,"Question: Write a text based on ""rangers show ...",[No Retrieval]six opposition candidates in the...,,flan_v2_18667,flan_v2
2,You will be given a sentence. Check whether th...,[No Retrieval]1\n****\n[Utility:5],,flan_v2_87754,flan_v2
3,Q:Is there a negative or positive tone to this...,[No Retrieval]Negative[Utility:5],,flan_v2_47789,flan_v2
4,Question: Fertility Clinics Vary on Embryo Dis...,[No Retrieval]Answer: World[Utility:5],,flan_v2_10803,flan_v2


In [None]:
def check_retrieval(output) -> str:
  """Function to denote if retrieval took place when generating an output"""
  if re.search(r"\[Retrieval\]", output):
    return "true"
  elif re.search(r"\[No Retrieval\]", output):
    return "false"
  else:
    # Some outputs are tagged with `[Continue to Use Evidence]` and `Dialog`.
    # We could flag these as false but we will ignore them instead.
    return "null"

In [None]:
# Apply check retrieval function
raw_df["retrieval"] = raw_df["output"].apply(check_retrieval)

In [None]:
# Check for balance
raw_df["retrieval"].value_counts()

Unnamed: 0_level_0,count
retrieval,Unnamed: 1_level_1
True,74219
False,70362
,1038


In [None]:
# Remove unwanted rows
raw_df.drop(raw_df.loc[raw_df["retrieval"] == "null"].index, inplace=True)
raw_df["retrieval"].value_counts()

Unnamed: 0_level_0,count
retrieval,Unnamed: 1_level_1
True,74219
False,70362


## 2. Text Cleaning and Preprocessing

In [None]:
nlp = spacy.load("en_core_web_lg")

In [None]:
def clean_preprocess(texts):
  processed_tokens = []
  for doc in nlp.pipe(texts, disable=nlp.pipe_names, batch_size=5000, n_process=8):
    tokens = [
            token.lower_ for token in doc
            if not token.is_punct
            and not token.like_num
        ]
    processed_tokens.append(" ".join(tokens))
  return processed_tokens

In [None]:
preprocessed_list = clean_preprocess(raw_df.instruction.astype(str).tolist())
raw_df["preprocessed_instruction"] = preprocessed_list