In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from IPython.display import display
import pandas as pd
import h5py
import os
from tqdm import tqdm
# to load data from the test file
from utils.create_datasets import SumDatasets
# To load the data sets here, also we will create some examples to explore torch
data_dir = '../data/'
from utils import config
import pickle
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

  from ._conv import register_converters as _register_converters


In [2]:
sum_dataset = SumDatasets(os.path.join(data_dir, 'features-600-40_v2.hdf5'))

In [5]:
def encoder_transform(features_1, features_2, features_3, features_4):
    """
    To transform the data to the format, we need to order the data in the sequences of length
    ,so we can later utilize the data in the NN model with the pack_paded_sequences and pad_packed_sequences.
    
    Parameters
    ----------
        
    """
    assert isinstance(features_1, torch.Tensor), "You must give the data to tensor object"
    batch_size = features_1.size(0)
    if batch_size == 1:
        pass
    else:
        sorted_length, sorted_idx = features_2.sort()  # sort will return both the ascending sorted value and also the sorted index
        reverse_idx = torch.linspace(batch_size - 1, 0, batch_size).long()  # this will contain the batch_size-1
        sorted_length, sorted_idx = sorted_length[reverse_idx], sorted_idx[reverse_idx]
        features_1 = features_1[sorted_idx]
        features_2 = features_2[sorted_idx]
        features_3 = features_3[sorted_idx]
        features_4 = features_4[sorted_idx]
    features_1.squeeze_(1), features_3.squeeze_(1), features_4.squeeze_(1)
    features_3 = features_3.permute([1, 0])
    features_4 = features_4.permute([1, 0])
    return features_1, features_2, features_3, features_4
features_1, features_2, features_3, features_4 = sum_dataset[0:10] # assume we have a batch data of 20 items.
display(features_2)

tensor([600, 600, 600, 600, 600, 600, 600, 600, 600, 436])

## given the data, build a very naive summarization model here
* first, we need to compute all the hidden state of each time step and a final state, contain cell state and hidden state of the final time and feed it into the network and do the decoder agian.

In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.embed = nn.Embedding(config.NUM_WORDS + 2, config.embedding_dim)
        self.rnn = nn.LSTM(config.embedding_dim, config.hidden_dim, bidirectional=True)
        self.reduce_ = nn.Linear(2 * config.hidden_dim, config.hidden_dim)
        
    def forward(self, X, seq_lens):
        """return the final satets and also the outputs of each timesteps, for the later usage of
        computing the Attentaion matrix foe each time step input of the Decoder.
        
        Parameters
        ----------
        X : [Torch tensor with batch*MAX_STEP]
            
        seq_lens : [descend order of the real length of the data]
        """
        X = self.embed(X)
        batch_size = X.size()[0]
        packed_x = pack_padded_sequence(X, seq_lens, batch_first=True)
        outputs, hidden = self.rnn(packed_x)
        outputs, seq_lens = pad_packed_sequence(outputs, batch_first=True)
        # outputs is a bathc*max_enc_steps*(2*hidden_dim), but for the hidden
        # must give then to batch first format, so we need to implement this
        # with the following code.
        hidden_c, hidden_s = hidden
        hidden_c = self.reduce_(hidden_c.permute([1, 0, 2]).contiguous().view(batch_size, -1))
        hidden_s = self.reduce_(hidden_s.permute([1, 0, 2]).contiguous().view(batch_size, -1))
        
        hidden_c.unsqueeze_(1)
        hidden_s.unsqueeze_(1)
        # to make the hidden to be batch first and later, do the loop over time
        hidden = (hidden_c.permute([1,0,2]), hidden_s.permute([1, 0, 2]))
        return outputs, hidden

In [6]:
encoder = Encoder()
# display(features_1.squeeze_(1).shape)
features_1, features_2, features_3, features_4 = encoder_transform(features_1, features_2, features_3, features_4)
outputs, hidden = encoder(features_1, features_2)
display(hidden[0].shape)

torch.Size([1, 10, 50])

