In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf
import time
import nltk
import re
import pickle

### Loading Data

In [2]:
news_total = pd.read_csv("/kaggle/input/arabic-text-summarization/summarizdataset.csv")

In [3]:
news = news_total[["text", "summarizer"]]

In [4]:
news["text"][2]

'\nاحتضن جناح تونس في القرية الدولية للأفلام بمدينة "كان" الفرنسية التي تستضيف الدورة السبعين لمهرجان كان السينمائي، لقاء جمع المدير العام للمركز الوطني للسينما والصورة، فتحي الخراط ومسؤولين بالمركز الفرنسي للسينما لتقديم إتفاقية بعث صندوق إنتاج مشترك تونسي فرنسي لتمويل الأفلام.\nوأوضح الخراط في تصريح أنه سيتم الشروع في تفعيل هذه الإتفاقية التي وقّعت في شهر فيفري الفارط بمناسبة زيارة وزيرة الثقافة والاتصال الفرنسية السابقة إلى تونس، مضيفا أن "صندوق الإنتاج المشترك، جاء ليوفر مصدر تمويل جديد للسينمائيين التونسيين"، وفق تأكيده. التعاون من "جحا" وقال مدير المركز الوطني للسينما والصورة إن "التعاون الثنائي بين البلدين في المجال السينمائي يعود إلى سنة 1957 لينطلق مع الفيلم التونسي، "جحا" الذي أخرجه السينمائي الفرنسي، جان باراتييه، موضحا أن "هذا الشريط يعد أول عمل سينمائي أنجز في تونس المستقلة"، على حد تعبيره. واستعرض الجانبان التونسي الفرنسي بالمناسبة، أهم البنود التي جاءت في الإتفاقية والتي تحدد إلتزامات كل طرف، فضلا عن تركيبة اللجنة التي ستنظر في المشاريع المعروضة عليها والشروط الواجب توفر

In [5]:
news.head()

Unnamed: 0,text,summarizer
0,\nأشرف رئيس الجمهورية الباجي قايد السبسي اليوم...,\nأشرف رئيس الجمهورية الباجي قايد السبسي اليوم...
1,"\nتحصل كتاب ""المصحف وقراءاته"" الذي ألفه باحثون...","\nتحصل كتاب ""المصحف وقراءاته"" الذي ألفه باحثون..."
2,\nاحتضن جناح تونس في القرية الدولية للأفلام بم...,تونس حاضرة من جهة أخرى ستكون تونس حاضرة في قائ...
3,\nشهدت برلين أمس الجمعة افتتاح مسجد فريد من نو...,واستأجرت صاحبة المشروع المحامية والكاتبة سيران...
4,\nنعت وزارة الشّؤون الثّقافيّة المنشد الصّوفي ...,\nنعت وزارة الشّؤون الثّقافيّة المنشد الصّوفي ...


In [6]:
news.shape

(8378, 2)

In [7]:
document = news['text']
summary = news['summarizer']

In [8]:
document[30], summary[30]

('\nقال إبراهيم لطيف مدير الدورة 27 لأيام قرطاج السينمائية في تصريح للجوهرة أف أم اليوم الاثنين 14 نوفمبر 2016 إنه قد أعلن عن استقالته منذ يوم الجمعة الماضي أي قبل صدور قرار وزارة الشؤون الثقافية المتعلق بوضع حد لتكليفه.\nوأضاف لطيف أنه تفاجأ بقرار\xa0إنهاء تكليفه مؤكدا أنه لم يتم إعلامه بذلك. وقال لطيف إنه مستعد للمحاسبة وعلى وزير الشؤون الثقافية تحمل مسؤوليته وفق تعبيره.\xa0\xa0 وكانت وزارة الشؤون الثقافية قد دعت في بلاغ أصدرته اليوم الهيئة المديرة للدورة 27 لأيام قرطاج السينمائية إلى تقديم التقريرين المالي والأدبي للدورة المذكورة في أقرب االآجال.\n',
 'وكانت وزارة الشؤون الثقافية قد دعت في بلاغ أصدرته اليوم الهيئة المديرة للدورة 27 لأيام قرطاج السينمائية إلى تقديم التقريرين المالي والأدبي للدورة المذكورة في أقرب االآجال. وقال لطيف إنه مستعد للمحاسبة وعلى وزير الشؤون الثقافية تحمل مسؤوليته وفق تعبيره. وأضاف لطيف أنه تفاجأ بقرار\xa0إنهاء تكليفه مؤكدا أنه لم يتم إعلامه بذلك.')

### Preprocessing

In [9]:
# for decoder sequence
summary = summary.apply(lambda x: '<go> ' + x + ' <stop>')
summary.head()

0    <go> \nأشرف رئيس الجمهورية الباجي قايد السبسي ...
1    <go> \nتحصل كتاب "المصحف وقراءاته" الذي ألفه ب...
2    <go> تونس حاضرة من جهة أخرى ستكون تونس حاضرة ف...
3    <go> واستأجرت صاحبة المشروع المحامية والكاتبة ...
4    <go> \nنعت وزارة الشّؤون الثّقافيّة المنشد الص...
Name: summarizer, dtype: object

#### Tokenizing the texts into integer tokens

In [10]:
# since < and > from default tokens cannot be removed
filters = '!"#$%&()*+,-./:;=?@[\\]^_`{|}~\t\n\xa0'
oov_token = '<unk>'

In [11]:
counter = 0
def delete_links(input_text):
    pettern  = r'''(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'".,<>?«»“”‘’]))'''
    out_text = re.sub(pettern, ' ', input_text)
    return out_text

def delete_repeated_characters(input_text):
    pattern  = r'(.)\1{2,}'
    out_text = re.sub(pattern, r"\1\1", input_text)
    return out_text

def remove_extra_spaces(input_text):
    replace = ' +'
    out_text = re.sub(replace, " ", input_text)
    words = nltk.word_tokenize(out_text)
    words = [word for word in words if word.isalpha()]
    out_text = ' '.join(words)
    return out_text

def replace_letters(input_text):
    replace = {"أ": "ا","ة": "ه","إ": "ا","آ": "ا","": ""}
    replace = dict((re.escape(k), v) for k, v in replace.items())
    pattern = re.compile("|".join(replace.keys()))
    out_text = pattern.sub(lambda m: replace[re.escape(m.group(0))], input_text)
    return out_text

def clean_text(input_text):
    replace = r'[^\u0621-\u064A\u0660-\u0669\u06F0-\u06F90-9]'
    out_text = re.sub(replace, " ", input_text)
    #words = nltk.word_tokenize(out_text)
    #words = [word for word in words if word.isalpha()]
    #out_text = ' '.join(words)
    return out_text

def remove_vowelization(input_text):
    vowelization = re.compile(""" ّ|َ|ً|ُ|ٌ|ِ|ٍ|ْ|ـ""", re.VERBOSE)
    out_text = re.sub(vowelization, '', input_text)
    return out_text

