In [1]:
import os
import json
import sys
import logging
# import torch
# from model import QAModel
# from train import train
import pickle
# from train import train
# from model import QAModel

from preprocess import get_glove, create_vocabulary, create_vocab_class
logging.basicConfig(level=logging.INFO)

NUM_EPOCH = 3
LEARNING_RATE =0.001
BATCH_SIZE = 16
HIDDEN_SIZE = 128
CONTEXT_LEN = 300
QUESTION_LEN = 150
ANSWER_LEN = 50
EMBEDDING_SIZE = 100
LOAD_PREV = False
# SAMPLING CONSTANTS

with open("./data/emb_matrix.pkl", "rb") as f:
    emb_matrix = pickle.load(f)

with open("./data/word2id.pkl", "rb") as f:
    word2id=  pickle.load(f)

with open("./data/id2word.pkl", "rb") as f:
    id2word = pickle.load(f)



context_vocab_path = "./data/vocab200.context"
train_context_path = "./data/train.graph"
context_vocab, rev_context_vocab = create_vocabulary(context_vocab_path,train_context_path,200)
NO_CLASS = len(context_vocab)

# TODO 
# qa_model.train(train_context_path, train_qn_path, train_ans_path, dev_qn_path, dev_context_path, dev_ans_path)


Skipping generating vocabulary file for ./data/vocab200.context
158 158


In [2]:
# For Training Steps

import time
import torch
import torch.nn as nn
import torch.optim as optim

loss_func = nn.CrossEntropyLoss()

def loss(loss_func,logits,target):
    return loss_func(logits,target)



def train_iteration(qamodel,batch,criterion,encoder_optimizer,decoder_optimizer):
    context_ids = batch.context_ids
    qn_ids = batch.qn_ids
    ans_ids = batch.ans_ids
    qn_mask = batch.qn_mask
    batch_size = batch.batch_size

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

#     qn_ids = torch.from_numpy(qn_ids).long().to(device)
    context_ids = torch.from_numpy(context_ids).long().to(device)
#     ans_ids = torch.from_numpy(ans_ids).long().to(device)
#     qn_mask = torch.from_numpy(qn_mask).long().to(device)
    
#     print(f"qn_ids - {qn_ids.is_cuda}, context_ids - {context_ids.is_cuda}, ans_ids - {ans_ids.is_cuda}, qn_mask - {qn_mask.is_cuda}")
    decoder_outputs= qamodel(qn_ids,context_ids,ans_ids,qn_mask)
    loss = 0    #loss per batch
    # print(len(decoder_outputs))
    ans_ids = torch.tensor(ans_ids).transpose(0,1).to(device)
    for idx,dec_out in enumerate(decoder_outputs):
        # print(dec_out.shape)
        # dec_out = [bsz,output_vocab_size]
        # print(ans_ids[idx],ans_ids[idx].shape)
        loss += criterion(dec_out,ans_ids[idx])
    
    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item()/batch_size


def train(qamodel,num_epochs, context_path, qn_path, ans_path, batch_size):
    epoch = 0
    criterion = nn.CrossEntropyLoss()
    encoder_optimizer = optim.Adam(qamodel.encoder.parameters(),lr = 0.001)
    decoder_optimizer = optim.Adam(qamodel.decoder.parameters(),lr = 0.001)
    # initialise optimiser
    while epoch<num_epochs:
        epoch+=1
        epoch_loss = 0
        epoch_start_time = time.time()
        num_iters = 0
        for batch in get_batch_generator(qamodel.word2id, qamodel.context2id, qamodel.ans2id, context_path,
                                            qn_path, ans_path, batch_size, qamodel.graph_vocab_class,
                                            context_len=300, question_len=150,
                                            answer_len=50, discard_long=False):
            try:

                batch_loss  = train_iteration(qamodel,batch,criterion,encoder_optimizer,decoder_optimizer)
                num_iters += 1
                epoch_loss += batch_loss
                if num_iters%50==0:
                    print(f'End of {num_iters} batches with loss = {batch_loss}')
            except RuntimeError as e:
                print(e)
                continue
            # loss backward
            # optimiser step
            # if num_iters%print_every:

            # add line to print at print_every
        print('End of epoch',epoch,' | Loss = ',epoch_loss)
        epoch_end_time = time.time()
        time_of_epoch = epoch_end_time - epoch_start_time
        print(time_of_epoch)

            



In [3]:
# vocab.py

"""This file contains a function to read the GloVe vectors from file,
and return them as an embedding matrix"""

from tqdm import tqdm
import numpy as np
import re
import nltk
import os

_PAD = r"<pad>"
_UNK = r"<unk>"
_SOS = r"<sos>"
_START_VOCAB = [_PAD, _UNK, _SOS]
PAD_ID = 0
UNK_ID = 1
SOS_ID = 2

# Regular expressions used to tokenize.
_WORD_SPLIT = re.compile(r"([.,!?\"':;()/-])")
_DIGIT_RE = re.compile(r"\d")
_ATTRIBUTE_RE = r'\d+[A-Za-z]+'
_NODE_RE = r'^[O|R|K|B|C|H|L]-\d+'
_ACT_RE = r'[A-Za-z]+'
_DISCARD_TOK = ['(', ')', 'nt', ';']