In [7]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(config.NUM_WORDS+4, config.embedding_dim)
        self.rnn = nn.LSTM(config.embedding_dim, config.hidden_dim, batch_first=True)
        self.logits = nn.Linear(config.hidden_dim, config.NUM_WORDS+4, bias=False)
        
        
        
    def forward(self, X, hidden):
        """
        using the encoder's hidden state to initiaize the decoder
        input and also use the teaching force in the training mode. here
        TODO: how we tell the real difference between the train, teaching force and
        the evalutaion?
        """
        X = self.embed(X)
        # first we need to transoform the X to be the seq_len first
        X = X.unsqueeze_(1) # just one time step, the X now shoud be batch_size*1*embeeding_dim
        outputs, hidden = self.rnn(X, hidden)
        outputs = F.log_softmax(self.logits(outputs.squeeze_(1)), dim=1)
        return outputs, hidden

In [9]:
decoder = Decoder()
outputs = []
# simply test the time dimension 
# feature 3 shoud be time first
# features_4 = features_4.permute([1, 0])
loss = 0
# display(features_3.squeeze_(1).permute([1, 0]))
critertion = nn.NLLLoss(ignore_index=0)
target = features_4
for di in range(config.max_dec_steps):
    X = features_3[di]
    output, hidden = decoder(X, hidden)
    # here to implement the loss wrt target
    print("+++++++++++++{}++++++++++++++".format(di))
    display(output.shape)
    display(target[di])
    loss += critertion(output, target[di])
    display(loss)


+++++++++++++0++++++++++++++


torch.Size([10, 50004])

tensor([  91, 2441, 2017,   31,   31, 2017,   15, 2250, 1559, 2017])

tensor(10.8721, grad_fn=<AddBackward>)

+++++++++++++1++++++++++++++


torch.Size([10, 50004])

tensor([  75, 2250,  208, 2250, 2250,   15, 1307, 1336,  623,    5])

tensor(21.7097, grad_fn=<ThAddBackward>)

+++++++++++++2++++++++++++++


torch.Size([10, 50004])

tensor([   7,    5, 1559,  623,  623, 1670, 2017, 2441,    5,  770])

tensor(32.5488, grad_fn=<ThAddBackward>)

+++++++++++++3++++++++++++++


torch.Size([10, 50004])

tensor([  31,   31, 1559,    5,    7,  623, 1559,   31, 1670,  623])

tensor(43.3820, grad_fn=<ThAddBackward>)

+++++++++++++4++++++++++++++


torch.Size([10, 50004])

tensor([1336,   31,   31, 1559,  208,   31,  208, 1336,   15, 1336])

tensor(54.2054, grad_fn=<ThAddBackward>)

+++++++++++++5++++++++++++++


torch.Size([10, 50004])

tensor([1351, 1336, 1351, 1043, 1351, 2250,  623,    7, 1307, 2017])

tensor(65.0108, grad_fn=<ThAddBackward>)

+++++++++++++6++++++++++++++


torch.Size([10, 50004])

tensor([  15,  215,    5,   15,  461,   15, 1307,  461, 1710,  208])

tensor(75.8343, grad_fn=<ThAddBackward>)

+++++++++++++7++++++++++++++


torch.Size([10, 50004])

tensor([ 623, 1336, 1307, 1710, 1351, 1307,   31,  623, 1710,    7])

tensor(86.6443, grad_fn=<ThAddBackward>)

+++++++++++++8++++++++++++++


torch.Size([10, 50004])

tensor([   7,    5,    7,  623,   15, 1710,   15, 1307, 1336,    7])

tensor(97.4298, grad_fn=<ThAddBackward>)

+++++++++++++9++++++++++++++


torch.Size([10, 50004])

tensor([1336, 2017,  770, 1307,    7,    7,    5,  215, 1336, 1336])

tensor(108.2259, grad_fn=<ThAddBackward>)

+++++++++++++10++++++++++++++


torch.Size([10, 50004])

tensor([ 208,   31, 1351,    7,   15,   15, 1559, 1559, 1710,  770])

tensor(119.0322, grad_fn=<ThAddBackward>)

+++++++++++++11++++++++++++++


torch.Size([10, 50004])

tensor([1351,  623,   15,   15, 1307, 1670, 1043,  623, 1559,  770])

tensor(129.8577, grad_fn=<ThAddBackward>)

+++++++++++++12++++++++++++++


torch.Size([10, 50004])

tensor([1351, 1351,  461,    5, 1710,  623, 1336,    7,  623,  623])

tensor(140.7036, grad_fn=<ThAddBackward>)

+++++++++++++13++++++++++++++


torch.Size([10, 50004])

tensor([ 623, 1710,   31, 1307,    7, 1559, 1336,    7, 2017, 1351])

tensor(151.4653, grad_fn=<ThAddBackward>)

