<a href="https://colab.research.google.com/github/amandakonet/amicus-iv/blob/main/nlp/topic_modeling_bert_based_uncased.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Topic Modeling

Using BERTopic

In [1]:
model_checkpoint = 'bert-base-uncased'

## Set up environment

In [None]:
!pip install transformers==4.16.2
!pip install torch
!pip install datasets
!pip install bertopic

you'll need to enable GPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select GPU from the Hardware Accelerator drop-down

In [48]:
import pandas as pd
import numpy as np
from html import unescape

from bertopic import BERTopic

from transformers import pipeline
from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset, load_metric, Dataset

from huggingface_hub import notebook_login

## Data

BERTopic function takes a list of documents, so we need to set this up ourselves. 

### Option 1: Tokenize text then decode back to original text

In [4]:
!git config --global credential.helper store
# get access token on Huggingface website > settings > access token (make sure it's a write token)
notebook_login()

Login successful
Your token has been saved to /root/.huggingface/token


Read in HF dataset

In [None]:
ds_path = 'repro-rights-amicus-briefs/repro-rights-amicus'
# use_auth_token must be true bc this is a private dataset
ds = load_dataset(ds_path, use_auth_token=True)

# remove html characters
ds = ds.map(
    lambda x: {"text": [unescape(o) for o in x["text"]]}, batched=True
)

Tokenize

In [6]:
#instantiate tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# split documents into text of size 512 tokens
def tokenize_and_split(examples):
    result = tokenizer(
        examples["text"],
        truncation = True,
        max_length = 510,#512,
        stride = 128,
        return_overflowing_tokens = True,
        padding = 'max_length'
    )
    # Extract mapping between new and old indices
    sample_map = result.pop("overflow_to_sample_mapping")
    for key, values in examples.items():
        result[key] = [values[i] for i in sample_map]
    return result

# tokenize
tokenized_ds = ds.map(tokenize_and_split, batched = True, batch_size = 100)

# decode tokenized text back to original text 
def decode_chunks(example):
  result = tokenizer.batch_decode(
      example['input_ids'],
      skip_special_tokens=True,
      clean_up_tokenization_spaces=True
  )
  example['text_chunk'] = result
  return example

# decode
tokenized_ds = tokenized_ds.map(decode_chunks, batched=True, batch_size=100)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/90 [00:00<?, ?ba/s]

  0%|          | 0/39 [00:00<?, ?ba/s]

  0%|          | 0/32 [00:00<?, ?ba/s]

Put document chunks into a list (since bertopic model only takes lists)

In [7]:
# new way using decoded tokenized text
sequences = tokenized_ds['train']['text_chunk'] + tokenized_ds['valid']['text_chunk'] + tokenized_ds['test']['text_chunk']
case = tokenized_ds['train']['case'] + tokenized_ds['valid']['case'] + tokenized_ds['test']['case']
brief_ids = tokenized_ds['train']['id'] + tokenized_ds['valid']['id'] + tokenized_ds['test']['id']
brief_names = tokenized_ds['train']['brief'] + tokenized_ds['valid']['brief'] + tokenized_ds['test']['brief']
brief_party = tokenized_ds['train']['brief_party'] + tokenized_ds['valid']['brief_party'] + tokenized_ds['test']['brief_party']

# check we have the results we expect
print(type(sequences))
print(type(sequences[0]))
print(len(sequences))

<class 'list'>
<class 'str'>
15911


### Option 2: split text by words 

Define function to split text into 512 words. Since we aren't using huggingface pipelines, we have to make this rough cut and be okay with the fact that we're introducing inefficiencies into our process. 

In [None]:
def split_text(text, n):
  # split text on space
  text = text.split()
  # grab tokens back into strings, with n words each 
  text = [' '.join(text[i:i+n]) for i in range(0,len(text),n)]

  return text

In [None]:
n = 512
df_512 = df.copy()
df_512['txt_split'] = df_512.apply(lambda row: split_text(row['txt_short'], n), axis=1)
df_512 = df_512.explode('txt_split')
df_512.drop('txt_short', axis=1, inplace=True)
df_512.rename({'txt_split': 'text'}, axis=1, inplace=True)
len(df_512)

11804

In [None]:
df_512.head(1)

