# Headline Generation - PyTorch Implementation

## Imports and Config

*Not really using fastai for this particular notebook. Import to explore Fastai options to the same things.*

In [None]:
import sys
import random
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import numpy as np
import sys
import time
import json
import pprint
from random import randint
pp = pprint.PrettyPrinter(indent=4)

In [None]:
config = []
with open("./gigaword_attn_config.json") as f:
    config = json.loads(f.read())

## Data Preprocessing

In [None]:
inputs = []
outputs = []
val_inputs = []
val_outputs = []
counter = 0
_dict = {}

path = config["data_path"]
input_path = config["training"]["inputs"]
output_path = config["training"]["outputs"]

validation_input_path = config["validation"]["inputs"]
validation_output_path = config["validation"]["outputs"]

with open(path + input_path, "r") as f:
    inputs = f.readlines()
with open(path + output_path, "r") as f:
    outputs = f.readlines()
with open(path + validation_input_path, "r") as f:
    val_inputs = f.readlines()
with open(path + validation_output_path, "r") as f:
    val_outputs = f.readlines()
        
print("Training Samples (x,y):",len(inputs), len(outputs))
print("Validation Samples (x,y):", len(val_inputs), len(val_outputs))

### Rebuild Dictionary

### ...or load from disk instead

In [None]:
import pickle
_dict = pickle.load(open(config["dictionary_path"], "rb"))

In [None]:
_dict = sorted([(word, _dict[word]) for word in _dict], key=lambda x:x[1], reverse=True)

set config["vocab_size"]:-1 to use all words in vocab

In [None]:
add_tokens = config["dictionary_tokens"]
for i in range(len(add_tokens)):
    _dict.insert(i, (add_tokens[i], 1))
f_dict = {}
for i in range(len(_dict)):
    f_dict[_dict[i][0]] = i 
r_dict = [_dict[i][0] for i in range(len(_dict))]
vocab_size = len(r_dict) if config["vocab_size"] == -1 else config["vocab_size"]

In [None]:
print(vocab_size)

## Scaffolding

In [None]:
print([r_dict[w] for w in range(34)])

In [None]:
oov_token = config["oov_token"]
padding_token = config["padding_token"]

def generateBatch(x_source, y_source, input_ts=30, output_ts = 10, bs=64):
    dont_copy_index = input_ts-1
    x_ = []
    y_ = []
    u_ = []
    u_ind = []
    while len(x_) < bs:
        _u = []
        _u_ind = []
        l = randint(0, len(x_source)-1)
        x = [f_dict[t] for t in x_source[l].split()]
        y = [f_dict[t] for t in y_source[l].split()]
        x = [t if t < vocab_size else f_dict[oov_token] for t in x]
        y = [t if t < vocab_size else f_dict[oov_token] for t in y] 
        x = x[:input_ts]
        while (len(x) < input_ts):
            x.insert(0,f_dict[padding_token])
        for i in range(len(y)):
            word = y[i]   
            if (word in x and word != oov_token):
                _u.append(1)
                _u_ind.append(x.index(word))
                #y[i] = vocab_size + x.index(word)
            else:
                _u.append(0)
                _u_ind.append(dont_copy_index)
                
        while (len(y) < output_ts):
            y.append(f_dict[padding_token])
            _u.append(0)
            _u_ind.append(dont_copy_index)
            
        _u_ind = _u_ind[:output_ts]
        _u = _u[:output_ts]
        y = y[:output_ts]

        
        x_.append(x)
        y_.append(y)
        u_ind.append(_u_ind)
        u_.append(_u)

    return np.array(x_), np.array(y_), np.array(u_), np.array(u_ind)

In [None]:
x,y,u,ind = generateBatch(val_inputs, val_outputs, input_ts=20, output_ts=10)
x.shape, y.shape, u.shape, ind.shape
np.max(ind)

The loss function **flattens along the batch and timesteps dimensions** and computes a Loss for each word. 

In [None]:
def lossfn_multi(outputs, acts, criterion, input_ts = 30):
    acts = acts.transpose(0,1).contiguous().view(-1)
    outputs = outputs.view(-1, vocab_size)
    return criterion(outputs, acts.view(-1))

