# **Tutorial** - (semi)-supervised topic modeling
(last updated 26-04-2021)

In this tutorial, we will be looking at a new feature of BERTopic, namely (semi)-supervised topic modeling! This allows us to steer the dimensionality reduction of the embeddings into a space that closely follows any labels you might already have. 

## Semi-supervised modeling
(semi)-supervised topic modeling is a class of methods that allows the user to perform topic modeling with previously defined labels. This might help nudge the model towards specific topics or classes for which you have labels. 

<br>

<img src="https://raw.githubusercontent.com/MaartenGr/BERTopic/master/images/logo.png" width="40%">

# Enabling the GPU

First, you'll need to enable GPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select GPU from the Hardware Accelerator drop-down

[Reference](https://colab.research.google.com/notebooks/gpu.ipynb)

# Installing BERTopic

We start by installing BERTopic from PyPi:

In [2]:
%%capture
!pip install bertopic

## Restart the Notebook
After installing BERTopic, some packages that were already loaded were updated and in order to correctly use them, we should now restart the notebook.

From the Menu:

Runtime → Restart Runtime

# **Data**
For this example, we use the popular 20 Newsgroups dataset which contains roughly 18000 newsgroups posts that each is assigned to one of 20 topics:

In [1]:
import pandas as pd
from bertopic import BERTopic
from sklearn.datasets import fetch_20newsgroups

data = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))
docs = data["data"]
targets = data["target"]
target_names = data["target_names"]
classes = [data["target_names"][i] for i in data["target"]]

In [201]:
import numpy as np
import io
from google.colab import drive
drive.mount('/content/drive')
unlabeled_data = pd.read_csv('/content/drive/My Drive/BERTopic_Unlabeled.csv')
labeled_data = pd.read_csv('/content/drive/My Drive/BERTopic_Labeled.csv')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).



Columns (28) have mixed types.Specify dtype option on import or set low_memory=False.



In [202]:
def make_string(text):
    text_preprocessed = (" ").join(text)
    return text_preprocessed

In [203]:
from ast import literal_eval
unlabeled_data['tokens'] = unlabeled_data['tokens'].apply(lambda row: literal_eval(row))
unlabeled_data['tokens'] = unlabeled_data['tokens'].apply(lambda x: make_string(x))
labeled_data['tokens'] = labeled_data['tokens'].apply(lambda row: literal_eval(row))
labeled_data['tokens'] = labeled_data['tokens'].apply(lambda x: make_string(x))

In [204]:
labeled_data = labeled_data.replace(np.nan, False)

In [205]:
labeled_data['true count'] = labeled_data[['university', 'relationships','break ups', 'divorce', 'weddings', 'death', 'family', 'friendship']].sum(axis=1)

In [206]:
labeled_data['true count'].value_counts()

1    409
0    393
2    149
3     38
4      9
5      2
Name: true count, dtype: int64

In [207]:
labeled_data['true list'] = labeled_data[['university', 'relationships','break ups', 'divorce', 'weddings', 'death', 'family', 'friendship']].apply(lambda row: row[row == True].index.tolist(), axis=1)

In [208]:
labeled_data['true list'] = labeled_data['true list'].apply(lambda x: [-1] if x == [] else x )

In [209]:
labeled_data

