In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
from numpy.random import choice
import json
import tweepy

import pandas
from random import random
import copy
import torch
from transformers import pipeline
import re
import os
from random import shuffle
from tqdm import tqdm
import string
import torch.nn.functional as F
from transformers import AdamW, get_linear_schedule_with_warmup
import time
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
torch.cuda.empty_cache()

In [None]:
def read_book(title): #load a singel book from a memory
    with open("text/{}".format(title), "r", encoding = "utf-8") as f:
        text = f.read()
        f.close()
    return text
    
    

In [None]:
#clean up text

alphabets= "([A-Za-z])"
prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = "(Inc|Ltd|Jr|Sr|Co)"
starters = "(Mr|Mrs|Ms|Dr|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
websites = "[.](com|net|org|io|gov)"
digits = "([0-9])"

def split_into_sentences(text):
    text = " " + text + "  "
    text = text.replace("\n"," ")
    text = text.split("* * *")[-1]
    text = re.sub(r'^https?:\/\/.*[\r\n]*', '', text, flags=re.MULTILINE)
    text = re.sub(prefixes,"\\1<prd>",text)
    text = re.sub(websites,"<prd>\\1",text)
    if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
    text = re.sub(digits + "[.]" + digits,"\\1<prd>\\2",text)
    text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
    text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
    text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
    text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text)
    text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text)
    text = re.sub("\x0c","",text)
    if "”" in text: text = text.replace(".”","”.")
    if "\"" in text: text = text.replace(".\"","\".")
    if "!" in text: text = text.replace("!\"","\"!")
    if "?" in text: text = text.replace("?\"","\"?")
    text = text.replace(".",".<stop>")
    text = text.replace("?","?<stop>")
    text = text.replace("!","!<stop>")
    text = text.replace("<prd>",".")
    sentences = text.split("<stop>")
    sentences = sentences[:-1]
    sentences = [s.strip() for s in sentences]
    sentences = ["" if s.startswith("§") else s for s in sentences]
    #sentences = [s.lower() for s in sentences]
    sentences = [re.sub("^\d+\s|\s\d+\s|\s\d+$", " ", s) for s in sentences]
    sentences = [s for s in sentences if len(s) > 10]
    sentences = [s for s in sentences if len(s)/len(s.split()) > 4]
    sentences = [s+" " for s in sentences]
    return sentences

In [None]:
#Tokenize text. Follows RoBERTa paper (adapted for text generation problem - i.e. no random words)

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

def tokenize_sentences(sentences, tokenizer = tokenizer, max_length = 128):
    
    inputs = []
    labels = []
    tokenized_sentences = []
    for sentence in sentences:
        tok = tokenizer(sentence, return_tensors="pt", truncation=True, max_length = max_length)
        tokenized_sentences.append(tok)
    
    tokenized_sentence = tokenized_sentences[0]  
    
    for tokenized in tokenized_sentences[1:]:
        
        size_tokenized_sentence = tokenized_sentence["input_ids"].shape[1]
        size_tokenized = tokenized["input_ids"].shape[1]
        
        if size_tokenized_sentence + size_tokenized < max_length:
            tokenized_sentence = {'input_ids': torch.hstack([tokenized_sentence['input_ids'],
                                               tokenized['input_ids']]),
                                 'attention_mask':torch.hstack([tokenized_sentence['attention_mask'],
                                               tokenized['attention_mask']])
                                 }
        else:
            target = torch.ones(1,  max_length)
            target[:,  :tokenized_sentence['input_ids'].shape[1]] = tokenized_sentence['input_ids']
            tokenized_sentence['input_ids'] = target.type(torch.LongTensor).cuda()
            
            target = torch.zeros(1,  max_length)
            target[:, :tokenized_sentence['attention_mask'].shape[1]] = tokenized_sentence['attention_mask']
            tokenized_sentence['attention_mask'] = target.type(torch.LongTensor).cuda()
            
            tokenized_sentence_masked = copy.deepcopy(tokenized_sentence)
            
            inputs.append(tokenized_sentence_masked)
            labels.append(tokenized_sentence['input_ids'])
            tokenized_sentence = tokenized
    
    return inputs,labels

