In [9]:
%load_ext autoreload

%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = True
## Set USE_PREPROCSSED_DATA = True to skip the data preprocessing
USE_PREPROCSSED_DATA = True

### Configuration


In [1]:
import os
import shutil
import sys
from tempfile import TemporaryDirectory
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

# from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset
# from utils_nlp.dataset.bundesministerium import BUNDBertSumProcessedData, BUNDSummarizationDataset
from utils_nlp.dataset.swiss import SwissSummarizationDataset

from utils_nlp.eval import compute_rouge_python, compute_rouge_perl
from utils_nlp.models.transformers.extractive_summarization import (
    ExtractiveSummarizer,
    ExtSumProcessedData,
    ExtSumProcessor,
)

from utils_nlp.models.transformers.datasets import SummarizationDataset
import nltk
from nltk import tokenize

import pandas as pd
import scrapbook as sb
import pprint

In [None]:
# things to do:

# clean up old pytorch tensors somehow?

# create the oracle summaries of each

# create the lead_1,lead_2,lead_3 summaries of each

# calculate rouge scores for BundesSet and swiss dataset

In [2]:
BUNDES_PYTORCH_DATA_PATH = "/home/ubuntu/mnt/data/bundes_dataset/bundes_processed/"
torch_bundes_train = torch.load(os.path.join(BUNDES_PYTORCH_DATA_PATH, "train_full.pt"))
torch_bundes_test = torch.load(os.path.join(BUNDES_PYTORCH_DATA_PATH, "test_full.pt"))
    

In [16]:
df = pd.DataFrame(torch_bundes_train)

In [25]:
df.dtypes

src           object
src_txt       object
tgt           object
tgt_txt       object
oracle_ids    object
dtype: object

In [34]:
df['summary'] = df['tgt_txt'].apply(lambda x: x[0])
df['source'] = df['src_txt'].apply(lambda x: x[0])

In [36]:
ada_ext = os.path.abspath("/home/ubuntu/adaptive-extractive-summarization/notebooks/")
if ada_ext not in sys.path:
    sys.path.insert(0, ada_ext)

import processing_utils

df['source_len'] = processing_utils.text_length(df['source'])
df['source_word_count'] = processing_utils.word_count(df['source'])

df['summary_len'] = processing_utils.text_length(df['summary'])
df['summary_word_count'] = processing_utils.word_count(df['summary'])


In [40]:
df['source']

0      WARC/1.0 WARC-Type: response WARC-Date: 2019-1...
1      WARC/1.0 WARC-Type: response WARC-Date: 2019-0...
2      WARC/1.0 WARC-Type: response WARC-Date: 2020-0...
3      WARC/1.0 WARC-Type: response WARC-Date: 2019-1...
4      WARC/1.0 WARC-Type: response WARC-Date: 2019-1...
                             ...                        
256    WARC/1.0 WARC-Type: response WARC-Date: 2020-0...
257    WARC/1.0 WARC-Type: response WARC-Date: 2018-1...
258    WARC/1.0 WARC-Type: response WARC-Date: 2020-0...
259    Coronapandemie: Artikel der Bundeskanzlerin un...
260    WARC/1.0 WARC-Type: response WARC-Date: 2019-0...
Name: source, Length: 261, dtype: object

In [38]:
summary_ratio = 0.8

print("Original df shape: ", df.shape)

df = df.drop_duplicates(subset=['summary'])
print("drop summary duplicates: ", df.shape)

df = df.drop_duplicates(subset=['source'])
print("drop source duplicates: ", df.shape)

df = df[~(df['source'] == df['summary'])]
print("drop summaries that are the same as the source: ", df.shape)

df = df[df['summary'].str.len() <  df['source'].str.len()*summary_ratio]
print("drop summary that is large percentage of source : ", df.shape)

df = df[df['source_word_count']>=80]
print("keep only sources 80 words and over: ", df.shape)

df = df[~df['summary'].str.contains('...',regex=False)]
print("remove summaries that end with ...", df.shape)

df = df[df['summary_word_count']<150]
print("remove summaries that are longer than 200 words ...", df.shape)
df = df.reset_index(drop=True)

Original df shape:  (20560, 11)
drop summary duplicates:  (17880, 11)
drop source duplicates:  (17139, 11)
drop summaries that are the same as the source:  (16788, 11)
drop summary that is large percentage of source :  (2768, 11)
keep only sources 80 words and over:  (276, 11)
remove summaries that end with ... (261, 11)
remove summaries that are longer than 200 words ... (261, 11)


