In [31]:
import transformers
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertPreTrainedModel
from torch.optim import AdamW
from torch.utils.data import Dataset
import datasets
import pandas as pd
import numpy as np
import nltk
import sentencepiece
import re
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import contractions

In [32]:
dataset = pd.read_csv("dataset/cnn_dailymail/train.csv", nrows=1000)
nltk.download('wordnet')

[nltk_data] Downloading package wordnet to /Users/merrick/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [33]:
## keep the first N articles if you want to keep it lite 
dtf = pd.DataFrame(dataset).rename(columns={"article":"text", 
      "highlights":"y"})[["text","y"]]
dtf.head()

Unnamed: 0,text,y
0,By . Associated Press . PUBLISHED: . 14:11 EST...,"Bishop John Folda, of North Dakota, is taking ..."
1,(CNN) -- Ralph Mata was an internal affairs li...,Criminal complaint: Cop used his role to help ...
2,A drunk driver who killed a young woman in a h...,"Craig Eccleston-Todd, 27, had drunk at least t..."
3,(CNN) -- With a breezy sweep of his pen Presid...,Nina dos Santos says Europe must be ready to a...
4,Fleetwood are the only team still to have a 10...,Fleetwood top of League One after 2-0 win at S...


In [34]:
def utils_preprocess_text(txt, punkt=True, lower=True, slang=True, lst_stopwords=None, stemm=False, lemm=True):
    ### separate sentences with '. '
    txt = re.sub(r'\.(?=[^ \W\d])', '. ', str(txt))
    ### remove punctuations and characters
    txt = re.sub(r'[^\w\s]', '', txt) if punkt is True else txt
    ### strip
    txt = " ".join([word.strip() for word in txt.split()])
    ### lowercase
    txt = txt.lower() if lower is True else txt
    ### slang
    txt = contractions.fix(txt) if slang is True else txt   
    ### tokenize (convert from string to list)
    lst_txt = txt.split()
    ### stemming (remove -ing, -ly, ...)
    if stemm is True:
        ps = nltk.stem.porter.PorterStemmer()
        lst_txt = [ps.stem(word) for word in lst_txt]
    ### lemmatization (convert the word into root word)
    if lemm is True:
        lem = nltk.stem.wordnet.WordNetLemmatizer()
        lst_txt = [lem.lemmatize(word) for word in lst_txt]
    ### remove Stopwords
    if lst_stopwords is not None:
        lst_txt = [word for word in lst_txt if word not in 
                   lst_stopwords]
    ### back to string
    txt = " ".join(lst_txt)
    return txt

In [35]:
# preprocess text
dtf["text"] = dtf["text"].apply(lambda x: utils_preprocess_text(x))
dtf["y"] = dtf["y"].apply(lambda x: utils_preprocess_text(x))

In [36]:
# Instantiate the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize the text in the DataFrame
dtf['text'] = dtf['text'].apply(lambda x: tokenizer.encode(x, return_tensors='pt'))

# Tokenize the target summaries if applicable
dtf['y'] = dtf['y'].apply(lambda x: tokenizer.encode(x, return_tensors='pt'))


Token indices sequence length is longer than the specified maximum sequence length for this model (918 > 512). Running this sequence through the model will result in indexing errors


In [20]:
print(dtf["text"][0])

tensor([[  101,  2011,  3378,  2811,  2405, 15471,  2487,  9765,  2423,  2255,
          2286,  7172, 16710,  2575,  9765,  2423,  2255,  2286,  1996,  3387,
          1997,  1996, 23054,  3234,  5801,  1999,  2167,  7734,  5292,  6086,
          9280,  3634,  1997,  2277,  2266,  1999, 23054,  2882,  9292,  1998,
         27435,  2000,  1996, 28389,  1037,  7865,  1999,  2397,  2244,  1998,
          2220,  2255,  1996,  2110,  2740,  2533,  5292,  3843,  2019,  7319,
          1997,  7524,  2005,  3087,  2040,  3230,  2274,  2277,  1998,  2165,
         15661,  3387,  2198, 10671,  2050, 15885,  1997,  1996, 23054,  3234,
          5801,  1999,  2167,  7734,  5292,  6086,  9280,  3634,  1997,  2277,
          2266,  1999, 23054,  2882,  9292,  1998, 27435,  2000,  1996, 28389,
          1037,  2110, 10047, 23041,  3989,  2565,  3208,  9618, 18473,  2360,
          1996,  3891,  2003,  2659,  2021,  2880,  2514,  2009,  2590,  2000,
          9499,  2111,  2000,  1996,  2825,  7524,  

In [40]:
class CustomDataset(Dataset):
    def __init__(self, dtf):
        self.dtf = dtf

    def __len__(self):
        return len(self.dtf)

    def __getitem__(self, idx):
        return self.dtf[idx]

# Instantiate your custom dataset with the tokenized data
dataset = CustomDataset(dtf)

# Create a DataLoader with appropriate batch size, shuffle, and other options
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

In [41]:
print(dataloader)

<torch.utils.data.dataloader.DataLoader object at 0x2af114f90>


In [43]:
from transformers import AutoTokenizer

checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)


Downloading (…)/main/tokenizer.json: 0.00B [00:00, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [44]:
prefix = "summarize: "


def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs