In [2]:
import copy
import biom
import numpy as np
import pandas as pd
import warnings
from typing import Callable
from scipy.spatial import distance
from skbio import (OrdinationResults,
                   DistanceMatrix)
from scipy.sparse.linalg import svds

from gemelli.optspace import svd_sort
from gemelli.ctf import ctf_table_processing
from gemelli.preprocessing import (build_sparse,
                                   matrix_rclr)
from gemelli._defaults import (DEFAULT_COMP, DEFAULT_MSC,
                               DEFAULT_MFC, DEFAULT_MFF,
                               DEFAULT_TEMPTED_PC,
                               DEFAULT_TEMPTED_EP,
                               DEFAULT_TEMPTED_SMTH,
                               DEFAULT_TEMPTED_RES,
                               DEFAULT_TEMPTED_MAXITER,
                               DEFAULT_TEMPTED_RH as DEFAULT_TRH,
                               DEFAULT_TEMPTED_RHC as DEFAULT_RC,
                               DEFAULT_TEMPTED_SVDC,
                               DEFAULT_TEMPTED_SVDCN as DEFAULT_TSCN)

from gemelli.tempted import (freg_rkhs, bernoulli_kernel)

In [None]:
def format_time(individual_id_tables,
                individual_id_state_orders, 
                n_individuals, resolution, 
                timestamps_all, interval=None):
    '''
    Normalize time points to be in the same format and keep
    only the defined interval (if defined)

    Parameters
    ----------
    individual_id_tables: dictionary, required
        Dictionary of tables constructed.
        (see build_sparse class)
        keys = individual_ids
        values = DataFrame, required
            rows = features
            columns = samples

    individual_id_state_orders : dict
        Dictionary of time points for each individual

    n_individuals: int, required
        Number of unique individuals/samples

    resolution: int, optional : Default is 101
        Number of time points to evaluate the value
        of the temporal loading function.

    interval : tuple, optional
        Start and end time points to keep

    Returns
    -------
    ti: list of numpy.ndarray
        List of time points within defined interval 
        per subject

    ind_vec: numpy.ndarray
        Subject indexes for each time point

    tm: numpy.ndarray
        Concatenated normalized time points for all
        subjects

    Raises
    ------
    TODO
    '''

    # make copy of tables to update
    tables_update = copy.deepcopy(individual_id_tables)
    orders_update = copy.deepcopy(individual_id_state_orders)
    # set the interval if none is given
    if interval is None:
        interval = (timestamps_all[0], timestamps_all[-1])
    # set time ranges [0, 1]
    input_time_range = (timestamps_all[0], timestamps_all[-1])
    # normalize time points
    for individual_id in orders_update.keys():
        orders_update[individual_id] = (orders_update[individual_id] - input_time_range[0]) \
                                        / (input_time_range[1] - input_time_range[0])
    # ensure interval is in the same format
    interval = tuple((interval - input_time_range[0]) \
                     / (input_time_range[1] - input_time_range[0]))
    
    # initialize variables to store time points (tps)
    Lt = [] # all normalized tps
    ind_vec = [] #individual indexes for each tp
    ti = [[] for i in range(n_individuals)] # tps within interval per subject
    
    # populate variables above
    for i, (id_, time_range_i) in enumerate(orders_update.items()):
        # save all normalized time points
        Lt.append(time_range_i)
        ind_vec.extend([i] * len(Lt[-1]))
        # define time points within interval
        mask = (time_range_i >= interval[0]) & (time_range_i <= interval[1])
        temp = time_range_i[mask]
        temp = [(resolution-1)*(tp - interval[0])/(interval[1] - interval[0]) for tp in temp]
        ti[i] = np.array(list(map(int, temp)))
        # update tables and orders
        tables_update[id_] = tables_update[id_].T[mask].T
    
    # convert variables to numpy arrays 
    ind_vec = np.array(ind_vec)
    tm = np.concatenate(Lt)

    return interval, tables_update, ti, ind_vec, tm

