In [2]:
import copy
import biom
import numpy as np
import pandas as pd
import warnings
from typing import Callable
from scipy.spatial import distance
from skbio import (OrdinationResults,
                   DistanceMatrix)
from scipy.sparse.linalg import svds

from gemelli.optspace import svd_sort
from gemelli.ctf import ctf_table_processing
from gemelli.preprocessing import (build_sparse,
                                   matrix_rclr)
from gemelli._defaults import (DEFAULT_COMP, DEFAULT_MSC,
                               DEFAULT_MFC, DEFAULT_MFF,
                               DEFAULT_TEMPTED_PC,
                               DEFAULT_TEMPTED_EP,
                               DEFAULT_TEMPTED_SMTH,
                               DEFAULT_TEMPTED_RES,
                               DEFAULT_TEMPTED_MAXITER,
                               DEFAULT_TEMPTED_RH as DEFAULT_TRH,
                               DEFAULT_TEMPTED_RHC as DEFAULT_RC,
                               DEFAULT_TEMPTED_SVDC,
                               DEFAULT_TEMPTED_SVDCN as DEFAULT_TSCN)

from gemelli.tempted import (freg_rkhs, bernoulli_kernel)

In [13]:
def udpate_residuals(table_mods, a_hat, b_hats,
                     phi_hats, times, lambdas):
    '''
    Update the tensor to be factorized by subtracting the 
    approximation the previous iteration. In other words,
    calculate the residuals.

    Parameters
    ----------
    table_mods: dictionary, required
        Tables for each modality
        keys = modality
        values = DataFrame
            rows = features
            columns = samples
    
    a_hat: np.narray, required
        Subject loadings
    
    b_hats: dictionary, required
        Feature loadings
        keys = modality
        values = loadings
    
    phi_hats: dictionary, required
        Temporal loadings
        keys = modality
        values = loadings
    
    times: dictionary, required
        Time points for each modality
        keys = modality
        values = list of numpy.ndarray
            list[0] = time points within interval
            list[1] = individual indexes
    
    lambdas: dictionary, required
        Singular values
        keys = modality
        values = loadings

    Returns
    ----------
    tables_update: dictionary
        Residuals for each modality
        keys = modality
        values = DataFrame
            rows = features
            columns = samples

    Raises
    ----------
    TODO
    '''
    
    tables_update = copy.deepcopy(table_mods)

    for modality in tables_update.keys():

        #get key modality-specific variables
        table_mod = tables_update[modality]
        b_hat = b_hats[modality]
        phi_hat = phi_hats[modality]
        ti = times[modality][0]
        lambda_coeff = lambdas[modality]
    
        for i, (individual_id, m) in enumerate(table_mod.items()):
            phi_ = phi_hat[ti[i]]
            scale_tmp = b_hat.dot(phi_)
            scale_tmp = a_hat * scale_tmp
            table_mod[individual_id] -= (lambda_coeff * scale_tmp)

        tables_update[modality] = table_mod

    return tables_update

In [4]:
def update_lambda(individual_id_tables, ti, 
                  a_hat, phi_hat, b_hat):
    '''
    Updates the singular values using the loadings
    from the most recent iteration

    Parameters
    ----------
    individual_id_tables: dictionary, required
        Dictionary of tables constructed. Note that at this point
        the tables have been subset to only include the time points
        within the previously defined interval.
        keys = individual_ids
        values = DataFrame, required
            rows = features
            columns = samples
    ti: list of int, required
        Time points within predefined interval for
        each individual 
    a_hat: np.narray, required
        Subject loadings from the previous iteration
    phi_hats: np.narray, required
        Temporal loadings from the previous iteration
    b_hat: np.narray, required
        Feature loadings from the previous iteration

    Returns
    ----------
    lambda_new: dictionary
        Updated singular values
        keys = modality
        values = loadings
        
    Raises
    ----------
    TODO
    '''
    
    nums = []
    denoms = []

    for i, m in enumerate(individual_id_tables.values()):
        
        phi_ = phi_hat[ti[i]]
        num = a_hat[i]*(b_hat.dot(m.values).dot(phi_))
        nums.append(num)
        denom = (a_hat[i]*phi_) ** 2
        denom = np.sum(denom)
        denoms.append(denom)
    
    lambda_new = np.sum(nums) / np.sum(denoms)

    return lambda_new

