In [30]:
torch.cuda.is_available()


False

In [31]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math

In [7]:
torch.cuda.is_available()


False

In [32]:
from __future__ import print_function
import torch
x = torch.rand(5, 3)
print(x)


tensor([[0.1188, 0.2462, 0.4069],
        [0.9233, 0.1082, 0.3635],
        [0.7462, 0.9186, 0.2222],
        [0.0582, 0.3274, 0.4069],
        [0.0055, 0.4024, 0.0434]])


In [33]:
torch.cuda.is_available()


False

In [34]:
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

In [9]:
torch.cuda.is_available()


True

In [60]:
corpus_name = "cornell movie-dialogs corpus"
corpus = os.path.join(r"C:\Users\kanna\Desktop", corpus_name)

def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

printLines(os.path.join(corpus, "movie_lines.txt"))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


In [61]:
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines


def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            utterance_id_pattern = re.compile('L[0-9]+')
            lineIds = utterance_id_pattern.findall(convObj["utteranceIDs"])
            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
    return conversations


def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

In [62]:
datafile = os.path.join(corpus, "formatted_movie_lines.txt")

delimiter = '\t'
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

print("\nProcessing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"),
                                  lines, MOVIE_CONVERSATIONS_FIELDS)

print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

print("\nSample lines from file:")
printLines(datafile)


Processing corpus...

Loading conversations...

Writing newly formatted file...

Sample lines from file:
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister. 

In [66]:
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [65]:
MAX_LENGTH = 10  # Maximum sentence length to consider

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc, pairs

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs


save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words...
Counted words: 18008

pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [67]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    voc.trim(MIN_COUNT)
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7823 / 18005 = 0.4345
Trimmed from 64271 pairs to 53165, 0.8272 of total


In [68]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.BoolTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len


small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[ 219,  197,   68,  112,  883],
        [  25,  117,    7,   18,    4],
        [3412, 7247,   69,   36, 4937],
        [ 177, 7246,   70,    6,    6],
        [   4,   76,   71,    2,    2],
        [   4,   37,    6,    0,    0],
        [   4,    2,    2,    0,    0],
        [   4,    0,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([9, 7, 7, 5, 5])
target_variable: tensor([[ 219, 4423,   34,   18,  318],
        [  25,   76,    4, 4929,    6],
        [1010, 4448,    2,  203,    2],
        [   7,    4,    0,   12,    0],
        [   4,    2,    0, 1717,    0],
        [   4,    0,    0,    4,    0],
        [   4,    0,    0,    2,    0],
        [   2,    0,    0,    0,    0]])
mask: tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True, False,  True, False],
        [ True,  True, False,  True, False],
        [ True, False

In [69]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        embedded = self.embedding(input_seq)
        packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
        return outputs, hidden

In [70]:
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)
    

In [71]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, last_hidden, encoder_outputs):
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        rnn_output, hidden = self.gru(embedded, last_hidden)
        attn_weights = self.attn(rnn_output, encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        return output, hidden

In [72]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [73]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)

    loss = 0
    print_losses = []
    n_totals = 0

    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    decoder_hidden = encoder_hidden[:decoder.n_layers]

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_input = target_variable[t].view(1, -1)
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden, encoder_outputs
            )
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    loss.backward()

    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

In [74]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(n_iteration)]

    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

In [75]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, input_seq, input_length, max_length):
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        for _ in range(max_length):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        return all_tokens, all_scores

In [83]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    indexes_batch = [indexesFromSentence(voc, sentence)]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    tokens, scores = searcher(input_batch, lengths, max_length)
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words
from firebase import firebase
firebase = firebase.FirebaseApplication('https://demi-19b11.firebaseio.com/', None)
def evaluateInput(encoder, decoder, searcher, voc):   
    result = firebase.get('/input/', '')
    x=len(result)
    input_sentence = ''
    while(1):
        try:
            while(x==len(result)):
                result = firebase.get('/input/', '')

            j=0
            x=len(result)
            for i in result:
                if j==x-1:
                   input_sentence = result[i]
                j+=1
            #input_sentence = input('> ')
            if input_sentence == 'q' or input_sentence == 'quit':
                break 
            elif 'your name' in input_sentence:
                firebase.post('/output/',"I'm Demi and I'll be there for you")
            else:
                input_sentence = normalizeString(input_sentence)
                output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
                output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
                firebase.post('/output/'," ".join(output_words))


        except KeyError:
            firebase.post('/output/',"I don't know but whatever dude . . .")