In [None]:
def batch_data(inputs, labels, idx,batch_size=4):
    new_inputs = copy.deepcopy(inputs[idx])
    new_labels = copy.deepcopy(labels[idx])
    
    for temp_input, temp_label in zip(inputs[idx+1:idx+batch_size],labels[idx+1:idx+batch_size]):
        for key in temp_input.keys():
            new_inputs[key] = torch.vstack((new_inputs[key], temp_input[key]))
        new_labels = torch.vstack((new_labels,temp_label))
    
    return new_inputs, new_labels
 

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2-medium').cuda()
book_idx = 0
total_loss = 0

In [None]:
lr = 1e-4
num_total_steps = 6000
num_warmup_steps = int(num_total_steps*0.2)
updates_count = 0

optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False)  # To reproduce BertAdam specific behavior set correct_bias=False
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_total_steps)  # PyTorch scheduler

In [None]:
model.cuda()

scaler = GradScaler()

batch_size = 4
update_size = 256 #actuall batch_size after considering gradient accumulation
count_global = updates_count*update_size
count_local = 1
accumulation = update_size/batch_size
labels = None
inputs = None

while True:
    model.train()
    if updates_count > num_total_steps: #break out of the loop when learing rate hits 0
        break
    count_local = 1
    training_examples = 0
    for _ in tqdm(range(500)): #so every 500 speeches we have evaluation
        if book_idx >= 15750: #set to max number of documents
            book_idx = 0
        try:
            book = read_book(books_list[book_idx])
        except Exception as e:
            print(e) #just incase 
            
        book_idx += 1    
        sentences = split_into_sentences(book)
        if len(sentences) == 0:
            continue  #skip when book failed to load
            
        if labels == None:
            inputs,labels = tokenize_sentences(sentences)
            training_examples += len(labels)
        else:
            new_inputs, new_labels = tokenize_sentences(sentences)
            training_examples += len(new_labels)
            labels.extend(new_labels)
            inputs.extend(new_inputs)
            
        num_backwards = len(labels)//update_size
        
        if num_backwards > 0:
            idx = 0
            
            for _ in range(int(num_backwards*accumulation)):
                temp_input,temp_label = batch_data(inputs, labels, idx, batch_size)
                with autocast():
                    outputs = model(**temp_input, labels=temp_label)
                    loss = outputs.loss/accumulation
                
                total_loss += loss.item()
                scaler.scale(loss).backward()
                count_global += 1
                count_local += 1
                
                if int((count_global+1) % accumulation) == 0:

                    updates_count += 1
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
                    if (updates_count+1) % 1000 == 0: #save checkpoint every ____ updates
                        mean_loss = accumulation * total_loss/count_local
                        print("Saving model, {} steps, mean loss is {}".format(updates_count,mean_loss))
                        checkpoint = { 
                            'updates_count': updates_count,
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            "book_idx" : book_idx,
                            'total_loss':total_loss}
                            
                        torch.save(checkpoint, 'checkpoints/checkpoint_{}_{}.pth'.format(book_idx,updates_count))
                        torch.save(model.state_dict(), "models/pytorch_model_{}_{}.bin".format(book_idx,updates_count))


                idx += batch_size

        inputs = inputs[num_backwards*update_size:]
        labels = labels[num_backwards*update_size:]
        
    mean_loss = accumulation*total_loss/count_global
    
    
    model.eval()
    
    
    text1 = "Number of sentences: {}, Mean Loss: {}, Number of backprops: {}, Forwards: {}, Last book: {}".format(training_examples, mean_loss, updates_count,count_local,books_list[book_idx])
    text2 = "-------------------------------------------"
    
    with open("generated.txt","a", encoding = "utf-8") as f:
        f.write(text1)
        for i in range(5):
            sample_outputs = model.generate(
                                            bos_token_id=np.random.randint(1,50256),
                                            do_sample=True,   
                                            top_k=50, 
                                            max_length = 128,
                                            top_p=0.95, 
                                            num_return_sequences=1
                                        )

            for _, sample_output in enumerate(sample_outputs):
                generated_text = tokenizer.decode(sample_output, skip_special_tokens=True)
                text = "{}: {} \n lenght of generated text: {} ".format(i,generated_text,len(generated_text))  
                f.write(text)
                f.write("\n")

        f.write(text2)
        f.close()

            
    
    
checkpoint = { 
    'updates_count': updates_count,
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict(),
    "book_idx" : book_idx}
torch.save(checkpoint, 'checkpoints/checkpoint_{}_{}.pth'.format(book_idx,updates_count))
torch.save(model.state_dict(), "models/pytorch_model_{}_{}.bin".format(book_idx,updates_count))