Unnamed: 0.3,Unnamed: 0,Unnamed: 0.2,Unnamed: 0.1,index,Unnamed: 0.1.1,bookId,title,series,author,rating,...,divorce,weddings,death,family,friendship,labeled?,Contains True?,tokens,true count,true list
0,0,0,0,0,39822,34838660-not-part-of-the-plan,Not Part of the Plan,Blue Moon #4,Lucy Score (Goodreads Author),4.46,...,False,False,False,False,False,Yes,1.0,wall street journal amazon bestselling author ...,1,[relationships]
1,1,1,1,1,34235,20176552-dragon-age-volume-1,"Dragon Age, Volume 1",Dragon Age Graphic Novels #1-3,"David Gaider, Chad Hardin (Illustrator), Antho...",4.26,...,False,False,False,False,False,Yes,0.0,helping set stage biowares hotly anticipated d...,0,[-1]
2,2,2,2,2,27904,124110.Dangerous_to_Know,Dangerous to Know,False,Barbara Taylor Bradford (Goodreads Author),3.73,...,True,False,True,True,False,Yes,1.0,sebastian locke fiftysixyearold patriarch powe...,4,"[relationships, divorce, death, family]"
3,3,3,3,3,10515,1046450.The_Wheel_of_Fortune,The Wheel of Fortune,False,Susan Howatch,4.11,...,False,False,False,True,False,Yes,1.0,take back oxmoon lost paradise childhood take ...,2,"[relationships, family]"
4,4,4,4,4,935,872333.Blue_Bloods,Blue Bloods,Blue Bloods #1,Melissa de la Cruz (Goodreads Author),3.69,...,False,False,False,False,False,Yes,0.0,mayflower set sail carried board men women wou...,0,[-1]
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,995,995,995,995,17361,588326.The_Blue_Helmet,The Blue Helmet,False,William Bell,3.42,...,False,False,False,False,True,False,1.0,lee wants tarantula member biggest powerful ga...,1,[friendship]
996,996,996,996,996,9029,93007.The_Merry_Adventures_of_Robin_Hood,The Merry Adventures of Robin Hood,False,Howard Pyle,4.07,...,False,False,False,False,False,False,0.0,merry adventures robin hood great renown notti...,0,[-1]
997,997,997,997,997,32216,1085376.Before_You_Sleep,Before You Sleep,False,"Linn Ullmann, Tiina Nunnally (Translator)",3.34,...,False,False,False,True,False,False,1.0,moving presentday oslo brooklyn sleep tells st...,1,[family]
998,998,998,998,998,1036,28195.Inkspell,Inkspell,Inkworld #2,"Cornelia Funke (Goodreads Author), Anthea Bell...",3.91,...,False,False,False,False,False,False,0.0,captivating sequel inkheart critically acclaim...,0,[-1]


In [210]:
import random

In [211]:
labeled_data['target'] = labeled_data['true list'].apply(lambda x: random.choice(x))

In [214]:
test_set = labeled_data[700:899]
full_labeled = labeled_data
labeled_data.drop(labeled_data.index[700:900],0,inplace=True)

In [217]:
labeled_data['tokens'][900]



Each document can be put into one of the following categories:

In [218]:
target_names = ['university', 'relationships','break ups', 'divorce', 'weddings', 'death', 'family', 'friendship']

In [219]:
labeled_data['target_num'] = labeled_data['target'].apply(lambda x: target_names.index(x) if x != -1 else x)

In [220]:
data = labeled_data[['tokens', 'target', 'target_num']].reset_index()

In [221]:
data

Unnamed: 0,index,tokens,target,target_num
0,0,wall street journal amazon bestselling author ...,relationships,1
1,1,helping set stage biowares hotly anticipated d...,-1,-1
2,2,sebastian locke fiftysixyearold patriarch powe...,death,5
3,3,take back oxmoon lost paradise childhood take ...,relationships,1
4,4,mayflower set sail carried board men women wou...,-1,-1
...,...,...,...,...
795,995,lee wants tarantula member biggest powerful ga...,friendship,7
796,996,merry adventures robin hood great renown notti...,-1,-1
797,997,moving presentday oslo brooklyn sleep tells st...,family,6
798,998,captivating sequel inkheart critically acclaim...,-1,-1


In [222]:
classes = data['target']
docs = data['tokens']
targets = data['target_num']

In [223]:
data


Unnamed: 0,index,tokens,target,target_num
0,0,wall street journal amazon bestselling author ...,relationships,1
1,1,helping set stage biowares hotly anticipated d...,-1,-1
2,2,sebastian locke fiftysixyearold patriarch powe...,death,5
3,3,take back oxmoon lost paradise childhood take ...,relationships,1
4,4,mayflower set sail carried board men women wou...,-1,-1
...,...,...,...,...
795,995,lee wants tarantula member biggest powerful ga...,friendship,7
796,996,merry adventures robin hood great renown notti...,-1,-1
797,997,moving presentday oslo brooklyn sleep tells st...,family,6
798,998,captivating sequel inkheart critically acclaim...,-1,-1


# **(semi)-Supervised modeling**


