In [100]:
pip install accelerate bertopic matplotlib jupyter-dash


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting jupyter-dash
  Downloading jupyter_dash-0.4.2-py3-none-any.whl.metadata (3.6 kB)
Collecting dash (from jupyter-dash)
  Downloading dash-2.16.1-py3-none-any.whl.metadata (10 kB)
Collecting flask (from jupyter-dash)
  Downloading flask-3.0.3-py3-none-any.whl.metadata (3.2 kB)
Collecting retrying (from jupyter-dash)
  Downloading retrying-1.3.4-py3-none-any.whl.metadata (6.9 kB)
Collecting ansi2html (from jupyter-dash)
  Downloading ansi2html-1.9.1-py3-none-any.whl.metadata (3.7 kB)
Collecting Werkzeug<3.1 (from dash->jupyter-dash)
  Downloading werkzeug-3.0.2-py3-none-any.whl.metadata (4.1 kB)
Collecting dash-html-components==2.0.0 (from dash->jupyter-dash)
  Downloading dash_html_components-2.0.0-py3-none-any.whl.metadata (3.8 kB)
Collecting dash-core-components==2.0.0 (from dash->jupyter-dash)
  Downloading dash_core_components-2.0.0-py3-none-any.whl.metadata (2.9 kB)
Collecting dash-table==5.0.0 (from dash->jupyter-dash)
  Downloading dash_table-5.0.0-py3-none-any.whl.metad

In [None]:
from huggingface_hub import notebook_login

In [None]:
notebook_login()

In [22]:
from datasets import load_dataset, Dataset

In [2]:
import torch 
import json
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from bertopic import BERTopic
from umap import UMAP
from hdbscan import HDBSCAN
from sklearn.feature_extraction.text import CountVectorizer
from tqdm import tqdm

In [21]:
embedded_ds = load_dataset("mwarchalowski/grants", "no-shorts-no-dups")
embedded_ds

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'abstr', 'text_len', 'tensors'],
        num_rows: 126778
    })
})

In [4]:
labeled_ds = load_dataset("mwarchalowski/grants", "labeled_subset")

In [5]:
dataset = embedded_ds

In [6]:
embedding_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

In [7]:
umap_model = UMAP(n_neighbors=10, n_components=8, min_dist=0.0, metric='cosine', random_state=42)
umap_model

In [8]:
hdbscan_model = HDBSCAN(min_cluster_size=50, metric='euclidean', cluster_selection_method='eom', prediction_data=True)
hdbscan_model

In [9]:
vectorizer_model = CountVectorizer(stop_words="english", min_df=2, ngram_range=(1, 2))
vectorizer_model

In [10]:
topic_model = BERTopic(

  # Pipeline models
  embedding_model=embedding_model,
  umap_model=umap_model,
  hdbscan_model=hdbscan_model,
  vectorizer_model=vectorizer_model,

  # Hyperparameters
  top_n_words=30,
  verbose=True
)


In [11]:
splits = dataset["train"].train_test_split(test_size=0.1)
splits

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'abstr', 'text_len', 'tensors'],
        num_rows: 114100
    })
    test: Dataset({
        features: ['id', 'title', 'abstr', 'text_len', 'tensors'],
        num_rows: 12678
    })
})

In [12]:
embeddings = [np.array(json.loads(x)) for x in splits["train"]["tensors"]]

In [13]:
embeddings = np.array(embeddings)
embeddings

array([[ 0.09480636,  0.25657061,  0.3466478 , ..., -0.35156459,
         0.29383859,  0.08172194],
       [-0.06521176, -0.63509202, -0.14487246, ...,  0.37998062,
         0.21496971,  0.09903586],
       [-0.44754729,  0.25623447,  0.42298391, ..., -0.44946998,
        -0.78393954, -1.05713916],
       ...,
       [-0.1533751 ,  0.61580908, -0.26142266, ...,  0.14666657,
         0.85336673, -0.24948071],
       [ 0.16619322,  0.02835516, -0.28797734, ..., -0.45414737,
        -0.03824041,  0.24011645],
       [ 0.02334986,  0.58924997, -0.89733708, ..., -0.20840904,
        -0.22607224,  0.21241634]])

In [14]:
topics, probs = topic_model.fit_transform(splits["train"]["abstr"], embeddings)

