Takes the per-sample prediction scores and the per-taxa attribution values (both from Attribution_calculations.ipynb) and identifies the taxa most and least associated with IBD.

In [6]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import copy
device = "cuda:0"



In [8]:
def load_dict_from_at_tensor(at_tensor):
    """
    Loads a dictionary of attributions from a tensor.
    """
    attributions = {}
    for i in range(len(at_tensor)):
        sample = int(at_tensor[i][0].item())
        microbe = int(at_tensor[i][1].item())
        att = at_tensor[i][2].item()
        base = at_tensor[i][3].item()
        if not microbe in attributions:
            attributions[microbe] = [[], [], []]
        attributions[microbe][0].append(base)
        attributions[microbe][1].append(att)
        attributions[microbe][2].append(sample)
    return attributions

# Load the different attribution values we calculated from the Attribution_calculations.ipynb
sh_at_tensor = torch.load('/path/to/Schirmer_attributions_e.pth')
sh_attributions = load_dict_from_at_tensor(sh_at_tensor)

hf_at_tensor = torch.load("/path/to/Halfvarson_attributions_e.pth")
hf_attributions = load_dict_from_at_tensor(hf_at_tensor)

ibd_at_tensor = torch.load('/path/to/IBD_attributions_e.pth')
ibd_attributions = load_dict_from_at_tensor(ibd_at_tensor)

In [9]:
# Load the labels for the different datasets
hvl = torch.from_numpy(np.load("/path/to/halfvarson_IBD_labels.npy"))
sl = torch.from_numpy(np.load("/path/to/schirmer_IBD_labels.npy"))
ibdl = torch.from_numpy(np.load("/path/to/total_IBD_label.npy")[:,0])
print(sum(hvl), len(hvl))
print(sum(sl), len(sl))
print(sum(ibdl), len(ibdl))

tensor(510, dtype=torch.int32) 564
tensor(155, dtype=torch.int32) 197
tensor(435) 8571


In [10]:
# Load the different datasets
hvd = np.load("/path/to/halfvarson_512_otu.npy")
sd = np.load("/path/to/schirmer_IBD_512_otu.npy")
ibdd = np.load("/path/to/total_IBD_512.npy")

In [11]:
# Load the base scores for the different datasets (as calculated in Attribution_calculations.ipynb)
ibd_base_scores = torch.load('/path/to/IBD_base_scores_e.pth')[0]
hf_base_scores = torch.load('/path/to/Halfvarson_base_scores_e.pth')[0]
sh_base_scores = torch.load('/path/to/Schirmer_base_scores_e.pth')[0]