## Basic Model
Before we start with semi-supervised modeling, let us first take a look at the output of the basic model.

The topics that were created mostly make sense. There are some clearly defined topics such as "nasa, orbit, spacecraft, moon" but also some topics that seem mostly derived from other topics. We can visualize this by extracting the topic representations per class and see if our unsupervised model closely resembles this. 

**NOTE**: You can **hover** over the bars to see the representation per class!!

The results do seem promising. Topics like "nasa, space, etc" seem to be clearly related to sci.space, but some topics were created that span many categories. For example, we expect the topic "bike, bikes, etc"  to only appear in rec.motorcycles.  

## Semi-supervised
In the example above you might notice that some topics were somewhat smushed together. What we would like to see is a clear separation between those topics. Fortunately, we have to labels and can use them to improve the model. 

Since we are not interested in any other topics, this method is called semi-supervised topic modeling. In practice, this means that we have the labels of some documents but not all. 

For this example let's say we only have the labels of all computer-related categories:

When generating our new labels it is important to mark unknown classes as **-1**. Next, we use those newly constructed labels to again run BERTopic:

In [224]:
topic_model = BERTopic(verbose=True, calculate_probabilities=True)
topics, probs = topic_model.fit_transform(docs, y=targets)

Batches:   0%|          | 0/25 [00:00<?, ?it/s]

2022-07-27 04:32:54,838 - BERTopic - Transformed documents to Embeddings
2022-07-27 04:33:00,951 - BERTopic - Reduced dimensionality
2022-07-27 04:33:01,022 - BERTopic - Clustered reduced embeddings


In [232]:
topic_model.get_topic_info()

Unnamed: 0,Topic,Count,Name
0,-1,423,-1_life_new_one_world
1,0,75,0_novel_love_first_story
2,1,72,1_king_queen_one_must
3,2,46,2_friends_friend_best_love
4,3,43,3_killer_one_murder_detective
5,4,34,4_world_earth_city_planet
6,5,21,5_school_shes_nathan_girl
7,6,19,6_love_life_woman_sent
8,7,18,7_times_new_york_flavia
9,8,17,8_life_heart_marriage_woman


Finally, we can again extract the topics per class to see if our semi-supervised approach had some effect:

In [234]:
topics_per_class = topic_model.topics_per_class(docs, topics, classes=classes)
fig_semi_supervised = topic_model.visualize_topics_per_class(topics_per_class)
fig_semi_supervised

9it [00:01,  8.38it/s]


In [229]:
test_topic, test_probs = topic_model.transform(test_set['tokens'][700])

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2022-07-27 04:37:39,383 - BERTopic - Reduced dimensionality
2022-07-27 04:37:39,390 - BERTopic - Calculated probabilities with HDBSCAN
2022-07-27 04:37:39,396 - BERTopic - Predicted clusters


In [231]:
test_probs

array([[1.43701307e-20, 8.74104916e-15, 2.72220487e-19, 1.15641793e-14,
        9.64154277e-15, 1.04078157e-14, 9.07016529e-01, 1.50201422e-14,
        1.44848562e-20, 3.89370317e-19, 1.02133808e-14]])

We can clearly see that many more topics about computers were created and that the seperation between those topics are solid. This indicates that even if you do not have all the labels, you can definitely improve the model!

However, there are still some clusters that could be improved with the labels that we have. 

## Supervised

Finally, we are going to be using all labels. These labels help BERTopic understand where most clusters can be found. However, this does not mean that it will only find the 20 clusters that we have defined. If there are sub-clusters to be found, then there is a good chance BERTopic will find them! 

Not only do we see a nice seperation of the topics, there are significantly less outliers which shows that BERTopic has improved in connecting the documents to topics. 

Let's see the results by again visualizing the topic representation per class:

Now that we have used all labels, BERTopic seems to closely match our pre-defined labels. Moreover, it still allows to discover topics that were not previously defined. Thus, you can use this method to find unknown topics in pre-defined topics!

In [235]:
seed_topic_list = ['weddings', 'friendship', 'family', 'break ups', 'relationships', 'death', 'divorce', 'university']
topic_model = BERTopic(seed_topic_list=seed_topic_list,calculate_probabilities=True, verbose=True)
topics, probs = topic_model.fit_transform(docs, y=targets)


