Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

In [1]:
%load_ext autoreload

%autoreload 2

### Configuration


In [2]:
import os
import shutil
import sys
from tempfile import TemporaryDirectory
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.dataset.swiss import SwissSummarizationDataset
from utils_nlp.dataset.bundes import BundesSummarizationDataset

from utils_nlp.eval import compute_rouge_python, compute_rouge_perl
from utils_nlp.models.transformers.extractive_summarization import (
    ExtractiveSummarizer,
    ExtSumProcessedData,
    ExtSumProcessor,
)

from utils_nlp.models.transformers.datasets import SummarizationDataset
import nltk
from nltk import tokenize

import pandas as pd
import scrapbook as sb
import pprint

  import pandas.util.testing as tm


In [3]:
pd.DataFrame({"model_name": ExtractiveSummarizer.list_supported_models()})

Unnamed: 0,model_name
0,bert-base-uncased
1,bert-base-german-cased
2,distilbert-base-uncased
3,dbmdz/bert-base-german-uncased
4,bert-base-german-dbmdz-cased
5,bert-base-multilingual-cased
6,distilbert-base-german-cased
7,bert-base-german-dbmdz-uncased
8,severinsimmler/bert-adapted-german-press
9,xlm-roberta-large-finetuned-conll03-german


In [4]:
# notebook parameters
# the cache data path during find tuning
CACHE_DIR = TemporaryDirectory().name

In [5]:
BUNDES_DATA_PATH='/home/ubuntu/data/bundes_dataset/'
SWISS_DATA_PATH='/home/ubuntu/data/swiss_dataset/'

bundes_save_path = os.path.join(BUNDES_DATA_PATH)
# bundes_train = torch.load(os.path.join(bundes_save_path, "train_full202008111812.pt"))
bundes_test = torch.load(os.path.join(bundes_save_path, "test_full202008111812.pt"))


swiss_save_path = os.path.join(SWISS_DATA_PATH)
# swiss_train = torch.load(os.path.join(swiss_save_path, "train_full.pt"))
swiss_test = torch.load(os.path.join(swiss_save_path, "test_full.pt"))


### Model Evaluation

