In [None]:
!pip install torch==1.4.0
!pip install transformers==2.9.0
!pip install pytorch_lightning==0.7.5

## Run inference with any question as input

In [19]:
# https://github.com/huggingface/transformers/issues/4411

import torch
from transformers import T5ForConditionalGeneration,T5Tokenizer
import tensorflow_hub as hub
from rouge import Rouge 
from typing import List
import numpy as np
import sacrebleu

"""
def set_seed(seed):
  torch.manual_seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

set_seed(42)
"""
model = T5ForConditionalGeneration.from_pretrained('t5_paraphrase')
tokenizer = T5Tokenizer.from_pretrained('t5-base')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print ("device ",device)
model = model.to(device)

#USE Embedder
url = "https://tfhub.dev/google/universal-sentence-encoder-large/4"
embed = hub.load(url)

INPUT_SEN = "What is the ARIMA forecasted revenue for the next 4 months?"

#Beam-search config
MAX_SEQ_LEN = 256
# Number of Paraphrases you want to generate
NB_GENERATED = 30
# Top N to keep
TOP_TO_KEEP = 10
#TOP_P (values bwn 0-1): threshold to keep token for nucleus sampling
TOP_P = 0.90
#TODO: Do a sweep in the right top_k/top_p value.


#Put in proper format
text =  "paraphrase: " + INPUT_SEN + " </s>"

encoding = tokenizer.encode_plus(text,pad_to_max_length=True, return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)

# https://huggingface.co/transformers/model_doc/t5.html?highlight=generate#overview
# https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.PreTrainedModel.generate
beam_outputs = model.generate(
    input_ids=input_ids, attention_mask=attention_masks,
    do_sample=True,
    max_length=MAX_SEQ_LEN,
    early_stopping=True,
    top_k=100,
    top_p=TOP_P,
    num_return_sequences=NB_GENERATED
)


paraphrases =[]
for beam_output in beam_outputs:
    sent = tokenizer.decode(beam_output, skip_special_tokens=True,clean_up_tokenization_spaces=True)
    if sent.lower() != sentence.lower() and sent not in final_outputs:
        paraphrases.append(sent)
        
def get_n_best_para(input_sentence: str, paraphrases: List[str], top_n: int = 1) -> List[str]:
    """
    Returns
        (list of strings): top n paraphrases that are most semantically similar (using USE embeddings) and most
            different structurally (using L-Rouge) to the input_sentence
    """
    #Remove duplicate sentences
    paraphrases = list(set(paraphrases))
    if len(paraphrases) < top_n:
        top_n = len(paraphrases)
        
    rouge = Rouge() 
    rouge_scrs = [1- rouge.get_scores(input_sentence, para)[0]['rouge-l']['f'] for para in paraphrases]
    
    ting = [[para] for para in paraphrases]
    bleu = sacrebleu.corpus_bleu(input_sentence , ting)
    print(bleu.score)
    
    #NOTE: Measure similarity using inner-product on USE embedding.
    #enc_input_sentence, *enc_paraphrases = self.embed([input_sentence] + paraphrases)
    enc_input_sentence = embed([input_sentence])
    enc_paraphrases = embed(paraphrases)
    scored_paraphrases = [
        (paraphrase, np.inner(enc_input_sentence['outputs'].numpy(), enc_paraphrase), score)
        for (paraphrase, enc_paraphrase, score) in zip(paraphrases, enc_paraphrases['outputs'].numpy(),rouge_scrs)
    ]
    #Sort on meaning, then diversity
    top_n_paraphrases = sorted(scored_paraphrases, key=lambda x: (x[1], x[2]), reverse=True)[:top_n]
    return [x[0] for x in top_n_paraphrases]

top_para = get_n_best_para(INPUT_SEN, paraphrases, TOP_TO_KEEP)

print("INPUT SENTENCE :", INPUT_SEN)
print("PARAPRHASES :")
for i, paraphrase in enumerate(top_para):
        print("n°%d : %s" % (i, paraphrase))


print("INPUT SENTENCE :", INPUT_SEN)
print("PARAPRHASES :")
for i, paraphrase in enumerate(paraphrases):
        print("n°%d : %s" % (i, paraphrase))


device  cpu
91.46912192286942
INPUT SENTENCE : What is the ARIMA forecasted revenue for the next 4 months?
PARAPRHASES :
n°0 : What is ARIMA forecasted revenue for the next 4 months?
n°1 : What is the ARIMA projected revenue for the next 4 months?
n°2 : What is the ARIMA forecast revenue for the next 4 months?
n°3 : What is the projected ARIMA revenue for the next 4 months?
n°4 : What is the forecast ARIMA revenue for the next 4 months?
n°5 : What is the ARIMA estimated revenue for the next 4 months?
n°6 : What is ARIMA projected revenue for the next 4 months?
n°7 : What is ARIMA forecast revenue for the next 4 months?
n°8 : What is the revenue forecast of ARIMA for the next 4 months?
n°9 : What is ARIMA revenue forecast for the next 4 months?
INPUT SENTENCE : What is the ARIMA forecasted revenue for the next 4 months?
PARAPRHASES :
n°0 : What is the projected ARIMA revenue for the next 4 months?
n°1 : How will ARIMA's revenues be for the next 4 months?
n°2 : What is ARIMA forecasted r