class Vocab:
    """
    Read in the class Vocab, it is a tidied class containing classified elements in graph
    self.node2id: dictionary converting the node token to its id
    self.nodes: list of all node tokens

    """
    def __init__(self, node2id, edge2id, flag2id):
        self.discard_tokens = _DISCARD_TOK
        self.node2id = node2id
        self.edge2id = edge2id
        self.flag2id = flag2id
        self.nodes = list(node2id.keys())
        self.edges = list(edge2id.keys())
        self.flags = list(flag2id.keys())
        self.all_tokens = self.flags + self.nodes + self.edges

    def tidy_in_triplet(self, tokens):
        """
        convert raw tokens into a id list of length [3*N]
        [first_node_id_list, edge_id_list, second_id_list, first_id_list, edge_id_list, second_id_list, ...]
        each entry in the list contains another list because node can be composed of many elements
        [[node_element0, node_element1, ...], [edge_element0, edge_element1, ...], [...], ...]
        :param tokens: a list of raw tokens in graph txt file
        :return: a list of lists
        """
        ids = []
        for (i, w) in enumerate(tokens):
            if w in self.nodes:
                if (i == 0) or (tokens[i - 1] in [';', ')']) or (tokens[i - 1] in self.edges):
                    ids.append([self.node2id[w]])
                else:
                    ids[-1].append(self.node2id[w])
            elif w in self.edges:
                if (tokens[i - 1] in self.flags) or (tokens[i - 1] in self.nodes):
                    ids.append([self.edge2id[w]])
                else:
                    ids[-1].append(self.edge2id[w])
            elif w in ['l', 'r']:
                w = "-".join(tokens[i - 1: i + 1])
                if w in self.edges:
                    ids[-1].append(self.edge2id[w])
                else:
                    ids[-1].append(UNK_ID)
                    
            elif (w in self.discard_tokens) or (re.match(_ATTRIBUTE_RE, w)):
                if w == ';':
                    assert len(ids) % 3 == 0, "error in tidy_in_triplet, can't be divided by 3"
                continue
            elif w not in self.all_tokens:
                raise ValueError("new token %s in graph representation."%w)
        return ids

def get_glove(glove_path, glove_dim):
    """Reads from original GloVe .txt file and returns embedding matrix and
    mappings from words to word ids.

    Input:
      glove_path: path to glove.6B.{glove_dim}d.txt
      glove_dim: integer; needs to match the dimension in glove_path

    Returns:
      emb_matrix: Numpy array shape (400002, glove_dim) containing glove embeddings
(plus PAD and UNK embeddings in first two rows).
        The rows of emb_matrix correspond to the word ids given in word2id and id2word
      word2id: dictionary mapping word (string) to word id (int)
      id2word: dictionary mapping word id (int) to word (string)
    """

    print("Loading GLoVE vectors from file: %s" % glove_path)
    vocab_size = int(4e5)  # this is the vocab size of the corpus we've downloaded

    emb_matrix = np.zeros((vocab_size + len(_START_VOCAB), glove_dim))
    word2id = {}
    id2word = {}

    random_init = True
    # randomly initialize the special tokens
    if random_init:
        emb_matrix[:len(_START_VOCAB), :] = np.random.randn(len(_START_VOCAB), glove_dim)

    # put start tokens in the dictionaries
    idx = 0
    for word in _START_VOCAB:
        word2id[word] = idx
        id2word[idx] = word
        idx += 1

    # go through glove vecs
    with open(glove_path, 'r', encoding="utf-8") as fh:
        for line in tqdm(fh, total=vocab_size):
            line = line.lstrip().rstrip().split(" ")
            word = line[0]
            vector = list(map(float, line[1:]))
            if glove_dim != len(vector):
                raise Exception(
                    "You set --glove_path=%s but --embedding_size=%i. If you set --glove_path yourself then make sure that --embedding_size matches!" % (
                    glove_path, glove_dim))
            emb_matrix[idx, :] = vector
            word2id[word] = idx
            id2word[idx] = word
            idx += 1

    final_vocab_size = vocab_size + len(_START_VOCAB)
    assert len(word2id) == final_vocab_size
    assert len(id2word) == final_vocab_size
    assert idx == final_vocab_size

    return emb_matrix, word2id, id2word


def one_hot_converter(vec_len):
    one_hot_embed = np.zeros((vec_len, vec_len))
    np.fill_diagonal(one_hot_embed, 1)
    return one_hot_embed

def instruction_tokenizer(sentence):
    """
    A special tokenizer for instructions.
    Turn into lower case and split Office-1 or office1 into "Office 1",
    :param sentence: instructions (natural language)
    :return: a list of tokens
    """
    words = []
    prepocessed_sen_list = preprocess_instruction(sentence.strip())
    for space_separated_fragment in prepocessed_sen_list:
        words.extend(_WORD_SPLIT.split(space_separated_fragment))
    return [w.lower() for w in words if w]

def preprocess_instruction(sentence):
    # change "office-12" or "office12" to "office 12"
    # change "12-office" or "12office" to "12 office"
    _WORD_NO_SPACE_NUM_RE = r'([A-Za-z]+)\-?(\d+)'
    _NUM_NO_SPACE_WORD_RE = r'(\d+)\-?([A-Za-z]+)'
    new_str = re.sub(_WORD_NO_SPACE_NUM_RE, lambda m: m.group(1) + ' ' + m.group(2), sentence)
    new_str = re.sub(_NUM_NO_SPACE_WORD_RE, lambda m: m.group(1) + ' ' + m.group(2), new_str)
    lemma = nltk.wordnet.WordNetLemmatizer()
    # correct common typos.
    correct_error_dic = {'rom': 'room', 'gout': 'go out', 'roo': 'room',
                         'immeidately': 'immediately', 'halway': 'hallway',
                         'office-o': 'office 0', 'hall-o': 'hall 0', 'pas': 'pass',
                         'offic': 'office', 'leftt': 'left', 'iffice': 'office'}
    for err_w in correct_error_dic:
        find_w = ' ' + err_w + ' '
        replace_w = ' ' + correct_error_dic[err_w] + ' '
        new_str = new_str.replace(find_w, replace_w)
    sen_list = []
    # Lemmatize words
    for word in new_str.split(' '):
        try:
            word = lemma.lemmatize(word)
            if len(word) > 0 and word[-1] == '-':
                word = word[:-1]
            if word:
                sen_list.append(word)
        except UnicodeDecodeError:
            continue
            # print("unicode error ", word, new_str)
    return sen_list

