<a href="https://colab.research.google.com/github/LouisCastricato/EDGAR-P/blob/main/EDGAR_P.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installs/Setup

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
import nltk
from nltk.tokenize import sent_tokenize
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
!pip install transformers sentencepiece textacy auto-tqdm graphviz



In [4]:
%%capture
!python -m spacy download en_core_web_sm
import spacy
import textacy
nlp = spacy.load("en_core_web_sm")

In [5]:
#We're using Neo for this tutorial
from transformers import GPTNeoModel, GPTNeoForCausalLM,\
    GPT2Tokenizer, GPTNeoConfig
import torch
from transformers import (
  StoppingCriteriaList,
  MinLengthLogitsProcessor,
  MaxLengthCriteria,
  AutoTokenizer,
  AutoModelForCausalLM,
  LogitsProcessorList,
  MaxTimeCriteria,
  ForcedEOSTokenLogitsProcessor,
)
import transformers
transformers.logging.set_verbosity(transformers.logging.CRITICAL)

In [6]:
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")
model.config.pad_token_id = model.config.eos_token_id


In [7]:
%%capture
model = model.to("cuda")

In [8]:
from transformers.generation_logits_process import LogitsProcessor,\
NoBadWordsLogitsProcessor, NoRepeatNGramLogitsProcessor, RepetitionPenaltyLogitsProcessor

class HorizonRepetitionPenalty(LogitsProcessor):
  def __init__(self, penalty: float, horizon: torch.LongTensor, horizon_exclusive = False):
    if not isinstance(penalty, float) or not (penalty > 0):
      raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

    self.penalty = penalty
    self.horizon=horizon
    self.exclusive=horizon_exclusive
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
    num_beams = input_ids.shape[0]
    horizon = torch.cat(num_beams*[self.horizon], dim=0)
    if not self.exclusive:
      input_ids = torch.cat((input_ids, horizon), dim=-1)
    else:
      input_ids = horizon
    for i in range(scores.shape[0]):
      for previous_token in set(input_ids[i].tolist()):
        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
        if scores[i, previous_token] < 0:
          scores[i, previous_token] *= self.penalty
        else:
          scores[i, previous_token] /= self.penalty
    return scores

In [9]:
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    TopKLogitsWarper,
    TemperatureLogitsWarper,
    BeamSearchScorer,
)

bad_words_list = ["Because", "Because,", "Because ", "because",\
                  " Because", " Because,", " because"\
                  "Yes", "Yes,", "Yes ", "yes",\
                  " Yes", " Yes,", " yes",\
                  " No", " No,", " no",\
                  "(", " (", ")", ") "]
bad_words_ids = list(map(lambda x: tokenizer(x)['input_ids'], bad_words_list))

expl = [[1427],[2602, 834],[29343],[37405],[35780],[2602]]
bad_words_ids += expl
print(bad_words_ids)

#Takes a model and computes the perplexity of the target sequence given the input sequence
def perplexity(encodings, stride=1, m=model):
  lls = []
  inp_ids = encodings['input_ids']
  start = encodings['start']
  max_length = len(encodings['input_ids'].squeeze())
  for i in range(start, inp_ids.size(1), stride):
      begin_loc = max(i + stride - max_length, 0)
      end_loc = min(i + stride, inp_ids.size(1))
      trg_len = end_loc - i    # may be different from stride on last loop
      input_ids = inp_ids[:,begin_loc:end_loc].to("cuda")
      target_ids = input_ids.clone()
      target_ids[:,:-trg_len] = -100

      with torch.no_grad():
          outputs = m(input_ids, labels=target_ids)
          log_likelihood = outputs[0] * trg_len
      lls.append(log_likelihood)

  return (torch.exp(torch.stack(lls).sum() / end_loc)).item()

#Constructs a sequence for determining the perplexity of a target given a prompt
def construct(prompt, target, force_start=None):
  prompt_tok = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
  target_tok = tokenizer(target, add_special_tokens=False, return_tensors="pt")
  if force_start is None:
    start = len(prompt_tok['input_ids'].squeeze())
  else:
    start = force_start
  #Start encodes where the prompt sequence ends and target begins
  return {
      'input_ids': torch.cat((prompt_tok['input_ids'],target_tok['input_ids']), dim=-1).cuda(),
      'attention_mask': torch.cat((prompt_tok['attention_mask'],target_tok['attention_mask']), dim=-1).cuda(),
      'start':start,
  }