In [None]:
def validate(model, lossfn, criterion, num_batches = 10, bs = 128, output_ts=10):
    t_loss = 0
    for i in range(num_batches):
        x, y, u, u_indices = generateBatch(val_inputs, val_outputs, output_ts = output_ts)
        y = torch.LongTensor(y).cuda()
        h = m.reinitInputHiddenState(1)
        w,h,u = model(torch.from_numpy(x).cuda(), h, output_ts)
        l = lossfn(w,y,criterion)
        t_loss += l.item()
    return t_loss/num_batches

In [None]:
def trainBatch(x,y,u,u_indices, 
               model, 
               optimizer, 
               criterion, 
               bs, 
               use_tf = False, 
               output_ts=10):
    loss = 0

    #print(y) 
    y = torch.LongTensor(y).cuda()
    u = torch.LongTensor(u).cuda()
    #u_app = torch.zeros_like(u).unsqueeze(0).cuda().permute(1,2,0)
    
    u_indices = torch.LongTensor(u_indices).cuda()
    h = m.reinitInputHiddenState(bs)
    w,h,u_pred = model(torch.from_numpy(x).cuda(), h, output_ts, y_acts=y, use_tf=True)
        
    optimizer.zero_grad()
    """
    l1 = lossfn_u(u, u_pred)
    loss += l1
    l1.backward(retain_graph=True)
    #optimizer.step()
    """
    #optimizer.zero_grad()
    l2 = lossfn_uind(u_indices, h, criterion)
    loss += l2
    l2.backward(retain_graph=True)
    #optimizer.step()
    
    #optimizer.zero_grad()
    l = lossfn_multi(w,y,criterion)
    loss += l
    loss.backward()
    optimizer.step()
    
    return loss.item(), l.item(), 0, l2.item()

In [None]:
def sample(model, generate_words = 10, print_attn=False):
    x, y, u, u_ind = generateBatch(val_inputs, val_outputs, output_ts =generate_words)
    h = m.reinitInputHiddenState(1)
    outputs, atts, uts = model(torch.from_numpy(x).cuda(), h, generate_words, print_attn=print_attn)
    outputs = outputs.exp()
    words = torch.max(outputs, -1)[1].view(-1,outputs.size()[1]).permute(1,0)
    samples = []
    for i in range(x.shape[0]):
        x_words = " ".join([r_dict[word.item()] for word in x[i]])
        y_act_words = " ".join([r_dict[word.item()] for word in y[i]])
        y_words = " ".join([r_dict[word] for word in words[i]])
        """
        if (print_attn):
            print("ATTNS:")
            print(atts[i])
        """
        f_ = {
            "text": {
                "source":x_words, 
                "actual":y_act_words, 
                "predicted":y_words
            },
            "attention":F.softmax(atts[i], dim=-1).cpu().detach().numpy().tolist()
        }
        samples.append(f_)
    return samples

In [None]:
losses = []
val_losses = []
network_losses = []
ut_losses = []
att_losses = []

