In [None]:
%matplotlib inline
import os
import csv
import random
import logging
from tqdm import tqdm, trange
import time
import numpy as np
import pandas as pd
import random
import torch
import pickle
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
import torch.nn as nn
from pytorch_pretrained_bert import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, OpenAIAdam, cached_path
#from torchnlp.metrics import get_moses_multi_bleu
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)
from torch.distributions import Categorical

In [None]:
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from bertviz.bertviz.pytorch_pretrained_bert import BertModel, BertTokenizer

In [None]:
special_tokens = ['<POS>', '<NEG>','<CON_START>','<START>','<END>'] # Set the special tokens
tokenizer = OpenAIGPTTokenizer.from_pretrained('./openai_gpt_vocab/', special_tokens=special_tokens)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # torch.device("cuda:1")
model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt', num_special_tokens=len(special_tokens))

### Load GST Model weights

In [None]:
yelp_bgst_model_path = "PATH of PRE-TRAINED BGST MODEL"
model_state_dict = torch.load(yelp_bgst_model_path)
model.load_state_dict(model_state_dict)
model.to(device)

### Define Decoding Algorithms

In [None]:
max_seq_len=120
sm = torch.nn.Softmax(dim=-1)

In [None]:
def top_p_sampling(ref_text, p=0.9):
    """
    This functions decodes the sentence by sampling from the samples whose culilitive probability
    is greater of equal to p.
    """
    sm = torch.nn.Softmax(dim=-1) # To calculate Softmax over the final layer Logits
    tokens = tokenizer.tokenize(ref_text) # Tokenize the input text
    
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokens) # Convert tokens to ids

    #index_tokens = [indexed_tokens for i in range(beam_width)]
    torch_tensor = torch.tensor(indexed_tokens).unsqueeze(0).to(device)
    current_state = torch.tensor(indexed_tokens).unsqueeze(0).to(device)
    #print(current_state)
    decoded_indexes = []
    count = 0
    stop_decode = False
    #while count < model.config.n_positions and not stop_decode:
    while count < 512 - len(tokens) and not stop_decode and count < len(tokens)+5:
        #print("Torch Tensor = {}".format(current_state))
        with torch.no_grad():
            preds = sm(model(current_state))
        sort_v, sort_i = torch.sort(preds[:,-1,:], descending=True)
        current_p = 0
        for i,x in enumerate(sort_v[0]):
            current_p += x
            #print(current_p)
            if current_p > p:
                final_v = sort_v[0][:i+1]
                final_i = sort_i[0][:i+1]
                break
        cat = torch.distributions.Categorical(final_v)
        sampled_index = cat.sample().item()
        #print("Final_v = {}".format(final_v))
        decoded_indexes.append(final_i.tolist()[sampled_index])
        current_state = torch.cat( (torch_tensor, torch.tensor(decoded_indexes).unsqueeze(0).to(device)), dim=1)
        count += 1
        if decoded_indexes[-1] == tokenizer.special_tokens["<END>"]:
            stop_decode = True
            
    try:
        end_index = decoded_indexes.index(tokenizer.special_tokens["<END>"])
    except ValueError:
        end_index = len(decoded_indexes)
    decoded_sentences = tokenizer.decode(decoded_indexes[:end_index])
    training_inputs = indexed_tokens + decoded_indexes[:end_index]  
    
    return decoded_sentences, decoded_indexes[:end_index], training_inputs

In [None]:
def top_p_decoding(ref_text, samples=5):
    decoded_sentences = [None] * samples
    decoded_indexes = [None] * samples
    training_inputs = [None] * samples
    
    for k in range(samples):
        decoded_sentences[k], decoded_indexes[k], training_inputs[k] = top_p_sampling(ref_text)
    return decoded_sentences, decoded_indexes, training_inputs

