Load the topic models fitted in a previous notebook.
* lda_gw: Gravitational Waves topics

Load the train dataset

Load the tokenized train dataset

4. Assign topics to all entries in the test dataset
5. Save the assigned topics to a CSV file

In [2]:
import pickle
from transformers import pipeline
import pandas as pd
from transformers import pipeline


In [3]:
gw_train = pd.read_csv('../data/gw_train.csv.zip', index_col=0)

In [4]:
with open('../models/lda_gw.pickle', 'rb') as handle:
    lda_gw = pickle.load(handle)

In [5]:
with open('../data/corpus_train_gw.pickle', 'rb') as handle:
    corpus_train_gw = pickle.load(handle)

## Assign topics to the data

Aggregate topic information in a dataframe (see: https://campus.datacamp.com/courses/fraud-detection-in-python/fraud-detection-using-text?ex=11)

In [6]:
def get_topic_details(ldamodel, corpus):
    topic_details_list = []
    for i, row in enumerate(ldamodel[corpus]):
        row = sorted(row, key=lambda x: (x[1]), reverse=True)
        for j, (topic_num, prop_topic) in enumerate(row):
            if j == 0:  # dominant topic
                wp = ldamodel.show_topic(topic_num)
                topic_details_list.append([topic_num, prop_topic, row])
    topic_details_df = pd.DataFrame(topic_details_list)
    topic_details_df.columns = ['Dominant_Topic', '% Score', 'Topics']
    return topic_details_df

In [7]:
def assign_topics(ldamodel, corpus, df):
    # put the arxiv id, original categories data and creation date in a dataframe
    # combine with result of topic details function
    topics_df = pd.DataFrame()
    topic_details = get_topic_details(ldamodel, corpus)
    topics_df['id'] = list(df['id'])
    topics_df['title'] = list(df['title'])
    topics_df['year'] = list(df['year'])
    topics_df['month'] = list(df['month'])
    topics_df['Dominant Topic'] = topic_details['Dominant_Topic']
    topics_df['% Score'] = topic_details['% Score']
    topics_df['Topics'] = topic_details['Topics']
    return topics_df

In [8]:
topics_gw_train = assign_topics(lda_gw, corpus_train_gw, gw_train)

In [9]:
topics_gw_train.head()

Unnamed: 0,id,title,year,month,Dominant Topic,% Score,Topics
0,2401.02604,Precision constraints on the neutron star equa...,2024,1,4,0.670213,"[(4, 0.670213), (1, 0.17013457), (0, 0.15641485)]"
1,2112.01979,Multi-Messenger Constraints on Magnetic Fields...,2021,12,4,0.825816,"[(4, 0.8258161), (0, 0.16997717)]"
2,1810.09764,On the post-common-envelope central star of th...,2018,10,4,0.910031,"[(4, 0.9100307), (0, 0.084191985)]"
3,2106.06209,Analytical computation of quasi-normal modes o...,2021,6,3,0.53075,"[(3, 0.53075004), (0, 0.23554358), (2, 0.23075..."
4,astro-ph/0511394,Non-linear axisymmetric pulsations of rotating...,2005,11,0,0.874086,"[(0, 0.8740861), (3, 0.08919316), (1, 0.033166..."


Concatenate all the titles by dominant topic

In [10]:
import random

idx = topics_gw_train['Dominant Topic'] == 4
topics_gw_train_0 = topics_gw_train[idx]
topics_gw_train_0 = topics_gw_train_0.reset_index()
topics_gw_train_0 = topics_gw_train_0.sample(frac=1)  # shuffle
all_titles = '. '.join(topics_gw_train_0['title'])

In [11]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
nltk.download('punkt')

#model_name = 'fabiochiu/t5-small-medium-title-generation'
#model_name = 'deep-learning-analytics/automatic-title-generation'
model_name = 'google-t5/t5-small'
#model_name = 'sshleifer/distilbart-cnn-12-6'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

#inputs = [f"Select a suitable label for these keywords: {all_titles}"]
#inputs = [f"Which topic is described by these keywords (response should be between 1 and 12 words): {all_titles}"]
inputs = [f"summarize: {all_titles}"]
#inputs = all_titles

inputs = tokenizer(inputs, max_length=1024*10, truncation=True, return_tensors="pt")
output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=1, max_length=12)
decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
predicted_title = nltk.sent_tokenize(decoded_output.strip())[0]

print(predicted_title)

[nltk_data] Downloading package punkt to /home/atroncos/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Gravitational waves from supermassive black holes