def generate(ids, max_length=1024, horizon=None, horizon_penalty=None, beams=2, extra_bad_words = None, repetition_penalty=2.0):
  bad_words_t = bad_words_ids
  if extra_bad_words is not None:
    bad_words_t += extra_bad_words
  model_out=None
  if horizon is None:
    model_out = model.generate(input_ids = ids['input_ids'],\
                               max_length=max_length, num_beams=beams,\
                               no_repeat_ngram_size=5, bad_words_ids=bad_words_t, repetition_penalty=repetition_penalty)[0]
  else:
    horizon_ids = tokenizer(horizon, return_tensors="pt")['input_ids'].cuda()
    input_ids = ids["input_ids"]
    model.config.max_length = max_length
    # instantiate logits processors
    logits_processor = LogitsProcessorList([
        MinLengthLogitsProcessor(ids['input_ids'].shape[1], model.config.eos_token_id),
        NoRepeatNGramLogitsProcessor(5),
        NoBadWordsLogitsProcessor(bad_words_t, eos_token_id=model.config.eos_token_id),
        HorizonRepetitionPenalty(penalty=horizon_penalty, horizon=horizon_ids, horizon_exclusive=True),
        RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty)
    ])
    stopping_criteria = StoppingCriteriaList([
        MaxLengthCriteria(max_length=max_length),
    ])
    model_kwargs={
        "attention_mask":ids['attention_mask'],
        "use_cache":True,
    }
    with torch.no_grad():
      model_out = model.greedy_search(
          input_ids=ids["input_ids"], logits_processor=logits_processor,\
          stopping_criteria=stopping_criteria)[0]
    
  return tokenizer.decode(model_out)

[[8128], [8128, 11], [8128, 220], [13893], [4362], [4362, 11], [780, 5297], [5297, 11], [5297, 220], [8505], [3363], [3363, 11], [3763], [1400], [1400, 11], [645], [7], [357], [8], [8, 220], [1427], [2602, 834], [29343], [37405], [35780], [2602]]


# Ranker

In [10]:
%mkdir distilgpt2-ranker-roc
!cp /content/gdrive/My\ Drive/Colab\ Notebooks/distilgpt2-roc/config.json distilgpt2-ranker-roc/config.json
!cp /content/gdrive/My\ Drive/Colab\ Notebooks/distilgpt2-roc/training_args.bin distilgpt2-ranker-roc/training_args.bin
!cp /content/gdrive/My\ Drive/Colab\ Notebooks/distilgpt2-roc/pytorch_model.bin distilgpt2-ranker-roc/pytorch_model.bin

%mkdir distilgpt2-ranker-scifi
!cp /content/gdrive/My\ Drive/Colab\ Notebooks/distilgpt2-scifi/config.json distilgpt2-ranker-scifi/config.json
!cp /content/gdrive/My\ Drive/Colab\ Notebooks/distilgpt2-scifi/training_args.bin distilgpt2-ranker-scifi/training_args.bin
!cp /content/gdrive/My\ Drive/Colab\ Notebooks/distilgpt2-scifi/pytorch_model.bin distilgpt2-ranker-scifi/pytorch_model.bin



mkdir: cannot create directory ‘distilgpt2-ranker-roc’: File exists
mkdir: cannot create directory ‘distilgpt2-ranker-scifi’: File exists


In [11]:
from transformers import AutoTokenizer, AutoModelWithLMHead
special_tokens_dict = {'prompt' : '<pmpt>'}
#model_name = "distilgpt2-ranker-roc/"
#Download models
#tokenizer_roc =  AutoTokenizer.from_pretrained("distilgpt2")
#model_roc = AutoModelWithLMHead.from_pretrained(model_name).to("cuda")

model_name = "distilgpt2-ranker-scifi/"
#Download models
tokenizer_scifi =  AutoTokenizer.from_pretrained("distilgpt2")
model_scifi = AutoModelWithLMHead.from_pretrained(model_name).to("cuda")



In [12]:
import transformers
from torch.nn import CrossEntropyLoss, MSELoss


rep = transformers.RepetitionPenaltyLogitsProcessor(1.1)

def compute_loss(logits, labels):
  # Shift so that tokens < n predict n
  shift_logits = logits[..., :-1, :].contiguous()
  shift_labels = labels[..., 1:].contiguous()
  # Flatten the tokens
  loss_fct = CrossEntropyLoss()
  return loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))