In [None]:
def preditction_with_beam_search(ref_text, beam_width=5, vocab_length=40483):
    """
    This function decodes sentences using Beam Seach. 
    It will output #sentences = beam_width. This function works on a single example.
    
    ref_text : string : Input sentence
    beam_width : int : Width of the output beam
    vocab_length : int : Size of the Vocab after adding the special tokens
    """
    
    done = [False for i in range(beam_width)] # To stop decoding for sentences after <END> token
    stop_decode = False
    decoded_sentences=[] # List of decoded sentences at any given time
    
    sm = torch.nn.Softmax(dim=-1) # To calculate Softmax over the final layer Logits
    tokens = tokenizer.tokenize(ref_text) # Tokenize the input text
    
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokens) # Convert tokens to ids
    index_tokens = [indexed_tokens for i in range(beam_width)] # Replication of Input ids for all the beams

    #index_tokens = [indexed_tokens for i in range(beam_width)]
    torch_tensor = torch.tensor(index_tokens).to(device)
    beam_indexes = [[] for i in range(beam_width)] # indexes of the current decoded beams
    best_scoes = [0 for i in range(beam_width)] # A list of lists to store Probability values of each decoded token of best beams
    count = 0
    #print("Tokens = {}".format(len(tokens)))
    while count < model.config.n_positions - len(tokens) and not stop_decode and count < len(tokens)+5:
        
        if count == 0: # For the first step when only one sentence is availabe
            with torch.no_grad():
                # Calculate output probability distribution over the Vocab,
                model.eval()
                preds = sm(model(torch_tensor)) #  shape = [beam_bidth, len(input_sen)+1,Vocab_length]
            top_v, top_i = preds[:,-1,:].topk(beam_width) # Fatch top indexes and it's values
            [beam_indexes[i].append(top_i[0][i].tolist()) for i in range(beam_width)] # Update the Beam indexes
            # Update the best_scores, for first time just add the topk values directly
            for i in range(beam_width):
                best_scoes[i] = top_v[0][i].item()
            count += 1
        else: # After first step
            # Prepare the current_state by concating original input and decoded beam indexes
            current_state = torch.cat((torch_tensor, torch.tensor(beam_indexes).to(device)), dim=1)
            # Prediction on the current state
            with torch.no_grad():
                model.eval()
                preds = sm(model(current_state))
            # Multiply new probability predictions with corresponding best scores
            # Total socres = beam_width * Vocab_Size
            flatten_score = (preds[:,-1,:]*torch.tensor(best_scoes).to(device).unsqueeze(1)).view(-1)
            # Fatch the top scores and indexes 
            vals, inx = flatten_score.topk(beam_width)
            # best_score_inx saves the index of best beams after multiplying the probability of new prediction
            best_scoes_inx = (inx / vocab_length).tolist()
            best_scoes = vals.tolist()
            # Unflatten the index 
            correct_inx = (inx % vocab_length).tolist()
            
            # Check if done for all the Beams
            for i in range(beam_width):
                if correct_inx[i] == tokenizer.special_tokens["<END>"]:
                    done[i] = True
            # Update the best score for each the current Beams
            for i in range(beam_width):
                if not done[i]:
                    best_scoes[i] = vals.tolist()[i]
            # Check is All the Beams are Done
            if (sum(done) == beam_width):
                stop_decode = True
            # Prepapre the new beams
            temp_lt=[0 for i in range(beam_width)]
            for i,x in enumerate(best_scoes_inx):
                temp_lt[i] = beam_indexes[x] + [correct_inx[i]]
            # Update the Beam indexes
            beam_indexes = temp_lt
            del temp_lt
            count += 1
    # Decode All the beam indexes to sentences by removing the "<END>" token
    training_inputs = [None for i in range(beam_width)]
    for i in range(beam_width):
        
        try:
            end_index = beam_indexes[i].index(tokenizer.special_tokens["<END>"])
        except ValueError:
            end_index = len(beam_indexes[i])
        training_inputs[i] = index_tokens[i] + beam_indexes[i][:end_index]
        
        decoded_sentences.append(tokenizer.decode(beam_indexes[i][:end_index]))
        beam_indexes[i] = beam_indexes[i][:end_index]
        
    return decoded_sentences, beam_indexes, training_inputs

### Load Data

In [None]:
rl_training_samples = 2000
np.random.seed = 5

In [None]:
train_sentiment0_path = "Path of sentiment0 (Negative) training data"
train_sentiment1_path = "Path of sentiment0 (Positive) training data"
with open(train_sentiment0_path) as fp:
    train0 = fp.read().splitlines()
with open(train_sentiment1_path) as fp:
    train1 = fp.read().splitlines()
    
