In [1]:
import spacy
import pandas as pd
from spacy.lang.en.stop_words import STOP_WORDS
from string import punctuation
from heapq import nlargest
from datasets import load_dataset

In [2]:
ds = load_dataset("cnn_dailymail", "1.0.0")
ds

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 11490
    })
})

In [3]:
stopwords = list(STOP_WORDS)
nlp = spacy.load('en_core_web_sm')

In [5]:
def select_main_sentence(text, punctuation, nlp):
    doc = nlp(text)
    tokens = [token.text for token in doc]
    punctuation = punctuation + '\n'
    sentence_tokens = [sent for sent in doc.sents]
    
    word_frequencies = {}
    for word in doc:
        if word.text.lower() not in stopwords:
            if word.text.lower() not in punctuation:
                if word.text not in word_frequencies.keys():
                    word_frequencies[word.text] = 1
                else:
                    word_frequencies[word.text] += 1

    sentence_scores = {}
    for sent in sentence_tokens:
        for word in sent:
            if word.text.lower() in word_frequencies.keys():
                if sent not in sentence_scores.keys():
                    sentence_scores[sent] = word_frequencies[word.text.lower()]
                else:
                    sentence_scores[sent] += word_frequencies[word.text.lower()]
         
    summary = nlargest(3, sentence_scores, key = sentence_scores.get)
    return summary
    
    


In [6]:
def run_file(ds_type):
    summary_list =  []
    for i in range(len(ds_type)):
    #for i in range(1):
        summary_list.append(select_main_sentence(ds_type[i]['article'], punctuation, nlp))
        print('\r {}/{}'.format(i, len(ds_type)), end='')
    return summary_list


In [7]:
f1  = run_file(ds['test'])
f1 = pd.DataFrame(f1)
f1.to_csv("summary_test.csv")
del f1

 11489/11490

In [20]:
f2  = run_file(ds['validation'])
f2 = pd.DataFrame(f2)
f2.to_csv("summary_validation.csv")
del f2

 13367/13368

In [ ]:
#f3 = run_file(ds['train'])