<a href="https://colab.research.google.com/github/Arziva/Summarise_transformer/blob/main/summarise_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install trax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting trax
  Downloading trax-1.4.1-py2.py3-none-any.whl (637 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m637.9/637.9 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
Collecting funcsigs (from trax)
  Downloading funcsigs-1.0.2-py2.py3-none-any.whl (17 kB)
Collecting tensorflow-text (from trax)
  Downloading tensorflow_text-2.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m77.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: funcsigs, tensorflow-text, trax
Successfully installed funcsigs-1.0.2 tensorflow-text-2.12.1 trax-1.4.1


In [4]:
import sys
import os
import numpy as np

import textwrap
wrapper = textwrap.TextWrapper(width=70)

import trax
from trax import layers as tl
from trax.fastmath import numpy as jnp

# to print the entire np array
np.set_printoptions(threshold=sys.maxsize)

In [5]:
# This will download the dataset if no data_dir is specified.
# Downloading and processing can take bit of time,
# so we have the data already in 'data/' for you

# Importing CNN/DailyMail articles dataset
train_stream_fn = trax.data.TFDS('cnn_dailymail',
                                 data_dir='data/',
                                 keys=('article', 'highlights'),
                                 train=True)

# This should be much faster as the data is downloaded already.
eval_stream_fn = trax.data.TFDS('cnn_dailymail',
                                data_dir='data/',
                                keys=('article', 'highlights'),
                                train=False)



Downloading and preparing dataset 558.32 MiB (download: 558.32 MiB, generated: 1.29 GiB, total: 1.84 GiB) to data/cnn_dailymail/3.4.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/287113 [00:00<?, ? examples/s]

Shuffling data/cnn_dailymail/3.4.0.incompleteSM69QW/cnn_dailymail-train.tfrecord*...:   0%|          | 0/28711…

Generating validation examples...:   0%|          | 0/13368 [00:00<?, ? examples/s]

Shuffling data/cnn_dailymail/3.4.0.incompleteSM69QW/cnn_dailymail-validation.tfrecord*...:   0%|          | 0/…

Generating test examples...:   0%|          | 0/11490 [00:00<?, ? examples/s]

Shuffling data/cnn_dailymail/3.4.0.incompleteSM69QW/cnn_dailymail-test.tfrecord*...:   0%|          | 0/11490 …

Dataset cnn_dailymail downloaded and prepared to data/cnn_dailymail/3.4.0. Subsequent calls will reuse this data.


In [6]:
def tokenize(input_str, EOS=1):
    """Input str to features dict, ready for inference"""
  
    # Use the trax.data.tokenize method. It takes streams and returns streams,
    # we get around it by making a 1-element stream with `iter`.
    inputs =  next(trax.data.tokenize(iter([input_str]),
                                      vocab_dir='vocab_dir/',
                                      vocab_file='summarize32k.subword.subwords'))
    
    # Mark the end of the sentence with EOS
    return list(inputs) + [EOS]

def detokenize(integers):
    """List of ints to str"""
  
    s = trax.data.detokenize(integers,
                             vocab_dir='vocab_dir/',
                             vocab_file='summarize32k.subword.subwords')
    
    return wrapper.fill(s)

In [None]:
# Special tokens
SEP = 0 # Padding or separator token
EOS = 1 # End of sentence token

# Concatenate tokenized inputs and targets using 0 as separator.
def preprocess(stream):
    for (article, summary) in stream:
        joint = np.array(list(article) + [EOS, SEP] + list(summary) + [EOS])
        mask = [0] * (len(list(article)) + 2) + [1] * (len(list(summary)) + 1) # Accounting for EOS and SEP
        yield joint, joint, np.array(mask)

# You can combine a few data preprocessing steps into a pipeline like this.
input_pipeline = trax.data.Serial(
    # Tokenizes
    trax.data.Tokenize(vocab_dir='vocab_dir/',
                       vocab_file='summarize32k.subword.subwords'),
    # Uses function defined above
    preprocess,
    # Filters out examples longer than 2048
    trax.data.FilterByLength(2048)
)

# Apply preprocessing to data streams.
train_stream = input_pipeline(train_stream_fn())
eval_stream = input_pipeline(eval_stream_fn())

train_input, train_target, train_mask = next(train_stream)

assert sum((train_input - train_target)**2) == 0  # They are the same in Language Model (LM).