In [None]:
import os
import sys
import time
import json
import regex as re
from collections import Counter

import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

import numpy as np
import pandas as pd

from scipy.stats import chi2_contingency

import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("white")
sns.set_context("paper") #paper, notebook, talk, poster



# Figures

In [None]:
### Load master values for figures
notes_df = pd.read_parquet("./output/dhnotes/DH_annotated_notes.parquet.gzip")
terms = pd.read_csv("./searchterms.csv", header=0, encoding='mac-roman')
terms = terms[terms["EMERSE_check"] == "Yes"]
terms = terms[terms["File_group"]!="Telehealth"] # Telehealth removed

# term columns
term_cols = list(terms["Term_clean"])
term_cols = [t for t in term_cols if t in notes_df.columns]

# Add categories 
category_dict = dict(terms.groupby("Category")["Term_clean"].apply("|".join))

for category in category_dict:
    notes_df[category] = notes_df[category_dict[category].split("|")].sum(axis=1)
    notes_df[category] = notes_df[category] > 0 # becomes true/false
    

# Table 1: Demographics

In [None]:
# Load demographics
dh_demo = pd.read_csv("./output/dhnotes_demographics.csv", index_col=0)
no_dh_demo = pd.read_csv("./output/no_dhnotes_demographics.csv", index_col=0)

# Other params
stats_demo_df = pd.DataFrame()
demo_dict = {'sex':"Sex", 'preferredlanguage':"Preferred Language", 
             'ucsfderivedraceethnicity_x':"Race/Ethnicity", 'mychartstatus': "MyChart Status"}

# Get stats for each group
for group in ['sex', 'ucsfderivedraceethnicity_x', 'preferredlanguage', 'mychartstatus']:
    dh_subset = dh_demo[dh_demo["category"]==group][["value", "count", "proportion"]]
    no_dh_subset = no_dh_demo[no_dh_demo["category"]==group][["value", "count", "proportion"]]
    
    # Merge and add statistics
    demo_df = no_dh_subset.merge(dh_subset, how="inner", on="value", suffixes = ('', '_dh'),).set_index("value")
    demo_df = demo_df.replace("<10", np.nan)
    demo_df = demo_df.astype(float, errors="ignore")
    demo_df.loc[demo_dict[group]+" (Total)"] = demo_df.sum(axis=0, numeric_only=True)
    demo_df = demo_df.sort_values("count_dh", ascending=False) # sort by counts in DH note group
    
    demo_df["Digital health note"] = [np.nan if np.isnan(c)
                                      else str(int((c)))+" (%s%%)"%("{:.1f}".format(100*float(p))) if c!="<10" 
                                    else c for p,c in zip(demo_df["proportion_dh"], demo_df["count_dh"])]
    
    demo_df["No digital health note"] = [np.nan if np.isnan(c)
                                      else str(int((c)))+" (%s%%)"%("{:.1f}".format(100*float(p))) if c!="<10" 
                                              else c for p,c in zip(demo_df["proportion"], demo_df["count"])]

    # Add statistics
    chi_df = demo_df[["count", "count_dh"]]
    chi_df = chi_df.replace("<10", np.nan)
    chi_df = chi_df.astype(float)
    chi_df = chi_df.dropna(how="any")
   
    # Perform the chi-square test
    chi2, p_value, dof, expected = chi2_contingency(chi_df.values)

     # Create final demographics table 1 
    demo_df = demo_df[["No digital health note", "Digital health note"]]
    demo_df["pvalue"] = [""]+[p_value]+[""]*(len(demo_df)-2)
    demo_df = demo_df.replace(np.nan, "<10")
    stats_demo_df = pd.concat([stats_demo_df, demo_df])

stats_demo_df.to_csv("./output/figures/Table1.csv")


In [None]:
dh_demo = pd.read_csv("./output/dhnotes_demographics.csv", index_col=0)
no_dh_demo = pd.read_csv("./output/no_dhnotes_demographics.csv", index_col=0)


# Figure 2: Note distribution

## Figure S2A: Heatmap of terms across departments

In [None]:
### Load data
x_value = "encounterdepartmentspecialty"

values_df = notes_df.copy(deep=True)
values_df = values_df[["note_id", "patientid", x_value, "providertype"]+term_cols]
values_df["count"] = 1

### Supplemental table 1. Distribution of note terms across departments
numbers_df = values_df.groupby([x_value])["count"].sum().reset_index()
numbers_df = numbers_df.sort_values("count", ascending=False)
numbers_df.to_csv("./output/figures/supplement/TableS1Data.csv")