Batches:   0%|          | 0/25 [00:00<?, ?it/s]

2022-07-27 04:45:14,554 - BERTopic - Transformed documents to Embeddings


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2022-07-27 04:45:21,768 - BERTopic - Reduced dimensionality
2022-07-27 04:45:21,836 - BERTopic - Clustered reduced embeddings


In [238]:
probs[0]

array([2.24355509e-308, 3.44345959e-308, 3.18122837e-308, 1.48034627e-308,
       1.00000000e+000, 2.23852284e-308, 2.71970757e-308, 2.02957495e-308,
       2.25570805e-308])

In [236]:
topic_model.get_topic_info()

Unnamed: 0,Topic,Count,Name
0,-1,371,-1_one_new_life_world
1,0,111,0_love_life_shes_school
2,1,107,1_novel_life_love_story
3,2,61,2_killer_one_murder_body
4,3,51,3_queen_kingdom_one_prince
5,4,40,4_stoker_new_one_times
6,5,23,5_earth_world_planet_humankind
7,6,14,6_doctor_mack_rose_book
8,7,12,7_vampire_vampires_saba_blood
9,8,10,8_unbounded_expected_landon_never


In [237]:
topics_per_class = topic_model.topics_per_class(docs, topics, classes=classes)
fig_semi_supervised = topic_model.visualize_topics_per_class(topics_per_class, top_n_topics=10)
fig_semi_supervised

9it [00:00, 14.43it/s]


In [1]:
topic_model.get_topic_info(10)

NameError: ignored

In [50]:
similar_topics, similarity = topic_model.find_topics("university", top_n=5); 
for i in similar_topics:
  print(topic_model.get_topic(i))

