In [1]:
from __future__ import unicode_literals

import pandas as pd
from barbar import Bar
import numpy as np
import torch
from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel, AdamW, GPT2LMHeadModel
from transformers import BertTokenizer, BertForSequenceClassification
from itertools import chain
import json
import torch.nn as nn
import os
import re
import unicodedata
import torch.optim as optim
import torch.nn.functional as F
import random
from torch.utils.data import Dataset as Dataset
from torch.utils.data import DataLoader, random_split
from transformers import get_linear_schedule_with_warmup

from transformers import cached_path
import tarfile
import tempfile
import logging

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
print(device)



cuda


In [2]:
def build_input_from_segments(persona,history, reply, tokenizer, lm_labels=False, with_eos=True):
    
    sequence = [[bos] + list(chain(*persona))] + history + [reply + ([eos] if with_eos else [])]
    sequence = [sequence[0]] + [[speaker2 if (len(sequence)-i) % 2 else speaker1] + s for i, s in enumerate(sequence[1:])]
    instance = {}
    instance["input_ids"] = list(chain(*sequence))
    instance["token_type_ids"] = [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence) for _ in s]
    instance["mc_token_ids"] = len(instance["input_ids"]) - 1
    instance["lm_labels"] = [-100] * len(instance["input_ids"])
    if lm_labels:
        instance["lm_labels"] = ([-100] * sum(len(s) for s in sequence[:-1])) + [-100] + sequence[-1][1:]
    return instance


bos, eos, speaker1, speaker2, CLS  = "<BOS>", "<EOS>", "<speaker1>", "<speaker2>", "<CLS>"
SPECIAL_TOKENS = ['<BOS>', '<EOS>', '<speaker1>', '<speaker2>', '<PAD>']

