In [None]:
from __future__ import print_function

import random
import os
import time
import argparse

import pickle as pkl
import torch
from torch import nn, optim
import numpy as np
from beeprint import pp

from models.vrnn import VRNN
from data_apis.data_utils import SWDADataLoader
from data_apis.SWDADialogCorpus import SWDADialogCorpus
from utils.loss import print_loss
import params

from torch.autograd import Variable
import torch.nn.functional as F

In [None]:
class ARGS():
    forward_only = False
    resume = False
    checkpoint_path = ""
    test_path = ""
    save_model = True
    use_test_batch = True

In [None]:
args = ARGS()

In [None]:
def get_dataset():
    with open(params.api_dir, "rb") as fh:
        api = pkl.load(fh, encoding='latin1')
    dial_corpus = api.get_dialog_corpus()

    train_dial, labeled_dial, test_dial = dial_corpus.get(
        "train"), dial_corpus.get("labeled"), dial_corpus.get("test")

    # convert to numeric input outputs
    train_feed = SWDADataLoader("Train", train_dial, params.max_utt_len,
                                params.max_dialog_len)
    valid_feed = test_feed = SWDADataLoader("Test", test_dial,
                                            params.max_utt_len,
                                            params.max_dialog_len)
    return train_feed, valid_feed, test_feed, np.array(api.word2vec)


def train(model, train_loader, optimizer):
    losses = []
    local_t = 0
    start_time = time.time()
    loss_names = ["loss"]
    model.train()

    while True:
        optimizer.zero_grad()
        batch = train_loader.next_batch()
        if batch is None:
            break
        local_t += 1
        loss = model(*batch)
        losses.append(loss)
        loss.backward()
        optimizer.step()

        if local_t % (train_loader.num_batch // 10) == 0:
            print_loss("%.2f" %
                       (train_loader.ptr / float(train_loader.num_batch)),
                       loss_names, [losses],
                       postfix='')
    # finish epoch!
    epoch_time = time.time() - start_time
    print_loss("Epoch Done", loss_names, [losses],
               "step time %.4f" % (epoch_time / train_loader.num_batch))


def valid(model, valid_loader):
    losses = []
    while True:
        batch = valid_loader.next_batch()
        if batch is None:
            break
        loss = model(*batch)
        losses.append(loss)

    print_loss("ELBO_VALID", ['losses valid'], [losses], "")


def decode(model, data_loader):
    results = []
    while True:
        batch = data_loader.next_batch()
        if batch is None:
            break
        result = model(*batch, interpret=True)
        results.append(result)
    return results

In [None]:
pp(params)
# set random seeds
seed = params.seed
random.seed(seed)
np.random.seed(seed + 1)
torch.manual_seed(seed + 2)

# TODO: set device
use_cuda = params.use_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

train_loader, valid_loader, test_loader, word2vec = get_dataset()

if args.forward_only or args.resume:
    log_dir = os.path.join(params.log_dir, args.ckpt_dir)
    checkpoint_path = os.path.join(log_dir, args.ckpt_name)
else:
    log_dir = os.path.join(params.log_dir, "run" + str(int(time.time())))
os.makedirs(log_dir, exist_ok=True)


In [None]:
model = VRNN()
# TODO: learning rate with decay
optimizer = optim.Adam(model.parameters(), lr=params.init_lr)

if word2vec is not None and not args.forward_only:
    print("Load word2vec")
    # TODO: trainable pretrained embedding
    model.embedding.from_pretrained(torch.from_numpy(word2vec),
                                    freeze=False)

# Write config to a file for logging
if not args.forward_only:
    with open(os.path.join(log_dir, "run.log"), "w") as f:
        f.write(pp(params, output=False))

In [None]:
last_epoch = 0
if args.resume:
    print("Resuming training from %s" % checkpoint_path)
    state = torch.load(checkpoint_path)
    model.load_state_dict(state['state_dict'])
    optimizer.load_state_dict(state['optimizer'])
    last_epoch = state['epoch']

In [None]:
 # Train and evaluate
# TODO: early stop
if not args.forward_only:
    for epoch in range(last_epoch + 1, params.max_epoch + 1):
        print(">> Epoch %d with lr %f" % (epoch, params.init_lr))
        if train_loader.num_batch is None or train_loader.ptr >= train_loader.num_batch:
            train_loader.epoch_init(params.batch_size, shuffle=True)
        train(model, train_loader, optimizer)
        valid_loader.epoch_init(params.batch_size, shuffle=False)
        valid(model, valid_loader)

        if args.save_model:
            print("Save the model at the end of each epoch.")
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(state,
                       os.path.join(log_dir, "vrnn_" + str(epoch) + ".pt"))
# Inference only
else:
    state = torch.load(checkpoint_path)
    print("Load model from %s" % checkpoint_path)
    model.load_state_dict(state['state_dict'])
    if not args.use_test_batch:
        train_loader.epoch_init(params.batch_size, shuffle=False)
        results = decode(model, train_loader)
    else:
        valid_loader.epoch_init(params.batch_size, shuffle=False)
        results = decode(
            model, valid_loader
        )  # [num_batches(8), 4, batch_size(16), max_dialog_len(10), n_state(10)]
    with open(os.path.join(log_dir, "result.pkl"), "wb") as fh:
        pkl.dump(results, fh)

In [None]:
class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()

        self.method = method
        self.hidden_size = hidden_size
        
        print(method)
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)

        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = nn.Parameter(torch.FloatTensor(1, hidden_size))

    def forward(self, hidden, encoder_outputs):
        max_len = encoder_outputs.size(1)
        this_batch_size = encoder_outputs.size(0)

        # Create variable to store attention energies
        attn_energies = Variable(torch.zeros(this_batch_size, max_len))  # B x S

        if params.use_cuda:
            attn_energies = attn_energies.cuda()

        # For each batch of encoder outputs
        for b in range(this_batch_size):
            # Calculate energy for each encoder output
            for i in range(max_len):
                attn_energies[b, i] = self.score(hidden[b, :], encoder_outputs[b, i].unsqueeze(0))

        # Normalize energies to weights in range 0 to 1, resize to 1 x B x S
        return F.softmax(attn_energies, dim = 1).unsqueeze(1)

    def score(self, hidden, encoder_output):

        if self.method == 'dot':
            energy = torch.dot(encoder_output.view(-1), hidden.view(-1))
            # energy = hidden.dot(encoder_output)
            return energy

        elif self.method == 'general':
            energy = self.attn(encoder_output)
            # print("enery size: ", energy.size())
            # print("hidden size: ", hidden.size())
            energy = torch.dot(energy.view(-1), hidden.view(-1))
            return energy

        elif self.method == 'concat':
            energy = self.attn(torch.cat((hidden, encoder_output), 1))
            energy = self.v.dot(energy)
            return energy

In [None]:
encoder_outputs = torch.randn(16,5,800)

In [None]:
attn = Attn("general", 800)

In [None]:
hidden = torch.randn(16, 800)

In [None]:
attn_weights = attn(hidden, encoder_outputs)

In [None]:
context = attn_weights.bmm(encoder_outputs)

In [None]:
context.shape