#Computes perplexity using a rep penalty
def perplexity_w_rep(encodings, stride=1, m=None):
  lls = []
  inp_ids = encodings['input_ids']
  start = encodings['start']
  max_length = len(encodings['input_ids'].squeeze())
  for i in range(start, inp_ids.size(1), stride):
      begin_loc = max(i + stride - max_length, 0)
      end_loc = min(i + stride, inp_ids.size(1))
      trg_len = end_loc - i    # may be different from stride on last loop
      input_ids = inp_ids[:,begin_loc:end_loc].to("cuda")
      target_ids = input_ids.clone()
      target_ids[:,:-trg_len] = -100

      with torch.no_grad():
          outputs = m(input_ids, labels=target_ids)
          logits = outputs.logits.squeeze()
          for i in range(1, len(logits)):
            #print(i)
            #print(logits[i].shape)
            ids_t = input_ids[0, :i].unsqueeze(0)
            logits_t = logits[i].unsqueeze(0)
            logits[i] = rep(ids_t, logits_t).squeeze()
          loss = compute_loss(logits.unsqueeze(0), target_ids)
          #logits = rep(input_ids, logits)

          #print(target_ids.shape)
          #print(input_ids.shape)

          log_likelihood = loss * trg_len
      lls.append(log_likelihood)

  return (torch.exp(torch.stack(lls).sum() / end_loc)).item()


def rank(string, force_start=2, model=model_scifi):
  #Pull out the last word to use the construct function
  t1 = string.split()
  t2 = " ".join(t1[-1:])
  t1 = " ".join(t1[:-1])

  #Filtering out the first few tokens helps significantly. so force_start = 3
  return perplexity_w_rep(construct(t1, t2, force_start = force_start), m=model) 

rank("Karen was assigned a roommate her first year of college. Her roommate asked her to go to a nearby city for a concert.")

68.33157348632812

# Generator

In [13]:
questions_bad_words = ["What", " What", "\nWhat", " what", "what"]
questions_bad_words_ids = list(map(lambda x: tokenizer(x)['input_ids'], questions_bad_words))

def clean_story(story):
  #Need to remove colons and line breaks
  story = story.replace(":", "-")
  story = story.replace("\n\n", "")
  story = story.replace("\n", " ")
  story = story.replace("<|endoftext|>", "")
  return story

def get_questions(story):
  instructions = "You will be given a set of short stories and asked to write a set of questions.\n\
  Please write questions, in no particular order, that when answered tell the events leading up to the story.\n"
  story1 = "Story 1: John went for a swim.\n"
  good1 = "Good Questions: How did John get to the swimming pool? What happened before John went swimming? Why did John go swimming?\n"
  bad1 = "Bad Questions: Who is John? What did the pool water taste like? What happened after John went swimming? What does John do now?\n\n"


  story2 = "Story 2: The walk to school that day was long but, Tom was motivated to give Jim back his book. Tom gave the book to Jim.\n"
  good2 = "Good Questions: Why was Tom motivated? How did Tom get the book? Why did Tom give Jim the book? Why did Jim want the book?\n"
  bad2 = "Bad Questions: Who were they? Did Jim want the book? How are they similar? What is Tom wearing? When did Tom give Jim the book? Where is the book? What happens next? What is Jim going to do once he gets the book?\n\n"

  story3 = "Story 3: Mary was so happy to have finally crossed the street.\n"
  good3 = "Good Questions: Why did Mary cross the street? Why was Mary unhappy? Why was Mary running from someone?\n"
  bad3 = "Bad Questions: What is Mary like? What does Mary do after crossing the street? What does Mary do now that she is happy? What happens next?\n\n"

  story4 = "Story 4: " + story + "\n"
  
  inp = instructions + story1 + good1 +  bad1 + story2 + good2 + bad2 + story3 + good3 + bad3 + story4


  out = generate(construct(inp, "Good Questions: Why"), max_length=512, repetition_penalty=2.8, extra_bad_words=questions_bad_words_ids)
  #Get questions out
  out = out.split(":")[11].split("?")[:-1]
  out = "?".join(out)
  if out[-1] != "?":
    out+="?"
  out = sent_tokenize(out)
  good_questions = list()
  #As soon as we find a bad question, break
  for string in out:
    if not ("Bad Question" in string):
      good_questions.append(string)
    else:
      break
  #Remove the space from the first question
  if good_questions[0][0] == ' ':
    good_questions[0] = good_questions[0][1:]
  return good_questions