2024-04-09 06:13:23,664 - BERTopic - Dimensionality - Fitting the dimensionality reduction algorithm
2024-04-09 06:15:31,141 - BERTopic - Dimensionality - Completed ✓
2024-04-09 06:15:31,145 - BERTopic - Cluster - Start clustering the reduced embeddings
2024-04-09 06:15:39,248 - BERTopic - Cluster - Completed ✓
2024-04-09 06:15:39,277 - BERTopic - Representation - Extracting topics from clusters using representation models.
2024-04-09 06:17:09,934 - BERTopic - Representation - Completed ✓


In [15]:
pd.set_option('display.max_colwidth', None)
topic_model.get_topic_info()[["Count", "Representation"]][:25]


Unnamed: 0,Count,Representation
0,58528,"[research, project, new, cancer, cells, data, development, study, based, high, cell, using, use, used, systems, br, studies, develop, time, provide, patients, human, different, gt, breast, lt, lt br, br gt, design, specific]"
1,2569,"[physics, stars, universe, solar, galaxies, dark, matter, observations, particle, gravitational, stellar, star, space, galaxy, dark matter, lhc, planets, particles, neutrino, mass, theory, gravity, br, energy, standard model, detector, particle physics, magnetic, cosmic, data]"
2,1439,"[alloys, materials, material, mechanical, composite, manufacturing, alloy, strength, process, composites, temperature, properties, microstructure, crack, high, fatigue, grain, corrosion, steel, phase, steels, mechanical properties, deformation, metal, coating, thermal, ceramic, coatings, components, aluminum]"
3,1390,"[br, theory, lt br, br gt, lt, gt, geometry, algebraic, algebras, equations, spaces, geometric, manifolds, mathematics, algebra, groups, mathematical, differential, problems, conjecture, dimensional, operators, lie, invariants, finite, functions, number theory, varieties, group, space]"
4,1369,"[species, evolutionary, populations, evolution, genetic, variation, selection, traits, ecological, population, plant, reproductive, diversity, birds, speciation, biodiversity, sexual, males, conservation, ecology, br, natural, habitat, females, plants, mating, change, patterns, fitness, animals]"
5,1048,"[mantle, seismic, rocks, earth, crust, magma, earthquake, fault, volcanic, tectonic, plate, crustal, earthquakes, subduction, magmatic, deformation, deposits, rock, mineral, continental, minerals, geological, br, evolution, lt, fluid, gt, lt br, br gt, zone]"
6,964,"[plant, plants, arabidopsis, genes, crop, gene, wheat, auxin, proteins, seed, root, genetic, protein, resistance, molecular, stress, mutants, crops, growth, breeding, cell, barley, expression, pathogen, thaliana, rice, genome, chloroplast, cell wall, regulation]"
7,804,"[coal, water, waste, gas, process, sludge, wastewater, combustion, removal, treatment, emissions, plant, biomass, co2, gasification, energy, industrial, bed, carbon, furnace, technology, flotation, slag, dust, production, wastes, air, steel, recovery, industry]"
8,782,"[antibiotic, bacteria, antibiotics, bacterial, resistance, infections, infection, resistant, antimicrobial, host, antibiotic resistance, aureus, amr, tb, pathogens, virulence, strains, tuberculosis, coli, pathogen, mtb, vaccine, drug, gram, biofilm, antimicrobial resistance, aeruginosa, mrsa, proteins, phage]"
9,724,"[synthesis, reactions, chemistry, complexes, reaction, chiral, compounds, synthetic, catalysts, metal, organic, catalysis, asymmetric, catalytic, reactivity, bond, ligands, molecules, transition metal, ring, chemical, bonds, new, group, transition, reagents, catalyzed, enantioselective, ligand, organometallic]"


In [16]:
from transformers import pipeline, AutoTokenizer
evaluator = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(evaluator)
pipe = pipeline("text-generation", model=evaluator, device_map="auto", torch_dtype=torch.bfloat16)


tokenizer_config.json:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/638 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/8 [00:00<?, ?it/s]

model-00001-of-00008.safetensors:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

model-00002-of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

model-00003-of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00004-of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

model-00005-of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00006-of-00008.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

model-00007-of-00008.safetensors:   0%|          | 0.00/1.98G [00:00<?, ?B/s]

model-00008-of-00008.safetensors:   0%|          | 0.00/816M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [48]:
def evaluate(example):
    messages = [
        {
            "role": "system",
            "content": "You are a knowledgable science expert. Given list of words, find category that the words fit into. Omit all explanations and artifacts.",
        },
        { 
            "role": "user", 
            "content": "[apple, pear, carrot, potato, banana]"
        },
        { 
            "role": "assistant", 
            "content": "Fruits and Vegetbles"
        },
        { 
            "role": "user", 
            "content": "[car, bus, passanger, ferry]"
        },
        { 
            "role": "assistant", 
            "content": "Transportation"
        },
        { 
            "role": "user", 
            "content": example["text"]
        }
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.3, top_k=20, top_p=0.45, return_full_text=False)
    return {'label': outputs[0]["generated_text"]}