In [None]:
validation = False

if DATA_NAME is "cnndm":
    train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH)
elif DATA_NAME is "swiss":
    if validation:
        train_dataset, validation_dataset, test_dataset = SwissSummarizationDataset(top_n=TOP_N, validation=True, language='german')
    else:
        train_dataset, test_dataset = SwissSummarizationDataset(top_n=TOP_N, validation=False, language='german')
        

In [None]:

len(train_dataset), len(validation_dataset), len(test_dataset)

### Preprocess the data.

In [None]:

ext_sum_train = processor.preprocess(train_dataset, oracle_mode="greedy")
ext_sum_test = processor.preprocess(test_dataset, oracle_mode="greedy")


### Save the data.

In [None]:
SAVE_DATA = False


# save and load preprocessed data

if SAVE_DATA:
    save_path = os.path.join(DATA_PATH, DATA_NAME + "_processed")
    os.makedirs(save_path, exist_ok=True)

    torch.save(ext_sum_train, os.path.join(save_path, "train_full.pt"))
    torch.save(ext_sum_test, os.path.join(save_path, "test_full.pt"))

In [None]:
len(ext_sum_train)

#### Inspect Data

In [None]:
ext_sum_train[0]

In [None]:
ext_sum_train[0].keys()

##### [Option 2] Reuse cached preprocessed data

In [None]:
if USE_PREPROCSSED_DATA:
    save_path = os.path.join(DATA_PATH)
    ext_sum_train = torch.load(os.path.join(save_path, "train_full.pt"))
    ext_sum_test = torch.load(os.path.join(save_path, "test_full.pt"))
    

### Model training
To start model training, we need to create a instance of ExtractiveSummarizer.

Potentionally, roberta-based model and xlnet can be supported but needs to be tested.
#### Choose the encoder algorithm.
There are four options:
- baseline: it used a smaller transformer model to replace the bert model and with transformer summarization layer
- classifier: it uses pretrained BERT and fine-tune BERT with **simple logistic classification** summarization layer
- transformer: it uses pretrained BERT and fine-tune BERT with **transformer** summarization layer
- RNN: it uses pretrained BERT and fine-tune BERT with **LSTM** summarization layer

In [None]:
BATCH_SIZE = 5 # batch size, unit is the number of samples
MAX_POS_LENGTH = 512


# GPU used for training
NUM_GPUS = torch.cuda.device_count()

# Encoder name. Options are: 1. baseline, classifier, transformer, rnn.
ENCODER = "transformer"

# Learning rate
LEARNING_RATE=2e-3

# How often the statistics reports show up in training, unit is step.
REPORT_EVERY=50

# total number of steps for training
MAX_STEPS=1e2
# number of steps for warm up
WARMUP_STEPS=5e2
    
if not QUICK_RUN:
    MAX_STEPS=5e4
    WARMUP_STEPS=5e3
 

In [None]:
summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)

In [None]:
summarizer.model_name

In [None]:
#"""

summarizer.fit(
            ext_sum_train,
            num_gpus=NUM_GPUS,
            batch_size=BATCH_SIZE,
            gradient_accumulation_steps=2,
            max_steps=MAX_STEPS,
            learning_rate=LEARNING_RATE,
            warmup_steps=WARMUP_STEPS,
            verbose=True,
            report_every=REPORT_EVERY,
            clip_grad_norm=False,
            use_preprocessed_data=False
        )

#"""


In [None]:
summarizer.save_model(
    os.path.join(
        CACHE_DIR,
        "extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt".format(
            MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS
        ),
    )
)

### Model Evaluation

[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluating text summarization.

In [None]:
# for loading a previous saved model

model_filename = "dist_extsum_model.pt"
model_filepath = "/home/ubuntu/mnt/train/distilbert-base-german-cased/2007142250/"
model_path = os.path.join(model_filepath, model_filename)
summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)
summarizer.model.load_state_dict(torch.load(model_path, map_location="cpu"))

In [None]:
if "segs" in ext_sum_test[0]: # preprocessed_data
    source = [i['src_txt'] for i in ext_sum_test]
    target = ["\n".join(i['tgt_txt'].split("<q>")) for i in ext_sum_test]
else:
    source = []
    temp_target = []
    for i in ext_sum_test:
        source.append(i["src_txt"]) 
        temp_target.append(" ".join(j) for j in i['tgt']) 
    target = [''.join(i) for i in list(temp_target)]