def basic_tokenizer(sentence):
    """Very basic tokenizer: split the sentence into a list of tokens."""
    words = []
    for space_separated_fragment in sentence.strip().split():
        words.extend(_WORD_SPLIT.split(space_separated_fragment))
    return [w for w in words if w]

def create_vocab_class(vocab_dict):
    """
    Convert the raw tokens in graph to
    To organize the vocab into two groups: node group, action + attribute group
    :param vocab_dict:
    :param rev_vocab:
    :return: A vocab class holding all the information necessary for training
    """
    rev_vocab = vocab_dict.keys()
    new_vocab_dic = {"node": ["S", "N", "E", "W"], "edge": [], "flag": _START_VOCAB}
    for vocab in rev_vocab:
        if vocab in 'lrSNEW':
            continue
        elif re.match(_NODE_RE, vocab):
            new_vocab_dic["node"].append(vocab)
        elif re.match(_ATTRIBUTE_RE, vocab):
            new_vocab_dic["edge"].append(vocab + '-l')
            new_vocab_dic["edge"].append(vocab + '-r')
        elif re.match(_ACT_RE, vocab) and vocab != 'nt':
            new_vocab_dic["edge"].append(vocab)
    node2id = dict([(x, y) for (y, x) in enumerate(new_vocab_dic['node'])])
    edge2id = dict([(x, y) for (y, x) in enumerate(new_vocab_dic['edge'])])
    flag2id = dict([(x, y) for (y, x) in enumerate(new_vocab_dic['flag'])])

    return Vocab(node2id, edge2id, flag2id)


def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
                      tokenizer=None, normalize_digits=False):
    """Create vocabulary file (if it does not exist yet) from data file.

    Data file is assumed to contain one sentence per line. Each sentence is
    tokenized and digits are normalized (if normalize_digits is set).
    Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
    We write it to vocabulary_path in a one-token-per-line format, so that later
    token in the first line gets id=0, second line gets id=1, and so on.

    Args:
      vocabulary_path: path where the vocabulary will be created.
      data_path: data file that will be used to create vocabulary.
      max_vocabulary_size: limit on the size of the created vocabulary.
      tokenizer: a function to use to tokenize each data sentence;
        if None, basic_tokenizer will be used.
      normalize_digits: Boolean; if true, all digits are replaced by 0s.
    """
    if not os.path.exists(vocabulary_path):
        print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
        vocab = {}
        with open(data_path, mode="r", encoding="utf-8") as f:
            counter = 0
            for line in f:
                counter += 1
                if counter % 100000 == 0:
                    print("  processing line %d" % counter)
                
                # TODO - CHANGE THE BELOW LINE
                # line = tf.compat.as_bytes(line)
                tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
                for w in tokens:
                    word = _DIGIT_RE.sub(r"0", w) if normalize_digits else w
                    if word in vocab:
                        vocab[word] += 1
                    else:
                        vocab[word] = 1
            vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
            if len(vocab_list) > max_vocabulary_size:
                vocab_list = vocab_list[:max_vocabulary_size]
            with open(vocabulary_path, mode="w", encoding="utf-8") as vocab_file:
                for w in vocab_list:
                    vocab_file.write(w + b"\n")
        rev_vocab = vocab_list # a list contain all the tokens
        vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])  # key is the token, value is index
        return vocab, rev_vocab
    else:
        print("Skipping generating vocabulary file for {}".format(vocabulary_path))
        return initialize_vocabulary(vocabulary_path)

def initialize_vocabulary(vocabulary_path):
  """Initialize vocabulary from file.

  We assume the vocabulary is stored one-item-per-line, so a file:
    dog
    cat
  will result in a vocabulary {"dog": 0, "cat": 1}, and this function will
  also return the reversed-vocabulary ["dog", "cat"].

  Args:
    vocabulary_path: path to the file containing the vocabulary.

  Returns:
    a pair: the vocabulary (a dictionary mapping string to integers), and
    the reversed vocabulary (a list, which reverses the vocabulary mapping).

  Raises:
    ValueError: if the provided vocabulary_path does not exist.
  """
  if os.path.exists(vocabulary_path):
    rev_vocab = []
    with open(vocabulary_path, mode="r", encoding="utf-8") as f:
      rev_vocab.extend(f.readlines())
    
    # TODO - CHANGE THE BELOW LINE
    # rev_vocab = [tf.compat.as_bytes(line.strip()) for line in rev_vocab]
    rev_vocab = [line.strip() for line in rev_vocab]
    vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
    print(len(vocab.keys()), len(rev_vocab))
    return vocab, rev_vocab
  else:
    raise ValueError("Vocabulary file %s not found.", vocabulary_path)

In [4]:
"""This file contains code to read tokenized data from file,
truncate, pad and process it into batches ready for training"""

import random
import time
import re

import numpy as np

class Batch(object):
    """A class to hold the information needed for a training batch"""

    def __init__(self, context_ids, context_tokens, qn_ids, qn_mask, qn_tokens, ans_ids, ans_mask,
                 ans_tokens, batch_size):
        """
        Inputs:
          {context/qn}_ids: Numpy arrays.
            Shape (batch_size, {context_len/question_len}). Contains padding.
          {context/qn}_mask: Numpy arrays, same shape as _ids.
            Contains 1s where there is real data, 0s where there is padding.
          {context/qn/ans}_tokens: Lists length batch_size, containing lists (unpadded) of tokens (strings)
          ans_span: numpy array, shape (batch_size, 2)
        """
        self.context_ids = context_ids
        # self.context_mask = context_mask
        self.context_tokens = context_tokens
        # self.context_embeddings = context_embeddings

        self.qn_ids = qn_ids
        self.qn_mask = qn_mask
        self.qn_tokens = qn_tokens

        self.ans_ids = ans_ids
        self.ans_mask = ans_mask
        self.ans_tokens = ans_tokens

        self.batch_size = batch_size