In [55]:
z  =  Dataset.from_pandas(topic_model.get_topic_info()[["Representation", 'Count']][50:100])
z = z.map(lambda e: {'text': "[{}]".format(", ".join(e['Representation']))})
x = z.map(evaluate)

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

In [56]:
x.with_format("pandas")[:][['Count','Representation', 'label']]

Unnamed: 0,Count,Representation,label
0,294,"[students, military connected, military, school, connected students, stem, connected, math, achievement, district, student, military dependent, schools, social emotional, school district, learning, dependent students, elementary, science, college career, grades, career, college, support, emotional, professional, professional development, teachers, instruction, instructional]",Education and Military Connected Communities
1,293,"[immune, tumor, cells, cancer, immunotherapy, nk, car, breast, antigen, vaccine, breast cancer, cell, nk cells, vaccines, antigens, patients, tumors, tumor cells, ovarian, anti, cancer cells, ovarian cancer, ctl, immune response, immunity, therapy, immune cells, cancers, anti tumor, response]",Immunology and Cancer Research
2,291,"[laser, electron, pulses, pulse, beam, ultrafast, ray, attosecond, high, femtosecond, electrons, power, laser pulses, intense, lasers, ionization, accelerator, source, field, emission, sources, accelerators, beams, light, coherent, plasma, dynamics, energy, physics, phase]",Attosecond and Ultrafast Science
3,285,"[rock, soil, numerical, geotechnical, clay, pore, tests, fracture, soils, piles, flow, hydraulic, pile, underground, loading, waste, geological, model, ground, borehole, gas, behaviour, reservoir, method, co2, models, deformation, transport, stress, pressure]",Geotechnical Engineering and Geology
4,283,"[cartilage, oa, joint, bone, articular, osteoarthritis, articular cartilage, knee, ptoa, hip, tissue, implant, joints, pain, repair, disc, ivd, degeneration, mechanical, osteoarthritis oa, implants, chondrocytes, clinical, traumatic, spine, injuries, regeneration, injury, tissues, biomechanical]",Orthopedics and Joint Health
5,279,"[parasite, malaria, parasites, falciparum, infection, host, leishmania, antigens, vaccine, leishmaniasis, plasmodium, drug, infected, gondii, brucei, immune, drugs, endemic, disease, trypanosomes, resistance, toxoplasma, infections, antimalarial, schistosomiasis, proteins, parasitic, human, vivax, immunity]",Parasitology and Infectious Diseases
6,276,"[hiv, living hiv, aids, prep, hiv aids, health, sexual, care, living, people living, hiv prevention, sex, community, people, prevention, msm, women, intervention, men, stigma, adherence, canada, plwh, indigenous, youth, risk, interventions, transmission, gay, sex workers]",HIV/AIDS and Sexual Health
7,271,"[climate, records, ice, glacial, lake, proxy, ocean, holocene, variability, atlantic, past, sediment, cores, changes, sediments, climatic, circulation, sea, record, isotope, interglacial, paleoclimate, north, change, reconstructions, resolution, climate change, temperature, climate variability, 000 years]",Paleoclimatology and Climate Science
8,271,"[manufacturing, product, scheduling, design, production, planning, supply, supply chain, problems, companies, chain, construction, industry, optimization, process, customer, inventory, management, systems, stochastic, transportation, engineering, tools, smes, models, model, simulation, information, problem, research]",Operations Management and Industrial Engineering
9,269,"[er, estrogen, breast, tamoxifen, breast cancer, estrogen receptor, receptor, cancer, estrogens, er alpha, growth, tumors, alpha, antiestrogens, hormone, resistance, aib1, antiestrogen, breast cancers, expression, cancers, cancer cells, receptors, src, cells, binding, genes, therapy, breast tumors, steroid]",Breast Cancer Research and Treatment


In [57]:
hierarchical_topics = topic_model.hierarchical_topics(splits["train"]["abstr"])


100%|██████████| 251/251 [00:08<00:00, 29.69it/s]


In [113]:
import plotly.io as pio
pio.renderers.default = 'notebook'
import plotly.offline as pyo
pyo.init_notebook_mode(connected=True)


In [115]:
import plotly.express as px
fig = px.bar(x=["a", "b", "c"], y=[1, 3, 2])
fig.show()


In [116]:
topic_model.visualize_topics()