[('london', 0.05192365395005623), ('city', 0.020787882842927326), ('londons', 0.015272069092610306), ('squid', 0.014952913069499684), ('jas', 0.014952913069499684), ('hannay', 0.012145605798262236), ('streets', 0.011840191293395055), ('richard', 0.01057144067236192), ('tube', 0.010031996644519975), ('malkanis', 0.009064856059962663)]
[('school', 0.013431133094552485), ('best', 0.009823454304689973), ('friends', 0.009464132922901597), ('summer', 0.00925477248264793), ('friend', 0.007521725997622359), ('shes', 0.007428361508460959), ('girl', 0.00669198165047273), ('girls', 0.006624445942547353), ('hes', 0.006591530574952619), ('year', 0.006245443569027486)]
[('kent', 0.0429726320735342), ('war', 0.02448913251228431), ('civil', 0.022706699046543346), ('america', 0.013737854573026915), ('south', 0.013472278520865457), ('kents', 0.013392252290198192), ('american', 0.01212188497955491), ('richmond', 0.012082065340881935), ('sid', 0.011619056906458657), ('philip', 0.010782754203034435)]
[('la

In [51]:
similar_topics, similarity = topic_model.find_topics("weddings", top_n=5); 
for i in similar_topics:
  print(topic_model.get_topic(i))

[('wedding', 0.03916593957420759), ('kady', 0.020050764215207333), ('courcy', 0.01981639926858164), ('lavon', 0.019255973011817765), ('laurel', 0.016080266033420866), ('weddings', 0.015425903691891344), ('saoirse', 0.015404236943605848), ('fake', 0.01468915082628304), ('dresses', 0.013383102735489205), ('paddy', 0.012373499051415122)]
[('lady', 0.012961724086009402), ('duke', 0.012770980444443641), ('london', 0.01041792680048587), ('earl', 0.010334892305214215), ('handsome', 0.009158555970069062), ('marry', 0.009114910700602493), ('marriage', 0.008750041093465399), ('lord', 0.007848802719804085), ('sebastian', 0.007683679399235111), ('man', 0.007219229312385458)]
[('leigh', 0.03543143098235519), ('laura', 0.02560559461904678), ('roe', 0.025499112589130388), ('beau', 0.023053917227500064), ('clayborne', 0.022632620821345602), ('retirement', 0.02203703943189772), ('husband', 0.02161702700636077), ('moriarty', 0.020418432812756977), ('veronica', 0.019738595664811454), ('tess', 0.019580747

In [52]:
similar_topics, similarity = topic_model.find_topics("break ups", top_n=5); 
for i in similar_topics:
  print(topic_model.get_topic(i))

[('life', 0.009562962604971383), ('sang', 0.009357778784174663), ('heart', 0.007853420914954843), ('im', 0.007665856079708666), ('love', 0.0073053206784513035), ('never', 0.00715828303561882), ('hes', 0.006350655925898168), ('away', 0.006010356821241176), ('everything', 0.005785891180156243), ('didnt', 0.005626585367011441)]
[('life', 0.003676018545439045), ('one', 0.0036720614476845933), ('world', 0.0034660000369633424), ('new', 0.003419656703564767), ('love', 0.0033461391623841523), ('family', 0.00292463995291762), ('story', 0.0028872719143259577), ('time', 0.00287022316312423), ('shes', 0.0028274234966838984), ('find', 0.0027607571398824892)]
[('shes', 0.00958059681358302), ('boyfriend', 0.008298512449777798), ('job', 0.008202797414343876), ('hollywood', 0.008131362487835914), ('movie', 0.0070544027894235235), ('perfect', 0.006878291545082357), ('alex', 0.006701325830191191), ('career', 0.006605880610103581), ('becky', 0.006084279908527788), ('lucy', 0.006078596701183006)]
[('aubrey

In [53]:
similar_topics, similarity = topic_model.find_topics("friendship", top_n=5); 
for i in similar_topics:
  print(topic_model.get_topic(i))

[('school', 0.013431133094552485), ('best', 0.009823454304689973), ('friends', 0.009464132922901597), ('summer', 0.00925477248264793), ('friend', 0.007521725997622359), ('shes', 0.007428361508460959), ('girl', 0.00669198165047273), ('girls', 0.006624445942547353), ('hes', 0.006591530574952619), ('year', 0.006245443569027486)]
[('life', 0.003676018545439045), ('one', 0.0036720614476845933), ('world', 0.0034660000369633424), ('new', 0.003419656703564767), ('love', 0.0033461391623841523), ('family', 0.00292463995291762), ('story', 0.0028872719143259577), ('time', 0.00287022316312423), ('shes', 0.0028274234966838984), ('find', 0.0027607571398824892)]
[('love', 0.006861141948113139), ('family', 0.00632674839651223), ('novel', 0.005236338713958875), ('life', 0.004602173525220112), ('story', 0.004353950772237878), ('mother', 0.004266249748517645), ('new', 0.00391445130865658), ('author', 0.003881338566236123), ('taylor', 0.003806573611090984), ('characters', 0.0037501324252181517)]
[('vegetar

In [54]:
similar_topics, similarity = topic_model.find_topics("divorce", top_n=5); 
for i in similar_topics:
  print(topic_model.get_topic(i))

[('lady', 0.012961724086009402), ('duke', 0.012770980444443641), ('london', 0.01041792680048587), ('earl', 0.010334892305214215), ('handsome', 0.009158555970069062), ('marry', 0.009114910700602493), ('marriage', 0.008750041093465399), ('lord', 0.007848802719804085), ('sebastian', 0.007683679399235111), ('man', 0.007219229312385458)]
[('leigh', 0.03543143098235519), ('laura', 0.02560559461904678), ('roe', 0.025499112589130388), ('beau', 0.023053917227500064), ('clayborne', 0.022632620821345602), ('retirement', 0.02203703943189772), ('husband', 0.02161702700636077), ('moriarty', 0.020418432812756977), ('veronica', 0.019738595664811454), ('tess', 0.019580747015524435)]
[('love', 0.006861141948113139), ('family', 0.00632674839651223), ('novel', 0.005236338713958875), ('life', 0.004602173525220112), ('story', 0.004353950772237878), ('mother', 0.004266249748517645), ('new', 0.00391445130865658), ('author', 0.003881338566236123), ('taylor', 0.003806573611090984), ('characters', 0.003750132425

In [55]:
similar_topics, similarity = topic_model.find_topics("relationships", top_n=5); 
for i in similar_topics:
  print(topic_model.get_topic(i))

[('aubrey', 0.013444589718433698), ('jenny', 0.01145067057633071), ('love', 0.011145459886762916), ('rosie', 0.010972596937458995), ('hopey', 0.010667048810013272), ('relationship', 0.01057607388291274), ('things', 0.010325908158456543), ('lincoln', 0.010305225036077018), ('alex', 0.010171013308748934), ('bertha', 0.010000709776472558)]
[('love', 0.006861141948113139), ('family', 0.00632674839651223), ('novel', 0.005236338713958875), ('life', 0.004602173525220112), ('story', 0.004353950772237878), ('mother', 0.004266249748517645), ('new', 0.00391445130865658), ('author', 0.003881338566236123), ('taylor', 0.003806573611090984), ('characters', 0.0037501324252181517)]
[('life', 0.003676018545439045), ('one', 0.0036720614476845933), ('world', 0.0034660000369633424), ('new', 0.003419656703564767), ('love', 0.0033461391623841523), ('family', 0.00292463995291762), ('story', 0.0028872719143259577), ('time', 0.00287022316312423), ('shes', 0.0028274234966838984), ('find', 0.0027607571398824892)]

In [56]:
similar_topics, similarity = topic_model.find_topics("death", top_n=5); 
for i in similar_topics:
  print(topic_model.get_topic(i))

[('zombie', 0.033486816651960336), ('virus', 0.025997025440850807), ('zombies', 0.019739186508940614), ('survivors', 0.015725412605839144), ('dead', 0.014763352191325197), ('plague', 0.011536406850210333), ('infected', 0.011472722316704112), ('walking', 0.009021840062440546), ('apocalypse', 0.008793566271296235), ('disease', 0.008658269074107511)]
[('killer', 0.014232142563428791), ('murder', 0.01137766122305586), ('detective', 0.011061228975776148), ('case', 0.010989466299622266), ('police', 0.008939553520021225), ('crime', 0.00798741800201927), ('investigation', 0.0073132879804698485), ('serial', 0.006808359834051321), ('murdered', 0.005687811497059156), ('victims', 0.005535739160214359)]
[('life', 0.009562962604971383), ('sang', 0.009357778784174663), ('heart', 0.007853420914954843), ('im', 0.007665856079708666), ('love', 0.0073053206784513035), ('never', 0.00715828303561882), ('hes', 0.006350655925898168), ('away', 0.006010356821241176), ('everything', 0.005785891180156243), ('didn

In [57]:
similar_topics, similarity = topic_model.find_topics("family", top_n=5); 
for i in similar_topics:
  print(topic_model.get_topic(i))

[('love', 0.006861141948113139), ('family', 0.00632674839651223), ('novel', 0.005236338713958875), ('life', 0.004602173525220112), ('story', 0.004353950772237878), ('mother', 0.004266249748517645), ('new', 0.00391445130865658), ('author', 0.003881338566236123), ('taylor', 0.003806573611090984), ('characters', 0.0037501324252181517)]
[('life', 0.003676018545439045), ('one', 0.0036720614476845933), ('world', 0.0034660000369633424), ('new', 0.003419656703564767), ('love', 0.0033461391623841523), ('family', 0.00292463995291762), ('story', 0.0028872719143259577), ('time', 0.00287022316312423), ('shes', 0.0028274234966838984), ('find', 0.0027607571398824892)]
[('lucy', 0.022611974004423364), ('father', 0.016254596988697984), ('holly', 0.012558803919365157), ('mother', 0.012084998778411434), ('mothers', 0.009925133982837196), ('alfie', 0.00840786068204534), ('daughter', 0.008394862654365925), ('amy', 0.007827823640422782), ('life', 0.0078250906410042), ('sally', 0.0074682581384666515)]
[('ami

In [24]:
topic_model.visualize_topics()

In [25]:
topic_model.visualize_distribution(probs[200], min_probability=0.015)

NameError: ignored

In [None]:
topic_model.visualize_hierarchy(top_n_topics=50)

In [None]:
topic_model.visualize_barchart(top_n_topics=5)

In [None]:
topic_model.visualize_heatmap(n_clusters=20, width=1000, height=1000)