## Step 1: Extracting relevant sections and passing converting them into embedding vectors

In [1]:
import pandas as pd

In [2]:
import os
os.chdir('/kaggle/input/data-laysumm')

filename = "eLife_train.jsonl"
df = pd.read_json(filename, orient="records", lines=True)
df.head()

Unnamed: 0,lay_summary,article,headings,keywords,id
0,"In the USA , more deaths happen in the winter ...","In temperate climates , winter deaths exceed s...","[Abstract, Introduction, Results, Discussion, ...",[epidemiology and global health],elife-35500-v1
1,Most people have likely experienced the discom...,Whether complement dysregulation directly cont...,"[Abstract, Introduction, Results, Discussion, ...","[microbiology and infectious disease, immunolo...",elife-48378-v2
2,The immune system protects an individual from ...,Variation in the presentation of hereditary im...,"[Abstract, Introduction, Results, Discussion, ...","[microbiology and infectious disease, immunolo...",elife-04494-v1
3,The brain adapts to control our behavior in di...,Rapid and flexible interpretation of conflicti...,"[Abstract, Introduction, Results, Discussion, ...",[neuroscience],elife-12352-v2
4,Cells use motor proteins that to move organell...,Myosin 5a is a dual-headed molecular motor tha...,"[Abstract, Introduction, Results, Discussion, ...",[structural biology and molecular biophysics],elife-05413-v2


In [3]:
df.loc[0, 'lay_summary']



In [4]:
print(df.loc[0, 'article'])

In temperate climates , winter deaths exceed summer ones . However , there is limited information on the timing and the relative magnitudes of maximum and minimum mortality , by local climate , age group , sex and medical cause of death . We used geo-coded mortality data and wavelets to analyse the seasonality of mortality by age group and sex from 1980 to 2016 in the USA and its subnational climatic regions . Death rates in men and women ≥ 45 years peaked in December to February and were lowest in June to August , driven by cardiorespiratory diseases and injuries . In these ages , percent difference in death rates between peak and minimum months did not vary across climate regions , nor changed from 1980 to 2016 . Under five years , seasonality of all-cause mortality largely disappeared after the 1990s . In adolescents and young adults , especially in males , death rates peaked in June/July and were lowest in December/January , driven by injury deaths . 
 It is well-established that d

In [5]:
df.loc[0, 'headings']

['Abstract', 'Introduction', 'Results', 'Discussion', 'Materials and methods']

In [6]:
df.loc[0, 'keywords']

['epidemiology and global health']

In [7]:
df['headings'].value_counts()

headings
[Abstract, Introduction, Results, Discussion, Materials and methods]                                                                                                   3484
[Abstract, Introduction, Results, Discussion, Materials and methods]                                                                                                    250
[Abstract, Introduction, Results and discussion, Materials and methods]                                                                                                 231
[Abstract, Introduction, Results, Discussion, Material and methods]                                                                                                      78
[Abstract, Introduction, Results, Discussion]                                                                                                                            49
                                                                                                                                   

In [8]:
set(df['headings'].explode().tolist())