In [None]:
def initialize_tabular(individual_id_tables, 
                       n_individuals,
                       n_components=3):
                   
    """
    Initialize subject and feature loadings

    Parameters
    ----------
    individual_id_tables: dictionary, required
        Dictionary of tables constructed.
        (see build_sparse class)
        keys = individual_ids
        values = DataFrame, required
            rows = features
            columns = samples

    n_individuals: int, required
        Number of unique individuals/samples

    n_components: int, optional : Default is 3
        The underlying rank of the data and number of
        output dimentions.

    Returns
    ----------
    b_hat: dictionary
        Updated feature loadings
        keys = modality
        values = loadings
        
    Raises
    ----------
    TODO
    """

    # initialize feature loadings
    data_unfold = np.hstack([m.values for m in individual_id_tables.values()])
    u, e, v = svds(data_unfold, k=n_components, which='LM')
    u, e, v = svd_sort(u, np.diag(e), v)
    b_hat = u[:, 0]

    # initialize subject loadings
    consistent_sign = np.sign(np.sum(b_hat))
    a_hat = (np.ones(n_individuals) / np.sqrt(n_individuals)) * consistent_sign

    return b_hat, a_hat

In [None]:
def update_lambda(individual_id_tables, ti, 
                  a_hat, phi_hat, b_hat):
    '''
    Updates the singular values using the loadings
    from the most recent iteration

    Parameters
    ----------
    individual_id_tables: dictionary, required
        Dictionary of tables constructed. Note that at this point
        the tables have been subset to only include the time points
        within the previously defined interval.
        keys = individual_ids
        values = DataFrame, required
            rows = features
            columns = samples
    ti: list of int, required
        Time points within predefined interval for
        each individual 
    a_hat: np.narray, required
        Subject loadings from the previous iteration
    phi_hats: np.narray, required
        Temporal loadings from the previous iteration
    b_hat: np.narray, required
        Feature loadings from the previous iteration

    Returns
    ----------
    lambda_new: dictionary
        Updated singular values
        keys = modality
        values = loadings
        
    Raises
    ----------
    TODO
    '''
    
    nums = []
    denoms = []

    for i, m in enumerate(individual_id_tables.values()):
        
        phi_ = phi_hat[ti[i]]
        num = a_hat[i]*(b_hat.dot(m.values).dot(phi_))
        nums.append(num)
        denom = (a_hat[i]*phi_) ** 2
        denom = np.sum(denom)
        denoms.append(denom)
    
    lambda_new = np.sum(nums) / np.sum(denoms)

    return lambda_new