+++++++++++++14++++++++++++++


torch.Size([10, 50004])

tensor([   5, 1351, 2017,  770,  623,  623, 1529,    7, 1336,    7])

tensor(162.3031, grad_fn=<ThAddBackward>)

+++++++++++++15++++++++++++++


torch.Size([10, 50004])

tensor([ 215,    5, 1336, 1351,  770,    5,    7,    5, 1351,   15])

tensor(173.1577, grad_fn=<ThAddBackward>)

+++++++++++++16++++++++++++++


torch.Size([10, 50004])

tensor([ 623,  215, 1351,  208, 1351, 1351,   15, 1670, 2017, 1307])

tensor(183.9357, grad_fn=<ThAddBackward>)

+++++++++++++17++++++++++++++


torch.Size([10, 50004])

tensor([1351,  208,  124,    7,  623, 1307, 1351,  623, 1351,  770])

tensor(194.7434, grad_fn=<ThAddBackward>)

+++++++++++++18++++++++++++++


torch.Size([10, 50004])

tensor([  7,   5, 623,   5,  31, 623, 623, 124,  15, 623])

tensor(205.5338, grad_fn=<ThAddBackward>)

+++++++++++++19++++++++++++++


torch.Size([10, 50004])

tensor([1559,   31, 1351,  215,   31,  215,    5, 1336,  623,   91])

tensor(216.3450, grad_fn=<ThAddBackward>)

+++++++++++++20++++++++++++++


torch.Size([10, 50004])

tensor([1336,   15, 1710,  623, 1336, 2017,  215, 1351, 1307,  149])

tensor(227.1494, grad_fn=<ThAddBackward>)

+++++++++++++21++++++++++++++


torch.Size([10, 50004])

tensor([1670, 1307, 1351,    5,  770, 1351,   15,  623,  215,  149])

tensor(237.9324, grad_fn=<ThAddBackward>)

+++++++++++++22++++++++++++++


torch.Size([10, 50004])

tensor([ 623, 1710, 1336, 1307, 2250, 1336, 1307,    5, 5175,  209])

tensor(248.7524, grad_fn=<ThAddBackward>)

+++++++++++++23++++++++++++++


torch.Size([10, 50004])

tensor([ 215,  770,  208,  215,    5,  124,   31, 1307, 2250, 2441])

tensor(259.5410, grad_fn=<ThAddBackward>)

+++++++++++++24++++++++++++++


torch.Size([10, 50004])

tensor([  15, 1336,  461,   31, 1307,    7, 2250,  215, 1336,    5])

tensor(270.2885, grad_fn=<ThAddBackward>)

+++++++++++++25++++++++++++++


torch.Size([10, 50004])

tensor([1307, 1559, 1336, 2250, 1710, 2250,  623,  770, 2441, 1307])

tensor(281.0934, grad_fn=<ThAddBackward>)

+++++++++++++26++++++++++++++


torch.Size([10, 50004])

tensor([  91, 1559, 1307,  623,   15, 1336, 1559, 1559, 2441,   31])

tensor(291.9171, grad_fn=<ThAddBackward>)

+++++++++++++27++++++++++++++


torch.Size([10, 50004])

tensor([ 149,  623,  770, 1043, 1307, 1351,    5,  623,  623,    7])

tensor(302.7193, grad_fn=<ThAddBackward>)

+++++++++++++28++++++++++++++


torch.Size([10, 50004])

tensor([  75, 1710,  623, 1559, 1710,   31,    7,    5,  215, 1336])

tensor(313.5751, grad_fn=<ThAddBackward>)

+++++++++++++29++++++++++++++


torch.Size([10, 50004])

tensor([ 211,  623, 1336,    5,   31,   31,   31, 1307,   15,  208])

tensor(324.3560, grad_fn=<ThAddBackward>)

+++++++++++++30++++++++++++++


torch.Size([10, 50004])

tensor([50002, 50002,     5,   770,  2250,   623,    91,   208,   215,    31])

tensor(335.2165, grad_fn=<ThAddBackward>)

+++++++++++++31++++++++++++++


torch.Size([10, 50004])

tensor([    0,     0,  1307,  1529,   623,  1351,  2976,   461,  2250, 50002])

tensor(346.0214, grad_fn=<ThAddBackward>)

+++++++++++++32++++++++++++++


torch.Size([10, 50004])

tensor([   0,    0,  215,  124, 2441,  124,  623, 2976,    5,    0])

