# Installing CAET5

In [None]:
!git clone https://github.com/LeoLaugier/conditional-auto-encoder-text-to-text-transfer-transformer.git

%cd conditional-auto-encoder-text-to-text-transfer-transformer 

!pip install .

# Settings

## TPU setting

In [None]:
print("Setting up GCS access...")
import tensorflow as tf
import tensorflow_gcs_config
from google.colab import auth
# Set credentials for GCS reading/writing from Colab and TPU.
TPU_TOPOLOGY = "2x2"
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  TPU_ADDRESS = tpu.get_master()
  print('Running on TPU:', TPU_ADDRESS)
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
auth.authenticate_user()
tf.config.experimental_connect_to_host(TPU_ADDRESS)
tensorflow_gcs_config.configure_gcs_from_colab_auth()

%env TPU_ADDRESS=$TPU_ADDRESS

## GCS Settings

In [None]:
import os

base_dir = "gs://yourbucket/"  # @param { type: "string" }
%env BASE_DIR = $base_dir
models_dir_name = "your_models_dir_name" # @param { type: "string" }
%env MODELS_DIR_NAME = $models_dir_name
MODELS_DIR = os.path.join(base_dir, models_dir_name)
model_size = "large"
%env MODEL_SIZE = $model_size
model_dir_counter = 1 # @param { type: "integer" }
%env MODEL_DIR_COUNTER = $model_dir_counter
bucket = "yourbucket" # @param { type: "string" }
%env BUCKET = $bucket
data_dir_name = "your_data_dir" # @param { type: "string" }
%env DATA_DIR_NAME = $data_dir_name 


## Other Settings

In [None]:
training_steps = 100000 # @param { type: "integer" }
%env TRAINING_STEPS = $training_steps
task = "yelp" # @param { type: "string" }
mixture = "mixture_%s" %task 
%env MIXTURE = $mixture
sequence_length_gin = os.path.join("sequence_lengths", "%s.gin" % task)
%env SEQUENCE_LENGTH_GIN = $sequence_length_gin
control_codes_gin = os.path.join("control_codes", "%s.gin" % task)
%env CONTROL_CODES_GIN = $control_codes_gin

#Training CAET5

In [None]:
!caet5 --base_dir="${BASE_DIR}" \
       --model_dir_name="${MODELS_DIR_NAME}" \
       --model_size="${MODEL_SIZE}" \
       --model_dir_counter="${MODEL_DIR_COUNTER}" \
       --tpu="${TPU_ADDRESS}" \
       --module_import=caet5.data.tasks \
       --use_model_api=True \
       --mode="finetune" \
       --bucket="${BUCKET}" \
       --data_raw_dir_name="yelp_processed" \
       --train_steps="${TRAINING_STEPS}" \
       --mixture_or_task="${MIXTURE}" \
       --base_pretrained_model_dir="gs://t5-data/pretrained_models/" \
       --gin_file="dataset.gin" \
       --gin_file="objectives/denoise.gin" \
       --gin_file="models/cae_bi.gin" \
       --gin_file="train.gin" \
       --gin_file="${SEQUENCE_LENGTH_GIN}" \
       --gin_file="${CONTROL_CODES_GIN}" \
       --gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'"

# Evaluating CAET5


In [None]:
!caet5 --base_dir="${BASE_DIR}" \
       --model_dir_name="${MODELS_DIR_NAME}" \
       --model_size="${MODEL_SIZE}" \
       --model_dir_counter="${MODEL_DIR_COUNTER}" \
       --tpu="${TPU_ADDRESS}" \
       --module_import=caet5.data.tasks \
       --use_model_api=True \
       --mode="eval" \
       --bucket="${BUCKET}" \
       --mixture_or_task="${MIXTURE}" \
       --base_pretrained_model_dir="gs://t5-data/pretrained_models/" \
       --checkpoint_mode="latest" \
       --gin_file="dataset.gin" \
       --gin_file="models/cae_bi.gin" \
       --gin_file="${SEQUENCE_LENGTH_GIN}" \
       --gin_file="${CONTROL_CODES_GIN}" \
       --gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'"

#Predicting / Inferring with CAET5

In [None]:
import re
import time 