def split_by_whitespace(sentence):
    words = []
    for space_separated_fragment in sentence.strip().split():
        words.extend(re.split(" ", space_separated_fragment))
    return [w for w in words if w]


def intstr_to_intlist(string):
    """Given a string e.g. '311 9 1334 635 6192 56 639', returns as a list of integers"""
    return [int(s) for s in string.split()]


def sentence_to_token_ids(sentence, word2id, is_instr=False):
    """Turns an already-tokenized sentence string into word indices
    e.g. "i do n't know" -> [9, 32, 16, 96]
    Note any token that isn't in the word2id mapping gets mapped to the id for UNK
    """
    if is_instr:
        tokens = instruction_tokenizer(sentence)  # list of strings
    else:
        tokens = split_by_whitespace(sentence)

    # if simply split in tokens.
    ids = [word2id.get(w, UNK_ID) for w in tokens]
    ''' for debugging
    if UNK_ID in ids:
        print(tokens[ids.index(UNK_ID)], " ".join(tokens))
    '''
    return tokens, ids

def padded(token_batch, batch_pad=0):
    """
    Inputs:
      token_batch: List (length batch size) of lists of ints.
      batch_pad: Int. Length to pad to. If 0, pad to maximum length sequence in token_batch.
    Returns:
      List (length batch_size) of padded of lists of ints.
        All are same length - batch_pad if batch_pad!=0, otherwise the maximum length in token_batch
    """
    maxlen = max(map(lambda x: len(x), token_batch)) if batch_pad == 0 else batch_pad
    return map(lambda token_list: token_list + [PAD_ID] * (maxlen - len(token_list)), token_batch)

def reorganize(context_line, ans_line):
    start = ans_line.strip().split()[0]
    context_trip_list = context_line.strip().split(';')
    trips_contain_start = []
    trips_not_contain_start = []

    for trip_str in context_trip_list:
        if start in trip_str:
            trips_contain_start.append(trip_str)
        else:
            trips_not_contain_start.append(trip_str)
    if trips_not_contain_start[0][0] != ' ':
        trips_not_contain_start[0] = ' ' + trips_not_contain_start[0]
    organized_context_line = ";".join(trips_contain_start + trips_not_contain_start).strip() + '\n'

    #assert len(organized_context_line) == len(context_line), "len {} {}{} len {}".\
    #      format(len(context_line), context_line, organized_context_line, len(organized_context_line))
    return organized_context_line


def refill_batches(batches, word2id, context2id, ans2id, context_file, qn_file, ans_file, batch_size, context_len,
                   question_len, ans_len, discard_long, shuffle=True, output_goal=False):
    """
    Adds more batches into the "batches" list.
    Inputs:
      batches: list to add batches to
      word2id: dictionary mapping word (string) to word id (int)
      context_file, qn_file, ans_file: paths to {train/dev}.{context/question/answer} data files
      batch_size: int. how big to make the batches
      context_len, question_len: max length of context and question respectively
      discard_long: If True, discard any examples that are longer than context_len or question_len.
        If False, truncate those exmaples instead.
    """
    print("Refilling batches...")
    tic = time.time()
    examples = []  # list of (qn_ids, context_ids, ans_span, ans_tokens) triples
    context_line, qn_line, ans_line = context_file.readline(), qn_file.readline(), ans_file.readline()  # read the next line from each
    # print(context_line,qn_line,ans_line)
    while context_line and qn_line and ans_line:  # while you haven't reached the end

        # Reorganize the map to make the nodes containing the start point comes at the front.
        context_line = reorganize(context_line, ans_line)
        # Convert tokens to word ids
        context_tokens, context_ids = sentence_to_token_ids(context_line, context2id)
        qn_tokens, qn_ids = sentence_to_token_ids(qn_line, word2id, is_instr=True)

        ans_tokens, ans_ids = sentence_to_token_ids(ans_line, ans2id)

        ############# reorganize ans tokens into [start] + [action list] (+ [end]) #####################
        if output_goal:
            ans_tokens = [ans_tokens[0]] + ans_tokens[1::2] + [ans_tokens[-1]]
            ans_ids = [ans_ids[0]] + ans_ids[1::2] + [ans_ids[-1]]
        else:
            ans_tokens = [ans_tokens[0]] + ans_tokens[1::2]
            ans_ids = [ans_ids[0]] + ans_ids[1::2]
        ##############################################################################################s
        

        # read the next line from each file
        context_line, qn_line, ans_line = context_file.readline(), qn_file.readline(), ans_file.readline()

        # discard or truncate too-long questions
        if len(qn_ids) > question_len:
            if discard_long:
                continue
            else:  # truncate
                qn_ids = qn_ids[:question_len]

        # discard or truncate too-long contexts
        if len(context_ids) > context_len:
            if discard_long:
                continue
            else:  # truncate
                context_ids = context_ids[:context_len]

        # discard or truncate too-long answer
        if len(ans_ids) > ans_len:
            if discard_long:
                continue
            else:  # truncate
                ans_ids = ans_ids[:ans_len]

        # add to examples
        examples.append((context_ids, context_tokens, qn_ids, qn_tokens, ans_ids, ans_tokens))

        # stop refilling if you have 160 batches
        if len(examples) == batch_size * 160:
            break

    # Once you've either got 160 batches or you've reached end of file:

    # Sort by context length for speed
    # Note: if you sort by context length, then you'll have batches which contain the same context many times
    # (because each context appears several times, with different questions)
    # shuffle==False means to not change the sequence of the input data, thus no sorting.
    if shuffle:
        examples = sorted(examples, key=lambda e: len(e[0]))

    # Make into batches and append to the list batches
    for batch_start in range(0, len(examples), batch_size):
        # Note: each of these is a list length batch_size of lists of ints (except on last iter when it might be less than batch_size)
        context_ids_batch, context_tokens_batch, qn_ids_batch, qn_tokens_batch, ans_span_batch, ans_tokens_batch = zip(*examples[batch_start:batch_start + batch_size])

        batches.append(
            (context_ids_batch, context_tokens_batch, qn_ids_batch, qn_tokens_batch, ans_span_batch, ans_tokens_batch))
    if shuffle:
        # shuffle the batches
        random.shuffle(batches)

    toc = time.time()
    print("Refilling batches took %.2f seconds" % (toc - tic))
    return