In [5]:
def update_a_mod(individual_id_tables, 
                 n_individuals, n_features,
                 b_mod, phi_mod, 
                 lambda_mod, ti):
    '''
    Update the tabular loadings (subjects and features) loadings

    Parameters
    ----------
    individual_id_tables: dictionary, required
        Dictionary of tables constructed. Note that at this point
        the tables have been subset to only include the time points
        within the previously defined interval.
        keys = individual_ids
        values = DataFrame, required
            rows = features
            columns = samples

    n_individuals: int, required
        Number of unique individuals/samples in modality

    n_features: int, required
        Number of unique features in modality

    b_mod: np.narray, required
        Feature loadings from a specific modality

    phi_mod: np.narray, required
        Temporal loadings from a specific modality

    lambda_mod: float, required
        Singular value from a specific modality

    ti: list of int, required
        Time points within predefined interval for
        each individual

    Returns
    ----------
    a_new: np.narray
        Updated subject loadings
    b_new: np.narray
        Updated feature loadings
        
    Raises
    ----------
    TODO
    '''

    #initialize intermediate outputs
    a_num = {}
    a_denom = {}
    b_num = np.zeros((n_features, n_individuals))
    common_denom = {}

    for i, (individual_id, m) in enumerate(individual_id_tables.items()):

        #keep only relevant timepoints (within interval)
        phi_ = phi_mod[ti[i]]
        #save item needed for both a_hat and b_hat
        common_denom[individual_id] = np.sum(phi_ ** 2)
        #save item needed later for b_hat
        b_num[:, i] = (m.values).dot(phi_)  #vector per individual
        #a_hat specific operations
        a_num_mod = lambda_mod*b_mod.dot(m.values).dot(phi_)
        a_num[individual_id] = a_num_mod
        a_denom[individual_id] = (lambda_mod ** 2)*common_denom[individual_id]

    return a_num, a_denom, b_num, common_denom

In [6]:
def initialize_tabular(individual_id_tables, 
                       n_individuals,
                       n_components=3):
                   
    """
    Initialize subject and feature loadings

    Parameters
    ----------
    individual_id_tables: dictionary, required
        Dictionary of tables constructed.
        (see build_sparse class)
        keys = individual_ids
        values = DataFrame, required
            rows = features
            columns = samples

    n_individuals: int, required
        Number of unique individuals/samples

    n_components: int, optional : Default is 3
        The underlying rank of the data and number of
        output dimentions.

    Returns
    ----------
    b_hat: dictionary
        Updated feature loadings
        keys = modality
        values = loadings
        
    Raises
    ----------
    TODO
    """

    # initialize feature loadings
    data_unfold = np.hstack([m.values for m in individual_id_tables.values()])
    u, e, v = svds(data_unfold, k=n_components, which='LM')
    u, e, v = svd_sort(u, np.diag(e), v)
    b_hat = u[:, 0]

    # initialize subject loadings
    consistent_sign = np.sign(np.sum(b_hat))
    a_hat = (np.ones(n_individuals) / np.sqrt(n_individuals)) * consistent_sign

    return b_hat, a_hat