In [None]:
def train(model, epochs=1, 
          batches=128, 
          optim=None, 
          criterion = None, 
          bs = 64, 
          output_ts=20, 
          use_tf=False, 
          lr=1e-3, 
          num_valid_batches=10):
    
    if optim == None:
        optim = torch.optim.Adam(model.parameters(), lr=lr) 
    if criterion == None:
        criterion = nn.NLLLoss()
    for e in range(epochs):
        rolling_loss = 0
        rolling_time = 0
        rolling_network = 0
        rolling_ut = 0
        rolling_att = 0
        print("\n")
        for b in range(batches):
            b_start = time.time()
            loss, network_loss, ut_loss, att_loss = trainBatch(*generateBatch(inputs, outputs, output_ts = output_ts, bs=bs), 
                              model, 
                              optim, 
                              criterion, 
                              bs, 
                              output_ts=output_ts,
                              use_tf = use_tf)
            rolling_loss += loss
            rolling_network += network_loss
            rolling_ut += ut_loss
            rolling_att += att_loss
            b_stop = time.time()
            rolling_time += b_stop-b_start
            avg_time = rolling_time/(b+1)
            eta = (batches-b)*avg_time
            _str = "e" + str(e+1) + ", batch: " + \
                    str(b+1) + "\tloss:" + \
                    "{:10.3f}".format(rolling_loss/(b+1)) + \
                    " (" + \
                    "{:10.3f}".format(rolling_network/(b+1)) + \
                    "," + \
                    "{:10.3f}".format(rolling_ut/(b+1)) + \
                    "," + \
                    "{:10.3f}".format(rolling_att/(b+1)) + \
                    ") " + \
                    " \t\teta: " +  \
                    "{:5.1f}".format(eta) + "s\t" + \
                    "{:1.2f}".format(avg_time) + "s/batch\r"
            sys.stdout.write(_str)
            sys.stdout.flush()

        losses.append(rolling_loss/batches)
        #validate
        valid_loss = validate(model, lossfn_multi, criterion, bs=bs, num_batches=num_valid_batches, output_ts=output_ts)        
        print("\n")
        print("validation loss:", "{:3.2f}".format(valid_loss))

        doSave = False        
        if (len(val_losses) == 0):
            doSave = True
        elif (np.min(val_losses) > valid_loss):
            doSave = True            
        if (doSave):
            print("Saving Model:", config["save_model_path"])
            torch.save(model, config["save_model_path"])    
        val_losses.append(valid_loss)
        network_losses.append(network_loss)
        ut_losses.append(ut_loss)
        att_losses.append(att_loss)
        
        with open(config["save_training_cycle_path"], "w") as f:
            f.write(json.dumps( {                
                "training_loss":losses,
                "validation_loss":val_losses,
                "network_loss":network_losses,
                "ut_loss":ut_losses,
                "att_loss":att_losses
            }))
            f.close()
        
        #sample
        samples = sample(m, generate_words=output_ts)  
        _l = rolling_loss/batches
        _samples = {
            "epochs":len(losses),
            "used_tf":use_tf,
            "loss":_l,
            "val_loss":valid_loss,
            "samples":samples
        }
        with open(config["save_samples_path"] + str(time.time()) + "_.json", "w") as f:
            f.write(json.dumps(_samples, indent=4))
            f.close()

    return losses

## Model

