# __Step 4.4: Topic over time__

Goals here:
- Analyze topics over time

## ___Set up___

### Module import

In [None]:
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from pathlib import Path
from bertopic import BERTopic
from tqdm import tqdm
import csv
from xlsxwriter.workbook import Workbook
from plotly.io import write_image
from sklearn.metrics.pairwise import cosine_similarity
from umap import UMAP
from sklearn.preprocessing import MinMaxScaler

### Key variables

In [None]:
# Reproducibility
seed = 20220609

# Setting working directory
proj_dir   = Path.home() / "projects/plant_sci_hist"
work_dir   = proj_dir / "4_topic_model/4_4_over_time"
work_dir.mkdir(parents=True, exist_ok=True)

# plant science corpus
dir25       = proj_dir / "2_text_classify/2_5_predict_pubmed"
corpus_file = dir25 / "corpus_plant_421658.tsv.gz"

# saved model and probability file
dir42            = proj_dir / "4_topic_model/4_2_outlier_assign"
topic_model_file = dir42 / "topic_model_updated"
prob_file        = dir42 / "probs.pickle"
embedding_file   = dir42 / "embeddings_scibert.pickle"

# So PDF is saved in a format properly
mpl.rcParams['pdf.fonttype'] = 42
plt.rcParams["font.family"] = "sans-serif"

## ___Load data___

### Load original corpus

In [None]:
df_corpus = pd.read_csv(corpus_file, sep='\t')
df_corpus.head(3)

In [None]:
df_corpus.shape

### Load cleaned data

In [None]:
with open(docs_clean_file, "rb") as f:
  docs_clean = pickle.load(f)

### Load topic model and probability

In [None]:
# Load topic model
topic_model = BERTopic.load(topic_model_file)

In [None]:
# load prob
with open(prob_file, "rb") as f:
  probs = pickle.load(f)

In [None]:
help(topic_model)

## ___Basic summary___

### Topic size distribution

See Revisit topic size plot for an updated version used for graphics.

In [None]:
topic_info = topic_model.get_topic_info()
topic_info

In [None]:
plt.hist(np.log10(topic_info["Count"]), bins=200)
plt.xlabel("log10(Count)")
plt.ylabel("Frquency")
plt.xlim(2.5,5)
plt.savefig(work_dir / "fig4_3_topic_count_dist.pdf")

### Representative docs

In [None]:
rep_docs = topic_model.get_representative_docs()

# So outlier topic is not included
type(rep_docs), len(rep_docs.keys()), len(rep_docs[0])

In [None]:
# Write to a tsv
rep_docs_file = work_dir / "topic_rep_docs.tsv"
rep_docs_df   = pd.DataFrame.from_dict(rep_docs, orient='index',
                                  columns=['doc1', 'doc2', 'doc3'])
rep_docs_df.head()

In [None]:
rep_docs_df.to_csv(rep_docs_file, sep='\t')

## ___Get top words for different topics___

### generate_topic_labels

Get labels for each topic in a user-defined format
- Try nr_words=10
- E.g., cluster 5 does not make sense 

In [None]:
topic_labels = topic_model.generate_topic_labels(nr_words=10,
                                                 topic_prefix=True,
                                                 separator='|')

In [None]:
type(topic_labels), topic_labels[:10]

### get_topic

Return top 10 words for a specific topic and their c-TF-IDF scores

In [None]:
# Example
topic0 = topic_model.get_topic(0)
type(topic0)

In [None]:
topic0

In [None]:
# Get all topic top 10 words, exclude the outlier cluster
# Ok, this can be done with top_model.get_topics()

#topic_top10 = {} # {cluster_id: top_10_list}
#for cluster_id in range(topic_info.shape[0]-1):
#  topic_top10[cluster_id] = topic_model.get_topic(cluster_id)


In [None]:
all_topics = topic_model.get_topics()
type(all_topics), all_topics[0]

### Topic-term matrix