In [7]:
def decomposition_iter(table_mods, times, 
                       individual_id_lst, 
                       Kmats, Kmat_outputs,
                       maxiter=20, epsilon=1e-4,
                       smooth=1e-6, n_components=3):
    '''
    Iterate over the available modalities

    Parameters
    ----------
    table_mods: dictionary, required
        Updated tables for each modality. Times are
        normalized and only points within the interval
        are kept.
        keys = modality
        values = DataFrame
            rows = features
            columns = samples

    individual_id_lst: list, required
        List of unique individual IDs

    times: dictionary, required
        Updated time points for each modality
        keys = modality
        values = list of numpy.ndarray
            list[0] = time points within interval
            list[1] = individual indexes

    Kmats: dictionary, required
        Kernel matrix for each modality
        keys = modality
        values = numpy.ndarray
            rows, columns = time points
    
    Kmat_outputs: dictionary, required
        Bernoulli kernel matrix for each modality
        keys = modality
        values = numpy.ndarray
            rows = resolution
            columns = time points

    maxiter: int, optional : Default is 20
        Maximum number of iteration in for rank-1 calculation

    epsilon: float, optional : Default is 0.0001
        Convergence criteria for difference between iterations
        for each rank-1 calculation.

    smooth: float, optional : Default is 1e-6
        Smoothing parameter for the kernel matrix

    Returns
    ----------
    Rank-1 loadings
    a_hat: np.narray
        Subject loadings, shared across modalities

    b_hats: dictionary
        Feature loadings
        keys = modality
        values = loadings

    phi_hats: dictionary
        Temporal loadings
        keys = modality
        values = loadings

    lambdas: dictionary
        Singular values
        keys = modality
        values = loadings
        
    Raises
    ----------
    TODO
    '''

    a_hats = {}
    b_hats = {}
    phi_hats = {}
    lambdas = {}
    common_denom = {}
    b_num = {}
    
    #iterate until convergence
    t = 0
    dif = 1
    while t <= maxiter and dif > epsilon:            
        
        #variables to save intermediate outputs
        a_num = {}
        a_denom = {}
        b_hat_difs = {}
        for modality in table_mod.keys():
            
            #get key modality-specific variables
            table_mod = table_mods[modality]
            ti, ind_vec = times[modality]
            Kmat = Kmats[modality]
            Kmat_output = Kmat_outputs[modality]
            n_individuals = len(table_mod)
            n_features = table_mod[0].shape[0]
            
            if t == 0:
                #initialize feature and subject loadings
                data_unfold = np.hstack([m.values for m in table_mod.values()])
                b_hat, a_hat = initialize_tabular(data_unfold, 
                                                  n_individuals=n_individuals,
                                                  n_components=n_components)
                b_hats[modality] = b_hat
                a_hats[modality] = a_hat
            if t > 0:
                #update feature loadings
                b_temp = b_num[modality]
                b_new = b_temp.dot(a_hat) / (common_denom[modality].dot(a_hat ** 2))
                b_hat = b_new / np.sqrt(np.sum(b_new ** 2))
                b_hat_difs[modality] = np.sum((b_hats[modality] - b_hat) ** 2)
                b_hats[modality] = b_hat
            
            #calculate state loadings
            Ly = [a_hat[i] * b_hat.dot(m) for i, m in enumerate(table_mod.values())]
            phi_hat = freg_rkhs(Ly, a_hat, ind_vec, Kmat, Kmat_output, smooth=smooth)
            phi_hat = (phi_hat / np.sqrt(np.sum(phi_hat ** 2)))
            phi_hats[modality] = phi_hat
            #calculate lambda
            lambda_mod = update_lambda(table_mod, ti, a_hat, phi_hat, b_hat)
            lambdas[modality] = lambda_mod
            #begin updating subject and feature loadings
            (a_mod_num, a_mod_denom, 
             b_mod_num, common_mod_denom) = update_a_mod(table_mod, n_individuals, n_features,
                                                         b_hat, phi_hat, lambda_mod, ti)
            #save intermediate b-hat variables
            b_num[modality] = b_mod_num
            common_denom[modality] = common_mod_denom
            #add subject loading variables
            a_num = {**a_num, **{key: a_mod_num[key] + a_num.get(key, 0) 
                                 for key in a_mod_num}}
            a_denom = {**a_denom, **{key: a_mod_denom[key] + a_denom.get(key, 0) 
                                     for key in a_mod_denom}}
        #update subject loadings
        a_tilde = [a_num[id] / a_denom[id] for id in individual_id_lst]
        a_new = a_tilde / np.sqrt(np.sum(a_tilde ** 2))
        a_hat_dif = np.sum((a_hat - a_new) ** 2)
        a_hat = a_new
        #check for convergence
        dif = max([a_hat_dif]+list(b_hat_difs.values())) #or take mean of b_hat_difs?    
        t += 1    

    return a_hat, b_hats, phi_hats, lambdas

