In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerDecoder,TransformerDecoderLayer
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.nn import Transformer
import pandas as pd
import torch.optim as optim
import itertools

In [2]:
from functools import lru_cache
import gensim
import gensim.downloader as api
from gensim.models import KeyedVectors
import gensim.utils as utils

import nltk
nltk.download('wordnet')
from nltk.stem import WordNetLemmatizer
cached_lemmatize = lru_cache(maxsize=50000)(WordNetLemmatizer().lemmatize)
from gensim.utils import simple_preprocess, to_unicode

unable to import 'smart_open.gcs', disabling that module
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\mapka\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
class BiLSTMEncoder(nn.Module):
    def __init__(self, input_dim,emb_dim,enc_hid_dim,dec_hid_dim,dropout=0.5):
        
        super(Encoder,self).__init__()
        
        self.input_dim = input_dim
        self.emb_dim = emb_dim
        self.enc_hid_dim = enc_hid_dim
        self.dec_hid_dim = dec_hid_dim
        self.dropout = dropout
        
        self.embedding = nn.Embedding(input_dim,emb_dim)
        
        self.rnn = nn.LSTM(emb_dim, enc_hid_dim, bidirectional = True)
        
        self.fc = nn.Linear( enc_hid_dim * 2, dec_hid_dim )
        
        self.dropout = nn.Dropout( dropout )
        
    def forward(self, X):
        
        embedded = self.dropout(self.embedding(X))
        
        outputs, hidden = self.rnn(embedded)
        
        hidden = F.tanh( self.fc ( torch.cat( (hidden[-2,:,:], hidden[-1, : , : ] ), dim = 1 ) ) )
        
        return outputs, hidden


        
        

In [4]:
base_dir = "data"
train_file_X = os.path.join(base_dir,"train.source")
train_file_y = os.path.join(base_dir,"train.target")
test_file_X = os.path.join(base_dir,"test.source")
test_file_y = os.path.join(base_dir,"test.target")
val_file_X = os.path.join(base_dir,"val.source")
val_file_y = os.path.join(base_dir,"val.target")


In [5]:
import re
import string

STOP_WORDS = ["i", "a", "about", "an", "are", "as", "at", "be", "by", 
                "for", "from", "how", "in", "is", "it", "of", "on", "or", "that", "the", 
                "this", "to", "was", "what", "when", "where", "who", "will", "with"]

def ExpandContractions(contraction):

    contraction = re.sub(r"won\'t", "will not", contraction)
    contraction = re.sub(r"can\'t", "can not", contraction)

    contraction = re.sub(r"n\'t", " not", contraction)
    contraction = re.sub(r"\'re", " are", contraction)
    contraction = re.sub(r"\'s", " is", contraction)
    contraction = re.sub(r"\'d", " would", contraction)
    contraction = re.sub(r"\'ll", " will", contraction)
    contraction = re.sub(r"\'t", " not", contraction)
    contraction = re.sub(r"\'ve", " have", contraction)
    contraction = re.sub(r"\'m", " am", contraction)

    return contraction

def PreProcess(line):
    
    line = line.translate(str.maketrans("", "", string.punctuation))
    line = ExpandContractions(line)
    line = simple_preprocess(to_unicode(line))
    line = [cached_lemmatize(word) for word in line if word not in STOP_WORDS]

    line = " ".join(line)
    return line