In [12]:
def filter_attributions_and_plot_hbars(attribution_dict, lenlim=10, title="Unspecified",
                                       neg_cutoff=1, pos_cutoff=0, per_microbe_std=True, n=30,
                                       data_labels=None, filter=False, rarity_lim=0, total_freqs=[],
                                       all_scores=[], sort_by_positive_class_attribution=False, plot=True):
    """
    Filter and analyze attribution data for microbes, optionally plotting the results.

    This function processes attribution data for microbes, applying various filters and
    calculating statistics. It can optionally generate a horizontal bar plot of the results.

    Parameters:
    - attribution_dict (dict): Dictionary containing attribution data for each microbe.
    - lenlim (int): Minimum number of attributions required for a microbe to be included.
    - title (str): Title for the plot.
    - neg_cutoff (float): Cutoff percentile for negative-class scores. For a datapoint with a negative label, its score must be below or equal to the provided percentile.
    - pos_cutoff (float): Cutoff percentile for positive-class scores. For a datapoint with a positive label, its score must be above or equal to the provided percentile.
    - per_microbe_std (bool): If True, use per-microbe standard deviation for the standard deviation of the attribution scores for plotting error bars; otherwise, use sample size adjusted std.
    - n (int): Number of top microbes to analyze and potentially plot.
    - data_labels (array-like): Binary true labels for the samples.
    - filter (bool): If True, apply filtering based on scores and rarity.
    - rarity_lim (int): Minimum frequency of occurrence for a microbe to be included.
    - total_freqs (list): List of total occurrence frequencies for each microbe.
    - all_scores (list): List of all attribution scores (from Attribution_calculations.ipynb).
    - sort_by_positive_class_attribution (bool): If True, sort by attributions on the positive class alone; otherwise, sort by overall attribution.
    - plot (bool): If True, generate and display a plot of the results.

    Returns:
    - tuple: (microbes, microbe_stats, filtered_attribution_dict)
        microbes (list): List of top microbe IDs.
        microbe_stats (list): Statistics for each microbe (microbe_id, num_attributions, sum_attributions, std_attributions, index).
        filtered_attribution_dict (dict): Filtered attribution dictionary.

    The function performs the following steps:
    1. Optionally filters the attribution data based on scores and rarity.
    2. Calculates statistics for positive and negative attributions.
    3. Sorts the results based on specified criteria.
    4. Returns the top n microbes, their statistics, and the filtered attribution dictionary.
    5. Optionally generates a horizontal bar plot of the top microbes' attributions.

    The plot, if generated, shows positive and negative attributions for each microbe,
    along with error bars and additional annotations about sample counts and attribution types.
    """    
    
    if filter:
        sorted_neg_scores = sorted(all_scores[(1 - data_labels).to(torch.bool)])
        sorted_pos_scores = sorted(all_scores[data_labels.to(torch.bool)])

        # Calculate cutoff scores for negative and positive classes
        neg_cutoff_score = sorted_neg_scores[int(neg_cutoff * len(sorted_neg_scores))]
        pos_cutoff_score = sorted_pos_scores[min(int(pos_cutoff * len(sorted_pos_scores)), len(sorted_pos_scores) - 1)]
        
        # Apply filtering to the attribution dictionary
        attribution_dict = copy.deepcopy(attribution_dict)
        for k in attribution_dict.keys():
            entry = attribution_dict[k]
            filtered_results = [[entry[0][i], entry[1][i], entry[2][i]] for i in range(len(entry[0])) if ( \
                                                                ((not data_labels[entry[2][i]]) and entry[0][i] <= neg_cutoff_score or \
                                                                 data_labels[entry[2][i]] and entry[0][i] >= pos_cutoff_score) and \
                                                                (entry[2][i] in total_freqs and total_freqs[entry[2][i]] >= rarity_lim))]
            # Update the filtered results in the attribution dictionary
            entry[0] = [e[0] for e in filtered_results]
            entry[1] = [e[1] for e in filtered_results]
            entry[2] = [e[2] for e in filtered_results]
    # Create a list of attribution results for each microbe meeting the criteria:
    # Each sublist contains: [microbe_id, num_attributions, sum_attributions, std_attributions, 
    #                         sum_abs_attributions, std_abs_attributions]
    # Only includes microbes with at least 'lenlim' attributions and excludes microbe ID 26727
    attribution_results = [
        [
            k, #0 Microbe ID
            len(attribution_dict[k][1]), #1 Number of attributions
            sum(attribution_dict[k][1]), #2 Sum of attributions
            np.std(attribution_dict[k][1]), #3 Standard deviation of attributions
            sum([abs(w) for w in attribution_dict[k][1]]), #4 Sum of absolute values of all attributions
            np.std([abs(w) for w in attribution_dict[k][1]]) #5 Standard deviation of absolute values of all attributions
        ] 
        for k in attribution_dict.keys() 
        if (len(attribution_dict[k][1]) >= lenlim and k != 26727)
    ]
    attribution_results = sorted(attribution_results, reverse=True, key=lambda x : x[2] / x[1])

    # Separate positive and negative attributions
    positive_attribution_dict = copy.deepcopy(attribution_dict)
    for k in positive_attribution_dict.keys():
        positive_attribution_dict[k][1] = [at for at in positive_attribution_dict[k][1] if at > 0]
    negative_attribution_dict = copy.deepcopy(attribution_dict)
    for k in positive_attribution_dict.keys():
        negative_attribution_dict[k][1] = [at for at in negative_attribution_dict[k][1] if at < 0]

    # Calculate detailed statistics for positive attributions
    positive_attribution_results = [
        [
            k,  #0 Microbe ID
            len(positive_attribution_dict[k][1]),  #1 Number of positive attributions
            sum(positive_attribution_dict[k][1]),  #2 Sum of positive attributions
            np.std(positive_attribution_dict[k][1]) if len(positive_attribution_dict[k][1]) > 0 else 0,  #3 Standard deviation of positive attributions (0 if none)
            sum([abs(w) for w in attribution_dict[k][1]]),  #4 Sum of absolute values of all attributions
            np.std([abs(w) for w in attribution_dict[k][1]]) if len(positive_attribution_dict[k][1]) > 0 else 0,  #5 Standard deviation of absolute attributions (0 if no positive attributions)
            len(attribution_dict[k][1]),  #6 Total number of attributions (positive and negative)
            sum(attribution_dict[k][1])  #7 Sum of all attributions (positive and negative)
        ]
        for k in attribution_dict.keys()
        if (len(attribution_dict[k][1]) >= lenlim and k != 26727)  # Include only if total attributions >= lenlim and microbe ID is not 26727
    ]

    # Calculate detailed statistics for negative attributions
    negative_attribution_results = [
        [
            k,  #0 Microbe ID
            len(negative_attribution_dict[k][1]),  #1 Number of negative attributions
            sum(negative_attribution_dict[k][1]),  #2 Sum of negative attributions
            np.std(negative_attribution_dict[k][1]) if len(negative_attribution_dict[k][1]) > 0 else 0,  #3 Standard deviation of negative attributions (0 if none)
            sum([abs(w) for w in attribution_dict[k][1]]),  #4 Sum of absolute values of all attributions
            np.std([abs(w) for w in attribution_dict[k][1]]) if len(negative_attribution_dict[k][1]) > 0 else 0,  #5 Standard deviation of absolute attributions (0 if no negative attributions)
            len(attribution_dict[k][1]),  #6 Total number of attributions (positive and negative)
            len(positive_attribution_dict[k][1]),  #7 Number of positive attributions
            sum(positive_attribution_dict[k][1]),  #8 Sum of positive attributions
            sum(attribution_dict[k][1])  #9 Sum of all attributions (positive and negative)
        ]
        for k in attribution_dict.keys()
        if (len(attribution_dict[k][1]) >= lenlim and k != 26727)  # Include only if total attributions >= lenlim and microbe ID is not 26727
    ]

    # Sort attribution results based on specified criteria (either by average positive class attribution or by average total attribution)
    if sort_by_positive_class_attribution:
        positive_attribution_results = sorted(positive_attribution_results, reverse=True, key=lambda x : x[2] / max(x[1], 1))
        negative_attribution_results = sorted(negative_attribution_results, reverse=True, key=lambda x : x[8] / max(x[7], 1))
    else:
        positive_attribution_results = sorted(positive_attribution_results, reverse=True, key=lambda x : x[7] / x[6])
        negative_attribution_results = sorted(negative_attribution_results, reverse=True, key=lambda x : x[9] / x[6])

    # Prepare data for plotting
    microbes = [j[0] for j in attribution_results[:n]]
    positive_attributions = [j[2] / max(j[1], 1) for j in positive_attribution_results[:n]]
    negative_attributions = [j[2] / max(j[1], 1) for j in negative_attribution_results[:n]]

    frequencies = [[j[0], j[1]] for j in attribution_results[:n]]
    positive_errors = [j[3] / (1 if per_microbe_std else np.sqrt(max(j[1], 1))) for j in positive_attribution_results[:n]]
    negative_errors = [j[3] / (1 if per_microbe_std else np.sqrt(max(j[1], 1))) for j in negative_attribution_results[:n]]
    
    # microbe_stats entry: [microbe_id, num_attributions, sum_attributions, std_attributions, index]
    microbe_stats = [[j[0], j[1], j[2], j[3], l] for j,l in zip(attribution_results, range(len(attribution_results)))]
    if plot:
        plt.barh(range(n), positive_attributions, label="Positive", xerr=positive_errors)
        plt.barh(range(n), negative_attributions, label="Negative", xerr=negative_errors)
        plt.title("Positive and Negative Attribution Results For " + title)
        plt.legend()
        xlim = -1000
        for i, v in enumerate(frequencies):
            if data_labels is None:
                annotation = str(v[1])
            else:
                samples = np.array(attribution_dict[v[0]][2])
                n_pos = sum(data_labels[samples]).item()
                annotation = "P:" + str(n_pos) + ",  N:" + str(v[1] - n_pos) + "  --  PA:" + str(positive_attribution_results[i][1]) + ", NA:" + str(negative_attribution_results[i][1])
            text_x_loc = positive_attributions[i] + positive_errors[i] + 0.02 * (sum(positive_attributions) - sum(negative_attributions)) / n
            plt.text(text_x_loc, i - 0.15, annotation)
            if text_x_loc > xlim:
                xlim = text_x_loc
        xlim += 0.35 * (sum(positive_attributions) - sum(negative_attributions)) / n
        plt.xlim((-xlim, xlim))
        plt.yticks(range(n), microbes)
        plt.ylabel("Microbe ID")
        plt.xlabel("Attribution (Feature Ablation)")
        plt.show()
    return microbes, microbe_stats, attribution_dict