In [3]:
def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k: <=0: no filtering, >0: keep only top k tokens with highest probability.
            top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset
                whose total probability mass is greater than or equal to the threshold top_p.
                In practice, we select the highest probability tokens whose cumulative probability mass exceeds
                the threshold top_p.
            threshold: a minimal threshold to keep logits
    """
    assert logits.dim() == 1  # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        # Remove all tokens with a probability less than the last token in the top-k tokens
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # Compute cumulative probabilities of sorted tokens
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probabilities > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Back to unsorted indices and set them to -infinity
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    indices_to_remove = logits < threshold
    logits[indices_to_remove] = filter_value

    return logits


In [8]:
history = []
reply = ""
first = 1
sent = ""

personalities = ["I have no emotion.", "I am angry.", "I am happy.", "I am sad.", "I am surprised."]

temperature = 0.7
decoding_strategy = "nucleus" # choose between greedy, beam, top-k, nucleus
while True:
        
        #initialize bert 
        
        tokenizer = BertTokenizer.from_pretrained('./Sentiment Analysis/')
        model = BertForSequenceClassification.from_pretrained('./Sentiment Analysis/',num_labels = 5, output_attentions = False, output_hidden_states = False)
        model.to(device)
        
        
        model.eval()
        
        
        
        if first !=1:
            #This is for getting the emotion of our previous reply
            ############################
            encoded_dict = tokenizer.encode_plus(
                            sent,                      # Sentence to encode.
                            add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                            max_length = 78,           # Pad & truncate all sentences.
                            pad_to_max_length = True,
                            return_attention_mask = True,   # Construct attn. masks.
                            return_tensors = 'pt',     # Return pytorch tensors.
                       )
            
            
            input_ids = [encoded_dict['input_ids']]
            attention_masks = [encoded_dict['attention_mask']]

            input_ids = torch.cat(input_ids, dim=0)
            attention_masks = torch.cat(attention_masks, dim=0)

            input_ids = input_ids.to(device)
            attention_mask = attention_masks.to(device)

            outputs = model(input_ids,token_type_ids=None,attention_mask = attention_mask)
            values, indices = torch.max(outputs[0][0],0)

            personality = indices.item()
            #print('Reply personality', personality)
            print('Reply personality',personality, personalities[personality])
        
        ###############################
        first = 0
        
        
        
        #get our input
        raw_text = input(">>> ")
        if raw_text == 'q': 
                break
        while not raw_text:
            print('Prompt should not be empty!')
            raw_text = input(">>> ")
            if raw_text == 'q': 
                break

               
    

        # This is for getting the emotion of our input
        ############################
        encoded_dict = tokenizer.encode_plus(
                        raw_text,                      # Sentence to encode.
                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                        max_length = 78,           # Pad & truncate all sentences.
                        pad_to_max_length = True,
                        return_attention_mask = True,   # Construct attn. masks.
                        return_tensors = 'pt',     # Return pytorch tensors.
                   )
        
        input_ids = [encoded_dict['input_ids']]
        attention_masks = [encoded_dict['attention_mask']]
        
        input_ids = torch.cat(input_ids, dim=0)
        attention_masks = torch.cat(attention_masks, dim=0)
        
        input_ids = input_ids.to(device)
        attention_mask = attention_masks.to(device)
        
        outputs = model(input_ids,token_type_ids=None,attention_mask = attention_mask)
        values, indices = torch.max(outputs[0][0],0)
        
                                                           
        
                                                           
        personality = indices.item()
        print('Input personality',personality, personalities[personality])  
        if personality > 1:
            personality += 2 
        #############################
        
        
        model = GPT2LMHeadModel.from_pretrained('./gpt2_model/').to(device)
        tokenizer = GPT2Tokenizer.from_pretrained('./gpt2_model/')
        
        
        seq0 = str(personality) 
        seq1 = re.sub("[^\w]", " ",  raw_text).split()
        seq2 = re.sub("[^\w]", " ",  reply).split()
        
        history.append(re.sub("[^\w]", " ",  raw_text).split())
        
        while True:
            
            instance =  build_input_from_segments(seq0,history,seq2, tokenizer, lm_labels=False, with_eos= False)
            input_ids = tokenizer.convert_tokens_to_ids(instance["input_ids"])

            token_type_ids = tokenizer.convert_tokens_to_ids(instance["token_type_ids"])

            input_ids = torch.tensor([input_ids], dtype=torch.long)
            
            token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)

            input_ids = input_ids.to(device)

            token_type_ids = token_type_ids.to(device)
            
           
            
            with torch.no_grad():
                outputs = model(input_ids=input_ids, token_type_ids = token_type_ids)
                
                
                
                predictions = outputs[0]
               
                ################
                #Changes based on decoding strategy 
                
                if decoding_strategy == "greedy":
                    value,greedy = torch.max(predictions[0, -1, :],0)
                    prev = greedy
    
                elif decoding_strategy == "top-k":
            
                    predictions = predictions[0, -1, :]/ temperature
                    predictions = top_filtering(predictions, top_k=20, top_p= 0.0)
                    probs = F.softmax(predictions, dim = -1)
                    prev = torch.multinomial(probs, 1)
                    
    
                elif decoding_strategy == "nucleus":
                
                    predictions = predictions[0, -1, :]/ temperature
                    predictions = top_filtering(predictions, top_k=0, top_p=0.9)
                    probs = F.softmax(predictions, dim = -1)
                    prev = torch.multinomial(probs, 1)   
                    
                
                predicted_text = tokenizer.decode([prev.item()])
                if predicted_text in SPECIAL_TOKENS:
                    x = 1
                else:
                    seq2.append(predicted_text)
                    
                    
            if predicted_text == '<EOS>' and len(seq2)>0:
                history.append(seq2)
                break
        
                
        #history = history[-(2):]
        sent = ""
        for i in range(len(seq2)):
            sent = sent + str(seq2[i]) + ' '
        print(sent)
         
print(history)     

>>> i am very hungry and i do not feel well
Input personality 3 I am sad.
i am so sorry i feel well i am very sick 
Reply personality 3 I am sad.
>>> i am sorry that you feel sick. Can i help you?
Input personality 3 I am sad.
i know i know what you need to do i ll be back soon 
Reply personality 3 I am sad.
>>> ok, i will be waiting for you here
Input personality 2 I am happy.
i will be right 
Reply personality 3 I am sad.
>>> can not wait
Input personality 2 I am happy.
how much 
Reply personality 0 I have no emotion.
>>> a lot
Input personality 3 I am sad.
i know 
Reply personality 3 I am sad.
>>> good see you soon
Input personality 2 I am happy.
i will be back 
Reply personality 2 I am happy.
>>> bye 
Input personality 2 I am happy.
i will be right 
Reply personality 3 I am sad.
>>> q
[['i', 'am', 'very', 'hungry', 'and', 'i', 'do', 'not', 'feel', 'well'], ['i', 'am', 'so', 'sorry', 'i', 'feel', 'well', 'i', 'am', 'very', 'sick'], ['i', 'am', 'sorry', 'that', 'you', 'feel', 'sick',

In [None]:
q