In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.cuda import FloatTensor, LongTensor

np.random.seed(42)

In [2]:
class NetState():
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.components = []
    
    def add_layer(self, component):
        self.components.append(component)
        
    def get_value_by_name(self, name, index, module_name):
        for c in self.components:
            if c.name == name:
                return c.get(index, module_name)
            
    def get_full(self, name, module_name):
        for c in self.components:
            if c.name == name:
                return c.get_full(module_name)
        
    def add(self, hidden, name):
        for c in self.components:
            if c.name == name:
                c.add(hidden)
                return
        
class ComponentLayerState():
    def __init__(self, name, is_solid):
        self.name = name
        self.is_solid = is_solid
        self.reset()
        
    def reset(self):
        self.pos = {}
        self.hiddens = []
    
    def get(self, index, module_name):
        if not module_name in self.pos:
            self.pos[module_name] = -1
        if index > 0:
            self.pos[module_name] += 1
            if self.pos[module_name] >= len(self.hiddens):
                return None
            else:
                return self.hiddens[self.pos[module_name]]
    
    def get_full(self, module_name):
        if not module_name in self.pos:
            self.pos[module_name] = len(self.hiddens)
            return self.hiddens
        else:
            return None
    
    def add(self, token):
        if self.is_solid:
            self.hiddens = token
        else:
            self.hiddens.append(token)
            
class InputLayerState(ComponentLayerState):
    def __init__(self, name, is_solid, inputs):
        super().__init__(name, is_solid)
        self.hiddens = inputs

In [3]:
class RNNComputer(nn.Module):
    def __init__(self, hidden_size, input_size):
        super().__init__()
        
        self._hidden_size = hidden_size
        self._hidden = nn.Linear(hidden_size + input_size, hidden_size)

    def forward(self, state, input_token):
        inputs, hidden = input_token
        if inputs is None:
            return state, None
        if hidden is None:
            hidden = inputs.new_zeros(inputs.size(0), self._hidden_size)
        x = torch.cat((hidden, inputs), -1)
        hidden = torch.tanh(self._hidden(x))

        return state, hidden
    
