In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from math import ceil, floor

import numpy as np
import pandas as pd

import os
import urllib
import ast

In [2]:
### Config settings

## Data config
# Download the data again regardless if it already exists?
download_again = False


## DataLoader config
# Batch size
batch_size = 64


## Model config
# filenames for model storage
encoder_name = 'kotor-rnn-encoder-encinput.ptm'
decoder_name = 'kotor-rnn-decoder-encinput.ptm'

In [3]:
use_cuda = torch.cuda.is_available()
use_cuda = False
device = torch.device("cuda:0" if use_cuda else "cpu")
print(device)
torch.backends.cudnn.benchmark = True

cpu


In [4]:
url = 'https://github.com/hmi-utwente/video-game-text-corpora/raw/master/Star%20Wars:%20Knights%20of%20the%20Old%20Republic/data/dataset_20200716.csv'
filename = 'dataset_20200716.csv'
if (not os.path.exists(filename)) or download_again:
    urllib.request.urlretrieve(url,filename)

In [5]:
index_col = 'id'
usecols = None
#usecols = ['id','text','previous']
converters = {'previous':ast.literal_eval,
              'next':ast.literal_eval,
             }
data = pd.read_csv(filename,
                   index_col=index_col,
                   usecols=usecols,
                   converters=converters,
                  )
data

Unnamed: 0_level_0,speaker,listener,text,animation,comment,next,previous,source_dlg,audiofile
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
0,Anchorhead Tradesman,,Take care of yourself. The price of kolto tank...,[],,[],[None],tat17_news_01,NM17AANEWS11000_.mp3
1,Anchorhead Tradesman,,The Selkath put a bunch of export restrictions...,[],,[],[None],tat17_news_01,NM17AANEWS11001_.mp3
2,Anchorhead Tradesman,,I hear that Manaan is no longer shipping kolto...,[],,[],[None],tat17_news_01,NM17AANEWS11002_.mp3
3,Anchorhead Tradesman,,"If you have kolto tanks, use them sparingly. I...",[],,[],[None],tat17_news_01,NM17AANEWS11003_.mp3
4,Anchorhead Tradesman,,I'm sure I saw some holo-footage of you on the...,[],,[],[None],tat17_news_01,NM17AANEWS11004_.mp3
...,...,...,...,...,...,...,...,...,...
29208,Zaalbar,,It is a description of the ritual you have alr...,[],,[],"[29207, 29211]",kas25_ritualmark,
29209,Zaalbar,,I never went to the Shadowlands to prove mysel...,[],,[29210],"[29207, 29211]",kas25_ritualmark,
29210,Zaalbar,,You will have to follow whatever your instinct...,[],,[],[29209],kas25_ritualmark,
29211,Player,,Whatever. Just tell me what you know about it.,[],,"[29208, 29209]",[29206],kas25_ritualmark,


In [6]:
data.dtypes

speaker       object
listener      object
text          object
animation     object
comment       object
next          object
previous      object
source_dlg    object
audiofile     object
dtype: object