def delete_stopwords(input_text):
    stop_words = set(nltk.corpus.stopwords.words("arabic") + nltk.corpus.stopwords.words("english"))
    tokenizer = nltk.tokenize.WhitespaceTokenizer()
    tokens = tokenizer.tokenize(input_text)
    wnl = nltk.WordNetLemmatizer()
    lemmatizedTokens =[wnl.lemmatize(t) for t in tokens]
    out_text = [w for w in lemmatizedTokens if not w in stop_words]
    out_text = ' '.join(out_text)
    return out_text

def stem_text(input_text):
    st = ISRIStemmer()
    tokenizer = nltk.tokenize.WhitespaceTokenizer()
    tokens = tokenizer.tokenize(input_text)
    out_text = [st.stem(w) for w in tokens]
    out_text = ' '.join(out_text)
    return out_text


def text_prepare(input_text, ar_text):
    global counter
    counter +=1

    out_text = delete_links(input_text)
    out_text = delete_repeated_characters(out_text)
    out_text = delete_stopwords(input_text)
    out_text = clean_text(input_text)
    out_text = remove_extra_spaces(out_text)
    if(counter%100==0):
        print(counter,'\n',out_text)
    return out_text

In [None]:
xl_sum = []

In [13]:
xl_sum  = news
xl_sum.head()

Unnamed: 0,text,summarizer
0,\nأشرف رئيس الجمهورية الباجي قايد السبسي اليوم...,\nأشرف رئيس الجمهورية الباجي قايد السبسي اليوم...
1,"\nتحصل كتاب ""المصحف وقراءاته"" الذي ألفه باحثون...","\nتحصل كتاب ""المصحف وقراءاته"" الذي ألفه باحثون..."
2,\nاحتضن جناح تونس في القرية الدولية للأفلام بم...,تونس حاضرة من جهة أخرى ستكون تونس حاضرة في قائ...
3,\nشهدت برلين أمس الجمعة افتتاح مسجد فريد من نو...,واستأجرت صاحبة المشروع المحامية والكاتبة سيران...
4,\nنعت وزارة الشّؤون الثّقافيّة المنشد الصّوفي ...,\nنعت وزارة الشّؤون الثّقافيّة المنشد الصّوفي ...


In [18]:
xl_sum['paragraph'] = xl_sum['text'].apply(text_prepare, args=(True,))
xl_sum['summary'] = xl_sum['summarizer'].apply(text_prepare, args=(True,))

LookupError: 
**********************************************************************
  Resource [93momw-1.4[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('omw-1.4')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mcorpora/omw-1.4[0m

  Searched in:
    - '/root/nltk_data'
    - '/opt/conda/nltk_data'
    - '/opt/conda/share/nltk_data'
    - '/opt/conda/lib/nltk_data'
    - '/usr/share/nltk_data'
    - '/usr/local/share/nltk_data'
    - '/usr/lib/nltk_data'
    - '/usr/local/lib/nltk_data'
**********************************************************************


In [12]:
document_tokenizer = tf.keras.preprocessing.text.Tokenizer(oov_token=oov_token)
summary_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters=filters, oov_token=oov_token)
document_tokenizer.fit_on_texts(document)
summary_tokenizer.fit_on_texts(summary)

In [13]:
inputs = document_tokenizer.texts_to_sequences(document)
targets = summary_tokenizer.texts_to_sequences(summary)

In [14]:
summary_tokenizer.texts_to_sequences([" كل ابن أُنثى وإن طالت سلامتُهُ بوماً على آلةٍ حدباء محمول "])

[[56, 1832, 1, 1612, 3845, 1, 1, 6, 1, 1, 8323]]

In [15]:
summary_tokenizer.sequences_to_texts([[56, 1813, 1, 1690, 3818, 1, 1, 6, 1, 1, 8275]])

['كل بقابس <unk> الخطوة يلتقي <unk> <unk> على <unk> <unk> بنظام']

In [16]:
encoder_vocab_size = len(document_tokenizer.word_index) + 1
decoder_vocab_size = len(summary_tokenizer.word_index) + 1

# vocab_size
encoder_vocab_size, decoder_vocab_size

(92110, 49459)

#### Obtaining insights on lengths for defining maxlen

In [17]:
document_lengths = pd.Series([len(x) for x in document])
summary_lengths = pd.Series([len(x) for x in summary])

In [18]:
document_lengths.describe()

count    8378.000000
mean      663.660062
std       491.505431
min        82.000000
25%       366.000000
50%       524.500000
75%       793.000000
max      9854.000000
dtype: float64

In [19]:
summary_lengths.describe()

count    8378.000000
mean      261.807591
std       117.474832
min        22.000000
25%       160.000000
50%       257.000000
75%       351.000000
max       596.000000
dtype: float64

In [20]:
# maxlen
# taking values > and round figured to 75th percentile
# at the same time not leaving high variance
encoder_maxlen = 400
decoder_maxlen = 75

#### Padding/Truncating sequences for identical sequence lengths

In [21]:
inputs = tf.keras.preprocessing.sequence.pad_sequences(inputs, maxlen=encoder_maxlen, padding='post', truncating='post')
targets = tf.keras.preprocessing.sequence.pad_sequences(targets, maxlen=decoder_maxlen, padding='post', truncating='post')

### Creating dataset pipeline

In [22]:
inputs = tf.cast(inputs, dtype=tf.int32)
targets = tf.cast(targets, dtype=tf.int32)

In [23]:
BUFFER_SIZE = 20000
BATCH_SIZE = 64

In [24]:
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

### Positional Encoding for adding notion of position among words as unlike RNN this is non-directional

In [25]:
def get_angles(position, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    return position * angle_rates

def positional_encoding(position, d_model):
    angle_rads = get_angles(
        np.arange(position)[:, np.newaxis],
        np.arange(d_model)[np.newaxis, :],
        d_model
    )

    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[np.newaxis, ...]

    return tf.cast(pos_encoding, dtype=tf.float32)


### Masking

- Padding mask for masking "pad" sequences
- Lookahead mask for masking future words from contributing in prediction of current words in self attention

In [26]:
def create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
    return seq[:, tf.newaxis, tf.newaxis, :]

def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask

### Building the Model

#### Scaled Dot Product

In [27]:
def scaled_dot_product_attention(q, k, v, mask):
    matmul_qk = tf.matmul(q, k, transpose_b=True)

    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    if mask is not None:
        scaled_attention_logits += (mask * -1e9)  

    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

    output = tf.matmul(attention_weights, v)
    return output, attention_weights

#### Multi-Headed Attention

In [28]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)
        
    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])

        concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))
        output = self.dense(concat_attention)
            
        return output, attention_weights

### Feed Forward Network

In [29]:
def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(dff, activation='relu'),
        tf.keras.layers.Dense(d_model)
    ])

#### Fundamental Unit of Transformer encoder

In [30]:
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
    
    def call(self, x, training, mask):
        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)

        return out2


#### Fundamental Unit of Transformer decoder

In [31]:
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()

        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)

        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        self.dropout3 = tf.keras.layers.Dropout(rate)
    
    
    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(attn1 + x)

        attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, padding_mask)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(attn2 + out1)

        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(ffn_output + out2)

        return out3, attn_weights_block1, attn_weights_block2



#### Encoder consisting of multiple EncoderLayer(s)

In [32]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding, rate=0.1):
        super(Encoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, self.d_model)

        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]

        self.dropout = tf.keras.layers.Dropout(rate)
        
    def call(self, x, training, mask):
        seq_len = tf.shape(x)[1]

        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)
    
        for i in range(self.num_layers):
            x = self.enc_layers[i](x, training, mask)
    
        return x


#### Decoder consisting of multiple DecoderLayer(s)

In [33]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, maximum_position_encoding, rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

        self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(rate)
    
    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        seq_len = tf.shape(x)[1]
        attention_weights = {}

        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x, block1, block2 = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask)

            attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
            attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
    
        return x, attention_weights



#### Finally, the Transformer

In [34]:
class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, pe_input, pe_target, rate=0.1):
        super(Transformer, self).__init__()

        self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, pe_input, rate)

        self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, pe_target, rate)

        self.final_layer = tf.keras.layers.Dense(target_vocab_size)
    
    def call(self, inp, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask):
        enc_output = self.encoder(inp, training, enc_padding_mask)

        dec_output, attention_weights = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)

        final_output = self.final_layer(dec_output)

        return final_output, attention_weights


### Training

In [35]:
# hyper-params
num_layers = 4
d_model = 128
dff = 512
num_heads = 8
EPOCHS = 100

#### Adam optimizer with custom learning rate scheduling

In [36]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps
    
    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


#### Defining losses and other metrics 

In [37]:
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

In [38]:
learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_sum(loss_)/tf.reduce_sum(mask)
train_loss = tf.keras.metrics.Mean(name='train_loss')

#### Transformer

In [39]:
transformer = Transformer(
    num_layers, 
    d_model, 
    num_heads, 
    dff,
    encoder_vocab_size, 
    decoder_vocab_size, 
    pe_input=encoder_vocab_size, 
    pe_target=decoder_vocab_size,
)

#### Masks

In [40]:
def create_masks(inp, tar):
    enc_padding_mask = create_padding_mask(inp)
    dec_padding_mask = create_padding_mask(inp)

    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
  
    return enc_padding_mask, combined_mask, dec_padding_mask


#### Checkpoints

In [41]:
checkpoint_path = "checkpoints"

ckpt = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

#### Training steps

In [42]:
@tf.function
def train_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

    with tf.GradientTape() as tape:
        predictions, _ = transformer(
            inp, tar_inp, 
            True, 
            enc_padding_mask, 
            combined_mask, 
            dec_padding_mask
        )
        loss = loss_function(tar_real, predictions)

    gradients = tape.gradient(loss, transformer.trainable_variables)    
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    train_loss(loss)
    return predictions

In [87]:
# # import plotly.graph_objects as go
# from tqdm import tqdm
# import time

# # Initialize lists to store training and validation/test loss
# train_losses = []
# val_losses = []

# # Loop over epochs
# for epoch in tqdm(range(EPOCHS)):
#     start = time.time()

#     train_loss.reset_states()
  
#     for (batch, (inp, tar)) in enumerate(dataset):
#         train_step(inp, tar)
    
#         # 55k samples
#         # we display 3 batch results -- 0th, middle and last one (approx)
#         # 55k / 64 ~ 858; 858 / 2 = 429
#         if batch % 429 == 0:
#             print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, batch, train_loss.result()))

#     # Calculate and append training loss
#     train_losses.append(train_loss.result())

#     if (epoch + 1) % 5 == 0:
#         ckpt_save_path = ckpt_manager.save()
#         print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
    
#     print ('Epoch {} Loss {:.4f}'.format(epoch + 1, train_loss.result()))

#     # Simulate validation/test loss calculation
#     val_loss = ...  # calculate validation/test loss
#     val_losses.append(val_loss)

#     print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))



In [48]:
# # !pip install rouge
# import rouge
# from rouge import rouge_score
# # from rouge_score import rouge_scorer
# from nltk.translate.bleu_score import corpus_bleu
# # Initialize Rouge scorer
# rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeLsum'], use_stemmer=True)

# # Initialize BLEU scorer
# def compute_bleu(reference, hypothesis):
#     return corpus_bleu([[r.split()] for r in reference], [h.split() for h in hypothesis])


[0m

NameError: name 'rouge_scorer' is not defined

In [43]:
from tqdm import tqdm 
import time

# Initialize Rouge and BLEU calculators
# rouge_calculator = RougeCalculator()
# bleu_calculator = BleuCalculator()

# Initialize metrics
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

for epoch in tqdm(range(EPOCHS)):
    start = time.time()

    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
  
    for (batch, (inp, tar)) in enumerate(dataset):
        predictions = train_step(inp, tar)
    
        # Update accuracy
        train_accuracy(tar[:, 1:], predictions)

        # 55k samples
        # we display 3 batch results -- 0th, middle and last one (approx)
        # 55k / 64 ~ 858; 858 / 2 = 429
        if batch % 429 == 0:
            print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, batch, train_loss.result()))
      
    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
    
    print ('Epoch {} Loss: {:.4f}'.format(epoch + 1, train_loss.result()))
    print ('Epoch {} Accuracy: {:.4f}'.format(epoch + 1, train_accuracy.result()))

#     # Calculate Rouge and BLEU scores for training data
#     train_rouge_scores = rouge_calculator.compute_scores(train_reference, train_predictions)
#     train_bleu_score = bleu_calculator.compute_bleu(train_reference, train_predictions)

#     print ('Train Rouge Scores:', train_rouge_scores)
#     print ('Train BLEU Score:', train_bleu_score)

#     # Calculate Rouge and BLEU scores for testing data
#     test_rouge_scores = rouge_calculator.compute_scores(test_reference, test_predictions)
#     test_bleu_score = bleu_calculator.compute_bleu(test_reference, test_predictions)

#     print ('Test Rouge Scores:', test_rouge_scores)
#     print ('Test BLEU Score:', test_bleu_score)

    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 Batch 0 Loss 10.8121


  1%|          | 1/100 [01:16<2:05:40, 76.16s/it]

Epoch 1 Loss: 10.7137
Epoch 1 Accuracy: 0.0102
Time taken for 1 epoch: 76.16401386260986 secs

Epoch 2 Batch 0 Loss 10.5221


  2%|▏         | 2/100 [02:20<1:53:14, 69.33s/it]

Epoch 2 Loss: 10.1771
Epoch 2 Accuracy: 0.0167
Time taken for 1 epoch: 64.54160857200623 secs

Epoch 3 Batch 0 Loss 9.7594


  3%|▎         | 3/100 [03:25<1:48:52, 67.35s/it]

Epoch 3 Loss: 9.3648
Epoch 3 Accuracy: 0.0167
Time taken for 1 epoch: 64.9888083934784 secs

Epoch 4 Batch 0 Loss 9.0053


  4%|▍         | 4/100 [04:30<1:46:16, 66.42s/it]

Epoch 4 Loss: 8.7945
Epoch 4 Accuracy: 0.0167
Time taken for 1 epoch: 65.00418472290039 secs

Epoch 5 Batch 0 Loss 8.6519


  5%|▌         | 5/100 [05:36<1:44:56, 66.28s/it]

Saving checkpoint for epoch 5 at checkpoints/ckpt-1
Epoch 5 Loss: 8.6553
Epoch 5 Accuracy: 0.0192
Time taken for 1 epoch: 66.02557396888733 secs

Epoch 6 Batch 0 Loss 8.5877


  6%|▌         | 6/100 [06:41<1:43:14, 65.90s/it]

Epoch 6 Loss: 8.4628
Epoch 6 Accuracy: 0.0264
Time taken for 1 epoch: 65.1675157546997 secs

Epoch 7 Batch 0 Loss 8.3351


  7%|▋         | 7/100 [07:46<1:41:44, 65.64s/it]

Epoch 7 Loss: 8.0822
Epoch 7 Accuracy: 0.0397
Time taken for 1 epoch: 65.0904860496521 secs

Epoch 8 Batch 0 Loss 7.8517


  8%|▊         | 8/100 [08:52<1:40:24, 65.49s/it]

Epoch 8 Loss: 7.6487
Epoch 8 Accuracy: 0.0518
Time taken for 1 epoch: 65.17139005661011 secs

Epoch 9 Batch 0 Loss 7.3428


  9%|▉         | 9/100 [09:57<1:39:10, 65.39s/it]

Epoch 9 Loss: 7.2362
Epoch 9 Accuracy: 0.0656
Time taken for 1 epoch: 65.16575646400452 secs

Epoch 10 Batch 0 Loss 6.8483


 10%|█         | 10/100 [11:03<1:38:28, 65.65s/it]

Saving checkpoint for epoch 10 at checkpoints/ckpt-2
Epoch 10 Loss: 6.8553
Epoch 10 Accuracy: 0.0776
Time taken for 1 epoch: 66.23180317878723 secs

Epoch 11 Batch 0 Loss 6.5348


 11%|█         | 11/100 [12:08<1:37:11, 65.52s/it]

Epoch 11 Loss: 6.4954
Epoch 11 Accuracy: 0.0889
Time taken for 1 epoch: 65.2299497127533 secs

Epoch 12 Batch 0 Loss 6.1555


 12%|█▏        | 12/100 [13:14<1:36:00, 65.46s/it]

Epoch 12 Loss: 6.1480
Epoch 12 Accuracy: 0.1007
Time taken for 1 epoch: 65.32147312164307 secs

Epoch 13 Batch 0 Loss 5.7968


 13%|█▎        | 13/100 [14:19<1:34:49, 65.40s/it]

Epoch 13 Loss: 5.8104
Epoch 13 Accuracy: 0.1125
Time taken for 1 epoch: 65.24665832519531 secs

Epoch 14 Batch 0 Loss 5.4427


 14%|█▍        | 14/100 [15:24<1:33:39, 65.34s/it]

Epoch 14 Loss: 5.4767
Epoch 14 Accuracy: 0.1243
Time taken for 1 epoch: 65.20362401008606 secs

Epoch 15 Batch 0 Loss 4.9812


 15%|█▌        | 15/100 [16:30<1:32:53, 65.57s/it]

Saving checkpoint for epoch 15 at checkpoints/ckpt-3
Epoch 15 Loss: 5.1413
Epoch 15 Accuracy: 0.1379
Time taken for 1 epoch: 66.11102533340454 secs

Epoch 16 Batch 0 Loss 4.8044


 16%|█▌        | 16/100 [17:35<1:31:38, 65.45s/it]

Epoch 16 Loss: 4.7972
Epoch 16 Accuracy: 0.1535
Time taken for 1 epoch: 65.18008661270142 secs

Epoch 17 Batch 0 Loss 4.3963


 17%|█▋        | 17/100 [18:41<1:30:31, 65.44s/it]

Epoch 17 Loss: 4.4509
Epoch 17 Accuracy: 0.1716
Time taken for 1 epoch: 65.39781403541565 secs

Epoch 18 Batch 0 Loss 3.9949


 18%|█▊        | 18/100 [19:46<1:29:20, 65.37s/it]

Epoch 18 Loss: 4.1057
Epoch 18 Accuracy: 0.1916
Time taken for 1 epoch: 65.22041869163513 secs

Epoch 19 Batch 0 Loss 3.6774


 19%|█▉        | 19/100 [20:51<1:28:15, 65.38s/it]

Epoch 19 Loss: 3.7554
Epoch 19 Accuracy: 0.2129
Time taken for 1 epoch: 65.39759135246277 secs

Epoch 20 Batch 0 Loss 3.4123


 20%|██        | 20/100 [21:57<1:27:25, 65.57s/it]

Saving checkpoint for epoch 20 at checkpoints/ckpt-4
Epoch 20 Loss: 3.4076
Epoch 20 Accuracy: 0.2360
Time taken for 1 epoch: 66.01503920555115 secs

Epoch 21 Batch 0 Loss 2.8106


 21%|██        | 21/100 [23:03<1:26:13, 65.49s/it]

Epoch 21 Loss: 3.0680
Epoch 21 Accuracy: 0.2589
Time taken for 1 epoch: 65.30859756469727 secs

Epoch 22 Batch 0 Loss 2.7201


 22%|██▏       | 22/100 [24:08<1:25:05, 65.46s/it]

Epoch 22 Loss: 2.7349
Epoch 22 Accuracy: 0.2843
Time taken for 1 epoch: 65.37945318222046 secs

Epoch 23 Batch 0 Loss 2.2268


 23%|██▎       | 23/100 [25:13<1:23:57, 65.43s/it]

Epoch 23 Loss: 2.4168
Epoch 23 Accuracy: 0.3088
Time taken for 1 epoch: 65.35070276260376 secs

Epoch 24 Batch 0 Loss 1.8470


 24%|██▍       | 24/100 [26:19<1:22:48, 65.38s/it]

Epoch 24 Loss: 2.1179
Epoch 24 Accuracy: 0.3346
Time taken for 1 epoch: 65.27370405197144 secs

Epoch 25 Batch 0 Loss 1.7144


 25%|██▌       | 25/100 [27:25<1:22:05, 65.67s/it]

Saving checkpoint for epoch 25 at checkpoints/ckpt-5
Epoch 25 Loss: 1.8545
Epoch 25 Accuracy: 0.3576
Time taken for 1 epoch: 66.33617973327637 secs

Epoch 26 Batch 0 Loss 1.4357


 26%|██▌       | 26/100 [28:30<1:20:51, 65.56s/it]

Epoch 26 Loss: 1.6163
Epoch 26 Accuracy: 0.3789
Time taken for 1 epoch: 65.30693197250366 secs

Epoch 27 Batch 0 Loss 1.2571


 27%|██▋       | 27/100 [29:36<1:19:39, 65.47s/it]

Epoch 27 Loss: 1.4189
Epoch 27 Accuracy: 0.3968
Time taken for 1 epoch: 65.27124071121216 secs

Epoch 28 Batch 0 Loss 1.1893


 28%|██▊       | 28/100 [30:41<1:18:32, 65.45s/it]