In [None]:
def modality_iterator(individual_id_tables, 
                      individual_id_state_orders,
                      mod_id_ind, interval, 
                      resolution=101, maxiter=20,
                      epsilon=1e-4, smooth=1e-6,
                      n_components=3):
    '''
    Iterate over the available modalities

    Parameters
    ----------
    individual_id_tables: dictionary, required
        Dictionary of 1 to n tables constructed,
        (see build_sparse class), where n is the 
        number of modalities.
        keys = individual_ids
        values = list of DataFrame, required
            For each DataFrame (modality):
            rows = features
            columns = samples

    individual_id_state_orders: dictionary, required
        Dictionary of 1 to n lists of time points (one 
        per modality) for each sample.
        keys = individual_ids
        values = list of numpy.ndarray
            Each numpy.ndarray contains the time points
            of the corresponding modality
            Note: array of dtype=object to allow for
            different number of time points per modality

    mod_id_ind: dictionary, required
        Dictionary of individual IDs for each modality
        keys = modality
        values = list of tuples
            Each tuple contains the individual id and 
            the dataframe index in individual_id_tables

    interval : tuple, optional
        Start and end time points to keep

    resolution: int, optional : Default is 101
        Number of time points to evaluate the value
        of the temporal loading function.

    maxiter: int, optional : Default is 20
        Maximum number of iteration in for rank-1 calculation

    epsilon: float, optional : Default is 0.0001
        Convergence criteria for difference between iterations
        for each rank-1 calculation.

    smooth: float, optional : Default is 1e-6
        Smoothing parameter for the kernel matrix

    Returns
    ----------
    b_hats: dictionary
        Updated feature loadings
        keys = modality
        values = loadings
        
    Raises
    ----------
    TODO
    '''

    # make copy of tables to update
    orders_update = copy.deepcopy(individual_id_state_orders)
    tables_update = copy.deepcopy(individual_id_tables)

    #get all time points across all modalities
    timestamps_all = np.concatenate(list(orders_update.values()))
    timestamps_all = np.concatenate(timestamps_all)
    timestamps_all = np.unique(timestamps_all)
    
    #initialize dictionary to store outputs per modality
    #To-do: save in self
    b_hats = {}
    phi_hats = {}
    a_hats = {}
    lambdas = {}
    table_mods = {}
    times = {}
    
    #iterate through each modality
    for modality in mod_id_ind.keys():        

        #get the individual IDs
        ind_tuple_lst = mod_id_ind[modality]
        #keep modality-specific time points
        orders_sub = {ind: orders_update[ind[0]][ind[1]] for ind in ind_tuple_lst}
        #keep modality-specific tables
        table_mod = {ind: tables_update[ind[0]][ind[1]] for ind in ind_tuple_lst}
        n_individuals = len(table_mod)

        #format time points
        (norm_interval, table_mod, 
         ti, ind_vec, tm) = format_time(table_mod, orders_sub, 
                                        n_individuals, resolution,
                                        timestamps_all, interval)        
        #save key outputs
        table_mods[modality] = table_mod
        times[modality] = [ti, ind_vec]

        #construct the kernel matrix
        Kmat = bernoulli_kernel(tm, tm)
        Kmat_output = bernoulli_kernel(np.linspace(norm_interval[0],
                                                   norm_interval[1],
                                                   num=resolution),
                                       tm)
        
        #initialize feature and subject loadings
        data_unfold = np.hstack([m.values for m in table_mod.values()])
        b_hat, a_hat = initialize_tabular(data_unfold, 
                                          n_individuals=n_individuals,
                                          n_components=n_components)
        b_hats[modality] = b_hat
        a_hats[modality] = a_hat

    #TO-DO: update components for each modality
    t = 0
    dif = 1
    while t <= maxiter and dif > epsilon:    
        for modality in mod_id_ind.keys():
            
            #get key modality-specific variables
            table_mod = table_mods[modality]
            ti, ind_vec = times[modality]
            a_hat = a_hats[modality]
            b_hat = b_hats[modality]

            #calculate state loadings
            Ly = [a_hat[i] * b_hat.dot(m) for i, m in enumerate(table_mod.values())]
            phi_hat = freg_rkhs(Ly, a_hat, ind_vec, Kmat, Kmat_output, smooth=smooth)
            phi_hat = (phi_hat / np.sqrt(np.sum(phi_hat ** 2)))
            phi_hats[modality] = phi_hat

            #calculate lambda
            lambda_mod = update_lambda(table_mod, ti, a_hat, phi_hat, b_hat)
            lambdas[modality] = lambda_mod

            #update tabular loadings
            n_features = table_mod[0].shape[0]

In [4]:
#create dummy orders_update
individual_id_state_orders = {'ind1': np.array([[0, 0.5, 1, 2],[0, 0.5, 1],[0, 0.5, 1]], dtype=object),
                              'ind2': np.array([[0, 1, 3],[0, 0.5, 1, 3]], dtype=object)}

individual_id_state_orders2 = {'ind1': np.array([0, 0.5, 1, 2]),
                              'ind2': np.array([0, 1, 3])}

print(individual_id_state_orders)
print(individual_id_state_orders2)

{'ind1': array([list([0, 0.5, 1, 2]), list([0, 0.5, 1]), list([0, 0.5, 1])],
      dtype=object), 'ind2': array([list([0, 1, 3]), list([0, 0.5, 1, 3])], dtype=object)}
{'ind1': array([0. , 0.5, 1. , 2. ]), 'ind2': array([0, 1, 3])}


In [5]:
#create random tables
table1 = np.random.randint(0,100,size=(10, 5))
table2 = np.random.randint(0,100,size=(12, 7))