tensor(356.8384, grad_fn=<ThAddBackward>)

+++++++++++++33++++++++++++++


torch.Size([10, 50004])

tensor([   0,    0, 1351,    5, 1336,   31,    5, 1336, 1307,    0])

tensor(367.6624, grad_fn=<ThAddBackward>)

+++++++++++++34++++++++++++++


torch.Size([10, 50004])

tensor([   0,    0,  623,    7, 1351, 1351, 1351,  208,  215,    0])

tensor(378.4536, grad_fn=<ThAddBackward>)

+++++++++++++35++++++++++++++


torch.Size([10, 50004])

tensor([   0,    0, 2441,    7, 1559,    5,    7, 1351, 1559,    0])

tensor(389.2600, grad_fn=<ThAddBackward>)

+++++++++++++36++++++++++++++


torch.Size([10, 50004])

tensor([    0,     0,   124, 50002,   215,  1670, 50002,  2017,   623,     0])

tensor(400.1111, grad_fn=<ThAddBackward>)

+++++++++++++37++++++++++++++


torch.Size([10, 50004])

tensor([    0,     0,     5,     0, 50002,   623,     0,    15,    31,     0])

tensor(410.9999, grad_fn=<ThAddBackward>)

+++++++++++++38++++++++++++++


torch.Size([10, 50004])

tensor([   0,    0,    7,    0,    0, 1559,    0, 1307, 2250,    0])

tensor(421.8301, grad_fn=<ThAddBackward>)

+++++++++++++39++++++++++++++


torch.Size([10, 50004])

tensor([   0,    0, 1336,    0,    0, 1559,    0,    5,  623,    0])

tensor(432.6760, grad_fn=<ThAddBackward>)

In [10]:
from torch.utils.data import DataLoader

train_data = DataLoader(sum_dataset, shuffle=True, batch_size=10)
len(train_data)

14583

In [11]:
for i_batch, sample_batched in enumerate(tqdm(train_data)):
    if i_batch % 100 == 0:
        print("one batch finished.")

  3%|▎         | 455/14583 [00:00<00:06, 2268.89it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


  6%|▋         | 947/14583 [00:00<00:05, 2360.33it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 10%|▉         | 1454/14583 [00:00<00:05, 2416.65it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 13%|█▎        | 1961/14583 [00:00<00:05, 2444.61it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 17%|█▋        | 2506/14583 [00:01<00:04, 2499.43it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 21%|██        | 3032/14583 [00:01<00:04, 2490.27it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 24%|██▍       | 3555/14583 [00:01<00:04, 2449.34it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 28%|██▊       | 4055/14583 [00:01<00:04, 2408.26it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 32%|███▏      | 4604/14583 [00:01<00:04, 2443.44it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 35%|███▌      | 5142/14583 [00:02<00:03, 2466.94it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 39%|███▉      | 5680/14583 [00:02<00:03, 2471.23it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 45%|████▍     | 6496/14583 [00:02<00:03, 2499.48it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 48%|████▊     | 7047/14583 [00:02<00:02, 2517.26it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 52%|█████▏    | 7589/14583 [00:03<00:02, 2518.11it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 56%|█████▌    | 8126/14583 [00:03<00:02, 2507.97it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 59%|█████▉    | 8662/14583 [00:03<00:02, 2517.81it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 63%|██████▎   | 9201/14583 [00:03<00:02, 2527.17it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 69%|██████▉   | 10123/14583 [00:03<00:01, 2566.33it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 74%|███████▍  | 10772/14583 [00:04<00:01, 2598.90it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 78%|███████▊  | 11408/14583 [00:04<00:01, 2605.62it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 83%|████████▎ | 12088/14583 [00:04<00:00, 2640.15it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 87%|████████▋ | 12719/14583 [00:04<00:00, 2654.88it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 92%|█████████▏| 13367/14583 [00:04<00:00, 2678.15it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


 96%|█████████▋| 14039/14583 [00:05<00:00, 2704.27it/s]

one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.
one batch finished.


100%|██████████| 14583/14583 [00:05<00:00, 2720.30it/s]

one batch finished.
one batch finished.
one batch finished.





In [12]:
display(features_1.shape)
display(features_2.shape)
display(features_3.shape)
display(features_4.shape)

torch.Size([10, 600])

torch.Size([10])

torch.Size([40, 10])

torch.Size([40, 10])