# 2. Typo correction


## 2.1 Import the required libraries


In [None]:
import pandas as pd
import torch
import os

from langdetect import detect_langs, DetectorFactory
from transformers import pipeline
from textblob import TextBlob
from utils import *
from tqdm import tqdm

tqdm.pandas()

device = True if torch.cuda.is_available() else False
print("GPU availability:{}".format(device))

language_detector = pipeline("text-classification", model="papluca/xlm-roberta-base-language-detection") # this model is 1.1 gigabyte so it will take around 5 mins to download it
typo_corrector = pipeline("text2text-generation", model="oliverguhr/spelling-correction-english-base", max_length=1000)

DetectorFactory.seed = 0

## 2.2 Load in the raw data

In [None]:
if 'DATABRICKS_RUNTIME_VERSION' in os.environ:
    def file_exists(path):
        try:
            dbutils.fs.ls(path)
            return True
        except Exception as e:
            if 'java.io.FileNotFoundException' in str(e):
                raise FileNotFoundError("File could not be found. Are you sure the file exists in the provided directory?")
            else:
                raise
    input_path = dbutils.jobs.taskValues.get(taskKey = "Flagging", key = "OutPath", default = "None", debugValue = 0)
    file_exists(input_path)
    flagged_df = spark.read.option("header", "true").option("inferSchema","true").csv(input_path).toPandas()
else:
    # Determine the location of the dataframe containing the translated text
    base_tilt_data_location = "../../data/example_data/output/base_data/base_flagged_products.csv"
    # Determine the location of the dataframe containing the translated text
    italy_tilt_data_location = "../../data/example_data/output/italy_data/italy_flagged_products.csv"
    
    # use raw_df
    base_tilt_data = pd.read_csv(base_tilt_data_location)
    # use raw_df
    italy_tilt_data = pd.read_csv(italy_tilt_data_location)

## 2.3 Apply typo correction module

### Helper functions

In [None]:
def conf_ld_detect_language(text, model="def"):
    """Language detection wrapper.
    
    Returns detected language (ISO-code) and confidence of detection. In case of 
    failure of detection string 'ident_fail' and a pd.NA value for confidence is 
    returned.
    
    Args:
        text (str): The string for which language shall be detected.
        model (str): The model to be used for language detection. Defaults to langdetect model.
    Returns:
        str: The detected language (ISO-code).
    """
    try:
        if model == "def":
            highest_conf = detect_langs(text)[0]
            return highest_conf.lang
        elif model == "huggingface":
            result = language_detector(text)[0]
            return str(result["label"])
    except:   
        return "ident_fail", pd.NA

In [None]:
def typo_correction(text="", model="def"):
    """Typo correction wrapper.
    
    Returns corrected text. In case of failure of correction the original text 
    is returned. 
    
    Args:
        text (str): The string to be corrected.
        model (str): The model to be used for typo correction. Defaults to textblob model.
    Returns:
        str: The corrected string.
    """
    try:
        if model == "def":
            return(TextBlob(text).correct().string)
        elif model == "huggingface":
            return(typo_corrector(text)[0]["generated_text"])
    except:
        return text

### Typo correction module

In [None]:
def typo_correct_df(df):
    """Typo correction wrapper for dataframes.
    
    Returns dataframe with corrected text. In case of failure of correction the 
    original text is returned. 
    
    Args:
        df (pd.DataFrame): The dataframe containing the text to be corrected.
    Returns:
        pd.DataFrame: The dataframe with corrected text.
    """
    # detect the language of the text but only for the rows that do not have a value in the automatic_processed_products_and_services column
    print("Detecting the language of the text...")
    # only take rows that have a True value in the to_process column
    to_process_df = df[df["to_process"] == True].copy()
    # exclude to_processed_df rows from df
    df = df[df["to_process"] == False].copy()
    to_process_df.loc[:, "language (ISO-code)"] = to_process_df["products_and_services"].progress_apply(lambda x: conf_ld_detect_language(x, model="huggingface"))

    # then take subset of english texts
    print("Taking subset of English texts...")
    english_df = to_process_df[to_process_df["language (ISO-code)"] == "en"]
    # exclude enlgish texts from the original df
    to_process_df = to_process_df[to_process_df["language (ISO-code)"] != "en"]

    # apply typo correction to english texts
    print("Applying typo correction...")
    english_df = english_df.copy()
    english_df.loc[:, "typo_corrected"] = english_df["products_and_services"].progress_apply(lambda x: typo_correction(x, model="huggingface"))

    # merge the corrected english texts with the original df
    print("Merging the corrected english texts with the original df...")
    df = pd.concat([to_process_df, english_df, df], ignore_index=True)
    # replace empty values in typo_corrected with the original text
    df["typo_corrected"].fillna(df["products_and_services"], inplace=True)
    # make typo_corrected lowercase and remove all dots at the end
    df["typo_corrected"] = df["typo_corrected"].str.lower().str.replace("\.$", "")
    return df

### Provided input Data

In [None]:
corrected_df = typo_correct_df(flagged_df)

### Base Data

In [None]:
# base_typo_corrected_df = typo_correct_df(base_tilt_data)

### tilt Italy Data

In [None]:
# italy_typo_corrected_df = typo_correct_df(italy_tilt_data)

## 2.4 Export the dataframe with the corrected text 

In [None]:
if 'DATABRICKS_RUNTIME_VERSION' in os.environ:
    output_path = str(os.path.join(os.path.dirname(os.path.dirname(input_path)), "output/corrected_" + os.path.basename(input_path))).replace("\\", "/")    
    # Convert the pandas dataframe to a spark sql dataframe
    corrected_df_spark = spark.createDataFrame(corrected_df)
    # Write the new dataframe to the path
    corrected_df_spark.write.csv(output_path, mode="overwrite", header=True)
    dbutils.jobs.taskValues.set(key = 'OutPath', value = output_path)
    
else:
    # Define the path for the new dataframe
    output_path_base_typo_corrected = "../../data/example_data/output/base_data/base_typo_corrected_products.csv"
    # Define the path for the new dataframe
    output_path_italy_typo_corrected = "../../data/example_data/output/italy_data/italy_typo_corrected_products.csv"

    # Write the new dataframe to the path
    base_typo_corrected_df.to_csv(output_path_base_typo_corrected, index=False)
    # Write the new dataframe to the path
    italy_typo_corrected_df.to_csv(output_path_italy_typo_corrected, index=False)