In [7]:
data.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 29213 entries, 0 to 29212
Data columns (total 9 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   speaker     29213 non-null  object
 1   listener    2231 non-null   object
 2   text        29213 non-null  object
 3   animation   29213 non-null  object
 4   comment     2658 non-null   object
 5   next        29213 non-null  object
 6   previous    29213 non-null  object
 7   source_dlg  29213 non-null  object
 8   audiofile   12325 non-null  object
dtypes: object(9)
memory usage: 2.2+ MB


In [8]:
for col in data.columns:
    try:
        desc = data[col].apply(len).describe()
        print('\t',col)
        print(desc)
        print()
    except:
        pass

	 speaker
count    29213.000000
mean         8.973094
std          4.522340
min          3.000000
25%          6.000000
50%          6.000000
75%         12.000000
max         24.000000
Name: speaker, dtype: float64

	 text
count    29213.000000
mean        82.780748
std         49.850241
min          1.000000
25%         39.000000
50%         74.000000
75%        122.000000
max        344.000000
Name: text, dtype: float64

	 animation
count    29213.000000
mean         4.144867
std          8.150476
min          2.000000
25%          2.000000
50%          2.000000
75%          2.000000
max        243.000000
Name: animation, dtype: float64

	 next
count    29213.000000
mean         1.498100
std          1.451002
min          0.000000
25%          1.000000
50%          1.000000
75%          2.000000
max         22.000000
Name: next, dtype: float64

	 previous
count    29213.000000
mean         1.611235
std          1.664769
min          1.000000
25%          1.000000
50%          1.0000

In [9]:
def add_start_stop_codon(data,column,start='\r',stop='\n',force=False):
    detect_start_stop = lambda s,start=start,stop=stop: start in s or stop in s
    codons_in_text = data[column].apply(detect_start_stop).any()
    if codons_in_text:
        if not force:
            raise ValueError('data already contains start or stop codon at column: {0}'.format(column))
    transform = lambda s,start=start,stop=stop: start+s+stop
    data[column] = data[column].apply(transform)

In [10]:
add_start_stop_codon(data,'text')
data['text']

id
0        \rTake care of yourself. The price of kolto ta...
1        \rThe Selkath put a bunch of export restrictio...
2        \rI hear that Manaan is no longer shipping kol...
3        \rIf you have kolto tanks, use them sparingly....
4        \rI'm sure I saw some holo-footage of you on t...
                               ...                        
29208    \rIt is a description of the ritual you have a...
29209    \rI never went to the Shadowlands to prove mys...
29210    \rYou will have to follow whatever your instin...
29211    \rWhatever. Just tell me what you know about i...
29212    \r[Obviously this was once a place of great ri...
Name: text, Length: 29213, dtype: object

In [11]:
class CustomDataset:
    def __init__(self,data):
        self.data = data
        self.tensors = {}
        
        #Make Vocab
        self.vocab = sorted(list(set(''.join(self.data['text']))))
        self.ch2i = { v:k for k,v in enumerate(self.vocab) }
        self.i2ch = { v:k for k,v in self.ch2i.items() }
        
    def __len__(self):
        return(len(self.data))
    def str2vec(self,text):
        out = torch.tensor([ self.ch2i[s] for s in text],dtype=torch.long)
        return(out)
    def vec2str(self,vec):
        out = ''.join([ self.i2ch[i.item()] for i in vec ])
        return(out)
    def get_dialogue(self,idx):
        try:
            dialogue = self.tensors[idx]
        except:
            if idx is None:
                dialogue = '\r\n'
            else:
                dialogue = self.data.loc[idx,'text']
            dialogue = self.str2vec(dialogue)
            self.tensors[idx] = dialogue
        return(dialogue)
    def __getitem__(self,idx):
        self.data.loc[idx,'text']
        response = self.get_dialogue(idx)
        ins = response[:-1]
        outs = response[1:]
        prevs = self.data.loc[idx,'previous']
        prevs = np.random.choice(prevs)
        prevs = self.get_dialogue(prevs)
        return(prevs,ins,outs)

In [12]:
dataset = CustomDataset(data)
print(len(dataset))
print()
print(dataset.vocab)
print()
print(dataset[0])
print()
print(dataset[len(dataset)-1])

29213

['\n', '\r', ' ', '!', '"', '#', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '^', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

(tensor([1, 0]), tensor([ 1, 53, 64, 74, 68,  2, 66, 64, 81, 68,  2, 78, 69,  2, 88, 78, 84, 81,
        82, 68, 75, 69, 15,  2, 53, 71, 68,  2, 79, 81, 72, 66, 68,  2, 78, 69,
         2, 74, 78, 75, 83, 78,  2, 83, 64, 77, 74, 82,  2, 71, 64, 82,  2, 73,
        84, 76, 79, 68, 67,  2, 83, 71, 81, 78, 84, 70, 71,  2, 83, 71, 68,  2,
        81, 78, 78, 69, 15]), tensor([53, 64, 74, 68,  2, 66, 64, 81, 68,  2, 78, 69,  2, 88, 78, 84, 81, 82,
        68, 75, 69, 15,  2, 53, 71, 68,  2, 79, 81, 72, 66, 68,  2, 78, 69,  2,
        74, 

In [13]:
def collate_fn(data_points):
    L_prevs = [len(p) for p,i,o in data_points]
    L_currents  = [len(i) for p,i,o in data_points]
    N_prevs = max(*L_prevs)
    N_currents = max(*L_currents)
    B = len(data_points)
    prevs = torch.zeros((B,N_prevs),dtype=data_points[0][0].dtype)
    ins = torch.zeros((B,N_currents),dtype=data_points[0][1].dtype)
    outs = torch.zeros((B,N_currents),dtype=data_points[0][2].dtype)
    for k in range(B):
        l_prevs = L_prevs[k]
        prevs[k,:l_prevs] = data_points[k][0]
        l_currents = L_currents[k]
        ins[k,:l_currents] = data_points[k][1]
        outs[k,:l_currents] = data_points[k][2]
    return((prevs,ins,outs),L_prevs,L_currents)

In [14]:
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         drop_last=True,
                                         pin_memory=True,
                                         collate_fn=collate_fn,
                                        )

In [15]:
class TrainablePositionalEncoding(nn.Module):
    def __init__(self,encoding_dim,num_of_features=None):
        super(TrainablePositionalEncoding,self).__init__()
        self.num_of_features = num_of_features
        self.encoding_dim = encoding_dim
        if encoding_dim%2!=0:
            raise ValueError('encoding_dim should be a multiple of two!')
        if num_of_features is None:
            self.exp_linear = nn.Linear(1,encoding_dim//2,bias=True)
            self.angle_linear = nn.Linear(1,encoding_dim//2,bias=False)
        else:
            self.exp_linear = nn.Linear(self.num_of_features,encoding_dim//2,bias=True)
            self.angle_linear = nn.Linear(self.num_of_features,encoding_dim//2,bias=True)
    def forward(self,x):
        if self.num_of_features is None:
            x = x.unsqueeze(-1)
        exp_tensor = torch.exp(self.exp_linear(x)/80)
        angle_tensor = self.angle_linear(x)
        out = torch.cat((exp_tensor*torch.sin(angle_tensor),exp_tensor*torch.cos(angle_tensor)),dim=-1)
        return(out)

In [16]:
def add_positional_info(x,ch2i = dataset.ch2i):
    try:
        space_idx = add_positional_info.space_idx
        punct_idx = add_positional_info.punct_idx
    except AttributeError:
        space_idx = ch2i[' ']
        punct_idx = torch.tensor([ ch2i[s] for s in ['.','!','?'] ]).to(device)
        add_positional_info.space_idx = space_idx
        add_positional_info.punct_idx = punct_idx
    punct_mask = torch.isin(x,punct_idx)
    punct_mask = punct_mask.cumsum(axis=1)
    space_mask = x==space_idx
    out = 0
    try:
        punct_mask_max = punct_mask.max().item()
    except RuntimeError:
        punct_mask_max = 0
    for punct_mark in range(punct_mask_max+1):
        punct_mark_mask = (punct_mask==punct_mark)
        out += (space_mask*punct_mark_mask).cumsum(axis=1)*punct_mark_mask
    out = torch.stack((x,out),dim=2)
    return(out)

In [17]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, out_dim, embedding_dim, rnn_units, n_layers=2):
        super(Encoder,self).__init__()
        self.n_layers = n_layers
        self.rnn_units = rnn_units
        self.out_dim = out_dim
        self.embedding = nn.Embedding(vocab_size,
                                      embedding_dim,
                                     )
        self.pos_encoding = TrainablePositionalEncoding(embedding_dim)
        self.grus = nn.ModuleList()
        self.linears = nn.ModuleList()
        self.initial_states = nn.ParameterList()
        for submodel_layers in range(1,n_layers+1):
            submodel_gru = nn.GRU(input_size=embedding_dim,
                                          hidden_size=self.rnn_units,
                                          num_layers=submodel_layers,
                                          batch_first=True,
                                         )
            submodel_linear = nn.Linear(rnn_units,
                                                out_dim,
                                                bias=True,
                                               )
            self.grus.append(submodel_gru)
            self.linears.append(submodel_linear)
            self.initial_states.append(nn.Parameter(torch.randn((submodel_layers,self.rnn_units,))))
        self.bias = nn.Parameter(torch.randn((out_dim,)))
    def batch_initial_states(self,batch_size):
        states = [ init_state.repeat((batch_size,1,1)).permute(1,0,2) for init_state in self.initial_states ]
        return(states)
    def forward(self, inputs, lengths, states=None,device=device):
        batch_size = len(lengths)
        if states is None:
            states = self.batch_initial_states(batch_size)
        x = self.embedding(inputs[...,0])
        x += self.pos_encoding(inputs[...,1].float())
        x = torch.nn.utils.rnn.pack_padded_sequence(x,lengths,batch_first=True,enforce_sorted=False)
        out = 0
        for k in range(len(self.grus)):
            # Apply GRU
            state = states[k]
            submodel_gru = self.grus[k]
            sub_out, state = submodel_gru(x,state)

            # Apply linear transform
            sub_out,_ = torch.nn.utils.rnn.pad_packed_sequence(sub_out, batch_first=True)
            sub_out = sub_out[torch.arange(len(lengths)),torch.tensor(lengths).to(device)-1]
            submodel_linear = self.linears[k]
            sub_out = submodel_linear(sub_out)
            
            # Collect in output
            out += sub_out
        return(out)
    def noisify(self,scale):
        with torch.no_grad():
            for p in self.grus.parameters():
                p.add_(torch.randn_like(p),alpha=scale)
            for p in self.linears.parameters():
                p.add_(torch.randn_like(p),alpha=scale)

In [18]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, rnn_units, n_layers=2):
        super(Decoder,self).__init__()
        self.n_layers = n_layers
        self.rnn_units = rnn_units
        self.embedding = nn.Embedding(vocab_size,
                                      embedding_dim,
                                     )
        self.pos_encoding = TrainablePositionalEncoding(embedding_dim)
        self.grus = nn.ModuleList()
        self.linears = nn.ModuleList()
        self.initial_states = nn.ParameterList()
        for submodel_layers in range(1,n_layers+1):
            submodel_gru = nn.GRU(input_size=embedding_dim,
                                          hidden_size=self.rnn_units,
                                          num_layers=submodel_layers,
                                          batch_first=True,
                                         )
            submodel_linear = nn.Linear(rnn_units,
                                                vocab_size,
                                                bias=False,
                                               )
            self.grus.append(submodel_gru)
            self.linears.append(submodel_linear)
            self.initial_states.append(nn.Parameter(torch.randn((submodel_layers,self.rnn_units,))))
        self.bias = nn.Parameter(torch.randn((vocab_size,)))
    def batch_initial_states(self,batch_size):
        states = [ init_state.repeat((batch_size,1,1)).permute(1,0,2) for init_state in self.initial_states ]
        return(states)
    def forward(self, inputs, encoding_tensor, lengths, states=None):
        if states is None:
            states = self.batch_initial_states(len(lengths))
        batch_size = len(lengths)
        x = self.embedding(inputs[...,0])
        x += self.pos_encoding(inputs[...,1].float())
        x += encoding_tensor[:,None,:]
        x = torch.nn.utils.rnn.pack_padded_sequence(x,lengths,batch_first=True,enforce_sorted=False)
        out = 0
        for k in range(len(self.grus)):
            # Apply GRU
            state = states[k]
            submodel_gru = self.grus[k]
            sub_out, state = submodel_gru(x,state)
            states[k] = state
            
            # Apply linear transform
            sub_out,_ = torch.nn.utils.rnn.pad_packed_sequence(sub_out, batch_first=True)
            submodel_linear = self.linears[k]
            sub_out = submodel_linear(sub_out)
            
            # Collect in output
            out += sub_out
        
        out += self.bias[None,None,:]
        return(out,states)
    def noisify(self,scale):
        with torch.no_grad():
            for p in self.grus.parameters():
                p.add_(torch.randn_like(p),alpha=scale)
            for p in self.linears.parameters():
                p.add_(torch.randn_like(p),alpha=scale)

In [19]:
vocab_size = len(dataset.vocab)
embedding_dim = 256
rnn_units = 256
n_layers = 5

encoder = Encoder(vocab_size=vocab_size,
                  out_dim = embedding_dim,
                  embedding_dim = embedding_dim,
                  rnn_units=rnn_units,
                  n_layers = n_layers,
             )
encoder.to(device)
print(encoder)

decoder = Decoder(vocab_size=vocab_size,
                  embedding_dim = embedding_dim,
                  rnn_units=rnn_units,
                  n_layers = n_layers,
             )
decoder.to(device)
print(decoder)

Encoder(
  (embedding): Embedding(90, 256)
  (pos_encoding): TrainablePositionalEncoding(
    (exp_linear): Linear(in_features=1, out_features=128, bias=True)
    (angle_linear): Linear(in_features=1, out_features=128, bias=False)
  )
  (grus): ModuleList(
    (0): GRU(256, 256, batch_first=True)
    (1): GRU(256, 256, num_layers=2, batch_first=True)
    (2): GRU(256, 256, num_layers=3, batch_first=True)
    (3): GRU(256, 256, num_layers=4, batch_first=True)
    (4): GRU(256, 256, num_layers=5, batch_first=True)
  )
  (linears): ModuleList(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): Linear(in_features=256, out_features=256, bias=True)
  )
  (initial_states): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 1x256]
      (1): Parameter containing: [tor

In [20]:
print('encoder parameters:',sum( np.prod(p.shape) for p in encoder.parameters()))
print('decoder parameters:',sum( np.prod(p.shape) for p in decoder.parameters()))

encoder parameters: 6277760
decoder parameters: 6063834


In [21]:
encoder.load_state_dict(torch.load(encoder_name))
encoder.to(device)
print('Loaded encoder')
decoder.load_state_dict(torch.load(decoder_name))
decoder.to(device)
print('Loaded decoder')

Loaded encoder
Loaded decoder


In [22]:
prevs,ins,outs = dataset[0]
N_prevs = len(prevs)
N_currents = len(ins)
prevs = prevs.to(device)[None,:]
ins = ins.to(device)[None,:]
outs = outs.to(device)[None,:]
prevs = add_positional_info(prevs)
ins = add_positional_info(ins)
print(prevs.shape)
print(prevs)
print(dataset.vec2str(prevs[0,:,0]))
print(''.join( str(i) for i in prevs[0,:,1].tolist()[2:]))
print(ins.shape)
print(ins)
print(dataset.vec2str(ins[0,:,0]))
print(''.join( str(i) for i in ins[0,:,1].tolist()[2:]))
encoder_tensor = encoder(prevs,[N_prevs])
print(encoder_tensor.shape)
print(encoder_tensor)
pred, states = decoder(ins,encoder_tensor,[N_currents])
print(states)
print(dataset.vec2str(torch.exp(pred).squeeze(0).multinomial(1)))

torch.Size([1, 2, 2])
tensor([[[1, 0],
         [0, 0]]])



torch.Size([1, 77, 2])
tensor([[[ 1,  0],
         [53,  0],
         [64,  0],
         [74,  0],
         [68,  0],
         [ 2,  1],
         [66,  1],
         [64,  1],
         [81,  1],
         [68,  1],
         [ 2,  2],
         [78,  2],
         [69,  2],
         [ 2,  3],
         [88,  3],
         [78,  3],
         [84,  3],
         [81,  3],
         [82,  3],
         [68,  3],
         [75,  3],
         [69,  3],
         [15,  0],
         [ 2,  1],
         [53,  1],
         [71,  1],
         [68,  1],
         [ 2,  2],
         [79,  2],
         [81,  2],
         [72,  2],
         [66,  2],
         [68,  2],
         [ 2,  3],
         [78,  3],
         [69,  3],
         [ 2,  4],
         [74,  4],
         [78,  4],
         [75,  4],
         [83,  4],
         [78,  4],
         [ 2,  5],
         [83,  5],
         [64,  5],
         [77,  5],
         [74,  5],
         [82,  5],
   

In [23]:
prevs,ins,outs = dataset[len(dataset)-1]
N_prevs = len(prevs)
N_currents = len(ins)
prevs = prevs.to(device)[None,:]
ins = ins.to(device)[None,:]
outs = outs.to(device)[None,:]
prevs = add_positional_info(prevs)
ins = add_positional_info(ins)
print(prevs.shape)
print(prevs)
print(dataset.vec2str(prevs[0,:,0]))
print(''.join( str(i) for i in prevs[0,:,1].tolist()[2:]))
print(ins.shape)
print(ins)
print(dataset.vec2str(ins[0,:,0]))
print(''.join( str(i) for i in ins[0,:,1].tolist()[2:]))
encoder_tensor = encoder(prevs,[N_prevs])
print(encoder_tensor.shape)
print(encoder_tensor)
pred, states = decoder(ins,encoder_tensor,[N_currents])
print(states)
print(dataset.vec2str(torch.exp(pred).squeeze(0).multinomial(1)))

torch.Size([1, 145, 2])
tensor([[[ 1,  0],
         [60,  0],
         [39,  0],
         [68,  0],
         [68,  0],
         [67,  0],
         [ 2,  1],
         [83,  1],
         [71,  1],
         [68,  1],
         [ 2,  2],
         [65,  2],
         [68,  2],
         [64,  2],
         [82,  2],
         [83,  2],
         [ 2,  3],
         [64,  3],
         [77,  3],
         [67,  3],
         [ 2,  4],
         [72,  4],
         [83,  4],
         [ 2,  5],
         [86,  5],
         [72,  5],
         [75,  5],
         [75,  5],
         [ 2,  6],
         [71,  6],
         [68,  6],
         [68,  6],
         [67,  6],
         [ 2,  7],
         [88,  7],
         [78,  7],
         [84,  7],
         [81,  7],
         [ 2,  8],
         [66,  8],
         [64,  8],
         [75,  8],
         [75,  8],
         [15,  0],
         [ 2,  1],
         [53,  1],
         [64,  1],
         [74,  1],
         [68,  1],
         [ 2,  2],
         [85,  2],
       

In [24]:
class Bot(nn.Module):
  def __init__(self, encoder, decoder, chars_from_ids, ids_from_chars, temperature=1.0):
    super().__init__()
    if type(temperature)==float or type(temperature)==int:
        temperature = lambda pred,*args,temp=temperature: pred/temp
    self.temperature = temperature
    self.encoder = encoder
    self.decoder = decoder
    self.chars_from_ids = chars_from_ids
    self.ids_from_chars = ids_from_chars
    self.softmax = nn.Softmax(-1)

  def generate_answer(self, inputs, max_length=500):
    self.encoder.eval()
    self.decoder.eval()
    # Convert strings to token IDs.
    input_ids = self.ids_from_chars(inputs)
    input_ids = input_ids[None,:]
    input_ids = input_ids.to(device)
    input_ids = add_positional_info(input_ids)
    
    # Encode input
    encoder_tensor = self.encoder(input_ids,[len(inputs)])

    # First Run
    input_ids = self.ids_from_chars('\r')
    input_ids = input_ids[None,:]
    input_ids = input_ids.to(device)
    input_ids = add_positional_info(input_ids)
    last_positional_info = input_ids[:,-1:,1]
    predicted_logits, states = self.decoder(input_ids,encoder_tensor,[1])
    
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    prediction_position = len(inputs)
    predicted_logits = self.temperature(predicted_logits,prediction_position)

    # Sample the output logits to generate token IDs.
    predicted_logits = self.softmax(predicted_logits)
    predicted_ids = predicted_logits.multinomial(1)

    # Convert from token ids to characters
    predicted_chars = self.chars_from_ids(predicted_ids[0])
    predicted_chars = list(predicted_chars)
    predicted_chars = predicted_chars[0]
    if predicted_chars == ' ':
        last_positional_info += 1
    elif predicted_chars in ['.','!','?']:
        last_positional_info *= 0

    run = '\r' + predicted_chars
    for _ in range(max_length):
        # Consecutive Run
        predicted_ids = torch.stack((predicted_ids,last_positional_info),dim=2)
        predicted_logits, states = self.decoder(predicted_ids,encoder_tensor,[1],states)

        # Only use the last prediction.
        predicted_logits = predicted_logits[:, -1, :]
        prediction_position += 1
        predicted_logits = self.temperature(predicted_logits,prediction_position)

        # Sample the output logits to generate token IDs.
        predicted_logits = self.softmax(predicted_logits)
        predicted_ids = predicted_logits.multinomial(1)

        # Convert from token ids to characters
        predicted_chars = self.chars_from_ids(predicted_ids[0])
        predicted_chars = list(predicted_chars)
        predicted_chars = predicted_chars[0]
        
        # Update positional info
        if predicted_chars == ' ':
            last_positional_info += 1
        elif predicted_chars in ['.','!','?']:
            last_positional_info *= 0

        run = run + predicted_chars
        if predicted_chars=='\n':
            break
    
    if run[-1]!='\n':
        run += '\n'
    
    return run

In [25]:
bot = Bot(encoder,decoder,dataset.vec2str,dataset.str2vec)

In [26]:
bot.generate_answer('\rWhat are you thinking?\n')

'\rThat world... It was freedom at Bastila.\n'

In [27]:
#temperature = lambda pred,x,b=0: (pred-b)*(pred>b)+(pred-b)*(pred<b)/.33 + b
temperature = lambda pred,x,b=.9: pred/(b+(1.-b)/np.sqrt(.1*x+1))
bot = Bot(encoder,decoder,dataset.vec2str,dataset.str2vec,temperature=temperature)
bot.generate_answer('\rWhat are you thinking?\n')

"\rThat world... I... I didn't know that of it. I must be able to help you with this, Mandalore! Your father stations may have been thinking for me!\n"

In [28]:
query = 'Got any money?'
print('Query:',query,end='\n\n\n')
for _ in range(10):
    print(bot.generate_answer('\r'+query+'\n'))

Query: Got any money?


Taris guards the prress of its hard significance before. Stock you're quite some stuch things. Brejik may have gained through this plate now.

Don't pass the different assassin in my previous use-tomeral.

There you are weakness - she still go aheaday after all. You take them for your advancent, yes? My ancient angratunates?

There is a Gamorrean of the light side, raving so. Jone does not know much about a Jedi...

There used to be a rancor species. Not personally. But how unfortunates takes you now?

There you are definitely 8 who will stop me until you do it many thousand beasts that Rorcus Lirs. It claims themselves's running of my allit-tod.

There is a docking fee on the last three, though he did not realize them. She was up against the Force... but you have alloed them.

There is a planet has been forged better shop to do. The capaties have no cruel something.

There is supposed to have a Jedi holocron? Haven's power genetity of our war because we have be

In [None]:
while True:
    query = input()
    print('Query:',query,end='\n\n\n')
    print(bot.generate_answer('\r'+query+'\n'),end='\n\n')

Help!
Query: Help!


We don't allow us to live down. There was more troops for him.