def get_batch_generator(word2id, context2id, ans2id, context_path, qn_path, ans_path, batch_size, graph_vocab_class,
                        context_len, question_len, answer_len, discard_long, shuffle=True, output_goal=False):
    
    context_file, qn_file, ans_file = open(context_path, encoding="utf-8"), open(qn_path, encoding="utf-8"), open(ans_path, encoding="utf-8")
    batches = []


    while True:
        if len(batches) == 0:  # add more batches
            refill_batches(batches, word2id, context2id, ans2id, context_file, qn_file, ans_file, batch_size,
                           context_len, question_len, answer_len, discard_long, shuffle=shuffle, output_goal=output_goal)
        if len(batches) == 0:
            break

        # Get next batch. These are all lists length batch_size
        (context_ids, context_tokens, qn_ids, qn_tokens, ans_ids, ans_tokens) = batches.pop(0)

        # Pad context_ids and qn_ids
        qn_ids = padded(qn_ids, question_len)  # pad questions to length question_len
        context_ids = padded(context_ids, context_len)  # pad contexts to length context_len
        ans_ids = padded(ans_ids, answer_len) # pad ans to maximum length

        # Make qn_ids into a np array and create qn_mask
        qn_ids = np.array(list(qn_ids))  # shape (batch_size, question_len)
        qn_mask = (qn_ids != PAD_ID).astype(np.int32)  # shape (batch_size, question_len)

        # Make context_ids into a np array and create context_mask
        context_ids = np.array(list(context_ids))  # shape (batch_size, context_len)
        # context_mask = (context_ids != PAD_ID).astype(np.int32)  # shape (batch_size, context_len)

        # Make ans_ids into a np array and create ans_mask
        ans_ids = np.array(list(ans_ids))
        ans_mask = (ans_ids != PAD_ID).astype(np.int32)
        # print(list(ans_ids), list(context_ids), list(qn_ids))
        # interpret graph as triplets and append the first token
        # if not show_start_tokens:
        # context_embeddings, context_mask = compute_graph_embedding(context_tokens, graph_vocab_class, context_mask.shape[1])
        # else:
        #     context_embeddings, context_mask = compute_graph_embedding(context_tokens, graph_vocab_class, context_mask.shape[1],
        #                                                     np.array([ans_token[0] for ans_token in ans_tokens]))
        
        # Make into a Batch object
        batch = Batch(context_ids, context_tokens, qn_ids, qn_mask, qn_tokens, ans_ids, ans_mask, ans_tokens, batch_size)
        # print(len(batch))
        yield batch

    return

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
# the Model
import math
import time


import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


class Encoder(nn.Module):
    def __init__(self,hidden_size,embedding_size,keep_prob):
        super(Encoder,self).__init__()
        self.hidden_size = hidden_size
        self.keep_prob = keep_prob
        self.gru = nn.GRU(embedding_size,hidden_size,dropout=1-keep_prob,bidirectional=True)
    
    def forward(self, embedded_inputs):
        # print(embedded_inputs.shape,'\n',embedded_inputs)
        # embedded_inputs.to(device)
        output,hidden = self.gru(embedded_inputs.to(device))
        
        return output
        # output = [seq_len,batch_size,hidden_size*2]
        # hidden = [2,batch_size,hidden_size]
        
        # can add Dropout layer

# class Decoder(nn.Module):
#     def __init__(self, batch_size, hidden_size, tgt_vocab_size, max_decoder_length, embeddings, 
#                 keep_prob, sampling_prob, schedule_embed=False, pred_method='greedy'):
#         self.hidden_size = hidden_size
#         self.projection_layer = nn.Linear(hidden_size,tgt_vocab_size)
#         self.gru = nn.GRU(hidden_size,hidden_size)
#         self.batch_size = batch_size
#         self.embeddings = embeddings
#         self.start_id = SOS_ID
#         self.end_id = PAD_ID
#         self.tgt_vocab_size = tgt_vocab_size
#         self.max_decoder_length = max_decoder_length
#         self.keep_prob = keep_prob
#         self.schedule_embed = schedule_embed
#         self.pred_method = pred_method
#         self.beam_width = 9
#         self.sampling_prob = sampling_prob

#     def forward(self, blended_reps_final, encoder_hidden, decoder_emb_inputs, ans_masks, ans_ids, context_masks):
#         start_ids = ans_ids[:,0]
#         train_output = blended_reps_final
#         context_lengths = torch.Tensor([context_masks.size(1)]*self.batch_size)
#         decoder_lengths = torch.Tensor([context_masks.size(1)]*self.batch_size)

#         # traininghelper vali lines

#         pred_start_ids = ans_ids[:,0]

#         # pred_helper vali lines