train0_indexes = np.random.choice(len(train0), rl_training_samples//2, replace=False)
train1_indexes = np.random.choice(len(train1), rl_training_samples//2, replace=False)

train0 = [train0[z1] for z1 in train0_indexes]
train1 = [train1[z1] for z1 in train1_indexes]

train0 = ["<CON_START> {} <START>".format(x) for x in train0]
train1 = ["<CON_START> {} <START>".format(x) for x in train1]

train = train0 + train1

In [None]:
tags = ["<POS>","<NEG>"]
train = [t + " " + x for x in train0 + train1 for t in tags]

In [None]:
r0 = [x.replace("<POS> ","").replace("<CON_START> ","").replace(" <START>","") for x in train0]
r1 = [x.replace("<NEG> ","").replace("<CON_START> ","").replace(" <START>","") for x in train1]
r = r0 + r1
r[:5]

In [None]:
#r_for_bleu = [x1 for x1 in r for k1 in range(10*2)]
r_for_bleu = [x1 for x1 in r for k1 in range(10)]
len(r_for_bleu), r_for_bleu[:40]

In [None]:
lbls = [1,0]
labls = [l for i in range(len(train0) + len(train1)) for l in lbls]

In [None]:
print(labls[:10], len(labls), len(train))

### Load Feedback classifier 

In [None]:
BERT_CLASSIFIER_PATH = "path of bert classifier"
model_cls = BertForSequenceClassification.from_pretrained(BERT_CLASSIFIER_PATH, num_labels=2)
tokenizer_cls = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model_cls.to(device)
model_cls.eval()

In [None]:
def run_examples(input_sentences, bs=32, dvc='cpu'):
    """
    To evaluate whole dataset and return predictions
    """
    ids = []
    segment_ids = []
    input_masks = []
    pred_lt = []
    pred_value = []
    for sen in input_sentences:
        text_tokens = tokenizer_cls.tokenize(sen)[:max_seq_len-2]
        tokens = ["[CLS]"] + text_tokens + ["[SEP]"]
        temp_ids = tokenizer_cls.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(temp_ids)
        segment_id = [0] * len(temp_ids)
        padding = [0] * (max_seq_len - len(temp_ids))

        temp_ids += padding
        input_mask += padding
        segment_id += padding
        
        ids.append(temp_ids)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
    
    ids = torch.tensor(ids) # .to('cuda')
    segment_ids = torch.tensor(segment_ids) # .to('cuda')
    input_masks = torch.tensor(input_masks) # .to('cuda')
    
    steps = len(ids) // bs
    
    for i in trange(steps+1):
        if i == steps:
            temp_ids = ids[i * bs : len(ids)]
            temp_segment_ids = segment_ids[i * bs: len(ids)]
            temp_input_masks = input_masks[i * bs: len(ids)]
        else:
            temp_ids = ids[i * bs : i * bs + bs]
            temp_segment_ids = segment_ids[i * bs: i * bs + bs]
            temp_input_masks = input_masks[i * bs: i * bs + bs]
        
        temp_ids = temp_ids.to(dvc)
        temp_segment_ids = temp_segment_ids.to(dvc)
        temp_input_masks = temp_input_masks.to(dvc)
        
        with torch.no_grad():
            preds = sm(model_cls(temp_ids, temp_segment_ids, temp_input_masks))
        
        try:
            ps = torch.argmax(preds, dim=-1)
            pred_lt.extend(ps.tolist())
        except RuntimeError:
            pass
        pred_value.extend(preds.tolist())
    
    pred_value1 = [z2[pred_lt[k5]] for k5,z2 in enumerate(pred_value)]
    return pred_lt, pred_value1

In [None]:
input_data = train
input_data_label = labls
r_for_bleu = r_for_bleu

### BERT model for token level feedback for classification

In [None]:
bert_model = BertModel.from_pretrained(BERT_CLASSIFIER_PATH)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
bert_model.to(device)
bert_model.eval()

In [None]:
def run_attn_examples(input_sentences, dvc='cpu',bs=32, head=7):
    """
    To evaluate whole dataset and return predictions
    """
    ids = []
    ids_to_decode = [None for k in range(len(input_sentences))]
    tokens_to_decode = [None for k in range(len(input_sentences))]
    segment_ids = []
    input_masks = []
    attention_weights = [None for z in input_sentences]
    for j,sen in enumerate(tqdm(input_sentences)):
        sen = sen..replace(" ' ","'").replace("ca n't", "can not").replace("wo n't","will not").replace("n't", " not")
        text_tokens = bert_tokenizer.tokenize(sen)[:max_seq_len-5]
        tokens = ["[CLS]"] + text_tokens + ["[SEP]"]
        tokens_to_decode[j] = tokens + ['[PAD]']
        temp_ids = bert_tokenizer.convert_tokens_to_ids(tokens) + [0]
        ids_to_decode[j] = temp_ids
        input_mask = [1] * len(temp_ids)
        segment_id = [0] * len(temp_ids)
        padding = [0] * (max_seq_len - len(temp_ids))
        
        
        temp_ids += padding
        input_mask += padding
        segment_id += padding
        
        ids.append(temp_ids)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
    
    ids = torch.tensor(ids) # .to('cuda')
    segment_ids = torch.tensor(segment_ids) #.to('cuda')
    input_masks = torch.tensor(input_masks) #.to('cuda')
    
    steps = len(ids) // bs
    
    for i in trange(steps+1):
        if i == steps:
            temp_ids = ids[i * bs : len(ids)]
            temp_segment_ids = segment_ids[i * bs: len(ids)]
            temp_input_masks = input_masks[i * bs: len(ids)]
        else:
            temp_ids = ids[i * bs : i * bs + bs]
            temp_segment_ids = segment_ids[i * bs: i * bs + bs]
            temp_input_masks = input_masks[i * bs: i * bs + bs]
        
        temp_ids = temp_ids.to(dvc)
        temp_segment_ids = temp_segment_ids.to(dvc)
        temp_input_masks = temp_input_masks.to(dvc)
        with torch.no_grad():
             _, _, attn = bert_model(temp_ids, temp_segment_ids, temp_input_masks)
        
        for j in range(len(attn[9]['attn_probs'])):
            attention_weights[i * bs + j] = (attn[9]['attn_probs'][j][head][0]).to('cpu')
    
    return attention_weights, ids_to_decode, tokens_to_decode

In [None]:
common_words=['is','are','was','were','has','have','had','a','an','the','this','that','these','those','there','how','i','we',
             'he','she','it','they','them','their','his','him','her','us','our', 'and','in','my','your','you', 'will', 'shall']
common_words_tokens = bert_tokenizer.convert_tokens_to_ids(common_words)
not_to_remove_ids = bert_tokenizer.convert_tokens_to_ids(["[CLS]","[SEP]", ".", "?", "!"])
not_to_remove_ids += common_words_tokens

In [None]:
def prepare_data(aw, ids_to_decode, tokens_to_decode, threshold=0.5):
    out_sen = [None for i in range(len(aw))]
    for i in trange(len(aw)):
        #topv, topi = aw[i].topk(len(inps_tokens[i]))
        topv, topi = aw[i].topk(ids_to_decode[i].index(0))
        topi = topi.tolist()
        topv = topv.tolist()
        
        #print("Original Top Indexes = {}".format(topi))
        topi = [topi[j] for j in range(len(topi)) if ids_to_decode[i][topi[j]] not in not_to_remove_ids] # remove noun and common words
        #print("After removing Nouns = {}".format(topi))
        topi = [topi[j] for j in range(len(topi)) if "##" not in tokens_to_decode[i][topi[j]]] # Remove half words
        #print("After removing Half-words = {}".format(topi))
        
        topi = topi[:int(threshold * len(topi))]
        #print("Final Topi = {}".format(topi))
        final_indexes = []
        count = 0
        count1 = 0
        #print(ids_to_decode[i], tokens_to_decode[i])
        while ids_to_decode[i][count] != 0:
            if count in topi:
                while ids_to_decode[i][count + count1 + 1] != 0:
                    if "##" in tokens_to_decode[i][count + count1 + 1]:
                        count1 += 1
                    else:
                        break
                count += count1
                count1 = 0
            else:
                final_indexes.append(ids_to_decode[i][count])
            count += 1

        #print(final_indexes)
        temp_out_sen = tokenizer.convert_ids_to_tokens(final_indexes)
        temp_out_sen = " ".join(temp_out_sen).replace(" ##", "").replace("[CLS]","").replace("[SEP]","")
        #print(temp_out_sen, "\n\n")
        out_sen[i] = temp_out_sen.strip()
    
    return out_sen

### Load LM for Perplexity feedback

In [None]:
lm_model_special_tokens = ["<POS>","<NEG>","<END>"]
lm_tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt', special_tokens=lm_model_special_tokens)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lm_model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt', num_special_tokens=len(lm_model_special_tokens))

openai_gpt_lm_path = "PATH OF OPENAI GPT MODEL"
lm_model_state_dict = torch.load(openai_gpt_lm_path)
lm_model.load_state_dict(lm_model_state_dict)
lm_model.to(device)
lm_model.eval()

In [None]:
def calculate_ppl(tokenized_ids, bs=32, dvc='cpu'):
    
    lm_loss = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='none')
    # tokenize the sentences
    #tokenized_ids = [None for i in range(len(sentence_batch))]
    ppl = [None for i in range(len(tokenized_ids))]
    ppl_lt = [None for i in range(len(tokenized_ids))]
    #tokens_lt = [None for i in range(len(tokenized_ids))]
        
    #print(tokens_lt)
    
    sen_lengths = [len(x) for x in tokenized_ids]
    max_sen_length = max(sen_lengths)
    
    n_batch = len(tokenized_ids)
    input_ids = np.zeros( shape=(n_batch, max_sen_length), dtype=np.int64)
    lm_labels = np.full(shape=(n_batch, max_sen_length), fill_value=-1)
    
    for i, tokens in enumerate(tokenized_ids):
        input_ids[i, :len(tokens)] = tokens
        lm_labels[i, :len(tokens)-1] = tokens[1:] 
    
    input_ids = torch.tensor(input_ids)#.to(device)
    lm_labels = torch.tensor(lm_labels)#.to(device)
    
    steps = n_batch // bs
    print("Steps = {}".format(steps))
    if (n_batch % bs == 0):
        steps = steps - 1
        print("Steps = {}".format(steps))
    print("Steps = {}".format(steps))
    for i in trange(steps+1):
        
        if i == steps:
            temp_input_ids = input_ids[i * bs : n_batch]
            temp_lm_labels = lm_labels[i * bs : n_batch]
            temp_sen_lengths = sen_lengths[i * bs : n_batch]
        else:
            temp_input_ids = input_ids[i * bs : i * bs + bs]
            temp_lm_labels = lm_labels[i * bs : i * bs + bs]
            temp_sen_lengths = sen_lengths[i * bs : i * bs + bs]
            
        temp_input_ids = temp_input_ids.to(dvc)
        temp_lm_labels = temp_lm_labels.to(dvc)
            
        with torch.no_grad():
            lm_pred = lm_model(temp_input_ids)
            
        loss_val = lm_loss(lm_pred.view(-1, lm_pred.size(-1)), temp_lm_labels.view(-1))
        normalized_loss = loss_val.view(len(temp_input_ids),-1).sum(dim= -1) / torch.tensor(temp_sen_lengths, dtype=torch.float32).to(device)
        #normalized_loss_lt = loss_val.view(len(temp_input_ids),-1) / torch.tensor(temp_sen_lengths, dtype=torch.float32).to(device)
        normalized_loss_lt = loss_val.view(len(temp_input_ids),-1)# / torch.tensor(temp_sen_lengths, dtype=torch.float32).to(device)
        tmp_ppl = torch.exp(normalized_loss)
        #tmp_ppl_lt = torch.exp(normalized_loss_lt)
        ppl[i * bs: i * bs + len(temp_input_ids)] = tmp_ppl.tolist()
        ppl_lt[i * bs: i * bs + len(temp_input_ids)] = normalized_loss_lt.tolist()
    ppl_lt = [(ppl_lt[i1][:sen_lengths[i1]-1],tokenized_ids[i1][1:]) for i1 in range(len(ppl_lt))]
        
    
    return  ppl, ppl_lt


