# START


In [1]:
import math
import re
import pickle
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
from datasets import load_dataset
from transformers import (
    TrainerCallback,
    GPT2Config,
    GPT2Tokenizer,
    GPT2LMHeadModel,
    AutoConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    AdamW,
    TrainingArguments,
    Trainer,
)

In [None]:
# Load pretrained tokenizer and model
from transformers import BartTokenizerFast, BartModel, pipeline, BartForConditionalGeneration, BartForCausalLM

tokenizer = BartTokenizerFast.from_pretrained('BART-movie-plot-generator')
model = BartForConditionalGeneration.from_pretrained('./BART-movie-plot-generator/checkpoint-12500')
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=model.to(device)

In [None]:
txt = "<s> <action> The lost village <\s>"
input_ids = tokenizer.encode(txt, return_tensors='pt')
input_ids=input_ids.to(device)
    
#Generate text with top-k nucleus sampling
sample_output = model.generate(
    input_ids,
    do_sample=True, 
    max_length=512, 
    top_k=50, 
    top_p=0.95, 
    num_return_sequences=1,
    bos_token_id=tokenizer.bos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    sep_token_id=tokenizer.sep_token_id,
    eos_token_id=tokenizer.eos_token_id,
    device=0
)[0]
print(tokenizer.decode(sample_output, skip_special_tokens=False))

In [None]:
stories = generator("<s> <action> The lost village <\s> Was once a large", max_length=512, num_return_sequences=1)
print(*[story['generated_text'] + "\n\n\n------------------------\n" for story in stories])