class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        self.hidden_size = hidden_size
        self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))
        stdv = 1. / math.sqrt(self.v.size(0))
        self.v.data.normal_(mean=0, std=stdv)

    def forward(self, hidden, encoder_outputs, src_len=None):
        '''
        :param hidden: 
            previous hidden state of the decoder, in shape (layers*directions,B,H)
        :param encoder_outputs:
            encoder outputs from Encoder, in shape (T,B,H)
        :param src_len:
            used for masking. NoneType or tensor in shape (B) indicating sequence length
        :return
            attention energies in shape (B,T)
        '''
        encoder_outputs = encoder_outputs.transpose(0,1).to(device) # [T*B*H]
        max_len = encoder_outputs.size(0)
        this_batch_size = encoder_outputs.size(1)
        H = hidden.repeat(max_len,1,1).transpose(0,1).to(device)
        # print(hidden.shape, encoder_outputs.shape,H.shape)
        attn_energies = self.score(H,encoder_outputs) # compute attention score
        
        if src_len is not None:
            mask = []
            for b in range(src_len.size(0)):
                mask.append([0] * src_len[b].item() + [1] * (encoder_outputs.size(1) - src_len[b].item()))
            mask = cuda_(torch.ByteTensor(mask).unsqueeze(1)) # [B,1,T]
            attn_energies = attn_energies.masked_fill(mask, -1e18)
        
        return F.softmax(attn_energies).unsqueeze(1) # normalize with softmax

    def score(self, hidden, encoder_outputs):
        encoder_outputs = encoder_outputs.transpose(0,1)
        # print(hidden.shape,encoder_outputs.shape)

        energy = F.tanh(self.attn(torch.cat([hidden, encoder_outputs], 2))) # [B*T*2H]->[B*T*H]
        energy = energy.transpose(2,1) # [B*H*T]
        v = self.v.repeat(encoder_outputs.data.shape[0],1).unsqueeze(1) #[B*1*H]
        energy = torch.bmm(v,energy) # [B*1*T]
        return energy.squeeze(1) #[B*T]

class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, embed_size, output_size, n_layers=1, dropout_p=0.1):
        super(DecoderRNN, self).__init__()
        # Define parameters
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout_p = dropout_p
        # Define layers
        self.embedding_dec = nn.Embedding(output_size, embed_size)
        self.dropout = nn.Dropout(dropout_p)
        self.attn = Attn('concat', hidden_size)
        self.gru = nn.GRU(hidden_size + embed_size, hidden_size, n_layers, dropout=dropout_p)
        #self.attn_combine = nn.Linear(hidden_size + embed_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, word_input, last_hidden, encoder_outputs):
        '''
        :param word_input: === decoder_input
            word input for current time step, in shape (B)
        :param last_hidden:=== decoder_hidden
            last hidden stat of the decoder, in shape (layers*direction*B*H)
        :param encoder_outputs:
            encoder outputs in shape (T*B*H)
        :return
            decoder output
        Note: we run this one step at a time i.e. you should use a outer loop 
            to process the whole sequence
            '''
        # Get the embedding of the current input word (last output word)
        word_input = word_input.long().to(device)
        last_hidden = last_hidden.to(device)
        encoder_outputs = encoder_outputs.to(device)

        word_embedded = self.embedding_dec(word_input).view(1, word_input.size(0), -1) # (1,B,V)
        word_embedded = self.dropout(word_embedded)
        # Calculate attention weights and apply to encoder outputs
        attn_weights = self.attn(last_hidden[-1], encoder_outputs)
        context = attn_weights.bmm(encoder_outputs)  # (B,1,V)
        context = context.transpose(0, 1)  # (1,B,V)
        # Combine embedded input word and attended context, run through RNN
        # print(word_embedded.shape,context.shape)
        # word_embedded = [1,1,100]
        # context = [1,1,128]
        rnn_input = torch.cat((word_embedded, context), 2)
        #rnn_input = self.attn_combine(rnn_input) # use it in case your size of rnn_input is different
        # rnn_input = [1,1,228]
        # print(last_hidden.shape)
        # last_hidden = [150,128]
        output, hidden = self.gru(rnn_input, last_hidden)
        output = output.squeeze(0)  # (1,B,V)->(B,V)
        # context = context.squeeze(0)
        # update: "context" input before final layer can be problematic.
        # output = F.log_softmax(self.out(torch.cat((output, context), 1)))
        output = F.log_softmax(self.out(output))
        # Return final output, hidden state
        return output, hidden



class BasicAttn(nn.Module):
    def __init__(self, keep_prob, key_vec_size, value_vec_size):
        super(BasicAttn,self).__init__()
        self.keep_prob = keep_prob
        self.key_vec_size = key_vec_size
        self.value_vec_size = value_vec_size

    def forward(self, values, values_mask, keys):
        # values = torch.from_numpy(values).float() 
        values_mask = torch.from_numpy(values_mask).float().to(device) 
        # keys = torch.from_numpy(keys).float() 
        # values = values.to(device)
        # keys = keys.to(device)
        attn_logits_mask = torch.unsqueeze(values_mask, 1).to(device) # -> (batch_size, 1, num_values)
        
        w = torch.zeros(self.key_vec_size, self.value_vec_size)
        w = nn.init.xavier_normal_(w).to(device)
        values_t = torch.transpose(values, 0, 1) 
        values_t = torch.transpose(values_t, 1,2)# -> (batch_size, value_vec_size, num_values)
        def fn(a, x):
            return torch.matmul(x, w)

        list_ = [fn(8, keys[i, :,:]) for i in range(keys.shape[0])]
        part_logits = torch.stack(list_)
        # part_logits = torch.Tensor(list_) # (batch_size, num_keys, value_vec)
        # print(values_t.shape)
        attn_logits = torch.bmm(part_logits, values_t).to(device) # -> (batch_size, num_keys, num_values)
        _, attn_dist = masked_softmax(attn_logits, attn_logits_mask, dim = -1)
        # attn_dist = attn_dist.transpose(1,2)
        # print('attn dist: ',attn_dist.shape)
        # print(values.shape,values_t.shape)
        output = torch.matmul(attn_dist, values.transpose(0,1))

        return attn_dist, output