In [8]:
def format_time(individual_id_tables,
                individual_id_state_orders, 
                n_individuals, resolution, 
                input_time_range, interval=None):
    '''
    Normalize time points to be in the same format and keep
    only the defined interval (if defined)

    Parameters
    ----------
    individual_id_tables: dictionary, required
        Dictionary of tables constructed.
        (see build_sparse class)
        keys = individual_ids
        values = DataFrame, required
            rows = features
            columns = samples

    individual_id_state_orders : dict
        Dictionary of time points for each individual

    n_individuals: int, required
        Number of unique individuals/samples

    resolution: int, optional : Default is 101
        Number of time points to evaluate the value
        of the temporal loading function.

    input_time_range: tuple, required
        Start and end time points for each individual

    interval : tuple, optional
        Start and end time points to keep

    Returns
    -------
    ti: list of numpy.ndarray
        List of time points within defined interval 
        per subject

    ind_vec: numpy.ndarray
        Subject indexes for each time point

    tm: numpy.ndarray
        Concatenated normalized time points for all
        subjects

    Raises
    ------
    TODO
    '''

    # make copy of tables to update
    tables_update = copy.deepcopy(individual_id_tables)
    orders_update = copy.deepcopy(individual_id_state_orders)
    
    # normalize time points
    for individual_id in orders_update.keys():
        orders_update[individual_id] = (orders_update[individual_id] - input_time_range[0]) \
                                        / (input_time_range[1] - input_time_range[0])
    # ensure interval is in the same format
    interval = tuple((interval - input_time_range[0]) \
                     / (input_time_range[1] - input_time_range[0]))
    
    # initialize variables to store time points (tps)
    Lt = [] # all normalized tps
    ind_vec = [] #individual indexes for each tp
    ti = [[] for i in range(n_individuals)] # tps within interval per subject
    
    # populate variables above
    for i, (id_, time_range_i) in enumerate(orders_update.items()):
        # save all normalized time points
        Lt.append(time_range_i)
        ind_vec.extend([i] * len(Lt[-1]))
        # define time points within interval
        mask = (time_range_i >= interval[0]) & (time_range_i <= interval[1])
        temp = time_range_i[mask]
        temp = [(resolution-1)*(tp - interval[0])/(interval[1] - interval[0]) for tp in temp]
        ti[i] = np.array(list(map(int, temp)))
        # update tables and orders
        tables_update[id_] = tables_update[id_].T[mask].T
    
    # convert variables to numpy arrays 
    ind_vec = np.array(ind_vec)
    tm = np.concatenate(Lt)

    return interval, tables_update, ti, ind_vec, tm

In [9]:
def formatting_iter(individual_id_tables, 
                    individual_id_state_orders,
                    mod_id_ind, input_time_range, 
                    interval, resolution):
    '''
    Format the input data for downstream tasks and 
    calculate tne kernel matrix.

    Parameters
    ----------
    individual_id_tables: dictionary, required
        Dictionary of 1 to n tables constructed,
        (see build_sparse class), where n is the 
        number of modalities.
        keys = individual_ids
        values = list of DataFrame, required
            For each DataFrame (modality):
            rows = features
            columns = samples

    individual_id_state_orders: dictionary, required
        Dictionary of 1 to n lists of time points (one 
        per modality) for each sample.
        keys = individual_ids
        values = list of numpy.ndarray
            Each numpy.ndarray contains the time points
            of the corresponding modality
            Note: array of dtype=object to allow for
            different number of time points per modality

    mod_id_ind: dictionary, required
        Dictionary of individual IDs for each modality
        keys = modality
        values = list of tuples
            Each tuple contains the individual id and 
            the dataframe index in individual_id_tables

    input_time_range: tuple, required
        Start and end time points for each individual

    interval : tuple, optional
        Start and end time points to keep

    resolution: int, optional : Default is 101
        Number of time points to evaluate the value
        of the temporal loading function.

    Returns
    ----------
    table_mods: dictionary
        Updated tables for each modality. Times are
        normalized and only points within the interval
        are kept.
        keys = modality
        values = DataFrame
            rows = features
            columns = samples

    times: dictionary
        Updated time points for each modality
        keys = modality
        values = list of numpy.ndarray
            list[0] = time points within interval
            list[1] = individual indexes

    Kmats: dictionary
        Kernel matrix for each modality
        keys = modality
        values = numpy.ndarray
            rows, columns = time points
    
    Kmat_outputs: dictionary
        Bernoulli kernel matrix for each modality
        keys = modality
        values = numpy.ndarray
            rows = resolution
            columns = time points

    Raises
    ----------
    TODO
    '''

    #initialize dictionary to store outputs per modality
    #To-do: save in self
    table_mods = {}
    times = {}
    Kmats = {}
    Kmat_outputs = {}
    
    #iterate through each modality
    for modality in mod_id_ind.keys():        

        #get the individual IDs
        ind_tuple_lst = mod_id_ind[modality]
        #keep modality-specific time points
        orders_mod = {ind[0]: individual_id_state_orders[ind[0]][ind[1]] 
                      for ind in ind_tuple_lst}
        #keep modality-specific tables
        table_mod = {ind[0]: individual_id_tables[ind[0]][ind[1]] 
                     for ind in ind_tuple_lst}
        n_individuals = len(table_mod)
        #format time points and keep points in the interval
        (norm_interval, table_mod, 
         ti, ind_vec, tm) = format_time(table_mod, orders_mod,
                                        n_individuals, resolution,
                                        input_time_range, interval)
        #save key outputs
        table_mods[modality] = table_mod
        times[modality] = [ti, ind_vec]
        #construct the kernel matrix
        Kmats[modality] = bernoulli_kernel(tm, tm)
        Kmat_outputs[modality] = bernoulli_kernel(np.linspace(norm_interval[0],
                                                              norm_interval[1],
                                                              num=resolution),
                                                 tm)
    
    return table_mods, times, Kmats, Kmat_outputs, norm_interval

