In [2]:
import pandas as pd
from transformers import AutoModelForSequenceClassification, TFAutoModelForSequenceClassification
from transformers import AutoTokenizer
import numpy as np
from scipy.special import expit

    
MODEL = f"cardiffnlp/tweet-topic-21-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForSequenceClassification.from_pretrained(MODEL)
class_mapping = model.config.id2label

Downloading (…)okenizer_config.json:   0%|          | 0.00/1.30k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.88k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/499M [00:00<?, ?B/s]

In [25]:
def get_most_likely_topics(comment:str, length=200)->str:
    """Function for getting a list of topics for a comment. List is 
    then joined on ; to get it all in one column. Idea is to use apply 
    on the commend dataset column."""
    text = comment[:length]
    tokens = tokenizer(text, return_tensors='pt')
    output = model(**tokens)
    scores = output[0][0].detach().numpy()
    scores = expit(scores)
    predictions = (scores >= 0.5) * 1
    return(";".join([class_mapping[i] for i in range(len(predictions)) if predictions[i]]))

get_most_likely_topics("food is tasty at grandma's house")

'food_&_dining'

In [26]:
##Note: runtime 1 minute for one day's data.
df=pd.read_csv("../data_for_plotting/comment_data.csv")
df["topics"] = df.text.apply(get_most_likely_topics)

In [35]:
import plotly.express as px
import kaleido

df['individual_topics'] = df['topics'].str.split(';')
df_expanded = df.explode('individual_topics')
topic_counts = df_expanded['individual_topics'].value_counts()
fig = px.pie(topic_counts, names=topic_counts.index, values=topic_counts.values, title='Topic Distribution')
fig.show()