# Install packages

In [None]:
# !python3 -m pip install transformers
# !python3 -m pip install statsmodels
# !python3 -m pip install tensorflow==1.4.1
# !python3 -m pip install torch
# !python3 -m pip install nltk
# import nltk
# nltk.download('punkt')

# Dependencies

In [None]:
from mle import Mandelbrot
import tensorflow as tf
import torch
import nltk
import re
import numpy as np
from tqdm import tqdm
from time import time
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# from gensim.corpora.wikicorpus import extract_pages, filter_wiki, process_article

# Model init

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)

model.to("cuda")

# Load dataset

In [None]:
# final, cleaned wikipedia file
subset = open("subset.txt", "r").read()

texts = subset.split('</doc>')[:-1]

# Generator

### Prompt Maker

In [None]:
def make_prompt(article, begin=True):
  
    article_sents = nltk.tokenize.sent_tokenize(article)
    sent_num = len(article_sents)

    if begin is True:
      if sent_num >= 3:
        return ' '.join(article_sents[:3])
      if sent_num >= 2:
        return ' '.join(article_sents[:2])
      else:
        return article_sents[0]
    elif sent_num >= 3:
      return ' '.join(article_sents[-4:])
    elif sent_num >= 2:
      return ' '.join(article_sents[-2:])
    else:
      return article_sents[-1]

### Section Generator

In [None]:
def generate_section(section, min_tokens, k, p, t, rep_pen, n_gram):
    
    prompt = make_prompt(section, begin=True)
    prompt_len = len(prompt.split())
        
    in_ids = tokenizer.encode(prompt, return_tensors='pt')
    in_ids = in_ids.to("cuda")
    
    curr_id_len = len(in_ids[0])
    max_len = min_len = curr_id_len+512

    out = []
    
    while len(out) <= min_tokens:
        
        torch.manual_seed(0)
            
        out_ids = model.generate(
            in_ids,
            do_sample=True,
            max_length=max_len,
            min_length=min_len,
            top_k=k,
            top_p=p,
            temperature=t,
            repetition_penalty=rep_pen,
            no_repeat_ngram_size=n_gram
        )
        
        output = tokenizer.decode(out_ids[0], skip_special_tokens=True).split()      
        output_no_prompt = output[prompt_len:]
        out.extend(output_no_prompt)
        
        output_full = " ".join(output)
        prompt = make_prompt(output_full, begin=False)
        prompt_len = len(prompt.split())
                
        in_ids = tokenizer.encode(prompt, return_tensors='pt')
        in_ids = in_ids.to("cuda")
        
        # Error control
        if len(in_ids[0]) == curr_id_len+128:
            
            out_sents = nltk.tokenize.sent_tokenize(" ".join(out))
            if len(out_sents) > 1:
                new_sents = " ".join(out_sents[1:])
            else:
                new_sents = " ".join(out_sents)
    
            prompt = make_prompt(new_sents, begin=True)
            prompt_len = len(prompt.split())
            
            in_ids = tokenizer.encode(prompt, return_tensors='pt')
            in_ids = in_ids.to("cuda")
              
        
        curr_id_len = len(in_ids[0])
        max_len = min_len = curr_id_len+512
    
    return out

In [None]:
def generate(txt, k, p, t, rep_pen, n_gram):
    
    new_tokens = []

    sections = txt.split('\n\n')
    for section in sections:
        sec_len = len(section.split())
        if sec_len > 10:
            
            curr_tokens = generate_section(
                section, 
                min_tokens=sec_len, 
                k=k, 
                p=p, 
                t=t, 
                rep_pen=rep_pen, 
                n_gram=n_gram
            )
            
            new_tokens.extend(curr_tokens[:sec_len] + ["\n\n"])

    new_txt = " ".join(new_tokens)
    return new_txt

### Generation Process

In [None]:
%%time

for txt in texts:
    print(i)
    curr_txt = generate(txt, k=None, p=0.95, t=1, rep_pen=1.0, n_gram=3)
    with open("gen_set.txt", "a", encoding='utf-8') as handle:
        handle.write(curr_txt)
        handle.write("\n\n</doc>\n\n\n\n")
    print(": completed")
#     print(curr_txt)
    i += 1