In [1]:
from sentence_transformers import SentenceTransformer, util
import pandas as pd

model = SentenceTransformer('all-MiniLM-L6-v2')

df = pd.read_csv('../data/filtered_events_country_code.csv', delimiter=',')

topics = [
    "women rights: Protests advocating for gender equality, reproductive rights, protection from gender-based violence, equal pay, and representation in politics and leadership roles.",
    "climate: Demonstrations focused on climate change, demanding urgent action to reduce carbon emissions, transition to renewable energy, and hold governments or corporations accountable for environmental damage.",
    "labor rights: Protests by workers or labor unions for fair wages, better working conditions, protection from exploitation, the right to organize, and opposition to layoffs or union-busting.",
    "health care: Protests demanding access to affordable and quality health care, opposition to privatization, or calling attention to deficiencies in hospitals, medical services, or health insurance systems.",
    "farmers: Protests by agricultural workers or rural communities over land rights, crop pricing, subsidies, agricultural reforms, or environmental impacts on farming.",
    "environment: Protests against pollution, deforestation, mining, industrial projects, or government inaction that harms natural ecosystems or biodiversity.",
    "public services: Demonstrations in support of better public transport, utilities, postal services, sanitation, or opposition to budget cuts or privatization of essential services.",
    "palestine-israel conflict: Protests addressing the Israeli-Palestinian conflict, including calls for ceasefire, condemnation of military actions, support for Palestinian or Israeli civilians, or broader geopolitical stances.",
    "immigration: Protests concerning immigration policies, refugee treatment, deportation, border enforcement, asylum rights, and xenophobia or anti-immigrant rhetoric.",
    "unjust law enforcement: Demonstrations against police brutality, racial profiling, excessive use of force, lack of accountability, or systemic abuse by law enforcement.",
    "ukraine-russia war: Protests related to the conflict in Ukraine, including opposition to Russian military aggression, support for Ukrainian sovereignty, calls for peace, or criticism of international responses.",
    "discrimination: Protests against racial, religious, ethnic, or disability-based discrimination, including civil rights marches, anti-hate movements, and demands for equality and justice.",
    "education: Demonstrations for accessible, equitable, and quality education; opposition to budget cuts, tuition hikes, or curriculum changes; or support for teachers and students.",
    "housing: Protests calling for affordable housing, rent control, opposition to evictions or gentrification, or action on homelessness and housing insecurity.",
    "culture: Protests about cultural preservation, opposition to censorship or erasure of indigenous or minority identities, or resistance to cultural appropriation or defunding of arts programs.",
    "policies: Protests targeting government or institutional policies, including new laws, reforms, or administrative decisions perceived as harmful, unjust, or controversial.",
    "animal welfare: Demonstrations advocating for animal rights, opposing factory farming, animal testing, abuse, or calling for stronger animal protection laws and ethical treatment.",
    "pandemic: Protests related to COVID-19 or other public health crises, including lockdowns, vaccine mandates, health measures, or economic relief demands.",
    "lgbtq: Protests for LGBTQ rights, including marriage equality, anti-discrimination protections, trans rights, and opposition to anti-LGBTQ legislation or rhetoric."
]

topic_embeddings = model.encode(topics, convert_to_tensor=True)

def assign_topic(text):
    note_embedding = model.encode(text, convert_to_tensor=True)
    cosine_scores = util.cos_sim(note_embedding, topic_embeddings)
    best_topic_idx = cosine_scores.argmax().item()
    return topics[best_topic_idx]
 
df['predicted_topic'] = df['notes'].apply(assign_topic)

In [2]:
df['topic'] = df['predicted_topic'].str.split(':').str[0]
df = df[['notes', 'topic']]
df.to_csv("../data/notes_with_inferred_topics_ESS_strictly_slides.csv", index=False)

In [3]:
import plotly.graph_objects as go
from plotly.offline import iplot, init_notebook_mode
import plotly.express as px
import plotly.io as pio

init_notebook_mode(True)

fig = px.bar(x=df["topic"].value_counts().index,y=df["topic"].value_counts(),color=df["topic"].value_counts().index,text=df["topic"].value_counts())
fig.update_traces(hovertemplate="Category:'%{x}' Counted: %{y}")
fig.update_layout(title={"text":"Category Counts","x":0.5,"font":{"size":35}},xaxis={"title":"Category","showgrid":False},yaxis={"title":"Value","showgrid":False},plot_bgcolor="white",width=800,height=500,showlegend=False)
iplot(fig)

In [5]:
df_true = pd.read_csv("../data/topics_evaluation.csv")

df_merged = pd.merge(df_true, df[['notes', 'topic']], on='notes', how='inner')


from sklearn.metrics import classification_report, accuracy_score

df_eval = df_merged[df_merged['class'] != 'unknown']

y_true = df_eval['class']
y_pred = df_eval['topic']

print("Accuracy:", accuracy_score(y_true, y_pred))
print("\nClassification Report:\n", classification_report(y_true, y_pred))

Accuracy: 0.5871156661786238

Classification Report:
                            precision    recall  f1-score   support

           animal welfare       0.67      0.86      0.75         7
                  climate       0.73      0.83      0.77        29
                  culture       0.57      0.44      0.50         9
           discrimination       0.03      0.33      0.06         3
                education       0.86      0.36      0.51        53
              environment       0.48      0.58      0.53        24
                  farmers       0.86      0.92      0.89        53
              health care       0.17      0.15      0.16        27
                  housing       0.28      0.88      0.42         8
              immigration       0.09      0.50      0.15         2
             labor rights       0.96      0.39      0.55       205
                    lgbtq       0.77      0.83      0.80        12
palestine-israel conflict       0.97      0.90      0.94        40
       