In [13]:
# Count the total number of attributions for each microbe across all datasets
counts_dict = {}
for ad in [hf_attributions, ibd_attributions, sh_attributions]:
    for m in ad.keys():
        if not m in counts_dict:
            counts_dict[m] = 0
        counts_dict[m] += len(ad[m][0])


In [14]:
plt.rcParams['figure.figsize'] = [12, 8]

# Set cutoff values for positive and negative class scores
pos_cf = 0.5
neg_cf = 0.5
# Set rarity limit to 5% of the total sample size
rarity_l = int(9332 * 0.05)

# Set minimum number of attributions required for a microbe to be included
L = 5
print("L = 5, n = 30000")

# Process Halfvarson dataset attributions
top_half_abs_L_5_n_300, top_half_abs_L_5_n_300_stats, filtered_hf_atts = filter_attributions_and_plot_hbars(
    hf_attributions,  # Halfvarson attributions
    n=30000,  # Number of top microbes to analyze
    sort_by_positive_class_attribution=False,
    pos_cutoff=pos_cf,
    neg_cutoff=neg_cf,
    lenlim=L,  # Minimum number of attributions
    rarity_lim=rarity_l,
    absolute=False,
    title="Halfvarson (Per Sample STD)",
    data_labels=hvl,  # Halfvarson labels
    filter=True,
    all_scores=hf_base_scores,
    total_freqs=counts_dict,
    plot=False  # Don't generate plot
)

