# __Step 5.3: Speices, topic, and time__

Questions:
- What taxa tend to be worked on in a topic?
- Are there different focal species for a topic over time?

To do:
- Taxa over/under-represented in a topic
  - Focus on the genus level, top 100
  - Must include bioenergy related taxa: Populus, Sorghum, Panicum
- Taxa over/under-represented in a topic/time bin
  - Focus on a few example genus:
    - Arabidopsis, Solanum, Oryza, Zea, Populus, Sorghum, Panicum

Other thoughts:
- Analyzed an emerging field, e.g., bioenergy research
  - Keywords: bioenergy/biofuek/feedstock/

## ___Set up___

### Module import

In [1]:
import pickle, itertools
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
from scipy.stats import fisher_exact
from multiprocessing.pool import Pool
from functools import partial

### Key variables

In [2]:
# Reproducibility
seed = 20220609

# Setting working directory
proj_dir   = Path.home() / "projects/plant_sci_hist"
work_dir   = proj_dir / "5_species_over_time/5_3_sp_topic_time"
work_dir.mkdir(parents=True, exist_ok=True)

# topic assignment
dir42             = proj_dir / "4_topic_model/4_2_outlier_assign"
file_topic_assign = dir42 / "table4_2_corpus_with_topic_assignment.tsv.gz"

# topic name
dir44             = proj_dir / "4_topic_model/4_4_over_time"
file_topic_name   = dir44 / "fig4_4_tot_heatmap_weighted_xscaled_names.txt"

# species-time analysis folder
dir51 = proj_dir / "5_species_over_time/5_1_sp_time"
# taxa count sparse matrices and corresponding taxa names
file_csr_fam      = dir51 / "match_csr_family.pickle"
file_csr_fam_nm   = dir51 / "match_csr_family_names.pickle"
file_csr_genus    = dir51 / "match_csr_genus.pickle"
file_csr_genus_nm = dir51 / "match_csr_genus_names.pickle"
# taxa count time series
file_ts_genus     = dir51 / "Table5_1_ts_genusALL_count.txt"
#file_ts_fam       = dir51 / "Table5_1_ts_familyALL_count.txt"

# So PDF is saved in a format properly
mpl.rcParams['pdf.fonttype'] = 42
plt.rcParams["font.family"] = "sans-serif"

## ___Process topic data___

### Read topic assignment

In [3]:
#https://stackoverflow.com/questions/35101093/load-directly-gz-file-into-pandas-dataframe
#https://www.delftstack.com/howto/python-pandas/pandas-read-gz-file/
#https://stackoverflow.com/questions/36519086/how-to-get-rid-of-unnamed-0-column-in-a-pandas-dataframe-read-in-from-csv-fil

# topic data-frame
tdf = pd.read_csv(file_topic_assign, sep='\t', compression='gzip', index_col=[0])
tdf.shape

(421658, 12)

In [4]:
tdf.head(1)

Unnamed: 0,Index_1385417,PMID,Date,Journal,Title,Abstract,Initial filter qualifier,Corpus,reg_article,Text classification score,Preprocessed corpus,Topic
0,3,61,1975-12-11,Biochimica et biophysica acta,Identification of the 120 mus phase in the dec...,After a 500 mus laser flash a 120 mus phase in...,spinach,Identification of the 120 mus phase in the dec...,1,0.716394,identification 120 mus phase decay delayed flu...,52


In [5]:
toc_array = tdf['Topic'].values
type(toc_array), toc_array.shape

(numpy.ndarray, (421658,))

In [6]:
# map True/False to 1/0
#https://stackoverflow.com/questions/17383094/how-can-i-map-true-false-to-1-0-in-a-pandas-dataframe

toc0 = (toc_array==0).astype(int)
len(toc0), sum(toc0)

(421658, 895)

### Set key numbers for topics

In [7]:
# topic indices
tocs = np.unique(toc_array)

# exclude topic=-1
tocs_90 = tocs[1:]

# number of topic=-1
n_rec_toc_unassigned = sum((toc_array==-1).astype(int))

# number of docs with topic assignment. Originally was thinking about minus
# unassigned, but realize that the totol for taxa would be the number of total
# docs, so it does not make sense to remove unassigned.
n_rec_total  = len(toc_array)

n_rec_toc_unassigned, n_rec_total 

(49228, 421658)

### Read topic names

In [8]:
toc_names = pd.read_csv(file_topic_name, sep='\t')
toc_names.head(2)

Unnamed: 0,Topic,Mod_name
0,22,enzyme | fatty acids | lipid | synthesis
1,18,protein | dna | rna | synthesis | mrna


