# Data Prep Notebook
### Purpose is to download, process, and write out instruction tuning examples for various ciphers

In [25]:
# Imports
import re
import random
from datasets import load_dataset
from transformers import T5Tokenizer

In [3]:
# Enciphering/deciphering helpers
char_to_num = {
    'a': 0,
    'b': 1,
    'c': 2,
    'd': 3,
    'e': 4,
    'f': 5,
    'g': 6,
    'h': 7,
    'i': 8,
    'j': 9,
    'k': 10,
    'l': 11,
    'm': 12,
    'n': 13,
    'o': 14,
    'p': 15,
    'q': 16,
    'r': 17,
    's': 18,
    't': 19,
    'u': 20,
    'v': 21,
    'w': 22,
    'x': 23,
    'y': 24,
    'z': 25,
}


# Remove all non alphabet text except spaces
def format_text(text):
    plaintext = re.sub(r'[^A-Za-z ]+', '', text)
    return plaintext.lower()


# NOTE: shift can be negative (left) or positive (right)
# If encode=True, encipher text, otherwise decipher
def caesar_cipher(original, shift, encode):
    if encode:
        myshift = shift
    else:
        myshift = shift * -1
    newtext = ''
    for i in original:
        if i == ' ':  # Preserve spaces
            newtext += ' '
        else:
            newnum = (char_to_num[i] + myshift) % 26
            newchar = list(char_to_num.keys())[list(char_to_num.values()).index(newnum)]
            newtext += newchar
    return newtext

In [4]:
# Download gigaword dataset from huggingface
DATA_NAME = "gigaword"
gigaword = load_dataset(DATA_NAME)

Downloading builder script: 100%|█| 4.40k/4.40k [00:00<00:00, 8.10MB/
Downloading metadata: 100%|█████| 2.20k/2.20k [00:00<00:00, 13.4MB/s]
Downloading readme: 100%|███████| 8.06k/8.06k [00:00<00:00, 30.4MB/s]
Downloading data: 100%|███████████| 578M/578M [00:10<00:00, 57.3MB/s]
Generating train split: 100%|█| 3803957/3803957 [00:53<00:00, 71639.3
Generating validation split: 100%|█| 189651/189651 [00:02<00:00, 7315
Generating test split: 100%|█| 1951/1951 [00:00<00:00, 70948.75 examp
  table = cls._concat_blocks(blocks, axis=0)


In [5]:
gigaword

DatasetDict({
    train: Dataset({
        features: ['document', 'summary'],
        num_rows: 3803957
    })
    validation: Dataset({
        features: ['document', 'summary'],
        num_rows: 189651
    })
    test: Dataset({
        features: ['document', 'summary'],
        num_rows: 1951
    })
})

In [8]:
gigaword['train'][0]

{'document': "australia 's current account deficit shrunk by a record #.## billion dollars -lrb- #.## billion us -rrb- in the june quarter due to soaring commodity prices , figures released monday showed .",
 'summary': 'australian current account deficit narrows sharply'}

In [26]:
MODEL_NAME = "google/flan-t5-base"
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [53]:
# Define the preprocessing function
def preprocess_function(examples):
    """Add prefix to the sentences, tokenize the text, and set the labels"""
    # Create lists of data of instructions w/ ciphered text and the corresponding plaintext
    inputs = []
    targets = []
    for doc in examples["document"]:
        shift = random.randint(-25, 25)
        prefix = f"Use a Caesar cipher with shift {shift} to decipher the following text: "
        text = format_text(doc)
        inputs.append(prefix + caesar_cipher(text, shift, True))
        targets.append(text)

    # Tokenize
    model_inputs = tokenizer(inputs, max_length=128, truncation=True)
    labels = tokenizer(text_target=targets, max_length=512, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [55]:
tokenized_dataset = gigaword.map(preprocess_function, batched=True, remove_columns=["document", "summary"])

Map: 100%|████████| 3803957/3803957 [40:33<00:00, 1563.19 examples/s]
Map: 100%|██████████| 189651/189651 [01:57<00:00, 1607.60 examples/s]


In [58]:
tokenized_dataset['test'][0]

{'input_ids': [2048,
  3,
  9,
  26218,
  3,
  3389,
  760,
  28,
  4108,
  3,
  11590,
  12,
  20,
  3389,
  760,
  8,
  826,
  1499,
  10,
  3,
  208,
  51,
  115,
  51,
  172,
  3,
  15,
  3,
  172,
  1824,
  32,
  3,
  32,
  9,
  26,
  115,
  3,
  51,
  172,
  102,
  3,
  122,
  172,
  210,
  3,
  32,
  9,
  63,
  115,
  122,
  89,
  1824,
  26,
  3,
  32,
  9,
  26,
  115,
  1584,
  3,
  89,
  17,
  1824,
  3,
  122,
  1000,
  89,
  1824,
  102,
  3,
  15,
  89,
  51,
  89,
  1824,
  15,
  3,
  15,
  51,
  413,
  3,
  23,
  1824,
  102,
  172,
  1824,
  15,
  2028,
  157,
  3,
  89,
  17,
  1824,
  157,
  3,
  17,
  1167,
  3,
  51,
  7,
  26,
  1824,
  1824,
  102,
  3,
  89,
  9,
  409,
  76,
  172,
  3,
  19042,
  1824,
  15,
  3,
  76,
  172,
  3,
  15,
  122,
  115,
  1824,
  26,
  32,
  9,
  63,
  115,
  1],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
 

In [61]:
tokenized_dataset.save_to_disk(dataset_dict_path='/home/as6734/langgen_class_project/data/caesar')

Saving the dataset (8/8 shards): 100%|█| 3803957/3803957 [00:12<00:00
Saving the dataset (1/1 shards): 100%|█| 189651/189651 [00:01<00:00, 
Saving the dataset (1/1 shards): 100%|█| 1951/1951 [00:00<00:00, 1446


In [None]:
# Load saved datasets
data_files = {"train": "train.csv", "test": "caesar_test.csv", "validation": "caesar_validation.csv"}
dataset = load_dataset("caesar/", data_dir='/home/as6734/langgen_class_project/data/caesar', data_files=data_files)

In [22]:
prefix = "Use a Caesar cipher with shift -3 to decipher the following text: "
s = "Use a Caesar cipher with shift -3 to decipher the following text: ABCDEFG HIJK"
s.replace(prefix, '')

'ABCDEFG HIJK'