In [None]:
!pip install gpt-2-simple



In [None]:
from datetime import datetime
from google.colab import files
%tensorflow_version 1.x
import gpt_2_simple as gpt2
import numpy as np

TensorFlow 1.x selected.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [None]:
# gpt2.download_gpt2(model_name="355M")

In [None]:
gpt2.mount_gdrive()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
gpt2.copy_checkpoint_from_gdrive(run_name='model-3.0')

In [None]:
test_file = "test_descriptions.csv"
gpt2.copy_file_from_gdrive(test_file)

In [None]:
import argparse
import re
ANSWER_PATTERN = re.compile(r'<\|startofname\|>(.+?)<\|endofname\|>')
NUMBER_OF_ANSWERS_GENERATED = 64
NUMBER_OF_ANSWERS_SELECTED = 64
SUBMISSION_ROW = '{}\n'
THREAD_COUNT = 8

In [None]:
def sanitize_input(input):
    """
    Removes break line html task and | symbol from input string
    """
    if input.startswith('"') and input.endswith('"'):
        input = input[1:-1]

    return input \
        .replace('<br>', '') \
        .replace('</br>', '') \
        .replace('<br/>', '') \
        .replace('|', '') \
        .upper()

In [None]:

def sanitize_answer(item):
    """
    Cleans up answer given by a model.
    - removes commas (reserved for item name separation in submission file)
    - strips ending dots
    :param item:
    :return:
    """
    item = item \
        .upper() \
        .replace(',', '') \
        .strip()

    if item.endswith('.'):
        return item[0:-1]
    else:
        return item

In [None]:
def extract_answer(model_output):
    """
    Extracts answer from model output
    using ANSWER_PATTERN regular expression
    :param model_output: single element of model generation result
    :return: extracted answer
    """
    matched_answers = re.findall(ANSWER_PATTERN, model_output)
    if len(matched_answers) > 0:
        return sanitize_answer(matched_answers[0])
    else:
        return None

In [None]:
def answer_quality(answer, prompt):
    """
    Calculates heuristics of answer quality in
    the context of prompt given to the model
    :param answer: answer extracted from model output
    :param prompt: prompt given to the model
    :return: float value between 0 and 1. 0 - answer considered nonsensical, 1 - answer considered exact
    """
    return 1.0

In [None]:
def without_duplicates(items):
    """
    Removes duplicates from given list
    :param items:
    :return: list of unique items
    """
    return list(dict.fromkeys(items))

In [None]:
import pandas as pd
def load_test_descriptions(test_descriptions_file_path):
    """
    :param test_descriptions_file_path: path to the single column CSV file containing item descriptions
    :return: list of item descriptions
    """
    test_descriptions = pd.read_csv(test_descriptions_file_path, sep="~")
    return test_descriptions['description'].tolist()

In [None]:
def select_answers(outputs, prompt):
    """
    Returns top answers according to extract_answer heuristics
    :param outputs: output generated by the model
    :param prompt: prompt give to the model
    :return: list of top NUMBER_OF_ANSWERS SELECTED answers sorted according to extarct_answer quality heuristics
    """
    answers = sorted(
        without_duplicates(
            list(
                filter(
                    lambda x: x is not None,
                    [extract_answer(output) for output in outputs if output is not None]
                )
            )
        ),
        key=lambda x: answer_quality(x, prompt),
        reverse=True
    )[: NUMBER_OF_ANSWERS_SELECTED]
    return answers

In [None]:
def generate_model_outputs(input):
    """
    Generates NUMBER_OF_ANSWERS_GENERATED answers
    using gpt-2 model loaded in TF Session from
    given input
    :param input: gpt-2 prompt (starting text)
    :return: list of model results
    """
    gpt_2_prompt = "<|startoftext|> <|startofdesc|> {} <|endofdesc|> <|startofname|>"
    description = gpt_2_prompt.format(sanitize_input(input))
    return gpt2.generate(
        tf_sess,
        temperature=1.0,
        length=60,
        nsamples=NUMBER_OF_ANSWERS_GENERATED,
        prefix=description,
        run_name='model-3.0',
        return_as_list=True,
        seed=666
    )

In [None]:
def generate_item_name_candidates(description, step, on_finished):
    """
    Generates list of item name candidates sorted according to
    quality heuristics generated from given item description
    :param step: number of a step to be executed
    :param description: single description from test_descriptions.csv file
    :param on_finished: callback on item generated
    :return:
    """
    reset_model(step)
    outputs = generate_model_outputs(description)
    answers = select_answers(outputs, description)
    on_finished()
    return ', '.join(answers)

In [None]:
!pip install alive_progress



In [None]:
import alive_progress as ap
def generate_answers_file(test_descriptions_file_path, answer_file_path):
    """
    Writes names generated by model to a file
    """
    descriptions = load_test_descriptions(test_descriptions_file_path)
    with ap.alive_bar(len(descriptions), bar='filling') as bar:
        names = [generate_item_name_candidates(description, index, bar)
                 for (index, description)
                 in enumerate(descriptions)]

    with open(answer_file_path, 'w') as file:
        file.write(SUBMISSION_ROW.format('name'))
        for item_name_candidates in names:
            file.write(SUBMISSION_ROW.format(item_name_candidates))

In [None]:
tf_sess = gpt2.start_tf_sess(threads=4)
gpt2.load_gpt2(tf_sess, run_name='model-3.0', multi_gpu=True)

Loading checkpoint checkpoint/model-3.0/model-403
INFO:tensorflow:Restoring parameters from checkpoint/model-3.0/model-403


In [None]:
def load_model():
    gpt2.load_gpt2(
        tf_sess,
        run_name='model-3.0',
        multi_gpu=True
    )

In [None]:
def reset_model(step_count):
    global tf_sess
    if step_count > 0 and step_count % 10 == 0:
        tf_session = gpt2.reset_session(sess=tf_sess)
        load_model()

In [None]:
generate_answers_file(test_file, 'content/drive/submission.csv')

RuntimeError: ignored