comment_attribute_pairs = [

            {"text": "these donuts have the perfect texture and taste .",
             "Destination attribute": "negative"},
            {"text": "these donuts have the perfect texture and taste .",
             "Destination attribute": "positive"},

            {"text": "good food for the price .",
             "Destination attribute": "negative"},
            {"text": "good food for the price .",
             "Destination attribute": "positive"},

            {"text": "a little dirty on the inside , but wonderful people that work there !",
                "Destination attribute": "negative"},
            {"text": "a little dirty on the inside , but wonderful people that work there !",
                "Destination attribute": "positive"},

            {"text": "i always order it when i go there and it is always awesome .",
             "Destination attribute": "negative"},
            {"text": "i always order it when i go there and it is always awesome .",
             "Destination attribute": "positive"},

            {"text": "the rest of the food there is good also and not very expensive .",
             "Destination attribute": "negative"},
            {"text": "the rest of the food there is good also and not very expensive .",
             "Destination attribute": "positive"},

            {"text": "great food , low prices , and huge quantity !",
             "Destination attribute": "negative"},
            {"text": "great food , low prices , and huge quantity !",
             "Destination attribute": "positive"},

            {"text": "so excited to have a chinese place near my office !",
                "Destination attribute": "negative"},
            {"text": "this is my go to spot for chinese food .",
                "Destination attribute": "positive"},

            {"text": "i guess i need to spell that out more clearly next time .",
             "Destination attribute": "positive"},
            {"text": "i guess i need to spell that out more clearly next time .",
             "Destination attribute": "negative"},

            {"text": "the service the last time i went was just terrible .",
             "Destination attribute": "positive"},
            {"text": "the service the last time i went was just terrible .",
             "Destination attribute": "negative"},

            {"text": "it has n't been for quite a few years .",
                "Destination attribute": "positive"},
            {"text": "it has n't been for quite a few years .",
                "Destination attribute": "negative"},

            {"text": "the food here is n't very good .",
                "Destination attribute": "positive"},
            {"text": "the food here is n't very good .",
                "Destination attribute": "negative"},

            {"text": "i am sad to see how much this place has gone downhill .",
                "Destination attribute": "positive"},
            {"text": "i am sad to see how much this place has gone downhill .",
                "Destination attribute": "negative"},

            {"text": "never again will i go back to this restaurant .",
                "Destination attribute": "positive"},
            {"text": "never again will i go back to this restaurant .",
                "Destination attribute": "negative"},

            {"text": "i would n't go here for a meal ever again .",
                "Destination attribute": "positive"},
            {"text": "i would n't go here for a meal ever again .",
                "Destination attribute": "negative"},

            {"text": "but nothing show stopping .",
                "Destination attribute": "positive"},
            {"text": "but nothing show stopping .",
                "Destination attribute": "negative"},

            {"text": "very rude , will not come back .",
                "Destination attribute": "positive"},
            {"text": "very rude , will not come back .",
                "Destination attribute": "negative"},
        ]

attribute_ids = {"negative": "0", "positive": "1"}

comments = []
for p in comment_attribute_pairs:
    comments.append(p["text"] + "|dst_attribute:" + attribute_ids[p["Destination attribute"]])

now = time.time()
# Write out the input text to text files.

model_dir = os.path.join("%s_%s" % (MODELS_DIR, str(model_dir_counter)), model_size)
predict_inputs_path = os.path.join(model_dir, "predict_inputs_%d.txt" % now)
predict_outputs_path = os.path.join(model_dir, "predict_outputs_%d.txt" % now)
# Manually apply preprocessing
with tf.io.gfile.GFile(predict_inputs_path, "w") as f:
    for c in comments:
        c = re.sub(r'\n', r"\\n", c, flags=re.S)
        f.write("%s\n" % c.lower())

predict_batch_size = len(comments)
%env PREDICT_BATCH_SIZE = $predict_batch_size
%env PREDICT_INPUTS_PATH = $predict_inputs_path
%env PREDICT_OUTPUTS_PATH = $predict_outputs_path

In [None]:
!caet5 --base_dir="${BASE_DIR}" \
       --model_dir_name="${MODELS_DIR_NAME}" \
       --model_size="${MODEL_SIZE}" \
       --model_dir_counter="${MODEL_DIR_COUNTER}" \
       --tpu="${TPU_ADDRESS}" \
       --module_import=caet5.data.tasks \
       --use_model_api=True \
       --mode="predict" \
       --bucket="${BUCKET}" \
       --mixture_or_task="${MIXTURE}" \
       --base_pretrained_model_dir="gs://t5-data/pretrained_models/" \
       --checkpoint_mode="latest" \
       --input_file="${PREDICT_INPUTS_PATH}" \
       --output_file="${PREDICT_OUTPUTS_PATH}" \
       --predict_batch_size="${PREDICT_BATCH_SIZE}" \
       --gin_file="dataset.gin" \
       --gin_file="models/cae_bi.gin" \
       --gin_file="infer.gin" \
       --gin_file="${SEQUENCE_LENGTH_GIN}" \
       --gin_file="${CONTROL_CODES_GIN}" \
       --gin_param="utils.tpu_mesh_shape.tpu_topology = '2x2'"

In [None]:
# The output filename will have the checkpoint appended so we glob to get
# the latest.
prediction_files = sorted(tf.io.gfile.glob(predict_outputs_path + "*"))
print("\nPredictions using checkpoint %s:\n" % prediction_files[-1].split("-")[-1])
with tf.io.gfile.GFile(prediction_files[-1]) as f:
    for c, g in zip(comments, f):
        if c:
            print("Initial text: " + c.split("|dst_style:")[0])
            print("Generated text: " + g)
            print()