# print
top_departments = numbers_df[:25]["encounter_department_specialty"]
print("Number of DH notes per department (top 25)")
print(numbers_df[:25])
print()

### Supplemental table 2. Distribution of note terms across different departments
values_df = notes_df.copy(deep=True)
values_df = values_df[['patientid', 'encounterkey', x_value]+term_cols]

# pivot to table containing columns ["term", "termpresent"] indexed by [encounterkey, patientid, and x_value]
values_df = values_df.melt(id_vars=['encounterkey', 'patientid', x_value], 
                           value_vars=term_cols, var_name='term', value_name='termPresent')
values_df = values_df[values_df["termPresent"]]

# count number of DH terms present
values_df = values_df.groupby(["term", x_value]).sum(numeric_only=False)[["termPresent"]].reset_index()

# save counts by [term, department]
values_df = values_df.sort_values("termPresent", ascending=False)
values_df.to_csv("./output/figures/supplement/TableS2Data.csv")

# print top terms (across all notes)
top_terms = values_df.groupby("term")[["termPresent"]].sum().sort_values("termPresent", ascending=False)
print('Number of top 10 digital health term notes (may be greater than # notes due to notes with multiple DH terms)')
print(top_terms)

top_terms = top_terms[:10].index

# Get depts and terms that have at least 1000 values
heatmap_terms = values_df.groupby("term")[["termPresent"]].sum().sort_values("termPresent", ascending=False)
heatmap_terms = heatmap_terms[heatmap_terms["termPresent"] > 999]
values_df = values_df[values_df["term"].isin(list(heatmap_terms.index))]

heatmap_depts = values_df.groupby("encounter_department_specialty")[["termPresent"]].sum().sort_values("termPresent", ascending=False)
heatmap_depts = heatmap_depts[heatmap_depts["termPresent"] > 999]
values_df = values_df[values_df["encounter_department_specialty"].isin(list(heatmap_depts.index))]

# Plot values
heatmap_df = values_df.pivot(columns="term", index ="encounter_department_specialty", values="termPresent")
heatmap_df = heatmap_df.replace(np.nan,0)
ax = sns.clustermap(heatmap_df, vmax=800, vmin=-150, **{"cmap":"BuPu"})
ax.figure.savefig("./output/figures/supplement/Figure2SA.pdf")


## Figure S2: Distribution of top terms across departments

In [None]:

### Figure 2. Bar plot of top terms and departments
# get only top 5 departments
x_n = 5

top_p = list(top_departments)[:x_n]
plot_df = values_df.copy(deep=True)
plot_df[x_value] = ["Unspecified" if p=="UCSF" 
                    else p if p in top_p
                    else "Other" for p in plot_df[x_value]]
plot_df = plot_df[plot_df[x_value].isin(list(top_p))]

# get only top 10 term values
plot_df = plot_df[plot_df["term"].isin(top_terms[:10])]
xtab_df = pd.crosstab(index=plot_df[x_value], columns=plot_df.term, 
                      values=plot_df.termPresent, aggfunc=np.sum, normalize=False)
notes_order = xtab_df.sum(axis=1).sort_values(ascending=False)

# crosstab for barplot
xtab_df_norm = pd.crosstab(index=plot_df[x_value], columns=plot_df.term, 
                      values=plot_df.termPresent, aggfunc=np.sum, normalize='index')
xtab_df_norm = xtab_df_norm.loc[top_p]
ax = xtab_df_norm.plot(kind="bar", stacked=True, rot=0, figsize=(10, 8))
ax.legend(title='Digital health term', bbox_to_anchor=(1, 1.02), loc='upper left')
ax.set(xlabel=None)
ax.figure.savefig("./output/figures/Figure2A.pdf", bbox_inches='tight')

# Get raw data
xtab_df = xtab_df.replace(np.nan, 0)
print(xtab_df.shape)

# Perform the chi-square test
chi2, p_value, dof, expected = chi2_contingency(xtab_df.loc[top_p].values)

# Save raw data
xtab_df = xtab_df.round(3)
xtab_df_norm = xtab_df_norm.round(3)
xtab_df.to_csv("./output/figures/raw/Figure2AData_1.csv")
xtab_df_norm.to_csv("./output/figures/raw/Figure2AData_2.csv")


# Print the results
print("Chi-square statistic:", chi2)
print("p-value:", p_value)
print("Degrees of freedom:", dof)



## Figure 2SB: Top terms across provider type