# Process IBD dataset attributions
top_ibd_abs_L_5_n_300, top_ibd_abs_L_5_n_300_stats, filtered_ibd_atts = filter_attributions_and_plot_hbars(
    ibd_attributions,  # IBD attributions
    n=30000,
    sort_by_positive_class_attribution=False,
    pos_cutoff=pos_cf,
    neg_cutoff=neg_cf,
    lenlim=L,
    rarity_lim=rarity_l,
    absolute=False,
    title="IBD (Per Sample STD)",
    data_labels=ibdl,  # IBD labels
    filter=True,
    all_scores=ibd_base_scores,
    total_freqs=counts_dict,
    plot=False
)

# Process Schirmer dataset attributions
top_sh_abs_L_5_n_300, top_sh_abs_L_5_n_300_stats, filtered_sh_atts = filter_attributions_and_plot_hbars(
    sh_attributions,  # Schirmer attributions
    n=30000,
    sort_by_positive_class_attribution=False,
    pos_cutoff=pos_cf,
    neg_cutoff=neg_cf,
    lenlim=L,
    rarity_lim=rarity_l,
    absolute=False,
    title="Schirmer (Per Sample STD)",
    data_labels=sl,  # Schirmer labels
    filter=True,
    all_scores=sh_base_scores,
    total_freqs=counts_dict,
    plot=False
)

L = 5, n = 300


In [15]:
# Combine Halfvarson and Schirmer attributions
hf_sh_attributions = copy.deepcopy(hf_attributions)
sh_attributions_c = copy.deepcopy(sh_attributions)

# Merge Schirmer attributions into the combined dictionary
for m in sh_attributions_c.keys():
    if not m in hf_sh_attributions:
        hf_sh_attributions[m] = [[], [], []]
    for i in range(len(sh_attributions_c[m][0])):
        # Append base score
        hf_sh_attributions[m][0].append(sh_attributions_c[m][0][i])
        # Append attribution
        hf_sh_attributions[m][1].append(sh_attributions_c[m][1][i])
        # Append index
        hf_sh_attributions[m][2].append(sh_attributions_c[m][2][i] + len(hvl))

# Combine Halfvarson and Schirmer labels
hvsl = torch.cat((hvl, sl))
# Combine Halfvarson and Schirmer base scores
hf_sh_base_scores = torch.cat((hf_base_scores, sh_base_scores))

In [16]:
pos_cf = 0.001
neg_cf = 0.999

top_hf_sh_abs_L_5_n_300, top_hf_sh_abs_L_5_n_300_stats, filtered_hf_sh_atts = filter_attributions_and_plot_hbars(hf_sh_attributions, n=30000, sort_by_positive_class_attribution=False, pos_cutoff = pos_cf, \
                                  neg_cutoff = neg_cf, lenlim = L, rarity_lim = rarity_l, absolute = False, title = "Schirmer + Halfvarson (Per Sample STD)", \
                                  data_labels=hvsl, filter=True, all_scores = hf_sh_base_scores, total_freqs = counts_dict, plot = False)

# Set cutoff values for positive and negative class scores
# Essentially disables this part of filtering, since we've already filtered once when generating hf_sh_attributions.
pos_cf = 0.001
neg_cf = 0.999

