# Label and describe using an LLM

Download and instantiate an LLM from Huggingface.

Load the LDA topic models. 

Prompt the LLM to generate a label and a description for each topic in the models.

In [1]:
import string
import re
import gc

import pandas as pd
import pickle
from transformers import pipeline
import nltk
nltk.download('punkt')

pd.set_option('display.max_colwidth', None)

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to
[nltk_data]     /home/kobv/atroncos/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Load the topic models fitted in a previous notebook.

* lda_gw: Gravitational Waves topics
* lda_cscl: Computation and Language topics

In [2]:
# LDA gravitational waves model
with open('../models/lda_gw.pickle', 'rb') as handle:
    lda_gw = pickle.load(handle)

# Ensemble LDA gravitational waves model
with open('../models/ensemble_gw.pickle', 'rb') as handle:
    ensemble_gw = pickle.load(handle)


In [3]:
# LDA computing & language model
with open('../models/lda_cscl.pickle', 'rb') as handle:
    lda_cscl = pickle.load(handle)

# Ensemble LDA computing & language model
with open('../models/ensemble_cscl.pickle', 'rb') as handle:
    ensemble_cscl = pickle.load(handle)

Get a list of all topics in the model, each topic described by MAX_WORDS 

* The result is a list of topics. Each topic is represented by a tuple.
* The first element of the tuple is a topic number (int).
* The second element of the tuple is a list of tuples,
* Each tuple represents the words characterising he topic (string) and its corresponding probability (float)

In [4]:
MAX_WORDS = 30

# The expected format for the topics list is:
# list[tuples<int, list[tuple<string, float>]>]

# a Gensim LDA model
topics_gw = lda_gw.show_topics(num_words=MAX_WORDS, formatted=False)

# an Ensemble lDA model, has to be converted to Gensim LDA first
topics_ensemble_gw = ensemble_gw.generate_gensim_representation().show_topics(num_topics=-1, num_words=MAX_WORDS, formatted=False)

In [5]:
# a Gensim LDA model
topics_cscl = lda_cscl.show_topics(num_words=MAX_WORDS, formatted=False)

# an Ensemble lDA model, has to be converted to Gensim LDA first
topics_ensemble_cscl = ensemble_cscl.generate_gensim_representation().show_topics(num_topics=-1, num_words=MAX_WORDS, formatted=False)

### Using Mistral gated model

1. Go to huggingface, login, go to `settings/access tokens` 
2. Create a new READ token, save it to ../token.txt
3. Go here: https://huggingface.co/mistralai/Mistral-7B-v0.1 and accept the usage conditions

In [6]:
from huggingface_hub import login
with open('../token.txt', 'r') as handle:
    token = handle.read()
login(token=token)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /home/kobv/atroncos/.cache/huggingface/token
Login successful


In [7]:
def get_topic_str(topic, max=None):
    """Return the terms describing a topic as a string
    topic: list of tuples<string, float>
    """
    if not max:
        resp = ', '.join([term[0] for term in topic[1]])
    else:
        resp = ', '.join([term[0] for term in topic[1][:max]])    
    return(resp)

An wrapper class for the Mistral LLM.

In [8]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import disk_offload
import torch
from abc import ABC, abstractmethod

class Promptable(ABC):
    @abstractmethod
    def one_shot(self, prompt): pass

    @abstractmethod
    def one_shot_label(self, prompt): pass

    @abstractmethod
    def one_shot_description(self, prompt): pass

