In [1]:
from sklearn.datasets import fetch_20newsgroups
from bertopic import BERTopic
import pandas as pd
import numpy as np
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = "3" 


In [3]:
df = pd.read_csv ('../../datasets/Reuters-21578/bertopic_result/50topic_info_3000.csv')
df

Unnamed: 0,Topic,Count,Name,Representation,Llama2,Representative_Docs
0,-1,2160,-1_growth_mergers_finance_economic growth,"['growth', 'mergers', 'finance', 'economic gro...","['Business and Economic Growth', '', '', '', '...","['mergers, acquisitions, dividends', 'antitrus..."
1,0,518,0_oil_oil prices_energy_gas,"['oil', 'oil prices', 'energy', 'gas', 'prices...",['Oil prices and their impact on the energy in...,"['oil, prices', 'oil, prices, firms', 'oil pri..."
2,1,199,1_takeover_shares_ownership_securities,"['takeover', 'shares', 'ownership', 'securitie...","['Corporate takeovers and share ownership', ''...","['takeover, tender offer, shares', 'takeover, ..."
3,2,126,2_acquisitions debt_acquisitions_acquisitions ...,"['acquisitions debt', 'acquisitions', 'acquisi...","['Corporate finance and acquisitions', '', '',...","['acquisitions, debt, downgrade', 'finance, ac..."
4,3,124,3_dividend_dividend finance_dividends_finance ...,"['dividend', 'dividend finance', 'dividends', ...","['Dividend Finance', '', '', '', '', '', '', '...","['dividend, finance', 'dividend, finance', 'di..."
...,...,...,...,...,...,...
138,137,21,137_soybeans_cotton_agriculture soybeans_trade...,"['soybeans', 'cotton', 'agriculture soybeans',...","['Soybean trade and regulation', '', '', '', '...","['agriculture, soybeans, trade', 'imports, soy..."
139,138,21,138_budget deficit_deficit_budget_cuts,"['budget deficit', 'deficit', 'budget', 'cuts'...","['Fiscal policy and government spending', '', ...","['budget, deficit, taxes', 'budget, deficit, g..."
140,139,21,139_energy_energy prices_economy energy_prices...,"['energy', 'energy prices', 'economy energy', ...","['Energy market and prices', '', '', '', '', '...","['economy, energy, prices', 'economy, energy, ..."
141,140,21,140_loan_loans_losses_loan loss,"['loan', 'loans', 'losses', 'loan loss', 'prov...","['Financial performance and loan losses', '', ...","['earnings, loan loss provisions, financial re..."


In [4]:
rep_label = []
for i in range(len(df)):
    s = df['Representation'][i]
    rep_label.append(s[1:-1].split(',')[0][1:-1])

In [5]:
llama_label = []
for i in range(len(df)):
    s = df['Llama2'][i]
    llama_label.append(s[1:-1].split(',')[0][1:-1])

In [6]:
file1 = open('../../datasets/Reuters-21578/test_raw_texts.txt', 'r')
documents = file1.readlines()[:1000]  

In [7]:
from sentence_transformers import SentenceTransformer

# Pre-calculate embeddings
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedding_model.encode(documents, show_progress_bar=True)

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

In [8]:
from umap import UMAP

umap_model = UMAP(n_neighbors=5, n_components=5, min_dist=0.0, metric='cosine', random_state=42)

In [9]:
from hdbscan import HDBSCAN

hdbscan_model = HDBSCAN(min_cluster_size=20, metric='euclidean', cluster_selection_method='eom', prediction_data=True)

In [10]:
from sklearn.feature_extraction.text import CountVectorizer
vectorizer_model = CountVectorizer(stop_words="english", min_df=2, ngram_range=(1, 2))

In [30]:
zeroshot_topic_list = llama_label

In [31]:
topic_model = BERTopic(

  # Sub-models
  embedding_model=embedding_model,
  umap_model=umap_model,
  hdbscan_model=hdbscan_model,
  vectorizer_model=vectorizer_model,
  zeroshot_topic_list=zeroshot_topic_list,
  zeroshot_min_similarity=.05,

  # Hyperparameters
  top_n_words=10,
  verbose=True,
  calculate_probabilities=True,
)

# Train model
topics, probs = topic_model.fit_transform(documents, embeddings)

2024-01-18 10:04:45,177 - BERTopic - Zeroshot Step 1 - Finding documents that could be assigned to either one of the zero-shot topics
2024-01-18 10:04:45,865 - BERTopic - Zeroshot Step 2 - Clustering documents that were not found in the zero-shot model...
2024-01-18 10:04:45,868 - BERTopic - Dimensionality - Fitting the dimensionality reduction algorithm
2024-01-18 10:04:45,868 - BERTopic - Dimensionality - Completed ✓
2024-01-18 10:04:45,870 - BERTopic - Cluster - Start clustering the reduced embeddings
2024-01-18 10:04:45,872 - BERTopic - Cluster - Completed ✓
2024-01-18 10:04:45,876 - BERTopic - Representation - Extracting topics from clusters using representation models.
2024-01-18 10:04:46,240 - BERTopic - Representation - Completed ✓
2024-01-18 10:04:46,557 - BERTopic - Zeroshot Step 2 - Completed ✓
2024-01-18 10:04:46,558 - BERTopic - Zeroshot Step 3 - Combining clustered topics with the zeroshot model


In [32]:
topic_model.get_topic_info()

Unnamed: 0,Topic,Count,Name,Representation,Representative_Docs
0,0,103,Earnings from Discontinued Operations,"[dlrs, 000 dlrs, 000, quarter, loss, lt, net, ...",[XEBEC &lt;XEBC> TO REPORT 2ND QTR LOSS Xebec...
1,1,55,Telecom earnings quarterly results,"[cts, vs, cts vs, shr, net, 000 vs, revs, 000,...",[ELECTRO RENT CORP &lt;ELRC> 3RD QTR FEB 28 NE...
2,2,45,Financial instruments and interest rates,"[days, 00 pct, rate, pct, net, billion, mln dl...","[J.P. MORGAN &lt;JPM> NET HURT BY BRAZIL, TRAD..."
3,3,39,Financial Performance and Losses,"[loss, vs loss, profit, net loss, vs, shr loss...",[ATLAS CONSOLIDATED MINING &lt;ACMB> 4TH QTR ...
4,4,38,Corporate takeovers and share ownership,"[shares, offer, common, stake, stock, dome, ua...",[CRAZY EDDIE &lt;CRZY> SETS DEFENSIVE RIGHTS ...
...,...,...,...,...,...
107,107,1,Corporate Leadership Changes,"[shr 37, 18 dlrs, 37 dlrs, general electric, v...",[GENERAL ELECTRIC CO 1ST QTR SHR 1.37 DLRS VS ...
108,108,1,Corporate earnings and dividends,"[tax gain, include tax, stock distribution, 77...",[DOW JONES AND CO INC &lt;DJ> 1ST QTR NET Shr...
109,109,1,Economic Sanctions and Trade Restrictions in S...,"[court press, fitzwater, press, sanctions, eff...",[WHITE HOUSE STANDING FIRM ON JAPANESE SANCTIO...
110,110,1,Government Subsidies in Agricultural Trade,"[retail, maize, price, imf, zambian, price mai...",[ZAMBIA DOES NOT PLAN RETAIL MAIZE PRICE HIKE ...


In [33]:
df = topic_model.get_document_info(documents)
df

Unnamed: 0,Document,Topic,Name,Representation,Representative_Docs,Top_n_words,Representative_document
0,ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN RI...,57,International Trade and Exports,"[trade, steel, exports, imports, japan, steel ...","[CANADA PLANS TO MONITOR STEEL IMPORTS, EXPORT...",trade - steel - exports - imports - japan - st...,True
1,CHINA DAILY SAYS VERMIN EAT 7-12 PCT GRAIN STO...,71,Agricultural Production and Disease,"[hectares, mln hectares, china, dry spell, dry...",[CHINA DAILY SAYS VERMIN EAT 7-12 PCT GRAIN ST...,hectares - mln hectares - china - dry spell - ...,True
2,JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWA...,89,Energy market and prices,"[energy, miti, demand, revise, natural, power,...",[JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNW...,energy - miti - demand - revise - natural - po...,True
3,THAI TRADE DEFICIT WIDENS IN FIRST QUARTER Th...,5,Agricultural Exports,"[nil, nil nil, 87, 09, 09 87, tonnes, 1986 87,...",[WORLD SUPPLY/DEMAND ESTIMATES ISSUED BY USDA ...,nil - nil nil - 87 - 09 - 09 87 - tonnes - 198...,False
4,INDONESIA SEES CPO PRICE RISING SHARPLY Indon...,30,OPEC oil production and prices,"[prices, producer prices, crude, cts barrel, s...",[GERMAN PRODUCER PRICES FALL 0.1 PCT IN MARCH ...,prices - producer prices - crude - cts barrel ...,False
...,...,...,...,...,...,...,...
995,HONEYWELL INC &lt;HON> 1ST QTR OPER NET Oper ...,29,Aviation Financing and Technology,"[showboat, atlantic, fleet, hotel, quarter, ch...",[ATLANTIC FINANCIAL &lt;ATLF.O> TO ACQUIRE S A...,showboat - atlantic - fleet - hotel - quarter ...,False
996,WALL STREET STOCKS/BROWNING FERRIS &lt;BFI> T...,8,Corporate actions: stock splits and earnings,"[split, stock split, stock, mln vs, shares, vs...",[MAYFAIR SUPER MARKETS INC &lt;MYFRA> 2ND QTR ...,split - stock split - stock - mln vs - shares ...,False
997,HOMESTAKE &lt;HM> MULLS BUYING ORE RESERVES H...,37,Mining and Gold Reserves,"[gold, atlas, exploration, tons, 000 tons, min...",[GORDEX MINERALS LOCATES CANADA GOLD DEPOSITS ...,gold - atlas - exploration - tons - 000 tons -...,False
998,CHRONAR CORP &lt;CRNR.O> YEAR LOSS Shr loss 9...,0,Earnings from Discontinued Operations,"[dlrs, 000 dlrs, 000, quarter, loss, lt, net, ...",[XEBEC &lt;XEBC> TO REPORT 2ND QTR LOSS Xebec...,dlrs - 000 dlrs - 000 - quarter - loss - lt - ...,False


In [21]:
file1 = open('../../datasets/Reuters-21578/test_label.txt', 'r')
labels = file1.readlines()  

In [22]:
test_set = []  
for label in labels:
    lbs = label.strip().split(' ')
    test_set.append(lbs)
len(test_set)

3019

In [16]:
true_labels = []
for row in labels:
    lb_set = row.strip().split(' ')
    for lb in lb_set:
        if not lb in true_labels:
            true_labels.append(lb)

In [24]:
len(true_labels)

90

In [92]:
true_labels = []
for row in labels:
    true_labels.append(row.split('; ')[1].strip())
true_labels

['Information Retrieval',
 'Methodology',
 'Quantum Physics',
 'Information Theory',
 'Information Theory',
 'Applications',
 'Computer Vision and Pattern Recognition',
 'Computation and Language',
 'Artificial Intelligence',
 'Numerical Analysis',
 'Mathematical Software',
 'Cryptography and Security',
 'Software Engineering',
 'Machine Learning',
 'Networking and Internet Architecture',
 'Systems and Control',
 'Data Structures and Algorithms',
 'Computational Complexity',
 'Formal Languages and Automata Theory',
 'Robotics',
 'Optimization and Control',
 'Multiagent Systems',
 'Performance',
 'Social and Information Networks',
 'Physics and Society;',
 'Distributed, Parallel, and Cluster Computing',
 'Databases',
 'Combinatorics',
 'Machine Learning',
 'Probability',
 'Neural and Evolutionary Computing',
 'Discrete Mathematics',
 'Statistical Mechanics;',
 'Logic in Computer Science',
 'Computers and Society',
 'Disordered Systems and Neural Networks;',
 'Numerical Analysis',
 'Comp

In [144]:
file2 = open('../../datasets/Amazon-531/test_label.txt', 'r')
test_label_set = file2.readlines()
test_set = []  
for label in test_label_set:
    labels = label.strip().split(', ')
    test_set.append(labels)
len(test_set)

19658

In [145]:
test_set

[['beauty', 'hair_care', 'styling_tools,'],
 ['toys_games', 'dolls_accessories', 'dolls,'],
 ['baby_products', 'gifts,'],
 ['grocery_gourmet_food', 'pantry_staples', 'canned_jarred_food,'],
 ['beauty', 'skin_care', 'sun,'],
 ['grocery_gourmet_food', 'beverages', 'tea,'],
 ['beauty', 'skin_care', 'body,'],
 ['baby_products', 'gear', 'baby_gyms_playmats,'],
 ['health_personal_care', 'personal_care', 'oral_hygiene,'],
 ['toys_games', 'dress_up_pretend_play', 'pretend_play,'],
 ['health_personal_care', 'personal_care', 'lip_care_products,'],
 ['beauty', 'bath_body', 'bath,'],
 ['pet_supplies', 'dogs', 'feeding_watering_supplies,'],
 ['health_personal_care', 'nutrition_wellness', 'nutrition_bars_drinks,'],
 ['health_personal_care', 'health_care', 'massage_relaxation,'],
 ['health_personal_care', 'nutrition_wellness', 'vitamins_supplements,'],
 ['beauty', 'hair_care', 'hair_perms_texturizers,'],
 ['toys_games', 'dolls_accessories', 'doll_accessories,'],
 ['beauty', 'hair_care', 'styling_tool

In [140]:
file2 = open('../../datasets/Amazon-531/test/labels.txt', 'r')
test_label_set = file2.readlines()
test_label_set

['0\tgrocery_gourmet_food\n',
 '1\tmeat_poultry\n',
 '2\tjerky\n',
 '3\ttoys_games\n',
 '4\tgames\n',
 '5\tpuzzles\n',
 '6\tjigsaw_puzzles\n',
 '7\tboard_games\n',
 '8\tbeverages\n',
 '9\tjuices\n',
 '10\tbeauty\n',
 '11\tmakeup\n',
 '12\tnails\n',
 '13\tarts_crafts\n',
 '14\tdrawing_painting_supplies\n',
 '15\taction_toy_figures\n',
 '16\tfigures\n',
 '17\tdolls_accessories\n',
 '18\tdolls\n',
 '19\tcard_games\n',
 '20\tdrawing_sketching_tablets\n',
 '21\tbaby_toddler_toys\n',
 '22\tshape_sorters\n',
 '23\thealth_personal_care\n',
 '24\tpersonal_care\n',
 '25\tdeodorants_antiperspirants\n',
 '26\tnutrition_wellness\n',
 '27\tnutrition_bars_drinks\n',
 '28\tlearning_education\n',
 '29\thabitats\n',
 '30\telectronics_for_kids\n',
 '31\thousehold_supplies\n',
 '32\thousehold_batteries\n',
 '33\tpush_pull_toys\n',
 '34\tstuffed_animals_plush\n',
 '35\ttricycles\n',
 '36\tscooters_wagons\n',
 '37\tclay_dough\n',
 '38\thealth_care\n',
 '39\tallergy\n',
 '40\tbaby_products\n',
 '41\tgear\n',

In [141]:
true_labels = []
for label in test_label_set:
    true_labels.append(label.strip().split('\t')[1])
true_labels

['grocery_gourmet_food',
 'meat_poultry',
 'jerky',
 'toys_games',
 'games',
 'puzzles',
 'jigsaw_puzzles',
 'board_games',
 'beverages',
 'juices',
 'beauty',
 'makeup',
 'nails',
 'arts_crafts',
 'drawing_painting_supplies',
 'action_toy_figures',
 'figures',
 'dolls_accessories',
 'dolls',
 'card_games',
 'drawing_sketching_tablets',
 'baby_toddler_toys',
 'shape_sorters',
 'health_personal_care',
 'personal_care',
 'deodorants_antiperspirants',
 'nutrition_wellness',
 'nutrition_bars_drinks',
 'learning_education',
 'habitats',
 'electronics_for_kids',
 'household_supplies',
 'household_batteries',
 'push_pull_toys',
 'stuffed_animals_plush',
 'tricycles',
 'scooters_wagons',
 'clay_dough',
 'health_care',
 'allergy',
 'baby_products',
 'gear',
 'baby_gyms_playmats',
 'shaving_hair_removal',
 'skin_care',
 'face',
 'animals_figures',
 'feminine_care',
 'music_sound',
 'oral_hygiene',
 'grown_up_toys',
 'dress_up_pretend_play',
 'pretend_play',
 'novelty_gag_toys',
 'bath_body',
 'c

In [None]:
file3 = open('../../datasets/Amazon-531/test_label.txt', 'r')
test_label_set = file2.readlines()

In [25]:
from sentence_transformers import SentenceTransformer, util

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device = 0)


In [26]:
df

Unnamed: 0,Document,Topic,Name,Representation,Representative_Docs,Top_n_words,Representative_document
0,ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN RI...,21,exports trade,"[ec, trade, steel, japan, coffee, exports, jap...",[ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN R...,ec - trade - steel - japan - coffee - exports ...,True
1,CHINA DAILY SAYS VERMIN EAT 7-12 PCT GRAIN STO...,18,commodities agriculture,"[tonnes, grain, wheat, storage, crop, 000 tonn...",[BRAZIL GRAIN HARVEST FACES STORAGE PROBLEMS ...,tonnes - grain - wheat - storage - crop - 000 ...,False
2,JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWA...,97,energy,"[energy, miti, demand, revise, natural, power,...",[JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNW...,energy - miti - demand - revise - natural - po...,True
3,THAI TRADE DEFICIT WIDENS IN FIRST QUARTER Th...,3,agriculture exports,"[nil, nil nil, 87, tonnes, 1986 87, 1987 88, s...",[SOYBEAN SUPPLY/DEMAND BY COUNTRY -- USDA The...,nil - nil nil - 87 - tonnes - 1986 87 - 1987 8...,False
4,INDONESIA SEES CPO PRICE RISING SHARPLY Indon...,4,commodity prices,"[prices, futures, price, traders, palm, rise, ...",[U.K. PRODUCER PRICES SEEN MOVED BY TECHNICALI...,prices - futures - price - traders - palm - ri...,False
...,...,...,...,...,...,...,...
995,HONEYWELL INC &lt;HON> 1ST QTR OPER NET Oper ...,20,earnings mergers,"[vs, net, shr, qtr, 000, dlrs vs, cts, 05 dlrs...",[WESTINGHOUSE ELECTRIC CORP &lt;WX> 1ST QTR NE...,vs - net - shr - qtr - 000 - dlrs vs - cts - 0...,False
996,WALL STREET STOCKS/BROWNING FERRIS &lt;BFI> T...,63,finance stocks,"[steel, fleet, sees, suit, recommend, stock, s...",[WALL STREET STOCKS/BROWNING FERRIS &lt;BFI> ...,steel - fleet - sees - suit - recommend - stoc...,True
997,HOMESTAKE &lt;HM> MULLS BUYING ORE RESERVES H...,32,mining gold,"[gold, mining, tons, 000 tons, ounces, explora...",[GORDEX MINERALS LOCATES CANADA GOLD DEPOSITS ...,gold - mining - tons - 000 tons - ounces - exp...,True
998,CHRONAR CORP &lt;CRNR.O> YEAR LOSS Shr loss 9...,15,loss revenue,"[loss, 000 dlrs, 000, dlrs, vs profit, quarter...",[GENERAL INSTRUMENT CORP &lt;GRL> 4TH QTR LOSS...,loss - 000 dlrs - 000 - dlrs - vs profit - qua...,False


In [35]:
correct = 0
for index in range(len(documents)):
    query_embedding = model.encode(df['Name'][index])
    passage_embedding = model.encode(test_set[index])
    print(df['Name'][index])
    sim_scores = util.dot_score(query_embedding, passage_embedding)[0].numpy()
    rank_list = np.argsort(sim_scores)[-1]
    print(test_set[index][rank_list])
    if sim_scores[rank_list] >= 0.60:
        correct +=1
correct

International Trade and Exports
trade
Agricultural Production and Disease
grain
Energy market and prices
nat-gas
Agricultural Exports
corn
OPEC oil production and prices
veg-oil
Maritime Accidents and Safety
ship
Coffee trade regulations
coffee
International trade and agricultural policies
wheat
Gold Mining Investment and Production
gold
Corporate Finance: Mergers & Acquisitions
acq
Trade Agreements and Negotiations
tin
Financial instruments and interest rates
interest
Mining and Gold Reserves
copper
Economic growth and industrial production
ipi
Global Agricultural Markets
livestock
Dividend Finance
earn
Corporate earnings and revenue
earn
Central banking and monetary policy
money-fx
Trade Surplus and Economic Balance
trade
Trading commodities: future
lead
Gold Mining Investment and Production
acq
Economic growth and industrial production
jobs
Gold Mining Investment and Production
earn
Financial Regulation
earn
Financial Performance and Losses
earn
Dividend Finance
earn
Trade Agreement

6

In [28]:
correct = 0
for lb in rep_label:
    query_embedding = model.encode(lb)
    print(lb)
    passage_embedding = model.encode(true_labels)
    sim_scores = util.dot_score(query_embedding, passage_embedding)[0].numpy()
    rank_list = np.argsort(sim_scores)[-1]
    print(true_labels[rank_list], sim_scores[rank_list])
    if sim_scores[rank_list] >= 0.60:
        correct +=1
correct

growth
money-supply 0.41968155
oil
soy-oil 0.72730654
takeover
lead 0.3945085
acquisitions debt
instal-debt 0.56405336
dividend
earn 0.45968884
policy currency
money-fx 0.5866241
labor
jobs 0.59846985
mergers_acquisitions
strategic-metal 0.25205106
policy rates
cpi 0.3969617
restructuring
housing 0.4261221
revenue quarterly
income 0.45236588
retail
retail 1.0000001
debt
instal-debt 0.63177705
economy exports
money-supply 0.37944
shipping
ship 0.7106267
manufacturing
jobs 0.48060197
bonds
interest 0.43741518
indicators
cpi 0.36446005
trade agriculture
livestock 0.53268754
airlines
jet 0.4906035
weather
heat 0.33420545
relations
interest 0.35102537
divestiture
jobs 0.34716272
split
interest 0.32740128
earnings acquisitions
earn 0.33341706
discontinued
retail 0.36432317
mining
fuel 0.4253512
revenue earnings
income 0.5739685
finance government
money-fx 0.49767777
accidents
jobs 0.41744506
economic growth
income 0.47945014
media
jobs 0.39661086
finance record
income 0.34613106
investment e

21

In [29]:
correct/len(llama_label)

0.14685314685314685