Epoch 28 Loss: 1.2524
Epoch 28 Accuracy: 0.4128
Time taken for 1 epoch: 65.3794195652008 secs

Epoch 29 Batch 0 Loss 0.8849


 29%|██▉       | 29/100 [31:46<1:17:23, 65.41s/it]

Epoch 29 Loss: 1.1096
Epoch 29 Accuracy: 0.4265
Time taken for 1 epoch: 65.31606030464172 secs

Epoch 30 Batch 0 Loss 0.8534


 30%|███       | 30/100 [32:53<1:16:39, 65.71s/it]

Saving checkpoint for epoch 30 at checkpoints/ckpt-6
Epoch 30 Loss: 0.9915
Epoch 30 Accuracy: 0.4379
Time taken for 1 epoch: 66.42801308631897 secs

Epoch 31 Batch 0 Loss 0.6874


 31%|███       | 31/100 [33:58<1:15:27, 65.62s/it]

Epoch 31 Loss: 0.8891
Epoch 31 Accuracy: 0.4491
Time taken for 1 epoch: 65.39900541305542 secs

Epoch 32 Batch 0 Loss 0.6696


 32%|███▏      | 32/100 [35:04<1:14:17, 65.56s/it]

Epoch 32 Loss: 0.7951
Epoch 32 Accuracy: 0.4595
Time taken for 1 epoch: 65.41250133514404 secs

Epoch 33 Batch 0 Loss 0.5513


 33%|███▎      | 33/100 [36:09<1:13:08, 65.50s/it]

Epoch 33 Loss: 0.6975
Epoch 33 Accuracy: 0.4714
Time taken for 1 epoch: 65.36624026298523 secs

Epoch 34 Batch 0 Loss 0.4881


 34%|███▍      | 34/100 [37:14<1:12:02, 65.49s/it]

Epoch 34 Loss: 0.6183
Epoch 34 Accuracy: 0.4815
Time taken for 1 epoch: 65.4605484008789 secs

Epoch 35 Batch 0 Loss 0.4926


 35%|███▌      | 35/100 [38:21<1:11:15, 65.78s/it]

Saving checkpoint for epoch 35 at checkpoints/ckpt-7
Epoch 35 Loss: 0.5573
Epoch 35 Accuracy: 0.4898
Time taken for 1 epoch: 66.44639015197754 secs

Epoch 36 Batch 0 Loss 0.4049


 36%|███▌      | 36/100 [39:26<1:10:02, 65.66s/it]

Epoch 36 Loss: 0.5030
Epoch 36 Accuracy: 0.4971
Time taken for 1 epoch: 65.38026642799377 secs

Epoch 37 Batch 0 Loss 0.3503


 37%|███▋      | 37/100 [40:32<1:08:55, 65.64s/it]

Epoch 37 Loss: 0.4554
Epoch 37 Accuracy: 0.5041
Time taken for 1 epoch: 65.60781931877136 secs

Epoch 38 Batch 0 Loss 0.2927


 38%|███▊      | 38/100 [41:37<1:07:44, 65.55s/it]

Epoch 38 Loss: 0.4147
Epoch 38 Accuracy: 0.5101
Time taken for 1 epoch: 65.33800148963928 secs

Epoch 39 Batch 0 Loss 0.3024


 39%|███▉      | 39/100 [42:43<1:06:38, 65.55s/it]

Epoch 39 Loss: 0.3838
Epoch 39 Accuracy: 0.5145
Time taken for 1 epoch: 65.55772161483765 secs

Epoch 40 Batch 0 Loss 0.2829


 40%|████      | 40/100 [43:49<1:05:45, 65.76s/it]

Saving checkpoint for epoch 40 at checkpoints/ckpt-8
Epoch 40 Loss: 0.3536
Epoch 40 Accuracy: 0.5191
Time taken for 1 epoch: 66.25106573104858 secs

Epoch 41 Batch 0 Loss 0.2725


 41%|████      | 41/100 [44:54<1:04:33, 65.65s/it]

Epoch 41 Loss: 0.3258
Epoch 41 Accuracy: 0.5235
Time taken for 1 epoch: 65.38230633735657 secs

Epoch 42 Batch 0 Loss 0.2482


 42%|████▏     | 42/100 [46:00<1:03:22, 65.56s/it]

Epoch 42 Loss: 0.3035
Epoch 42 Accuracy: 0.5266
Time taken for 1 epoch: 65.34623718261719 secs

Epoch 43 Batch 0 Loss 0.2222


 43%|████▎     | 43/100 [47:05<1:02:14, 65.52s/it]

Epoch 43 Loss: 0.2840
Epoch 43 Accuracy: 0.5297
Time taken for 1 epoch: 65.42472767829895 secs

Epoch 44 Batch 0 Loss 0.2078


 44%|████▍     | 44/100 [48:11<1:01:06, 65.48s/it]

Epoch 44 Loss: 0.2661
Epoch 44 Accuracy: 0.5323
Time taken for 1 epoch: 65.38583087921143 secs

Epoch 45 Batch 0 Loss 0.2298


 45%|████▌     | 45/100 [49:17<1:00:15, 65.74s/it]

Saving checkpoint for epoch 45 at checkpoints/ckpt-9
Epoch 45 Loss: 0.2525
Epoch 45 Accuracy: 0.5345
Time taken for 1 epoch: 66.3420717716217 secs

Epoch 46 Batch 0 Loss 0.1834


 46%|████▌     | 46/100 [50:22<59:03, 65.62s/it]  

Epoch 46 Loss: 0.2369
Epoch 46 Accuracy: 0.5370
Time taken for 1 epoch: 65.3372631072998 secs

Epoch 47 Batch 0 Loss 0.1750


 47%|████▋     | 47/100 [51:28<57:54, 65.55s/it]

Epoch 47 Loss: 0.2249
Epoch 47 Accuracy: 0.5389
Time taken for 1 epoch: 65.38879728317261 secs

Epoch 48 Batch 0 Loss 0.1844


 48%|████▊     | 48/100 [52:33<56:45, 65.49s/it]

Epoch 48 Loss: 0.2147
Epoch 48 Accuracy: 0.5406
Time taken for 1 epoch: 65.35401725769043 secs

Epoch 49 Batch 0 Loss 0.1702


 49%|████▉     | 49/100 [53:39<55:41, 65.51s/it]

Epoch 49 Loss: 0.2017
Epoch 49 Accuracy: 0.5426
Time taken for 1 epoch: 65.56240057945251 secs

Epoch 50 Batch 0 Loss 0.1670


 50%|█████     | 50/100 [54:45<54:49, 65.80s/it]

Saving checkpoint for epoch 50 at checkpoints/ckpt-10
Epoch 50 Loss: 0.1925
Epoch 50 Accuracy: 0.5444
Time taken for 1 epoch: 66.4630663394928 secs

Epoch 51 Batch 0 Loss 0.1852


 51%|█████     | 51/100 [55:50<53:36, 65.65s/it]

Epoch 51 Loss: 0.1825
Epoch 51 Accuracy: 0.5459
Time taken for 1 epoch: 65.30337476730347 secs

