In [1]:
import numpy as np
import os
import pandas as pd
import regex as re
import datetime
import matplotlib.pyplot as plt
from tqdm import tqdm
from difflib import SequenceMatcher
from collections import Counter
from functools import partial

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.callbacks import TensorBoard

# import pysbd
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import words
from nltk.corpus import stopwords
# from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
from datasets import load_metric
from transformers import AutoTokenizer
from transformers import AdamWeightDecay
from transformers import TFAutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
from transformers.keras_callbacks import KerasMetricCallback
import evaluate
import bert_score
from sklearn.model_selection import train_test_split


In [2]:
df=pd.read_csv('news_summary.csv',encoding='ISO-8859-1')
df=df.drop(['author','date','read_more'],axis=1)
df=df.rename(columns={'text':'summary'})

In [3]:
df['ctext']=df['ctext'].str.lower()
df['summary']=df['summary'].str.lower()
df['headlines']=df['headlines'].str.lower()
df=df.dropna(how='any',axis=0)

In [4]:
df['summary'].apply(lambda x:len(x)).max()

400

In [7]:
import string
def remove_punctuation(text):
    
    translator = str.maketrans('', '', string.punctuation)
    return text.translate(translator)

df['ctext'] = df['ctext'].apply(remove_punctuation)

In [8]:
df.to_csv('refined.csv',index=False)

In [9]:
dataframe=load_dataset('csv',data_files='refined.csv',encoding= 'ISO-8859-1')
dataframe

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['headlines', 'summary', 'ctext'],
        num_rows: 4396
    })
})

In [10]:
train_set,val_set=dataframe['train'].train_test_split(test_size=0.1).values()

In [11]:
model_checkpoint='facebook/bart-large-cnn'
metric=load_metric('rouge')
tokeizer=AutoTokenizer.from_pretrained(model_checkpoint)
max_length=1024
max_target=300
def apply_tokenization(dataset,sentence,target,maxlen,maxtarget,tokenizer):
    inputs=tokenizer(dataset[sentence],max_length=maxlen,truncation=True)
    with tokeizer.as_target_tokenizer():
        labels=tokenizer(dataset[target],max_length=max_target,truncation=True)
        inputs['labels']=labels["input_ids"]
    
    return inputs


In [15]:
train_set_tokenized=train_set.map(
    lambda batch: apply_tokenization(
        batch,'ctext','summary',max_length,max_target,tokeizer
    ),
    batched=True,
    remove_columns=train_set.column_names
)
val_set_tokenized=val_set.map(
    lambda batch:apply_tokenization(
        batch,'ctext','summary',max_length,max_target,tokeizer
    ),
    batched=True,
    remove_columns=val_set.column_names
)

Map:   0%|          | 0/3956 [00:00<?, ? examples/s]



Map:   0%|          | 0/440 [00:00<?, ? examples/s]

In [41]:
model=TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
data_collator=DataCollatorForSeq2Seq(tokeizer,model=model,return_tensors='tf')

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBartForConditionalGeneration: ['model.decoder.embed_tokens.weight']
- This IS expected if you are initializing TFBartForConditionalGeneration from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBartForConditionalGeneration from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFBartForConditionalGeneration were not initialized from the PyTorch model and are newly initialized: ['model.shared.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [52]:
weight_decay=0.001
epochs=5
learning_rate=2e-5

train_dataset=train_set_tokenized.to_tf_dataset(
    batch_size=4,
    columns=['input_ids','attention_mask','labels'],
    shuffle=False,
    drop_remainder=True,
    collate_fn=data_collator
)
valid_dataset=val_set_tokenized.to_tf_dataset(
    batch_size=4,
    columns=['input_ids','attention_mask','labels'],
    shuffle=False,
    collate_fn=data_collator
)
generation_dataset=val_set_tokenized.shuffle().select(list(range(200))).to_tf_dataset(
    batch_size=4,
    columns=['input_ids','attention_mask','labels'],
    shuffle=False,
    collate_fn=data_collator
)

In [55]:
optimizer=AdamWeightDecay(learning_rate=learning_rate,weight_decay_rate=weight_decay)
model.compile(optimizer=optimizer)

In [57]:
def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id  # Replace masked label tokens
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # Rouge expects a newline after each sentence
    decoded_predictions = [
        "\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_predictions
    ]
    decoded_labels = [
        "\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels
    ]
    result = metric.compute(
        predictions=decoded_predictions, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    # Add mean generated length
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions
    ]
    result["gen_len"] = np.mean(prediction_lens)

    return result
metric_callback=KerasMetricCallback(
    metric_fn,eval_dataset=generation_dataset,predict_with_generate=True
)



In [58]:
# fine-tune the model
history = model.fit(
    train_dataset, validation_data=valid_dataset, epochs=epochs, callbacks=metric_callback
)

Epoch 1/5
  2/989 [..............................] - ETA: 32:27:50 - loss: 10.8363