In [6]:
class LineSentenceGenerator(object):

    def __init__(self, source, preprocess=None, max_sentence_length=4000, limit=None, preprocess_flag=True):
        self.source = source
        self.max_sentence_length = max_sentence_length
        self.limit = limit
        self.input_files = []

        if preprocess != None and callable(preprocess) and preprocess_flag:
            self.preprocess = preprocess
        else:
            self.preprocess = lambda line: line.rstrip("\r\n")

        if isinstance(self.source, list):
            print('List of files given as source. Verifying entries and using.')
            self.input_files = [filename for filename in self.source if os.path.isfile(filename)]
            self.input_files.sort()  # makes sure it happens in filename order

        elif os.path.isfile(self.source):
            print('Single file given as source, rather than a list of files. Wrapping in list.')
            self.input_files = [self.source]  # force code compatibility with list of files

        elif os.path.isdir(self.source):
            self.source = os.path.join(self.source, '')  # ensures os-specific slash at end of path
            print('Directory of files given as source. Reading directory %s', self.source)
            self.input_files = os.listdir(self.source)
            self.input_files = [self.source + filename for filename in self.input_files]  # make full paths
            self.input_files.sort()  # makes sure it happens in filename order
        else:  # not a file or a directory, then we can't do anything with it
            raise ValueError('Input is neither a file nor a path nor a list')
        print('Files read into LineSentenceGenerator: %s' % ('\n'.join(self.input_files)))

        self.token_count = 0

    def __iter__(self):
        for file_name in self.input_files:
            print('Reading file %s', file_name)
            with open(file_name, 'rb') as fin:
                for line in itertools.islice(fin, self.limit):
                    line = self.preprocess(utils.to_unicode(line))
                    self.token_count += len(line)
                    i = 0
                    while i < len(line):
                        yield line[i:i + self.max_sentence_length]
                        i += self.max_sentence_length

    def __len__(self):
        if self.token_count > 0:
            return self.token_count
        else:
            return len(self.input_files)

    def __bool__(self):
        return self.has_data()

    def is_empty(self):
        return len(self.input_files) == 0

    def has_data(self):
        return not self.is_empty()

In [7]:
from torchtext.data import Dataset,Example
from torchtext.data import Field, BucketIterator
from torchtext.data.utils import get_tokenizer

SRC = Field(tokenize = get_tokenizer("spacy"),
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = False)


In [43]:
def read_data(X, y, src, preprocess=None, limit=1000):
    examples = []
    fields = {'text-tokens': ('text', src),
              'summ-tokens': ('summ', src)}
    for i,(x,y) in enumerate(zip(LineSentenceGenerator(X, preprocess),LineSentenceGenerator(y, preprocess))):
        text_field = x
        summ_field = y
        
        if i > limit:
            break
       
        e = Example.fromdict({"text-tokens": text_field, "summ-tokens": summ_field},
                             fields=fields)
        examples.append(e)
    print("examples: \n", examples[0])
    return Dataset(examples, fields=[('text', src), ('summ', src)])

In [45]:
train_data = read_data(train_file_X,train_file_y,SRC,PreProcess,1000)
test_data = read_data(test_file_X,test_file_y,SRC,PreProcess,200)
val_data = read_data(val_file_X,val_file_y,SRC,PreProcess,200)

Single file given as source, rather than a list of files. Wrapping in list.
Files read into LineSentenceGenerator: data\train.source
Single file given as source, rather than a list of files. Wrapping in list.
Files read into LineSentenceGenerator: data\train.target
Reading file %s data\train.source
Reading file %s data\train.target
examples: 
 <torchtext.data.example.Example object at 0x0000022F73E5BB08>
Single file given as source, rather than a list of files. Wrapping in list.
Files read into LineSentenceGenerator: data\test.source
Single file given as source, rather than a list of files. Wrapping in list.
Files read into LineSentenceGenerator: data\test.target
Reading file %s data\test.source
Reading file %s data\test.target
examples: 
 <torchtext.data.example.Example object at 0x0000022F18762288>
Single file given as source, rather than a list of files. Wrapping in list.
Files read into LineSentenceGenerator: data\val.source
Single file given as source, rather than a list of files.

In [46]:
print("text: ",train_data[0].text)
print("\ntext-len: ",len(train_data[0].text))
print("\n\nsummary: ",train_data[0].summ)