{'',
 '1\xa0Introduction',
 '2 Adhesiveness as a quantitative trait affecting group formation and function',
 '3 Adaptive dynamics of adhesiveness',
 '4 The evolution of adhesiveness by attachment',
 '5 Discussion',
 'Abstract',
 'Accession codes',
 'Accession numbers',
 'Acknowledgments',
 'Analysis',
 'Bacterial films',
 'Bacterial swarms',
 'Bradykinin-induced local pulmonary angioedema',
 'Building the stimulus space',
 'Clinical observations',
 'Computational\xa0methods',
 'Conclusion',
 'Conclusions',
 'Database depositions',
 'Deep learning networks',
 'Description',
 'Differential diagnosis',
 'Discussion',
 'Discussion and conclusions',
 'Discussions',
 'Ethics statement',
 'Experiment 1 – Defining a relative value perception curve',
 'Experiment 2 – ruling out alternative explanations using scent training',
 'Experiment 3 – expectation setting via trophallaxis: the nest as an information hub',
 'Experimental background',
 'Experimental procedures',
 'Hypothesis and Results',


In [9]:
imp_sections = ['Abstract','Introduction', 'Conclusion']

def get_relevant_text(article, headings):
    section_ids = [idx for imp_section in imp_sections for idx, section in enumerate(headings) if imp_section in section]
    article_split = article.split("\n")
    return " ".join([article_split[id] for id in section_ids])

In [10]:
get_relevant_text(df.loc[0, 'article'], df.loc[0, 'headings'])

'In temperate climates , winter deaths exceed summer ones . However , there is limited information on the timing and the relative magnitudes of maximum and minimum mortality , by local climate , age group , sex and medical cause of death . We used geo-coded mortality data and wavelets to analyse the seasonality of mortality by age group and sex from 1980 to 2016 in the USA and its subnational climatic regions . Death rates in men and women ≥ 45 years peaked in December to February and were lowest in June to August , driven by cardiorespiratory diseases and injuries . In these ages , percent difference in death rates between peak and minimum months did not vary across climate regions , nor changed from 1980 to 2016 . Under five years , seasonality of all-cause mortality largely disappeared after the 1990s . In adolescents and young adults , especially in males , death rates peaked in June/July and were lowest in December/January , driven by injury deaths .   It is well-established that 

In [11]:
df['extracted'] = df.apply(lambda x: get_relevant_text(x['article'], x['headings']), axis=1)

In [12]:
df['extracted'][0]

'In temperate climates , winter deaths exceed summer ones . However , there is limited information on the timing and the relative magnitudes of maximum and minimum mortality , by local climate , age group , sex and medical cause of death . We used geo-coded mortality data and wavelets to analyse the seasonality of mortality by age group and sex from 1980 to 2016 in the USA and its subnational climatic regions . Death rates in men and women ≥ 45 years peaked in December to February and were lowest in June to August , driven by cardiorespiratory diseases and injuries . In these ages , percent difference in death rates between peak and minimum months did not vary across climate regions , nor changed from 1980 to 2016 . Under five years , seasonality of all-cause mortality largely disappeared after the 1990s . In adolescents and young adults , especially in males , death rates peaked in June/July and were lowest in December/January , driven by injury deaths .   It is well-established that 

In [13]:
df['extracted'].apply(lambda x: len(x.split())).describe() #95 percentile is around 1945 and 99 percentile is around 2800

count     4346.000000
mean      1172.854579
std        587.048829
min         96.000000
25%        867.250000
50%       1087.000000
75%       1364.000000
max      19729.000000
Name: extracted, dtype: float64

- w `Result` section
```
count     4346.000000
mean      6206.520249
std       2684.050603
min         96.000000
25%       4289.250000
50%       5862.500000
75%       7828.750000
max      24023.000000
Name: extracted, dtype: float64
```

## Step 2: Chunking of articles and calculating relevance 

In [14]:
import torch
import logging
from transformers import logging as transformers_logging
from transformers import LongformerModel, LongformerTokenizer

logging.basicConfig(level=logging.WARNING)
transformers_logging.set_verbosity_warning()

tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
model = LongformerModel.from_pretrained('allenai/longformer-base-4096')
model = model.to('cuda')

def get_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=4096)
    inputs = {k: v.to('cuda') for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)

    last_hidden_states = outputs.last_hidden_state
    embedding = torch.mean(last_hidden_states, dim=1)
    embedding = embedding.to('cpu')
    return embedding

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/694 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/597M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


In [15]:
import torch
from scipy.spatial.distance import cosine
from nltk.tokenize import sent_tokenize


def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def get_chunk_embeddings(article):
    sentences = sent_tokenize(article)
    sentence_groups = list(chunks(sentences, 5))
    chunk_embeddings = []
    
    for group in sentence_groups:
        chunk_text = ' '.join(group)
        embedding = get_embedding(chunk_text)
        chunk_embeddings.append((chunk_text, embedding))
    
    return chunk_embeddings

def find_similar_chunks(article, extracted_text, threshold=0.96):
    extracted_embedding = get_embedding(extracted_text)
    
    similar_chunks = []
    for chunk_text, chunk_embedding in get_chunk_embeddings(article):
        similarity = 1 - cosine(extracted_embedding.squeeze().numpy(), chunk_embedding.squeeze().numpy())
        if similarity >= threshold:
            similar_chunks.append(chunk_text)
    
    return similar_chunks

In [None]:
results = []
for index, row in df.iterrows():
    inter_results = []
    article = row['article']
    extracted_text = row['extracted']
    similar_chunks = find_similar_chunks(article, extracted_text)
    results.append(" ".join(similar_chunks))

Input ids are automatically padded from 649 to 1024 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 145 to 512 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 134 to 512 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 110 to 512 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 131 to 512 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 90 to 512 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 176 to 512 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 148 to 512 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 120 to 512 to be a multiple of `config.attention_window`: 512
Input ids are automatically padded from 151 to 512 to be a multiple of `config.att