In [None]:
x_value = "providertype"#"prov_specialty" #"primarycoveragefinancialclass"

values_df = notes_df.copy(deep=True)
values_df = values_df[[x_value, "encounterkey"]+term_cols]
values_df["count"] = 1

# Get counts by provider and term
plot_df = values_df.groupby(x_value).sum()
plot_df.to_csv("./output/figures/raw/Figure2SBData.csv") # save all values
plot_df = plot_df.sort_values("count").iloc[-25:]
del plot_df["count"]
top_terms = plot_df.sum(axis=0).sort_values().iloc[-25:].index
plot_df = plot_df[top_terms]

ax = sns.clustermap(plot_df, vmax=800, vmin=-150, **{"cmap":"BuPu"})
ax.figure.savefig("./output/figures/supplement/Figure2SB.pdf", bbox_inches='tight')

# Perform the chi-square test
chi2, p_value, dof, expected = chi2_contingency(plot_df.values)

# Print the results
print("Chi-square statistic:", chi2)
print("p-value:", p_value)
print("Degrees of freedom:", dof)


# Figure 3: Note occurence over time

## Figure 3A: Digital health term occurence


In [None]:
from utils.figures import plot_notes_over_time, calculate_cagr, ridge_plot

# Get all note counts
all_note_counts_df = pd.DataFrame.from_dict({2021:14205387,
                                   2020:13487925,
                                   2019:12773799,
                                   2022:11521289,
                                   2018:11416946,
                                   2017:10749857,
                                   2016:10453727,
                                   2013:10434202,
                                   2015: 9714333,
                                   2014: 8671952,
                                   2012: 5685316}, orient="index")

all_note_counts_df = all_note_counts_df.reset_index()
all_note_counts_df["Digital health term"] = "All notes"
all_note_counts_df.columns = ["Year", "Count","Digital health term", ]
total_cagr=calculate_cagr(all_note_counts_df, group_col="Digital health term", time_col="Year", count_col="Count")
print(total_cagr)
'''
 CAGR (%)      Count
Digital health term                     
All notes            7.318513  119114733
'''

### Figure 3A. Change in term type over time
values_df = notes_df[["note_id", "year", "note_text_clean"] + term_cols]
values_df = values_df.set_index("note_id")
values_df = values_df.groupby("year").sum()
values_df = values_df.stack().reset_index()
values_df.columns = ["Year", "Digital health term", "Count"]

# Get CAGR values
cagr_df = calculate_cagr(values_df, group_col="Digital health term", time_col="Year", count_col="Count")

# add values for all digital health notes
dh_note_counts_df = values_df.groupby("Year").sum()[["Count"]]
dh_note_counts_df = dh_note_counts_df.reset_index()
dh_note_counts_df["Digital health term"] = "DH notes"
dh_note_counts_df.columns = ["Year", "Count","Digital health term", ]
dh_cagr = calculate_cagr(dh_note_counts_df, group_col="Digital health term", time_col="Year", count_col="Count")
print(dh_cagr)
'''
CAGR (%)   Count
Digital health term                   
DH notes             26.883556  226898
'''
# add values for all & DH notes notes
cagr_df = pd.concat([cagr_df,dh_cagr, total_cagr])
cagr_df.to_csv("./output/figures/raw/Figure3AData.csv")

# Get largest departments
top_n = 10
top_values = cagr_df.sort_values("Count").iloc[-(top_n+2):].index #cagr_df[cagr_df["Count"]>999].iloc[list(range(-top_n, 0))].index
#top_values = list(values_df.groupby("Digital health term")["Count"].sum(numeric_only=True).sort_values()[-top_n:].index)
plot_df = values_df[values_df["Digital health term"].isin(top_values)]
plot_df = pd.concat([plot_df, dh_note_counts_df, all_note_counts_df])
plot_df = plot_df.reset_index(drop=True)

sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

# Initialize the FacetGrid object
facet_order = list(top_values)
facet_order.reverse()
order = ["All notes", "DH notes"] + [f for f in facet_order if f not in ("All notes", "DH notes")]

# colors
pal = sns.cubehelix_palette(start=0.5, rot=-0.45, dark=0.1, light=.6, gamma=0.95, reverse=False, hue=0.9,as_cmap=True)

# Add color by CAGR
hue="Digital health term"