table3 = np.random.randint(0,100,size=(10, 6))
table4 = np.random.randint(0,100,size=(12, 4))

#create dictionary of tables
individual_id_tables = {'ind1': [table1, table2], 'ind2': [table3, table4]}
individual_id_mod = {'ind1': table1, 'ind2': table3}

In [None]:
def udpate_tabular(individual_id_tables, 
                   n_individuals, a_hat,
                   n_features, b_hat, 
                   ti, phi_hat):
    '''
    Update the tabular loadings (subjects and features) loadings

    Parameters
    ----------
    individual_id_tables: dictionary, required
        Dictionary of tables constructed. Note that at this point
        the tables have been subset to only include the time points
        within the previously defined interval.
        keys = individual_ids
        values = DataFrame, required
            rows = features
            columns = samples
    n_individuals: int, required
        Number of unique individuals/samples
    ti: list of int, required
        Time points within predefined interval for
        each individual 
    a_hat: np.narray, required
        Subject loadings from the previous iteration
    phi_hat: np.narray, required
        Temporal loadings from the previous iteration
    b_hat: np.narray, required
        Feature loadings from the previous iteration

    Returns
    ----------
    a_new: np.narray
        Updated subject loadings
    b_new: np.narray
        Updated feature loadings
        
    Raises
    ----------
    TODO
    '''
    
    # initialize variables
    a_tilde = np.zeros(n_individuals)
    temp_num = np.zeros((n_features, n_individuals))
    temp_denom = np.zeros(n_individuals)
    
    #TO-DO: Iterate through modalities
    for i, m in enumerate(individual_id_tables.values()):
        
        phi_ = phi_hat[ti[i]]
        a_tilde[i] = b_hat.dot(m.values).dot(phi_)
        a_tilde[i] = a_tilde[i] / np.sum(phi_ ** 2)
        temp_num[:, i] = (m.values).dot(phi_)
        temp_denom[i] = np.sum(phi_ ** 2)
    
    # update subject
    a_new = a_tilde / np.sqrt(np.sum(a_tilde ** 2))
    dif = np.sum((a_hat - a_new) ** 2)
    a_hat = a_new
    
    # update feature loadings
    b_tilde = temp_num.dot(a_hat) / (temp_denom.dot(a_hat ** 2))
    b_new = b_tilde / np.sqrt(np.sum(b_tilde ** 2))
    dif = max(dif, np.sum((b_hat - b_new) ** 2))
    b_hat = b_new
    t += 1

In [None]:
def udpate_residuals():
    '''
    Update the tensor to be factorized by subtracting the 
    approximation the previous iteration
    '''
    pass

