Load the train dataset, and the train corpus

Load the topic models fitted in a previous notebook.
* lda_gw: Gravitational Waves topics

Load the tokenized train dataset

4. Assign topics to all entries in the test dataset
5. Save the assigned topics to a CSV file

In [1]:
import pickle
from transformers import pipeline
import pandas as pd
from transformers import pipeline
import random
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
nltk.download('punkt')


  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /home/atroncos/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

Load the topic models fitted in a previous notebook.
* lda_gw: Gravitational Waves topics


In [2]:
with open('../models/lda_gw.pickle', 'rb') as handle:
    lda_gw = pickle.load(handle)

Load the train dataset, and the train corpus

In [3]:
gw_train = pd.read_csv('../data/gw_train.csv.zip', index_col=0)

In [4]:
with open('../data/corpus_train_gw.pickle', 'rb') as handle:
    corpus_train_gw = pickle.load(handle)

## Assign topics to the data

Aggregate topic information in a dataframe (see: https://campus.datacamp.com/courses/fraud-detection-in-python/fraud-detection-using-text?ex=11)

In [5]:
def get_topic_details(ldamodel, corpus):
    topic_details_list = []
    for i, row in enumerate(ldamodel[corpus]):
        row = sorted(row, key=lambda x: (x[1]), reverse=True)
        for j, (topic_num, prop_topic) in enumerate(row):
            if j == 0:  # dominant topic
                wp = ldamodel.show_topic(topic_num)
                topic_details_list.append([topic_num, prop_topic, row])
    topic_details_df = pd.DataFrame(topic_details_list)
    topic_details_df.columns = ['Dominant_Topic', '% Score', 'Topics']
    return topic_details_df

In [6]:
def assign_topics(ldamodel, corpus, df):
    # put the arxiv id, original categories data and creation date in a dataframe
    # combine with result of topic details function
    topics_df = pd.DataFrame()
    topic_details = get_topic_details(ldamodel, corpus)
    topics_df['id'] = list(df['id'])
    topics_df['title'] = list(df['title'])
    topics_df['year'] = list(df['year'])
    topics_df['month'] = list(df['month'])
    topics_df['Dominant Topic'] = topic_details['Dominant_Topic']
    topics_df['% Score'] = topic_details['% Score']
    topics_df['Topics'] = topic_details['Topics']
    return topics_df

In [7]:
topics_gw_train = assign_topics(lda_gw, corpus_train_gw, gw_train)

In [8]:
topics_gw_train.head()

Unnamed: 0,id,title,year,month,Dominant Topic,% Score,Topics
0,711.025,Tail effects in the third post-Newtonian gravi...,2007,11,1,0.394228,"[(1, 0.39422807), (3, 0.35740227), (0, 0.24538..."
1,2009.08103,The Advanced Virgo Photon Calibrators,2020,9,0,0.970369,"[(0, 0.9703692), (3, 0.022496346)]"
2,1905.08286,Listening to the sound of dark sector interact...,2019,5,2,0.721912,"[(2, 0.7219119), (0, 0.2207775), (3, 0.05395951)]"
3,2004.06503,Computationally efficient models for the domin...,2020,4,1,0.358168,"[(1, 0.35816848), (3, 0.34653628), (0, 0.29264..."
4,2212.05291,Probing Minimal Grand Unification through Grav...,2022,12,2,0.715022,"[(2, 0.71502197), (1, 0.10800291), (0, 0.10706..."


Concatenate all the titles by dominant topic

In [9]:
def shuffle_titles(dominant_topic):
    """
    Concatenates all the titles for papers about this topic,
    the dominant topic is used to filter the papers. Paper titles are shuffled.
    
    dominant_topic: int id of the dominant topic for a paper
    returns: string
    """
    idx = topics_gw_train['Dominant Topic'] == dominant_topic
    topics_gw_train_0 = topics_gw_train[idx]
    topics_gw_train_0 = topics_gw_train_0.reset_index()
    topics_gw_train_0 = topics_gw_train_0.sample(frac=1)  # shuffle
    all_titles = '. '.join(topics_gw_train_0['title'])
    return all_titles

In [10]:
def predict_label_ALT(all_titles, min_length=1, max_length=12):
    """
    Predict a label for a given concatenation of titles.
    """
    #model_name = 'fabiochiu/t5-small-medium-title-generation'
    #model_name = 'deep-learning-analytics/automatic-title-generation'
    model_name = 'google-t5/t5-small'
    #model_name = 'sshleifer/distilbart-cnn-12-6'
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    
    #inputs = [f"Select a suitable label for these keywords: {all_titles}"]
    #inputs = [f"Which topic is described by these keywords (response should be between 1 and 12 words): {all_titles}"]
    inputs = [f"summarize: {all_titles}"]
    #inputs = all_titles
    
    inputs = tokenizer(inputs, max_length=512, truncation=True, return_tensors="pt")
    output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=1, max_length=12)
    decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    predicted_label = nltk.sent_tokenize(decoded_output.strip())[0]
    return(predicted_label)


In [35]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

def predict_label(all_titles, min_length=1, max_length=12):
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
    model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")

    input_text = [f"summarize: {all_titles}"]
    input_ids = tokenizer(input_text, max_length=512, truncation=True, return_tensors="pt").input_ids.to('cpu')

    outputs = model.generate(input_ids, num_beams=8, do_sample=True, min_length=min_length, max_length=max_length)
    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    predicted_label = nltk.sent_tokenize(decoded_outputs.strip())[0]

    return(predicted_label)  

In [33]:
def predict_topic_labels(df, min_length, max_length):
    topics_range = set(df['Dominant Topic'])
    labels = []
    topics = []
    for topic in topics_range:
        print(f"Processing topic {topic} / {len(topics_range)}")
        all_titles = shuffle_titles(topic)
        label = predict_label(all_titles, min_length, max_length)
        topics.append(topic)
        labels.append(label)
        print(f"label: {label}")
    return(pd.DataFrame.from_dict({'topic': topics, 'label': labels}))

In [36]:
%%time


df1 = predict_topic_labels(topics_gw_train, 1, 5)
df2 = predict_topic_labels(topics_gw_train, 1, 5)

Processing topic 0 / 4


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


label: Detection
Processing topic 1 / 4


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


label: The Golden Era
Processing topic 2 / 4
label: Observational
Processing topic 3 / 4


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


label: Summary: Theory of
Processing topic 0 / 4
label: Search for gravitation
Processing topic 1 / 4


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


label: Summary: Gravitati
Processing topic 2 / 4


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


label: Summary: Gravitati
Processing topic 3 / 4


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


label: Review of gravitation
CPU times: user 26 s, sys: 2.02 s, total: 28 s
Wall time: 9.71 s


In [28]:
df1

Unnamed: 0,topic,label
0,0,Detecting a stochastic gravitational wave from...
1,1,GW190425 Using Astrophysical Argument
2,2,X-ray model
3,3,Observation of polarized gravitational waves i...


In [29]:
df2

Unnamed: 0,topic,label
0,0,Detecting Gravitational Waves from Binary Neut...
1,1,Observational
2,2,Gravitational Waves from a universe filled wit...
3,3,Quantity


In [39]:
pd.concat([df1, df2]).reset_index()

Unnamed: 0,index,topic,label
0,0,0,Detection
1,1,1,The Golden Era
2,2,2,Observational
3,3,3,Summary: Theory of
4,0,0,Search for gravitation
5,1,1,Summary: Gravitati
6,2,2,Summary: Gravitati
7,3,3,Review of gravitation