# Process combined Halfvarson and Schirmer attributions
top_hf_sh_abs_L_5_n_300, top_hf_sh_abs_L_5_n_300_stats, filtered_hf_sh_atts = filter_attributions_and_plot_hbars(
    hf_sh_attributions,  # Combined Halfvarson and Schirmer attributions
    n=30000,  # Number of top microbes to analyze (everything)
    sort_by_positive_class_attribution=False,  # Sort by overall attribution, not just positive class
    pos_cutoff=pos_cf,  # Cutoff for positive class scores (0.1% percentile)
    neg_cutoff=neg_cf,  # Cutoff for negative class scores (99.9% percentile)
    lenlim=L,  # Minimum number of attributions required (defined earlier)
    rarity_lim=rarity_l,  # Minimum frequency of occurrence (defined earlier)
    absolute=False,  # Use signed attributions, not absolute values
    title="Schirmer + Halfvarson (Per Sample STD)",
    data_labels=hvsl,  # Combined Halfvarson and Schirmer labels
    filter=True,  # Apply filtering based on scores and rarity
    all_scores=hf_sh_base_scores,  # Combined base scores
    total_freqs=counts_dict,  # Dictionary of total frequencies for each microbe
    plot=False  # Don't generate a plot
)

In [17]:
def report_overlap(top_IBD_abs, top_half_abs, top_sch_abs):
    """
    Reports the overlap between the top IBD, Halfvarson, and Schirmer microbes.

    Args:
    top_IBD_abs (list): List of top microbes from IBD dataset.
    top_half_abs (list): List of top microbes from Halfvarson dataset.
    top_sch_abs (list): List of top microbes from Schirmer dataset.

    Returns:
    tuple: A tuple containing:
        - intersections (list): List of overlap counts between datasets.
        - overlap_lists (list): List of overlapping microbes between datasets.
    """
    names = ["IBD", "HF ", "SH "]
    string = '     IBD HF  SH\n'
    intersections = []
    overlap_lists = []
    for x, name in zip([top_IBD_abs, top_half_abs, top_sch_abs], names):
        row = []
        string = string + name + ": "
        for y in [top_IBD_abs, top_half_abs, top_sch_abs]:
            num_same = sum([1 if t in y else 0 for t in x])
            overlap = [t for t in x if t in y and x != y]
            string = string + str(num_same) + ("  " if len(str(num_same)) > 1 else "   ")
            row.append(num_same)
            overlap_lists.append(overlap)
        string = string + "\n"
        intersections.append(row)
    print(string)
    return intersections, overlap_lists

def pick_entry(m, att_stats):
    """
    Finds and returns the attribution statistics for a given microbe.

    Args:
    m (int): Microbe ID to search for.
    att_stats (list): List of attribution statistics for all microbes.

    Returns:
    list or int: Attribution statistics for the microbe if found, -1 otherwise.
    """
    for a in att_stats:
        if a[0] == m:
            return a
    return -1

#  k, len(attribution_dict[k][1]), sum(attribution_dict[k][1]), np.std(attribution_dict[k][1])

def report_attributions(microbes, attribution_stats, dataset_names):
    """
    Generates a report of attribution statistics for given microbes across multiple datasets.

    Args:
    microbes (list): List of microbe IDs to report on.
    attribution_stats (list): List of attribution statistics for each dataset.
    dataset_names (list): Names of the datasets.

    Returns:
    list: Sorted list of attribution reports for each microbe across all datasets.
    """
    output = []
    for m in microbes:
        microbe_result = [m]
        for att, n in zip(attribution_stats, dataset_names):
            microbe_stats = pick_entry(m, att)
            if type(microbe_stats) == list:
                n_averaged = microbe_stats[1]
                avg_at = microbe_stats[2] / n_averaged
                std = microbe_stats[3] / (n_averaged ** 0.5)
                pos_in_orig = microbe_stats[4]
                pads = 6 - len(str(n_averaged)) - len(str(pos_in_orig))
                padding = ' '.join('' for _ in range(pads))
                microbe_result.append([n, "{:.5f}".format(avg_at), "{:.5f}".format(std), str(n_averaged), str(pos_in_orig) + padding, avg_at])
            else:
                microbe_result.append([n, "missing", "missing", "0", "---", 0])
        output.append(microbe_result)
    output = sorted(output, reverse=True, key=lambda x : x[1][-1])
    for mr in output:
        for t in mr:
            if type(t) == list and type(t[-1]) != str:
                del t[-1]
    return output

def check_att_match(att_stats, microbes, cutoff):
    """
    Checks which microbes have an average attribution above a given cutoff.

    Args:
    att_stats (list): Attribution statistics for all microbes.
    microbes (list): List of microbe IDs to check.
    cutoff (float): Minimum average attribution value.

    Returns:
    list: Microbes that pass the cutoff criterion.
    """
    passed_microbes = []
    for m in microbes:
        microbe_stats = pick_entry(m, att_stats)
        if type(microbe_stats) == list:
            n_averaged = microbe_stats[1]
            avg_at = microbe_stats[2] / n_averaged
            if avg_at > cutoff:
                passed_microbes.append(m)
    return passed_microbes