class Mistral(Promptable):

    def __init__(self):
        self.model_id="mistralai/Mistral-7B-Instruct-v0.2"

    def one_shot(self, prompt, max_new_tokens):
        model_inputs = self.tokenizer([prompt], return_tensors='pt').to('cuda')
        generated_ids = self.model.generate(**model_inputs, pad_token_id=self.tokenizer.eos_token_id, max_new_tokens=max_new_tokens, do_sample=False, num_beams=1)
        decoded_outputs = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        resp = nltk.sent_tokenize(decoded_outputs.strip())[-1]  # Drop prompt
        return resp

    def one_shot_label(self, prompt):
        resp = self.one_shot(prompt, 10)
        # cleanup to get a clean label
        resp = re.sub(r'[^A-Za-z0-9 \-\.:]+', '', resp)  # remove non-alphanumeric chars except space, hyphen, dot, column
        resp = re.split('\.|:', resp)[0] # split into sentences, take first one
        return resp.title()

    def one_shot_description(self, prompt):
        resp = self.one_shot(prompt, 100)
        return resp
    
    def __enter__(self):
        self.model = AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.float16, device_map = 'auto')
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        return self

    def __exit__(self, *args):
        del self.model
        del self.tokenizer
        # release GPU memory
        gc.collect()  # explicitly call garbage collector
        torch.cuda.empty_cache()


A function to create labels for all topics

In [9]:
def predict_topic_labels(category, topics, model):
    """
    Predict label for a list of topics.

    category: an arxiv category, see https://arxiv.org/category_taxonomy, e.g. "General Relativity and Quantum Cosmology"
    topics: topics in an LDA model, obtained through lda_XX.show_topics(num_words=MAX_WORDS, formatted=False)
    model: Object implementing abstract class "Promptable" (see above)
    returns: dataFrame with columns: topic id, label
    """
    topic_labels = []  # topic label, string, generated by LLM
    topic_ids = []  # topic id, numeric
    topic_main_words = []  # the first 5 keywords, as string
    topic_descriptions = []  # description of a topic
    
    topics_range = [topic[0] for topic in topics]
    for count, topic in enumerate(topics):
        print(f"Processing topic {count} / {len(topics_range)}")
        topic_id = topic[0]
        terms = get_topic_str(topic) # all keywords, as string

        # label
        prompt = f"What concise and human-readable label best describes the topic in the \"{category}\" category characterised by these terms: {terms}? Output only the label."
        label = model.one_shot_label(prompt)
        topic_labels.append(label)

        # numeric topic id
        topic_ids.append(topic_id)

        # topic keywords
        topic_main_words.append(get_topic_str(topic, 5))

        # description
#        prompt = f"Describe the topic in the \"{category}\" category characterised by these terms: {terms}."
#        description = model.one_shot_description(prompt)
#        topic_descriptions.append(description)
        
    return(pd.DataFrame.from_dict({'Topic': topic_ids, 'First 5 keywords': topic_main_words, 'Label': topic_labels}))

### Topics for Gravitational Waves LDA

In [10]:
%%time

with Mistral() as model:
    topics_gw_df = predict_topic_labels("General Relativity and Quantum Cosmology", topics_gw, model)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.97it/s]
Some parameters are on the meta device device because they were offloaded to the cpu.


Processing topic 0 / 4
Processing topic 1 / 4
Processing topic 2 / 4
Processing topic 3 / 4
CPU times: user 2min 26s, sys: 5.42 s, total: 2min 31s
Wall time: 2min 22s


In [11]:
topics_gw_df

Unnamed: 0,Topic,First 5 keywords,Label
0,0,"detector, signal, data, noise, frequency",Gravitational Wave Detection
1,1,"binary, hole, mass, black, star",Binary Black Hole Merger
2,2,"model, spectrum, energy, dark, background",Cosmic Microwave Background
3,3,"field, theory, mode, gravity, equation",General Relativity And Quantum Gravity


### Topics for Gravitational Waves ensemble LDA

In [13]:
%%time

with Mistral() as model:
    topics_ensemble_gw = predict_topic_labels("General Relativity and Quantum Cosmology", topics_ensemble_gw, model)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.84it/s]
Some parameters are on the meta device device because they were offloaded to the cpu.


Processing topic 0 / 12
Processing topic 1 / 12
Processing topic 2 / 12
Processing topic 3 / 12
Processing topic 4 / 12
Processing topic 5 / 12
Processing topic 6 / 12
Processing topic 7 / 12
Processing topic 8 / 12
Processing topic 9 / 12
Processing topic 10 / 12
Processing topic 11 / 12
CPU times: user 7min 2s, sys: 5.92 s, total: 7min 8s
Wall time: 6min 57s