plot_df = plot_df[plot_df[hue].isin(order)]
plot_df["CAGR"] = plot_df[hue].map(cagr_df["CAGR (%)"])
vmax = plot_df['CAGR'].max()
vmin = plot_df['CAGR'].min()
g = ridge_plot(plot_df, hue=hue, order=order, pal=pal, vmax = vmax,vmin=vmin)
g.figure.savefig("./output/figures/Figure3A.pdf", bbox_inches='tight')



In [None]:
sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

# Figure 3B data. 
plot_df = notes_df.groupby(["year", "encounter_department_specialty"])["note_id"].count().reset_index()
plot_df = plot_df.sort_values("note_id", ascending=False)
plot_df.columns = ["Year", "Specialty", "Count"]
dh_note_counts_df.columns = ["Year",  "Count", "Specialty",]
all_note_counts_df.columns = ["Year",  "Count", "Specialty",]
plot_df = pd.concat([plot_df, dh_note_counts_df, all_note_counts_df])
plot_df = plot_df.reset_index(drop=True)

# Save all cagr values
cagr_df = pd.concat([cagr_df,dh_cagr, total_cagr])
cagr_df = calculate_cagr(plot_df, group_col="Specialty", time_col="Year", count_col="Count")
cagr_df.to_csv("./output/figures/raw/Figure3BData.csv")

top_values = cagr_df.sort_values("Count").iloc[-(top_n+2):].index #[cagr_df["Count"]>999].iloc[list(range(-top_n, 0))+list(range(0, top_n))].index

# Initialize the FacetGrid object
facet_order = list(top_values)
facet_order.reverse()
order = ["All notes", "DH notes"] + [f for f in facet_order if f not in ("All notes", "DH notes")]

# colors
pal = sns.cubehelix_palette(start=0.5, rot=-0.45, dark=0.1, light=.6, gamma=0.95, reverse=False, hue=0.9,as_cmap=True)

# Add color by CAGR
hue="Specialty"
plot_df = plot_df[plot_df[hue].isin(order)]
plot_df["CAGR"] = plot_df[hue].map(cagr_df["CAGR (%)"])

# use same min/max values as previous for same scale ridge plots
g = ridge_plot(plot_df, hue=hue, order=order, pal=pal, vmin=vmin, vmax=vmax)
g.figure.savefig("./output/figures/Figure3B.pdf", bbox_inches='tight')



# Figure 4: LDA groups

In [None]:
import glob
from utils.ClinicalNoteLDA import ClinicalNoteLDA

### Do topic modeling for each type of digital health term
# Load data
dh_terms_df = pd.DataFrame()
dh_term_files = glob.glob("./output/dhnotes/sentences/*.parquet.gzip")
dh_terms_clean = list(terms["Term_clean"])

# get all sentences from all patients
# sent_extended means each digital health sentence + sentence before and sentence after
for file in dh_term_files:
    term_df = pd.read_parquet(file)
    term_df["note_text_dh_sent"] = term_df["note_text_dh_sent"].apply(lambda x: " ".join(x))
    term_df["note_text_dh_sent_extended"] = term_df["note_text_dh_sent_extended"].apply(lambda x: " ".join(x))
    dh_terms_df = pd.concat([dh_terms_df, term_df])
    
# Get all sentences from each note
dh_term_sentences = dh_terms_df.groupby("note_id")["note_text_dh_sent"].sum()
notes_df["sentences"] = notes_df["note_id"].map(dh_term_sentences)

dh_terms_df = notes_df.dropna(subset="sentences")
dh_terms_df["All"] = True
'''
# Coherence metrics for topic selection
figure_name = "Figure4"
category="All"

dh_subset_df = dh_terms_df[dh_terms_df[category]]

med_lda = ClinicalNoteLDA(list(dh_subset_df["sentences"]))
custom_stopwords = ["ug", "mg","kg", "ml", "year", "years", "month", "months", "day", "days"]
med_lda.preprocessDHNotes(custom_stopwords=custom_stopwords)

# Grid search - Done
grid_search = {"num_topics":[10,11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 40, 45, 50]}
coherence_df = med_lda.hparam_sweep(grid_search, no_below=5, no_above=0.5, coherence="c_npmi")
fig, ax = plt.subplots(figsize=(10,10))
ax = sns.scatterplot(data=coherence_df,  y=0, x="index") 

# Save coherence figure and raw data
ax.figure.savefig(f"./output/figures/supplement/{figure_name}Data_coherence.pdf", bbox_inches='tight')
coherence_df.to_csv(f"./output/figures/supplement/{figure_name}Data_coherence.csv")
'''