In [None]:
from collections import Counter
import math
def bleu_stats(hypothesis, reference):
    """Compute statistics for BLEU."""
    stats = []
    stats.append(len(hypothesis))
    stats.append(len(reference))
    for n in range(1, 5):
        s_ngrams = Counter(
            [tuple(hypothesis[i:i + n]) for i in range(len(hypothesis) + 1 - n)]
        )
        r_ngrams = Counter(
            [tuple(reference[i:i + n]) for i in range(len(reference) + 1 - n)]
        )
        stats.append(max([sum((s_ngrams & r_ngrams).values()), 0]))
        stats.append(max([len(hypothesis) + 1 - n, 0]))
    return stats

def bleu(stats):
    """Compute BLEU given n-gram statistics."""
    if len(list(filter(lambda x: x == 0, stats))) > 0:
        return 0
    (c, r) = stats[:2]
    log_bleu_prec = sum(
        [math.log(float(x) / y) for x, y in zip(stats[2::2], stats[3::2])]
    ) / 4.
    return math.exp(min([0, 1 - float(r) / c]) + log_bleu_prec)

def get_bleu(hypotheses, reference):
    """Get validation BLEU score for dev set."""
    stats = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
    for hyp, ref in zip(hypotheses, reference):
        stats += np.array(bleu_stats(hyp, ref))
    return 100 * bleu(stats)