In [21]:
from firebase import firebase
firebase = firebase.FirebaseApplication('https://demi-19b11.firebaseio.com/', None)
data="bih"
result = firebase.post('/output/', 'hello')


In [84]:
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

loadFilename = None
checkpoint_iter = 4000
loadFilename = os.path.join(save_dir, model_name, corpus_name,
                           '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
                            '{}_checkpoint.tar'.format(checkpoint_iter))


if loadFilename:
    checkpoint = torch.load(loadFilename)
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


In [30]:
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = 1
save_every = 500

encoder.train()
decoder.train()


print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)


for state in encoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

for state in decoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()


print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 1; Percent complete: 0.0%; Average loss: 8.9632
Iteration: 2; Percent complete: 0.1%; Average loss: 8.8471
Iteration: 3; Percent complete: 0.1%; Average loss: 8.6954
Iteration: 4; Percent complete: 0.1%; Average loss: 8.3812
Iteration: 5; Percent complete: 0.1%; Average loss: 8.0285
Iteration: 6; Percent complete: 0.1%; Average loss: 7.4813
Iteration: 7; Percent complete: 0.2%; Average loss: 6.9908
Iteration: 8; Percent complete: 0.2%; Average loss: 6.6937
Iteration: 9; Percent complete: 0.2%; Average loss: 6.9017
Iteration: 10; Percent complete: 0.2%; Average loss: 6.5428
Iteration: 11; Percent complete: 0.3%; Average loss: 6.3633
Iteration: 12; Percent complete: 0.3%; Average loss: 6.0605
Iteration: 13; Percent complete: 0.3%; Average loss: 5.1000
Iteration: 14; Percent complete: 0.4%; Average loss: 5.6447
Iteration: 15; Percent complete: 0.4%; Average loss: 5.7529
Iteration: 16; Percent complete: 0.4%

Iteration: 269; Percent complete: 6.7%; Average loss: 4.0793
Iteration: 270; Percent complete: 6.8%; Average loss: 3.7750
Iteration: 271; Percent complete: 6.8%; Average loss: 3.8930
Iteration: 272; Percent complete: 6.8%; Average loss: 3.8035
Iteration: 273; Percent complete: 6.8%; Average loss: 3.9466
Iteration: 274; Percent complete: 6.9%; Average loss: 4.0039
Iteration: 275; Percent complete: 6.9%; Average loss: 3.5757
Iteration: 276; Percent complete: 6.9%; Average loss: 4.0177
Iteration: 277; Percent complete: 6.9%; Average loss: 4.1376
Iteration: 278; Percent complete: 7.0%; Average loss: 4.0673
Iteration: 279; Percent complete: 7.0%; Average loss: 3.9291
Iteration: 280; Percent complete: 7.0%; Average loss: 3.8834
Iteration: 281; Percent complete: 7.0%; Average loss: 3.8974
Iteration: 282; Percent complete: 7.0%; Average loss: 3.9503
Iteration: 283; Percent complete: 7.1%; Average loss: 3.8684
Iteration: 284; Percent complete: 7.1%; Average loss: 3.9405
Iteration: 285; Percent 

Iteration: 535; Percent complete: 13.4%; Average loss: 3.6776
Iteration: 536; Percent complete: 13.4%; Average loss: 3.8125
Iteration: 537; Percent complete: 13.4%; Average loss: 3.7877
Iteration: 538; Percent complete: 13.5%; Average loss: 3.8364
Iteration: 539; Percent complete: 13.5%; Average loss: 3.8760
Iteration: 540; Percent complete: 13.5%; Average loss: 3.6615
Iteration: 541; Percent complete: 13.5%; Average loss: 3.8671
Iteration: 542; Percent complete: 13.6%; Average loss: 3.7049
Iteration: 543; Percent complete: 13.6%; Average loss: 3.7934
Iteration: 544; Percent complete: 13.6%; Average loss: 3.5548
Iteration: 545; Percent complete: 13.6%; Average loss: 3.9642
Iteration: 546; Percent complete: 13.7%; Average loss: 3.9553
Iteration: 547; Percent complete: 13.7%; Average loss: 4.1543
Iteration: 548; Percent complete: 13.7%; Average loss: 3.7200
Iteration: 549; Percent complete: 13.7%; Average loss: 3.5336
Iteration: 550; Percent complete: 13.8%; Average loss: 3.5096
Iteratio