class QAModel(nn.Module):
    def __init__(self, id2word, word2id, emb_matrix, ans2id, id2ans, context2id,hidden_size, embedding_size, tgt_vocab_size,batch_size):
        super(QAModel,self).__init__()
        self.hidden_size = hidden_size
        self.embedding_size = embedding_size
        self.tgt_vocab_size = tgt_vocab_size
        self.id2word = id2word
        self.word2id = word2id
        self.ans_vocab_size = len(ans2id)
        self.ans2id = ans2id
        self.id2ans = id2ans
        self.batch_size = batch_size
        self.emb_matrix = emb_matrix
        self.context2id = context2id
        self.keep_prob = 0.8
        self.embedding = nn.Embedding(len(context2id),embedding_size)
        self.linear21 = nn.Linear(2*hidden_size,hidden_size)
        self.linear41 = nn.Linear(4*hidden_size,hidden_size)
        self.graph_vocab_class = create_vocab_class(context2id)
        self.context_dimension_compressed = len(self.graph_vocab_class.all_tokens) + len(self.graph_vocab_class.nodes)

        self.encoder = Encoder(self.hidden_size,self.embedding_size, self.keep_prob).to(device)
        self.decoder = DecoderRNN(hidden_size,embedding_size,tgt_vocab_size).to(device)

    def forward(self,qn_ids,context_ids,ans_ids,qn_mask):

        context_embs = self.embedding(context_ids)
        # context_embs.to(device)
        context_hiddens = self.encoder(context_embs)  # (batch_size, context_len, hidden_size*2)

        qn_embs = self.get_embeddings(self.id2word,self.emb_matrix,qn_ids,self.embedding_size, self.batch_size)
        question_hiddens = self.encoder(qn_embs)  # (batch_size, question_len, hidden_size*2)
        # print('question hiddens: ',question_hiddens.shape)
        question_last_hidden = question_hiddens[-1, :, :]
        # question_last_hidden.to(device)
        # print('question last hidden: ',question_last_hidden.shape)
        question_last_hidden = self.linear21(question_last_hidden)
        question_last_hidden = question_last_hidden.unsqueeze(0)
        # Working fine till here

        attn_layer = BasicAttn(self.keep_prob, self.hidden_size * 2, self.hidden_size * 2)
        _, attn_output = attn_layer(question_hiddens, qn_mask, context_hiddens)
        # Concat attn_output to context_hiddens to get blended_reps
        blended_reps = torch.cat((context_hiddens, attn_output), axis=2)  # (batch_size, context_len, hidden_size*4)
        blended_reps_final = self.linear41(blended_reps)
        dec_hidden = question_last_hidden
        # Idhar for loop lagaane ka hai
        decoder_outputs = []
        # print('ans_ids',len(ans_ids))
        ans_ids = torch.tensor(ans_ids)
        ans_ids = ans_ids.transpose(0,1)
        # print(ans_ids.shape)
        tgt_len,_ = ans_ids.shape
        # ans_ids = [50=tgt_len,bsz]
        for idx in range(tgt_len):
            dec_output,dec_hidden = self.decoder(ans_ids[idx,:],dec_hidden,blended_reps_final)
            decoder_outputs.append(dec_output)
            # topk
            # loss add
        return decoder_outputs #, loss
        
        # ----------------------------------- #
        
    def get_embeddings(self,token2id,embed_matrix,input_ids,embed_size,batch_size):
            array = np.zeros((len(input_ids[0]),batch_size,embed_size)) 
            # input_ids = [bsz,src_len]
            # print(input_ids)
            for idx,tokenised_words in enumerate(input_ids):
                # words = [token2id[char_id] for char_id in tokenised_id]
                for word_idx,word in enumerate(tokenised_words):
                    array[word_idx,idx,:] = embed_matrix[word,:]
            vector = torch.from_numpy(array).float()
            # print(vector.size,vector)
            return vector.to(device)



def masked_softmax(logits, masks, dim):
    # print('attn logits: ',logits.shape)
    inf_mask = (1 - masks.type(torch.FloatTensor)) * (-1e30)
    inf_mask = inf_mask.to(device)
    masked_logits = torch.add(logits, inf_mask)
    sm = nn.LogSoftmax(dim)
    softmax_out = sm(masked_logits)
    return masked_logits.to(device), softmax_out.to(device)


In [7]:
qa_model = QAModel(id2word, word2id, emb_matrix, context_vocab, rev_context_vocab, context_vocab, HIDDEN_SIZE, EMBEDDING_SIZE, NO_CLASS,BATCH_SIZE)
qa_model = qa_model.to(device)
# file_handler = logging.FileHandler(os.path.join("./model", "log.txt"))
# logging.getLogger().addHandler(file_handler)

# if LOAD_PREV:
#     qa_model.load_state_dict(torch.load("./data"))

train(qa_model, NUM_EPOCH, "./data/train.graph", "./data/train.instruction", "./data/train.answer", BATCH_SIZE)

  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))


Refilling batches...
Refilling batches took 11.97 seconds




