# Label and describe using an LLM

Download and instantiate an LLM from Huggingface.

Load the LDA topic models. 

Prompt the LLM to generate a label and a description for each topic in the models.

In [32]:
import string
import re
import pandas as pd
import pickle
from transformers import pipeline
import nltk
nltk.download('punkt')


[nltk_data] Downloading package punkt to
[nltk_data]     /home/kobv/atroncos/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

Load the topic models fitted in a previous notebook.

* lda_gw: Gravitational Waves topics
* lda_cscl: Computation and Language topics

In [2]:
with open('../models/lda_gw.pickle', 'rb') as handle:
    lda_gw = pickle.load(handle)

Get a list of all topics in the model, each topic described by MAX_WORDS 

* The result is a list of topics. Each topic is represented by a tuple.
* The first element of the tuple is a topic number (int).
* The second element of the tuple is a list of tuples,
* Each tuple represents the words characterising he topic (string) and its corresponding probability (float)

In [3]:
MAX_WORDS = 30
# list[tuples<int, list[tuple<string, float>]>]
topics_gw = lda_gw.show_topics(num_words=MAX_WORDS, formatted=False)

### Using gated models

1. Go to huggingface, login, go to `settings/access tokens` 
2. Create a new READ token, save it to ../token.txt
3. Go here: https://huggingface.co/mistralai/Mistral-7B-v0.1 and accept the usage conditions

In [4]:
from huggingface_hub import login
with open('../token.txt', 'r') as handle:
    token = handle.read()
login(token=token)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /home/kobv/atroncos/.cache/huggingface/token
Login successful


In [5]:
def get_topic_str(topic, max=None):
    """Return the terms describing a topic as a string
    topic: list of tuples<string, float>
    """
    if not max:
        resp = ', '.join([term[0] for term in topic[1]])
    else:
        resp = ', '.join([term[0] for term in topic[1][:max]])    
    return(resp)

In [36]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import disk_offload
import torch
from abc import ABC, abstractmethod

class Promptable(ABC):
    @abstractmethod
    def one_shot(self, prompt): pass

class Mistral(Promptable):
    def __init__(self):
        self.model_id="mistralai/Mistral-7B-Instruct-v0.2"
#        self.model = AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.float16, low_cpu_mem_usage = False)
        self.model = AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.float16, device_map = 'auto')
#        disk_offload(model=self.model, offload_dir="alpha")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
    
    def one_shot(self, prompt):
        model_inputs = self.tokenizer([prompt], return_tensors='pt').to('cuda')
#        model_inputs = model_inputs.to('meta')
#        generated_ids = self.model.generate(**model_inputs, pad_token_id=self.tokenizer.eos_token_id, max_new_tokens=25, do_sample=True)
        generated_ids = self.model.generate(**model_inputs, pad_token_id=self.tokenizer.eos_token_id, max_new_tokens=10, do_sample=False, num_beams=1)
        decoded_outputs = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        resp = nltk.sent_tokenize(decoded_outputs.strip())[2]  # split into sentences
        resp = re.sub(r'[^A-Za-z0-9 ]+', '', resp)  # remove non-alphanumeric chars
        return(resp)

In [39]:
def predict_topic_labels(category, topics, model):
    """
    Predict label for a list of topics.

    category: an arxiv category, see https://arxiv.org/category_taxonomy, e.g. "General Relativity and Quantum Cosmology"
    topics: topics in an LDA model, obtained through lda_XX.show_topics(num_words=MAX_WORDS, formatted=False)
    model: Object implementing abstract class "Promptable" (see above)
    returns: dataFrame with columns: topic id, label
    """
    topics_range = [topic[0] for topic in topics]
    topic_labels = []
    topic_ids = []
    resp = []
    for topic_id in topics_range:
        print(f"Processing topic {topic_id} / {len(topics_range)}")
        terms = get_topic_str(topics[topic_id]) # all keywords, as string
        prompt = f"What concise and human-readable label best describes the topic in the \"{category}\" category characterized by these terms: {terms}? Output only the label."
        label = model.one_shot(prompt)
        topic_labels.append(label)  # topic label, string, generated by LLM
        topic_ids.append(topic_id)  # topic id, numeric
        topic_main_words = get_topic_str(topics[topic_id], 5) # the first 5 keywords, as string
    return(pd.DataFrame.from_dict({'topic': topic_ids, 'first 5 keywords': topic_main_words, 'label': topic_labels}))

In [42]:
%%time

mistral = Mistral()
topics_gw_df = predict_topic_labels("General Relativity and Quantum Cosmology", topics_gw, mistral)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.31it/s]
Some parameters are on the meta device device because they were offloaded to the cpu.


Processing topic 0 / 4
Processing topic 1 / 4
Processing topic 2 / 4
Processing topic 3 / 4
CPU times: user 2min 31s, sys: 5 s, total: 2min 36s
Wall time: 2min 28s


In [41]:
pd.set_option('display.max_colwidth', None)
topics_gw_df

Unnamed: 0,topic,first 5 keywords,label
0,0,"field, theory, mode, gravity, equation",Gravitational Wave Detection
1,1,"field, theory, mode, gravity, equation",Binary Black Hole Merger The
2,2,"field, theory, mode, gravity, equation",Dark Energy and Inflation Model
3,3,"field, theory, mode, gravity, equation",General Relativity and Quantum Gravity


In [43]:
pd.set_option('display.max_colwidth', None)
topics_gw_df

Unnamed: 0,topic,first 5 keywords,label
0,0,"field, theory, mode, gravity, equation",Gravitational Wave Detection
1,1,"field, theory, mode, gravity, equation",Binary Black Hole Merger The
2,2,"field, theory, mode, gravity, equation",Dark Energy and Inflation Model
3,3,"field, theory, mode, gravity, equation",General Relativity and Quantum Gravity
