In [None]:
!pip install datasets
!pip install transformers
!pip install evaluate
!pip install sentencepiece
!pip install nltk

In [None]:
from nltk import download
download('punkt')

In [None]:
import json
import random

from transformers import AutoModelWithLMHead, AutoTokenizer
from nltk.tokenize import word_tokenize


def init_model(model_name: str, device, do_lower_case: bool = False, args=None):
    """
    Initialize a pre-trained LM
    :param model_name: from MODEL_CLASSES
    :param device: CUDA / CPU device
    :param do_lower_case: whether the model is lower cased or not
    :return: the model and tokenizer
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=do_lower_case, use_fast=False)
    model = AutoModelWithLMHead.from_pretrained(model_name)
    model.to(device)
    model.eval()
    return tokenizer, model


def get_json(json_path: str):
    with open(json_path) as json_file:
        return json.load(json_file)


def combine_data(json_path: str, dataset_type: str):
    input_json = get_json(json_path)
    examples = []
    for docset in input_json[dataset_type]:
        input_texts = docset["text"]
        examples.append(
            (
                f"[shuffled] {' '.join([' '.join((f'<S{i}>', sent)) for i, sent in zip(list(range(len(input_texts))), input_texts)])} [orig]"
            )
        )
    return examples

In [None]:
import argparse
import json
import logging
import re

import torch
import tqdm


def generate_conditional(tokenizer, model, args, input, device):
    """
    Generate a sequence with models like Bart and T5
    """
    input_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input))
    decoder_start_token_id = input_ids[-1]
    input_ids = torch.tensor([input_ids]).to(device)
    max_length = args.max_length

    outputs = model.generate(
        input_ids,
        do_sample=args.beams == 0,
        max_length=max_length,
        min_length=5,
        temperature=args.temperature,
        top_p=args.p if args.p > 0 else None,
        top_k=args.k if args.k > 0 else None,
        num_beams=args.beams if args.beams > 0 else None,
        early_stopping=True,
        no_repeat_ngram_size=2,
        eos_token_id=tokenizer.eos_token_id,
        decoder_start_token_id=decoder_start_token_id,
        num_return_sequences=1  # max(1, args.beams)
    )

    preds = [tokenizer.decode(
        output, skip_special_tokens=False, clean_up_tokenization_spaces=False) for output in outputs]

    return preds

In [None]:
"""
Generate outputs
"""

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)

logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser()

# Required
parser.add_argument(
    "--in_file",
    default="/content/dataset.json",
    type=str,
    help="The input json file",
)
parser.add_argument(
    "--out_file",
    default="/content/test.json",
    type=str,
    help="out jsonl file",
)
parser.add_argument(
    "--model_name_or_path",
    default="junyinc/LING575-WIN21-Reorder",
    type=str,
    help="LM checkpoint for initialization.",
)

# Optional
parser.add_argument(
    "--max_length", default=1024, type=int, required=False, help="Maximum text length"
)
parser.add_argument(
    "--k", default=0, type=int, required=False, help="k for top k sampling"
)
parser.add_argument(
    "--p", default=0, type=float, required=False, help="p for nucleus sampling"
)
parser.add_argument(
    "--beams", default=1, type=int, required=False, help="beams for beam search"
)
parser.add_argument(
    "--temperature",
    default=1.0,
    type=float,
    required=False,
    help="temperature for sampling",
)
parser.add_argument(
    "--dataset_type",
    default="train",
    type=str,
    help="Which part of the dataset to load.",
)

args, unknown = parser.parse_known_args()
logger.debug(args)

if (
        (args.k == args.p == args.beams == 0)
        or (args.k != 0 and args.p != 0)
        or (args.beams != 0 and args.p != 0)
        or (args.beams != 0 and args.k != 0)
):
    raise ValueError(
        "Exactly one of p, k, and beams should be set to a non-zero value."
    )

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
logger.debug(f"Initializing {device}")

tokenizer, model = init_model(args.model_name_or_path, device)

examples = combine_data(args.in_file, args.dataset_type)

logger.info(examples[:5])

special_tokens = ["[shuffled]", "[orig]", "<eos>"]
extra_specials = [f"<S{i}>" for i in range(args.max_length)]
special_tokens += extra_specials

with open(args.out_file, "w") as f_out:
    for input_lines in tqdm.tqdm(examples):
        try:
            preds = generate_conditional(
                tokenizer,
                model,
                args,
                input_lines,
                device,
            )

            # Remove any word that has "]" or "[" in it
            preds = [re.sub(r"(\w*\])", "", pred) for pred in preds]
            preds = [re.sub(r"(\[\w*)", "", pred) for pred in preds]
            preds = [re.sub(" +", " ", pred).strip() for pred in preds]

        except Exception as exp:
            logger.info(exp)
            preds = []

        f_out.write(
            json.dumps({"input": input_lines, "predictions": preds})
            + "\n"
        )