Epoch 52 Batch 0 Loss 0.1587


 52%|█████▏    | 52/100 [56:56<52:25, 65.54s/it]

Epoch 52 Loss: 0.1722
Epoch 52 Accuracy: 0.5474
Time taken for 1 epoch: 65.28259229660034 secs

Epoch 53 Batch 0 Loss 0.1621


 53%|█████▎    | 53/100 [58:01<51:16, 65.45s/it]

Epoch 53 Loss: 0.1677
Epoch 53 Accuracy: 0.5481
Time taken for 1 epoch: 65.24159669876099 secs

Epoch 54 Batch 0 Loss 0.1246


 54%|█████▍    | 54/100 [59:06<50:08, 65.40s/it]

Epoch 54 Loss: 0.1602
Epoch 54 Accuracy: 0.5493
Time taken for 1 epoch: 65.26932525634766 secs

Epoch 55 Batch 0 Loss 0.1406


 55%|█████▌    | 55/100 [1:00:12<49:14, 65.66s/it]

Saving checkpoint for epoch 55 at checkpoints/ckpt-11
Epoch 55 Loss: 0.1532
Epoch 55 Accuracy: 0.5503
Time taken for 1 epoch: 66.26700043678284 secs

Epoch 56 Batch 0 Loss 0.1321


 56%|█████▌    | 56/100 [1:01:18<48:03, 65.54s/it]

Epoch 56 Loss: 0.1474
Epoch 56 Accuracy: 0.5514
Time taken for 1 epoch: 65.26070928573608 secs

Epoch 57 Batch 0 Loss 0.1107


 57%|█████▋    | 57/100 [1:02:23<46:54, 65.45s/it]

Epoch 57 Loss: 0.1391
Epoch 57 Accuracy: 0.5527
Time taken for 1 epoch: 65.24136567115784 secs

Epoch 58 Batch 0 Loss 0.1247


 58%|█████▊    | 58/100 [1:03:28<45:46, 65.38s/it]

Epoch 58 Loss: 0.1347
Epoch 58 Accuracy: 0.5535
Time taken for 1 epoch: 65.22942566871643 secs

Epoch 59 Batch 0 Loss 0.0983


 59%|█████▉    | 59/100 [1:04:33<44:39, 65.35s/it]

Epoch 59 Loss: 0.1315
Epoch 59 Accuracy: 0.5541
Time taken for 1 epoch: 65.27144718170166 secs

Epoch 60 Batch 0 Loss 0.1109


 60%|██████    | 60/100 [1:05:39<43:43, 65.58s/it]

Saving checkpoint for epoch 60 at checkpoints/ckpt-12
Epoch 60 Loss: 0.1241
Epoch 60 Accuracy: 0.5551
Time taken for 1 epoch: 66.1025812625885 secs

Epoch 61 Batch 0 Loss 0.0983


 61%|██████    | 61/100 [1:06:45<42:33, 65.49s/it]

Epoch 61 Loss: 0.1191
Epoch 61 Accuracy: 0.5560
Time taken for 1 epoch: 65.27210283279419 secs

Epoch 62 Batch 0 Loss 0.0926


 62%|██████▏   | 62/100 [1:07:50<41:26, 65.44s/it]

Epoch 62 Loss: 0.1167
Epoch 62 Accuracy: 0.5564
Time taken for 1 epoch: 65.34123182296753 secs

Epoch 63 Batch 0 Loss 0.0884


 63%|██████▎   | 63/100 [1:08:55<40:18, 65.36s/it]

Epoch 63 Loss: 0.1116
Epoch 63 Accuracy: 0.5573
Time taken for 1 epoch: 65.18000364303589 secs

Epoch 64 Batch 0 Loss 0.0999


 64%|██████▍   | 64/100 [1:10:01<39:12, 65.33s/it]

Epoch 64 Loss: 0.1070
Epoch 64 Accuracy: 0.5581
Time taken for 1 epoch: 65.26262021064758 secs

Epoch 65 Batch 0 Loss 0.0792


 65%|██████▌   | 65/100 [1:11:07<38:13, 65.53s/it]

Saving checkpoint for epoch 65 at checkpoints/ckpt-13
Epoch 65 Loss: 0.1038
Epoch 65 Accuracy: 0.5587
Time taken for 1 epoch: 65.99302506446838 secs

Epoch 66 Batch 0 Loss 0.0912


 66%|██████▌   | 66/100 [1:12:12<37:05, 65.45s/it]

Epoch 66 Loss: 0.1014
Epoch 66 Accuracy: 0.5591
Time taken for 1 epoch: 65.26125693321228 secs

Epoch 67 Batch 0 Loss 0.0861


 67%|██████▋   | 67/100 [1:13:17<35:58, 65.41s/it]

Epoch 67 Loss: 0.0976
Epoch 67 Accuracy: 0.5596
Time taken for 1 epoch: 65.31169891357422 secs

Epoch 68 Batch 0 Loss 0.0857


 68%|██████▊   | 68/100 [1:14:22<34:51, 65.37s/it]

Epoch 68 Loss: 0.0946
Epoch 68 Accuracy: 0.5602
Time taken for 1 epoch: 65.28817296028137 secs

Epoch 69 Batch 0 Loss 0.0803


 69%|██████▉   | 69/100 [1:15:28<33:45, 65.34s/it]

Epoch 69 Loss: 0.0920
Epoch 69 Accuracy: 0.5605
Time taken for 1 epoch: 65.26100635528564 secs

Epoch 70 Batch 0 Loss 0.0925


 70%|███████   | 70/100 [1:16:34<32:48, 65.62s/it]

Saving checkpoint for epoch 70 at checkpoints/ckpt-14
Epoch 70 Loss: 0.0875
Epoch 70 Accuracy: 0.5613
Time taken for 1 epoch: 66.27508664131165 secs

Epoch 71 Batch 0 Loss 0.0674


 71%|███████   | 71/100 [1:17:39<31:40, 65.53s/it]

Epoch 71 Loss: 0.0855
Epoch 71 Accuracy: 0.5618
Time taken for 1 epoch: 65.33165526390076 secs

Epoch 72 Batch 0 Loss 0.0628


 72%|███████▏  | 72/100 [1:18:45<30:32, 65.46s/it]

Epoch 72 Loss: 0.0839
Epoch 72 Accuracy: 0.5618
Time taken for 1 epoch: 65.29195594787598 secs

Epoch 73 Batch 0 Loss 0.0747


 73%|███████▎  | 73/100 [1:19:50<29:25, 65.40s/it]

Epoch 73 Loss: 0.0806
Epoch 73 Accuracy: 0.5624
Time taken for 1 epoch: 65.26714372634888 secs

Epoch 74 Batch 0 Loss 0.0614


 74%|███████▍  | 74/100 [1:20:55<28:19, 65.36s/it]

Epoch 74 Loss: 0.0790
Epoch 74 Accuracy: 0.5627
Time taken for 1 epoch: 65.24929618835449 secs

Epoch 75 Batch 0 Loss 0.0679


 75%|███████▌  | 75/100 [1:22:01<27:18, 65.56s/it]

