In [None]:
import pandas as pd
from Sastrawi.Stemmer.StemmerFactory import StemmerFactory
import preprocessor as p
from slang_word import SLANG_WORDS
from Sastrawi.StopWordRemover.StopWordRemoverFactory import StopWordRemoverFactory
from datasets import load_dataset
from transformers import (
    T5ForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    AutoTokenizer
)
import torch

# Preprocess Util

In [None]:
def clean_repetitive(word):
    prev_char = None
    char_count=-1
    clean_word=''
    for c in word:
        if prev_char!=c:
            prev_char=c
            char_count=0
        else:
            char_count+=1
        if char_count<1:
            clean_word+=c
    #remove word if only 1 char left
    return clean_word if len(clean_word)>1 else ''

In [None]:
def clean_text(text):
    #lower case
    text = text.lower()
    #clean text with tweet-preprocessor
    text = p.clean(text)
    #clean repetitive word
    text = " ".join([clean_repetitive(word) for word in text.split()])
    #convert slang word into dictionary
    text = " ".join([SLANG_WORDS[word] if word in SLANG_WORDS else word for word in text.split()])
    return text

In [None]:
# create stemmer
factory = StemmerFactory()
stemmer = factory.create_stemmer()
def stem(text):
    return stemmer.stem(text)

In [None]:
#create stopword remover
stop_factory = StopWordRemoverFactory()
stopword_remover = stop_factory.create_stop_word_remover()
def stopword_removal(text):
    return stopword_remover.remove(text)

In [None]:
def preprocess_dataset(examples):
    inputs = examples[TEXT_COL]
    inputs = [clean_text(input) for input in inputs]
    if PREPROCESS=='p02' or PREPROCESS=='p04':
        inputs = [stopword_removal(input) for input in inputs]
    if PREPROCESS=='p03' or PREPROCESS=='p04':
        inputs = [stem(input) for input in inputs]
    targets =examples["quadruplet"] 
    tokenized_inputs = tokenizer(
        inputs, text_target=targets, max_length=max_length, truncation=True
    )
    return tokenized_inputs

# Inference

In [None]:
tokenizer = AutoTokenizer.from_pretrained("Wikidepia/IndoT5-base")
model = T5ForConditionalGeneration.from_pretrained('models/tf-indot5')

In [None]:
#dataset
raw_dataset = load_dataset('csv', data_files='../Data/quadruplet_only.csv', split='train')
tokenized_dataset = raw_dataset.map(preprocess_dataset, batched=True, remove_columns=raw_dataset.column_names)
splitted_dataset = tokenized_dataset.train_test_split(test_size=0.2, seed=42)


In [None]:
max_length = 100
generated_text = model.generate(splitted_dataset['test']['input_ids'], max_length=max_length)
pred_text = tokenizer.batch_decode(generated_text, skip_special_tokens=True)

In [None]:
test_df = splitted_dataset['test'].to_pandas()
test_df['pred_quadruplet_pt_bart'] = pred_text

In [None]:
test_df.to_csv('data/quadruplet_test_pred.csv', index=False)