input_story = "They wanted to see what happened outside."

get_questions(clean_story(input_story))

['Why would they want to see if there was something outside?',
 'Why would they not want to see anything outside?']

In [14]:
def continue_story(story, question, width = 10):
  instructions = "You will be given a set of short stories and question.\nThe question asks about what happens before the story starts.\nList plausible answers for the question\n"
  story1="Story 1: Jim sat by the swings as Tom slowly approached. Tom gave the book to Jim during recess.\n"
  question1="Question: How did Tom get the book?\n"
  answer1=\
"Correct Answers:\n\
1. Jim gave Tom the book because Tom is his best friend.\n\
2. Tom is a thief, so he stole the book from Jim.\n\
3. Tom took the book from the bully that beat up Jim.\n\
4. Tom noticed that the book fell off Jim's lap as he played on the swings.\n\
Wrong Answers:\n\
1. Jim ran from Tom.\n\
2. Jim wanted something to drink.\n\
3. Jim sat patiently in his bedroom.\n\
4. Jim was not going to give Tom the book.\n\
5. Because Jim gave it to him.\n\n"
  story2 = "Story 2: Mary was so happy to have finally crossed the street.\n"
  question2 = "Question: What was Mary running from?\n"
  answer2=\
"Correct Answers:\n\
1. Mary's heart pounded. As she looked over her shoulder, she saw it.\n\
2. Mary had been trying to avoid them all day. If they knew she was skipping school she would be in trouble.\n\
3. The stop sign was the finish line, she only had a few hundred more feet to go.\n\
Wrong Answers:\n\
1. Mary was in the desert where there are no roads.\n\
2. Mary was going for a joyful stroll through the park.\n\
3. Mary was indifferent about her pursuer.\n\n" 
  story3 = "Story 3: " + story + "\n"
  question3 = "Question: " + question +"\n"
  inp = instructions + story1 + question1 +  answer1 + story2 + question2 + answer2 + story3 + question3


  out = generate(construct(inp, "Correct Answers:\n1."), max_length=1024, horizon=story, horizon_penalty=3.0)
  out = out.split("Story 3")[1]
  #Remove the next story
  out = out.split("Story 4")[0]
  #Filter to correct answers
  out = out.split("Correct Answers:")[1]
  out = out.split("Wrong")[0]

  #print(out)
  #If the user does not specify a width
  if width == -1:
    #Capture all of them
    width = 100
  responses = list()
  #Reads through the outputted list and returns every item
  for i in range(1, width):
    try:
      start = "\n"+str(i)
      end = "\n"+str(i+1)
      responses.append(out.split(start)[1].split(end)[0])
    except:
      break
  #Remove first space
  for i in range(len(responses)):
    if responses[i][:2] == '. ':
      responses[i] = responses[i][2:] 
    elif responses[i][:2] == ') ':
      responses[i] = responses[i][2:]
    elif responses[i][0] == ' ':
      responses[i] = responses[i][1:]

    responses[i] = " ".join(responses[i].split())
    responses[i] = clean_story(responses[i])
  return responses
inp_story = "They felt lucky they had evacuated when they did."
starts = continue_story(inp_story, "What did Karen do to get to the concert?")
print(starts)
new_stories = list(map(lambda x: x+" "+inp_story, starts))
print(new_stories)

['She walked home from school.', 'She rode her bike to school.', 'She walked home with her friends.']
['She walked home from school. They felt lucky they had evacuated when they did.', 'She rode her bike to school. They felt lucky they had evacuated when they did.', 'She walked home with her friends. They felt lucky they had evacuated when they did.']


In [15]:
continue_bad_words_list = ["Wrong", "\nWrong", " Wrong",\
                           "Wrongs", "\nWrongs", " Wrongs",\
                           "WrONG", "\nWrONG", " WrONG",\
                           "Wrond", "\nWrond", " Wrond",\
                           "wrong", "\nwrong", " wrong"]
continue_bad_words_list_ids = list(map(lambda x: tokenizer(x)['input_ids'], continue_bad_words_list))

#Generates SVO tuples continuing the story
def continue_story_svo(story, question, width = 10):
  instructions = "You will be given a set of short stories and question.\nThe question asks about what happens before the story starts.\nList plausible answers for the question without contradicting the story.\n"
  story1="Story 1: Jim sat by the swings as Tom slowly approached. Tom gave the book to Jim during recess.\n"
  question1="Question: Why did Jim recieve the book?\n"
  answer1=\