text:  ['editor', 'note', 'our', 'behind', 'scene', 'series', 'cnn', 'correspondent', 'share', 'their', 'experience', 'covering', 'news', 'and', 'analyze', 'story', 'behind', 'event', 'here', 'soledad', 'obrien', 'take', 'user', 'inside', 'jail', 'many', 'inmate', 'mentally', 'ill', 'inmate', 'housed', 'forgotten', 'floor', 'many', 'mentally', 'ill', 'inmate', 'housed', 'miami', 'before', 'trial', 'miami', 'florida', 'cnn', 'ninth', 'floor', 'miamidade', 'pretrial', 'detention', 'facility', 'dubbed', 'forgotten', 'floor', 'here', 'inmate', 'most', 'severe', 'mental', 'illness', 'incarcerated', 'until', 'they', 're', 'ready', 'appear', 'court', 'most', 'often', 'they', 'face', 'drug', 'charge', 'charge', 'assaulting', 'officer', 'charge', 'judge', 'steven', 'leifman', 'say', 'usually', 'avoidable', 'felony', 'he', 'say', 'arrest', 'often', 'result', 'confrontation', 'police', 'mentally', 'ill', 'people', 'often', 'wo', 'nt', 'do', 'they', 're', 'told', 'police', 'arrive', 'scene', 'conf

In [47]:
train_data.fields

{'text': <torchtext.data.field.Field at 0x22f625f8b08>,
 'summ': <torchtext.data.field.Field at 0x22f625f8b08>}

In [48]:
print("text: ", test_data[100].text)
print("\n\nsumm: ",test_data[100].summ)

text:  ['cnna', 'frenchlanguage', 'global', 'television', 'network', 'regained', 'control', 'one', 'it', 'channel', 'thursday', 'after', 'cyberattack', 'day', 'earlier', 'crippled', 'it', 'broadcast', 'and', 'social', 'medium', 'account', 'television', 'network', 'tv', 'monde', 'gradually', 'regaining', 'control', 'it', 'channel', 'and', 'social', 'medium', 'outlet', 'after', 'suffering', 'network', 'director', 'called', 'extremely', 'powerful', 'cyberattack', 'addition', 'it', 'channel', 'tv', 'monde', 'lost', 'control', 'it', 'social', 'medium', 'outlet', 'and', 'it', 'website', 'director', 'yves', 'bigot', 'said', 'video', 'message', 'posted', 'later', 'facebook', 'mobile', 'site', 'which', 'still', 'active', 'network', 'said', 'hacked', 'islamist', 'group', 'isi', 'logo', 'and', 'marking', 'appeared', 'tv', 'monde', 'social', 'medium', 'account', 'but', 'there', 'no', 'immediate', 'claim', 'responsibility', 'isi', 'any', 'other', 'group', 'day', 'broke', 'thursday', 'europe', 'netw

In [49]:
SRC.build_vocab(train_data.text, min_freq = 2)

In [50]:
device = torch.device('cpu')

BATCH_SIZE = 128

train_iter = BucketIterator(train_data,BATCH_SIZE, shuffle=True,
                                                 sort_key=lambda x: len(x.text), sort_within_batch=True)

val_iter = BucketIterator(val_data, BATCH_SIZE, sort_key=lambda x: len(x.text), sort_within_batch=True)
test_iter = BucketIterator(test_data,BATCH_SIZE, sort_key=lambda x: len(x.text), sort_within_batch=True)

In [51]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [52]:
import math
class TransformerSummarizer(nn.Module):
    def __init__(self, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length,vocab_size, pad_idx,  d_model=None, pos_dropout =0.1, trans_dropout= 0.1,embeddings=None):
        super().__init__()
       
        if embeddings is None:
            self.embed_src = nn.Embedding(vocab_size, d_model)
            self.embed_tgt = nn.Embedding(vocab_size, d_model)
        else:
            d_model = embeddings.size(1)
            self.d_model = embeddings.size(1)
            self.embed_src = nn.Embedding(*embeddings.shape)
            self.embed_src.weight = nn.Parameter(embeddings,requires_grad=False)
            
            self.embed_tgt = nn.Embedding(*embeddings.shape)
            self.embed_tgt.weight = nn.Parameter(embeddings,requires_grad=False)
        
        
        self.pos_enc = PositionalEncoding(d_model, pos_dropout, max_seq_length)

        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, trans_dropout)
        
        self.fc = nn.Linear(d_model, vocab_size)
        
        self.pad_idx = pad_idx
        
        self.src_mask = None
        self.tgt_mask = None
        self.memory_mask = None
        
    def generate_square_mask(self,sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def make_pad_mask(self,seq,pad_idx):
        mask = (seq == pad_idx).transpose(0,1)
        return mask
    

    def forward(self, src, tgt):
        if self.tgt_mask is None or self.tgt_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_mask(len(trg)).to(trg.device)
        
        
        src_pad_mask = self.make_pad_mask(src,self.pad_idx)
        tgt_pad_mask = self.make_pad_mask(tgt,self.pad_idx)

        
        src = self.pos_enc(self.embed_src(src) * math.sqrt(self.d_model))

        tgt = self.pos_enc(self.embed_tgt(tgt) * math.sqrt(self.d_model))
        

        output = self.transformer(src, tgt, src_mask=self.src_mask, tgt_mask=self.tgt_mask, memory_mask=self.memory_mask, 
                                 src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask, memory_key_padding_mask=src_pad_mask)
        
        return self.fc(output)
        
        
        

        
        

In [53]:

PAD_IDX = SRC.vocab.stoi[SRC.pad_token]

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [54]:

SEQ_LEN = 4000

D_MODEL = 300 # Embedding dimension
DIM_FEEDFORWARD = 300  # Dimensionality of the hidden state
VOCAB_SIZE = len(SRC.vocab)  # size of the vocabulary
print("vocab-size: ", VOCAB_SIZE) 

ATTENTION_HEADS = 6  # number of attention heads
N_LAYERS = 1 # number of encoder/decoder layers



# nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length,vocab_size, pad_idx,  d_model=None, pos_dropout =0.1, trans_dropout= 0.1,embeddings=None

vocab-size:  14744


In [55]:
from torchtext.vocab import FastText

ff = FastText("en")

embeddings =  ff.get_vecs_by_tokens(SRC.vocab.itos)

embeddings.shape

torch.Size([14744, 300])

In [56]:
model = TransformerSummarizer( ATTENTION_HEADS,N_LAYERS, N_LAYERS, DIM_FEEDFORWARD, SEQ_LEN,VOCAB_SIZE,PAD_IDX,embeddings=embeddings).to(device)

In [57]:
model

TransformerSummarizer(
  (embed_src): Embedding(14744, 300)
  (embed_tgt): Embedding(14744, 300)
  (pos_enc): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): Linear(in_features=300, out_features=300, bias=True)
          )
          (linear1): Linear(in_features=300, out_features=300, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=300, out_features=300, bias=True)
          (norm1): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): Tr