In [None]:
def joint_ctf(tables, 
              sample_metadata: pd.DataFrame,
              individual_id_column: str,
              state_column: str,
              #tensor_column: str, 
              n_components: int = DEFAULT_COMP,
              ##done separately by user
              ##also for rclr or other transformations
              ##default can be same transformation
              #min_sample_count: int = DEFAULT_MSC,
              #min_feature_count: int = DEFAULT_MFC,
              #min_feature_frequency: float = DEFAULT_MFF,
              #transformation: Callable = matrix_rclr,
              #pseudo_count: float = DEFAULT_TEMPTED_PC,
              ##important to test dif modalities
              replicate_handling: str = DEFAULT_TRH,
              svd_centralized: bool = DEFAULT_TEMPTED_SVDC,
              n_components_centralize: int = DEFAULT_TSCN,
              smooth: float = DEFAULT_TEMPTED_SMTH,
              resolution: int = DEFAULT_TEMPTED_RES,
              max_iterations: int = DEFAULT_TEMPTED_MAXITER,
              epsilon: float = DEFAULT_TEMPTED_EP) #-> (
            #OrdinationResults,
            #pd.DataFrame,
            #DistanceMatrix,
            #pd.DataFrame)):
    '''
    Joint decomposition of two or more tensors
    
    Parameters
    ----------
    tables: list of numpy.ndarray, required
        List of feature tables (1-n) from different modalities
        in biom format containing the samples over which
        metrics should be computed.
        Each modality should contain same number of samples
        or individuals. Length of features might vary.
    
    sample_metadata: DataFrame, required
        Sample metadata file in QIIME2 formatting. The file must
        contain the columns for individual_id_column and
        state_column and the rows matched to the table.

    individual_id_column: str, required
        Metadata column containing subject IDs to use for
        pairing samples. WARNING: if replicates exist for an
        individual ID at either state_1 to state_N, that
        subject will be mean grouped by default.

    state_column: str, required
        Metadata column containing state (e.g.,Time,
        BodySite) across which samples are paired. At least
        one is required but up to four are allowed by other
        state inputs.
    
    tensor_column: str, optional
        Metadata column denoting which modality was 
        collected at a specific time point for each sample.
        Note that for some time points, data from several   
        modalities might be available 

    n_components: int, optional : Default is 3
        The underlying rank of the data and number of
        output dimentions.

    min_sample_count: int, optional : Default is 0
        Minimum sum cutoff of sample across all features.
        The value can be at minimum zero and must be an
        whole integer. It is suggested to be greater than
        or equal to 500.

    min_feature_count: int, optional : Default is 0
        Minimum sum cutoff of features across all samples.
        The value can be at minimum zero and must be
        an whole integer.

    min_feature_frequency: float, optional : Default is 0
        Minimum percentage of samples a feature must appear
        with a value greater than zero. This value can range
        from 0 to 100 with decimal values allowed.

    transformation: function, optional : Default is matrix_rclr
        The transformation function to use on the data.

    pseudo_count: float, optional : Default is 1
        The pseudo count to add to all values before applying
        the transformation.

    replicate_handling: function, optional : Default is "sum"
        Choose how replicate samples are handled. If replicates are
        detected, "error" causes method to fail; "drop" will discard
        all replicated samples; "random" chooses one representative at
        random from among replicates.

    svd_centralized: bool, optional : Default is True
        Removes the mean structure of the temporal tensor.

    n_components_centralize: int
        Rank of approximation for average matrix in svd-centralize.

    smooth: float, optional : Default is 1e-8
        Smoothing parameter for RKHS norm. Larger means
        smoother temporal loading functions.

    resolution: int, optional : Default is 101
        Number of time points to evaluate the value
        of the temporal loading function.

    max_iterations: int, optional : Default is 20
        Maximum number of iteration in for rank-1 calculation.

    epsilon: float, optional : Default is 0.0001
        Convergence criteria for difference between iterations
        for each rank-1 calculation.

    Returns
    -------
    OrdinationResults
        Compositional biplot of subjects as points and
        features as arrows. Where the variation between
        subject groupings is explained by the log-ratio
        between opposing arrows.

    DataFrame
        Each components temporal loadings across the
        input resolution included as a column called
        'time_interval'.

    DistanceMatrix
        A subject-subject distance matrix generated
        from the euclidean distance of the
        subject ordinations and itself.

    DataFrame
        The loadings from the SVD centralize
        function, used for projecting new data.
        Warning: If SVD-centering is not used
        then the function will add all ones as the
        output to avoid variable outputs.

    Raises
    ------
    ValueError
        if features don't match between tables
        across the values of the dictionary
    ValueError
        if id_ not in mapping
    ValueError
        if any state_column not in mapping
    ValueError
        Table is not 2-dimensional
    ValueError
        Table contains negative values
    ValueError
        Table contains np.inf or -np.inf
    ValueError
        Table contains np.nan or missing.
    Warning
        If a conditional-sample pair
        has multiple IDs associated
        with it. In this case the
        default method is to mean them.
    ValueError
        `ValueError: n_components must be at least 2`.
    ValueError
        `ValueError: Data-table contains
         either np.inf or -np.inf`.
    ValueError
        `ValueError: The n_components must be less
         than the minimum shape of the input tensor`.

    Examples
    --------
    TODO
    '''
    pass

In [None]:
def joint_ctf_helper():
    '''
    Joint decomposition of two or more tensors
    '''
    pass