In [None]:
'''figures_dict = {"Figure4A":"Connected_digital_product",
                "Figure4B":"Generic_software_intervention",
                "FigureS4A":"Single_intervention",
                "FigureS4B":"Other"}

### Figure 4A, 4B, 4SA, 4SB: LDA coherence scores for each category
for figure_name in figures_dict:
    category = figures_dict[figure_name]
    dh_subset_df = dh_terms_df[dh_terms_df[category]]

    med_lda = ClinicalNoteLDA(list(dh_subset_df["sentences"]))
    med_lda.preprocessDHNotes()

    # Grid search - Done
    grid_search = {"num_topics":[10,11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]} # 10 is best with no_below=5, no_above=0.5
    coherence_df = med_lda.hparam_sweep(grid_search, no_below=5, no_above=0.5, coherence="c_npmi")
    fig, ax = plt.subplots(figsize=(10,10))
    ax = sns.scatterplot(data=coherence_df,  y=0, x="index") 

    # Save coherence figure and raw data
    ax.figure.savefig(f"./output/figures/supplement/{figure_name}Data_coherence.pdf", bbox_inches='tight')
    coherence_df.to_csv(f"./output/figures/supplement/{figure_name}Data_coherence.csv")

'''

In [None]:
from sklearn.feature_extraction.text import CountVectorizer

'''figures_dict = {"Figure4A":"Connected_digital_product",
                "Figure4B":"Generic_software_intervention",
                "FigureS4A":"Single_intervention",
                "FigureS4B":"Other"}
'''
figures_dict = {"Figure4":"All"}

### Figure 4A, 4B, 4SA, 4SB: Best LDA topics for each category
### Separating out this for loop so the previous cell does not have to be rerun
for figure_name in figures_dict:
    category = figures_dict[figure_name]
    dh_subset_df = dh_terms_df[dh_terms_df[category]]

    med_lda = ClinicalNoteLDA(list(dh_subset_df["sentences"]))
    med_lda.preprocessDHNotes()
    
    # Get best number of notes (highest coherence score)
    best_df = pd.read_csv(f"./output/figures/supplement/{figure_name}Data_coherence.csv", index_col=0)
    best_df = best_df.sort_values("0", ascending=False)
    best_k = best_df.iloc[0]["index"]
    
    # Create final LDA
    med_lda.create_lda(no_below=5, no_above=0.5, num_topics=best_k, coherence='c_npmi') #c_uci
    vis = med_lda.visualize_lda()
    
    # Get top topics as scatterplot
    supplement= "/supplement" if "FigureS" in figure_name else ""
    #axis = (-0.4,0.4) if if "FigureS" in figure_name else (-0.3,0.3) # TODO: fix the axis
    # ) for all
    
    fig = ClinicalNoteLDA.scatterplot_topics(vis, x_axis_lim=[-0.42, 0.24],y_axis_lim=[-0.3, 0.25],)
    fig.figure.savefig(f"./output/figures{supplement}/{figure_name}_Scatter.pdf", bbox_inches='tight')
  
    # Get top words as barplot
    #fig = ClinicalNoteLDA.barplot_top_terms(vis, normalize=True)
    #fig.figure.savefig(f"./DHinRWD/dataOutput{supplement}/{figure_name}_bar.pdf", bbox_inches='tight')
    
    # print top words in topic 
    #med_lda.lda_model.print_topics(num_words=10)

    # Get top 10 terms in each topic
    topics_df = pd.DataFrame(med_lda.lda_model.show_topics(formatted=False,num_topics=-1))
    topics_df["topics"] = [",".join([t[0] for t in terms]) for terms in topics_df[1]]
    topics_df["freq"] = [[t[1] for t in terms] for terms in topics_df[1]]
    topics_df = topics_df.iloc[:, [0,2,3]]
    topics_df.columns = ["unsorted_topic_num", "Top terms", "Frequency"]

    # Assign each note to a topic and add counts
    get_document_topics = [med_lda.lda_model.get_document_topics(item) for item in med_lda.corpus]

    dh_subset_df["topic"]= [i[0][0] if len(i)==1 else ClinicalNoteLDA.get_top_topic(i)[0] for i in get_document_topics]
    dh_subset_df["probs"]= [i[0][1] if len(i)==1 else ClinicalNoteLDA.get_top_topic(i)[1] for i in get_document_topics]

    # Supplemental figure: Save probability distribution of topics
    fig, ax = plt.subplots(figsize=(12,8))
    ax = sns.histplot(dh_subset_df["probs"])
    ax.set(xlabel="Probability")

    supplement= "/supplement" if "FigureS" in figure_name else "/raw"
    ax.figure.savefig(f"./output/figures/supplement/{figure_name}_TopicDistribution.pdf", bbox_inches='tight')

    # Raw data: Add counts of each topic to top terms dataframe & save
    topics_df["Document counts"] = topics_df["unsorted_topic_num"].map(dh_subset_df.value_counts("topic"))
    
    # Map sorted topic numbers from pyLDAvis (sorts by default) to original topics
    topic_values_df = vis.topic_info.copy(deep=True) 
    topic_values_df = topic_values_df.sort_values("Freq", ascending=False).groupby("Category").head(10)
    topic_values_df["Sum"] = topic_values_df.groupby("Category")["Freq"].transform('sum')
    topic_values_df["Frequency"] = topic_values_df["Freq"] / topic_values_df["Sum"]
    topic_values_df = pd.DataFrame(topic_values_df.groupby("Category")["Term"].apply(lambda x: ",".join(x)))
    topics_df["sorted_topic"] = topics_df["Top terms"].map(dict(zip(topic_values_df["Term"], topic_values_df.index)))

    topics_df["sorted_topic_num"] = [int(t.split("Topic")[1]) if "Topic" in t else t for t in topics_df["sorted_topic"]]
    topics_df = topics_df.sort_values("Document counts", ascending=False)

    supplement= "/supplement" if "FigureS" in figure_name else ""
    topics_df.to_csv(f"./output/figures{supplement}/{figure_name}Topics.csv")
    
    # Save topics
    dh_subset_df["sorted_topic_num"] = dh_subset_df["topic"].map(dict(zip(topics_df["unsorted_topic_num"], topics_df["sorted_topic_num"])))
    topic_distribution = dh_subset_df[["note_id", "patientid", "sentences", "topic" ,"sorted_topic_num", "probs"]]
    topic_distribution.to_parquet(f"./output/dhnotes/{figure_name}DHnotes_topics.parquet.gzip", compression="gzip")

    

