In [1]:
import json
import pickle
from dataclasses import dataclass

import pandas as pd
from sentence_transformers import SentenceTransformer

  from tqdm.autonotebook import tqdm, trange


In [2]:
@dataclass
class Paper:
    filename: str
    title: str = ''
    authors: str = ''
    abstract: str = ''
    keywords: str = ''
    introduction: str = ''
    
    def __repr__(self):
        return f' filename \n----------\n {self.filename}' + \
               f'\n\n title \n----------\n {self.title}' + \
               f'\n\n authors \n----------\n {self.authors}' + \
               f'\n\n abstract \n----------\n {self.abstract}' + \
               f'\n\n keywords \n----------\n {self.keywords}' + \
               f'\n\n introduction \n----------\n {self.introduction}'

In [3]:
def to_camel_case(text: str) -> str:
    # Split the text into words
    words = text.split()
    
    # Capitalize the first letter of each word except the first word
    camel_case_words = [words[0].lower()] + [word.capitalize() for word in words[1:]]
    
    # Join the words back together without spaces
    camel_case_text = ''.join(camel_case_words)
    
    return camel_case_text

def authors_to_list(authors: str) -> list[str]:

    split = authors.rsplit(' and ', maxsplit=1) # there should be at most one ' and '

    authors_list = split[0].split(',')
    authors_list = [author.strip() for author in authors_list]

    if len(split) == 2: # there should be a last author as inteded after the ' and '
        last_author = split[1].strip()
        authors_list.append(last_author)

    return authors_list

In [4]:
# Load the model

model_name = 'sentence-transformers/all-MiniLM-L6-v2'

try:
    model = SentenceTransformer(f'local-models/{model_name}')
    print('Loaded local model')
except:
    model = SentenceTransformer(model_name)
    model.save(f'local-models/{model_name}')
    print('Downloaded and saved model')



Downloaded and saved model


In [5]:
# Load the papers and encode them into embeddings

with open('papers.pkl', 'rb') as f:
    papers: list[Paper] = pickle.load(f)

papers_text = [f'Title: {paper.title} \n Abstract: {paper.abstract}' for paper in papers]

papers_emb = model.encode(papers_text)

## Categories

In [6]:
# Define the categories and predict the category of each paper

categories = [
    'Tables', 
    'Classification', 
    'Key Information Extraction',
    'Optical Character Recognition', 
    'Datasets', 
    'Document Layout Understanding', 
    'Others'
]

categories_emb = model.encode(categories)

similarities = model.similarity(papers_emb, categories_emb)

predictions = list(map(lambda index: categories[index], similarities.argmax(dim=1)))

In [7]:
# Save predictions to a CSV file for further analysis

rows = []
for paper, prediction in zip(papers, predictions):
    rows.append({'filename': paper.filename, 'title': paper.title, 'authors': paper.authors, 'category': prediction})

predictions_df = pd.DataFrame(rows)

predictions_df.to_csv('similarity-preds.csv', sep=';', index=False)

In [8]:
# Save predictions in a json file with the specified format

result = {to_camel_case(category): [] for category in categories}

for paper, prediction in zip(papers, predictions):
    result[to_camel_case(prediction)].append({"originalFileName": paper.filename, "title": paper.title, "authors": authors_to_list(paper.authors)})

with open('similarity-preds.json', 'w') as f:
    json.dump(result, f, indent=4)

## Extended Categories

In [9]:
# Extend the categories and predict the category of each paper

extended_categories = [
    'Tables are structured representations of data organized in rows and columns, often used to present numerical information, comparisons, and relationships clearly and efficiently.', 
    'Classification is the task of assigning predefined categories to text documents based on their content, enabling systematic organization and retrieval of information.', 
    'Key Information Extraction is the automatic identification and extraction of significant entities and relevant data from unstructured texts, facilitating efficient access to critical information and enhancing data organization.',
    'Optical Character Recognition is the technology used to convert different types of documents, such as scanned paper documents and images, into editable and searchable data by recognizing and extracting printed or handwritten text.', 
    'Datasets are ollections of structured or unstructured data organized for analysis and research purposes, often used in machine learning and statistical modeling to train and evaluate algorithms.', 
    'Document Layout Understanding is the process of analyzing and interpreting the structural layout of documents to extract meaningful information about the arrangement and organization of content, including text, images, tables, and other elements.', 
    'Others are any additional tasks or methodologies related to document processing and information extraction that do not fit into the predefined categories, encompassing a variety of techniques and applications.'
]

extended_categories_emb = model.encode(extended_categories)

similarities_extended = model.similarity(papers_emb, extended_categories_emb)

predictions_extended = list(map(lambda index: categories[index], similarities_extended.argmax(dim=1)))

In [10]:
# Save extended predictions to a CSV file for further analysis

rows = []
for paper, prediction in zip(papers, predictions_extended):
    rows.append({'filename': paper.filename, 'title': paper.title, 'authors': paper.authors, 'category': prediction})

predictions_extended_df = pd.DataFrame(rows)

predictions_df.to_csv('similarity-preds-ext.csv', sep=';', index=False)

In [11]:
# Save extended predictions in a json file with the specified format

result = {to_camel_case(category): [] for category in categories}

for paper, prediction in zip(papers, predictions_extended):
    result[to_camel_case(prediction)].append({"originalFileName": paper.filename, "title": paper.title, "authors": authors_to_list(paper.authors)})

with open('similarity-preds-ext.json', 'w') as f:
    json.dump(result, f, indent=4)