[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluating text summarization.

In [6]:
# models:

model_names = ['distilbert-base-german-cased', 'bert-base-german-cased']
train_names = ['200805_distilbert-base-german-cased_swiss', 
               '200806_bert-base-german-cased_swiss',
               '200811_bert-base-german-cased_swissBundes',
               '200811_distilbert-base-german-cased_swissBundes',
               '200812_bert-base-german-cased_bundes',
               '200812_distilbert-base-german-cased_bundes',
               'lead_1',
               'lead_2',
               'lead_3',
              ]

model_filepaths = ['/home/ubuntu/models/200805_distilbert-base-german-cased_swiss/output/',
                   '/home/ubuntu/models/200806_bert-base-german-cased_swiss/output/',
                   '/home/ubuntu/models/200811_bert-base-german-cased_swissBundes/output/',
                   '/home/ubuntu/models/200811_distilbert-base-german-cased_swissBundes/output/',
                   '/home/ubuntu/models/200812_bert-base-german-cased_bundes/output/',
                   '/home/ubuntu/models/200812_distilbert-base-german-cased_bundes/output/',
                  ]

models = {
    '200805_distilbert-base-german-cased_swiss':{
        'model': 'distilbert-base-german-cased',
        'filepath': '/home/ubuntu/models/200805_distilbert-base-german-cased_swiss/output/'
    },
    '200806_bert-base-german-cased_swiss':{
        'model': 'bert-base-german-cased',
        'filepath': '/home/ubuntu/models/200806_bert-base-german-cased_swiss/output/'
    },
    '200811_bert-base-german-cased_swissBundes':{
        'model': 'bert-base-german-cased',
        'filepath': '/home/ubuntu/models/200811_bert-base-german-cased_swissBundes/output/'
    },
    '200811_distilbert-base-german-cased_swissBundes':{
        'model': 'distilbert-base-german-cased',
        'filepath': '/home/ubuntu/models/200811_distilbert-base-german-cased_swissBundes/output/'
    },
    '200812_bert-base-german-cased_bundes':{
        'model': 'bert-base-german-cased',
        'filepath': '/home/ubuntu/models/200812_bert-base-german-cased_bundes/output/'
    },
    '200812_distilbert-base-german-cased_bundes':{
        'model': 'distilbert-base-german-cased',
        'filepath': '/home/ubuntu/models/200812_distilbert-base-german-cased_bundes/output/'
    },
    
}

In [7]:
MAX_POS_LENGTH = 512


# GPU used for training
NUM_GPUS = torch.cuda.device_count()

# Encoder name. Options are: 1. baseline, classifier, transformer, rnn.
ENCODER = "transformer"

# How often the statistics reports show up in training, unit is step.
REPORT_EVERY=50

In [9]:
# create processors:
processors = {}
for model_name in model_names:
    processors[model_name] = ExtSumProcessor(model_name=model_name, cache_dir=CACHE_DIR)

HBox(children=(IntProgress(value=0, description='Downloading', max=464, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='Downloading', max=239836, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Downloading', max=433, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='Downloading', max=254728, style=ProgressStyle(description_wid…




In [10]:
summarizers= {}
model_filename = "dist_extsum_model.pt"

for model, meta in list(models.items()):
    print("creating summarizer for", model)

    processor = processors[meta['model']]
    print("Processor loaded for", meta['model'])
    
    model_path = os.path.join(meta['filepath'], model_filename)
    summarizer = ExtractiveSummarizer(processor, meta['model'], ENCODER, MAX_POS_LENGTH, CACHE_DIR)
    summarizer.model.load_state_dict(torch.load(model_path, map_location="cpu"))
    print("model loaded for", meta['model'])
    summarizers[model] = summarizer

creating summarizer for 200805_distilbert-base-german-cased_swiss
Processor loaded for distilbert-base-german-cased


HBox(children=(IntProgress(value=0, description='Downloading', max=269752043, style=ProgressStyle(description_…


model loaded for distilbert-base-german-cased
creating summarizer for 200806_bert-base-german-cased_swiss
Processor loaded for bert-base-german-cased


HBox(children=(IntProgress(value=0, description='Downloading', max=438869143, style=ProgressStyle(description_…


model loaded for bert-base-german-cased
creating summarizer for 200811_bert-base-german-cased_swissBundes
Processor loaded for bert-base-german-cased
model loaded for bert-base-german-cased
creating summarizer for 200811_distilbert-base-german-cased_swissBundes
Processor loaded for distilbert-base-german-cased
model loaded for distilbert-base-german-cased
creating summarizer for 200812_bert-base-german-cased_bundes
Processor loaded for bert-base-german-cased
model loaded for bert-base-german-cased
creating summarizer for 200812_distilbert-base-german-cased_bundes
Processor loaded for distilbert-base-german-cased
model loaded for distilbert-base-german-cased


In [11]:
source = {}
target = {}


source['bundes'] = []
source['swiss'] = []

temp_target_bundes = []
temp_target_swiss = []
for i in bundes_test:
    source['bundes'].append(i["src_txt"]) 
    
    temp_target_bundes.append(" ".join(j) for j in i['tgt']) 
target['bundes'] = [''.join(i) for i in list(temp_target_bundes)]

for i in swiss_test:
    source['swiss'].append(i["src_txt"]) 
    
    temp_target_swiss.append(" ".join(j) for j in i['tgt']) 
target['swiss'] = [''.join(i) for i in list(temp_target_swiss)]


### create test dictionary
torch_tests = {
    'bundes': bundes_test,
    'swiss': swiss_test
}

In [12]:
summarizers['lead_1']={
    'bundes': [s[0] for s in source['bundes']],
    'swiss': [s[0] for s in source['swiss']]
}

summarizers['lead_2']={
    'bundes': [" ".join(s[:2]) for s in source['bundes']],
    'swiss': [" ".join(s[:2]) for s in source['swiss']],
}

summarizers['lead_3']={
    'bundes': [" ".join(s[:3]) for s in source['bundes']],
    'swiss': [" ".join(s[:3]) for s in source['swiss']], 
}

In [18]:
%%time
sentence_separator = "\n"
batch_size = 400
rouge_scores = {}
predictions = {}

TEST = False


for dataset in ['bundes','swiss']:
    predictions[dataset] = {}
    rouge_scores[dataset] = {}
    print("Dataset: ", dataset)
    if TEST:
        n = 2
    else:
        n = len(torch_tests[dataset])
    print("Sample size:", n)
    
    for train_name, summarizer in summarizers.items():
        print("model name: ", train_name)
        if "lead" in train_name:
            predictions[dataset][train_name] = summarizer[dataset][:n]
        else:
            predictions[dataset][train_name] = summarizer.predict(torch_tests[dataset][:n], num_gpus=NUM_GPUS, batch_size=batch_size, sentence_separator=sentence_separator)
        
        rouge_scores[dataset][train_name] = compute_rouge_python(cand=predictions[dataset][train_name], ref=target[dataset][:n])
    




Scoring:   0%|          | 0/2 [00:00<?, ?it/s][A[A[A

Dataset:  bundes
Sample size: 545
model name:  200805_distilbert-base-german-cased_swiss





Scoring:  50%|█████     | 1/2 [00:09<00:09,  9.69s/it][A[A[A


Scoring: 100%|██████████| 2/2 [00:13<00:00,  7.79s/it][A[A[A

Number of candidates: 545
Number of references: 545





Scoring:   0%|          | 0/2 [00:00<?, ?it/s][A[A[A

model name:  200806_bert-base-german-cased_swiss





Scoring:  50%|█████     | 1/2 [00:09<00:09,  9.55s/it][A[A[A


Scoring: 100%|██████████| 2/2 [00:12<00:00,  7.71s/it][A[A[A

Number of candidates: 545
Number of references: 545





Scoring:   0%|          | 0/2 [00:00<?, ?it/s][A[A[A

model name:  200811_bert-base-german-cased_swissBundes





Scoring:  50%|█████     | 1/2 [00:09<00:09,  9.55s/it][A[A[A


Scoring: 100%|██████████| 2/2 [00:12<00:00,  7.71s/it][A[A[A

Number of candidates: 545
Number of references: 545





Scoring:   0%|          | 0/2 [00:00<?, ?it/s][A[A[A

model name:  200811_distilbert-base-german-cased_swissBundes





Scoring:  50%|█████     | 1/2 [00:09<00:09,  9.50s/it][A[A[A


Scoring: 100%|██████████| 2/2 [00:12<00:00,  7.67s/it][A[A[A

Number of candidates: 545
Number of references: 545





Scoring:   0%|          | 0/2 [00:00<?, ?it/s][A[A[A

model name:  200812_bert-base-german-cased_bundes





Scoring:  50%|█████     | 1/2 [00:09<00:09,  9.51s/it][A[A[A


Scoring: 100%|██████████| 2/2 [00:12<00:00,  7.68s/it][A[A[A

Number of candidates: 545
Number of references: 545





Scoring:   0%|          | 0/2 [00:00<?, ?it/s][A[A[A

model name:  200812_distilbert-base-german-cased_bundes





Scoring:  50%|█████     | 1/2 [00:09<00:09,  9.46s/it][A[A[A


Scoring: 100%|██████████| 2/2 [00:12<00:00,  7.64s/it][A[A[A

Number of candidates: 545
Number of references: 545
model name:  lead_1
Number of candidates: 545
Number of references: 545
model name:  lead_2
Number of candidates: 545
Number of references: 545
model name:  lead_3
Number of candidates: 545
Number of references: 545





Scoring:   0%|          | 0/13 [00:00<?, ?it/s][A[A[A

Dataset:  swiss
Sample size: 5000
model name:  200805_distilbert-base-german-cased_swiss





Scoring:   8%|▊         | 1/13 [00:10<02:04, 10.35s/it][A[A[A


Scoring:  15%|█▌        | 2/13 [00:20<01:52, 10.21s/it][A[A[A


Scoring:  23%|██▎       | 3/13 [00:30<01:42, 10.25s/it][A[A[A


Scoring:  31%|███       | 4/13 [00:40<01:32, 10.25s/it][A[A[A


Scoring:  38%|███▊      | 5/13 [00:50<01:20, 10.11s/it][A[A[A


Scoring:  46%|████▌     | 6/13 [01:00<01:10, 10.09s/it][A[A[A


Scoring:  54%|█████▍    | 7/13 [01:10<01:00, 10.15s/it][A[A[A


Scoring:  62%|██████▏   | 8/13 [01:21<00:50, 10.14s/it][A[A[A


Scoring:  69%|██████▉   | 9/13 [01:31<00:40, 10.16s/it][A[A[A


Scoring:  77%|███████▋  | 10/13 [01:41<00:30, 10.06s/it][A[A[A


Scoring:  85%|████████▍ | 11/13 [01:51<00:20, 10.09s/it][A[A[A


Scoring:  92%|█████████▏| 12/13 [02:01<00:10, 10.06s/it][A[A[A


Scoring: 100%|██████████| 13/13 [02:06<00:00,  8.66s/it][A[A[A

Number of candidates: 5000
Number of references: 5000





Scoring:   0%|          | 0/13 [00:00<?, ?it/s][A[A[A

model name:  200806_bert-base-german-cased_swiss





Scoring:   8%|▊         | 1/13 [00:10<02:04, 10.36s/it][A[A[A


Scoring:  15%|█▌        | 2/13 [00:20<01:52, 10.23s/it][A[A[A


Scoring:  23%|██▎       | 3/13 [00:30<01:42, 10.29s/it][A[A[A


Scoring:  31%|███       | 4/13 [00:40<01:32, 10.27s/it][A[A[A


Scoring:  38%|███▊      | 5/13 [00:50<01:21, 10.13s/it][A[A[A


Scoring:  46%|████▌     | 6/13 [01:00<01:10, 10.11s/it][A[A[A


Scoring:  54%|█████▍    | 7/13 [01:11<01:00, 10.15s/it][A[A[A


Scoring:  62%|██████▏   | 8/13 [01:21<00:50, 10.13s/it][A[A[A


Scoring:  69%|██████▉   | 9/13 [01:31<00:40, 10.15s/it][A[A[A


Scoring:  77%|███████▋  | 10/13 [01:41<00:30, 10.04s/it][A[A[A


Scoring:  85%|████████▍ | 11/13 [01:51<00:20, 10.08s/it][A[A[A


Scoring:  92%|█████████▏| 12/13 [02:01<00:10, 10.07s/it][A[A[A


Scoring: 100%|██████████| 13/13 [02:06<00:00,  8.67s/it][A[A[A

Number of candidates: 5000
Number of references: 5000





Scoring:   0%|          | 0/13 [00:00<?, ?it/s][A[A[A

model name:  200811_bert-base-german-cased_swissBundes





Scoring:   8%|▊         | 1/13 [00:10<02:05, 10.44s/it][A[A[A


Scoring:  15%|█▌        | 2/13 [00:20<01:53, 10.29s/it][A[A[A


Scoring:  23%|██▎       | 3/13 [00:30<01:43, 10.32s/it][A[A[A


Scoring:  31%|███       | 4/13 [00:41<01:32, 10.29s/it][A[A[A


Scoring:  38%|███▊      | 5/13 [00:50<01:21, 10.14s/it][A[A[A


Scoring:  46%|████▌     | 6/13 [01:00<01:10, 10.11s/it][A[A[A


Scoring:  54%|█████▍    | 7/13 [01:11<01:00, 10.14s/it][A[A[A


Scoring:  62%|██████▏   | 8/13 [01:21<00:50, 10.13s/it][A[A[A


Scoring:  69%|██████▉   | 9/13 [01:31<00:40, 10.15s/it][A[A[A


Scoring:  77%|███████▋  | 10/13 [01:41<00:30, 10.04s/it][A[A[A


Scoring:  85%|████████▍ | 11/13 [01:51<00:20, 10.07s/it][A[A[A


Scoring:  92%|█████████▏| 12/13 [02:01<00:10, 10.06s/it][A[A[A


Scoring: 100%|██████████| 13/13 [02:06<00:00,  8.66s/it][A[A[A

Number of candidates: 5000
Number of references: 5000





Scoring:   0%|          | 0/13 [00:00<?, ?it/s][A[A[A

model name:  200811_distilbert-base-german-cased_swissBundes





Scoring:   8%|▊         | 1/13 [00:10<02:05, 10.45s/it][A[A[A


Scoring:  15%|█▌        | 2/13 [00:20<01:53, 10.32s/it][A[A[A


Scoring:  23%|██▎       | 3/13 [00:30<01:43, 10.35s/it][A[A[A


Scoring:  31%|███       | 4/13 [00:41<01:33, 10.34s/it][A[A[A


Scoring:  38%|███▊      | 5/13 [00:51<01:21, 10.19s/it][A[A[A


Scoring:  46%|████▌     | 6/13 [01:01<01:11, 10.16s/it][A[A[A


Scoring:  54%|█████▍    | 7/13 [01:11<01:01, 10.21s/it][A[A[A


Scoring:  62%|██████▏   | 8/13 [01:21<00:50, 10.19s/it][A[A[A


Scoring:  69%|██████▉   | 9/13 [01:31<00:40, 10.20s/it][A[A[A


Scoring:  77%|███████▋  | 10/13 [01:41<00:30, 10.10s/it][A[A[A


Scoring:  85%|████████▍ | 11/13 [01:51<00:20, 10.14s/it][A[A[A


Scoring:  92%|█████████▏| 12/13 [02:01<00:10, 10.11s/it][A[A[A


Scoring: 100%|██████████| 13/13 [02:07<00:00,  8.70s/it][A[A[A

Number of candidates: 5000
Number of references: 5000





Scoring:   0%|          | 0/13 [00:00<?, ?it/s][A[A[A

model name:  200812_bert-base-german-cased_bundes





Scoring:   8%|▊         | 1/13 [00:10<02:04, 10.38s/it][A[A[A


Scoring:  15%|█▌        | 2/13 [00:20<01:52, 10.25s/it][A[A[A


Scoring:  23%|██▎       | 3/13 [00:30<01:42, 10.29s/it][A[A[A


Scoring:  31%|███       | 4/13 [00:40<01:32, 10.28s/it][A[A[A


Scoring:  38%|███▊      | 5/13 [00:50<01:21, 10.13s/it][A[A[A


Scoring:  46%|████▌     | 6/13 [01:00<01:10, 10.10s/it][A[A[A


Scoring:  54%|█████▍    | 7/13 [01:11<01:00, 10.14s/it][A[A[A


Scoring:  62%|██████▏   | 8/13 [01:21<00:50, 10.12s/it][A[A[A


Scoring:  69%|██████▉   | 9/13 [01:31<00:40, 10.15s/it][A[A[A


Scoring:  77%|███████▋  | 10/13 [01:41<00:30, 10.04s/it][A[A[A


Scoring:  85%|████████▍ | 11/13 [01:51<00:20, 10.08s/it][A[A[A


Scoring:  92%|█████████▏| 12/13 [02:01<00:10, 10.07s/it][A[A[A


Scoring: 100%|██████████| 13/13 [02:06<00:00,  8.67s/it][A[A[A

Number of candidates: 5000
Number of references: 5000





Scoring:   0%|          | 0/13 [00:00<?, ?it/s][A[A[A

model name:  200812_distilbert-base-german-cased_bundes





Scoring:   8%|▊         | 1/13 [00:10<02:05, 10.44s/it][A[A[A


Scoring:  15%|█▌        | 2/13 [00:20<01:53, 10.32s/it][A[A[A


Scoring:  23%|██▎       | 3/13 [00:30<01:43, 10.35s/it][A[A[A


Scoring:  31%|███       | 4/13 [00:41<01:32, 10.33s/it][A[A[A


Scoring:  38%|███▊      | 5/13 [00:51<01:21, 10.18s/it][A[A[A


Scoring:  46%|████▌     | 6/13 [01:01<01:11, 10.16s/it][A[A[A


Scoring:  54%|█████▍    | 7/13 [01:11<01:01, 10.20s/it][A[A[A


Scoring:  62%|██████▏   | 8/13 [01:21<00:50, 10.18s/it][A[A[A


Scoring:  69%|██████▉   | 9/13 [01:31<00:40, 10.20s/it][A[A[A


Scoring:  77%|███████▋  | 10/13 [01:41<00:30, 10.10s/it][A[A[A


Scoring:  85%|████████▍ | 11/13 [01:51<00:20, 10.13s/it][A[A[A


Scoring:  92%|█████████▏| 12/13 [02:01<00:10, 10.11s/it][A[A[A


Scoring: 100%|██████████| 13/13 [02:07<00:00,  8.70s/it][A[A[A

Number of candidates: 5000
Number of references: 5000
model name:  lead_1
Number of candidates: 5000
Number of references: 5000
model name:  lead_2
Number of candidates: 5000
Number of references: 5000
model name:  lead_3
Number of candidates: 5000
Number of references: 5000
CPU times: user 17min 54s, sys: 1min 38s, total: 19min 33s
Wall time: 18min 45s


In [19]:
# print out the calculated rouge scores
pprint.pprint(rouge_scores)

{'bundes': {'200805_distilbert-base-german-cased_swiss': {'rouge-1': {'f': 0.3827514805602676,
                                                                      'p': 0.30461223878755994,
                                                                      'r': 0.6075377864949992},
                                                          'rouge-2': {'f': 0.2966035810615314,
                                                                      'p': 0.2358905302107044,
                                                                      'r': 0.4753467444078293},
                                                          'rouge-l': {'f': 0.35791017468425335,
                                                                      'p': 0.28495270654344645,
                                                                      'r': 0.5680595056734346}},
            '200806_bert-base-german-cased_swiss': {'rouge-1': {'f': 0.3352210738955823,
                                                 

In [20]:
import pickle

pickle.dump(rouge_scores, open('rouge_scores.p','wb'))

In [None]:
with open('sample_results.txt','w') as f:
    for i in range(len(prediction)):
        source_output = " ".join(source[i]) 
        f.write("Source Text: \n")
        f.write("\"" + source_output + "\" \n")
        f.write("\n")
        f.write("Source target: \n")
        f.write("\"" + target[i] + "\" \n")
        f.write("\n")
        f.write("Model Prediction: \n")
        f.write("\"" + prediction[i].replace("\n", " ") + "\" \n")        
        f.write("\n")
        f.write("======================================")        
        f.write("\n \n")

In [None]:
target[10]

In [None]:
prediction[10]

In [None]:
# for testing
sb.glue("rouge_2_f_score", rouge_scores['rouge-2']['f'])

## Prediction on a single input sample

## Clean up temporary folders

In [None]:
if os.path.exists(DATA_PATH):
    shutil.rmtree(DATA_PATH, ignore_errors=True)
if os.path.exists(CACHE_DIR):
    shutil.rmtree(CACHE_DIR, ignore_errors=True)
if USE_PREPROCSSED_DATA:
    if os.path.exists(PROCESSED_DATA_PATH):
        shutil.rmtree(PROCESSED_DATA_PATH, ignore_errors=True)