In [10]:
def joint_ctf_helper(individual_id_tables,
                     individual_id_state_orders,
                     mod_id_ind, interval,
                     resolution, maxiter,
                     epsilon, smooth, 
                     n_components):
    '''
    Joint decomposition of two or more tensors

    Parameters
    ----------
    individual_id_tables: dictionary, required
        Dictionary of 1 to n tables constructed,
        (see build_sparse class), where n is the 
        number of modalities.
        keys = individual_ids
        values = list of DataFrame, required
            For each DataFrame (modality):
            rows = features
            columns = samples

    individual_id_state_orders: dictionary, required
        Dictionary of 1 to n lists of time points (one 
        per modality) for each sample.
        keys = individual_ids
        values = list of numpy.ndarray
            Each numpy.ndarray contains the time points
            of the corresponding modality
            Note: array of dtype=object to allow for
            different number of time points per modality

    mod_id_ind: dictionary, required
        Dictionary of individual IDs for each modality
        keys = modality
        values = list of tuples
            Each tuple contains the individual id and 
            the dataframe index in individual_id_tables

    interval : tuple, optional
        Start and end time points to keep

    resolution: int, optional : Default is 101
        Number of time points for the temporal 
        loading function.

    maxiter: int, optional : Default is 20
        Maximum number of iteration in for rank-1 calculation.

    epsilon: float, optional : Default is 0.0001
        Convergence criteria for difference between iterations
        for each rank-1 calculation.

    smooth: float, optional : Default is 1e-8
        Smoothing parameter for RKHS norm. Larger means
        smoother temporal loading functions.

    n_components: int, optional : Default is 3
        The underlying rank of the data and number of
        output dimentions.

    Returns
    ----------
    individual_loadings: pd.DataFrame
        Subject loadings
        rows = individual IDs
        columns = component number
    
    feature_loadings: dictionary
        Feature loadings
        keys = component number
        values = dictionary of modality-specific loadings
    
    state_loadings: dictionary
        Temporal loadings

    lambda_coeff: dictionary
        Singular values

    time_return: np.ndarray
        Time points for the temporal loading function
    '''
    
    #make copy of tables to update
    tables_update = copy.deepcopy(individual_id_tables)
    orders_update = copy.deepcopy(individual_id_state_orders)
    #get all individual IDs
    individual_id_lst = list(orders_update.keys())
    n_individuals_all = len(individual_id_lst)
    #get all time points across all modalities
    timestamps_all = np.concatenate(list(orders_update.values()))
    timestamps_all = np.concatenate(timestamps_all)
    timestamps_all = np.unique(timestamps_all)
    # set the interval if none is given
    if interval is None:
        interval = (timestamps_all[0], timestamps_all[-1])
    # set time ranges [0, 1]
    input_time_range = (timestamps_all[0], timestamps_all[-1])

    #format time points and keep points in defined interval
    (table_mods, times,
    Kmats, Kmat_outputs,
    norm_interval) = formatting_iter(tables_update, 
                                     orders_update,
                                     mod_id_ind, 
                                     input_time_range,
                                     interval, resolution)
    #init dataframes to fill
    #key: component number, value: dictionary of modality-specific loadings
    n_component_col_names = ['component_' + str(i+1)
                             for i in range(n_components)]
    individual_loadings = pd.DataFrame(np.zeros((n_individuals_all, n_components)),
                                       index=tables_update.keys(),
                                       columns=n_component_col_names)
    feature_loadings = {}
    state_loadings = {}
    lambda_coeff = {} 

    #perform decomposition
    for r in n_components:
        comp_name = 'component_' + str(r+1)
        (a_hat, b_hats, 
         phi_hats, lambdas) = decomposition_iter(table_mods, times,
                                                 individual_id_lst,
                                                 Kmats, Kmat_outputs,
                                                 maxiter, epsilon, 
                                                 smooth, n_components)
        #save rank-1 components
        individual_loadings.iloc[:, r] = a_hat
        feature_loadings[comp_name] = b_hats
        state_loadings[comp_name] = phi_hats
        lambda_coeff[comp_name] = lambdas

        #calculate residuals and update tables
        tables_update = udpate_residuals(table_mods, a_hat, b_hats, 
                                         phi_hats, times, lambdas)
        
        table_mods = tables_update

    #TODO: revise signs to make sure summation is non-negative(?)
    #TODO: find better format to save decompositions(?)
    #return original time points
    time_return = np.linspace(norm_interval[0],
                              norm_interval[1],
                              resolution)
    time_return *= (input_time_range[1] - input_time_range[0])
    time_return += input_time_range[0]

    return (individual_loadings, feature_loadings, 
            state_loadings, lambda_coeff, time_return)