End of 50 batches with loss = 2.8000874519348145
End of 100 batches with loss = 2.5970027446746826
End of 150 batches with loss = 2.5395398139953613
Refilling batches...
Refilling batches took 8.96 seconds
End of 200 batches with loss = 2.315581798553467
End of 250 batches with loss = 2.7311747074127197
End of 300 batches with loss = 2.550208330154419
Refilling batches...
Refilling batches took 8.65 seconds
End of 350 batches with loss = 2.594959020614624
End of 400 batches with loss = 2.375993013381958
End of 450 batches with loss = 2.577575206756592
Refilling batches...
Refilling batches took 1.41 seconds
invalid argument 7: equal number of batches expected at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:493
End of 500 batches with loss = 2.8143985271453857
Refilling batches...
Refilling batches took 0.00 seconds
End of epoch 1  | Loss =  1516.9305906295776
296.8049328327179
Refilling batches...
Refilling batches took 8.44 seconds
End of 50 batches with loss = 2.582042455673217

In [10]:
checkpoint = {'model': QAModel(id2word, word2id, emb_matrix, context_vocab, rev_context_vocab, context_vocab, HIDDEN_SIZE, EMBEDDING_SIZE, NO_CLASS,BATCH_SIZE).to(device),
              'state_dict': qa_model.state_dict(),
              "encoder_optimizer": optim.Adam(qa_model.encoder.parameters(),lr = 0.001).state_dict(),
              "decoder_optimizer": optim.Adam(qa_model.decoder.parameters(),lr = 0.001).state_dict()}

In [11]:
torch.save(checkpoint, 'model.pth')

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


In [None]:
from collections import Counter, defaultdict
import string
import re
import argparse
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd

# Interface for summarizing all metrics
# input are two lists of strings of form [start] + [action list] + [goal]
def compute_all_metrics(pred_answer, true_answer):
    # because we only consider the accuracy of the actions, we remove first and last items.
    prediction_str = normalize_answer(" ".join(pred_answer[1:-1]))
    ground_truth_str = normalize_answer(" ".join(true_answer[1:-1]))
    em = exact_match_score(prediction_str, ground_truth_str)
    f1 = f1_score(prediction_str, ground_truth_str)
    ed = edit_distance(prediction_str.split(), ground_truth_str.split())
    if em > int(pred_answer[-1] == true_answer[-1]):
        print("weird thing happens, pred {}, true {}".format(pred_answer, true_answer))
    return f1, em, ed, pred_answer[-1] == true_answer[-1]

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    prediction_tokens = prediction.split()
    ground_truth_tokens = ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def edit_distance(s1, s2):
    """
    :param s1: list
    :param s2: list
    :return: edit distance of two lists
    """
    if len(s1) < len(s2):
        return edit_distance(s2, s1)

    # len(s1) >= len(s2)
    if len(s2) == 0:
        return len(s1)

    previous_row = range(len(s2) + 1)
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[
                             j + 1] + 1  # j+1 instead of j since previous_row and current_row are one character longer
            deletions = current_row[j] + 1  # than s2
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))
        previous_row = current_row

    return previous_row[-1]

def exact_match_score(prediction, ground_truth):
    return prediction == ground_truth

def rough_match_score(prediction, ground_truth):
    prediction = ' '.join(prediction.split(' '))
    ground_truth = ' '.join(ground_truth.split(' '))
    pred_list = prediction.split(' ')
    truth_list = ground_truth.split(' ')
    poss_correct = len(pred_list) == len(truth_list) or \
                   (len(pred_list) > len(truth_list) and pred_list[len(truth_list)] not in ['oor', 'ool'])
    return prediction[: len(ground_truth)] == ground_truth and poss_correct

def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def evaluate(ground_truth, predictions):
    f1_total = em_total = 0
    total = len(ground_truth)
    err_analysis = defaultdict(list)
    assert len(ground_truth) == len(predictions)

    for i in range(total):
        truth = ground_truth[i].strip().split(' ')[1:]
        pred = predictions[i].strip().split(' ')
        f1 = f1_score(predictions[i], " ".join(truth))
        em = exact_match_score(predictions[i], " ".join(truth))
        for j in range(len(truth)):
            err_analysis[j].append(j < len(pred) and truth[j] == pred[j])
        f1_total += f1
        em_total += em
    err_dist = np.zeros([len(err_analysis)])

    for k in err_analysis:
        err_dist[k] = sum(err_analysis[k]) / float(len(err_analysis[k]))
    plt.plot(err_dist)
    plt.xlabel("pos in the answer")
    plt.ylabel("accuracy")
    plt.show()
    exact_match = 100.0 * em_total / total
    f1 = 100.0 * f1_total / total
    print('exact_match: {}, f1: {}'.format(exact_match, f1))
    return

def evaluate_new(ground_truth, predictions):
    """
    :param ground_truth: a list of strings
    :param predictions: a list of strings
    :return: nil, side effect: print out the metrics value.
    """
    assert len(ground_truth) == len(predictions)
    f1_all = 0.0
    em_all = 0.0
    ed_all = 0.0
    gem_all = 0.0
    i = 0
    for (g, p) in zip(ground_truth, predictions):
        i += 1
        # print(i)
        pred_answer = p.strip().split(" ")
        true_answer = g.strip().split(" ")
        true_answer = [true_answer[0]] + true_answer[1::2] + [true_answer[-1]]
        f1, em, ed, gem = compute_all_metrics(pred_answer, true_answer)
        f1_all += f1
        em_all += em
        ed_all += ed
        gem_all += gem

    f1_all /= len(ground_truth)
    em_all /= len(ground_truth)
    ed_all /= len(ground_truth)
    gem_all /= len(ground_truth)
    print("f1 {}, em {}, ed {}, gem {}".format(f1_all, em_all, ed_all, gem_all))
    return

In [None]:
with open("./data/train.answer") as true_file:
    dataset = true_file.readlines()
with open(args.prediction_file) as prediction_file:
    predictions = prediction_file.readlines()
evaluate_new(dataset, predictions)