def check_passes_min_atts(all_att_stats, microbes, minimum):
    """
    Checks which microbes have an average attribution above a minimum value across all datasets.

    Args:
    all_att_stats (list): List of attribution statistics for each dataset.
    microbes (list): List of microbe IDs to check.
    minimum (float): Minimum average attribution value.

    Returns:
    list: Microbes that pass the minimum criterion across all datasets.
    """
    passed_microbes = []
    for m in microbes:
        passes_all_datasets = True
        for att_stats in all_att_stats:
            microbe_stats = pick_entry(m, att_stats)
            if type(microbe_stats) == list:
                n_averaged = microbe_stats[1]
                avg_at = microbe_stats[2] / n_averaged
                if avg_at < minimum:
                    passes_all_datasets = False
        if passes_all_datasets:
            passed_microbes.append(m)
    return passed_microbes


def check_same_atts_sign(all_att_stats, microbes):
    """
    Checks which microbes have the same attribution sign across all datasets.

    Args:
    all_att_stats (list): List of attribution statistics for each dataset.
    microbes (list): List of microbe IDs to check.

    Returns:
    list: Microbes that have consistent attribution signs across all datasets.
    """
    passed_microbes = []
    for m in microbes:
        att_list = []
        for att_stats in all_att_stats:
            microbe_stats = pick_entry(m, att_stats)
            if type(microbe_stats) == list:
                n_averaged = microbe_stats[1]
                avg_at = microbe_stats[2] / n_averaged
                att_list.append(avg_at < 0)
        passes_all_datasets = (len(att_list) == len(all_att_stats))
        current_sign = att_list[0]
        for s in att_list[1:]:
            if s != current_sign:
                passes_all_datasets = False
        if passes_all_datasets:
            passed_microbes.append(m)
    return passed_microbes


def get_combined_microbes_atts(m, att_dicts):
    """
    Retrieves all attributions for a given microbe across multiple attribution dictionaries.

    Args:
    m (int): Microbe ID to retrieve attributions for.
    att_dicts (list): List of attribution dictionaries from different datasets.

    Returns:
    list: Combined list of attributions for the microbe across all datasets.
    """
    atts = []
    for att_dict in att_dicts:
        if m in att_dict:
            for a in att_dict[m][1]:
                atts.append(a)
    return atts


In [18]:
intersections, overlaps = report_overlap(top_ibd_abs_L_5_n_300, top_half_abs_L_5_n_300, top_sh_abs_L_5_n_300)

ibd_microbes = pd.DataFrame(list(ibd_attributions.keys()))
hf_microbes = pd.DataFrame(list(hf_attributions.keys()))
sh_microbes = pd.DataFrame(list(sh_attributions.keys()))

     IBD HF  SH
IBD: 2408  645  405  
HF : 645  1829  204  
SH : 405  204  550  



In [19]:
# Check if microbes pass the minimum attribution threshold

top_half_pass_min_att = check_passes_min_atts([top_ibd_abs_L_5_n_300_stats, top_half_abs_L_5_n_300_stats, top_sh_abs_L_5_n_300_stats], top_half_abs_L_5_n_300, 0.0001)
top_sch_pass_min_att = check_passes_min_atts([top_ibd_abs_L_5_n_300_stats, top_half_abs_L_5_n_300_stats, top_sh_abs_L_5_n_300_stats], top_sh_abs_L_5_n_300, 0.0001)


In [20]:
# Check if microbes have the same attribution sign across all datasets

top_half_sh_same_sign = check_same_atts_sign([top_ibd_abs_L_5_n_300_stats, top_hf_sh_abs_L_5_n_300_stats], top_hf_sh_abs_L_5_n_300)


In [21]:
# Combine attributions for microbes that pass the above criteria
combined_stats = []

for m in top_half_sh_same_sign:
    combined_atts = get_combined_microbes_atts(m, [filtered_hf_sh_atts, filtered_ibd_atts])
    n = len(combined_atts)
    mean_att = sum(combined_atts) / n
    mean_std = torch.std(torch.tensor(combined_atts)).item() / n ** 0.5
    combined_stats.append([m, n, mean_att, mean_std, "Null"])

# Sort combined stats
combined_stats = sorted(combined_stats, reverse = True, key=lambda x : x[2])

In [22]:
# Find top and bottom 10 microbes
top_10_microbes = [stats[0] for stats in combined_stats[:10]]
bottom_10_microbes = [stats[0] for stats in combined_stats[-10:]]

In [23]:
top_10_microbes, bottom_10_microbes

([2602, 1054, 5969, 2139, 2721, 1988, 506, 1116, 4202, 1555],
 [138, 99, 1013, 1074, 1711, 4830, 1050, 1785, 2710, 647])