In [None]:
#lr=6.25e-5 # original
output_dir = "Path to save the output model"
lr=6.25e-6 
warmup=0.002
max_grad_norm=1
weight_decay=0.01

In [None]:
max_input_len = 90
num_train_epochs = 5
batch_size = 16
beam_size = 5
bs = batch_size // beam_size
steps = len(input_data) * beam_size // batch_size
if len(input_data) * beam_size % batch_size == 0:
    steps = steps - 1
print(steps, len(input_data), beam_size, batch_size)
#steps = len(input_data) // (batch_size * beam_size)

In [None]:
rollout_input_labels = [x for x in input_data_label for j in range(1 * beam_size)]
sent_labls = tokenizer.convert_tokens_to_ids(["<NEG>","<POS>"])
sents = [sent_labls[x1] for x1 in rollout_input_labels]

In [None]:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
num_train_optimization_steps = len(input_data) * num_train_epochs * beam_size // batch_size
optimizer = OpenAIAdam(optimizer_grouped_parameters,
                       lr= lr,
                       warmup= warmup,
                       max_grad_norm= max_grad_norm,
                       weight_decay= weight_decay,
                       t_total=num_train_optimization_steps)

### RL Training

In [None]:
nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
for epoch in trange(int(num_train_epochs), desc="Epoch"):
    tr_loss = 0
    nb_tr_steps = 0
    predictions = [None for i in range(len(input_data))]
    prediction_indexes = [None for i in range(len(input_data))]
    training_indexes = [None for i in range(len(input_data))]
    random_indx = np.random.randint(0, len(input_data), len(input_data))
    #input_data = [input_data[random_indx[k1]] for k1 in range(len(input_data))]
    #input_data_label = [input_data_label[random_indx[k1]] for k1 in range(len(input_data_label))]
    
    # Create Trejectories of all the sentences
    print("Rolling out Trejectories...\n")
    with open("policy_gradient_roll_outs_after_{}_epochs_topp_sampling_dev_cls_lm.txt".format(epoch),'w') as fp:
        for i,x in enumerate(tqdm(input_data)):
            #predictions[i], prediction_indexes[i], training_indexes[i] = preditction_with_beam_search(x, beam_width=beam_size)
            predictions[i], prediction_indexes[i], training_indexes[i] = top_p_decoding(x)
            for z in predictions[i]:
                fp.write(z + "\n")
    predictions = [y for x in predictions if x != None for y in x] # Generated sentences
    prediction_indexes = [y for x in prediction_indexes if x != None for y in x ] # Generated indexes
    training_indexes = [y for x in training_indexes if x != None for y in x ] # Indexes to feed for training
    pointer_indexes=[(x.index(40481), len(x)-1) for x in training_indexes] # Tuples which points satrt and end
    #predictions_for_ppl = [x11 +' '+ y11 for x11,y11 in zip(sents, predictions)]
    predictions_for_ppl = [[x11] + y11 for x11,y11 in zip(sents, prediction_indexes)]
    
    print("Classifier Predictions...")
    classifier_preds, classifier_values = run_examples(predictions, dvc=device, bs=64)
    classifier_preds_comparison = [x == y for x,y in zip(classifier_preds,rollout_input_labels )]
    false_sentences = {i2:predictions[i2] for i2,x in enumerate(classifier_preds_comparison) if not x}
    
    false_sens = list(false_sentences.values())
    false_sens_indexes = list(false_sentences.keys())
    #aw, ids, tkns = run_attn_examples(false_sens)
    aw, ids, tkns = run_attn_examples(predictions,dvc=device, bs=32)
    data1 = prepare_data(aw,ids,tkns)
    
    attn_words = [None for x in data1]
    attn_words_index = [None for x in data1]

    for i5, sen in enumerate(data1):
        tmp_lt = []
        tmp_lt1 = []
        for tkns in predictions[i5].split():
            if tkns not in sen:
                tmp_lt.append(tkns)
                tmp_lt1.append(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(tkns)[0]))
        attn_words[i5] = (predictions[i5],tmp_lt)
        attn_words_index[i5] = tmp_lt1
    with open("Attention_words_{}_r0_tmp.txt".format(epoch+1),'w') as fp1:
        for sen1, atw in attn_words:
            #print(sen1, " ".join(atw))
            fp1.write(sen1 + "\tAttributes --> " + " ".join(atw) + "\n")
            
    # prediction_index_with_return
    prediction_index_with_return=[[] for i in prediction_indexes]
    for i3 in range(len(prediction_indexes)):
        if i3 in false_sens_indexes:
            for j1 in range(len(prediction_indexes[i3])):
                if prediction_indexes[i3][j1] in attn_words_index[i3]:
                    prediction_index_with_return[i3].append([-classifier_values[i3],prediction_indexes[i3][j1]])
                    #prediction_index_with_return[i3].append((-1,prediction_indexes[i3][j1]))
                else:
                    prediction_index_with_return[i3].append([0.0,prediction_indexes[i3][j1]])
        else:
            for j1 in range(len(prediction_indexes[i3])):
                if prediction_indexes[i3][j1] in attn_words_index[i3]:
                    prediction_index_with_return[i3].append([classifier_values[i3],prediction_indexes[i3][j1]])
                    #prediction_index_with_return[i3].append((-1,prediction_indexes[i3][j1]))
                else:
                    prediction_index_with_return[i3].append([0.0,prediction_indexes[i3][j1]])
    
    # Calculate Perplexity
    _,ppl_lt=calculate_ppl(predictions_for_ppl, dvc=device, bs=32)
    for i3 in range(len(prediction_indexes)):
        for j1 in range(len(prediction_indexes[i3])):
            prediction_index_with_return[i3][j1][0] += (np.exp(-ppl_lt[i3][0][j1]) - 0.1) * 1.2
    #print("After PPL = {}".format(prediction_index_with_return))
    
    # Calculate BLEU Score
    bleu_src = [0] * len(r_for_bleu)
    c1=0
    for x1,x2 in zip(predictions, r_for_bleu):
        bleu_src[c1] = get_bleu([x1],[x2])
        c1 += 1
    #print("BLEU Scores = {}".format(bleu_src))
    for i3 in range(len(prediction_indexes)):
        for j1 in range(len(prediction_indexes[i3])):
            prediction_index_with_return[i3][j1][0] += (bleu_src[i3] * 0.01) * 0.8 
    
    # Prepare training tensors
    training_indexes_tensor = np.zeros(shape=(len(training_indexes),90), dtype=np.int64)
    for i, tokens in enumerate(training_indexes):
        training_indexes_tensor[i,:len(tokens)] = tokens[:90]
    training_indexes_tensor = torch.tensor(training_indexes_tensor)
    
    for i in trange(steps + 1):
        
        train_tensor = training_indexes_tensor[i * batch_size : (i * batch_size) + batch_size]
        train_pointer_indexes = pointer_indexes[i * batch_size : (i * batch_size) + batch_size]
        train_prediction_indexes = prediction_indexes[i * batch_size : (i * batch_size) + batch_size]
        train_prediction_indexes_with_return = prediction_index_with_return[i * batch_size : (i * batch_size) + batch_size]

            
        model.train()
        preds = model(train_tensor.to('cuda'))
        
        # Calculate loss
        lt = []
        for i1 in range(preds.shape[0]):
            count = 0
            #print(i1)
            for j in range(train_pointer_indexes[i1][0],train_pointer_indexes[i1][1]):
                m = Categorical(logits=preds[i1][j % 90])
        
                temp = torch.tensor(train_prediction_indexes_with_return[i1][count][1]).to('cuda')
                #print(m.log_prob(temp).unsqueeze(0))
                lt.append(-1 * train_prediction_indexes_with_return[i1][count][0] * m.log_prob(temp).unsqueeze(0))
               
                count += 1
        
        optimizer.zero_grad()
        loss = torch.cat(lt).mean()
        print(loss)
        loss.backward()
        optimizer.step()
        
        if i % 500 == 0 :
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join("{}".format(output_dir), "pytorch_model_zero_grad_{}_{}_v1.bin".format(epoch, i))
            config = model.module.config if hasattr(model, 'module') else model.config
            torch.save(model_to_save.state_dict(), output_model_file)
    
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join("{}".format(output_dir), "pytorch_model_zero_grad_{}_final_v1.bin".format(epoch))
    config = model.module.config if hasattr(model, 'module') else model.config
    torch.save(model_to_save.state_dict(), output_model_file)