"Correct Answers:\n\
1. Tom desired to return the book to his friend.\n\
2. Jim bought the book from Tom.\n\
3. Tom noticed Jim dropped his book.\n\
4. Jim needed a book to study for his exam.\n\
5. Tom wanted to give Jim his favorite book.\n\
Wrong Answers:\n\
1. Jim did not want the book.\n\
2. Tom noticed Jim was sitting alone during lunch.\n\
3. Jim sat patiently in his bedroom.\n\
4. Jim was not going to give Tom the book.\n\n"
  story2 = "Story 2: Mary was happy to finally cross the street.\n"
  question2 = "Question: Why was Mary running?\n"
  answer2=\
"Correct Answers:\n\
1. Mary was running from a monster.\n\
2. Mary was running a marathon.\n\
3. Mary wanted to get away from her parents.\n\
Wrong Answers:\n\
1. Mary was in the desert where there are no roads.\n\
2. Mary was going for a joyful stroll through the park.\n\
3. Mary was indifferent about her pursuer.\n\
4. Mary was running on a road.\n\n" 
  story3 = "Story 3: " + story + "\n"
  question3 = "Question: " + question +"\n"
  inp = instructions + story1 + question1 +  answer1 + story2 + question2 + answer2 + story3 + question3
  #print(inp)

  out = generate(construct(inp, "Correct Answers:\n1."), max_length=512, horizon=story, horizon_penalty=3.5, repetition_penalty=2.0)
  #print(out)
  out = out.split("Story 3")[1]
  #Remove the next story
  out = out.split("Story 4")[0]
  #Filter to correct answers
  out = out.split("Correct Answers:")[1]
  out = out.split("Wrong")[0]

  #If the user does not specify a width
  if width == -1:
    #Capture all of them
    width = 100
  responses = list()
  #Reads through the outputted list and returns every item
  for i in range(1, width):
    try:
      start = "\n"+str(i)
      end = "\n"+str(i+1)
      responses.append(out.split(start)[1].split(end)[0])
    except:
      break
  #Remove first space
  for i in range(len(responses)):
    if responses[i][:2] == '. ':
      responses[i] = responses[i][2:] 
    elif responses[i][:2] == ') ':
      responses[i] = responses[i][2:]
    elif responses[i][0] == ' ':
      responses[i] = responses[i][1:]

    responses[i] = " ".join(responses[i].split())
    responses[i] = clean_story(responses[i])
  return responses
inp_story = "They felt lucky they had evacuated when they did."
starts = continue_story_svo(inp_story, "Why would they want to see what happened out there?")
print(starts)

['They wanted to see if the evacuation was successful.', 'They wanted to know if the evacuation was safe.', 'They wanted to find out if the evacuation was complete.']


In [16]:
def rank_sort(stories, model=model_scifi):
  ranks = list(map(lambda x: rank(x, force_start=1, model=model), stories))
  ranked_stories = zip(stories, ranks)
  sorted_stories = sorted(ranked_stories, key=lambda x: x[1])
  return list(map(lambda x: x[0], sorted_stories))
#print(rank_sort(new_stories)[0])

# EDGAR-P

In [17]:
#Determines if we should reject based off of if it has an SVO
def hasSVO(sent):
  text = nlp(sent)
  ex =  textacy.extract.subject_verb_object_triples(text)
  for i in ex:
    return True
  return False


In [18]:
#Determines if one statement implies another
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
#This stays on CPU
tokenizer_deberta = AutoTokenizer.from_pretrained("microsoft/deberta-v2-xxlarge-mnli")
model_deberta = AutoModelForSequenceClassification.from_pretrained("microsoft/deberta-v2-xxlarge-mnli")


In [19]:
import numpy as np
def not_contradict(sentA, sentB):
  to_run = "[CLS]" + sentA  + "[SEP]" + sentB + "[SEP]"
  inputs = tokenizer_deberta(to_run, return_tensors="pt")
  with torch.no_grad():
    outputs = model_deberta(**inputs)

  outputs = np.array(outputs.logits.squeeze().cpu().tolist())
  choice = np.argmax(outputs)
  if choice == 0:
    return -100
  return outputs[2]
  

