In [None]:
!pip install scikit-learn datasets transformers seaborn torchtext wandb spacy_transformers peft accelerate flash-attn bitsandbytes

In [None]:
from datasets import load_dataset, load_metric
raw_datasets = load_dataset("wi_locness", 'wi')

In [None]:
from transformers import AutoTokenizer
model_checkpoint = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
def preprocess_function(examples):
    inputs = examples['text']
    model_inputs = tokenizer(
        inputs,
        max_length=512,
        truncation=True,
        return_offsets_mapping=True
    )
    
    labels_out = []
    offset_mapping = model_inputs.pop("offset_mapping")
    for i in range(len(model_inputs["input_ids"])):
        example_idx = i

        start_idx = offset_mapping[i][0][0]
        end_idx = offset_mapping[i][-2][1]  # last token is <eos>, so we care about second last tok offset

        edits = examples["edits"][example_idx]

        corrected_text = inputs[example_idx][start_idx:end_idx]

        for start, end, correction in reversed(
            list(zip(edits["start"], edits["end"], edits["text"]))
        ):
            if start < start_idx or end > end_idx:
                continue
            start_offset = start - start_idx  # >= 0
            end_offset = end - start_idx
            if correction == None:
                correction = tokenizer.unk_token
            corrected_text = (
                corrected_text[:start_offset] + correction + corrected_text[end_offset:]
            )
        
        labels_out.append(corrected_text)
    
    labels_out = tokenizer(labels_out, max_length=512, truncation=True)
    model_inputs["labels"] = labels_out["input_ids"]
    
    return model_inputs

In [None]:
tokenized_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=raw_datasets['train'].column_names
)

In [None]:
print(tokenizer.batch_decode(tokenized_datasets["train"]['input_ids'][512]))
print(tokenizer.batch_decode(tokenized_datasets["train"]['labels'][512]))

In [None]:
print(tokenized_datasets["train"]['input_ids'][512])
print(tokenized_datasets["train"]['input_ids'][510])
print(tokenizer.batch_decode(tokenized_datasets["train"]['input_ids'][512][16:])) # check removal of prefix
print(tokenizer.batch_decode(tokenized_datasets["train"]['input_ids'][510][16:]))

In [None]:
# split train into 9:1 train:test
from sklearn.model_selection import train_test_split

print("Number of original training samples:\t{}".format(len(tokenized_datasets["train"]["input_ids"])))

X_train, X_test, Y_train, Y_test = train_test_split(tokenized_datasets["train"]["input_ids"], tokenized_datasets["train"]["labels"], test_size=0.1, random_state=0)

print("Number of training samples:\t\t{}".format(len(X_train)))
print("Number of test samples:\t\t\t{}".format(len(X_test)))

In [None]:
# remove prefixes and extract decoded train & validation for RNN, LSTM, GRU models (in case prof chris' code doesn't work):
X_train_text = []
Y_train_text = []
X_test_text = []
Y_test_text = []
X_validation_text = []
Y_validation_text = []

def decode(src, dest, is_y=False):
  # if is_y: # no prefix
  #   for i in range(len(src)):
  #     dest.append(tokenizer.batch_decode(src[i]))

  # else:
  #   for i in range(len(src)):
  #     dest.append(tokenizer.batch_decode(src[i][16:])) # remove prefix from tokenizer

  for i in range(len(src)):
    dest.append(tokenizer.batch_decode(src[i]))

decode(X_train, X_train_text)
decode(X_test, X_test_text)
decode(tokenized_datasets["validation"]["input_ids"], X_validation_text)

decode(Y_train, Y_train_text, True)
decode(Y_test, Y_test_text, True)
decode(tokenized_datasets["validation"]["labels"], Y_validation_text, True)

In [None]:
print(len(X_train) == len(X_train_text))
print(X_train[0:5])
print(X_train_text[0:5])
print(Y_validation_text[0:5])
print(len(Y_validation_text))