In [11]:
class concat_tensors():

    '''
    Concatenate the tensors from each modality into a
    single tensor class

    Parameters
    ----------
    tensors: dictionary, required
        Dictionary of tensors constructed.
        keys = modality
        values = tensor, required
            rows = features
            columns = samples

    Returns
    ----------
    self: object
        Returns the instance itself
    '''
    def __init__(self):
        pass

    def concat(self, tensors):
        '''
        Concatenate tensors from each modality into a
        single tensor. Note: tensors should have been
        preprocessed by this point.
        '''

        individual_id_tables = {}
        individual_id_state_orders = {}
        mod_id_ind = {}
        
        for mod, tensor in tensors.items():
            
            #concatenate tables
            for ind_id, table in tensor.individual_id_tables_centralized.items():
                individual_id_tables[ind_id] = individual_id_tables.get(ind_id, []) + [table]
                mod_id_ind[mod] = mod_id_ind.get(mod, []) + [(ind_id, len(individual_id_tables[ind_id])-1)]
            #concatenate state orders
            for ind_id, order in tensor.individual_id_state_orders.items():
                individual_id_state_orders[ind_id] = individual_id_state_orders.get(ind_id, []) + [order]

        ##TODO make sure individuals are ordered the same way in all dictionaries?
        
        #store all to self
        self.individual_id_tables = individual_id_tables
        self.individual_id_state_orders = individual_id_state_orders
        self.mod_id_ind = mod_id_ind

        return self

