In [None]:
import copy
import biom
import numpy as np
import pandas as pd
import warnings
from typing import Callable
from scipy.spatial import distance
from skbio import (OrdinationResults,
                   DistanceMatrix)
from scipy.sparse.linalg import svds
from gemelli.optspace import svd_sort
from gemelli.ctf import ctf_table_processing
from gemelli.preprocessing import (build_sparse,
                                   matrix_rclr)
from gemelli._defaults import (DEFAULT_COMP, DEFAULT_MSC,
                               DEFAULT_MFC, DEFAULT_MFF,
                               DEFAULT_TEMPTED_PC,
                               DEFAULT_TEMPTED_EP,
                               DEFAULT_TEMPTED_SMTH,
                               DEFAULT_TEMPTED_RES,
                               DEFAULT_TEMPTED_MAXITER,
                               DEFAULT_TEMPTED_RH as DEFAULT_TRH,
                               DEFAULT_TEMPTED_RHC as DEFAULT_RC,
                               DEFAULT_TEMPTED_SVDC,
                               DEFAULT_TEMPTED_SVDCN as DEFAULT_TSCN)

from gemelli.tempted import (freg_rkhs, bernoulli_kernel)

In [None]:
def format_time():
    '''
    Normalize time points to be in the same format and keep
    only the defined interval (if defined)
    '''

In [None]:
def initialization(n_individuals, tables_update):
    """
    Initialize subject and feature loadings

    Parameters
    ----------
    n_individuals: int, required
        Number of unique individuals/samples
    tables_update: dictionary, required
        Dictionary of tables constructed
        (see build_sparse class).
        keys = individual_ids
        values = list of DataFrame, required
            For each DataFrame (modality):
            rows = features
            columns = samples

    Returns
    ----------
    a_hat: list of int
        Updated subject loadings
    b_hat: dictionary
        Updated feature loadings
        keys = modality
        values = loadings
        
    Raises
    ----------
    TODO
    """
    pass

In [None]:
def udpate_tabular(tables_update, n_individuals,
                   tipos, a_hat, phi_hats, b_hats):
    '''
    Update the tabular loadings (subjects and features) loadings

    Parameters
    ----------
    tables_update: dictionary, required
        Dictionary of tables constructed
        (see build_sparse class).
        keys = individual_ids
        values = list of DataFrame, required
            For each DataFrame (modality):
            rows = features
            columns = samples
    n_individuals: int, required
        Number of unique individuals/samples
    tipos: list of boolean, required
        Time points to keep, based on the defined interval
    a_hat: list of int, required
        Subject loadings from the previous iteration
    phi_hats: DataFrame, required
        Temporal loadings from the previous iteration
        rows = timepoints
        columns = modality
    b_hats: dictionary, required
        Feature loadings from the previous iteration
        keys = modality
        values = loadings

    Returns
    ----------
    a_new: list of int
        Updated subject loadings
    b_new: dictionary
        Updated feature loadings
        keys = modality
        values = loadings
        
    Raises
    ----------
    TODO
    '''
    pass

In [None]:
def update_lambda(tables_update, tipos, 
                  a_hat, phi_hats, b_hats):
    '''
    Updates the singular values using the loadings
    from the most recent iteration

    Parameters
    ----------
    tables_update: dictionary, required
        Dictionary of tables constructed
        (see build_sparse class).
        keys = individual_ids
        values = list of DataFrame, required
            For each DataFrame (modality):
            rows = features
            columns = samples
    tipos: list of boolean, required
        Time points to keep, based on the defined interval
    a_hat: list of int, required
        Subject loadings from the previous iteration
    phi_hats: DataFrame, required
        Temporal loadings from the previous iteration
        rows = timepoints
        columns = modality
    b_hats: dictionary, required
        Feature loadings from the previous iteration
        keys = modality
        values = loadings

    Returns
    ----------
    lambda_new: dictionary
        Updated singular values
        keys = modality
        values = loadings
        
    Raises
    ----------
    TODO
    '''
    pass

In [None]:
def udpate_residuals():
    '''
    Update the tensor to be factorized by subtracting the 
    approximation the previous iteration
    '''
    pass