Saving checkpoint for epoch 75 at checkpoints/ckpt-15
Epoch 75 Loss: 0.0778
Epoch 75 Accuracy: 0.5628
Time taken for 1 epoch: 66.02468848228455 secs

Epoch 76 Batch 0 Loss 0.0683


 76%|███████▌  | 76/100 [1:23:06<26:11, 65.48s/it]

Epoch 76 Loss: 0.0738
Epoch 76 Accuracy: 0.5636
Time taken for 1 epoch: 65.28208541870117 secs

Epoch 77 Batch 0 Loss 0.0673


 77%|███████▋  | 77/100 [1:24:12<25:04, 65.41s/it]

Epoch 77 Loss: 0.0726
Epoch 77 Accuracy: 0.5637
Time taken for 1 epoch: 65.25299739837646 secs

Epoch 78 Batch 0 Loss 0.0472


 78%|███████▊  | 78/100 [1:25:17<23:57, 65.34s/it]

Epoch 78 Loss: 0.0706
Epoch 78 Accuracy: 0.5642
Time taken for 1 epoch: 65.18315935134888 secs

Epoch 79 Batch 0 Loss 0.0605


 79%|███████▉  | 79/100 [1:26:22<22:51, 65.30s/it]

Epoch 79 Loss: 0.0676
Epoch 79 Accuracy: 0.5645
Time taken for 1 epoch: 65.18841981887817 secs

Epoch 80 Batch 0 Loss 0.0558


 80%|████████  | 80/100 [1:27:28<21:50, 65.52s/it]

Saving checkpoint for epoch 80 at checkpoints/ckpt-16
Epoch 80 Loss: 0.0667
Epoch 80 Accuracy: 0.5646
Time taken for 1 epoch: 66.05284452438354 secs

Epoch 81 Batch 0 Loss 0.0651


 81%|████████  | 81/100 [1:28:33<20:43, 65.45s/it]

Epoch 81 Loss: 0.0651
Epoch 81 Accuracy: 0.5651
Time taken for 1 epoch: 65.26150918006897 secs

Epoch 82 Batch 0 Loss 0.0531


 82%|████████▏ | 82/100 [1:29:39<19:36, 65.38s/it]

Epoch 82 Loss: 0.0633
Epoch 82 Accuracy: 0.5653
Time taken for 1 epoch: 65.21295976638794 secs

Epoch 83 Batch 0 Loss 0.0477


 83%|████████▎ | 83/100 [1:30:44<18:30, 65.32s/it]

Epoch 83 Loss: 0.0611
Epoch 83 Accuracy: 0.5655
Time taken for 1 epoch: 65.17601704597473 secs

Epoch 84 Batch 0 Loss 0.0502


 84%|████████▍ | 84/100 [1:31:49<17:25, 65.31s/it]

Epoch 84 Loss: 0.0605
Epoch 84 Accuracy: 0.5659
Time taken for 1 epoch: 65.3038318157196 secs

Epoch 85 Batch 0 Loss 0.0590


 85%|████████▌ | 85/100 [1:32:55<16:24, 65.62s/it]

Saving checkpoint for epoch 85 at checkpoints/ckpt-17
Epoch 85 Loss: 0.0592
Epoch 85 Accuracy: 0.5661
Time taken for 1 epoch: 66.34156322479248 secs

Epoch 86 Batch 0 Loss 0.0401


 86%|████████▌ | 86/100 [1:34:01<15:17, 65.54s/it]

Epoch 86 Loss: 0.0565
Epoch 86 Accuracy: 0.5664
Time taken for 1 epoch: 65.34748411178589 secs

Epoch 87 Batch 0 Loss 0.0545


 87%|████████▋ | 87/100 [1:35:06<14:10, 65.46s/it]

Epoch 87 Loss: 0.0555
Epoch 87 Accuracy: 0.5666
Time taken for 1 epoch: 65.25665926933289 secs

Epoch 88 Batch 0 Loss 0.0422


 88%|████████▊ | 88/100 [1:36:11<13:04, 65.40s/it]

Epoch 88 Loss: 0.0554
Epoch 88 Accuracy: 0.5667
Time taken for 1 epoch: 65.27080798149109 secs

Epoch 89 Batch 0 Loss 0.0387


 89%|████████▉ | 89/100 [1:37:17<11:58, 65.36s/it]

Epoch 89 Loss: 0.0533
Epoch 89 Accuracy: 0.5670
Time taken for 1 epoch: 65.26919341087341 secs

Epoch 90 Batch 0 Loss 0.0409


 90%|█████████ | 90/100 [1:38:23<10:56, 65.64s/it]

Saving checkpoint for epoch 90 at checkpoints/ckpt-18
Epoch 90 Loss: 0.0515
Epoch 90 Accuracy: 0.5673
Time taken for 1 epoch: 66.27811336517334 secs

Epoch 91 Batch 0 Loss 0.0393


 91%|█████████ | 91/100 [1:39:28<09:49, 65.52s/it]

Epoch 91 Loss: 0.0500
Epoch 91 Accuracy: 0.5676
Time taken for 1 epoch: 65.23678183555603 secs

Epoch 92 Batch 0 Loss 0.0461


 92%|█████████▏| 92/100 [1:40:33<08:43, 65.47s/it]

Epoch 92 Loss: 0.0490
Epoch 92 Accuracy: 0.5676
Time taken for 1 epoch: 65.37428164482117 secs

Epoch 93 Batch 0 Loss 0.0439


 93%|█████████▎| 93/100 [1:41:39<07:37, 65.39s/it]

Epoch 93 Loss: 0.0478
Epoch 93 Accuracy: 0.5678
Time taken for 1 epoch: 65.19862604141235 secs

Epoch 94 Batch 0 Loss 0.0636


 94%|█████████▍| 94/100 [1:42:44<06:31, 65.33s/it]

Epoch 94 Loss: 0.0464
Epoch 94 Accuracy: 0.5681
Time taken for 1 epoch: 65.17573738098145 secs

Epoch 95 Batch 0 Loss 0.0400


 95%|█████████▌| 95/100 [1:43:50<05:28, 65.63s/it]

Saving checkpoint for epoch 95 at checkpoints/ckpt-19
Epoch 95 Loss: 0.0472
Epoch 95 Accuracy: 0.5680
Time taken for 1 epoch: 66.33777570724487 secs

Epoch 96 Batch 0 Loss 0.0397


 96%|█████████▌| 96/100 [1:44:55<04:22, 65.51s/it]

Epoch 96 Loss: 0.0448
Epoch 96 Accuracy: 0.5683
Time taken for 1 epoch: 65.22428369522095 secs

Epoch 97 Batch 0 Loss 0.0264


 97%|█████████▋| 97/100 [1:46:01<03:16, 65.43s/it]

Epoch 97 Loss: 0.0446
Epoch 97 Accuracy: 0.5684
Time taken for 1 epoch: 65.23095035552979 secs

Epoch 98 Batch 0 Loss 0.0432


 98%|█████████▊| 98/100 [1:47:06<02:10, 65.38s/it]

Epoch 98 Loss: 0.0434
Epoch 98 Accuracy: 0.5686
Time taken for 1 epoch: 65.26409149169922 secs