In [20]:
print(not_contradict("The sky is blue.", "The sky is not blue."))
print(not_contradict("The sky is blue.", "The sky is not red."))
print(not_contradict("The sky is blue.", "The ground is green."))

KeyboardInterrupt: ignored

In [None]:
questions = get_questions(inp_story)
extensions = list()
for q in questions:
  extensions += continue_story_svo(inp_story, q)
#Do we have SVO
extensions = list(filter(hasSVO, extensions))
new_stories = list(map(lambda x: x+" "+inp_story, extensions))
sorted_l = rank_sort(new_stories)
print(sorted_l[0])
inp_story=sorted_l[0]

In [21]:
from auto_tqdm import tqdm

#Beams should be [inp_story] when we start
def beam_search(beams, width=20, diversity_width=2, graph=None):
  candidates = list()
  for story in tqdm(beams):
    story_sents = sent_tokenize(story)
    #print(story)

    questions = get_questions(story)

    extensions = list()
    for q in questions:
      continuation = continue_story_svo(story, q, width=-1)
      q_cur = [q]*len(continuation)
      extensions += zip(continuation, q_cur)
    #Sort by most likely to imply and take the top k
    implication_story = " ".join(story_sents[0:min(len(story_sents), 3)]) #The first three sentences sliding window
    #extensions = list(filter(hasSVO, extensions)) #Filter on if there is an SVO tuple
    extensions_ranks = list(map(lambda x: not_contradict(x[0], implication_story), extensions)) #Rank
    extensions_zip = list(filter(lambda x: x[1] > -100, zip(extensions, extensions_ranks))) #Zip
    extensions_zip = sorted(extensions_zip, key=lambda x: x[1], reverse=True)
    extensions = list(map(lambda x: x[0], extensions_zip)) #Sort

    extensions = extensions[:min(diversity_width, len(extensions))] #Take top k
    new_stories = list(map(lambda x: x[0]+" "+story, extensions)) #Concat

    #Debug mode
    if graph is not None:
      for i, s in enumerate(new_stories):
        graph.node(extensions[i][0])
        graph.edge(story_sents[0], extensions[i][0], label=extensions[i][1])

    #print("\n".join(new_stories))

    #Internally rank the new stories to preserve diversity
    #new_stories = rank_sort(new_stories)[0:min(diversity_width, len(new_stories))]
    candidates += new_stories

  sorted_l = rank_sort(candidates, model=model_scifi)
  #print(sorted_l[0])
  return sorted_l[:min(len(sorted_l), width)], graph

In [97]:
from graphviz import Digraph
from IPython.core.display import display, HTML
f = Digraph('transition_graph', filename='tg.png')
f.format='png'
f.engine = 'dot'
f.ratio="fill"
f.node_attr['fixedsize'] = 'false'
f.attr(rankdir='TB', size='100,100')

inp_story="They felt lucky they had evacuated when they did.."
#Initial vertex
f.node(sent_tokenize(inp_story)[0])