In [12]:
def joint_ctf(tables, 
              sample_metadatas,
              modality_ids,
              individual_id_column: str,
              state_column: str,
              n_components: int = DEFAULT_COMP,
              ##could be done separately by user
              ##also for rclr or other transformations
              ##default can be same transformation
              min_sample_count: int = DEFAULT_MSC,
              min_feature_count: int = DEFAULT_MFC,
              min_feature_frequency: float = DEFAULT_MFF,
              transformation: Callable = matrix_rclr,
              pseudo_count: float = DEFAULT_TEMPTED_PC,
              ##important to test dif modalities
              replicate_handling: str = DEFAULT_TRH,
              svd_centralized: bool = DEFAULT_TEMPTED_SVDC,
              n_components_centralize: int = DEFAULT_TSCN,
              smooth: float = DEFAULT_TEMPTED_SMTH,
              resolution: int = DEFAULT_TEMPTED_RES,
              max_iterations: int = DEFAULT_TEMPTED_MAXITER,
              epsilon: float = DEFAULT_TEMPTED_EP):
    '''
    Joint decomposition of two or more tensors
    
    Parameters
    ----------
    tables: list of numpy.ndarray, required
        List of feature tables (1-n) from different modalities
        in biom format containing the samples over which
        metrics should be computed.
        Each modality should contain same number of samples
        or individuals. Length of features might vary.
    
    sample_metadatas: list of DataFrame, required
        Sample metadata files in QIIME2 formatting for each 
        modality. The file must contain the columns for 
        individual_id_column and state_column and the rows
        matched to the table.

    individual_id_column: str, required
        Metadata column containing subject IDs to use for
        pairing samples. WARNING: if replicates exist for an
        individual ID at either state_1 to state_N, that
        subject will be mean grouped by default.

    state_column: str, required
        Metadata column containing state (e.g.,Time,
        BodySite) across which samples are paired. At least
        one is required but up to four are allowed by other
        state inputs.

    n_components: int, optional : Default is 3
        The underlying rank of the data and number of
        output dimentions.

    min_sample_count: int, optional : Default is 0
        Minimum sum cutoff of sample across all features.
        The value can be at minimum zero and must be an
        whole integer. It is suggested to be greater than
        or equal to 500.

    min_feature_count: int, optional : Default is 0
        Minimum sum cutoff of features across all samples.
        The value can be at minimum zero and must be
        an whole integer.

    min_feature_frequency: float, optional : Default is 0
        Minimum percentage of samples a feature must appear
        with a value greater than zero. This value can range
        from 0 to 100 with decimal values allowed.

    transformation: function, optional : Default is matrix_rclr
        The transformation function to use on the data.

    pseudo_count: float, optional : Default is 1
        The pseudo count to add to all values before applying
        the transformation.

    replicate_handling: function, optional : Default is "sum"
        Choose how replicate samples are handled. If replicates are
        detected, "error" causes method to fail; "drop" will discard
        all replicated samples; "random" chooses one representative at
        random from among replicates.

    svd_centralized: bool, optional : Default is True
        Removes the mean structure of the temporal tensor.

    n_components_centralize: int
        Rank of approximation for average matrix in svd-centralize.

    smooth: float, optional : Default is 1e-8
        Smoothing parameter for RKHS norm. Larger means
        smoother temporal loading functions.

    resolution: int, optional : Default is 101
        Number of time points to evaluate the value
        of the temporal loading function.

    max_iterations: int, optional : Default is 20
        Maximum number of iteration in for rank-1 calculation.

    epsilon: float, optional : Default is 0.0001
        Convergence criteria for difference between iterations
        for each rank-1 calculation.

    Returns
    -------
    OrdinationResults - TODO
        Compositional biplot of subjects as points and
        features as arrows. Where the variation between
        subject groupings is explained by the log-ratio
        between opposing arrows.

    DataFrame - TODO
        Each components temporal loadings across the
        input resolution included as a column called
        'time_interval'.

    DistanceMatrix - TODO
        A subject-subject distance matrix generated
        from the euclidean distance of the
        subject ordinations and itself.

    DataFrame - TODO
        The loadings from the SVD centralize
        function, used for projecting new data.
        Warning: If SVD-centering is not used
        then the function will add all ones as the
        output to avoid variable outputs.

    Raises
    ------
    ValueError
        if features don't match between tables
        across the values of the dictionary
    ValueError
        if id_ not in mapping
    ValueError
        if any state_column not in mapping
    ValueError
        Table is not 2-dimensional
    ValueError
        Table contains negative values
    ValueError
        Table contains np.inf or -np.inf
    ValueError
        Table contains np.nan or missing.
    Warning
        If a conditional-sample pair
        has multiple IDs associated
        with it. In this case the
        default method is to mean them.
    ValueError
        `ValueError: n_components must be at least 2`.
    ValueError
        `ValueError: Data-table contains
         either np.inf or -np.inf`.
    ValueError
        `ValueError: The n_components must be less
         than the minimum shape of the input tensor`.

    Examples
    --------
    TODO
    '''
    
    #note: we assume each modality has a dif table and associated
    #metadata. We also assume filtering conditions are the same
    tensors = {}
    for table, sample_metadata, mod_ids in zip(tables, 
                                               sample_metadatas,
                                               modality_ids):
    
        # check the table for validity and then filter
        process_results = ctf_table_processing(table,
                                               sample_metadata,
                                               individual_id_column,
                                               [state_column],
                                               min_sample_count,
                                               min_feature_count,
                                               min_feature_frequency,
                                               None)
        table = process_results[0]
        sample_metadata = process_results[1]
        # build the sparse tensor format
        tensor = build_sparse()
        tensor.construct(table,
                        sample_metadata,
                        individual_id_column,
                        state_column,
                        transformation=transformation,
                        pseudo_count=pseudo_count,
                        branch_lengths=None,
                        replicate_handling=replicate_handling,
                        svd_centralized=svd_centralized,
                        n_components_centralize=n_components_centralize)
        tensors[mod_ids] = tensor
    #save all tensors to a class
    n_tensors = concat_tensors().concat(tensors)
    
    # run joint-CTF
    joint_ctf_res = joint_ctf_helper(n_tensors.individual_id_tables_centralized,
                                     n_tensors.individual_id_state_orders,
                                     n_tensors.feature_order,
                                     n_components=n_components,
                                     smooth=smooth,
                                     resolution=resolution,
                                     maxiter=max_iterations,
                                     epsilon=epsilon)
    (individual_loadings,
     feature_loadings,
     state_loadings,
     eigenvalues,
     time_return) = joint_ctf_res

    return (individual_loadings, feature_loadings,
            state_loadings, eigenvalues, time_return)