In [58]:
import math
import time
from tqdm.notebook import tqdm_notebook as tqdm

def train(model: nn.Module,
          iterator: BucketIterator,
          num_batches: int,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):
    
    print("Training......")

    model.train()

    epoch_loss = 0

    for batch in tqdm(iterator,total=num_batches):
        src = batch.text
        trg = batch.summ
        

        trg_inp, trg_out = trg[:-1, :], trg[1:, :]

        optimizer.zero_grad()

        output = model(src.to(device), trg_inp.to(device))
    
        output = output.view(-1, output.shape[-1])

        loss = criterion(output, trg_out.view(-1))

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.item()
        
    print("Training Done.....")

    return epoch_loss / len(iterator)

In [61]:
def evaluate(model: nn.Module,
             iterator: BucketIterator,
             num_batches:int,
             criterion: nn.Module,
            desc: str):

    model.eval()

    epoch_loss = 0
    
    print(f'{desc}ing')
    
    with torch.no_grad():

        for batch in tqdm(iterator,total=num_batches):
            
            src = batch.text
            trg = batch.summ
            
            trg_inp, trg_out = trg[:-1, :], trg[1:, :]

            output = model(src.to(device), trg_inp.to(device))

            output = output.view(-1,output.shape[-1])

            loss = criterion(output, trg_out.view(-1))

            epoch_loss += loss.item()
            
        print(f"{desc}ing Done........")

    return epoch_loss / len(iterator)

In [62]:
src_list = SRC.vocab.itos  # index2word
src_dict = SRC.vocab.stoi # word2index


In [65]:

def epoch_time(start_time: int,
               end_time: int):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

parameters = filter(lambda p:p.requires_grad, model.parameters())
optimizer = optim.Adam(parameters)
num_batches = math.ceil(len(train_data)/BATCH_SIZE)
val_batches = math.ceil(len(val_data)/BATCH_SIZE)

N_EPOCHS = 1
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()

    train_loss = train(model, train_iter, num_batches,optimizer, criterion, CLIP)
    valid_loss = evaluate(model, val_iter,val_batches, criterion, "evaluate")

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')
    