beams, f = beam_search([inp_story], width=5, diversity_width=2, graph=f)
for i in range(5):
  f.render()
  #display(HTML(f.svg()))

  #print(beams)
  beams = list(set(beams))
  print("\nSTEP: " + str(i) + "\n\n")
  print("\n".join(beams))
  beams, f = beam_search(beams, width=5, diversity_width=2,graph=f)




  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [05:41<00:00, 341.25s/it][A[A

                                              [A[A

  0%|          | 0/2 [00:00<?, ?it/s][A[A


STEP: 0


They were lucky because they were able to evacuate. They felt lucky they had evacuated when they did..
They were lucky that they were able to escape. They felt lucky they had evacuated when they did..




 50%|█████     | 1/2 [06:57<06:57, 417.88s/it][A[A

100%|██████████| 2/2 [09:44<00:00, 392.79s/it][A[A

                                              [A[A

  0%|          | 0/4 [00:00<?, ?it/s][A[A


STEP: 1


The fire department told them it was safe to leave. They were lucky that they were able to escape. They felt lucky they had evacuated when they did..
The school was on fire! They were lucky that they were able to escape. They felt lucky they had evacuated when they did..
The storm was moving so fast that they left early. They were lucky because they were able to evacuate. They felt lucky they had evacuated when they did..
The storm was coming closer so they left early. They were lucky because they were able to evacuate. They felt lucky they had evacuated when they did..




 25%|██▌       | 1/4 [05:16<15:49, 316.35s/it][A[A

 50%|█████     | 2/4 [12:26<10:55, 327.69s/it][A[A

 75%|███████▌  | 3/4 [18:24<05:30, 330.77s/it][A[A

100%|██████████| 4/4 [21:10<00:00, 314.29s/it][A[A

                                              [A[A

  0%|          | 0/5 [00:00<?, ?it/s][A[A


STEP: 2


The storm made it easy for them to escape. The storm was moving so fast that they left early. They were lucky because they were able to evacuate. They felt lucky they had evacuated when they did..
There is never an evacuation drill. The fire department told them it was safe to leave. They were lucky that they were able to escape. They felt lucky they had evacuated when they did..
The storm was very dangerous for them to stay at home. The storm was coming closer so they left early. They were lucky because they were able to evacuate. They felt lucky they had evacuated when they did..
The storm was approaching too quickly for them to leave later. The storm was coming closer so they left early. They were lucky because they were able to evacuate. They felt lucky they had evacuated when they did..
The siren is too loud. The fire department told them it was safe to leave. They were lucky that they were able to escape. They felt lucky they had evacuated when they did..




 20%|██        | 1/5 [07:44<30:56, 464.23s/it][A[A

 40%|████      | 2/5 [10:03<21:35, 431.71s/it][A[A

 60%|██████    | 3/5 [13:34<13:39, 409.62s/it][A[A

 80%|████████  | 4/5 [18:51<06:40, 400.36s/it][A[A

100%|██████████| 5/5 [22:03<00:00, 379.52s/it][A[A

                                              [A[A

  0%|          | 0/5 [00:00<?, ?it/s][A[A


STEP: 3


They were lucky because they escaped the fire at the last minute. There is never an evacuation drill. The fire department told them it was safe to leave. They were lucky that they were able to escape. They felt lucky they had evacuated when they did..
The firemen have left the building because the siren hasn't been turned on. The siren is too loud. The fire department told them it was safe to leave. They were lucky that they were able to escape. They felt lucky they had evacuated when they did..
They were lucky their house wasn’t destroyed by the storm. The storm was approaching too quickly for them to leave later. The storm was coming closer so they left early. They were lucky because they were able to evacuate. They felt lucky they had evacuated when they did..
The firemen thought the siren sounded like a car alarm, but it wasn't. The siren is too loud. The fire department told them it was safe to leave. They were lucky that they were able to escape. They felt lucky they h



 20%|██        | 1/5 [05:48<23:12, 348.11s/it][A[A

 40%|████      | 2/5 [09:07<16:39, 333.27s/it][A[A

 60%|██████    | 3/5 [14:32<11:04, 332.45s/it][A[A

 80%|████████  | 4/5 [17:53<05:19, 319.27s/it][A[A

100%|██████████| 5/5 [25:54<00:00, 335.43s/it][A[A

                                              [A[A

  0%|          | 0/5 [00:00<?, ?it/s][A[A


STEP: 4


Their house caught on fire. They were lucky, because they didn’t have to evacuate until the next day. There is never an evacuation drill. The fire department told them it was safe to leave. They were lucky that they were able to escape. They felt lucky they had evacuated when they did..
It's not necessary anymore. The firemen have left the building because the siren hasn't been turned on. The siren is too loud. The fire department told them it was safe to leave. They were lucky that they were able to escape. They felt lucky they had evacuated when they did..
They were lucky since they could go out into the streets safely. They were lucky, because they didn’t have to evacuate until the next day. There is never an evacuation drill. The fire department told them it was safe to leave. They were lucky that they were able to escape. They felt lucky they had evacuated when they did..
They could have gone to a shelter or a hotel instead of leaving right away. They were lucky their h

KeyboardInterrupt: ignored

In [98]:
f.attr(rankdir='TB', size='100,100')
f.engine = 'dot'
f.ratio="fill"
f.format="gv"
f.node_attr['fixedsize'] = 'false'
f.attr(fontsize='100')
#print(f.node)
#f.node_attr['width'] = '7'
f.node_attr['height'] = '3'
f.render("tg")

!unflatten -l 4 -f tg.gv | dot -Tpdf -o wide.pdf
!ls

distilgpt2-ranker-roc	 sample_data  tg.gv.gv	tg.png.gv   wide.pdf
distilgpt2-ranker-scifi  tg	      tg.gz.gv	tg.png.pdf  wide.png
gdrive			 tg.gv	      tg.png	tg.png.png