In [None]:
%%time
sentence_separator = "\n"
prediction = summarizer.predict(ext_sum_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE, sentence_separator=sentence_separator)

In [None]:
rouge_scores = compute_rouge_python(cand=prediction, ref=target)
pprint.pprint(rouge_scores)

In [None]:
prediction[0].replace("\n", " ")

In [None]:
with open('sample_results.txt','w') as f:
    for i in range(len(prediction)):
        source_output = " ".join(source[i]) 
        f.write("Source Text: \n")
        f.write("\"" + source_output + "\" \n")
        f.write("\n")
        f.write("Source target: \n")
        f.write("\"" + target[i] + "\" \n")
        f.write("\n")
        f.write("Model Prediction: \n")
        f.write("\"" + prediction[i].replace("\n", " ") + "\" \n")        
        f.write("\n")
        f.write("======================================")        
        f.write("\n \n")

In [None]:
target[10]

In [None]:
target[2]

In [None]:
prediction[10]

In [None]:
source

In [None]:
# for testing
sb.glue("rouge_2_f_score", rouge_scores['rouge-2']['f'])

## Prediction on a single input sample

In [None]:
source = """
Italien erlaubt nach tagelangem Zögern den etwa 180 Migranten auf dem privaten Rettungsschiff "Ocean Viking" den Wechsel auf das italienische Quarantäne-Schiff "Moby Zaza". Die Übernahme der aus Seenot geretteten Menschen sei für Montag geplant, hieß es am Samstagabend aus Quellen im Innenministerium in Rom. Zuvor hatte sich die Lage auf dem Schiff der Organisation SOS Méditerranée, das sich in internationalen Gewässern vor Sizilien befindet, zugespitzt.
Die Betreiber berichteten demnach von einem Hungerstreik unter den Geflüchteten. Verena Papke, Geschäftsführerin von SOS Méditerranée für Deutschland, hatte am Freitag von mehreren Suizidversuchen gesprochen. Die "Ocean Viking" hatte zudem den Notstand an Bord ausgerufen. Bis dahin waren mehrere Bitten um Zuweisung eines sicheren Hafens in Malta und Italien erfolglos geblieben.

Corona-Abstriche bei den Migranten geplant
Die Crew sandte die dringende Anfrage an die Behörden beider Länder zur Aufnahme von rund 45 Menschen, die in schlechter Verfassung seien. Italien schickte daraufhin am Samstag einen Psychiater und einen kulturellen Mediator aus Pozzallo für mehrere Stunden an Bord, berichteten beide Seiten. Danach kam die Erlaubnis aus Rom zur Übernahme auf die "Moby Zaza". Die Lage an Bord habe sich jedoch etwas entspannt, hieß es aus der italienischen Hauptstadt. Am Sonntag seien zunächst Corona-Abstriche bei den Migranten geplant.

Wie SOS Méditerranée am Samstag schrieb, nahm das Schiff in insgesamt vier Einsätzen am 25. und am 30. Juni etwa 180 Menschen aus dem Mittelmeer an Bord. Italien und Malta hatten sich in der Corona-Pandemie zu nicht sicheren Häfen erklärt. Trotzdem brechen Migranten von Libyen und Tunesien in Richtung Europa auf. Rom und Valletta nahmen zuletzt zwar wieder Menschen von privaten Schiffen auf, doch die Länder zögern mit der Zuweisung von Häfen oft lange. Sie fordern von anderen EU-Staaten regelmäßig Zusagen über die Weiterverteilung der Menschen."""

In [None]:
test_dataset = SummarizationDataset(
    None,
    source=[source],
    source_preprocessing=[tokenize.sent_tokenize],
    word_tokenize=nltk.word_tokenize,
    language='german'
)
processor = ExtSumProcessor(model_name=MODEL_NAME,  cache_dir=CACHE_DIR)
preprocessed_dataset = processor.preprocess(test_dataset)

In [None]:
preprocessed_dataset[0].keys()

In [None]:
prediction = summarizer.predict(preprocessed_dataset, num_gpus=0, batch_size=1, sentence_separator="\n")

In [None]:
prediction

## Clean up temporary folders

In [None]:
if os.path.exists(DATA_PATH):
    shutil.rmtree(DATA_PATH, ignore_errors=True)
if os.path.exists(CACHE_DIR):
    shutil.rmtree(CACHE_DIR, ignore_errors=True)
if USE_PREPROCSSED_DATA:
    if os.path.exists(PROCESSED_DATA_PATH):
        shutil.rmtree(PROCESSED_DATA_PATH, ignore_errors=True)