Iteration: 799; Percent complete: 20.0%; Average loss: 3.5847
Iteration: 800; Percent complete: 20.0%; Average loss: 3.4284
Iteration: 801; Percent complete: 20.0%; Average loss: 3.6451
Iteration: 802; Percent complete: 20.1%; Average loss: 3.4569
Iteration: 803; Percent complete: 20.1%; Average loss: 3.5169
Iteration: 804; Percent complete: 20.1%; Average loss: 3.6458
Iteration: 805; Percent complete: 20.1%; Average loss: 3.4345
Iteration: 806; Percent complete: 20.2%; Average loss: 3.4322
Iteration: 807; Percent complete: 20.2%; Average loss: 3.2288
Iteration: 808; Percent complete: 20.2%; Average loss: 3.3902
Iteration: 809; Percent complete: 20.2%; Average loss: 3.5191
Iteration: 810; Percent complete: 20.2%; Average loss: 3.5458
Iteration: 811; Percent complete: 20.3%; Average loss: 3.8268
Iteration: 812; Percent complete: 20.3%; Average loss: 3.6797
Iteration: 813; Percent complete: 20.3%; Average loss: 3.8512
Iteration: 814; Percent complete: 20.3%; Average loss: 3.6315
Iteratio

Iteration: 1062; Percent complete: 26.6%; Average loss: 3.5460
Iteration: 1063; Percent complete: 26.6%; Average loss: 3.2479
Iteration: 1064; Percent complete: 26.6%; Average loss: 3.6636
Iteration: 1065; Percent complete: 26.6%; Average loss: 3.3405
Iteration: 1066; Percent complete: 26.7%; Average loss: 3.5223
Iteration: 1067; Percent complete: 26.7%; Average loss: 3.3905
Iteration: 1068; Percent complete: 26.7%; Average loss: 3.5586
Iteration: 1069; Percent complete: 26.7%; Average loss: 3.5608
Iteration: 1070; Percent complete: 26.8%; Average loss: 3.3827
Iteration: 1071; Percent complete: 26.8%; Average loss: 3.4835
Iteration: 1072; Percent complete: 26.8%; Average loss: 3.4635
Iteration: 1073; Percent complete: 26.8%; Average loss: 3.6281
Iteration: 1074; Percent complete: 26.9%; Average loss: 3.3969
Iteration: 1075; Percent complete: 26.9%; Average loss: 3.1591
Iteration: 1076; Percent complete: 26.9%; Average loss: 3.3773
Iteration: 1077; Percent complete: 26.9%; Average loss:

Iteration: 1322; Percent complete: 33.1%; Average loss: 3.3070
Iteration: 1323; Percent complete: 33.1%; Average loss: 3.4159
Iteration: 1324; Percent complete: 33.1%; Average loss: 3.1362
Iteration: 1325; Percent complete: 33.1%; Average loss: 3.3257
Iteration: 1326; Percent complete: 33.1%; Average loss: 3.2997
Iteration: 1327; Percent complete: 33.2%; Average loss: 3.4089
Iteration: 1328; Percent complete: 33.2%; Average loss: 3.3569
Iteration: 1329; Percent complete: 33.2%; Average loss: 3.3855
Iteration: 1330; Percent complete: 33.2%; Average loss: 3.2641
Iteration: 1331; Percent complete: 33.3%; Average loss: 3.4300
Iteration: 1332; Percent complete: 33.3%; Average loss: 3.3180
Iteration: 1333; Percent complete: 33.3%; Average loss: 3.2752
Iteration: 1334; Percent complete: 33.4%; Average loss: 3.1670
Iteration: 1335; Percent complete: 33.4%; Average loss: 3.3024
Iteration: 1336; Percent complete: 33.4%; Average loss: 3.3983
Iteration: 1337; Percent complete: 33.4%; Average loss:

Iteration: 1582; Percent complete: 39.6%; Average loss: 3.1494
Iteration: 1583; Percent complete: 39.6%; Average loss: 3.2258
Iteration: 1584; Percent complete: 39.6%; Average loss: 3.1287
Iteration: 1585; Percent complete: 39.6%; Average loss: 3.3868
Iteration: 1586; Percent complete: 39.6%; Average loss: 3.0130
Iteration: 1587; Percent complete: 39.7%; Average loss: 3.2238
Iteration: 1588; Percent complete: 39.7%; Average loss: 3.2592
Iteration: 1589; Percent complete: 39.7%; Average loss: 3.0086
Iteration: 1590; Percent complete: 39.8%; Average loss: 3.1312
Iteration: 1591; Percent complete: 39.8%; Average loss: 3.1161
Iteration: 1592; Percent complete: 39.8%; Average loss: 3.2578
Iteration: 1593; Percent complete: 39.8%; Average loss: 3.2432
Iteration: 1594; Percent complete: 39.9%; Average loss: 3.2380
Iteration: 1595; Percent complete: 39.9%; Average loss: 3.2888
Iteration: 1596; Percent complete: 39.9%; Average loss: 3.0999
Iteration: 1597; Percent complete: 39.9%; Average loss:

Iteration: 1842; Percent complete: 46.1%; Average loss: 2.9923
Iteration: 1843; Percent complete: 46.1%; Average loss: 3.2919
Iteration: 1844; Percent complete: 46.1%; Average loss: 3.1618
Iteration: 1845; Percent complete: 46.1%; Average loss: 3.1627
Iteration: 1846; Percent complete: 46.2%; Average loss: 3.2911
Iteration: 1847; Percent complete: 46.2%; Average loss: 3.3888
Iteration: 1848; Percent complete: 46.2%; Average loss: 3.2120
Iteration: 1849; Percent complete: 46.2%; Average loss: 3.0942
Iteration: 1850; Percent complete: 46.2%; Average loss: 3.1665
Iteration: 1851; Percent complete: 46.3%; Average loss: 3.3093
Iteration: 1852; Percent complete: 46.3%; Average loss: 3.0568
Iteration: 1853; Percent complete: 46.3%; Average loss: 3.2423
Iteration: 1854; Percent complete: 46.4%; Average loss: 3.1585
Iteration: 1855; Percent complete: 46.4%; Average loss: 3.0544
Iteration: 1856; Percent complete: 46.4%; Average loss: 3.1635
Iteration: 1857; Percent complete: 46.4%; Average loss:

Iteration: 2102; Percent complete: 52.5%; Average loss: 3.0803
Iteration: 2103; Percent complete: 52.6%; Average loss: 3.2083
Iteration: 2104; Percent complete: 52.6%; Average loss: 3.2893
Iteration: 2105; Percent complete: 52.6%; Average loss: 3.1613
Iteration: 2106; Percent complete: 52.6%; Average loss: 3.1986
Iteration: 2107; Percent complete: 52.7%; Average loss: 3.3112
Iteration: 2108; Percent complete: 52.7%; Average loss: 3.1273
Iteration: 2109; Percent complete: 52.7%; Average loss: 2.9888
Iteration: 2110; Percent complete: 52.8%; Average loss: 3.0775
Iteration: 2111; Percent complete: 52.8%; Average loss: 3.0856
Iteration: 2112; Percent complete: 52.8%; Average loss: 3.0258
Iteration: 2113; Percent complete: 52.8%; Average loss: 3.3363
Iteration: 2114; Percent complete: 52.8%; Average loss: 3.1003
Iteration: 2115; Percent complete: 52.9%; Average loss: 2.9717
Iteration: 2116; Percent complete: 52.9%; Average loss: 3.1771
Iteration: 2117; Percent complete: 52.9%; Average loss:

Iteration: 2362; Percent complete: 59.1%; Average loss: 3.0030
Iteration: 2363; Percent complete: 59.1%; Average loss: 3.0322
Iteration: 2364; Percent complete: 59.1%; Average loss: 3.1997
Iteration: 2365; Percent complete: 59.1%; Average loss: 2.7470
Iteration: 2366; Percent complete: 59.2%; Average loss: 3.0300
Iteration: 2367; Percent complete: 59.2%; Average loss: 2.9209
Iteration: 2368; Percent complete: 59.2%; Average loss: 3.0102
Iteration: 2369; Percent complete: 59.2%; Average loss: 2.9932
Iteration: 2370; Percent complete: 59.2%; Average loss: 2.9899
Iteration: 2371; Percent complete: 59.3%; Average loss: 2.6393
Iteration: 2372; Percent complete: 59.3%; Average loss: 3.0768
Iteration: 2373; Percent complete: 59.3%; Average loss: 2.8870
Iteration: 2374; Percent complete: 59.4%; Average loss: 3.1859
Iteration: 2375; Percent complete: 59.4%; Average loss: 2.9579
Iteration: 2376; Percent complete: 59.4%; Average loss: 3.1964
Iteration: 2377; Percent complete: 59.4%; Average loss:

Iteration: 2622; Percent complete: 65.5%; Average loss: 2.8909
Iteration: 2623; Percent complete: 65.6%; Average loss: 3.1752
Iteration: 2624; Percent complete: 65.6%; Average loss: 3.0537
Iteration: 2625; Percent complete: 65.6%; Average loss: 3.3208
Iteration: 2626; Percent complete: 65.6%; Average loss: 2.8378
Iteration: 2627; Percent complete: 65.7%; Average loss: 2.9342
Iteration: 2628; Percent complete: 65.7%; Average loss: 2.9864
Iteration: 2629; Percent complete: 65.7%; Average loss: 2.9894
Iteration: 2630; Percent complete: 65.8%; Average loss: 2.8463
Iteration: 2631; Percent complete: 65.8%; Average loss: 3.0905
Iteration: 2632; Percent complete: 65.8%; Average loss: 3.1772
Iteration: 2633; Percent complete: 65.8%; Average loss: 2.9757
Iteration: 2634; Percent complete: 65.8%; Average loss: 2.9784
Iteration: 2635; Percent complete: 65.9%; Average loss: 3.0471
Iteration: 2636; Percent complete: 65.9%; Average loss: 3.0806
Iteration: 2637; Percent complete: 65.9%; Average loss:

Iteration: 2882; Percent complete: 72.0%; Average loss: 2.8483
Iteration: 2883; Percent complete: 72.1%; Average loss: 2.8489
Iteration: 2884; Percent complete: 72.1%; Average loss: 3.0738
Iteration: 2885; Percent complete: 72.1%; Average loss: 3.0019
Iteration: 2886; Percent complete: 72.2%; Average loss: 2.6977
Iteration: 2887; Percent complete: 72.2%; Average loss: 2.8378
Iteration: 2888; Percent complete: 72.2%; Average loss: 2.9714
Iteration: 2889; Percent complete: 72.2%; Average loss: 3.0415
Iteration: 2890; Percent complete: 72.2%; Average loss: 2.7769
Iteration: 2891; Percent complete: 72.3%; Average loss: 3.2011
Iteration: 2892; Percent complete: 72.3%; Average loss: 2.9625
Iteration: 2893; Percent complete: 72.3%; Average loss: 3.0317
Iteration: 2894; Percent complete: 72.4%; Average loss: 2.9052
Iteration: 2895; Percent complete: 72.4%; Average loss: 2.9954
Iteration: 2896; Percent complete: 72.4%; Average loss: 2.9059
Iteration: 2897; Percent complete: 72.4%; Average loss:

Iteration: 3142; Percent complete: 78.5%; Average loss: 2.7164
Iteration: 3143; Percent complete: 78.6%; Average loss: 2.8200
Iteration: 3144; Percent complete: 78.6%; Average loss: 2.6954
Iteration: 3145; Percent complete: 78.6%; Average loss: 2.8374
Iteration: 3146; Percent complete: 78.6%; Average loss: 2.9727
Iteration: 3147; Percent complete: 78.7%; Average loss: 2.8860
Iteration: 3148; Percent complete: 78.7%; Average loss: 2.8553
Iteration: 3149; Percent complete: 78.7%; Average loss: 2.9185
Iteration: 3150; Percent complete: 78.8%; Average loss: 3.0021
Iteration: 3151; Percent complete: 78.8%; Average loss: 2.8352
Iteration: 3152; Percent complete: 78.8%; Average loss: 2.7796
Iteration: 3153; Percent complete: 78.8%; Average loss: 2.8095
Iteration: 3154; Percent complete: 78.8%; Average loss: 2.6803
Iteration: 3155; Percent complete: 78.9%; Average loss: 2.7033
Iteration: 3156; Percent complete: 78.9%; Average loss: 2.9283
Iteration: 3157; Percent complete: 78.9%; Average loss:

Iteration: 3402; Percent complete: 85.0%; Average loss: 2.6494
Iteration: 3403; Percent complete: 85.1%; Average loss: 2.9606
Iteration: 3404; Percent complete: 85.1%; Average loss: 3.1035
Iteration: 3405; Percent complete: 85.1%; Average loss: 2.5107
Iteration: 3406; Percent complete: 85.2%; Average loss: 2.8345
Iteration: 3407; Percent complete: 85.2%; Average loss: 2.5890
Iteration: 3408; Percent complete: 85.2%; Average loss: 2.9581
Iteration: 3409; Percent complete: 85.2%; Average loss: 2.8808
Iteration: 3410; Percent complete: 85.2%; Average loss: 2.9036
Iteration: 3411; Percent complete: 85.3%; Average loss: 2.7194
Iteration: 3412; Percent complete: 85.3%; Average loss: 2.7384
Iteration: 3413; Percent complete: 85.3%; Average loss: 2.7954
Iteration: 3414; Percent complete: 85.4%; Average loss: 2.8108
Iteration: 3415; Percent complete: 85.4%; Average loss: 2.5112
Iteration: 3416; Percent complete: 85.4%; Average loss: 2.7609
Iteration: 3417; Percent complete: 85.4%; Average loss:

Iteration: 3662; Percent complete: 91.5%; Average loss: 2.5471
Iteration: 3663; Percent complete: 91.6%; Average loss: 2.4469
Iteration: 3664; Percent complete: 91.6%; Average loss: 2.6716
Iteration: 3665; Percent complete: 91.6%; Average loss: 2.8111
Iteration: 3666; Percent complete: 91.6%; Average loss: 2.7160
Iteration: 3667; Percent complete: 91.7%; Average loss: 2.7099
Iteration: 3668; Percent complete: 91.7%; Average loss: 2.5280
Iteration: 3669; Percent complete: 91.7%; Average loss: 2.3445
Iteration: 3670; Percent complete: 91.8%; Average loss: 2.5546
Iteration: 3671; Percent complete: 91.8%; Average loss: 2.7245
Iteration: 3672; Percent complete: 91.8%; Average loss: 2.8298
Iteration: 3673; Percent complete: 91.8%; Average loss: 2.6515
Iteration: 3674; Percent complete: 91.8%; Average loss: 2.8947
Iteration: 3675; Percent complete: 91.9%; Average loss: 2.6965
Iteration: 3676; Percent complete: 91.9%; Average loss: 2.6214
Iteration: 3677; Percent complete: 91.9%; Average loss:

Iteration: 3922; Percent complete: 98.0%; Average loss: 2.5156
Iteration: 3923; Percent complete: 98.1%; Average loss: 2.5632
Iteration: 3924; Percent complete: 98.1%; Average loss: 2.7889
Iteration: 3925; Percent complete: 98.1%; Average loss: 2.6349
Iteration: 3926; Percent complete: 98.2%; Average loss: 2.5906
Iteration: 3927; Percent complete: 98.2%; Average loss: 2.9124
Iteration: 3928; Percent complete: 98.2%; Average loss: 2.6142
Iteration: 3929; Percent complete: 98.2%; Average loss: 2.7076
Iteration: 3930; Percent complete: 98.2%; Average loss: 2.6930
Iteration: 3931; Percent complete: 98.3%; Average loss: 2.9138
Iteration: 3932; Percent complete: 98.3%; Average loss: 2.5599
Iteration: 3933; Percent complete: 98.3%; Average loss: 2.7715
Iteration: 3934; Percent complete: 98.4%; Average loss: 2.7511
Iteration: 3935; Percent complete: 98.4%; Average loss: 2.7273
Iteration: 3936; Percent complete: 98.4%; Average loss: 2.7335
Iteration: 3937; Percent complete: 98.4%; Average loss:

In [96]:
encoder.eval()
decoder.eval()

searcher = GreedySearchDecoder(encoder, decoder)

evaluateInput(encoder, decoder, searcher, voc)


SSLError: HTTPSConnectionPool(host='demi-19b11.firebaseio.com', port=443): Max retries exceeded with url: /input/.json (Caused by SSLError(SSLError("bad handshake: SysCallError(10054, 'WSAECONNRESET')")))

In [None]:
torch.save

In [None]:
pip install dill


In [None]:
import dill
dill.dump_session('notebook_env.db')

In [1]:
import dill
dill.load_session('notebook_env.db')