Unnamed: 0,case,brief,id,text
0,Anders v Floyd,Anders v Floyd - amicus brief for appellant (o...,861815186515,many roe v wade killings are murder the eviden...


Make a list of documents -- do not shuffle! 

In [None]:
list_512 = list(df_512['text'])

## Training

Instantiate BERTopic, set language to english. Note we aren't doing any fine-tuning here. 

In [8]:
topic_model = BERTopic(language = 'english', calculate_probabilities=True, verbose=True)

"Train"

Note the `fit_transform` function can take either a list of documents or pre-trained document embeddings. 

In [None]:
topics, probs = topic_model.fit_transform(sequences)

In [38]:
# map each chunk to topic with highest prob
new_topics = list(np.argmax(probs, axis=1)[1:])

15910

## Extracting Topics

Topics by frequency -- note that topic -1 can be ignored.

In [None]:
freq = topic_model.get_topic_info()
freq.head(20)

We can examine some of these topics more closely

In [13]:
topic_model.get_topic(1)

[('undue', 0.025939446907108525),
 ('burden', 0.024608088537923983),
 ('casey', 0.015589724362916887),
 ('test', 0.015122420000045088),
 ('balancing', 0.011077302373755463),
 ('obstacle', 0.009965190373669715),
 ('regulation', 0.008960288375416188),
 ('standard', 0.008484842616708536),
 ('burdens', 0.008414628756562228),
 ('scrutiny', 0.00835415890086354)]

In [16]:
topic_model.get_topic(9)

[('psychological', 0.014608103268376624),
 ('study', 0.011408455676109073),
 ('mental', 0.01084958143738922),
 ('psychiatric', 0.009770017219268741),
 ('studies', 0.009441316404368514),
 ('women', 0.009414834417946938),
 ('abortion', 0.008374241686715404),
 ('suicide', 0.00775442284251304),
 ('pregnancy', 0.007588681215265598),
 ('effects', 0.00758590443470102)]

In [18]:
topic_model.get_topic(10)

[('pennsylvania', 0.015507681747797597),
 ('viable', 0.011728784074544144),
 ('viability', 0.011425569958214145),
 ('fetus', 0.00995848704489944),
 ('3210', 0.009955145502092262),
 ('physician', 0.009865931248267415),
 ('ul0', 0.009407664067993845),
 ('b0', 0.008979926896457564),
 ('section', 0.008638244065021861),
 ('colautti', 0.0069539405985212825)]

And, we can search for generated topics that are similar to an input search term. Note that this does not generate new topics; it looks for closely related topics already constructed. 

In [19]:
similar_topics, similarity = topic_model.find_topics("physician", top_n=5); similar_topics

[78, 185, 109, 43, 154]

In [20]:
topic_model.get_topic(78)

[('patient', 0.02885336486276866),
 ('informed', 0.02317468782756879),
 ('consent', 0.018714302006654557),
 ('disclosure', 0.017975871812715483),
 ('treatment', 0.015422649210622006),
 ('physician', 0.015238578145617701),
 ('information', 0.013503868893424989),
 ('patients', 0.012436537164105825),
 ('medical', 0.011547385778118067),
 ('2d', 0.00987861129709309)]

In [23]:
topic_model.get_topic(185)

[('privileges', 0.04162694954377481),
 ('hospital', 0.03326567794055808),
 ('hospitals', 0.024645481214837277),
 ('admitting', 0.021964369363969855),
 ('credentialing', 0.01741717007979924),
 ('staff', 0.016819683528805324),
 ('physicians', 0.016040354680456274),
 ('doe', 0.015951893259549214),
 ('roa', 0.014462487500781575),
 ('outpatient', 0.014367540629381035)]

In [27]:
similar_topics, similarity = topic_model.find_topics("medical", top_n=5)
print(similar_topics)
topic_model.get_topic(110)

[78, 109, 185, 75, 110]


[('illinois', 0.027474389882757345),
 ('funding', 0.02024761176842546),
 ('medically', 0.01746654436707044),
 ('necessary', 0.01504166437359348),
 ('medicaid', 0.014423541577298039),
 ('idph', 0.011870809607115358),
 ('indigent', 0.0090025732294882),
 ('abortions', 0.008963844140291848),
 ('xix', 0.008190285261000926),
 ('medical', 0.007650652627652126)]

## Visualize topics in space

This first visualization collapses our topics onto two dimensions so we can visually examine which topics are similar to one another. These could be grouped to reduce our topic dimensionality. Note that this is an interactive visual.

In [None]:
topic_model.visualize_topics()

## Topic hierarchy

Another way to visually examine how topcis are related to one another. Just from looking on this, I think it would make more sense to topic model pro-women and pro-opp briefs separately, since they often use similar langague/topics but are articulating very different points on them! 

In [None]:
topic_model.visualize_hierarchy(top_n_topics=25)

## Topic Similarity

Having generated topic embeddings, through both c-TF-IDF and embeddings, we can create a similarity matrix by simply applying cosine similarities through those topic embeddings. The result will be a matrix indicating how similar certain topics are to each other.

In [None]:
topic_model.visualize_heatmap(n_clusters=10, width=1000, height=1000)

## Reduce n topics

This is a manual decision

In [None]:
#new_topics, new_probs = topic_model.reduce_topics(list_512, topics, probs, nr_topics=60)

# Part 2: Seed topics

https://maartengr.github.io/BERTopic/api/bertopic.html

In [39]:
seed_topic_list = [['physician', 'doctor', 'medical professional', 'medical expert']]
seed_topic_model = BERTopic(language = 'english', calculate_probabilities=True, verbose=True,
                            seed_topic_list = seed_topic_list)
seed_topics, seed_probs = seed_topic_model.fit_transform(sequences)

Batches:   0%|          | 0/498 [00:00<?, ?it/s]

2022-03-15 19:26:16,469 - BERTopic - Transformed documents to Embeddings


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2022-03-15 19:26:39,760 - BERTopic - Reduced dimensionality with UMAP
2022-03-15 19:27:01,233 - BERTopic - Clustered UMAP embeddings with HDBSCAN


In [None]:
seed_freq = seed_topic_model.get_topic_info()
seed_freq.head(10)

# Part 3: Use fine-tuned transformer

Flair allows you to choose almost any 🤗 transformers model. Simply select any from here and pass it to BERTopic:

In [None]:
!pip install bertopic[flair]

In [58]:
from flair.embeddings import TransformerDocumentEmbeddings
#import flair.embeddings

So, we can use our fine-tuned model here! Here, we use bert-base-uncased finetuned on our reproductive rights amicus.

Note that you have to make the model public in order to do this. 

Takes 9 minutes.

In [62]:
# init embeddings and model
bbu_ft_embed = TransformerDocumentEmbeddings('repro-rights-amicus-briefs/bert-base-uncased-finetuned-RRamicus')
bbu_ft_tm = BERTopic(embedding_model=bbu_ft_embed, language = 'english', calculate_probabilities=True, verbose=True)
bbu_ft_topics, bbu_ft_probs = bbu_ft_tm.fit_transform(sequences)

Some weights of the model checkpoint at repro-rights-amicus-briefs/bert-base-uncased-finetuned-RRamicus were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at repro-rights-amicus-briefs/bert-base-uncased-finetuned-RRamic

In [64]:
bbu_ft_freq = bbu_ft_tm.get_topic_info()
bbu_ft_freq.head(10)

Unnamed: 0,Topic,Count,Name
0,-1,6333,-1_the_of_to_in
1,0,2060,0_speech_injunction_petitioners_clinic
2,1,1914,1_university_of_at_in
3,2,431,2_casey_undue_burden_health
4,3,301,3_title_regulations_program_funds
5,4,274,4_murder_supreme_unborn_constitution
6,5,208,5_texas_women_clinics_http
7,6,154,6_minor_parental_notification_physician
8,7,144,7_my_me_was_had
9,8,108,8_roe_constitutional_court_wade


In [66]:
similar_topics, similarity = bbu_ft_tm.find_topics("physician", top_n=10)
print(similar_topics)
bbu_ft_tm.get_topic(85)

[85, 13, 28, 21, 53, 6, 125, 82, 128, 140]


[('health', 0.011792731364578753),
 ('patient', 0.01132847222756928),
 ('akron', 0.010619832111567775),
 ('consent', 0.010106135967386624),
 ('women', 0.01009158532917834),
 ('information', 0.00986812257325628),
 ('physician', 0.009786333138676714),
 ('informed', 0.009295692147359681),
 ('medical', 0.009249887099018759),
 ('woman', 0.007901742772512991)]

In [84]:
representative_docs = bbu_ft_tm.get_representative_docs(85)
representative_docs

["the missouri statute, by vesting a fetus with legal rights and by restricting abortion services, unquestionably undermines sound medical practice. presently, the ability of physicians and other reproductive health professionals to individualize their counsel and care entails balancing the many medical, genetic, and emotional factors involved in each woman's pregnancy. if the professional's legal obligations were expanded to mandate absolute protection of the fetus, the health care provider would be forced to * 15 trade off the life and health of the woman - - a demand inconsistent with the constitution, medical practice, and professional ethics. similarly, if the health professional is restricted, out of concern for the fetus, from freely providing information about abortion and from offering it as a medical option - - as the missouri statute also dictates - - the very core of reproductive health care will be eroded at great cost to the medical rights and lives of women. argument thi

In [85]:
bbu_ft_tm.visualize_topics()

We can divide up the topics into those that appear in one class vs the other (fem briefs and opp briefs)

In [75]:
topics_per_class = bbu_ft_tm.topics_per_class(sequences, bbu_ft_topics, brief_party)
topics_per_class.head(10)

Unnamed: 0,Topic,Words,Frequency,Class
0,-1,"the, of, to, in, abortion",3693,0
1,0,"speech, or, injunction, hobbs, the",991,0
2,1,"ul0, b0, of, at, in",1005,0
3,2,"casey, undue, burden, at, regulations",240,0
4,3,"title, program, funds, regulations, 1008",139,0
5,4,"murder, supreme, unborn, constitution, person",273,0
6,5,"texas, hb2, clinics, women, http",48,0
7,6,"minor, parental, notification, parents, decision",68,0
8,7,"my, me, was, baby, had",47,0
9,8,"roe, constitutional, court, wade, this",69,0


In [None]:
#fem_brief_bbu_topics = topics_per_class[topics_per_class['Class']==1].drop(['Class'],axis=1,inplace=False)
bbu_ft_tm.visualize_topics_per_class(topics_per_class, top_n_topics=5, normalize_frequency=True)