In [None]:
# Figure 4B
term_freq = pd.read_csv("./output/figures/Figure4Topics.csv", index_col=0)
term_freq["Term"] = [t.split(",") for t in term_freq["Top terms"]]
term_freq["Frequency"] = [t.strip("[]").split(", ") for t in term_freq["Frequency"]]
term_freq["Frequency"] = [[float(f) for f in t] for t in term_freq["Frequency"]]
term_freq = term_freq.explode(["Term", "Frequency"])
term_freq = term_freq[["sorted_topic", "Term", "Frequency", "Document counts"]]

# Get top topics by document counts to plot
top_topics = term_freq.groupby("sorted_topic").first().sort_values("Document counts")
plot_df = term_freq[term_freq["sorted_topic"].isin(top_topics.iloc[-10:].index)]

g = sns.catplot(data=plot_df, x="Frequency", y="Term", col="sorted_topic",
                         kind="bar", height=6, aspect=0.7, col_wrap=5, 
                         sharey=False, sharex=False, facet_kws=dict(margin_titles=True), )

g.figure.subplots_adjust(wspace=.55, hspace=.2)
g.set_titles(template="{col_name}")
g.figure.savefig(f"./output/figures/Figure4B.pdf", bbox_inches='tight')


In [None]:
# Figure 4B
term_freq = pd.read_csv("./output/figures/Figure4Topics.csv", index_col=0)
term_freq["Term"] = [t.split(",") for t in term_freq["Top terms"]]
term_freq["Frequency"] = [t.strip("[]").split(", ") for t in term_freq["Frequency"]]
term_freq["Frequency"] = [[float(f) for f in t] for t in term_freq["Frequency"]]
term_freq = term_freq.explode(["Term", "Frequency"])
term_freq = term_freq[["sorted_topic", "Term", "Frequency", "Document counts"]]

# Get top topics by document counts to plot
top_topics = term_freq.groupby("sorted_topic").first().sort_values("Document counts")
plot_df = term_freq[term_freq["sorted_topic"].isin(top_topics.iloc[:-10].index)]

g = sns.catplot(data=plot_df, x="Frequency", y="Term", col="sorted_topic",
                         kind="bar", height=6, aspect=0.7, col_wrap=5, 
                         sharey=False, sharex=False, facet_kws=dict(margin_titles=True), )

g.figure.subplots_adjust(wspace=.55, hspace=.2)
g.set_titles(template="{col_name}")
g.figure.savefig(f"./output/figures/supplement/FigureS4A.pdf", bbox_inches='tight')
