In [1]:
import os
import numpy as np
import pandas as pd
import tomotopy as tp
from tomotopy.utils import Corpus
from local.caching import load, save

In [2]:
df: pd.DataFrame = load('biocyc', alt_workspace="../prep/")
df.shape

recovering & decompressing cached data from [../prep/cache/biocyc.pkl.gz]


(19999, 7853)

In [3]:
folder = "./cache/corpus"
corp_name = "ec_all"
corp_file = f"{folder}/{corp_name}"

if os.path.exists(corp_file):
    corp = Corpus.load(corp_file)
else:
    corp = Corpus()
    columns = df.columns
    label_len = max([len(n) for n in columns])

    for i, (_, row) in enumerate(df.iterrows()):
        if i % 16 == 0: print(f"{i+1} of {len(df)}", " "*25, end='\r')
        doc = np.empty(shape=(row.sum(),), dtype=f'<U{label_len}')
        doc_i = 0
        for j, c in enumerate(row):
            if c == 0: continue
            for _ in range(c):
                doc[doc_i] = columns[j]
                doc_i += 1
        corp.add_doc(doc)

    corp.save(corp_file)

In [10]:
K = 50
model_type = "ctm"
model_name = "ec"
model_path = f"./cache/{model_name}_{model_type}_{K}"
cls = {
    "ctm": lambda: tp.CTModel,
    "lda": lambda: tp.LDAModel,
}[model_type]()
if os.path.exists(model_path):
    model = cls.load(model_path)
else:
    model = cls(k=K, rm_top=5, min_cf=20)
    model.add_corpus(corp)
    for i in range(10):
        model.train(iter=10, workers=14)
        print(i, end='\r')
    model.save(model_path)

9

src/TopicModel/../Utils/TruncMultiNormal.hpp(56): wrong truncation range [-1.47544, -1.47544]


In [11]:
topic_d = []
for doc in model.docs:
    x = model.infer(doc)
    topic_d.append(x[0])
len(topic_d)

19999

In [12]:
import plotly.graph_objects as go

# settings

axis_col = 'rgba(0, 0, 0, 0.15)'
no_col = 'rgba(0, 0, 0, 0)'
axis_desc: dict = dict(linecolor=no_col, gridcolor=axis_col, zerolinecolor=axis_col, zerolinewidth=1)
layout = dict(
    autosize=False,
    width=1400,
    height=650,
    margin=dict(
        l=25, r=25, b=25, t=50, pad=5
    ),
    # paper_bgcolor="white",
    font_family="Times New Roman",
    font_color="black",
    font_size=20,
    plot_bgcolor='white',
    xaxis=axis_desc,
    yaxis=axis_desc,
    xaxis2=axis_desc,
    yaxis2=axis_desc,
)

s, o = 5, 1
fig = go.Figure(data=[
    go.Scatter(
        x=[i for g in [enumerate(topic) for topic in topic_d[:1000]] for i, v in g],
        y=[v for g in [enumerate(topic) for topic in topic_d[:1000]] for i, v in g],
        mode='markers',
        marker=dict(size=s,opacity=o)
    )
])
_layout = layout.copy()
_layout.update(dict(
    xaxis=dict(title="", **axis_desc),
    yaxis=dict(title="", **axis_desc),
    width=900,
))
fig.update_annotations(font_size=24)
fig.update_layout(_layout)
fig.show()