# Summarize Texts
In this notebook we summarize the paper using extractive summarization (BART) so that it can be fed into OpenAI gpt3.5 API.

Papers can be long, so it is easy to exceed the 1024 token limit of BART.

For this reason we approach ti summarization by splitting the paper in multiple parts before summarization, and then returning the complete document summarized.

In [101]:
from transformers import BartForConditionalGeneration, AutoTokenizer

model_ckpt = "sshleifer/distilbart-cnn-6-6"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BartForConditionalGeneration.from_pretrained(model_ckpt)

Papers to summarize

In [102]:
%store -r similar_docs
print(similar_docs[1].payload['body'])

1. Introduction
Artificial intelligence (AI) has attracted substantial attention in recent years and is often regarded as heralding the fourth industrial revolution, in many expert assessments (, ). Developed countries have made significant investments in AI research and its application in healthcare. The COVID-19 epidemic has increased the demand for AI resources and knowledge in healthcare in order to reduce workload and diagnostic errors ().
While the application of AI in radiological interpretations is well known, recent research has switched to investigating its potential in other domains to improve the efficiency and efficacy of radiologists. These applications include enhancing image collecting procedures, diagnosing pathology, increasing research productivity, optimizing radiation dosage, and providing high-quality medical treatment ().
The shortage of radiologists is a critical issue that has been affecting the healthcare sector in Saudi Arabia and around the world (). This sh

Tokenize the documents

In [103]:
import torch
from tqdm import tqdm
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_ids = [tokenizer(doc.payload['body'], padding='max_length', return_tensors='pt').to(device) for doc in tqdm(similar_docs)]

100%|██████████| 6/6 [00:00<00:00, 26.43it/s]


Split

In [86]:
len(input_ids[0]['input_ids'][0])

3678

In [105]:
max_size = 1024

documents_tokenized = []

for input_id in input_ids:
    n_splits = math.ceil(len(input_id[0])/max_size)
    token_splits = []
    for index in list(range(n_splits)):
        if(index != n_splits-1):
            print(str(index*max_size) + " - " + str((index+1)*max_size))
            token_splits.append({ "input_ids": torch.tensor(input_id['input_ids'][0][index*max_size:(index+1)*max_size]).unsqueeze(0), 
                                "attention_mask": torch.tensor(input_id['attention_mask'][0][index*max_size:(index+1)*max_size]).unsqueeze(0)})
        else:
            print(str(index*max_size) + " - " + str(len(input_id[0])%max_size + index*max_size))
            token_splits.append({ "input_ids": torch.tensor(input_id['input_ids'][0][index*max_size:len(input_id[0])%max_size + index*max_size]).unsqueeze(0), 
                            "attention_mask": torch.tensor(input_id['attention_mask'][0][index*max_size:len(input_id[0])%max_size + index*max_size]).unsqueeze(0)})
    
    documents_tokenized.append(token_splits)

0 - 1024
1024 - 2048
2048 - 3072
3072 - 3678
0 - 1024
1024 - 2048
2048 - 3072
3072 - 4096
4096 - 5120
5120 - 6111
0 - 1024
1024 - 2048
2048 - 3072
3072 - 3579
0 - 1024
1024 - 1756
0 - 1024
1024 - 2048
2048 - 3072
3072 - 4096
4096 - 5120
5120 - 6144
6144 - 6333
0 - 1024
1024 - 2048
2048 - 3072
3072 - 4096
4096 - 5120
5120 - 6144
6144 - 6441


  token_splits.append({ "input_ids": torch.tensor(input_id['input_ids'][0][index*max_size:(index+1)*max_size]).unsqueeze(0),
  "attention_mask": torch.tensor(input_id['attention_mask'][0][index*max_size:(index+1)*max_size]).unsqueeze(0)})
  token_splits.append({ "input_ids": torch.tensor(input_id['input_ids'][0][index*max_size:len(input_id[0])%max_size + index*max_size]).unsqueeze(0),
  "attention_mask": torch.tensor(input_id['attention_mask'][0][index*max_size:len(input_id[0])%max_size + index*max_size]).unsqueeze(0)})


In [106]:
summaries = []

for doc_tokenized in tqdm(documents_tokenized):
    doc_summary = []
    for index in tqdm(list(range(len(doc_tokenized)))):
        doc_summary.append(model.generate(input_ids=doc_tokenized[index]['input_ids'], 
                            attention_mask=doc_tokenized[index]['attention_mask'],
                            min_length=16, 
                            max_length=64))
    summaries.append(doc_summary)

100%|██████████| 4/4 [00:40<00:00, 10.18s/it]
100%|██████████| 6/6 [01:03<00:00, 10.60s/it]
100%|██████████| 4/4 [00:41<00:00, 10.47s/it]
100%|██████████| 2/2 [00:18<00:00,  9.47s/it]
100%|██████████| 7/7 [01:08<00:00,  9.86s/it]
100%|██████████| 7/7 [01:10<00:00, 10.04s/it]
100%|██████████| 6/6 [05:04<00:00, 50.74s/it]


In [107]:
text_summaries = []

for summary in summaries:
    text_summary = ""

    for split in summary:
        extracted_summary = tokenizer.decode(split[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
    
        if ("." in extracted_summary):
            text_summary += (".".join(extracted_summary.split(".")[0:-1])) + "\n"
        else:
            text_summary += extracted_summary + "\n"
    
    text_summaries.append(text_summary)

In [108]:
%store text_summaries

Stored 'text_summaries' (list)