In [15]:
topics_ensemble_gw.sort_values(by='Topic')

Unnamed: 0,Topic,First 5 keywords,Label
0,0,"hole, black, binary, mass, spin",Black Hole Binaries
1,1,"search, signal, detector, data, ligo",Gravitational Wave Detection
2,2,"pulsar, timing, array, noise, data",Pulsar Timing Array
3,3,"star, neutron, merger, mass, binary",Neutron Star Mergers
4,4,"mode, star, instability, frequency, neutron",Rotating Neutron Star Inst
5,5,"ray, gamma, burst, energy, emission",Gamma-Ray Bursts Gr
6,6,"binary, source, distance, parameter, mass",Binary Black Hole Merger
7,7,"theory, gravity, general, scalar, field",General Theory Of Gravity
8,8,"model, dark, spectrum, inflation, matter",Dark Matter And Inflation Model
9,9,"transition, phase, model, order, electroweak",Transition Between Early Universe Phases


### Topics Computing & Language LDA

In [16]:
%%time

with Mistral() as model:
    topics_cscl_df = predict_topic_labels("Computation and Language", topics_cscl, model)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.91it/s]
Some parameters are on the meta device device because they were offloaded to the cpu.


Processing topic 0 / 6
Processing topic 1 / 6
Processing topic 2 / 6
Processing topic 3 / 6
Processing topic 4 / 6
Processing topic 5 / 6
CPU times: user 3min 32s, sys: 4.25 s, total: 3min 37s
Wall time: 3min 26s


In [17]:
topics_cscl_df

Unnamed: 0,Topic,First 5 keywords,Label
0,0,"data, research, user, analysis, text",Text Analysis Tools For Social Media
1,1,"task, data, training, performance, learning",Machine Learning For Text Data Processing
2,2,"question, llm, human, knowledge, task",Question-Answering Llm
3,3,"translation, speech, english, data, machine",Multilingual Machine Learning System For Spe
4,4,"word, based, method, sentence, representation",Word-Based Neural Network Methods For Sentence
5,5,"image, text, speech, visual, feature",Multimodal Information Processing


### Topics Computing & Language ensemble LDA

In [18]:
%%time

with Mistral() as model:
    topics_ensemble_cscl = predict_topic_labels("Computation and Language", topics_ensemble_cscl, model)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.01it/s]
Some parameters are on the meta device device because they were offloaded to the cpu.


Processing topic 0 / 15
Processing topic 1 / 15
Processing topic 2 / 15
Processing topic 3 / 15
Processing topic 4 / 15
Processing topic 5 / 15
Processing topic 6 / 15
Processing topic 7 / 15
Processing topic 8 / 15
Processing topic 9 / 15
Processing topic 10 / 15
Processing topic 11 / 15
Processing topic 12 / 15
Processing topic 13 / 15
Processing topic 14 / 15
CPU times: user 8min 43s, sys: 6.33 s, total: 8min 49s
Wall time: 8min 39s


In [19]:
topics_ensemble_cscl.sort_values(by='Topic')

Unnamed: 0,Topic,First 5 keywords,Label
0,0,"translation, machine, data, nmt, neural",Machine Learning For Multilingual Sent
1,1,"question, answer, answering, task, reasoning",Question Answering And Text Understanding
2,2,"llm, task, large, performance, prompt",Large-Scale Llm Performance
3,3,"speech, data, task, recognition, training",Speech Recognition System
4,4,"bias, gender, data, task, based",Gender Bias In Large Language
5,5,"dialogue, task, state, system, human",Dialogue System For Task-Oriented Con
6,6,"style, knowledge, task, transfer, text",Style And Knowledge Transfer In Text
7,7,"evaluation, human, metric, task, summarization",Automatic Text Summarization Evaluation Met
8,8,"topic, approach, document, method, word",Topic Modeling
9,9,"event, argument, extraction, task, method",Event-Based Argument Extraction And
