# Pre-Trained Transformer Model

In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import random
import torch
import time
import csv

## Load Models


In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("prithivida/informal_to_formal_styletransfer")
model = AutoModelForSeq2SeqLM.from_pretrained("prithivida/informal_to_formal_styletransfer").to(device)

## Create Wrapper


In [3]:
def empathize(input_string, max_length=32, max_candidate=1, quality_filter=0.95):
    tokens = tokenizer.encode(input_string, return_tensors="pt").to(device)

    preds = model.generate(
        tokens,
        do_sample=True,
        max_length=max_length,
        top_k=50,
        top_p=0.95,
        early_stopping=True,
        num_return_sequences=max_candidate
    )

    for pred in preds:
        return tokenizer.decode(pred, skip_special_tokens=True).strip()
    return None


## Load Example Data


In [4]:
with open("unlabelled.txt", "r") as l:
    lines = [x.strip() for x in l.readlines() if x.strip() != ""]
random.shuffle(lines)
print(f"Loaded {len(lines)} examples:\n")

Loaded 7871825 examples:



## Formatting Helpers


In [5]:
def format_time(seconds):
    if seconds < 60:
        return f"{seconds:.2f}s"
    elif seconds < 3600:
        return f"{(seconds/60):.2f}min"
    else:
        return f"{(seconds/3600):.2f}hr"


## Generate Example Spreadsheet


In [6]:
output = "examples.csv"

with open(output, "w") as w:
    w.write("Formal,Casual,Time/s\n")

total_start = time.time()
for i, line in enumerate(lines):
    start = time.time()
    emp = empathize(line)
    elapsed = time.time() - start
    with open(output, "a") as w:
        wr = csv.writer(w, quoting=csv.QUOTE_ALL)
        wr.writerow([line, emp, elapsed])
    total_elapsed = time.time()-total_start
    print(f"[{i+1}/{len(lines)}] - {(i+1)/len(lines):.2f}%, Elapsed: {format_time(total_elapsed)}, ETA: {format_time(total_elapsed/(i+1) * (len(lines)-i-1))}")

[1/7871825] - 0.00%, Elapsed: 6.58s, ETA: 14393.99hr
[2/7871825] - 0.00%, Elapsed: 6.86s, ETA: 7500.91hr
[3/7871825] - 0.00%, Elapsed: 7.28s, ETA: 5304.54hr
[4/7871825] - 0.00%, Elapsed: 7.75s, ETA: 4238.84hr
[5/7871825] - 0.00%, Elapsed: 8.38s, ETA: 3665.27hr
[6/7871825] - 0.00%, Elapsed: 8.92s, ETA: 3251.55hr
[7/7871825] - 0.00%, Elapsed: 9.56s, ETA: 2985.27hr
[8/7871825] - 0.00%, Elapsed: 10.18s, ETA: 2781.99hr
[9/7871825] - 0.00%, Elapsed: 10.49s, ETA: 2549.45hr
[10/7871825] - 0.00%, Elapsed: 11.04s, ETA: 2413.02hr
[11/7871825] - 0.00%, Elapsed: 11.65s, ETA: 2315.11hr
[12/7871825] - 0.00%, Elapsed: 11.92s, ETA: 2172.48hr
[13/7871825] - 0.00%, Elapsed: 12.29s, ETA: 2067.09hr
[14/7871825] - 0.00%, Elapsed: 12.94s, ETA: 2020.34hr
[15/7871825] - 0.00%, Elapsed: 13.57s, ETA: 1977.49hr
[16/7871825] - 0.00%, Elapsed: 13.92s, ETA: 1902.00hr
[17/7871825] - 0.00%, Elapsed: 14.31s, ETA: 1840.66hr
[18/7871825] - 0.00%, Elapsed: 14.90s, ETA: 1810.44hr
[19/7871825] - 0.00%, Elapsed: 15.51s, ETA:

KeyboardInterrupt: 