In [None]:
start_token = f_dict[config["start_token"]]
class customGRU(nn.Module):
    def __init__(self, 
                 vocab_size=128, 
                 embed_dim = 100, 
                 lstm_dim= 90, 
                 hidden_dim=64, 
                 bidirec=False, 
                 lstm_layers = 3,
                 start_token = start_token):
        super(customGRU, self).__init__()
        self.lstm_dim = lstm_dim
        self.lstm_layers = lstm_layers
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx = f_dict[padding_token])
        self.start_token = start_token
        self.input_lstm = nn.GRU(embed_dim, lstm_dim, num_layers=lstm_layers, dropout=0.1, bidirectional=True)
        self.attn_W = nn.Parameter(torch.randn(hidden_dim, hidden_dim))
        self.w2i = nn.Parameter(torch.randn(hidden_dim + lstm_dim*2, hidden_dim))
        self.dec_lstm = nn.GRU(embed_dim, hidden_dim)
        self.decoder = nn.Parameter(torch.randn(self.hidden_dim + self.lstm_dim*2, embed_dim))
        self.ut = nn.Parameter(torch.randn(self.hidden_dim + self.lstm_dim*2, 1))

        self.dropout = nn.Dropout(0.1)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.log_softmax = nn.LogSoftmax(dim=-1)
        self.norm_softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()
        
    def forward(self, x, hidden, output_ts, use_tf=False, y_acts=None, train=True, print_attn=False):
        bs, ts = x.size()
        x_ = x.permute(1,0)
        o = self.embed(x_) #b,ts,embed
        y_ = None
        if (y_acts is not None):
            y_ = y_acts.permute(1,0)
        i_lh, i_h = self.input_lstm(o)
        h = i_lh[-1, :, :].unsqueeze(0)
        attn_ = torch.zeros(1, 1, self.lstm_dim*2)
        pw = np.zeros((bs,1))
        pw[:,0] = start_token
        pw = torch.LongTensor(pw).cuda().permute(1,0)
        o_wh = []
        atts_ = []
        uts_ = []
        for i in range(output_ts):
            pw = self.embed(pw)
            h,h_ = self.dec_lstm(pw, h)    
            #Attention Calculations
            a = torch.matmul(h, self.attn_W) #1, b, hidden_dim
            a = a.permute(1,0,2) #b,n,m <- b, 1, hidden_dim
            b = i_lh.permute(1,2,0) #b,m,p <- b, hidden_dim, ts
            e = torch.bmm(a,b) #b,n,p <= b, 1, ts
            alpha = F.softmax(e, dim=-1)
            
            if (print_attn):
                print("ALPHA:", alpha.size())
                print("ALPHA[0] SUM:",torch.sum(alpha,-1)[0])
                print("ALPHA[0]:",alpha[0])
                print("E:", e.size())
                print("E[0]:",e[0])
                print("I_LH_PERM:", b.size())
                print("A_PERM:", a.size())
                print("a_matmul", a.size())
                print("w_attn:", self.attn_W)
                print("h[0]", h[0])
                print("\n\n\n\n\n\n")
                #print(alpha)
                
            #atts_.append(e.view(-1, ts+1)) 
            atts_.append(e.view(-1,ts))
            #alpha = alpha #ts, b
            
            #output creation
            #alpha_mult = alpha[:,:,:ts]
            att_out = torch.sum(i_lh*alpha.view(-1, x.size()[0], 1),0) #b, hidden_dim
            att_out = att_out.unsqueeze(0) #1, b, hidden_dim
            h_att = torch.cat([h, att_out], -1) 
            
            
            ut = torch.matmul(h_att, self.ut)
            uts_.append(self.sigmoid(ut))
            
            
            w_proj = torch.matmul(h_att, self.decoder)
            w_ = self.log_softmax(torch.matmul(w_proj, torch.transpose(self.embed.weight, 0, 1)))
            #f_out = torch.cat([w_, alpha.permute(1,0,2)], -1)

            o_wh.append(w_)
            
            #GRU State Management
            pw = torch.max(w_,-1)[1]
            if ((use_tf and torch.randint(11, (1,))[0] > 3)):
                pw = y_[i,:].unsqueeze(0)
            
        o_wh = torch.stack(o_wh, 0).squeeze(1)
        atts_ = torch.stack(atts_, 0)
        atts_ = atts_.permute(1, 0, 2)
        uts_ = torch.stack(uts_, 0).squeeze(1).permute(1,0,2)
        return o_wh, atts_, uts_
    
    def reinitInputHiddenState(self,bs):
        return torch.zeros((self.lstm_layers, bs, self.lstm_dim)).cuda()

In [None]:
m = customGRU(vocab_size = vocab_size, 
              hidden_dim = 256, 
              embed_dim=300, 
              lstm_dim = 128).cuda()

In [None]:
samples = sample(m, print_attn=False)    # test that sampling works without errors. 

# Training Loop

In [None]:
def lossfn_u(outputs_u, acts):
    loss = nn.BCEWithLogitsLoss()
    acts = acts.contiguous().view(-1, acts.size(-1))
    outputs = outputs_u.view(-1,1)
    l = loss(outputs.float(), acts)
    return l
    
def lossfn_uind(acts_ind, outputs_ind, criterion):
    loss2 = nn.CrossEntropyLoss()
    acts_ind = acts_ind.transpose(0,1).contiguous().view(-1)
    outputs_ind = outputs_ind.permute(1,0,2)
    outputs_ind = outputs_ind.contiguous().view(-1, outputs_ind.size()[-1])
    l2 = loss2(outputs_ind, acts_ind)
    return l2   

In [None]:
epochs = [10] 
batches = [2500, 1000]
tf = [True, True]
lrs = [1e-3, 1e-4]

_output_ts = 10
_bs = 128
optim = torch.optim.Adam(m.parameters(), lr=1e-3) 


while True:
    for i in range(len(epochs)):
        print("\n")
        print(str(epochs[i]) + " epoch(s) (" + str(batches[i]) + " batches of " + str(_bs) + " samples each.) Teacher Forcing:", tf[i])
        e = epochs[i]
        b = batches[i]
        _losses = train(m, 
                        epochs=e, 
                        batches=b, 
                        optim = optim, 
                        output_ts=_output_ts, 
                        use_tf=tf[i], 
                        bs=_bs)
        