Epoch 99 Batch 0 Loss 0.0336


 99%|█████████▉| 99/100 [1:48:11<01:05, 65.33s/it]

Epoch 99 Loss: 0.0427
Epoch 99 Accuracy: 0.5687
Time taken for 1 epoch: 65.22480154037476 secs

Epoch 100 Batch 0 Loss 0.0391


100%|██████████| 100/100 [1:49:17<00:00, 65.58s/it]

Saving checkpoint for epoch 100 at checkpoints/ckpt-20
Epoch 100 Loss: 0.0418
Epoch 100 Accuracy: 0.5688
Time taken for 1 epoch: 66.2869815826416 secs






In [44]:
from tqdm import tqdm 
for epoch in tqdm(range(EPOCHS)):
    start = time.time()

    train_loss.reset_states()
  
    for (batch, (inp, tar)) in enumerate(dataset):
        train_step(inp, tar)
    
        # 55k samples
        # we display 3 batch results -- 0th, middle and last one (approx)
        # 55k / 64 ~ 858; 858 / 2 = 429
        if batch % 429 == 0:
            print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, batch, train_loss.result()))
      
    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))
    
    print ('Epoch {} Loss {:.4f}'.format(epoch + 1, train_loss.result()))

    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 Batch 0 Loss 0.0336


  1%|          | 1/100 [00:58<1:36:22, 58.41s/it]

Epoch 1 Loss 0.0412
Time taken for 1 epoch: 58.40574526786804 secs

Epoch 2 Batch 0 Loss 0.0288


  2%|▏         | 2/100 [01:56<1:35:25, 58.42s/it]

Epoch 2 Loss 0.0400
Time taken for 1 epoch: 58.431933641433716 secs

Epoch 3 Batch 0 Loss 0.0250


  3%|▎         | 3/100 [02:55<1:34:24, 58.40s/it]

Epoch 3 Loss 0.0391
Time taken for 1 epoch: 58.373368978500366 secs

Epoch 4 Batch 0 Loss 0.0326


  4%|▍         | 4/100 [03:53<1:33:25, 58.39s/it]

Epoch 4 Loss 0.0383
Time taken for 1 epoch: 58.36535120010376 secs

Epoch 5 Batch 0 Loss 0.0297


  5%|▌         | 5/100 [04:52<1:32:56, 58.70s/it]

Saving checkpoint for epoch 5 at checkpoints/ckpt-21
Epoch 5 Loss 0.0382
Time taken for 1 epoch: 59.25064253807068 secs

Epoch 6 Batch 0 Loss 0.0302


  6%|▌         | 6/100 [05:51<1:31:48, 58.60s/it]

Epoch 6 Loss 0.0366
Time taken for 1 epoch: 58.40699887275696 secs

Epoch 7 Batch 0 Loss 0.0254


  7%|▋         | 7/100 [06:49<1:30:44, 58.54s/it]

Epoch 7 Loss 0.0353
Time taken for 1 epoch: 58.41326093673706 secs

Epoch 8 Batch 0 Loss 0.0350


  8%|▊         | 8/100 [07:48<1:29:42, 58.51s/it]

Epoch 8 Loss 0.0351
Time taken for 1 epoch: 58.44146251678467 secs

Epoch 9 Batch 0 Loss 0.0320


  9%|▉         | 9/100 [08:46<1:28:41, 58.48s/it]

Epoch 9 Loss 0.0348
Time taken for 1 epoch: 58.42677092552185 secs

Epoch 10 Batch 0 Loss 0.0238


  9%|▉         | 9/100 [09:04<1:31:42, 60.47s/it]


KeyboardInterrupt: 

In [46]:
!pip install Rouge
import Rouge
# Create an instance of the Rouge object
rouge = Rouge()
system_summaries = model_summaries["summary"].tolist()
reference_summaries = xl_test["summary"].tolist()

# Calculate RougeL scores for the list of summaries
scores = rouge.get_scores(system_summaries, reference_summaries)

# Print the RougeL scores for each summary pair
rouge_l_score = []
for i, score in enumerate(scores):
    rouge_l_score.append(score['rouge-l']['f'])
    print("RougeL Score for Summary", i + 1, ":", score['rouge-l']['f'])

print(sum(rouge_l_score)/len(rouge_l_score))

Collecting Rouge
  Downloading rouge-1.0.1-py3-none-any.whl (13 kB)
Installing collected packages: Rouge
Successfully installed Rouge-1.0.1
[0m

ModuleNotFoundError: No module named 'Rouge'

In [None]:
# Create a figure with two subplots (one for training loss, one for validation/test loss)
fig = go.Figure()

# Add training loss subplot
fig.add_trace(go.Scatter(x=list(range(len(train_losses))), y=train_losses, mode='lines', name='Train Loss'))

# Add validation/test loss subplot
fig.add_trace(go.Scatter(x=list(range(len(val_losses))), y=val_losses, mode='lines', name='Validation/Test Loss'))

# Add labels and title
fig.update_layout(xaxis_title='Epoch', yaxis_title='Loss', title='Training and Validation/Test Loss')

# Show the plot
fig.show()

### Inference

#### Predicting one word at a time at the decoder and appending it to the output; then taking the complete sequence as an input to the decoder and repeating until maxlen or stop keyword appears

In [90]:
def evaluate(input_document):
    input_document = document_tokenizer.texts_to_sequences([input_document])
    input_document = tf.keras.preprocessing.sequence.pad_sequences(input_document, maxlen=encoder_maxlen, padding='post', truncating='post')

    encoder_input = tf.expand_dims(input_document[0], 0)

    decoder_input = [summary_tokenizer.word_index["<go>"]]
    output = tf.expand_dims(decoder_input, 0)
    
    for i in range(decoder_maxlen):
        enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input, output)

        predictions, attention_weights = transformer(
            encoder_input, 
            output,
            False,
            enc_padding_mask,
            combined_mask,
            dec_padding_mask
        )

        predictions = predictions[: ,-1:, :]
        predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)

        if predicted_id == summary_tokenizer.word_index["<stop>"]:
            return tf.squeeze(output, axis=0), attention_weights

        output = tf.concat([output, predicted_id], axis=-1)

    return tf.squeeze(output, axis=0), attention_weights

In [91]:
def summarize(input_document):
    # not considering attention weights for now, can be used to plot attention heatmaps in the future
    summarized = evaluate(input_document=input_document)[0].numpy()
    summarized = np.expand_dims(summarized[1:], 0)  # not printing <go> token
    return summary_tokenizer.sequences_to_texts(summarized)[0]  # since there is just one translated document

In [92]:
summarize(
    "US-based private equity firm General Atlantic is in talks to invest about \
    $850 million to $950 million in Reliance Industries' digital unit Jio \
    Platforms, the Bloomberg reported. Saudi Arabia's $320 billion sovereign \
    wealth fund is reportedly also exploring a potential investment in the \
    Mukesh Ambani-led company. The 'Public Investment Fund' is looking to \
    acquire a minority stake in Jio Platforms."
)

'وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف وأضاف'