In [24]:
top_half_pass_min_att_pass_IBD_check = check_att_match(top_ibd_abs_L_5_n_300_stats, top_half_pass_min_att, 0.00098)
top_sch_pass_min_att_pass_IBD_check = check_att_match(top_ibd_abs_L_5_n_300_stats, top_sch_pass_min_att, 0.00101)

In [25]:
# Report some overlap statistics

intersections, overlap_lists = report_overlap(top_ibd_abs_L_5_n_300, top_half_abs_L_5_n_300, top_sh_abs_L_5_n_300)
ibd_hf_overlap = overlap_lists[3]
ibd_sh_overlap = overlap_lists[6]
hf_sh_overlap = overlap_lists[7]
all_overlap = [x for x in ibd_hf_overlap if x in hf_sh_overlap]
print("IBD / Halfvarson Overlap      = ", ibd_hf_overlap)
print("IBD / Schirmer Overlap        = ", ibd_sh_overlap)
print("Halfvarson / Schirmer Overlap = ", hf_sh_overlap)
print("All overlap                   = ", all_overlap)

     IBD HF  SH
IBD: 2408  645  405  
HF : 645  1829  204  
SH : 405  204  550  

IBD / Halfvarson Overlap      =  [1054, 5620, 2602, 2721, 4626, 5690, 10722, 1205, 2139, 4202, 4768, 852, 4029, 1392, 1988, 1576, 1498, 1555, 2000, 891, 6116, 2074, 3451, 1510, 8077, 1044, 5267, 1133, 5752, 2018, 678, 1618, 2672, 688, 1534, 657, 1970, 2161, 1541, 746, 3407, 2623, 446, 5478, 1249, 1192, 2173, 2650, 2481, 693, 1729, 1540, 1958, 11627, 7271, 12357, 537, 2040, 979, 531, 965, 3866, 1307, 2149, 357, 2336, 5252, 510, 1775, 1037, 1219, 1123, 2010, 6042, 3188, 535, 794, 939, 3497, 1095, 1092, 1710, 1466, 3142, 2528, 2483, 2864, 937, 4389, 2687, 1138, 4664, 5393, 1030, 940, 931, 1232, 1289, 399, 2884, 1159, 3192, 1461, 466, 1301, 734, 1881, 1917, 1611, 630, 3097, 3987, 1927, 1554, 2046, 1571, 557, 964, 1965, 665, 379, 1367, 1587, 10260, 5776, 1717, 1164, 823, 1169, 3511, 1769, 715, 800, 959, 1002, 1080, 6204, 2144, 1446, 1595, 1084, 2093, 12757, 3099, 1366, 1034, 5278, 1149, 1722, 443, 1625, 2134, 

In [26]:
# Now we take a look at what happens when we validate the top schirmer and halfvarson microbes on the IBD dataset

print("\n\n\nTop Schirmer Microbes Validated on IBD Microbe Average Attributions")
print("                 Attribution  STD      N     Rank                      Attribution  STD      N     Rank                      Attribution  STD      N     Rank   ")
all_overlap_atts = report_attributions(top_sch_pass_min_att_pass_IBD_check, [top_ibd_abs_L_5_n_300_stats, top_half_abs_L_5_n_300_stats, top_sh_abs_L_5_n_300_stats], ["IBD", "Halfvarson", "Schirmer"])
for x in all_overlap_atts:
    print('\t'.join([str(y) for y in x]))

print("\n\n\nTop Halfvarson Microbes Validated on IBD Microbe Average Attributions")
print("                 Attribution  STD      N     Rank                      Attribution  STD      N     Rank                      Attribution  STD      N     Rank   ")
all_overlap_atts = report_attributions(top_half_pass_min_att_pass_IBD_check, [top_ibd_abs_L_5_n_300_stats, top_half_abs_L_5_n_300_stats, top_sh_abs_L_5_n_300_stats], ["IBD", "Halfvarson", "Schirmer"])
for x in all_overlap_atts:
    print('\t'.join([str(y) for y in x]))

# And also look at the statistics of the top and bottom combined attributions
print("\n\n\nTop Combined Microbe Average Attributions")
print("                 Attribution  STD      N     Rank                      Attribution  STD      N     Rank                      Attribution  STD      N     Rank   ")
all_overlap_atts = report_attributions(top_10_microbes, [top_ibd_abs_L_5_n_300_stats, top_half_abs_L_5_n_300_stats, top_sh_abs_L_5_n_300_stats], ["IBD", "Halfvarson", "Schirmer"])
for x in all_overlap_atts:
    print('\t'.join([str(y) for y in x]))

print("\n\n\nBottom Combined Microbe Average Attributions")
print("                 Attribution  STD      N     Rank                      Attribution  STD      N     Rank                      Attribution  STD      N     Rank   ")
all_overlap_atts = report_attributions(bottom_10_microbes, [top_ibd_abs_L_5_n_300_stats, top_half_abs_L_5_n_300_stats, top_sh_abs_L_5_n_300_stats], ["IBD", "Halfvarson", "Schirmer"])
for x in all_overlap_atts:
    print('\t'.join([str(y) for y in x]))





Top Schirmer Microbes Validated on IBD Microbe Average Attributions
                 Attribution  STD      N     Rank                      Attribution  STD      N     Rank                      Attribution  STD      N     Rank   
2074	['IBD', '0.00264', '0.00023', '6', '15  ']	['Halfvarson', '0.00069', '0.00006', '29', '322']	['Schirmer', '0.00597', '0.00070', '10', '2  ']
506	['IBD', '0.00235', '0.00050', '5', '32  ']	['Halfvarson', 'missing', 'missing', '0', '---']	['Schirmer', '0.00100', '0.00020', '17', '119']
1205	['IBD', '0.00235', '0.00030', '10', '33 ']	['Halfvarson', '0.00093', '0.00010', '19', '215']	['Schirmer', '0.00515', '0.00241', '5', '7   ']
814	['IBD', '0.00230', '0.00020', '16', '39 ']	['Halfvarson', 'missing', 'missing', '0', '---']	['Schirmer', '0.00315', '0.00069', '5', '27  ']
678	['IBD', '0.00201', '0.00021', '21', '59 ']	['Halfvarson', '0.00053', '0.00007', '15', '405']	['Schirmer', '0.00306', '0.00042', '5', '29  ']
891	['IBD', '0.00173', '0.00015', '16', '91

In [28]:
# Print FASTA sequences for top and bottom microbes
fastas = []
for l in open("/path/to/seqs_.07_embed.fasta", "r").readlines():
    if len(l) > 20:
        fastas.append(l)


In [30]:
# Print FASTA sequences for top 10 microbes
for id in top_10_microbes:
    print("ID: " + str(id) + ", FASTA: " + fastas[id][:-1])
print()

# Print FASTA sequences for bottom 10 microbes
for id in bottom_10_microbes:
    print("ID: " + str(id) + ", FASTA: " + fastas[id][:-1])

ID: 2602, FASTA: TACGTAGGTGGCGAGCGTTGTCCGGAATTACTGGGTGTAAAGGGTGCGTAGGCGGGGATGCAAGTCAGATGTGAAATCTATCGGCTTAACTGGTAAACTGCATTTGAAACTGCATTTCTTGAGTGGTGGAGAGGTAAGCG
ID: 1054, FASTA: CACCGGCAGCTCAAGTGGTAGCTGTTTTTATTGGGCCTAAAGCGTTCGTAGCCGGTTTGATAAGTCTTTGGTGAAAGCTTGTAGCTTAACTATAAGAATTGCTGAAGATACTGTCAGACTTGAAGTCGGGAGAGGTTAGA
ID: 5969, FASTA: TACGTAGGTGGCAAGCGTTGTCCGGATTTACTGGGTGTAAAGGGCGAGTAGGCGGGACGGAAAGTCAGTAGTGAAATACCGAGGCTTAACTTCGGGGCTGCTATTGAAACTTCTGTTCTTGAGTGATGGAGAGGCAGGCG
ID: 2139, FASTA: TACGTAGGGGGCAAGCGTTATCCGGAATTACTGGGTGTAAAGGGTGCGTAGGCGGCCCGGCAAGTTTGATGTGAAACCCATAGGCTTAACCTGTGGCATGCATCAAAAACTACCGAGCTAGAGTGCAGGAGAGGAAAGCG
ID: 2721, FASTA: TACGTAGGGAGCGAGCGTTGTCCGGATTTACTGGGTGTAAAGGGCGTGCAGCCGGGCTGGTAAGTCAGATGTGAAATCCGTGGGCTTAACCCACGAACTGCATTTGAAACTGCTGGTCTTGAGTACCGGAGAGGTTATCG
ID: 1988, FASTA: TACGTAGGTGGCAAGCGTTGTCCGGAATTACTGGGTGTAAAGGGAGCGTAGGCGGGAGTGCAAGTTGAATGTGAAAACGATGGGCTCAACCCATCGTTGCGTTCAAAACTGCATTTCTTGAGTGAAGTAGAGGTAAGCGG
ID: 506, FASTA: TACGTAGGGGGCAAGCGTTATCCGGATTCATTGGGC