In [None]:
pip install scikit-learn datasets transformers seaborn torchtext wandb spacy_transformers peft accelerate flash-attn bitsandbytes

In [2]:
from datasets import load_dataset, load_metric
raw_datasets = load_dataset("wi_locness", 'wi')

In [3]:
from transformers import AutoTokenizer
model_checkpoint = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
prefix = "Correct spelling, punctuation and grammatical errors in this text:"
len_tokenized_prefix = len(tokenizer(prefix, add_special_tokens=False)["input_ids"])
def preprocess_function(examples):
    # print([x for x in examples])
    inputs = [prefix + ex for ex in examples['text']]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, return_offsets_mapping=True)
    labels_out = []
    offset_mapping = model_inputs.pop("offset_mapping")
    for i in range(len(model_inputs["input_ids"])):
        example_idx = i

        start_idx = offset_mapping[i][len_tokenized_prefix][0]
        end_idx = offset_mapping[i][-2][
            1
        ]  # last token is <eos>, so we care about second last tok offset

        edits = examples["edits"][example_idx]

        corrected_text = inputs[example_idx][start_idx:end_idx]
        start_idx -= len(prefix)
        end_idx -= len(prefix)

        for start, end, correction in reversed(
            list(zip(edits["start"], edits["end"], edits["text"]))
        ):
            if start < start_idx or end > end_idx:
                continue
            start_offset = start - start_idx  # >= 0
            end_offset = end - start_idx
            if correction == None:
                correction = tokenizer.unk_token
            corrected_text = (
                corrected_text[:start_offset] + correction + corrected_text[end_offset:]
            )
        labels_out.append(corrected_text)
    labels_out = tokenizer(labels_out, max_length=512, truncation=True)
    model_inputs["labels"] = labels_out["input_ids"]
    return model_inputs
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=raw_datasets['train'].column_names)

  _torch_pytree._register_pytree_node(


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map:   0%|          | 0/300 [00:00<?, ? examples/s]

In [4]:
print(tokenizer.batch_decode(tokenized_datasets["train"]['input_ids'][512]))
print(tokenizer.batch_decode(tokenized_datasets["train"]['labels'][512]))

['Correct', 'spelling', ',', 'punct', 'u', 'ation', 'and', '', 'gram', 'matic', 'al', 'errors', 'in', 'this', 'text', ':', 'My', 'favourite', 'sport', 'is', 'cricket', '.', 'I', 'love', 'cricket', 'very', 'much', 'since', 'from', 'my', 'school', 'time', '.', 'cricket', 'is', '', 'a', 'game', 'of', 'bat', 'and', 'ball', 'in', 'which', 'there', 'are', 'two', 'teams', 'which', 'have', 'eleven', 'players', 'on', 'each', 'side', '.', 'generally', 'we', 'are', 'using', '', 'cri', 'ket', 'ground', 'as', '', 'a', 'oval', 'shape', '.', '', '</s>']
['My', 'favourite', 'sport', 'is', 'cricket', '.', 'I', 'have', 'loved', 'cricket', 'very', 'much', 'since', 'from', 'my', 'school', 'days', '.', 'Cricket', 'is', '', 'a', 'game', 'of', 'bat', 'and', 'ball', 'in', 'which', 'there', 'are', 'two', 'teams', 'which', 'have', 'eleven', 'players', 'on', 'each', 'side', '.', '', 'Generally', ',', 'we', 'use', '', 'a', 'cricket', 'ground', 'which', 'has', 'an', 'oval', 'shape', '.', '', '</s>']


In [5]:
print(tokenized_datasets["train"]['input_ids'][512])
print(tokenized_datasets["train"]['input_ids'][510])
print(tokenizer.batch_decode(tokenized_datasets["train"]['input_ids'][512][16:])) # check removal of prefix
print(tokenizer.batch_decode(tokenized_datasets["train"]['input_ids'][510][16:]))

[28223, 19590, 6, 5427, 76, 257, 11, 3, 5096, 4992, 138, 6854, 16, 48, 1499, 10, 7008, 3960, 2600, 19, 18096, 5, 27, 333, 18096, 182, 231, 437, 45, 82, 496, 97, 5, 18096, 19, 3, 9, 467, 13, 3795, 11, 1996, 16, 84, 132, 33, 192, 2323, 84, 43, 20394, 1508, 30, 284, 596, 5, 2389, 62, 33, 338, 3, 2685, 8044, 1591, 38, 3, 9, 17986, 2346, 5, 3, 1]
[28223, 19590, 6, 5427, 76, 257, 11, 3, 5096, 4992, 138, 6854, 16, 48, 1499, 10, 634, 613, 7298, 30, 39, 17652, 2017, 808, 82, 1388, 6, 38, 34, 65, 373, 118, 82, 2461, 12, 161, 21, 8, 2968, 18, 567, 127, 1123, 22898, 9452, 5841, 11, 8, 1502, 942, 131, 8, 1098, 24, 27, 54, 370, 25, 28, 5, 37, 166, 97, 27, 808, 294, 16, 8, 4192, 5130, 47, 16, 1673, 11, 437, 258, 34, 65, 118, 3, 9, 600, 294, 13, 82, 280, 5, 23636, 565, 26, 57, 34, 27, 708, 12, 320, 21, 8278, 16, 27700, 5, 1541, 2038, 27, 183, 6908, 1566, 11, 4329, 7, 28921, 7, 44, 636, 1888, 5224, 26, 3768, 3920, 5, 3, 1]
['My', 'favourite', 'sport', 'is', 'cricket', '.', 'I', 'love', 'cricket', 'very

In [6]:
# split train into 9:1 train:test
from sklearn.model_selection import train_test_split

print("Number of original training samples:\t{}".format(len(tokenized_datasets["train"]["input_ids"])))

X_train, X_test, Y_train, Y_test = train_test_split(tokenized_datasets["train"]["input_ids"], tokenized_datasets["train"]["labels"], test_size=0.1, random_state=0)

print("Number of training samples:\t\t{}".format(len(X_train)))
print("Number of test samples:\t\t\t{}".format(len(X_test)))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Number of original training samples:	3000
Number of training samples:		2700
Number of test samples:			300


In [7]:
# remove prefixes and extract decoded train & validation for RNN, LSTM, GRU models (in case prof chris' code doesn't work):
X_train_text = []
Y_train_text = []
X_test_text = []
Y_test_text = []
X_validation_text = []
Y_validation_text = []

def decode(src, dest, is_y=False):
  if is_y: # no prefix
    for i in range(len(src)):
      dest.append(tokenizer.batch_decode(src[i]))

  else:
    for i in range(len(src)):
      dest.append(tokenizer.batch_decode(src[i][16:])) # remove prefix from tokenizer

decode(X_train, X_train_text)
decode(X_test, X_test_text)
decode(tokenized_datasets["validation"]["input_ids"], X_validation_text)

decode(Y_train, Y_train_text, True)
decode(Y_test, Y_test_text, True)
decode(tokenized_datasets["validation"]["labels"], Y_validation_text, True)

In [8]:
print(len(X_train) == len(X_train_text))
print(X_train[0:5])
print(X_train_text[0:5])
print(Y_validation_text[0:5])
print(len(Y_validation_text))

True
[[28223, 19590, 6, 5427, 76, 257, 11, 3, 5096, 4992, 138, 6854, 16, 48, 1499, 10, 196, 619, 16, 3, 9, 422, 690, 6, 258, 132, 33, 360, 378, 24, 82, 1511, 263, 21, 8, 1164, 10, 114, 8, 11667, 6, 2459, 62, 398, 214, 13, 125, 19, 263, 8, 1037, 24, 62, 3793, 6, 11, 474, 135, 16, 8, 269, 4021, 5, 100, 19, 3, 9, 385, 1810, 21, 178, 6, 68, 3, 99, 62, 66, 4139, 56, 36, 3, 9, 600, 199, 21, 8, 1164, 5, 101, 92, 54, 169, 16307, 24, 59, 11037, 78, 231, 6, 3, 99, 62, 54, 4693, 5086, 2806, 1002, 42, 805, 3648, 494, 5, 86, 8, 600, 3119, 79, 43, 1553, 12, 918, 1442, 3950, 6, 79, 169, 8, 2806, 23, 138, 452, 1855, 59, 12, 5492, 2810, 5, 86, 82, 3474, 132, 164, 36, 186, 2219, 6, 68, 3, 99, 62, 278, 31, 17, 1445, 34, 6, 62, 56, 59, 199, 69, 11341, 5, 1], [28223, 19590, 6, 5427, 76, 257, 11, 3, 5096, 4992, 138, 6854, 16, 48, 1499, 10, 7008, 3960, 2100, 1482, 6, 27, 310, 114, 48, 2600, 21, 633, 2081, 10, 8218, 94, 31, 7, 182, 1017, 27, 54, 103, 34, 3461, 11, 11008, 27, 241, 5, 8401, 94, 31, 7, 182, 2877