class RNNSolidComputer(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self._hidden_size = hidden_size
        self._rnn = nn.LSTM(input_size, hidden_size, bidirectional=True)

    def forward(self, state, input_token):
        if input_token is None:
            return state, None
        #if hidden is None:
        hidden = (input_token.new_zeros((2, input_token.shape[1], self._hidden_size)),
                  input_token.new_zeros((2, input_token.shape[1], self._hidden_size)))
        #print(input_token.shape)
        output, hidden = self._rnn(input_token, hidden)

        return state, output
    
class TaggerComputer(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        self._hidden_size = hidden_size
        self._hidden = nn.Linear(input_size, hidden_size)

    def forward(self, state, input_token):
        if input_token is None:
            return state, None
        hidden = self._hidden(input_token)
        return state, hidden
    
class EmbeddingComputer(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        
        self._embed = nn.Embedding(vocab_size, hidden_size)

    def forward(self, state, input_token):
        if input_token is None:
            return state, None
        hidden = self._embed(input_token)
        return state, hidden

In [4]:
class RNNRecurrent():
    def __init__(self, input_name, self_name):
        super().__init__()
        
        self._input_name = input_name
        self._self_name = self_name

    def get(self, state, net):
        inputs = net.get_value_by_name(self._input_name, 1, self._self_name)
        hidden = net.get_value_by_name(self._self_name, 1, self._self_name)
        return inputs, hidden
    
class RNNSolidRecurrent():
    def __init__(self, input_name, self_name):
        super().__init__()
        
        self._input_name = input_name
        self._self_name = self_name

    def get(self, state, net):
        inputs = net.get_full(self._input_name, self._self_name)
        return inputs
    
class TaggerRecurrent():
    def __init__(self, input_name, self_name):
        super().__init__()
        
        self._input_name = input_name
        self._self_name = self_name

    def get(self, state, net):
        inputs = net.get_full(self._input_name, self._self_name)
        #inputs = net.get_value_by_name(self._input_name,1, self._self_name)
        if isinstance(inputs, list):
            inputs = torch.stack(inputs)
            inputs.requires_grad_()
        return inputs

In [5]:
class TBRU(nn.Module):
    def __init__(self, name, recurrent, computer, state_shape, is_solid):
        super().__init__()
        
        self.is_solid = is_solid
        self.name = name
        self.state_shape = state_shape
        self._rec = recurrent
        self._comp = computer

    def forward(self, state, net):
        state, hidden = self._comp(state, (self._rec.get(state, net)))
        if hidden is not None:
            net.add(hidden, self.name)
        return state, hidden

In [6]:
class MasterComponent(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.comp = []

    def add_component(self, component):
        self.add_module(component.name, component)
        
    def prepare_net(self, net):
        for c in self._modules:
            comp_layer = ComponentLayerState(self._modules[c].name, self._modules[c].is_solid)
            net.add_layer(comp_layer)
        return net
        
    def forward(self, net):
        for c in self._modules:
            module = self._modules[c]
            state, hidden = module(np.zeros(module.state_shape), net)
            while hidden is not None:
                state, hidden = module(state, net)
                
        return net

class DRAGNNMaster():
    def __init__(self):
        super().__init__()
        
        self.net = NetState()
        self.main = MasterComponent().cuda()
        
    def add_component(self, component):
        self.main.add_component(component)

    def build_net(self, input_layer):
        if self.net is not None:
            del self.net
            self.net = NetState()
        self.net.reset()
        self.net.add_layer(input_layer)
        self.main.prepare_net(self.net)
        
    def forward(self, input_layer):
        self.build_net(input_layer)
        self.net = self.main(self.net)
        output = self.net.components[-1].hiddens
        if self.net.components[-1].is_solid:
            return output
        else:
            output = torch.stack(output)
            output.requires_grad_()
            return output
    
    def save_model(self, filename):
        torch.save(self.main.state_dict(), filename)
        
    def load_model(self, filename):
        self.main.load_state_dict(torch.load(filename))
        
    def save_checkpoint(self, epoch, optimizer, filename='checkpoint.pth.tar'):
        torch.save({
            'epoch': epoch + 1,
            'state_dict': self.main.state_dict(),
            'optimizer' : optimizer.state_dict(),
        }, filename)

In [7]:
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'

PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words
START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence
STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences


class Vocab(object):

  def __init__(self, vocab_file, max_size):
    self._word_to_id = {}
    self._id_to_word = {}
    self._count = 0 # keeps track of total number of words in the Vocab

    # [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3.
    for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
      self._word_to_id[w] = self._count
      self._id_to_word[self._count] = w
      self._count += 1

    # Read the vocab file and add words up to max_size
    with open(vocab_file, 'r') as vocab_f:
      for line in vocab_f:
        pieces = line.split()
        if len(pieces) != 2:
          print('Warning: incorrectly formatted line in vocabulary file: %s\n' % line)
          continue
        w = pieces[0]
        if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
          raise Exception('<s>, </s>, [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
        if w in self._word_to_id:
          raise Exception('Duplicated word in vocabulary file: %s' % w)
        self._word_to_id[w] = self._count
        self._id_to_word[self._count] = w
        self._count += 1
        if max_size != 0 and self._count >= max_size:
          print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (max_size, self._count))
          break

    print("Finished constructing vocabulary of %i total words. Last word added: %s" % (self._count, self._id_to_word[self._count-1]))

  def word2id(self, word):
    if word not in self._word_to_id:
      return self._word_to_id[UNKNOWN_TOKEN]
    return self._word_to_id[word]

  def id2word(self, word_id):
    if word_id not in self._id_to_word:
      raise ValueError('Id not found in vocab: %d' % word_id)
    return self._id_to_word[word_id]

  def size(self):
    return self._count

  def write_metadata(self, fpath):
    print("Writing word embedding metadata file to %s..." % (fpath))
    with open(fpath, "w") as f:
      fieldnames = ['word']
      writer = csv.DictWriter(f, delimiter="\t", fieldnames=fieldnames)
      for i in xrange(self.size()):
        writer.writerow({"word": self._id_to_word[i]})
        


def article2ids(article_words, vocab):
    ids = []
    oovs = []
    unk_id = vocab.word2id(UNKNOWN_TOKEN)
    for w in article_words.split():
        w = str(w)
        i = vocab.word2id(w)
        if i == unk_id: # If w is OOV
            if w not in oovs: # Add to list of OOVs
                oovs.append(w)
                oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...
                #ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...
        else:
            ids.append(i)
    return ids, oovs


def abstract2ids(abstract_words, vocab, article_oovs):
  ids = []
  unk_id = vocab.word2id(UNKNOWN_TOKEN)
  for w in abstract_words:
    i = vocab.word2id(w)
    if i == unk_id: # If w is an OOV word
      if w in article_oovs: # If w is an in-article OOV
        vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number
        #ids.append(vocab_idx)
      else: # If w is an out-of-article OOV
        ids.append(unk_id) # Map to the UNK token id
    else:
      ids.append(i)
  return ids

In [8]:
import struct
from tensorflow.core.example import example_pb2

def example_gen(filename):
    reader = open(filename, 'rb')
    examples = []
    while True:
        len_bytes = reader.read(8)
        if not len_bytes: break # finished reading this file
        str_len = struct.unpack('q', len_bytes)[0]
        example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
        e = example_pb2.Example.FromString(example_str)
        examples.append(e)
        
    for e in examples:  
        article_text = e.features.feature['article'].bytes_list.value[0]
        abstract_text = e.features.feature['abstract'].bytes_list.value[0]
        yield (article_text.decode('utf-8'), abstract_text.decode('utf-8'))
        

  from ._conv import register_converters as _register_converters


KeyboardInterrupt: 

In [None]:
vocab = Vocab("finished_files/vocab", 50000)

In [None]:
def build_tagger_model():
    master = DRAGNNMaster()
    master.add_component(TBRU("embed", TaggerRecurrent("input", "embed"), EmbeddingComputer(50000, 1000), (1,), True).cuda())
    #master.add_component(TBRU("extractive", TaggerRecurrent("embed", "extractive"), TaggerComputer(1000, 1000), (1,), True).cuda())
    master.add_component(TBRU("rnn", RNNSolidRecurrent("embed", "rnn"), RNNSolidComputer(1000, 500), (1,), True).cuda())
    
    master.add_component(TBRU("extractive2", TaggerRecurrent("rnn", "extractive2"), TaggerComputer(1000, 1), (1,), True).cuda())
    return master

model = build_tagger_model()
torch.cuda.memory_allocated()

In [None]:
def add_padding(articles, targets):
    lens = [len(article) for article in articles]
    max_len = max(lens)
    
    for i in range(len(articles)):
        targets[i].extend([0]*(max_len - len(articles[i])))
        articles[i].extend([vocab.word2id(PAD_TOKEN)]*(max_len - len(articles[i])))
    return np.array(articles).T, np.array(targets)

def iterate_batches(filename, batch_size):
    generator = example_gen(filename)
    while True:
        articles = []
        targets = []
        for i in range(batch_size):
            try:
                article_text, abstract_text = next(generator)
                article_ids, _ = article2ids(article_text, vocab)
                abstract_ids, _ = article2ids(abstract_text, vocab)
                target = [ int(i in abstract_ids and i != vocab.word2id(UNKNOWN_TOKEN)) for i in article_ids]
                
                articles.append(article_ids)
                targets.append(target)
            except:
                break
        if len(articles) == 0:
            break
        yield add_padding(articles, targets)

In [None]:
def calculate_mask(articles):
    mask = (articles == vocab.word2id(PAD_TOKEN))
    mask = np.logical_xor(mask, np.ones(articles.shape))
    return mask

def get_target(self, article, abstract):
    return [ int(i in abstract and i != vocab.word2id(UNKNOWN_TOKEN)) for i in article]

def push_abs_ptr(article, abstract, i, abs_ptr):
    while abs_ptr < len(abstract) and (not abstract[abs_ptr] in article[i+1:] 
                                               or abstract[abs_ptr] == vocab.word2id(SENTENCE_START)
                                               or abstract[abs_ptr] == vocab.word2id(SENTENCE_STOP)
                                               or abstract[abs_ptr] == vocab.word2id(UNKNOWN_TOKEN)):
        abs_ptr += 1
    return abs_ptr

def get_target2(self, article, abstract):
    abs_ptr = 0
    target = []
    abs_ptr = push_abs_ptr(article, abstract, -1, abs_ptr)
    if abs_ptr == len(abstract):
        target.extend([0]* len(article))
        break
    for i, art in enumerate(article):
        if art == abstract[abs_ptr]:
            target.append(1)
            abs_ptr += 1
            abs_ptr = push_abs_ptr(article, abstract, i, abs_ptr)
            if abs_ptr == len(abstract):
                target.extend([0]* (len(article) - i + 1))
                break
        else:
            target.append(0)
    return [ int(i in abstract and i != vocab.word2id(UNKNOWN_TOKEN)) for i in article]


class Batcher():
    
    def __init__(self, filename, batch_size):
        self.batch_size = batch_size
        generator = example_gen(filename)

        self.batches = []
        while True:
            articles = []
            targets = []
            for i in range(batch_size):
                try:
                    article_text, abstract_text = next(generator)
                    article_ids, _ = article2ids(article_text, vocab)
                    abstract_ids, _ = article2ids(abstract_text, vocab)
                    target = get_target(article, abstract)
                    
                    articles.append(article_ids)
                    targets.append(target)
                except:
                    break
            if len(articles) == 0:
                break
            articles, targets = add_padding(articles, targets)
            mask = calculate_mask(articles)
            batch.append( (articles, targets, mask) )
    
    def generator(self):
        for batch in self.batches:
            yield batch

In [None]:
def calc_f1(tp, fp, tn, fn):
    precision = tp/(fp + tp)
    recall = tp/(tp + fn)
    f1 = 2*precision*recall/(precision+recall)
    return f1

def precalc_f1(articles_tokens, articles, target):
    mask = calculate_mask(articles_tokens).T
    result = (articles > 0.5)
    #print(result[0])
    #print(articles.shape, target.shape, mask.shape)
    n_res = np.logical_not(result)
    n_tar = np.logical_not(target)
    tp = (result * target * mask).sum()
    fp = (n_res * target * mask).sum()
    tn = (n_res * n_tar * mask).sum()
    fn = (result * n_tar * mask).sum()
    return tp, fp, tn, fn

In [None]:
import math
import time

def do_epoch(model, criterion, data, batch_size, optimizer=None):  
    epoch_loss = 0.
    tp, fp, tn, fn = 0, 0, 0, 0
    is_train = not optimizer is None
    model.main.train(is_train)

    with torch.autograd.set_grad_enabled(is_train):
        for i, (article_text, target) in enumerate(iterate_batches(data, batch_size)):
            X_batch, y_batch = LongTensor(article_text), FloatTensor(target)
            inputs = InputLayerState("input", False, X_batch)
            logits = model.forward(inputs)
            
            logits = logits.squeeze(-1)
            loss = criterion(logits.transpose(0,1), y_batch)
            epoch_loss += loss.item()
            
            if is_train:
                optimizer.zero_grad()
                loss.backward()
                #nn.utils.clip_grad_norm_(model.parameters(), 1.)
                optimizer.step()
            
            tpb, fpb, tnb, fnb = precalc_f1(article_text, logits.cpu().detach().numpy().T, target)
            tp += tpb
            fp += fpb
            tn += tnb
            fn += fnb
            f1 = calc_f1(tp,fp,tn,fn)
            print('\r[{}]: Loss = {:.4f}, F1 = {:.4f}'.format(i, loss.item(), f1), end='')
             
    f1 = calc_f1(tp,fp,tn,fn)            
    return epoch_loss, f1

def fit(model, criterion, optimizer, train_data, epochs_count=1, 
        batch_size=32, val_data=None, val_batch_size=None):
    if not val_data is None and val_batch_size is None:
        val_batch_size = batch_size
        
    for epoch in range(epochs_count):
        start_time = time.time()
        train_loss, f1 = do_epoch(model, criterion, train_data, batch_size, optimizer)
        
        output_info = '\rEpoch {} / {}, Epoch Time = {:.2f}s: Train Loss = {:.4f}: F1-Score = {:.4f}'
        if not val_data is None:
            val_loss, f1 = do_epoch(model, criterion, val_data, val_batch_size, None)
            
            epoch_time = time.time() - start_time
            output_info += ', Val Loss = {:.4f}'
            print(output_info.format(epoch+1, epochs_count, epoch_time, train_loss, f1, val_loss))
        else:
            epoch_time = time.time() - start_time
            print(output_info.format(epoch+1, epochs_count, epoch_time, train_loss, f1))

In [None]:
criterion = nn.MSELoss().cuda()
optimizer = optim.Adam(model.main.parameters())

fit(model, criterion, optimizer, epochs_count=50, batch_size=32, train_data="finished_files/train.bin",
    val_data=None, val_batch_size=32)

In [None]:
model.save_model("first_try")

In [None]:
fit(model, criterion, optimizer, epochs_count=50, batch_size=32, train_data="finished_files/chunked/train_000.bin",
    val_data="finished_files/chunked/val_000.bin", val_batch_size=32)