Follow [this](https://maartengr.github.io/BERTopic/getting_started/tips_and_tricks/tips_and_tricks.html#topic-term-matrix)
- The approaches before only give the top 10 terms. Want to get more from each topic.
- So process the matrix instead.
- To get the top n entries per row, follow [this post](https://stackoverflow.com/questions/31790819/scipy-sparse-csr-matrix-how-to-get-top-ten-values-and-indices).
- Also, see [this post](https://stackoverflow.com/questions/3179106/python-select-subset-from-list-based-on-index-set) for selecting a subset from a list based on indices.


In [None]:
# Sparse matrix with topics as rows and features (i.e. terms) as columns, 
# values are c-Tf-idf
topic_term_matrix = topic_model.c_tf_idf

In [None]:
type(topic_term_matrix), topic_term_matrix.shape

In [None]:
# A list of features (terms)
terms = topic_model.vectorizer_model.get_feature_names()

In [None]:
type(terms), len(terms)

In [None]:
# Get top 50 terms
top_50 = {} # {topic:[top50_idx_list, top50_c-tf-idf_list, to50_feat_list]}

# Skip the outlier topic, named the variable topic_plus1 because the topic
# index is -1 from the index in the topic_term_marix.
for topic_plus1 in tqdm(range(1, topic_term_matrix.shape[0])):
  row     = topic_term_matrix.getrow(topic_plus1).toarray()[0].ravel()

  # The following two lines sorted from low to high
  t50_idx = list(row.argsort()[-50:])
  t50_val = list(row[row.argsort()[-50:]])

  t50_fea = [terms[i] for i in t50_idx]
  top_50[topic_plus1-1] = [t50_idx, t50_val, t50_fea]

In [None]:
# Save the top 50 terms
with open(work_dir / 'top_50_terms_per_topic.pickle', 'wb') as f:
  pickle.dump(top_50, f)

### Save top terms for different topics into an xlsx file

See [this post](https://www.geeksforgeeks.org/convert-a-tsv-file-to-excel-using-python/) for saving TSVs into spreadsheet
- Also include a topic label sheet where label is the top 10 words
- Also include a representative doc sheet where the topic 3 docs of each topic is included.

In [None]:
xlsx_file = work_dir / "table_top_50.xlsx"
xlsx      = Workbook(xlsx_file)

In [None]:
# Incude the top 10 words in a worksheet
topic_label_file = work_dir / "topic_labels.txt"

# Do not output outlier
topic_label_df = pd.DataFrame(topic_labels[1:])
topic_label_df.columns = ["label"]
topic_label_df.to_csv(topic_label_file, sep='\t')

worksheet = xlsx.add_worksheet("topic_label")
read_tsv = csv.reader(open(topic_label_file,'r',encoding='utf-8'),delimiter='\t')
for row, data in enumerate(read_tsv):
  worksheet.write_row(row, 0, data)

In [None]:
# Incude the representative doc in a worksheet
worksheet = xlsx.add_worksheet("representative docs")
read_tsv = csv.reader(open(rep_docs_file,'r',encoding='utf-8'),delimiter='\t')
for row, data in enumerate(read_tsv):
  worksheet.write_row(row, 0, data)

In [None]:
# Put the top 50 term info into different tsv files in the top_50 folder
top_50_dir = work_dir / "top_50"
top_50_dir.mkdir(parents=True, exist_ok=True)

# Output individual tsv files and put tsv into xlsx
for topic in top_50:
  topic_file = top_50_dir / f"topic_{topic}.tsv"
  # The nested list has index, c-tf-idf, and feature as rows. So it is transposed
  # to have the rows as columns. The iloc bit is to reverse the order so higher
  # c-tf-idf entries are on top.
  topic_df = pd.DataFrame(top_50[topic]).transpose().iloc[::-1]
  topic_df.columns = ["index", "c-tf-idf", "feature"]
  topic_df.to_csv(topic_file, sep='\t')

  # Save to xlsx
  worksheet = xlsx.add_worksheet(f"{topic}")
  read_tsv  = csv.reader(open(topic_file, 'r',encoding='utf-8'),delimiter='\t')
  for row, data in enumerate(read_tsv):
    worksheet.write_row(row, 0, data)

xlsx.close()

## ___Visualize topic___

### Topic relations in 2D

- visualize_topics:
  - This is useful to see how topics are related to each other in 2D.

In [None]:
vis1 = topic_model.visualize_topics()
type(vis1)

In [None]:
vis1

In [None]:
vis1.write_html(work_dir / "fig4_3_topic_relation_2d.html")
write_image(vis1, work_dir / "fig4_3_topic_relation_2d.pdf", 
            format='pdf')

### Get the dataframe for plotting the 2D graph

Get the dataframe used to plot this thing out, based on the codes in [here](https://github.com/MaartenGr/BERTopic/blob/master/bertopic/plotting/_topics.py).

In [None]:
topic_list   = sorted(topics)
all_topics   = sorted(list(topic_model.get_topics().keys()))
indices      = np.array([all_topics.index(topic) for topic in topics])
frequencies  = [topic_model.topic_sizes[topic] for topic in topic_list]

words = [" | ".join([word[0] for word in topic_model.get_topic(topic)[:5]]) 
                                                      for topic in topic_list]

In [None]:
embed_ctfidf = topic_model.c_tf_idf.toarray()[indices]

In [None]:
embed_scaled = MinMaxScaler().fit_transform(embed_ctfidf)

In [None]:
# This takes ~3 min.
embed_umap   = UMAP(
  n_neighbors=2, n_components=2, metric='hellinger').fit_transform(embed_scaled)

In [None]:
embed_df = pd.DataFrame({"x": embed_umap[:, 0], "y": embed_umap[:, 1],
                       "Topic": topic_list, "Words": words, "Size": frequencies})

In [None]:
embed_df.head(3)

In [None]:
embed_df.to_csv(work_dir / "table4_3_topic_relation_embedding_scaled_umap.tsv",
                sep='\t')

## ___Topic hierachical relations___

### Topic tree

In [None]:
probability_threshold = np.percentile(probs, 95)

# For hierachical_topics, a list of topics is required. This is returned by
# fit or fit_transform, but it does not make sense to run it again. So I get the
# topic cluster assignment based on probabilities.
topics = [np.argmax(prob) if max(prob) >= probability_threshold else -1 
                                                            for prob in probs]

In [None]:
hierarchical_topics = topic_model.hierarchical_topics(docs_clean, topics)

In [None]:
type(hierarchical_topics), hierarchical_topics.shape, hierarchical_topics.head()

In [None]:
fig_hier_bot = topic_model.visualize_hierarchy(orientation="bottom")
fig_hier_bot.write_html(work_dir / "fig4_3_topic_hierarchy_bottom.html")

In [None]:
fig_hier_lef = topic_model.visualize_hierarchy(orientation="left")
fig_hier_lef.write_html(work_dir / "fig4_3_topic_hierarchy_left.html")

In [None]:
write_image(fig_hier_bot, work_dir / "fig4_3_topic_hierarchy_bottom.pdf", 
            format='pdf')
write_image(fig_hier_lef, work_dir / "fig4_3_topic_hierarchy_left.pdf", 
            format='pdf')

### Topic relations heatmap

This is not particularly helpful out of the box as the axes are not clustered.
- Look into [source code](https://github.com/MaartenGr/BERTopic/blob/master/bertopic/plotting/_heatmap.py) to see if I can get the distance matrix out and do my own.

In [None]:
fig_heatmap = topic_model.visualize_heatmap()
fig_heatmap.write_html(work_dir / "fig4_3_topic_heatmap.html")
write_image(fig_heatmap, work_dir / "fig4_3_topic_heatmap.pdf", 
            format='pdf')

### Topic relations heatmap - manual

In [None]:
# Load embeddings

# The following won't work because when I save the model, I did not save the
# embeddings. I did not save it because it is precomputed in 4.2.
#embeddings = np.array(topic_model.topic_embeddings)

# Load the saved embedding file
# The following won't work because it ask for >600Gb of memory when the distance
# matrix is being created.
#with open(embedding_file, "rb") as f:
#  embeddings = pickle.load(f)

# Realize that the embedding here is the topic embedding, not doc embedding.
# So the above is not useful. 
embeddings = topic_model.c_tf_idf

In [None]:
type(embeddings), embeddings.shape

In [None]:
freq_df = topic_model.get_topic_freq()
topics  = sorted(freq_df.Topic.to_list())

In [None]:
distance_matrix = cosine_similarity(embeddings)

In [None]:
dist_matrix_df = pd.DataFrame(distance_matrix,
                              index=topics,
                              columns=topics)

In [None]:
topic_clustergrid = sns.clustermap(dist_matrix_df, cmap="coolwarm", 
                                  xticklabels=False, yticklabels=True)
topic_clustergrid.savefig(work_dir / 'fig4_3_topic_heatmap_seaborn.pdf')

In [None]:
new_labels = [[[f"{topic}\t", None]] + topic_model.get_topic(topic) 
                                                    for topic in topics]
new_labels = ["|".join([label[0] for label in labels[:10]]) 
                                                    for labels in new_labels]
#new_labels = [label if len(label) < 30 else label[:27] + "..." 
#                                                    for label in new_labels]

In [None]:
new_labels[:2]

In [None]:
# The returned order is in row index, not label or column names
order_idx   = topic_clustergrid.dendrogram_row.reordered_ind

# Create a list with each element containing topic labels
topic_order = [f"{new_labels[i]}" for i in order_idx]

# Write the topic order into a file
with open(work_dir / "fig4_3_topic_heatmap_seaborn_order.txt", "w") as f:
  f.write("\n".join(topic_order))

### Topic relations heatmap - lower half

Generate another version with only the lower half, see [this post](https://stackoverflow.com/questions/67879908/lower-triangle-mask-with-seaborn-clustermap) but did not lead to a figure. Try [this](https://medium.com/@fleetw00d/plotting-a-triangluar-portion-of-a-seaborn-clustermap-92f3405c2f4d).

In [None]:
mask   = np.tril(np.ones_like(dist_matrix_df))
values = topic_clustergrid.ax_heatmap.collections[0].get_array().reshape(
                                                          dist_matrix_df.shape)
new_values = np.ma.array(values, mask=mask)
topic_clustergrid.ax_heatmap.collections[0].set_array(new_values)
plt.show()

In [None]:
mask = np.zeros_like(dist_matrix_df)
mask[np.triu_indices_from(mask)] = True
g = sns.clustermap(dist_matrix_df, mask=mask, vmax=.3, figsize=(0.1,0.1))
mask = mask[np.argsort(g.dendrogram_row.reordered_ind),:]
mask = mask[:,np.argsort(g.dendrogram_col.reordered_ind)]
topic_clustergrid_lower = sns.clustermap(dist_matrix_df, 
                                         figsize=(40,40), mask=mask, 
                                         cmap='coolwarm', 
                                         xticklabels=False, 
                                         yticklabels=dist_matrix_df.columns)
topic_clustergrid_lower.ax_col_dendrogram.set_visible(False)
topic_clustergrid_lower.savefig(work_dir/'fig4_3_topic_heatmap_seaborn_lower.pdf')

### Modified top terms

Take `fig4_3_topic_heatmap_seaborn_order.txt`:
- Manually go through the terms to reduce redundancy and select 4-6 representative terms for each topic. The rules are:
  - Combine singular and plural forms (e.g., gene and genes)
  - Combine terms that are describing similar entities (e.g., strain and isolate)
  - Rid of overly common words (e.g., plant, gene in some cases)
- The result is in `fig4_3_topic_heatmap_seaborn_order_modified.txt`. Parse this so the info can be used as the topic names in the heatmap.

In [None]:
with open(work_dir / 'fig4_3_topic_heatmap_seaborn_order_modified.txt', 'r') as f:
  topic_modified = []
  lines = f.readlines()
  for line in lines:
    elements = line.strip().split("\t")
    ele_str  = f"{elements[0]}\t" # write the topic index
    for element in elements[1:]:
      if element != "":
        ele_str += f"{element} | "
    ele_str = ele_str[:-3]
    ele_str += "\n"
    topic_modified.append(ele_str)


In [None]:
with open(work_dir / 'fig4_3_topic_heatmap_seaborn_order_condensed.txt', 'w') as f:
  for topic in topic_modified:
    f.write(topic)

### Revisit topic size plot using the modified topic names

In [None]:
topic_info.head(2)

In [None]:
# Ordered based on the index
topic_modified_ordered = [""]*len(topic_modified)
for topic_name in topic_modified:
  [topic, name] = topic_name.strip().split('\t')
  topic_modified_ordered[int(topic)+1] = f"{topic}: {name}"
topic_modified_ordered[:5]

In [None]:
topic_modified_ordered[0] = "OUTLIER"

In [None]:
topic_info['Modified'] = topic_modified_ordered
topic_info.head(2)

In [None]:
topic_info.sort_values('Count', inplace=True)
topic_info.head(2)

In [None]:
# https://matplotlib.org/stable/gallery/lines_bars_and_markers/barh.html

# Do not plot outliers
fig, ax = plt.subplots(figsize=(6,16))
y_pos = np.arange(topic_info.shape[0]-1)
ax.barh(y_pos, topic_info['Count'][:-1], align='center')
ax.set_yticks(y_pos, labels=topic_info['Modified'][:-1])
ax.set_xlabel("Number of documents")
plt.savefig(work_dir / "fig4_3_number_docs_per_topic.pdf")
plt.show()

In [None]:
# Plot top 30
top_n   = 30
fig, ax = plt.subplots(figsize=(6,6))
y_pos = np.arange(top_n)
ax.barh(y_pos, topic_info['Count'][-top_n-1:-1], align='center')
ax.set_yticks(y_pos, labels=topic_info['Modified'][-top_n-1:-1])
ax.set_xlabel("Number of documents")
plt.savefig(work_dir / f"fig4_3_number_docs_per_topic_top{top_n}.pdf")
plt.show()

In [None]:
topic_info['Count'][-top_n-1:-1]

In [None]:
topic_info[topic_info["Topic"] == 72]

## ___Compare topics___

### Topic correlation scatter plot

The c-Tf-Idf values are multiplied by 1,000 so it is easier to read.

In [None]:
def topic_pair_scatter(topic_term_matrix, topic_pair, top_50, out_file,
                       t_annotate=6):
  '''Generate and save a scatter plot of the top50 c-Tf-Idfs of a topic pair
  Args:
   topic_term_matrix (csr): A sparse matrix returned from topic_model.c_tf_idf
   top_50 (dict): {topic:[t50_idx, t50_val, t50_fea]}
   topic_pair (list): a pair of topic indices (-1, ..., 90) in a list
   out_file (str): output pdf name
   t_annotate (float): threshold c-Tf-Idfx1000 values to show feature annotation
  Returns:
  Output:
    out_file (pdf): the scatter plot. 
  '''
  [topic1, topic2] = topic_pair

  # combined top 50 indices
  idx1 = top_50[topic1][0]
  idx2 = top_50[topic2][0]
  top_50_both = list(set(idx1 + idx2))
  #print(len(top_50_both))

  # Get the feature names of indices in the combined list
  top_50_both_feats = []
  feat1 = top_50[topic1][2]
  feat2 = top_50[topic2][2]
  for idx in top_50_both:
    if idx in idx1:
      idx_idx1 = idx1.index(idx)
      top_50_both_feats.append(feat1[idx_idx1])
    elif idx in idx2:
      idx_idx2 = idx2.index(idx)
      top_50_both_feats.append(feat2[idx_idx2])
    else:
      print("ERR: idx {idx} not found")      

  # Get all feature ctfidf values. Note that topic_term_matrix include the -1 
  # topic, so the index of topic 0 should have a row index of 1 in the matrix, 
  # so +1 in the getrow bit below.
  row1 = topic_term_matrix.getrow(topic1+1).toarray()[0].ravel()
  row2 = topic_term_matrix.getrow(topic2+1).toarray()[0].ravel()

  ctfidf1 = row1[top_50_both]*1e3
  ctfidf2 = row2[top_50_both]*1e3
  #print(ctfidf1, ctfidf2)

  # For setting the x, y limits
  #ctfidf_max = math.ceil(max([max(ctfidf1), max(ctfidf2)]))
  #print(ctfidf_max)
  ctfidf_max=12

  plt.figure(figsize=(5,5))
  plt.scatter(ctfidf1, ctfidf2)
  plt.xlabel(f"topic {topic1} c-Tf-Idf (x1,000)")
  plt.ylabel(f"topic {topic2} c-Tf-Idf (x1,000)")
  plt.plot([0, ctfidf_max], [0, ctfidf_max], 'ro--')
  plt.xlim(0, ctfidf_max)
  plt.ylim(0, ctfidf_max)

  for idx, label in enumerate(top_50_both_feats):
    x = ctfidf1[idx]
    y = ctfidf2[idx]
    # Only annotate if the values are larger
    if  x >= t_annotate or y >= t_annotate:
      # annotate labels out of boundary
      if x > ctfidf_max or y > ctfidf_max:
        if x > ctfidf_max and y > ctfidf_max:
            new_x = new_y = ctfidf_max
        elif x > ctfidf_max:
          new_x = ctfidf_max
          new_y = y
        else:
          new_x = x
          new_y = ctfidf_max
        plt.arrow(new_x-0.6, new_y-0.6, 0.5, 0.5, 
                  width=0.05, head_width=0.2, ec="purple")
        plt.annotate(f"{label}({x},{y})", (new_x, new_y), fontsize=6)
      # for points within boundary
      else:
        plt.annotate(label, (x, y), fontsize=6)
  out_file = work_dir / f"fig4_3_topic_pair_scatter_{topic1}_{topic2}.pdf"
  plt.savefig(out_file)
  


### Illustrate similarities between topics in a "super-cluster"

- Topic 74, 75: host, larve, herbivore, pest, host
- Topic 4, 44: uvb, stress, light, leave, co
- Topic 25, 26, 27

In [None]:
pair1 = [74, 75]
pair2 = [4, 44]
pair3 = [25, 26]
pair4 = [25, 27]
pair5 = [26, 27]

In [None]:
topic_pair_scatter(topic_term_matrix, pair1, top_50, 5.5)

In [None]:
topic_pair_scatter(topic_term_matrix, pair2, top_50, 5.5)
topic_pair_scatter(topic_term_matrix, pair3, top_50, 5.5)
topic_pair_scatter(topic_term_matrix, pair4, top_50, 5.5)
topic_pair_scatter(topic_term_matrix, pair5, top_50, 5.5)

In [None]:
pair11 = [86, 88]
topic_pair_scatter(topic_term_matrix, pair11, top_50, 5.5)

### Illustrate similarities between topics not in a "super-cluster"

- Topic 25, 6: auxin signaling, microscopy
- Topic 22, 15, lipid, oil
- Topic 4, 24: uvb stress, ros, metabolism
- Topic 44, 52: light, leaves, co, electron, light
- Topic 44, 51: 51 and 52 are highly similar, but not between 44 and 51

In [None]:
pair6 = [25, 6]
pair7 = [22, 15]
pair8 = [4, 24]
pair9 = [44, 52]
pair10 = [44, 51]

In [None]:
topic_pair_scatter(topic_term_matrix, pair6, top_50, 5.5)
topic_pair_scatter(topic_term_matrix, pair7, top_50, 5.5)
topic_pair_scatter(topic_term_matrix, pair8, top_50, 5.5)
topic_pair_scatter(topic_term_matrix, pair9, top_50, 5.5)
topic_pair_scatter(topic_term_matrix, pair10, top_50, 5.5)