### testing

In [3]:
from biom import load_table

In [5]:
# import tables
data_path = '../ipynb/tutorials/multi-omics-10333/'
table_16S = load_table('{}16S-table.biom'.format(data_path))
table_18S = load_table('{}18S-table.biom'.format(data_path))
table_ITS = load_table('{}ITS-table.biom'.format(data_path))
# import metadata
metadata = pd.read_csv('{}metadata.tsv'.format(data_path), sep='\t', index_col=0)
metadata.head(2)

Unnamed: 0_level_0,ac_sampled_room,accult_score,analysis_name,animals_in_house,anonymized_name,carbon_dioxide_inside,cats,cleaning_frequency,closed_trash,collection_date,...,temp_inside_house,temp_outside_house,title,use_soaps,village,water_source,wind_speed_outside,years_house_inhabited,zone,train_test
sample_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
10333.Iqu.959.bedr,not provided,26,Iqu.959.bedr,0,MG.0850,not provided,0,daily,not provided,7/18/12,...,,,Dominguez Sloan SAWesternization gradient,,Iquitos,municipal,,5.0,Centro,train
10333.Man.1767.bath,FALSE,30,Man.1767.bath,4,MG.1035,not provided,4,7,TRUE,9/5/12,...,,,Dominguez Sloan SAWesternization gradient,,Manaus,municipal,,15.0,Aleixo,train


In [1]:
(individual_loadings, feature_loadings,
state_loadings, eigenvalues, time_return) = joint_ctf([table_16S, table_18S, table_ITS], 
                                                       [metadata, metadata, metadata],
                                                       ['16S', '18S', 'ITS'],
                                                       'host_subject_id', 'time')

NameError: name 'joint_ctf' is not defined

### sanity checks

In [None]:
#create dummy orders_update
individual_id_state_orders = {'ind1': np.array([[0, 0.5, 1, 2],[0, 0.5, 1],[0, 0.5, 1]], dtype=object),
                              'ind2': np.array([[0, 1, 3],[0, 0.5, 1, 3]], dtype=object)}

individual_id_state_orders2 = {'ind1': np.array([0, 0.5, 1, 2]),
                              'ind2': np.array([0, 1, 3])}

print(individual_id_state_orders)
print(individual_id_state_orders2)

#create random tables
table1 = np.random.randint(0,100,size=(10, 4))
table2 = np.random.randint(0,100,size=(12, 4))

table3 = np.random.randint(0,100,size=(10, 3))
table4 = np.random.randint(0,100,size=(12, 4))

#create dictionary of tables
#individual_id_tables = {'ind1': [table1, table2], 'ind2': [table3, table4]}
#individual_id_mod = {'ind1': table1, 'ind2': table3}

tensor1_tables = {'ind1': table1, 'ind2': table2, 'ind4': table1}
tensor2_tables = {'ind1': table3, 'ind2': table4, 'ind3': table3}

tensor1_state_orders = {'ind1': [0, 0.5, 1, 2], 'ind2': [0, 0.5, 1, 3], 'ind4': [0, 0.5, 1, 3]}
tensor2_state_orders = {'ind1': [0, 0.5, 1], 'ind2': [0, 0.5, 1, 3], 'ind3': [0, 0.5, 2]}

individual_id_orders = {}
individual_id_tables = {}

#concat lst from both tensors by individual
for tensor in [tensor1_state_orders, tensor2_state_orders]:

    for key, value in tensor.items():
        individual_id_orders[key] = individual_id_orders.get(key, []) + [value]

for tensor in [tensor1_tables, tensor2_tables]:

    for key, value in tensor.items():
        individual_id_tables[key] = individual_id_tables.get(key, []) + [value]