## ___Species representation among topics___

|              | In topic T | Not in topic T |
| ---          | ---        | ---            |
|In taxa X     | a          | b              |
|Not in taxa X | c          | d              |


### Function for dealing with matching non-zero values between two arrays

- `toc_array` is a numpy array. 
- `csr_genus` is a sparse matrix. 
  - Cannot easily convert `csr_genus` column (like `csr_genus[:,0]`) into an array because it is till a sparse matrix of dimension (421658, 1) which with `toarray()` will become a (421658, 421658) dense array. Not good.
- Work on non-zero indices instead.
- Got ideas from the following posts:
  - [access row/col of a csr](https://stackoverflow.com/questions/25310760/access-a-particular-row-column-in-a-csr-matrix)
  - [boolean to int array](https://stackoverflow.com/questions/16869990/how-to-convert-from-boolean-array-to-int-array-in-python)
  - [Flipping 0 and 1 in an array](https://stackoverflow.com/questions/26890477/flipping-zeroes-and-ones-in-one-dimensional-numpy-array)
  - [Convert a row matrix to a numpy array](https://stackoverflow.com/questions/38405047/how-to-convert-a-scipy-row-matrix-into-a-numpy-array)
  - [Boolean array, index of true](https://stackoverflow.com/questions/36941294/find-the-index-of-a-boolean-array-whose-values-are-true)
  - [non-zero values in sparse array](https://stackoverflow.com/questions/40984516/most-efficient-way-of-accessing-non-zero-values-in-row-column-in-scipy-sparse-ma)
  - [iterate through sparse matrix](https://stackoverflow.com/questions/4319014/iterating-through-a-scipy-sparse-vector-or-matrix)


In [9]:
##https://stackoverflow.com/questions/13070461/get-indices-of-the-top-n-values-of-a-list
def get_topX(mat_taxa, topX):
  '''Get the topX taxa based on total counts
  Args:
   mat_taxa (csr): record (rows), taxa_idx (column)
   topX (int): top X taxa with the higher number of records
  Return:
   mat_taxa_topX (csr): sub-matrix of the topX taxa
  '''

  col_sum  = np.squeeze(np.asarray(csr_genus.sum(axis=0)))
  topX_idx = np.argpartition(col_sum, -topX)[-topX:]

  mat_taxa_topX = mat_taxa[:, topX_idx]

  return mat_taxa_topX

In [10]:
def fet(tax_idx):
  '''Function for calculating Fisher's exact test statistics
  Args:
    tax_idx (int): taxa index, i.e., the column index in mat_taxa
  Return:
    tax_idx (int): because this is called by multiprocessing Pool, starting out
      I have no idea which tax_idx is being processed, so this is returned.
    res (tuple): a tuple from fisher_exact() with (FET_statistic, p-value)
  '''

  # csr_matrix with info on whether a record mention a taxa. This has a
  # dimension of (421658,1).
  # convert to a coo matrix for more efficient operation next
  tax_in_mat = mat_taxa[:,tax_idx].tocoo()

  # Get non-zero row indices for a taxa
  tax_in = []
  for r, c in zip(tax_in_mat.row, tax_in_mat.col):
    tax_in.append(r)
  tax_in = set(tax_in)
  
  # Get numbers for fisher test
  # in: in a toc (T) or taxa (X)
  # ni: not in T or X
  n_inX = len(tax_in)
  n_inT_inX = len(list(toc_in.intersection(tax_in)))
  n_inT_niX = n_inT - n_inT_inX
  n_niT_inX = n_inX - n_inT_inX
  n_niT_niX = n_total - n_inT_inX - n_inT_niX - n_niT_inX
  res = fisher_exact([[n_inT_inX, n_inT_niX], 
                      [n_niT_inX, n_niT_niX]], alternative='two-sided')
  
  return tax_idx, res

### Genus level

- Genus count matrix: (421658, 16794)
  - row: each record
  - col: each taxa
- Genus name list: 16794

#### Read genus count matrix and genus names

In [11]:
with open(file_csr_genus, "rb") as f:
  csr_genus = pickle.load(f)

with open(file_csr_genus_nm, "rb") as f:
  csr_genus_nm = pickle.load(f)

csr_genus.shape, len(csr_genus_nm)

((421658, 16794), 16794)

#### Genus top 100

In [12]:
# Get top 100 genus matrix
topX = 100
mat_genus_topX = get_topX(csr_genus, topX)
mat_genus_topX.shape

(421658, 100)

In [13]:
# Get counts for topic/taxa combo
mat_taxa = mat_genus_topX
num_taxa = mat_taxa.shape[1]

toc_tax_pvalue = {} # {topic_idx:{taxa_idx:fisher exact test pvalue}}
n_total        = len(toc_array) # total number of records
# iterate through tocs
for toc_idx in tqdm(tocs_90):

  # indices of True (i.e., in a topic), nonzero return a nested tuple
  toc_in  = set((toc_array==toc_idx).nonzero()[0])

  # number of records in a topic
  n_inT   = len(toc_in)
  #print(f"topic:{toc_idx}, n_recods:{n_inT}")

  toc_tax_pvalue[toc_idx] = {}
  
  with Pool() as pool:
    for tax_idx, res in pool.imap(fet, range(num_taxa)):
      # store pvalue
      toc_tax_pvalue[toc_idx][tax_idx] = res[1]


100%|██████████| 90/90 [00:40<00:00,  2.23it/s]


In [14]:
file_toc_tax_pval = work_dir / 'toc_genus_pvalue.pickle'
with open(file_toc_tax_pval, 'wb') as f:
  pickle.dump(toc_tax_pvalue, f)

## ___Species/topic representation over time___

## ___Testing___

In [None]:
g = csr_genus[:,0].indices
type(g), g.shape, g[:100]

In [None]:
help(g)

### Multiprocessing implementation 1

Fast (11 secs/topic)

### Multi-prcoessing impleentation 2

So allow multiple arguments, but this is somehow really slow, even compared to non-parallized version.

In [None]:
def fet(tax_idx, fet_args):
  '''Construct 2x2 and do fisher exact test
  Args:
    tax_idx (int): index for taxa
    fet_args (list): [mat_taxa, toc_in, n_inT] where:
      mat_taxa (csr): the matrix with taxa count info
      toc_in (array): the indices of records in a topic
      n_inT (int): number of records in a topic
  Return:
    tax_idx (int): taxa index, returned because this is called with imap() and
      i have no idea what this is.
    res (tuple): Fisher exact test [statistic, p-value]
  '''

  # csr_matrix with info on whether a record mention a taxa. This has a
  # dimension of (421658,1).
  # convert to a coo matrix for more efficient operation next
  [mat_taxa, toc_in, n_inT] = fet_args

  tax_in_mat = mat_taxa[:,tax_idx].tocoo()

  # Get non-zero row indices for a taxa
  tax_in = []
  for r, c in zip(tax_in_mat.row, tax_in_mat.col):
    tax_in.append(r)
  tax_in = set(tax_in)
  
  # Get numbers for fisher test
  # in: in a toc (T) or taxa (X)
  # ni: not in T or X
  n_inX = len(tax_in)
  n_inT_inX = len(list(toc_in.intersection(tax_in)))
  n_inT_niX = n_inT - n_inT_inX
  n_niT_inX = n_inX - n_inT_inX
  n_niT_niX = n_total - n_inT_inX - n_inT_niX - n_niT_inX
  res = fisher_exact([[n_inT_inX, n_inT_niX], 
                      [n_niT_inX, n_niT_niX]], alternative='two-sided')
  
  return tax_idx, res

In [None]:
def get_counts(mat_taxa):
  '''Get counts for topic/taxa combo
  Args:
    mat_taxa (csr): Sparse matrix for whether a taxa (column) is present in a
      record (row)
  Return
    toc_tax_pvalue (dict): {topic_idx:{taxa_idx:fisher exact test pvalue}}
  '''
  num_taxa = mat_taxa.shape[1]

  toc_tax_pvalue = {} # {topic_idx:{taxa_idx:fisher exact test pvalue}}

  # iterate through tocs
  for toc_idx in tocs_90:
    # indices of True (i.e., in a topic), nonzero return a nested tuple
    toc_in   = set((toc_array==toc_idx).nonzero()[0])
    n_inT = len(toc_in)
    print(f"topic:{toc_idx}, n_recods:{n_inT}")

    toc_tax_pvalue[toc_idx] = {}
    
    # rest of the arguments for fet()
    fet_args = [mat_taxa, toc_in, n_inT]
    with Pool() as pool:
      for tax_idx, res in tqdm(
          pool.imap(partial(fet, fet_args=fet_args), range(num_taxa)), 
          total=num_taxa):
        # store pvalue
        toc_tax_pvalue[toc_idx][tax_idx] = res[1]
    break

  return toc_tax_pvalue


### Multiprocessing implementation 3

In [None]:
def fet(toc_idx):

  # indices of True (i.e., in a topic), nonzero return a nested tuple
  toc_in   = set((toc_array==toc_idx).nonzero()[0])
  # number of records in a topic
  n_inT    = len(toc_in)
  #print(f"topic:{toc_idx}, n_recods:{n_inT}")

  # Store p-value for each taxa
  tax_pvalues = {} #{taxa_idx:fisher exact test pvalue}

  num_taxa = mat_taxa.shape[1]
  for tax_idx in range(num_taxa):
    # csr_matrix with info on whether a record mention a taxa. This has a
    # dimension of (421658,1).
    # convert to a coo matrix for more efficient operation next
    tax_in_mat = mat_taxa[:,tax_idx].tocoo()

    # Get non-zero row indices for a taxa
    tax_in = []
    for r, c in zip(tax_in_mat.row, tax_in_mat.col):
      tax_in.append(r)
    tax_in = set(tax_in)
    
    # Get numbers for fisher test
    # in: in a toc (T) or taxa (X)
    # ni: not in T or X
    n_inX = len(tax_in)
    n_inT_inX = len(list(toc_in.intersection(tax_in)))
    n_inT_niX = n_inT - n_inT_inX
    n_niT_inX = n_inX - n_inT_inX
    n_niT_niX = n_total - n_inT_inX - n_inT_niX - n_niT_inX

    res = fisher_exact([[n_inT_inX, n_inT_niX], 
                        [n_niT_inX, n_niT_niX]], alternative='two-sided')
    # store pvalue
    tax_pvalues[tax_idx] = res[1]
  
  return toc_idx, tax_pvalues

In [None]:
# Get counts for topic/taxa combo
mat_taxa        = csr_genus
toc_tax_pvalues = {} # {topic_idx:{taxa_idx:fisher exact test pvalue}}

# iterate through tocs
with Pool() as pool:
  for toc_idx, tax_pvalues in tqdm(pool.imap(fet, tocs_90), total=len(tocs_90)):
    print("done:", toc_idx)
    toc_tax_pvalues[toc_idx] = tax_pvalues



### Non-parallized version

Take 1 min 27 sec for just one topic.

In [None]:
# Get counts for topic/taxa combo
mat_taxa = csr_genus

toc_tax_pvalue = {} # {topic_idx:{taxa_idx:fisher exact test pvalue}}

# iterate through tocs
for toc_idx in tocs_90:
  # indices of True (i.e., in a topic), nonzero return a nested tuple
  toc_in   = set((toc_array==toc_idx).nonzero()[0])
  n_inT = len(toc_in)
  print(f"topic:{toc_idx}, n_recods:{n_inT}")

  toc_tax_pvalue[toc_idx] = {}
  
  for tax_idx in tqdm(range(mat_taxa.shape[1])):
    # csr_matrix with info on whether a record mention a taxa. This has a
    # dimension of (421658,1).
    # convert to a coo matrix for more efficient operation next
    tax_in_mat = mat_taxa[:,tax_idx].tocoo()

    # Get non-zero row indices for a taxa
    tax_in = []
    for r, c in zip(tax_in_mat.row, tax_in_mat.col):
      tax_in.append(r)
    tax_in = set(tax_in)
    
    # Get numbers for fisher test
    # in: in a toc (T) or taxa (X)
    # ni: not in T or X
    n_inX = len(tax_in)
    n_inT_inX = len(list(toc_in.intersection(tax_in)))
    n_inT_niX = n_inT - n_inT_inX
    n_niT_inX = n_inX - n_inT_inX
    n_niT_niX = n_total - n_inT_inX - n_inT_niX - n_niT_inX
    res = fisher_exact([[n_inT_inX, n_inT_niX], 
                        [n_niT_inX, n_niT_niX]], alternative='two-sided')
    
    # store pvalue
    toc_tax_pvalue[toc_idx][tax_idx] = res[1]

  break


In [None]:
len((toc_array==toc_idx).nonzero()[0])

### Index of top x values in an array

https://stackoverflow.com/questions/13070461/get-indices-of-the-top-n-values-of-a-list
https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array/23734295#23734295


In [None]:
a = [5,3,1,4,10]

In [None]:
sorted(range(len(a)), key=lambda i: a[i])[-2:]

In [None]:
sorted(range(len(a)), key=lambda i: a[i], reverse=True)[:2]

In [None]:
# Returns the indices that would sort an array
np.argsort(a), np.argsort(a)[-2:]

In [None]:
# this is more efficient, but not sorted
np.argpartition(a, -2), np.argpartition(a, -2)[-2:]