## Dependencies

In [None]:
!apt install subversion &> /dev/null
!pip install transformers &> /dev/null
!pip install simpletransformers &> /dev/null
!pip install datasets &> /dev/null
import os
# Force runtime restart after installing dependencies (simpletransformers requires this), no need to run this cell again.
#os.kill(os.getpid(), 9)

## Imports and auxiliar functions

In [None]:
# Imports
import os
import tarfile
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from simpletransformers.seq2seq import Seq2SeqModel

In [None]:
def truncate(output):
    sentence = tokenizer.decode(output[0], skip_special_tokens=True)
    index = max(sentence.rfind(i) for i in '!?.')
    sentence = sentence[:index+1]
    return sentence

**⏰ Note that it could take a while to download the two models. ⏰**



---



## Description Model

In [None]:
tar_name = 'trained_gpt_medium.tar'
tar_gdrive_id = '1-eRePCWxcHnTt6Tf_mchxJxi8F3a2KYd'
model_path = 'gpt-model'
model = None
print('Downloading finetuned model.')

if not os.path.isfile(tar_name):
    !gdown --id {tar_gdrive_id}
    tar = tarfile.open(tar_name, "r:")
    tar.extractall()
    tar.close()
    !rm {tar_name}

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model_desc = GPT2LMHeadModel.from_pretrained(model_path, local_files_only=True, 
                                        pad_token_id=tokenizer.eos_token_id)

Downloading finetuned model.
Downloading...
From: https://drive.google.com/uc?id=1-eRePCWxcHnTt6Tf_mchxJxi8F3a2KYd
To: /content/trained_gpt_medium.tar
1.44GB [00:07, 182MB/s]


# Title Model

In [None]:
tar_name = 'trained_bart_base.tar'
tar_gdrive_id = '1N3p-4Ao9wTQ40cLiNjSxYVZvrdqZxTH8'

print('Downloading finetuned model.')
if not os.path.isfile(tar_name):
  !gdown --id {tar_gdrive_id}
tar = tarfile.open(tar_name, "r:")
tar.extractall()
tar.close()
!rm {tar_name}

model_title = Seq2SeqModel(encoder_decoder_type="bart", use_cuda=False,
                        encoder_decoder_name='outputs')

Downloading finetuned model.
Downloading...
From: https://drive.google.com/uc?id=1N3p-4Ao9wTQ40cLiNjSxYVZvrdqZxTH8
To: /content/trained_bart_base.tar
559MB [00:03, 179MB/s]


## Generator

In [None]:
from IPython.display import HTML, display
def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
starter = "In the year 2077"
input = tokenizer.encode(starter, return_tensors='pt')
output = model_desc.generate(input, min_length=40, temperature=0.8, do_sample=True, max_length=70,  top_k=200, top_p=0.95, early_stopping=True)             
movie_desc = truncate(output)
movie_title = model_title.predict([movie_desc])[0]
print("Title: {}".format(movie_title))
print("Description: \n{}".format(movie_desc))
# https://huggingface.co/blog/how-to-generate

HBox(children=(FloatProgress(value=0.0, description='Generating outputs', max=1.0, style=ProgressStyle(descrip…


Title: Pandora
Description: 
In the year 2077, the world is in the grip of a plague called the 'Pandora', which turns humans into mindless killing machines. A squad of elite soldiers led by Sergeant John Paul Selmer is sent to investigate.