test_size = math.ceil(len(test_data)/BATCH_SIZE)
test_loss = evaluate(model, test_iter,test_size, criterion, "testing")

print(f'| Test Loss: {test_loss:.3f}')

Training......


HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Training Done.....
evaluateing


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


evaluateing Done........
Epoch: 01 | Time: 2m 36s
	Train Loss: 8.669
	 Val. Loss: 7.575
testinging


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


testinging Done........
| Test Loss: 7.611


In [66]:
## code to view text

from einops import rearrange
for i,batch in enumerate(train_iter):
    if i == 1:
        break
    src = batch.text
    trg = batch.summ
    trg_inp, trg_out = trg[:-1, :], trg[1:, :]
    
    
    print("text: ",[src_list[i] for i in src.squeeze(1).transpose(0,1)[1].tolist()])
    
    print("\n\nsumm: ",[src_list[i] for i in trg.squeeze(1).transpose(0,1)[1].tolist()])
    
    
    
    
    memory = model.transformer.encoder(model.pos_enc(model.embed_src(src) * math.sqrt(model.d_model)))
        
    out = model.fc(model.transformer.decoder(model.pos_enc(model.embed_tgt(trg) * math.sqrt(model.d_model)), memory))

    
    out_  = rearrange(out,'t b e -> b t e')
    
    print("out:\n",out,"\n********Done*******")
    
    print("\n\nargmax: ",out_.argmax(2)[0].tolist())
    
    l = out_.argmax(2)[1].tolist()
    
    print([src_list[i] for i in l])
    
    output_dim = out.shape[-1]
    
    print("view adjusted: \n",out.view(-1,output_dim),"\n*******Done********")
    
    print("view-adjusted-shape: ",out.view(-1,output_dim).shape)


text:  ['<sos>', 'london', 'england', 'cnn', 'handsome', 'articulate', 'and', 'lightning', 'fast', 'mclarens', 'lewis', 'hamilton', 'can', 'now', 'add', 'two', 'more', 'word', 'his', 'list', 'quality', 'very', 'rich', 'lewis', 'hamilton', 'able', 'afford', 'lot', 'more', 'champagne', 'future', 'briton', 'set', 'become', 'one', 'most', 'marketable', 'sport', 'star', 'world', 'perhaps', 'second', 'only', 'tiger', 'wood', 'and', 'earn', 'more', 'than', 'billion', 'dollar', 'if', 'he', 'can', 'maintain', 'buzz', 'created', 'his', 'first', 'season', 'formula', 'one', 'expert', 'say', 'sunday', 'he', 'started', 'his', 'second', 'season', 'perfect', 'fashion', 'easily', 'winning', 'australian', 'grand', 'prix', 'melbourne', 'yearold', 'signed', 'fiveyear', 'contract', 'mclaren', 'worth', 'estimated', 'january', 'leaf', 'him', 'lagging', 'along', 'way', 'behind', 'ferraris', 'kimi', 'raikkonen', 'paid', 'estimated', 'year', 'driving', 'but', 'through', 'endorsement', 'he', 'stand', 'reap', 'gr



argmax:  [0, 3, 3, 0, 3, 0, 0, 3, 3, 3, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 3, 3, 0, 0, 3, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
['<unk>', '<eos>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<eos>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<eos>', '<eos>', '<unk>', '<eos>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<eos>', '<eos>', '<eos>', '<unk>', '<unk>', '<unk>', '<unk>', '<eos>', '<unk>', '<unk>', '<unk>', '<eos>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>']
view adjusted: 
 tensor([[ 5.8492, -2.0493, -2.4461,  ..., -1.3770, -1.7189, -2.2254],
        [ 5.8767, -2.0387, -2.4712,  ..., -1.3859, -1.7242, -2.2262],
        [ 5.8534, -2.0546, -2.4363,  ..., -1.3769, -1.7005, -2.2369],
        ...,
        [ 5.8542, -2.0322, -2.4615,  ..., -1.3587, -1.7510, -2.2788],
        [ 5.8559, -2.0362, -2.4856,  ..., -1.3826, -1.7207, -2.2352],
        [ 5.8144, -2.0324, -2.4