In [None]:
def joint_ctf(tables, ##NEW-ish
              sample_metadata: pd.DataFrame,
              individual_id_column: str,
              state_column: str,
              tensor_column: str, #NEW
              n_components: int = DEFAULT_COMP,
              min_sample_count: int = DEFAULT_MSC,
              min_feature_count: int = DEFAULT_MFC,
              min_feature_frequency: float = DEFAULT_MFF,
              transformation: Callable = matrix_rclr,
              pseudo_count: float = DEFAULT_TEMPTED_PC,
              replicate_handling: str = DEFAULT_TRH,
              svd_centralized: bool = DEFAULT_TEMPTED_SVDC,
              n_components_centralize: int = DEFAULT_TSCN,
              smooth: float = DEFAULT_TEMPTED_SMTH,
              resolution: int = DEFAULT_TEMPTED_RES,
              max_iterations: int = DEFAULT_TEMPTED_MAXITER,
              epsilon: float = DEFAULT_TEMPTED_EP) #-> (
            #OrdinationResults,
            #pd.DataFrame,
            #DistanceMatrix,
            #pd.DataFrame)):
    '''
    Joint decomposition of two or more tensors
    
    Parameters
    ----------
    tables: list of numpy.ndarray, required
        List of feature tables from different modalities
        in biom format containing the samples over which
        metrics should be computed.
        Each modality should contain same number of samples
        or individuals. Length of features might vary.
    
    sample_metadata: DataFrame, required
        Sample metadata file in QIIME2 formatting. The file must
        contain the columns for individual_id_column and
        state_column and the rows matched to the table.
    ##CHECK IF MULTIPLE METADATAS WILL BE NEEDED

    individual_id_column: str, required
        Metadata column containing subject IDs to use for
        pairing samples. WARNING: if replicates exist for an
        individual ID at either state_1 to state_N, that
        subject will be mean grouped by default.

    state_column: str, required
        Metadata column containing state (e.g.,Time,
        BodySite) across which samples are paired. At least
        one is required but up to four are allowed by other
        state inputs.
    
    ##NEW##
    tensor_column: str, required
        Metadata column denoting which modality was 
        collected at a specific time point for each sample.
        Note that for some time points, data from several   
        modalities might be available 

    n_components: int, optional : Default is 3
        The underlying rank of the data and number of
        output dimentions.

    ##DO WE NEED TO ADJUST THE FILTERS BASED ON THE MODALITY?
    ##IF SO, THIS COULD BE CHANGED TO LIST OF INT
    min_sample_count: int, optional : Default is 0
        Minimum sum cutoff of sample across all features.
        The value can be at minimum zero and must be an
        whole integer. It is suggested to be greater than
        or equal to 500.

    min_feature_count: int, optional : Default is 0
        Minimum sum cutoff of features across all samples.
        The value can be at minimum zero and must be
        an whole integer.

    min_feature_frequency: float, optional : Default is 0
        Minimum percentage of samples a feature must appear
        with a value greater than zero. This value can range
        from 0 to 100 with decimal values allowed.

    transformation: function, optional : Default is matrix_rclr
        The transformation function to use on the data.

    pseudo_count: float, optional : Default is 1
        The pseudo count to add to all values before applying
        the transformation.

    replicate_handling: function, optional : Default is "sum"
        Choose how replicate samples are handled. If replicates are
        detected, "error" causes method to fail; "drop" will discard
        all replicated samples; "random" chooses one representative at
        random from among replicates.

    svd_centralized: bool, optional : Default is True
        Removes the mean structure of the temporal tensor.

    n_components_centralize: int
        Rank of approximation for average matrix in svd-centralize.

    smooth: float, optional : Default is 1e-8
        Smoothing parameter for RKHS norm. Larger means
        smoother temporal loading functions.

    resolution: int, optional : Default is 101
        Number of time points to evaluate the value
        of the temporal loading function.

    max_iterations: int, optional : Default is 20
        Maximum number of iteration in for rank-1 calculation.

    epsilon: float, optional : Default is 0.0001
        Convergence criteria for difference between iterations
        for each rank-1 calculation.

    Returns
    -------
    OrdinationResults
        Compositional biplot of subjects as points and
        features as arrows. Where the variation between
        subject groupings is explained by the log-ratio
        between opposing arrows.

    DataFrame
        Each components temporal loadings across the
        input resolution included as a column called
        'time_interval'.

    DistanceMatrix
        A subject-subject distance matrix generated
        from the euclidean distance of the
        subject ordinations and itself.

    DataFrame
        The loadings from the SVD centralize
        function, used for projecting new data.
        Warning: If SVD-centering is not used
        then the function will add all ones as the
        output to avoid variable outputs.

    Raises
    ------
    ValueError
        if features don't match between tables
        across the values of the dictionary
    ValueError
        if id_ not in mapping
    ValueError
        if any state_column not in mapping
    ValueError
        Table is not 2-dimensional
    ValueError
        Table contains negative values
    ValueError
        Table contains np.inf or -np.inf
    ValueError
        Table contains np.nan or missing.
    Warning
        If a conditional-sample pair
        has multiple IDs associated
        with it. In this case the
        default method is to mean them.
    ValueError
        `ValueError: n_components must be at least 2`.
    ValueError
        `ValueError: Data-table contains
         either np.inf or -np.inf`.
    ValueError
        `ValueError: The n_components must be less
         than the minimum shape of the input tensor`.

    Examples
    --------
    TODO
    '''
    pass

In [None]:
def joint_ctf_helper():
    '''
    Joint decomposition of two or more tensors
    '''
    pass