In [None]:
PROJECTPREFIX = 'manchot'

In [None]:
# Required packages / imports

import pandas as pd
import itertools as it
import math
import altair as alt
import os
import numpy as np
import colorcet as cc
import matplotlib.colors as colors
import glob
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from scipy.stats import mannwhitneyu, wilcoxon, pearsonr
from functools import lru_cache,reduce
from statsmodels.stats.proportion import proportion_confint
import scipy.spatial.distance as ssd
import scipy.cluster.hierarchy as sch
import matplotlib.pyplot as plt
#from skbio.stats import ordination

os.makedirs('Output',exist_ok=True)


### Table of Contents

* [Definitions / Helper Functions](#defines)
* [Load Data](#loaddata)
* [Normalization Function](#normalize)
* Analysis
    * [Top Level Statistics](#toplevel)
    * [P-Values Top Level Statistics](#pvalues)
    * [Overview Top Taxa per Level](#toptaxa)
    * [Unmapped and Human Fraction](#human)
    * [AMR Gene Counts](#amr)
    * [Taxa Counts (Diversity)](#taxacount)
    * [Barplots Compositions](#barplots)
    * [Bray-Curtis Based PCoA](#pcoa)
    * [Sampling Times Overview](#sampletimes)
    * [Zymo Std Analysis](#zymo)
    * [Eukaryotes Analysis](#eukaryotes)
    * [Crassphage Analysis](#crassphage)
    * [Population Level Analysis / Illumina](#strains)



# Definitions / Helper Functions <a class="anchor" id="defines"></a>

In [None]:
def generatePrePostPairs(table : pd.DataFrame, patientcolumn: str = 'patientid' , timecolumn: str ='time' ):
    result = []
    counts_of_common_occurences = table.groupby([patientcolumn,timecolumn]).size().reset_index()
    for patient in counts_of_common_occurences[patientcolumn].unique():
        subtable = counts_of_common_occurences[counts_of_common_occurences[patientcolumn] == patient]
        for time1,time2 in it.combinations(subtable[timecolumn].unique(),2):
            if time1 < 0 and time2 > 0:
                result.append((patient,time1,time2))
            elif time1 > 0 and time2 < 0:
                result.append((patient,time2,time1))
    return result


#Helper Function that returns a preceding timepoint (if one exists) for each row in a table
def find_previous_sample(table,row):
    #print (row ['PatID'])
    timepoints_patient = list(table[table['PatID']==row['PatID']]['time'].unique())
    timepoints_filtered = [timepoint for timepoint in timepoints_patient if timepoint < row['time']]
    if len(timepoints_filtered) == 0:
        return None
    else:
        return sorted(timepoints_filtered)[-1]

#Helper function that recalculates some relative days and relabels some patients
def adjust_table(df):

    df.loc[(df['samplename'] == '999_0'), 'samplename'] = 'G1_-100'
    df.loc[df['samplename'] == '132_-100', 'samplename'] = '45_-100'
    df.loc[df['samplename'] == '133_558', 'samplename'] = '14_558'
    df.loc[df['samplename'] == '134_-5', 'samplename'] = '46_-5'

def sort_samples(samples):
    return sorted(
        samples,
        key=lambda x : ('G' in x,
                        int(x.split('/')[0].replace('G','').split('.')[0]),
                        int(x.split('/')[0].replace('G','').split('.')[-1]),
                        int(x.split('/')[1]))
    )


def is_below_or_equal(x,y):
    if x in ['-2','-3']:
        return False
    if x == y:
        return True
    
    parent_node = taxonomy[x]
    if parent_node == y: #Parent was the node we looked for
        return True
    elif parent_node == x: #This is only the case at the root node
        return False
    else: #We need to keep looking
        return is_below_or_equal(parent_node,y)

#Constants, frequently used
CHORDATA = '7711'
FUNGI = '4751'
VIRUSES = '10239'
BACTERIA = '2'
ARCHAEA = '2157'
EUKARYOTA = '2759'
PROTOZOA = '1891100'
FIRMICUTES = '1239'
BACTEROIDETES = '976'
ACTINOBACTERIA = '201174'
PROTEOBACTERIA = '1224'
VERRUCOMICROBIA = '74201'
BACTEROIDES = '816'
PREVOTELLA = '838'
ALISTIPES = '239759'
PARABACTEROIDES = '375288'
PLANTAE = '33090'
HOMO='9605'

## Custom Taxon IDs
NOT_ENOUGH_READS = -103
NOT_VALIDATED = -104

OFFSET_PAT_18 = 161

PAPER_SAMPLES = ['G1/-100','G2/-100','G3/-100','G4/-100','G5/-100','G6/-100',
             'G7/-100','G8/-100','G9/-100','G10/-100','G11/-100',
             '4/-7','7/-6','11/-2','12/-6','14/-7','15/-6','16/-5','17/-7','19/-2',
             '21/-9','24/-10','28/-5','29/-5','32/-5','34/-6','38/-6','39/-6','42/-8',
             '13/-7','18/-3','22/-4','23/-9','25/-62','37/-5','40/-1','41/-6','20/-4',
             '26/-3','31/-6','33/-2','36/-6',
             '4/1','11/0','15/6','16/2','17/9','17/13','24/0','24/9','28/1','29/1','34/10','39/1','42/6',
             '18/13','22/0','23/12','25/1','37/1','41/14','26/12','31/6','33/7',
             '15/16','28/21','29/15','34/18','38/15','39/15','42/18','42/30',
             '25/15','37/17','37/22','40/17','26/27',
             '12/34','24/49','28/63','34/41',
             '18/91','21/50','22/91','23/61','25/31','31/58','33/47','36/38','36/63',
             '11/182','12/342','14/558','15/104','15/189','16/163','16/171','17/154',
             '19/153','24/127','28/237','29/186','39/115',
             '13/298','18.2/-9','18.2/16','18.2/71','21/120','23/130','37/138','26/185','31/178']

PAPER_AND_NEW_SAMPLES = ['G1/-100','G2/-100','G3/-100','G4/-100','G5/-100','G6/-100',
             'G7/-100','G8/-100','G9/-100','G10/-100','G11/-100',
             '4/-7','7/-6','11/-2','12/-6','14/-7','15/-6','16/-5','17/-7','19/-2',
             '21/-9','24/-10','28/-5','29/-5','32/-5','34/-6','38/-6','39/-6','42/-8',
             '13/-7','18/-3','22/-4','23/-9','25/-62','37/-5','40/-1','41/-6','20/-4',
             '26/-3','31/-6','33/-2','36/-6',
             '4/1','11/0','15/6','16/2','17/9','17/13','24/0','24/9','28/1','29/1','34/10','39/1','42/6',
             '18/13','22/0','23/12','25/1','37/1','41/14','26/12','31/6','33/7',
             '15/16','28/21','29/15','34/18','38/15','39/15','42/18','42/30',
             '25/15','37/17','37/22','40/17','26/27',
             '12/34','24/49','28/63','34/41',
             '18/91','21/50','22/91','23/61','25/31','31/58','33/47','36/38','36/63',
             '11/182','12/342','14/558','15/104','15/189','16/163','16/171','17/154',
             '19/153','24/127','28/237','29/186','39/115',
             '13/298','18.2/-9','18.2/16','18.2/71','21/120','23/130','37/138','26/185','31/178', '11/672', '12/711',
              '15/565', '16/542', '17/530', '19/581', '23/440',
                '24/416', '25/322', '26/394', '28/360', 
               '33/278', '34/296', '37/292', '38/256', '39/255', '40/236', '41/161', 
              '42.3/164', '43/-7', '44/-6', '44/12', '44/28', '46/-5', '47/-3', 
                '47.3/48', '48/-8', '48/3', '49/-3', '49/4', '50/-1', '50/1', '50/22', 
               '52/-1', '53/2', '54/-7', '54/2', '56/-7','53/-6']

NEW_SAMPLES = [ '4/29', '11/672', '12/711',
              '15/565', '16/24', '16/83', '16/542', '17/-1', '17/530', '17.2/154', '19/581', '20/-1', '21/-3', '23/440',
                '24/416', '25/322', '26/394', '27/2', '27/14', '27/174', '27/174', '27/234', '27/379', '28/360', '32/3',
               '33/278', '34/296', '36/-4', '37/292', '38/256', '39/255', '40/236', '41/161', 
              '42.3/164', '43/-7', '44/-6', '44/12', '44/28', '45/-100', '46/-5', '47/-3', 
                '47.3/48', '48/-8', '48/3', '49/-3', '49/4', '50/-1', '50/1', '50/22', '51/5',
               '52/-1','53/-6', '53/2', '54/-7', '54/2', '55/2', '56/-7']

ALL_SAMPLES = ['G1/-100','G2/-100','G3/-100','G4/-100','G5/-100','G6/-100',
             'G7/-100','G8/-100','G9/-100','G10/-100','G11/-100',
             '4/-7','7/-6','11/-2','12/-6','14/-7','15/-6','16/-5','17/-7','19/-2',
             '21/-9','24/-10','28/-5','29/-5','32/-5','34/-6','38/-6','39/-6','42/-8',
             '13/-7','18/-3','22/-4','23/-9','25/-62','37/-5','40/-1','41/-6','20/-4',
             '26/-3','31/-6','33/-2','36/-6',
             '4/1','11/0','15/6','16/2','17/9','17/13','24/0','24/9','28/1','29/1','34/10','39/1','42/6',
             '18/13','22/0','23/12','25/1','37/1','41/14','26/12','31/6','33/7',
             '15/16','28/21','29/15','34/18','38/15','39/15','42/18','42/30',
             '25/15','37/17','37/22','40/17','26/27',
             '12/34','24/49','28/63','34/41',
             '18/91','21/50','22/91','23/61','25/31','31/58','33/47','36/38','36/63',
             '11/182','12/342','14/558','15/104','15/189','16/163','16/171','17/154',
             '19/153','24/127','28/237','29/186','39/115',
             '13/298','18.2/-9','18.2/16','18.2/71','21/120','23/130','37/138','26/185','31/178', '4/29', '11/672', '12/711',
              '15/565', '16/24', '16/83', '16/542', '17/-1', '17/530', '17.2/154', '19/581', '20/-1', '21/-3', '23/440',
                '24/416', '25/322', '26/394', '27/2', '27/14', '27/174', '27/174', '27/234', '27/379', '28/360', '32/3',
               '33/278', '34/296', '36/-4', '37/292', '38/256', '39/255', '40/236', '41/161', '42.1/164', '42.2/164',
              '42.3/164', '42.4/164', '43/-7', '44/-6', '44/12', '44/28', '45/-100', '46/-5', '47/-3', '47.1/48',
               '47.2/48', '47.3/48', '47.4/48', '48/-8', '48/3', '49/-3', '49/4', '50/-1', '50/1', '50/22', '51/5',
               '52/-1', '53/2', '54/-7', '54/2', '55/2', '56/-7']

DUPLICATES = ['15/104',
 '16/163',
 '18.2/71',
 '21/50',
 '23/61',
 '29/15',
 '31/58',
 '34/10',
 '34/18',
 '36/38',
 '37/17',
 '42/18']


HEALTHY_SAMPLES = ['G1/-100','G2/-100','G3/-100','G4/-100','G5/-100','G6/-100',
             'G7/-100','G8/-100','G9/-100','G10/-100','G11/-100']

ZYMO_SAMPLES_NEW = ['Zymo_Zymo/-100','Zymo_EZ1/-100','Zymo_Power/-100','Zymo_Pro/-100']+['Stool_Power_11/-100',
 'Stool_Power_21/-100',
 'Stool_Pro_11/-100',
 'Stool_Pro_21/-100',
 'Stool_Zymo_11/-100',
 'Stool_Zymo_21/-100',
 'Stool_EZ1_12/-100',
 'Stool_EZ1_21/-100']

LIFELINES = ['LL81_E07_7627', 'LL91_A04_8564',
       'LL93_H07_8785', 'LL87_A09_8215', 'LL80_A10_7551',
       'LL47_D03_4330', 'LL63_B04_5872', 'LL83_F03_7788',
       'LL66_C11_6217', 'LL80_D01_7482', 'LL92_D01_8637',
       'LL59_F11_5548', 'LL83_G11_7853', 'LL72_C02_6721',
       'LL84_D03_7882', 'LL45_F06_4164', 'LL89_E11_8427',
       'LL60_B04_5584', 'LL45_F11_4204', 'LL70_E06_6539'
    
]

MARKER_GENERA = [
    'Prevotella',
    'Enterobacter',
    'Bacteroides',
    'Enterococcus',
    'Faecalibacterium',
    'Klebsiella',
    'Escherichia',
    'Akkermansia',
    'Clostridium',
    'Blautia',
    'Skunavirus',
    'Bifidobacterium',
    'Eubacterium',
    'Ruminococcus'
]


MARKER_SPECIES = [
    'Akkermansia muciniphila',
    'Bacteroides thetaiotaomicron',
    'Bacteroides ovatus',
    'Faecalibacterium prausnitzii',
    'Enterococcus faecium',
    'Enterococcus faecalis',
    'Enterocloster bolteae',
    'Candida albicans',
    'Saccharomyces cerevisiae',
    'Toxoplasma gondii',
    'Blautia wexlerae',
    'Eubacterium biforme',
    '[Ruminococcus] gnavus',
    'uncultured crAssphage'
]



# Load Data <a class="anchor" id="loaddata"></a>

## Kraken2

In [None]:
#This data can be collected from the snakemake workflow
kraken_dataframe = pd.read_csv(
    '../data/output/mapping/{}_KrakenFullDump.csv'.format(PROJECTPREFIX),
    usecols=[1,2,3,4,5],dtype={'samplename' : str,'taxonid':str}
)
adjust_table(kraken_dataframe)

lifelines_data = pd.read_csv('Input/KrakenLifelines.csv',usecols=[0,1,2,3,4],dtype={'samplename' : str,'taxonid':str}
            )
if 'patientid' in lifelines_data.columns:
    print('>>><<<\nThe lifelines data has an old format, the columns patientid and time need to be combined to a new column with format {patientid}_{time}\n>>><<<')
    assert(False)
kraken_dataframe = pd.concat([kraken_dataframe,lifelines_data
]
                            
                            )

crassphages_species = kraken_dataframe[kraken_dataframe['taxon'] == 'uncultured crAssphage']
crassphages_species['level'] = 'G'
crassphages_species['taxon'] = 'Crassphage Pseudo-Genus'

kraken_dataframe = pd.concat([kraken_dataframe,crassphages_species])


kraken_dataframe

## Load Sample Annotations

In [None]:
#PatID is specified as String (Text) so IDs like 18.2 don't get confused as decimal numbers

#Tables contain annotations regarding the individual samples
sample_statistics = pd.read_excel('Input/Annotations/SampleStatistics.xlsx',dtype={'PatID' : str})

#Outcomes
outcomes = pd.read_excel('Input/Annotations/PatientStatistics.xlsx',dtype={'Pat ID' : str})

## AMR Detection

In [None]:
PAPER_SAMPLES_UNDERSCORE = [s.replace('/','_') for s in PAPER_SAMPLES]

#can be extracted from the snakemake workflow, used to normalize the AMR read counts
sequencing_stats = pd.read_csv('../data/output/manchot_sampleStats.reads.filtered.2.csv')[[
    'samplename','nReads','nBases','median','mean','standard deviation','minimum','maximum'
]]


adjust_table(sequencing_stats)
sequencing_stats = sequencing_stats[sequencing_stats['samplename'].isin(PAPER_SAMPLES_UNDERSCORE)]

readcount_min = sequencing_stats['nReads'].min()

sequencing_stats = sequencing_stats.set_index(['samplename'])
sequencing_stats

In [None]:
def calc_fraction(x):
    if x['samplename'] in sequencing_stats.index:
        return x['readcount'] / sequencing_stats.loc[x['samplename']]['nReads']
    else:
        return np.NaN

#can be extracted from the snakemake workflow
amr = pd.read_csv('../data/output/amr/manchot_fulldump_amr.csv').rename(columns={'Samplename':'samplename'})

amr['readcount'] = amr.apply(
    lambda x : 1/x['Ambiguous Assignments'] if x['Ambiguous Assignments'] != 0 else 1,axis=1
)
adjust_table(amr)
amr = amr[amr['samplename'].isin(PAPER_SAMPLES_UNDERSCORE)]

amr['Gene Pro Read'] =amr.apply(lambda x : calc_fraction(x) ,axis=1)

amr

### Automatically update the sample statistics

In [None]:
amr_stats = pd.concat(
    [
        (amr.groupby('samplename')['readcount'].sum()).rename('ARG Reads'),
        (sequencing_stats.groupby('samplename')['nReads'].sum()*10000).rename('ARG Input Reads'),
        (amr.groupby('samplename')['Gene Pro Read'].sum()*10000).rename('ARG Reads per 10000 Reads')
    ],axis=1
).reset_index()
amr_stats

In [None]:
for column in ['ARG Reads','ARG Input Reads','ARG Reads per 10000 Reads']:
    if column in sample_statistics.columns:
        print('The column {} was already found in the table and will be replaced!'.format(column))
        sample_statistics = sample_statistics.drop(columns=column)
sample_statistics['samplename'] = sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)
sample_statistics=pd.merge(sample_statistics,amr_stats,on='samplename',how='left')
sample_statistics = sample_statistics.drop(columns='samplename')
sample_statistics.to_excel('Input/Annotations/SampleStatistics.xlsx',index=False)

## Kraken2/Metamaps Taxonomy

In [None]:
idtonames = {}

with open('Input/Taxonomy/names.dmp','r') as f:
    for l in f.read().splitlines():
        d=[x.strip() for x in l.split('|')]
        if d[3] == 'scientific name':
            idtonames[d[0]] = d[1]


taxonomy = {}

levels = {}

with open('Input/Taxonomy/nodes.dmp','r') as f:
    for l in f.read().splitlines():
        d=[x.strip() for x in l.split('|')]
        taxonomy[d[0]] = d[1]
        levels[d[0]] = d[2]

## Zymo Theory

In [None]:
zymo_theory = pd.read_csv('Input/zymo_theory.csv',
                          header=None,
                          names=['Taxon','Read Fraction']
                         )

zymo_theory_taxa = zymo_theory['Taxon'].tolist()
zymo_theory['Read Fraction (%)'] = zymo_theory['Read Fraction']/100
zymo_theory['Sample ID'] = 'Zymo Theoretical Composition'

zymo_minimap = pd.read_csv('Input/zymo_minimap_verification.csv')

# Validation

In [None]:
validation_data = {} #Multilevel

for validation_level, kraken_key in {
    'species' : 'S',
    'genus' : 'G'
}.items():
    
    if not os.path.exists('../data/output/validation/{}_{}.csv'.format(PROJECTPREFIX,validation_level)):
        print('Could not load validation data for project prefix {} and level {}, check if this is ok'.format(PROJECTPREFIX,validation_level))
        continue
    validation_data[kraken_key] =pd.read_csv('../data/output/validation/{}_{}.csv'.format(PROJECTPREFIX,validation_level),dtype={'Taxon ID':str})
    validation_data[kraken_key] = validation_data[kraken_key].rename(columns={'Sample' : 'samplename'})
    adjust_table(validation_data[kraken_key])

    validation_data[kraken_key][['patientid','time']] = validation_data[kraken_key]['samplename'].str.rsplit('_',1,expand=True) #specific for manchot project
    validation_data[kraken_key]['time'] = validation_data[kraken_key]['time'].astype(int)
    validation_data[kraken_key]['Taxon Name'] = validation_data[kraken_key]['Taxon ID'].map(idtonames)

# Normalization Function <a class="anchor" id="normalize"></a>

In [None]:
def get_normalized_abundances(
    kraken_dataframe,
    level, #No default value here to avoid accidental mistakes!
    samples = None,
    random_seed = (4+8+15+16+23+42),
    normalize=True,
    excluded_taxa_filter = None, #Can be a list of taxa
    included_taxa_filter = None, #Only one taxa, becomes new root
):
    
    ### Helper Functions for Filtering
    
    not_found_taxa = set()

    @lru_cache(maxsize=None)
    def is_not_below(x,y):
        parent_node = taxonomy[x]
        if parent_node == y: #Parent was the node we looked for
            return False
        elif parent_node == x: #This is only the case at the root node
            return True
        else: #We need to keep looking
            return is_not_below(parent_node,y)


    def filter_function(row,taxon):
        if row['taxonid'] == '0': #Root Node
            return False
        if row['taxonid'] not in taxonomy:
            #print('Warning: TaxID {} is not in the taxonomy, this is (potentially) bad!'.format(row['taxonid']))
            return False
        if row['taxonid'] == taxon: #This is the taxon itself, remove
            return False
        #Otherwise we will check if the taxon is below our target taxon in the taxonomy
        return is_not_below(row['taxonid'],taxon)


    ###
    
    #We begin by assuming the raw kraken dataframe as input
    working_table = kraken_dataframe

    
    #####################
    #   SELECT SAMPLES
    #
    #####################    
    
    #Filter to target samples
    if samples != None:
        working_table = working_table[working_table['samplename'].isin(samples)]    


    #####################
    #   CALCULATE UNASSIGNED READS
    #
    #####################    

    #Determine root readcounts
    root_readcounts_kr = None

    if included_taxa_filter != None:
        root_readcounts_kr = working_table[
            working_table['taxonid']==included_taxa_filter
        ].groupby(
            ['samplename']
        )['readcount'].sum()

    else: #If we don't filter we take all reads -> R
        root_readcounts_kr = working_table[
            working_table['level']=='R'
        ].groupby(
            ['samplename']
        )['readcount'].sum()


    if excluded_taxa_filter != None:
            root_readcounts_kr -= working_table[
            working_table['taxonid'].isin(excluded_taxa_filter)
        ].groupby(
            ['samplename']
        )['readcount'].sum()

    readcounts_at_level=None

    if included_taxa_filter != None:
        include = working_table[
            working_table['level'] == level
        ].apply(lambda x : filter_function(x,included_taxa_filter),axis=1)

        readcounts_at_level = working_table[
            working_table['level'] == level
        ][~include]

        readcounts_at_level = readcounts_at_level.groupby(
            ['samplename']
        )['readcount'].sum()
    else:
        #Determine total read counts at level
        readcounts_at_level = working_table[
            working_table['level'] == level
        ].groupby(
            ['samplename']
        )['readcount'].sum()


    if excluded_taxa_filter != None:
        for taxon in excluded_taxa_filter:
            include = working_table[
                working_table['level'] == level
            ].apply(lambda x : filter_function(x,taxon),axis=1)

            excluded_taxon_sum = working_table[
                working_table['level'] == level
            ][~include]

            readcounts_at_level -= excluded_taxon_sum.groupby(
            ['samplename']
                )['readcount'].sum()

    #Add unassigned at level
    unassigned_entries = []
    for sample in readcounts_at_level.keys():

        difference = root_readcounts_kr[sample]-readcounts_at_level[sample]
        
        addon_table = pd.DataFrame([
            (difference,'Unassigned at Level','-2',level,sample)
        ],columns=[
            'readcount','taxon','taxonid','level','samplename'
        ])
        
        unassigned_entries.append(addon_table)
                
    #Combine into one table
    if len(unassigned_entries) != 0:
        unassigned_table = pd.concat(unassigned_entries)
        working_table = pd.concat([unassigned_table,working_table])
    
    #####################
    #   DOWNSAMPLING
    #
    #####################
    
    #Filter to target level
    downsampled_table = working_table[working_table['level'] == level]  
    
    #Filter for taxon if required
    if included_taxa_filter != None:
        print('Reducing composition to subtree below taxon: {}'.format(idtonames[included_taxa_filter]))
        include = downsampled_table.apply(lambda x : filter_function(x,included_taxa_filter),axis=1)
        downsampled_table = downsampled_table[~include]
           
    if normalize:
    
        # Identify lowest read count
        readAnzahlen = downsampled_table.groupby('samplename')['readcount'].sum()
        minimaleReadAnzahl = readAnzahlen.min()
        print('The minimal read count across all samples is [Taxonomic Level {}]: {}'.format(level,minimaleReadAnzahl))

        frames = []

        # Draw new counts for each sample
        for sample in downsampled_table['samplename'].unique():

            sample_table = downsampled_table[downsampled_table['samplename'] == sample]
            sample = sample_table.sample(
                n=round(minimaleReadAnzahl),
                random_state=random_seed,
                weights='readcount',
                replace=True
            )

            sample = sample.groupby([
                'taxon',
                'taxonid',
                'samplename',
                'level'
            ],as_index=False).count()

            frames.append(sample)

        #Overwrite table with downsampled entries
        downsampled_table = pd.concat(frames)   
    
    #####################
    #   FILTERING II
    #
    #####################
    
    unassigned_downsampled_table = downsampled_table[downsampled_table['taxonid'] == '-2']
    assigned_downsampled_table = downsampled_table[downsampled_table['taxonid'] != '-2']
    
    

    if excluded_taxa_filter != None:
        for taxon in excluded_taxa_filter:
            include = assigned_downsampled_table.apply(lambda x : filter_function(x,taxon),axis=1)
            assigned_downsampled_table = assigned_downsampled_table[include]
   

    tables=[unassigned_downsampled_table,assigned_downsampled_table]
    
    #Create dummy entries for patients that have nothing


    if samples != None:
        dummy_entries = []
        for sample_id in samples:
            if sample_id not in assigned_downsampled_table['samplename'].unique():
                print('Creating a dummy entry for sample {} (No reads with the selected parameters)'.format(sample_id))
                addon_table = pd.DataFrame([
                            (0,'Absolutely Nothing','Nothing','N',sample_id)
                        ],columns=[
                            'readcount','taxon','taxonid','level','samplename'
                        ]) 
                dummy_entries.append(addon_table)
                tables+=dummy_entries

    downsampled_table = pd.concat(tables)
    
        
        

    return downsampled_table


## Function that applies validation with specified cutoffs

In [None]:
def apply_validation(
    normalized_kraken_dataframe,
    level = 'S',
    validation_cutoff = 0.2,
    readcount_cutoff = 5
):
    with_validation = pd.merge(
        normalized_kraken_dataframe,
        validation_data[level][['samplename','Taxon ID','Validation Rate']],
        left_on=['samplename','taxonid'],
        right_on=['samplename','Taxon ID'],
        how='left'
    )

    #Phase 1: Kick out low abundance groups, assign to "Not enough reads"
    with_validation.loc[with_validation['readcount']<readcount_cutoff,'taxon'] = 'Not enough reads'
    with_validation.loc[with_validation['readcount']<readcount_cutoff,'taxonid'] = NOT_ENOUGH_READS

    #Phase 2: Check for rest if validates, if not assign to "Not validated"
    with_validation['Validated'] = (with_validation['Validation Rate'] >= validation_cutoff)
    with_validation.loc[
        (with_validation['Validated']!=True)&~(with_validation['taxon'] == 'Not enough reads'), 'taxon'
    ] = 'Not validated'
    with_validation.loc[
        (with_validation['Validated']!=True)&~(with_validation['taxon'] == 'Not enough reads'), 'taxonid'
    ] = NOT_VALIDATED
    
    #Readjust sum (group multiple "Not enough reads" entries together)
    return with_validation.groupby(['taxon','taxonid','samplename'],as_index=False).sum()

# Aggregate Species Validation Info for Top-Level Groups

In [None]:
top_level_tuples = []

validation_map = {
    'Viruses' : (VIRUSES,[]),
    'OtherEukaryota':(EUKARYOTA,[CHORDATA,FUNGI,PLANTAE]),
    'Fungi':(FUNGI,[]),
    'Bacteria':(BACTERIA,[]),
    'Plants':(PLANTAE,[]),
    'Archaea':(ARCHAEA,[]),
    'Microbiome' : ('1',[CHORDATA,PLANTAE]),
    'Human' : (CHORDATA,[])

}

for group in validation_map:
    
    TAXON,EXCLUDE = validation_map[group]

    
    input_table = get_normalized_abundances(
        kraken_dataframe,
        level='S',
        samples=[s.replace('/','_') for s in PAPER_SAMPLES], #New sample format uses underscore '_' instead of '/'
        included_taxa_filter=TAXON,
        excluded_taxa_filter=EXCLUDE,
        normalize=False
    )

    with_validation = apply_validation(
        input_table,
        level = 'S',
        validation_cutoff = 0.2,
        readcount_cutoff = 5
    )

    for samplename in PAPER_SAMPLES:
    
        samplename = samplename.replace('/','_')


        subtable = with_validation[with_validation['samplename'] == samplename]
        

        
        subtable['Validated'] = (subtable['Validation Rate'].astype(float) >= 0.2)
        #print(subtable)
        weighted_validation_rate= (
            subtable['Validated']*subtable['readcount']
        ).sum()/subtable['readcount'].sum()
        top_level_tuples.append((group,samplename,weighted_validation_rate >= 0.8))

    


In [None]:
pd.DataFrame(top_level_tuples,columns=['Group','Sample','Validated']).to_csv('Output/ValidationTopLevelGroups_MoreThan80.csv',index=False)

# Top Level Statistics & Visualization <a class="anchor" id="toplevel"></a>

Here we generate top-level plots for the different timepoints. 

In [None]:
#stats kraken
classified = kraken_dataframe[kraken_dataframe['level'].isin(['R','U'])].groupby(['samplename','taxon','level'],as_index=False).sum().pivot(
    index=['samplename'],columns=['level'],values=['readcount']
)

classified = classified.rename(columns={'R' : 'Classified','U' : 'Unclassified'})

classified.columns = classified.columns.get_level_values(1)

total = kraken_dataframe[kraken_dataframe['level'].isin(['R','U'])].groupby(['samplename'],as_index=False).sum()
total = total.rename(columns={'readcount' : 'Total'})

combined = pd.merge(total,classified,how='left',on=['samplename'])

for taxon in [CHORDATA,BACTERIA,FUNGI,VIRUSES,EUKARYOTA,ARCHAEA,PLANTAE]:
    taxonOnly = kraken_dataframe[kraken_dataframe['taxonid'] == taxon].groupby(['samplename'],as_index=False).sum()
    taxonOnly = taxonOnly.rename(columns={'readcount' : idtonames[taxon]})
    combined = pd.merge(combined,taxonOnly,how='left',on=['samplename'])
    
combined = combined.set_index(['samplename'])

for column in combined.columns:
    print(column)
    combined[column+'_Prozent'] = (combined[column]/combined['Total'])*100

    
combined.to_csv('stats_kraken.csv')
combined

In [None]:
top_level_validation = pd.read_csv('Output/ValidationTopLevelGroups_MoreThan80.csv')


#Some relabeling
sample_statistics_relabeling = sample_statistics
#rename healthy to control
sample_statistics_relabeling['timephase'] = sample_statistics_relabeling['timephase'].apply(
    lambda x : 'Control' if x == 'Healthy' else x)
sample_statistics_relabeling['timephase'] = sample_statistics_relabeling['timephase'].apply(
    lambda x : 'Pre-Tx' if x == 'Pre TX' else x)
sample_statistics_relabeling['timephase'] = sample_statistics_relabeling['timephase'].apply(
    lambda x : 'Leukopenia' if x == 'Leukozytopenia' else x)

sample_statistics_relabeling['samplename'] = sample_statistics_relabeling['PatID']+'_'+sample_statistics_relabeling['time'].astype(str)

domain_like_groups=['Microbiome','Human','Plants','Unclassified']
x = sample_statistics_relabeling[
    ['timephase',
     'Unclassified_%','Human_%','Plants_%','Microbiome_%',
     'samplename']
]

x = x.melt(id_vars=['samplename','timephase'])

x[['domain','percent_sign']] = x['variable'].str.split('_',expand=True)
x = x.drop(columns=['percent_sign'])
x = x.pivot(index=['samplename','timephase'],columns='domain',values='value')
x = x.reset_index()
x = x.melt(id_vars=['samplename','timephase'])
x = pd.merge(x,top_level_validation,
             left_on=['samplename','domain'],right_on=['Sample','Group'],how='left').fillna('Not an outlier')


c1 = alt.Chart(sample_statistics_relabeling,width=100,height=100).mark_boxplot(
    ticks=True,median={'color':'black'}).encode(
    column=alt.Column(
        'timephase:O',
        header=alt.Header(labels=False),
        spacing=0,
        sort=['Control', 'Pre-Tx','Leukopenia','Reconstitution'],
        title=None
    ),
    y=alt.Y('Microgram DNA per g Stool', axis=alt.Axis(grid=False,minExtent=40), title=['DNA [µg]','Per g Stool']),

    )

c2 = alt.Chart(sample_statistics_relabeling,width=100,height=100).mark_boxplot(
    ticks=True,median={'color':'black'}
).encode(
    column=alt.Column(
        'timephase:O',
        header=alt.Header(labels=False),
        spacing=0,
        sort=['Control', 'Pre-Tx','Leukopenia','Reconstitution'],
        title=None
    ),
     y=alt.Y('Total Reads per DNA Library:Q',scale=alt.Scale(type='log'),
             axis=alt.Axis(grid=False,tickCount=10, format='.1e',minExtent=40),title='Total Reads')
    )

c2_2 = alt.Chart(sample_statistics_relabeling,width=100,height=100).mark_boxplot(
    ticks=True,median={'color':'black'}).encode(
    column=alt.Column(
        'timephase:O',
        header=alt.Header(labels=False),
        spacing=0,
        sort=['Control', 'Pre-Tx','Leukopenia','Reconstitution'],
        title=None
    ),
     y=alt.Y('Median Read Length:Q',scale=alt.Scale(type='log'),
             axis=alt.Axis(grid=False,tickCount=10, format='.1e',minExtent=40),title='Median Read Length')
    )

c3 = alt.Chart(x,width=100).mark_boxplot(ticks=True,median={'color':'black'}).encode(
    column=alt.Column(
        'timephase:O',
        spacing=0,
        sort=['Control', 'Pre-Tx','Leukopenia','Reconstitution'],
        title=None,
        header=alt.Header(labelOrient='bottom', labelPadding=313)
    ),
     x=alt.X("domain:O", title=None, axis=alt.Axis(labels=False, ticks=False), 
             scale=alt.Scale(paddingInner=1), sort=domain_like_groups),    
    y=alt.Y("value:Q",title=['Fraction','[% Reads]'],
            scale=alt.Scale(type='symlog',domain=[0,100]),
            axis=alt.Axis(grid=False,minExtent=40,values=[0.1,0.5,1,5,10,20,50,100])), 
    color=alt.condition(
        (alt.datum['Validated'] != False),
        alt.Color("domain:N",sort=domain_like_groups,legend=alt.Legend(orient='right',
                                                                       symbolType='square',title='Reads',
                                                                      symbolStrokeWidth=0,symbolOpacity=1),
                    scale=alt.Scale(domain=domain_like_groups,
                                    range=['#dd1c77', '#fecc5c', '#006837', '#253494'])),
                alt.value('lightgrey')
    )
    ,tooltip=['samplename', 'value']
)


chart1=(c1&c2&c2_2&c3).configure_tick(thickness=2)

chart1.save('Output/Composition/High_Level_Metrics.html')

chart1

In [None]:
#Some relabeling
sample_statistics_relabeling = sample_statistics
#rename healthy to control
sample_statistics_relabeling['timephase'] = sample_statistics_relabeling['timephase'].apply(
    lambda x : 'Control' if x == 'Healthy' else x)
sample_statistics_relabeling['timephase'] = sample_statistics_relabeling['timephase'].apply(
    lambda x : 'Pre-Tx' if x == 'Pre TX' else x)
sample_statistics_relabeling['timephase'] = sample_statistics_relabeling['timephase'].apply(
    lambda x : 'Leukopenia' if x == 'Leukozytopenia' else x)


domain_like_groups=['Bacteria','Fungi','OtherEukaryota','Viruses','Archaea']
z = sample_statistics_relabeling[
    ['timephase',
     'Bacteria_%','Fungi_%','OtherEukaryota_%','Viruses_%','Archaea_%',
     'samplename']
]
z = z.melt(id_vars=['samplename','timephase'])
z[['domain','variable']] = z['variable'].str.split('_',expand=True)
z = z.drop(columns=['variable'])
z = z.pivot(index=['samplename','timephase'],columns='domain',values='value')
z = z.reset_index()
z = z.melt(id_vars=['samplename','timephase'])
z = pd.merge(
    z,top_level_validation,
             left_on=['samplename','domain'],right_on=['Sample','Group'],how='left')



c5 = alt.Chart(sample_statistics_relabeling,width=100,height=100).mark_boxplot(ticks=True,median={'color':'black'}).encode(
    column=alt.Column(
        'timephase:O',
        header=alt.Header(labels=False),
        spacing=0,
        sort=['Control', 'Pre-Tx','Leukopenia','Reconstitution'],
        title=None
    ),
    y=alt.Y('Detected Species (Validated and Normalized)', axis=alt.Axis(grid=False,minExtent=40), title=['Detected Species','(Validated and Normalized)'])
    )

c6 =alt.Chart(sample_statistics_relabeling,width=100,height=100).mark_boxplot(ticks=True,median={'color':'black'}).encode(
    column=alt.Column(
        'timephase:O',
        spacing=0,
        sort=['Control', 'Pre-Tx','Leukopenia','Reconstitution'],
        title=None,
        header=alt.Header(labels=False)
    ),
    y=alt.Y('ARG Reads per 10000 Reads:Q', 
            title=['ARG-carrying Reads','Per 10,000 Reads'],
            scale=alt.Scale(type='symlog'), axis=alt.Axis(grid=False,minExtent=40))
    )


c7 = alt.Chart(z,width=100).mark_boxplot(ticks=True,median={'color':'black'}).encode(
    column=alt.Column(
        'timephase:O',
        spacing=0,
        sort=['Control', 'Pre-Tx','Leukopenia','Reconstitution'],
        title=None,
        header=alt.Header(labelOrient='bottom', labelPadding=313)
    ),
    tooltip=['samplename'],
    x=alt.X("domain:O", title=None, 
            axis=alt.Axis(labels=False, ticks=False),
            scale=alt.Scale(paddingInner=1), sort=domain_like_groups),    
    y=alt.Y("value:Q",title=['Fraction','[% Reads]'],scale=alt.Scale(type='symlog',domain=[0,100]),
            axis=alt.Axis(grid=False,minExtent=40, values=[0.1,0.5,1,5,10,20,50,100])), 
    color=alt.condition(
        (alt.datum['Validated'] != False),
        alt.Color(
            "domain:N",
            sort=domain_like_groups,
            legend=alt.Legend(title='Microbial Reads',orient='right',symbolType='square', symbolStrokeWidth=0,symbolOpacity=1),
            scale=alt.Scale(domain=domain_like_groups,range=['#EE6677','#fecc5c','#228833', '#66CCEE', '#AA3377'])
        ),
        alt.value('lightgrey')
    )
)

chart2=(c5&c6&c7).configure_tick(thickness=2)

chart2.save(
    'Output/Composition/High_Level_Metrics_2.html')

chart2

In [None]:
z.groupby(['domain','timephase'])['value'].median()

### Correlation DNA / Reads

In [None]:
alt.Chart(sample_statistics).mark_point().encode(
    x='Microgram DNA per g Stool',
    y='Total Reads per DNA Library:Q'
)

### Threshold Criterion
A quick overview of how many samples pass a certain threshold of reads

In [None]:
for THRESHOLD in [10000,50000,100000]:
    print(THRESHOLD,len(sample_statistics[sample_statistics['Total Reads'] > THRESHOLD]))
    

# Top Level Statistics P-Values <a class="anchor" id="pvalues"></a>

We compare each group with respect to a specific property and calculate a mannwhitney-u test to determine if one distribution is "larger" than the other

In [None]:
test_columns = ['Microgram DNA per g Stool',
        'Total Reads', 
       'Median Read Length',
       'Unclassified_%', 'Human_%',
                'Detected Species (Validated and Normalized)',
                'ARG Reads per 10000 Reads']

group_column = 'timephase'

tuples_a = []
for test_column in test_columns:
    for group1,group2 in it.combinations(sample_statistics[group_column].unique(),2):
        result = mannwhitneyu(
            sample_statistics[sample_statistics[group_column] == group1][test_column],
            sample_statistics[sample_statistics[group_column] == group2][test_column]  
        )
        tuples_a.append((group1,group2,test_column,result.pvalue))
    
df_a = pd.DataFrame(tuples_a,columns=['Group A','Group B','Category','p-Value Whitney U'])
df_a = df_a.set_index(['Category','Group A','Group B'])

#Calculate Significances
SIGNIFICANCES = [0.05,0.01,0.001]
#We correct for the number of experiments
for idx,significance in enumerate(SIGNIFICANCES):
    corrected_significance = significance/len(df_a)
    df_a['{} Significance {} (Corrected for test count: {:.2E})'.format(
        '*'*(idx+1),significance,corrected_significance
    )] = df_a['p-Value Whitney U'] < corrected_significance

df_a.to_csv('Output/Whitney_{}.csv'.format(group_column))
df_a

For the column which was tested we will also track Mean/Min/Max + Std. Deviation

In [None]:
test_column = 'Median Readlength'

tuples_b = []
for group in sample_statistics[group_column].unique():
    tuples_b.append((
        group,
        sample_statistics[sample_statistics[group_column] == group][test_column].mean(),
        sample_statistics[sample_statistics[group_column] == group][test_column].std(),
        sample_statistics[sample_statistics[group_column] == group][test_column].min(),
        sample_statistics[sample_statistics[group_column] == group][test_column].max()
        
    ))
df_b = pd.DataFrame(tuples_b,columns=['Group','Mean','Std. Deviation','Minimum','Maximum'])
df_b.to_csv('Output/MeanAndStd_{}_{}.csv'.format(test_column,group_column))
df_b

# Taxa Counts (Diversity) Per Level <a class="anchor" id="taxacount"></a>

In [None]:
SAMPLES = PAPER_SAMPLES
EXCLUDE = [CHORDATA, PLANTAE]
############################

os.makedirs('Output/Diversity',exist_ok=True)

for normalized in [True,False] :
    charts = []

    for taxonomic_level in ['S','G']:
        table = get_normalized_abundances(
            kraken_dataframe,
            level=taxonomic_level,
            excluded_taxa_filter=EXCLUDE,
            samples=[s.replace('/','_') for s in SAMPLES], #New sample format uses underscore '_' instead of '/'
            normalize=normalized 
        )

        table = apply_validation(table,level=taxonomic_level,validation_cutoff=0.2,readcount_cutoff=5)

        table = table[table['taxonid']!='-2']

        taxaCounts = table.groupby(['samplename'],as_index=False)['taxonid'].count().rename(
            columns={'taxonid' : '{}_{}'.format(taxonomic_level,'Normalized')}
        )

        charts.append(taxaCounts)


    reduce(lambda x,y : pd.merge(x,y,how='left',on=['samplename']),charts).to_csv(
        'Output/Diversity/taxa_counts_{}-Without_{}_Normalized={}.csv'.format(
            hash(str(SAMPLES)),
            str(EXCLUDE),
            normalized
        )
    )

### Update Sample Statistics Table with calculated diversity values

In [None]:
diversity = pd.read_csv('Output/Diversity/taxa_counts_{}-Without_{}_Normalized={}.csv'.format(
hash(str(SAMPLES)),
    [CHORDATA,PLANTAE],
    True
)
                       )[['S_Normalized','samplename']]
if 'Detected Species (Validated and Normalized)' in sample_statistics.columns:
    print('The column was already found in the table and will be replaced!')
    sample_statistics = sample_statistics.drop(columns='Detected Species (Validated and Normalized)')
sample_statistics['samplename'] = sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)
sample_statistics=pd.merge(sample_statistics,diversity,on='samplename',how='left')
sample_statistics = sample_statistics.drop(columns='samplename').rename(columns={
    'S_Normalized' : 'Detected Species (Validated and Normalized)'
})
sample_statistics.to_excel('Input/Annotations/SampleStatistics.xlsx',index=False)

Here we visualize the precalculated diversity values (calculated above)

In [None]:
table= sample_statistics[sample_statistics['timephase']=='Pre-Tx']
print(table.groupby('Startcluster')['Detected Species (Validated and Normalized)'].median())
alt.Chart(table).mark_boxplot().encode(
    x='Startcluster:N',
    y='Detected Species (Validated and Normalized)'
)

# Barplots Compositions <a class="anchor" id="barplots"></a>

In [None]:
TOP_X = 30
LEVEL = 'S'
DISCARD_CUTOFF = 5

#Filters
INCLUDE = None #Use 'None' to use root as top node
EXCLUDE = [CHORDATA,PLANTAE]

SAMPLES = PAPER_SAMPLES

def SORTFUNCTION(x):
    try:
        return (
            x[0] == 'G', #Sort by G or regular patient first
            float(x.split('G')[-1].split('_')[0]), #Then Patient ID
            int(x.split('G')[-1].split('_')[1]), #Then Time
        )
    except:
        return (True,1,hash(x))
    
GROUPING = 'timephase_and_cluster'
SORTING = 'Sample ID'

NORMALIZE = False

#################

os.makedirs('Output/Composition',exist_ok=True)

input_table = get_normalized_abundances(
    kraken_dataframe,
    level=LEVEL,
    samples=[s.replace('/','_') for s in SAMPLES], #New sample format uses underscore '_' instead of '/'
    included_taxa_filter=INCLUDE,
    excluded_taxa_filter=EXCLUDE,
    normalize=NORMALIZE
)

with_validation = apply_validation(
    input_table,
    level = 'S',
    validation_cutoff = 0.2,
    readcount_cutoff = 5
)
   

#Calculate Read Fractions
with_validation['Read Fraction'] = with_validation['readcount']/with_validation.groupby('samplename')['readcount'].transform('sum')

total_reads = with_validation.groupby('samplename',as_index=False)['readcount'].sum()
sample_statistics['samplename']= sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)
total_reads = pd.merge(total_reads,sample_statistics[['samplename','Classified Reads']],on='samplename',how='left').fillna(0)
total_reads['Total Fraction'] = total_reads['readcount']/total_reads['Classified Reads']
total_reads['Sample ID'] = total_reads['samplename']

# determine the top taxa based on means
taxa_we_look_at = list(with_validation.groupby('taxon')['Read Fraction'].sum().sort_values(ascending=False)[:(TOP_X+2)].keys())
if 'Not enough reads' not in taxa_we_look_at:
    taxa_we_look_at.append('Not enough reads')
if 'Not validated' not in taxa_we_look_at:
    taxa_we_look_at.append('Not validated')
print('Determined the following taxa as relevant:',taxa_we_look_at)

#assign everything else to the "other" group and readjust sum
with_validation.loc[~with_validation['taxon'].isin(taxa_we_look_at), 'taxon'] = 'Other'
with_validation = with_validation.groupby(['taxon','samplename'],as_index=False).sum()


with_validation['other'] = with_validation['taxon'] == 'Other'

with_validation = with_validation.rename(columns={
    'samplename' : 'Sample ID',
    'taxon' : 'Taxon'
})

colorMap = {}

taxa = taxa_we_look_at+['Other']

palette = cc.glasbey_light
bright = palette[::2]
muted = palette[1::2]
palette = bright+muted

taxa_we_look_at_assigned = taxa_we_look_at

taxa_we_look_at_assigned.remove('Not validated')
taxa_we_look_at_assigned.remove('Not enough reads')

for tax,col in zip(taxa_we_look_at_assigned,palette):
    colorMap[tax] = col #colors.to_hex(col)
    
altdomain = []
altrange = []

for x in taxa_we_look_at_assigned:

    c = colorMap[x]
    altdomain.append(x)
    color = colors.to_hex(c)
    altrange.append(color)
    
altdomain.append('Other')
altrange.append(colors.to_hex((1,1,1)))
altdomain.append('Not enough reads')
altrange.append(colors.to_hex((0.15,0.15,0.15)))
altdomain.append('Not validated')
altrange.append(colors.to_hex((0.55,0.55,0.55)))

with_validation = pd.merge(
    with_validation,
    sample_statistics,
    how='left',
    left_on=['Sample ID'],
    right_on=['samplename'])
with_validation[GROUPING]=with_validation[GROUPING].fillna('Unknown Group')

maxfraction = total_reads['Total Fraction'].max()

colorMap

charts = []

title_mapping = {
    '1_Gesund' : 'Control',
    'pre_1' : 'Cluster 1',
    'pre_2' : 'Cluster 2',
    'pre_3' : 'Cluster 3',
    'leukozytopenia_1' : ['Pre-Tx','Cluster 1'],
    'leukozytopenia_2' : ['Pre-Tx','Cluster 2'],
    'leukozytopenia_3' : ['Pre-Tx','Cluster 3'],
    'reconstitution_1' : ['Pre-Tx','Cluster 1'],
    'reconstitution_2' : ['Pre-Tx','Cluster 2'],
    'reconstitution_3' : ['Pre-Tx','Cluster 3']

}

groups = with_validation[GROUPING].unique() if GROUPING != None else [None]
print('Order of charts:')

for group in groups:
    
    print(group)
    grouptable = None
    if group != None:
        #Reduce to required columns to keep output reasonably small
        grouptable = with_validation[with_validation[GROUPING] == group]
    else:
        grouptable = with_validation
    
    patientlist_sorted = None
    if isinstance(SORTING,list):#If we provide a list (manual sorting) use this
        patientlist_sorted = SORTING
    else:
        patientlist_sorted = sorted(
            grouptable[SORTING].unique().tolist(),
            key=lambda x : SORTFUNCTION(x)
        )
    
        chart = alt.Chart(
          grouptable,title=title_mapping[group]
      ).transform_calculate(
      order=f"-indexof({altdomain}, datum.Taxon)"
        ).mark_bar(stroke='black',strokeWidth=0.5,strokeOpacity=0.9).encode(
            x=alt.X('Sample ID:N',sort=patientlist_sorted, axis=alt.Axis(labels=False),title=None),
            y=alt.Y('Read Fraction:Q',scale=alt.Scale(
                domain=(0,1)),title=['Estimated', 'Abundance']
                   ),
            color=alt.Color('Taxon:N',
                            legend=alt.Legend(columns=1,symbolLimit=0,labelLimit=0,rowPadding=15),
                            sort=taxa,
                            scale=alt.Scale(domain=altdomain,range=altrange)),
            tooltip=['readcount','Read Fraction','Taxon'],
            order=alt.Order('order:Q')
      )& alt.Chart(grouptable).mark_bar(stroke='black',strokeWidth=0.5,strokeOpacity=0.9).encode(
        x=alt.X('Sample ID:N',sort=patientlist_sorted,title=None),
        y=alt.Y('ARG Reads per 10000 Reads',scale=alt.Scale(domain=[0,350]),title=['ARG-carrying Reads','Per 10,000 Reads'])
    )

    charts.append(
        chart
         
        
    )

chart = reduce(lambda x,y : x&y, charts).configure_axis(
    labelFontSize=16, titleFontSize=16
).configure_title(fontSize=20).configure_legend(titleFontSize=20, labelFontSize=16)

chart.save(
    'Output/Composition/Barplots-Top_{}-{}-Only_{}-Without_{}-{}_{}_{}.html'.format(
        TOP_X,
        LEVEL,
        INCLUDE,
        str(EXCLUDE),
        hash(str(SAMPLES)),
        hash(str(SORTING)) if isinstance(SORTING,list) else SORTING,
        GROUPING
    )
)

chart

Section Pre-Tx alloHSCT microbiomes could be grouped into 3 distinct clusters:
Sort by timephase and group and check for the dominating species which abundance they take

# Presence above threshold in sample groups

In [None]:
INCLUDE= None
EXCLUDE = [CHORDATA,PLANTAE]

sample_statistics['samplename'] = sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)

#Refetch data (no other)

input_table = get_normalized_abundances(
    kraken_dataframe,
    level=LEVEL,
    samples=[s.replace('/','_') for s in SAMPLES], #New sample format uses underscore '_' instead of '/'
    included_taxa_filter=INCLUDE,
    excluded_taxa_filter=EXCLUDE,
    normalize=NORMALIZE
)

with_validation = apply_validation(
    input_table,
    level = 'S',
    validation_cutoff = 0.2,
    readcount_cutoff = 5
)
   

#Calculate Read Fractions
with_validation['Read Fraction'] = with_validation['readcount']/with_validation.groupby('samplename')['readcount'].transform('sum')

with_validation = pd.merge(
    with_validation,
    sample_statistics,
    how='left',
    left_on=['samplename'],
    right_on=['samplename'])

In [None]:
t1 = with_validation.groupby(['taxon','timephase_and_cluster'])['Read Fraction'].mean().reset_index().pivot(
    index='taxon',
    columns='timephase_and_cluster',
    values='Read Fraction'
)
t1['Sum'] = t1.sum(axis=1)
t1 = t1.sort_values(by='Sum',ascending=False)
t1.to_csv('Output/meanabundances_timephase_and_cluster.csv')

t2 = with_validation.groupby(['taxon','timephase'])['Read Fraction'].mean().reset_index().pivot(
    index='taxon',
    columns='timephase',
    values='Read Fraction'
)
t2['Sum'] = t2.sum(axis=1)
t2 = t2.sort_values(by='Sum',ascending=False)
t2.to_csv('Output/meanabundances_timephase.csv')

t3 = with_validation.groupby(['taxon','Startcluster'])['Read Fraction'].mean().reset_index().pivot(
    index='taxon',
    columns='Startcluster',
    values='Read Fraction'
)
t3['Sum'] = t3.sum(axis=1)
t3 = t3.sort_values(by='Sum',ascending=False)
t3.to_csv('Output/meanabundances_Startcluster.csv')

In [None]:
tuples = []

species_of_interest = ['Enterococcus faecium']
thresholds = [0.05,0.1,0.2,0.25,0.5]
for timephase_and_cluster in with_validation['timephase_and_cluster'].unique():

    tc_table = with_validation[with_validation['timephase_and_cluster'] == timephase_and_cluster]
    
    for species in species_of_interest:
        
        sp_table = tc_table[tc_table['taxon'] == species]

        
        for threshold in thresholds:

            nr_of_samples = len(sp_table)
            nr_of_occurences = len(sp_table[sp_table['Read Fraction'] > threshold])
            
            tuples.append((timephase_and_cluster,species,threshold,nr_of_samples,nr_of_occurences,nr_of_occurences/nr_of_samples if nr_of_samples != 0 else 'NaN'))

for timephase in with_validation['timephase'].unique():

    tc_table = with_validation[with_validation['timephase'] == timephase]
    
    for species in species_of_interest:
        
        sp_table = tc_table[tc_table['taxon'] == species]

        
        for threshold in thresholds:

            nr_of_samples = len(sp_table)
            nr_of_occurences = len(sp_table[sp_table['Read Fraction'] > threshold])
            
            tuples.append((timephase,species,threshold,nr_of_samples,nr_of_occurences,nr_of_occurences/nr_of_samples if nr_of_samples != 0 else 'NaN'))

            
for cluster in with_validation['Startcluster'].unique():

    tc_table = with_validation[with_validation['Startcluster'] == cluster]
    
    for species in species_of_interest:
        
        sp_table = tc_table[tc_table['taxon'] == species]

        
        for threshold in thresholds:

            nr_of_samples = len(sp_table)
            nr_of_occurences = len(sp_table[sp_table['Read Fraction'] > threshold])
            
            tuples.append((cluster,species,threshold,nr_of_samples,nr_of_occurences,nr_of_occurences/nr_of_samples if nr_of_samples != 0 else 'NaN'))
            
tc_table = with_validation

for species in species_of_interest:

    sp_table = tc_table[tc_table['taxon'] == species]


    for threshold in thresholds:

        nr_of_samples = len(sp_table)
        nr_of_occurences = len(sp_table[sp_table['Read Fraction'] > threshold])

        tuples.append(('Full',species,threshold,nr_of_samples,nr_of_occurences,nr_of_occurences/nr_of_samples if nr_of_samples != 0 else 'NaN'))



pd.DataFrame(tuples,columns=[
    'Timephase and Cluster','Species','Threshold','Total Samples','Samples With Presence','Fraction'
]).to_csv('Output/PresenceInGroups.csv',index=False)         

Save as a table

In [None]:
def SORTFUNCTION(x):
    try:
        return (
            x[0] == 'G', #Sort by G or regular patient first
            float(x.split('G')[-1].split('_')[0]), #Then Patient ID
            int(x.split('G')[-1].split('_')[1]), #Then Time
        )
    except:
        return (True,1,hash(x))

with_validation_pivot = with_validation.pivot(index=['samplename'],columns='taxon',values='Read Fraction')

with_validation_pivot=with_validation_pivot.fillna(0).transpose()
sorted_columns = sorted(with_validation_pivot.columns,key=SORTFUNCTION)

with_validation_pivot['Sum'] = with_validation_pivot.sum(axis=1)
with_validation_pivot = with_validation_pivot.sort_values(by='Sum',ascending=False)

with_validation_pivot = with_validation_pivot[sorted_columns]
with_validation_pivot.to_csv('Output/BarplotsCompositionValues.csv')

check correlation with ARG

In [None]:
arg_corr_table = with_validation.pivot(index=['samplename','ARG Reads per 10000 Reads'],columns='taxon',values='Read Fraction').fillna(0).reset_index()
arg_corr_table

In [None]:
corr_tuples = []
arg_corr_table.corrwith(arg_corr_table['ARG Reads per 10000 Reads']).sort_values()#
for column in arg_corr_table.columns:
    if column != 'samplename':
        corr,pval = pearsonr(arg_corr_table[column],arg_corr_table['ARG Reads per 10000 Reads'])
        corr_tuples.append((column,corr,pval))
        
pd.DataFrame(corr_tuples,columns=['Correlate','Pearson Correlation','P-Value']).to_csv('pearsonCorrARGWithPVal.csv')

Additional analysis: We check for Pre-TX samples in cluster 2 the highest abundance per taxon and subtract the highest abundance per taxon for the other patient clusters to identify unique features

In [None]:
c13 = with_validation[with_validation['timephase_and_cluster'].isin(['pre_1','pre_3'])]
c2 = with_validation[with_validation['timephase_and_cluster'].isin(['pre_2'])]
c2.groupby('Taxon')['Read Fraction'].max()-c13.groupby('Taxon')['Read Fraction'].max()

We track for all taxa that have an abundance over 1/4 in how many samples they occur in cluster 2

In [None]:
c2[c2['Read Fraction'] > 0.25].groupby('Taxon')['samplename'].count()

Additionaly the maximum read fractions are displayed

In [None]:
c2.groupby('Taxon')['Read Fraction'].max()

we want to also know how many species we need to explain 50% reads

In [None]:
div_tuples = []

for timephase_and_cluster in with_validation['timephase_and_cluster'].unique():
    subtable = with_validation[with_validation['timephase_and_cluster'] == timephase_and_cluster]
    for sample in subtable['samplename'].unique():
        sampletable = subtable[subtable['samplename']==sample]
        explained = 0
        count = 0
        for idx,row in sampletable.sort_values(by='Read Fraction',ascending=False).iterrows():
            if row['taxon'] not in ['Other','Not validated','Not enough reads']:
                explained += row['Read Fraction']
                count += 1
                if explained >= 0.5:
                    div_tuples.append((timephase_and_cluster,sample,count))
                    break
        else:
            #div_tuples.append((timephase_and_cluster,sample,-1))
            pass
pd.DataFrame(div_tuples,columns=['timephase_and_cluster','samplename','count']).groupby('timephase_and_cluster').median()
    
    

### Leukopenia-Loss
We analyze here how much of each taxon is lost comparing leukozytopenia -> pre tx

In [None]:
(with_validation[
    with_validation['timephase']=='Leukozytopenia'
].groupby('Taxon')['Read Fraction'].sum()/len(with_validation[
    with_validation['timephase']=='Leukozytopenia'
]['samplename'].unique())-with_validation[
    with_validation['timephase']=='Pre TX'
].groupby('Taxon')['Read Fraction'].sum()/len(with_validation[
    with_validation['timephase']=='Pre TX'
]['samplename'].unique())).sort_values(ascending=False)

### Max-Diffs

In [None]:
for a,b in it.permutations(with_validation['timephase'].unique(),2):
    ca = with_validation[with_validation['timephase']==a]
    cb = with_validation[with_validation['timephase']==b]
    diffs = (ca.groupby('Taxon')['Read Fraction'].max()-cb.groupby('Taxon')['Read Fraction'].max()).sort_values(ascending=True)
    combined = pd.concat([
        ca[ca['Read Fraction'] > 0.25].groupby('Taxon')['samplename'].count().rename('Above 25% {}'.format(a)),
        ca.groupby('Taxon')['samplename'].count().rename('Any% {}'.format(a))
  
    ],axis=1).fillna(0).reset_index()
    combined['Above 25% {}'.format(a)] = combined['Above 25% {}'.format(a)].astype(int)
    combined = pd.merge(diffs,combined,how='outer',on='Taxon')
    combined = combined.rename(columns={'Read Fraction':'Diff Max {} Minus {}'.format(a,b)})
    combined.to_excel('{}_Minus_{}.xlsx'.format(a,b))

In [None]:
sample_statistics[['samplename','leukocytephase_cluster_2_kurz']].groupby('leukocytephase_cluster_2_kurz').count()

### Domination

In [None]:
dom_1 = with_validation.groupby(['samplename','timephase'],as_index=False)['Read Fraction'].max()
pd.concat([
    dom_1.groupby('leukocytephase_cluster_kurz')['samplename'].count().rename('n'),
    dom_1[dom_1['Read Fraction'] > 0.5].groupby('leukocytephase_cluster_kurz')['sample'].count()/dom_1.groupby('leukocytephase_cluster_kurz')['Sample ID'].count()
],axis=1)

### Occurence of Species across Time

In [None]:
TOP_X = 20
LEVEL = 'S'

READCOUNT_CUTOFF = 5
VALIDATION_CUTOFF = 0.2
PRESENCE_CUTOFF = 0.05

#Filters
GROUPS = [
    (BACTERIA,[]),
    (FUNGI,[]),
    (EUKARYOTA,[CHORDATA,FUNGI,PLANTAE]),
    (VIRUSES,[]),
    (ARCHAEA,[]),
    (PLANTAE,[]),
    (CHORDATA,[]),
    ('1',[CHORDATA,PLANTAE])
]

SAMPLES = PAPER_SAMPLES

def SORTFUNCTION(x):
    #return x
    return (
        x[0] == 'G', #Sort by G or regular patient first
        float(x.split('G')[-1].split('/')[0]), #Then Patient ID
        int(x.split('G')[-1].split('/')[1]), #Then Time
    )

GROUPING = 'leukocytephase_cluster_kurz'
SORTING = 'samplename'
NORMALIZE = False
#################

os.makedirs('Output/Presence',exist_ok=True)

#################

for INCLUDE,EXCLUDE in GROUPS:

    input_table = get_normalized_abundances(
        kraken_dataframe,
        level=LEVEL,
        samples=[s.replace('/','_') for s in SAMPLES],
        included_taxa_filter=INCLUDE,
        excluded_taxa_filter=EXCLUDE,
        normalize=NORMALIZE
    )

    input_table = apply_validation(input_table,
                                   level=LEVEL,
                                   validation_cutoff=VALIDATION_CUTOFF,
                                   readcount_cutoff=READCOUNT_CUTOFF
                                  )
    
    input_table['Read Fraction'] = input_table['readcount'] / input_table.groupby(
        ['samplename']
    )['readcount'].transform('sum')

    sample_statistics['samplename']= sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)

    merged = pd.merge(input_table,sample_statistics,how='left',left_on=['samplename'],right_on=['samplename'])

    subset_taxa = merged[merged['Read Fraction'] >= PRESENCE_CUTOFF]['taxon'].unique()
    subset = merged[
        (merged['taxon'].isin(subset_taxa))&
        (merged['Read Fraction'] >= PRESENCE_CUTOFF)
    ]
    
    subset=subset[~subset['taxon'].isin(['Not enough reads','Not validated'])]
    
    charts = []
    
    for group in subset['timephase'].unique():
        grouptable = (subset[
            subset['timephase']==group
        ].groupby(['taxon'])['samplename'].count()/len(
                sample_statistics[
                sample_statistics['timephase'] == group
            ]
        )).rename('Fraction {} (n={})'.format(
            group,
            len(sample_statistics[sample_statistics['timephase']==group])
        ))
        charts.append(grouptable)
        
    for group in subset['Startcluster'].unique():
        grouptable = (subset[
            subset['Startcluster']==group
        ].groupby(['taxon'])['samplename'].count()/len(
                sample_statistics[
                sample_statistics['Startcluster'] == group
            ]
        )).rename('Fraction {} (n={})'.format(
            group,
            len(sample_statistics[sample_statistics['Startcluster']==group])
        ))
        charts.append(grouptable)
        
    totaltable = (subset.groupby(
        ['taxon'])['samplename'].count()/112
    ).rename('Fraction Total (n=112)')
    charts.append(totaltable)
    pd.concat(charts,axis=1).sort_values(by='Fraction Total (n=112)',ascending=False).to_excel(
        'Output/Presence_Above_{}_{}_WITHOUT_{}.xlsx'.format(
        PRESENCE_CUTOFF,
        idtonames[INCLUDE],
        list(idtonames[x] for x in EXCLUDE)
    ))

## Overview all domains/groups

In [None]:
TOP_X = 30
LEVEL = 'S'
DISCARD_CUTOFF = 5
DOMAINS = {
    'Bacteriome':(BACTERIA,[]),
    'Mycobiome':(FUNGI,[]),
    'Archaeome':(ARCHAEA,[]),
    'DNA-Virome':(VIRUSES,[]),
    'Non-Fungal Eukaryome':(EUKARYOTA,[CHORDATA,FUNGI,PLANTAE]),
}

SAMPLES = PAPER_SAMPLES

NORMALIZE = False

#################

os.makedirs('Output/Composition',exist_ok=True)
charts = {}
raw_datasets = {}

for domain in DOMAINS:
    
    print(domain)
    
    INCLUDE,EXCLUDE = DOMAINS[domain]

    input_table = get_normalized_abundances(
        kraken_dataframe,
        level=LEVEL,
        samples=[s.replace('/','_') for s in SAMPLES], #New sample format uses underscore '_' instead of '/'
        included_taxa_filter=INCLUDE,
        excluded_taxa_filter=EXCLUDE,
        normalize=NORMALIZE
    )

    with_validation = apply_validation(
        input_table,
        level = 'S',
        validation_cutoff = 0.2,
        readcount_cutoff = 5
    )

    raw_datasets[domain] = with_validation

In [None]:
# save the tables
for domain in raw_datasets:
    st = raw_datasets[domain]
    st['Read Fraction'] = st['readcount']/st.groupby('samplename')['readcount'].transform('sum')
    
    stp = st.pivot(index=['samplename'],columns='taxon',values='Read Fraction')

    stp=stp.fillna(0).transpose()
    sorted_columns = sorted(stp.columns,key=SORTFUNCTION)

    stp['Sum'] = stp.sum(axis=1)
    stp = stp.sort_values(by='Sum',ascending=False)

    stp = stp[sorted_columns]
    stp.to_csv('Output/BarplotsFigure7Compositions_{}.csv'.format(domain))


In [None]:
def SORTFUNCTION_SLASH(x):
    try:
        return (
            x[0] == 'G', #Sort by G or regular patient first
            float(x.split('G')[-1].split('/')[0]), #Then Patient ID
            int(x.split('G')[-1].split('/')[1]), #Then Time
        )
    except:
        return (True,1,hash(x))
    
SORTING = 'samplename'

for domain in DOMAINS:
    
    with_validation = raw_datasets[domain].copy()

    #Calculate Read Fractions
    with_validation['Read Fraction'] = with_validation['readcount']/with_validation.groupby('samplename')['readcount'].transform('sum')

    sample_statistics['samplename']= sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)

    # determine the top taxa based on means
    taxa_we_look_at = list(with_validation.groupby('taxon')['Read Fraction'].sum().sort_values(ascending=False)[:(TOP_X+2)].keys())
    if 'Not enough reads' not in taxa_we_look_at:
        taxa_we_look_at.append('Not enough reads')
    if 'Not validated' not in taxa_we_look_at:
        taxa_we_look_at.append('Not validated')
    print('Determined the following taxa as relevant:',taxa_we_look_at)

    #assign everything else to the "other" group and readjust sum
    with_validation.loc[~with_validation['taxon'].isin(taxa_we_look_at), 'taxon'] = 'Other'
    with_validation = with_validation.groupby(['taxon','samplename'],as_index=False).sum()


    with_validation['other'] = with_validation['taxon'] == 'Other'

    with_validation['Sample ID'] = with_validation['samplename'].str.replace('_','/')
    
    with_validation = with_validation.rename(columns={
        'taxon' : 'Taxon'
    })

    colorMap = {}

    taxa = taxa_we_look_at+['Other']

    palette = cc.glasbey_light
    bright = palette[::2]
    muted = palette[1::2]
    palette = bright+muted

    taxa_we_look_at_assigned = taxa_we_look_at

    taxa_we_look_at_assigned.remove('Not validated')
    taxa_we_look_at_assigned.remove('Not enough reads')

    for tax,col in zip(taxa_we_look_at_assigned,palette):
        colorMap[tax] = col #colors.to_hex(col)

    altdomain = []
    altrange = []

    for x in taxa_we_look_at_assigned:

        c = colorMap[x]
        altdomain.append(x)
        color = colors.to_hex(c)
        altrange.append(color)

    altdomain.append('Other')
    altrange.append(colors.to_hex((1,1,1)))
    altdomain.append('Not enough reads')
    altrange.append(colors.to_hex((0.15,0.15,0.15)))
    altdomain.append('Not validated')
    altrange.append(colors.to_hex((0.55,0.55,0.55)))

    with_validation = pd.merge(
        with_validation,
        sample_statistics,
        how='left',
        left_on=['samplename'],
        right_on=['samplename'])
    

    maxfraction = total_reads['Total Fraction'].max()

    groups = ['1_Gesund','pre_1','pre_2','pre_3',
              'reconstitution_1','reconstitution_2' ,'reconstitution_3' ,
'leukozytopenia_1', 'leukozytopenia_2','leukozytopenia_3']

    for group in groups:

        grouptable = None
        if group != None:
            #Reduce to required columns to keep output reasonably small
            grouptable = with_validation[with_validation['timephase_and_cluster'] == group]
        else:
            grouptable = with_validation


        patientlist_sorted = sorted(
            grouptable['Sample ID'].unique().tolist(),
            key=lambda x : SORTFUNCTION_SLASH(x)
        )
        
        chart =  alt.Chart(
              grouptable
          ).transform_calculate(
          order=f"-indexof({altdomain}, datum.Taxon)"
            ).mark_bar(stroke='black',strokeWidth=0.5,strokeOpacity=0.9).encode(
                x=alt.X('Sample ID:N',sort=patientlist_sorted,title=None),
                y= alt.Y('Read Fraction:Q',scale=alt.Scale(
                    domain=(0,1))
                       ,title=None)
                if group == '1_Gesund' else 
            alt.Y('Read Fraction:Q',scale=alt.Scale(domain=(0,1)),title=None,axis=None),
                color=alt.Color('Taxon:N',
                                legend=alt.Legend(symbolLimit=0,labelLimit=0,labelFontSize=17,columnPadding=12),
                                sort=taxa,
                                scale=alt.Scale(domain=altdomain,range=altrange)),
                tooltip=['readcount','Read Fraction','Taxon'],
                order=alt.Order('order:Q')
          )

        charts[(group,domain)] = chart

In [None]:
groups_sorted = ['1_Gesund','pre_1','pre_2','pre_3','leukozytopenia_1', 'leukozytopenia_2','leukozytopenia_3',
              'reconstitution_1','reconstitution_2' ,'reconstitution_3' ,
]

charts_domain = []
for domain in DOMAINS:
    charts_timephase = []
    for timephase in ['1','pre','leukozytopenia','reconstitution']:
        charts_group = []
        groups = ['1_Gesund'] if timephase == '1' else (timephase+'_'+x for x in ['1','2','3'])
        for idx,group in enumerate(groups):
            chart_group = charts[(group,domain)]
            charts_group.append(chart_group)
        chart_timephase = reduce(lambda x,y : x|y,charts_group)
        charts_timephase.append(chart_timephase)
    chart_domain = reduce(lambda x,y : alt.hconcat(x,y,spacing=60), charts_timephase).properties(title=domain)
    charts_domain.append(chart_domain)

    
reduce(
    lambda x,y : x&y, charts_domain
).resolve_scale(color='independent').configure_legend(
            orient='none',
            title=None,
            direction='horizontal',
            titleAlign='center',
            columns=8,
            titleAnchor='middle',
            legendX=430,
            legendY=400,
    labelFontSize=14,
    symbolSize=140
).configure_title(fontSize=24, offset=5, orient='top', anchor='middle').configure_axis(labelFontSize=14).save('figure7.html')

## Bray-Curtis Distance based PCoA <a class="anchor" id="pcoa"></a>
$$
B_{i,j} = 1 - \frac{2C_{i,j}}{S_i+S_j}
$$

For relative abundances, this becomes:

$1-C_{i,j}$

## PCoA

In [None]:
import scipy.spatial.distance as ssd
from skbio.stats import ordination
import scipy.cluster.hierarchy as sch
import matplotlib.pyplot as plt

################################

LEVEL = 'S'
READ_CUTOFF = 5
VALIDATION_CUTOFF = 0.2
SAMPLES = PAPER_SAMPLES
EXCLUDED = [CHORDATA,PLANTAE]
INCLUDED = None
NORMALIZE = False
INVERT_X = False
INVERT_Y = False

#################################

subtable = get_normalized_abundances(kraken_dataframe,
                                        level=LEVEL,
                                        samples=[s.replace('/','_') for s in SAMPLES],
                                        excluded_taxa_filter=EXCLUDED,
                                        included_taxa_filter=INCLUDED,
                                       normalize=NORMALIZE)



subtable = apply_validation(subtable,level=LEVEL,readcount_cutoff=READ_CUTOFF,validation_cutoff=VALIDATION_CUTOFF)
subtable['Read Fraction'] = subtable['readcount']/subtable.groupby('samplename')['readcount'].transform('sum')

pivot_table = subtable.pivot(index=['samplename'],columns=['taxonid'],values=['Read Fraction']).fillna(0)
data = pivot_table.values

samples = pivot_table.index.tolist()
taxa_count = len(data[0])

def braycurtis(indexA,indexB):
    cij = 0
    for taxonIndex in range(taxa_count):
        cij += min(
            data[indexA][taxonIndex],
            data[indexB][taxonIndex]
        )
    if (1-cij) > 1:
        print(1-cij)
    return 1-cij

bc_tuples = []

distance_matrix = np.zeros(shape=(len(samples),len(samples)))

for x in range(len(samples)):
    for y in range(x+1,len(samples)):
        print('Calculating all distancesfor sample {} / {}'.format(x+1,len(samples)),end='\r')
        distance = braycurtis(x,y)
        distance_matrix[x][y] = distance
        distance_matrix[y][x] = distance
        bc_tuples.append((samples[x][0]+'/'+str(samples[x][1]),samples[y][0]+'/'+str(samples[y][1]),distance))

PCoA = ordination.pcoa(distance_matrix,number_of_dimensions=3)

tuples = []

for idx,s in PCoA.samples.iterrows():
    sample = samples[int(idx)]

    tuples.append((s.PC1,s.PC2,s.PC3,sample))
    

pcoa_table = pd.DataFrame(
    tuples,
    columns=['x','y','z','samplename']
)

pcoa_table['Name'] = pcoa_table['samplename']
sample_statistics['samplename']= sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)

pcoa_table = pd.merge(sample_statistics,pcoa_table,left_on=['samplename'],right_on=['samplename'],how='right')

pcoa_table['Sample Type'] = pcoa_table['Name'].apply(lambda x : 'Control' if x.startswith('G') else 'patient')

if INVERT_X:
    pcoa_table['x'] = -pcoa_table['x']
if INVERT_Y:
    pcoa_table['y'] = -pcoa_table['y']

calculate clustering based on the bray-curtis distances using k-means on the pcoa space

In [None]:
sil_tuples = []

for k in [2,3,4,5,6]:

    pretx_pcoa = pcoa_table[pcoa_table['timephase'].isin(['Pre TX','Healthy'])]

    km= KMeans(n_clusters=k,random_state=0,n_init='auto').fit(pretx_pcoa[['x','y']])
    pretx_pcoa['KMeansCluster'] = km.predict(pretx_pcoa[['x','y']])+1 #+1 to get cluster numbers 1 through 3

    sil_tuples.append((k,silhouette_score(pretx_pcoa[['x','y']],pretx_pcoa['KMeansCluster'])))
alt.Chart(pd.DataFrame(sil_tuples,columns=['k','Silhouette Value'])).mark_bar().encode(
    x='k',y='Silhouette Value'
)

In [None]:
#FIX Colors for clusters
group_names = ['Control','Pre-Tx Cluster 1','Pre-Tx Cluster 2','Pre-Tx Cluster 3','Samples from other phases']
group_colors = [
    '#00fdcf', #Phocaiecola Vulgatus
    '#b500ff', #Bacteroides Uniformis
    '#366962', #Enterobacter boltae
    '#d60000', #Enterococcus Faecium
    'lightgray'
]
group_colors_grey_control=[
    'lightgray', 
    '#b500ff', #Bacteroides Uniformis
    '#366962', #Enterobacter boltae
    '#d60000', #Enterococcus Faecium
    'lightgray'    
]
group_shapes = [
    'square', 
    'circle', 
    'circle', 
    'circle',
    'diamond'
]

charts = []

pcoa_table['timephase'] = pcoa_table['timephase'].apply(
    lambda x : 'Control' if x == 'Healthy' else x)
pcoa_table['timephase'] = pcoa_table['timephase'].apply(
    lambda x : 'Pre-Tx' if x == 'Pre TX' else x)
pcoa_table['timephase'] = pcoa_table['timephase'].apply(
    lambda x : 'Leukopenia' if x == 'Leukozytopenia' else x)

pcoa_table['Startcluster'] = pcoa_table['Startcluster'].apply(
    lambda x : 'Pre-Tx Cluster {}'.format(x) if isinstance(x,int) else x)
pcoa_table['Startcluster'] = pcoa_table['Startcluster'].apply(
    lambda x : 'Control' if x == 'Gesund' else x)


for timepoint in ['Pre-Tx','Leukopenia','Reconstitution']:
    
    def calculate_visual_group(row):
        #Selected
        if (row['timephase'] == timepoint) or (row['timephase'] == 'Control'):
            return row['Startcluster']
        else:
            return 'Samples from other phases'
    #Calculate Shape and Color
    pcoa_table['Visual Group {}'.format(timepoint)] = pcoa_table.apply(calculate_visual_group,axis=1)
    timepoint_pcoa = alt.Chart(pcoa_table).mark_point(filled=True).encode(
        x=alt.X('x:Q',title='PCoA Axis 1'),
        y=alt.Y('y:Q',title='PCoA Axis 2'),
        color=alt.Color('Visual Group {}'.format(timepoint),scale=alt.Scale(domain=group_names,range=group_colors if timepoint == 'Pre-Tx' else group_colors_grey_control ),title='Cluster'),
        shape=alt.Shape('Visual Group {}'.format(timepoint),scale=alt.Scale(domain=group_names,range=group_shapes)),
        tooltip=['Name','Visual Group {}'.format(timepoint)]
    )

    charts.append(timepoint_pcoa.properties(title='PCoA based on {} Samples'.format(timepoint)))

reduce(lambda x,y : x&y,charts).resolve_scale(color='independent',shape='independent').configure_axis(labelFontSize=15,titleFontSize=15,tickCount=7).configure_title(fontSize=15)


In [None]:
bc_frame = pd.DataFrame(bc_tuples,columns=['Sample X','Sample Y','BCD'])
bc_frame[['Patient X','Time X']] = bc_frame['Sample X'].str.split('/',expand=True)
bc_frame[['Patient Y','Time Y']] = bc_frame['Sample Y'].str.split('/',expand=True)

#Determine most extreme timepoints for each patient

extreme_points = pd.concat(
    [bc_frame.groupby('Patient X')['Time X'].min().rename('Earliest Time'),
    bc_frame.groupby('Patient X')['Time X'].max().rename('Latest Time')],axis=1
).reset_index()

#Eliminate rows where earliest is >= 0 or latest is <= 0
eliminated = extreme_points[
    (extreme_points['Earliest Time'].astype(int) < 0)&
    (extreme_points['Latest Time'].astype(int) > 0)
]

print(
    'For the following patients no valid pre/post pair could be generated: {} (They will be EXCLUDED from analysis)'.format(
        
        set(extreme_points['Patient X']).difference(set(eliminated['Patient X']))
    )
)

bc_frame = bc_frame[
    #(bc_frame['Patient X'].isin(eliminated['Patient X']))&
    #(bc_frame['Patient Y'].isin(eliminated['Patient X']))&
    (bc_frame['Patient X'] == bc_frame['Patient Y'])
]

eliminated = eliminated.set_index('Patient X')

#bc_frame['Earliest X'] = bc_frame['Patient X'].apply(lambda x :  eliminated.loc[str(x)]['Earliest Time'])
#bc_frame['Latest X'] = bc_frame['Patient X'].apply(lambda x :  eliminated.loc[str(x)]['Latest Time'])

#bc_frame = bc_frame[
#    (bc_frame['Time X'] == bc_frame['Earliest X'])&
#    (bc_frame['Time Y'] == bc_frame['Latest X'])

#]

bc_frame['Time X'] = bc_frame['Time X'].astype(int)
bc_frame['Time Y'] = bc_frame['Time Y'].astype(int)

bc_frame = pd.merge(
    bc_frame,
    sample_statistics[['PatID','time','Startcluster']],
    how='left',
    left_on=['Patient X','Time X'],
    right_on=['PatID','time']
)

bc_frame

bc_frame = pd.merge(
    bc_frame,
    sample_statistics[['PatID','time','leukocytephase_cluster_2_kurz']],
    how='left',
    left_on=['Patient Y','Time Y'],
    right_on=['PatID','time']
)[['Sample X','Sample Y','BCD','Startcluster','leukocytephase_cluster_2_kurz']]

In [None]:
alt.Chart(bc_frame).mark_boxplot().encode(
    x='Startcluster:N',
    y='BCD'
)

In [None]:
bc_frame.to_csv('BCPrePostPairs.csv')

In [None]:
bc_frame.groupby('Startcluster')['BCD'].var()

### Bray Curtis Pre/Reconst

In [None]:
bc_frame = pd.DataFrame(bc_tuples,columns=['Sample X','Sample Y','BCD'])
bc_frame[['Patient X','Time X']] = bc_frame['Sample X'].str.split('/',expand=True)
bc_frame[['Patient Y','Time Y']] = bc_frame['Sample Y'].str.split('/',expand=True)
bc_frame['Time X'] = bc_frame['Time X'].astype(int)
bc_frame['Time Y'] = bc_frame['Time Y'].astype(int)

bc_frame = pd.merge(
    bc_frame,
    sample_statistics[['PatID','time','timephase']],
    how='left',
    left_on=['Patient X','Time X'],
    right_on=['PatID','time']
).drop(columns=['PatID','time']).rename(columns={'timephase':'Phase X'})
bc_frame = pd.merge(
    bc_frame,
    sample_statistics[['PatID','time','timephase']],
    how='left',
    left_on=['Patient Y','Time Y'],
    right_on=['PatID','time']
).drop(columns=['PatID','time']).rename(columns={'timephase':'Phase Y'})

bc_frame = bc_frame[
    (bc_frame['Phase X'] == 'Pre TX')&
    (bc_frame['Phase Y'] == 'Reconstitution')
]

bc_frame['Within Patient'] = bc_frame['Patient X'] == bc_frame['Patient Y']
bc_frame

In [None]:
alt.Chart(bc_frame).mark_bar().encode(
    x=alt.X('BCD',bin=True),
    y='count()'
).facet(column='Within Patient')

### "Nadir less similar to Pre-TX than" Reconstitution Hypothesis

In [None]:
bc_frame = pd.DataFrame(bc_tuples,columns=['Sample X','Sample Y','BCD'])
bc_frame[['Patient X','Time X']] = bc_frame['Sample X'].str.split('/',expand=True)
bc_frame[['Patient Y','Time Y']] = bc_frame['Sample Y'].str.split('/',expand=True)
bc_frame['Time X'] = bc_frame['Time X'].astype(int)
bc_frame['Time Y'] = bc_frame['Time Y'].astype(int)

bc_frame = pd.merge(
    bc_frame,
    sample_statistics[['PatID','time','timephase']],
    how='left',
    left_on=['Patient X','Time X'],
    right_on=['PatID','time']
).drop(columns=['PatID','time']).rename(columns={'timephase':'Phase X'})
bc_frame = pd.merge(
    bc_frame,
    sample_statistics[['PatID','time','timephase']],
    how='left',
    left_on=['Patient Y','Time Y'],
    right_on=['PatID','time']
).drop(columns=['PatID','time']).rename(columns={'timephase':'Phase Y'})


bc_frame['Within Patient'] = bc_frame['Patient X'] == bc_frame['Patient Y']

bc_frame = bc_frame[
    (bc_frame['Within Patient'] == True)&
    (bc_frame['Phase X'] == 'Pre TX')
]


extreme_points = pd.concat(
    [bc_frame.groupby(['Patient X','Phase X','Phase Y'])['Time X'].min().rename('Earliest Time'),
    bc_frame.groupby(['Patient X','Phase X','Phase Y'])['Time Y'].max().rename('Latest Time')],axis=1
).reset_index()

bc_frame = pd.merge(bc_frame,extreme_points,on=['Patient X','Phase X','Phase Y'],how='left')
bc_frame = bc_frame[
    (bc_frame['Time X'] == bc_frame['Earliest Time'])&
    (bc_frame['Time Y'] == bc_frame['Latest Time'])
]
bc_frame

In [None]:
alt.Chart(bc_frame).mark_boxplot().encode(
    x='Phase Y',
    y='BCD'
)

In [None]:
bc_frame = bc_frame.pivot(
    index='Patient X',
    columns='Phase Y',
    values='BCD'
).dropna()

wilcoxon(bc_frame['Leukozytopenia'],bc_frame['Reconstitution'])

In [None]:
bc_frame['Diff'] = bc_frame['Leukozytopenia']-bc_frame['Reconstitution']
bc_frame

In [None]:
alt.Chart(bc_frame).mark_boxplot().encode(
    y='Diff'
)

# Sampling Time Overview <a class="anchor" id="sampletimes"></a>

In [None]:
outcomes_reduced = outcomes[
    ['Pat ID',
     'Day relative to HSCT',
     'Day relative to HSCT.1',
     'Day relative to HSCT.2',
     'Day relative to HSCT.3',
     'Day relative to HSCT.4',
     'Day relative to HSCT.5',
     'Day relative to HSCT.6',
     'Day relative to HSCT.7',
    ]
         ].rename(
    columns={
        'Day relative to HSCT' : '1st Relapse',
        'Day relative to HSCT.1' : '2nd Relapse',
        'Day relative to HSCT.2' : '2nd HSCT',
        'Day relative to HSCT.3' : 'Acute GvHD Grade 1-2',
        'Day relative to HSCT.4' : 'Acute GvHD Grade 3-4',
        'Day relative to HSCT.5' : 'Moderate cGvHD',
        'Day relative to HSCT.6' : 'Severe cGvHD',
        'Day relative to HSCT.7' : 'Death',
    }
)
outcomes_reduced['No Adverse Event']=outcomes_reduced.apply( lambda x : 0 if x.count() <= 1 else np.NaN,axis=1)
outcomes_reduced

outcomes_reduced = outcomes_reduced.melt(id_vars=['Pat ID'])
outcomes_reduced = outcomes_reduced[(outcomes_reduced['value']==outcomes_reduced['value'])]
outcomes_reduced = outcomes_reduced[(outcomes_reduced['value']!='?')]
outcomes_reduced.loc[
    (outcomes_reduced['Pat ID']=='18.2'),'KI ID'
] = '18.2'
outcomes_reduced.loc[(outcomes_reduced['Pat ID']=='18.2'),'value'] -= OFFSET_PAT_18
outcomes_reduced = outcomes_reduced[outcomes_reduced['variable'] != 'No Adverse Event']

outcomes_reduced = outcomes_reduced.rename(columns={
    'Pat ID' : 'PatID',
    'variable' : 'Adverse Event',
    'value' : 'time'
})


############

sample_statistics['id'] = sample_statistics['PatID']+'/'+sample_statistics['time'].astype(str)
overview = sample_statistics[sample_statistics['id'].isin(PAPER_SAMPLES)]
overview = overview[~overview['PatID'].str.startswith('G')]


combined = pd.concat([overview,outcomes_reduced])



patientlist_sorted = []

for startcluster in combined['Startcluster'].unique():
    clusterlist = sorted(combined[combined['Startcluster'] == startcluster]['PatID'].unique())
    patientlist_sorted += clusterlist

        
overview = (alt.Chart(combined[combined['Adverse Event'] == combined['Adverse Event']]).mark_point(size=38,color='black').encode(
    x=alt.X('time',scale=alt.Scale(type='symlog'),axis=alt.Axis(grid=False)),
    y=alt.Y('PatID',sort=patientlist_sorted),
    shape=alt.Shape('Adverse Event:N',scale=alt.Scale(
        domain=['2nd HSCT','Acute GvHD Grade 1-2','Acute GvHD Grade 3-4','Moderate cGvHD','Severe cGvHD','1st Relapse','2nd Relapse','Death'],
        range=['circle','square','square','diamond','diamond','triangle','triangle','cross']))
)+alt.Chart(combined,width=800).mark_circle().encode(
    x=alt.X('time',scale=alt.Scale(type='symlog'),axis=alt.Axis(grid=False)),
    y=alt.Y('PatID:N',sort=patientlist_sorted,axis=alt.Axis(grid=True)),
    color='Startcluster:N'
)+alt.Chart(combined).mark_text(dx=0,dy=-6).encode(
    x=alt.X('time',scale=alt.Scale(type='symlog'),axis=alt.Axis(grid=False)),
    y=alt.Y('PatID:N',sort=patientlist_sorted,axis=alt.Axis(grid=True)),
    text='time'
))

overview.save('Output/Samples.html')

overview

# Zymo Std. Analysis <a class="anchor" id="zymo"></a>

In [None]:
species_level = get_normalized_abundances(kraken_dataframe,samples=['Zymo_Power_-100'],level='S')

with_validation = apply_validation(
    species_level,
    level = 'S',
    validation_cutoff = 0.2,
    readcount_cutoff = 5
)
with_validation['Read Fraction (%)'] = with_validation['readcount']/with_validation.groupby('samplename')['readcount'].transform('sum')
with_validation=with_validation.rename(columns={
    'samplename' : 'Sample ID',
    'taxon' : 'Taxon'
})

sample_statistics['samplename']= sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)


with_validation = pd.concat([with_validation,zymo_theory])[
    ['Sample ID','Taxon','Read Fraction (%)']
]

melt = zymo_minimap.melt(id_vars='Index')

melt = melt.rename(
    columns={
    'variable' : 'Taxon',
        'Index' : 'Sample ID',
        'value' : 'Read Fraction (%)'
    }
)

#Replace underscores with spaces
melt['Taxon'] = melt['Taxon'] = melt['Taxon'].apply(lambda x : ' '.join(x.split('_')))

#Merge different E-Coli Strains
melt.loc[melt['Taxon'].str.startswith('Escherichia coli'),'Taxon'] = 'Escherichia coli'
melt.loc[melt['Taxon']=='Candida albican','Taxon'] = 'Candida albicans'

melt = melt[melt['Sample ID'].isin(['Power/Reads','Power/Bases'])]

melt['Sample ID'] = melt['Sample ID'].apply(lambda x : x + '_Minimap2')

melt['Taxon'] = melt['Taxon'].apply(lambda x : 'Unmapped' if x == 'unmapped' else x)

#Reduce to Species
melt = melt.groupby(['Sample ID','Taxon'],as_index=False).sum()

combined = pd.concat([melt,with_validation])

taxa_we_look_at = list(combined[combined['Sample ID'] == 'Power/Reads_Minimap2']['Taxon'].unique())
if 'Not validated' not in taxa_we_look_at:
    taxa_we_look_at.append('Not validated')
print('Determined the following taxa as relevant:',taxa_we_look_at)

combined.loc[~combined['Taxon'].isin(taxa_we_look_at), 'Taxon'] = 'Other'
combined = combined.groupby(['Taxon','Sample ID'],as_index=False).sum()


colorMap = {}

taxa = combined['Taxon'].unique()

palette = cc.glasbey_light
bright = palette[::2]
muted = palette[1::2]
palette = bright+muted

for tax,col in zip(list(taxa),palette):
    colorMap[tax] = col #colors.to_hex(col)
    
altdomain = []
altrange = []

for x in combined['Taxon'].unique():
    
    if x == 'Other' or x == 'Unmapped' or x == 'Not validated':
        continue

    c = colorMap[x]
    altdomain.append(x)
    color = colors.to_hex(c)
    altrange.append(color)
    
altdomain.append('Other')
altrange.append(colors.to_hex((1,1,1)))
altdomain.append('Unmapped')
altrange.append(colors.to_hex((0.45,0.45,0.45)))
altdomain.append('Not validated')
altrange.append(colors.to_hex((0.55,0.55,0.55)))

combined['Sample ID'] = combined['Sample ID'].map({
    'Zymo_Power_-100' : 'Kraken2',
    'Power/Bases_Minimap2' : 'Minimap2 Bases',
    'Power/Reads_Minimap2' : 'Minimap2 Reads',
    'Zymo Theoretical Composition' : 'Theoretical Composition',
    
})

c= alt.Chart(
  combined
).transform_calculate(
order=f"-indexof({altdomain}, datum.Taxon)"
).mark_bar(stroke='black',strokeWidth=0.5,strokeOpacity=0.9).encode(
    x=alt.X('Sample ID:N',title=None,sort=['Kraken2','Minimap2 Bases','Minimap2 Reads','Theoretical Composition']
),
    y=alt.Y('Read Fraction (%):Q',scale=alt.Scale(domain=(0,1)),title='Estimated Abundance'),
    color=alt.Color('Taxon:N',legend=alt.Legend(columns=2,symbolLimit=0,labelLimit=0),scale=alt.Scale(domain=altdomain,range=altrange),title=None),
    tooltip=['Read Fraction (%)','Taxon'],
    order=alt.Order('order:Q')
)
c.save('Output/Zymo_Overview.html')
c

In [None]:
combined.pivot(index='Taxon',columns='Sample ID',values='Read Fraction (%)').fillna(0).to_csv('Output/ZymoStdAbundances.csv')

### Zymo Pearson

In [None]:
combined.pivot(index='Taxon',columns='Sample ID',values='Read Fraction (%)').fillna(0).corr()


# Crassphage <a class="anchor" id="crassphage"></a>

In [None]:
kraken_data = kraken_dataframe[kraken_dataframe['taxon'].isin(['uncultured crAssphage','unclassified Crassvirales','root'])]
kraken_data = kraken_data[kraken_data['sample'].isin(PAPER_SAMPLES)]
kraken_data = kraken_data.pivot(
    index='sample',columns='taxon',values='readcount'
).fillna(0)

kraken_data['crassphage_detected'] = (kraken_data['uncultured crAssphage']+kraken_data['unclassified Crassvirales'])/kraken_data['root']

kraken_data = kraken_data.reset_index()

def rename(x):
    split = x.rsplit('_',1)
    return split[0]+'/'+split[1]

minimap_data = pd.read_csv('Input/Crassphage/summary.csv')
minimap_data['Sample'] = minimap_data['Sample'].apply(rename)


minimap_data_aggregated = minimap_data.groupby('Sample')[['Fraction Mapped','Mapped Reads']].sum().reset_index()


combined = pd.merge(kraken_data,minimap_data_aggregated,left_on='sample',right_on='Sample',how='left')
combined

In [None]:
max_value = max(
    combined['Fraction Mapped'].max(),combined['crassphage_detected'].max()
)

line = pd.DataFrame({
    'X': [0, max_value],
    'Y': [0, max_value],
})


(alt.Chart(combined,width=400,height=400).mark_point().encode(
    y=alt.Y('Fraction Mapped',scale=alt.Scale(type='symlog',constant=0.001)),
    x=alt.X('crassphage_detected',scale=alt.Scale(type='symlog',constant=0.001))
)+alt.Chart(line,width=700,height=700).mark_line(color= 'lightgray').encode(
        x= 'X',
        y= 'Y'
    )).interactive()#.save('Output/CrassphageKrakenVsMinimap.html')

## Heatmap, which Crassphage

In [None]:
TOP_X = 50

heatmap_table = minimap_data.pivot(index='Sample',columns='Reference',values='Fraction Mapped').fillna(0).melt(ignore_index=False).reset_index().rename(columns={'value' : 'Fraction Mapped'})
heatmap_table['Unambiguous'] = heatmap_table['Reference'] != 'Ambiguous'

#For top bar charts
heatmap_table_simplified = heatmap_table.groupby(['Sample','Unambiguous'],as_index=False)['Fraction Mapped'].sum()


heatmap_table_filtered = heatmap_table[
    (heatmap_table['Sample'].isin(PAPER_SAMPLES)  )
]



heatmap_table_filtered['Fraction Mapped']=heatmap_table_filtered['Fraction Mapped']/heatmap_table_filtered.groupby('Sample')['Fraction Mapped'].transform('sum')
top_refs = heatmap_table_filtered.groupby('Reference')['Fraction Mapped'].sum().sort_values(ascending=False).keys().tolist()[:TOP_X]




heatmap_table_filtered = heatmap_table_filtered[
    
    heatmap_table_filtered['Reference'].isin(top_refs)
]

In [None]:
chart = (alt.Chart(heatmap_table_simplified[heatmap_table_simplified['Sample'].isin(PAPER_SAMPLES)]).mark_bar().encode(
    x=alt.X('Sample',sort=sort_samples(PAPER_SAMPLES),axis=alt.Axis(orient='top')),
    y=alt.Y('Fraction Mapped',stack=True),
    color='Unambiguous'
)&alt.Chart(heatmap_table_filtered).mark_rect().encode(
    x=alt.X('Sample',sort=sort_samples(PAPER_SAMPLES),axis=None),
    y=alt.Y('Reference',sort=max_refs),
    color=alt.Color('Fraction Mapped:Q',title='Fraction of uniquely assigned reads'),
    tooltip=['Fraction Mapped']
)).resolve_scale(x='shared',color='independent',y='independent')

chart.save('Output/CrassphageOverview.html')

chart

## Migration Analysis

In [None]:
migration = pd.read_csv('Input/Crassphage/migration.csv',dtype={'Tax ID' : str})
migration['Fraction'] = migration['Reads']/migration.groupby('Sample')['Reads'].transform('sum')
migration['Taxon Name'] = migration['Tax ID'].map(idtonames)

top_hits = migration.groupby('Tax ID')['Fraction'].mean().sort_values(ascending=False).keys()[:20]
migration = migration[migration['Tax ID'].isin(top_hits)]

def categorize(taxon_name):
    if taxon_name in ['uncultured phage cr6_1','uncultured crAssphage','CrAss-like virus sp.']:
        return 'crAssphage classification'
    elif taxon_name != taxon_name:
        return 'unclassified'
    return 'other classification'

migration['Category'] = migration['Taxon Name'].apply(categorize)
migration['Taxon Name'] = migration['Taxon Name'].apply(lambda x : 'unclassified' if x != x else x)

crassphages = {'uncultured phage cr6_1','uncultured crAssphage','CrAss-like virus sp.'}
rest = set(migration['Taxon Name'])-(crassphages.union({'unclassified'}))
sorted_taxa = ['unclassified']+list(sorted(rest))+list(crassphages)

alt.Chart(migration).mark_boxplot().encode(
    y='Fraction',
    x=alt.X('Taxon Name:N',sort=sorted_taxa),
    color='Category',
    tooltip=['Reads','Sample','Fraction','Taxon Name']
)

In [None]:
details = pd.read_csv('Input/crassphage_details.csv')

charts = []

for metric in ['Percentage Aligned','Mapping Quality','Identity']:

    amount,edges = np.histogram(details[metric],bins=100)

    centers = []
    for x,y in zip(edges[:-1],edges[1:]):
        centers.append((x+y)/2)

    tuples = []
    for x,y in zip(centers,amount):
        tuples.append((x,y))

    df = pd.DataFrame(tuples,columns=[metric,'Count'])

    c = alt.Chart(df).mark_bar().encode(
        x=metric,
        y=alt.Y('Count',scale=alt.Scale(type='symlog'))
    )
    
    charts.append(c)
    
reduce(lambda x,y : x&y, charts).save('Output/CrassphageAlignmentDetails.html')

# Marker Taxa Overview

In [None]:
CONFIDENCE_INTERVAL_ALPHA = 0.05

sample_statistics['samplename'] = sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)

for level in ['S','G']:

    MARKER_OTUS = MARKER_GENERA if level == 'G' else MARKER_SPECIES

    otu_data = get_normalized_abundances(
        kraken_dataframe,
        samples=[s.replace('/','_') for s in PAPER_SAMPLES],
        level=level,
        excluded_taxa_filter=[CHORDATA,PLANTAE],
        normalize=False
    )

    total_roots = otu_data.groupby('samplename')['readcount'].sum()

    otu_data['Read Fraction'] =  otu_data['readcount'] / otu_data.groupby(
        ['samplename']
    )['readcount'].transform('sum')

    PAPER_SAMPLES_UNDERSCORE = [x.replace('/','_') for x in PAPER_SAMPLES]
    marker_data = validation_data[level]
    marker_data = marker_data[marker_data['samplename'].isin(PAPER_SAMPLES_UNDERSCORE)]
    marker_data['Validated'] = marker_data['Validation Rate'] >= 0.2

    os.makedirs('marker_otus_overview',exist_ok=True)


    def SORTFUNCTION(x):
        return (
            x[0] == 'G', #Sort by G or regular patient first
            float(x.split('G')[-1])
        )

    patientlist_sorted = sorted(
        marker_data['patientid'].unique(),key=SORTFUNCTION
    )


    for taxon in MARKER_OTUS:

        subtable = marker_data[
            marker_data['Taxon Name'] == taxon
        ]

        subtable = pd.merge(
            otu_data,
            subtable[['samplename','Taxon Name','Validated']],
            left_on=['samplename','taxon'],
            right_on=['samplename','Taxon Name'],
            how='right'
        )
        
        if len(subtable) == 0:
            print('No data for taxon: {}'.format(taxon))
            continue
        subtable[['patientid','time']] = subtable['samplename'].str.split('_',expand=True)

        substitutes = []
        for sample in PAPER_SAMPLES_UNDERSCORE:
            if sample not in subtable['samplename'].unique():
                patientid,time = sample.rsplit('_',1)
                substitutes.append(
                    (patientid,int(time),'?',0,patientid+'_'+str(time),0)
                )
        subtable = pd.concat([subtable,pd.DataFrame(substitutes,columns=['patientid','time','Validated','readcount','samplename','Read Fraction'])]) 


        subtable[['Confidence Interval Low','Confidence Interval High']] = subtable.apply(
            lambda row : proportion_confint(row['readcount'],total_roots[row['samplename']],alpha=CONFIDENCE_INTERVAL_ALPHA),
            axis=1,
            result_type='expand'
        )   

        subtable['Validated'] = subtable['Validated'].map(
            {
                '?' : 'Low abundance/Not validated',
                False : 'Low abundance/Not validated',
                True : 'Validated'
            }
        )

        subtable['patientid'] = subtable['patientid'].astype(str)
        subtable['day0'] = 0

        c=alt.Chart(subtable[~subtable['patientid'].str.startswith('G')],height=40)
        c2 = alt.Chart(subtable[subtable['patientid'].str.startswith('G')],height=40)
        c = (c.mark_rule(size=20,strokeWidth=3).encode(
                x=alt.X('time:Q',scale=alt.Scale(type='symlog'),axis=alt.Axis(grid=False),title='Day'),
                color=alt.Color('Validated',scale=alt.Scale(domain=['Low abundance/Not validated','Validated'],range=['lightgrey','orange'])),
                y=alt.Y('Confidence Interval Low:Q',title=None,axis=alt.Axis(tickCount=2)),
                y2=alt.Y2('Confidence Interval High:Q',title=None)

            )+c.mark_point(size=50,filled=True).encode(
                x=alt.X('time:Q',scale=alt.Scale(type='symlog'),axis=alt.Axis(grid=False),title='Day'),
                color=alt.Color('Validated',scale=alt.Scale(domain=['Low abundance/Not validated','Validated'],range=['lightgrey','orange'])),
                y=alt.Y('Read Fraction:Q',title=None)

            )+c.transform_calculate(time='0').mark_rule().encode(
            x=alt.X('time:Q',scale=alt.Scale(type='symlog'),axis=alt.Axis(grid=False),title='Day')
            )).resolve_scale(x='shared').facet(
            row=alt.Row('patientid',sort=patientlist_sorted,title='Patient')
        
        ).resolve_scale(y='independent')|(
            c2.mark_rule(size=20,strokeWidth=3).encode(
                x=alt.X('patientid',axis=alt.Axis(grid=False),title='Control'),
                color=alt.Color('Validated',scale=alt.Scale(domain=['Low abundance/Not validated','Validated'],range=['lightgrey','orange'])),
                y=alt.Y('Confidence Interval Low:Q',title=None,axis=alt.Axis(tickCount=2)),
                y2=alt.Y2('Confidence Interval High:Q',title=None)

            )+c2.mark_point(size=50,filled=True).encode(
                x=alt.X('patientid',axis=alt.Axis(grid=False),title='Control'),
                color=alt.Color('Validated',scale=alt.Scale(domain=['Low abundance/Not validated','Validated'],range=['lightgrey','orange'])),
                y=alt.Y('Read Fraction:Q',title=None),

            )
        )
        c.properties(title=taxon).save('marker_otus_overview/{}.png'.format(taxon), engine='vl-convert') #Outcomment the engine argument if you have issues with saving



In [None]:
validation_data['S'].dropna()[validation_data['S'].dropna()['Taxon Name'].str.startswith('uncultured')]

# Illumina Sequencing / Population Shifts <a class="anchor" id="strains"></a>

## Load samplesheet (To resolve illumina ids -> patient/time)

In [None]:
samplesheetDictPatient = {}
samplesheetDictTime = {}
    
samplesheet = pd.read_csv('Input/samples.tsv',sep='\t')
#Retain only entries that have illumina files
samplesheet=samplesheet[samplesheet['illuminafile'] ==samplesheet['illuminafile']]
for idx,row in samplesheet.iterrows():
    samplesheetDictPatient[row['illuminafile']] = row['patientid']
    samplesheetDictTime[row['illuminafile']] = row['time']


## Load and process/annotate distances

In [None]:
HIGH_CONFIDENCE_CUTOFF = 50
output_folder = 'Output/StrainAnalysis/GutTrSnp_{}'.format(HIGH_CONFIDENCE_CUTOFF)

guttrsnp_distances = pd.read_csv('Input/GutTrSnp/RealData_{}/distances.csv'.format(
   HIGH_CONFIDENCE_CUTOFF 
),dtype={
    'Destination ID':str,
    'Source ID' :str,
    'Taxon' : str
})

#Kick out the stuff where there was no overlap and thus no distance
guttrsnp_distances = guttrsnp_distances.dropna()
guttrsnp_distances=guttrsnp_distances[guttrsnp_distances['Share Of Overlap'] >= 0.01]

guttrsnp_distances['Taxon Name'] = guttrsnp_distances['Taxon'].map(idtonames)
guttrsnp_distances['Source'] = guttrsnp_distances['Source ID']+'/'+guttrsnp_distances['Source Time'].astype(str)
guttrsnp_distances['Destination'] = guttrsnp_distances['Destination ID']+'/'+guttrsnp_distances['Destination Time'].astype(str)
guttrsnp_distances['PrePost Pair'] = (
    guttrsnp_distances['Source ID'] == guttrsnp_distances['Destination ID']
)&(
guttrsnp_distances['Source Time'] < 0
)&(
guttrsnp_distances['Destination Time'] > 0
)
guttrsnp_distances['Within Patient'] = guttrsnp_distances['Source ID'] == guttrsnp_distances['Destination ID']
#DoT
guttrsnp_distances['Elapsed Time'] = guttrsnp_distances['Destination Time'] - guttrsnp_distances['Source Time']
guttrsnp_distances['Distance over Time'] = guttrsnp_distances['GutTrSnp Distance'] / guttrsnp_distances['Elapsed Time']

guttrsnp_distances['Distance Name'] = guttrsnp_distances['Source Time'].astype(str) + '->' + guttrsnp_distances['Destination Time'].astype(str)

coverages = pd.read_csv('Input/GutTrSnp/RealData_{}/coverages.csv'.format(
    HIGH_CONFIDENCE_CUTOFF
),dtype={
    'Patient ID' : str,
    'Taxon ID' : str
})
guttrsnp_distances=pd.merge(guttrsnp_distances,coverages,how='left',left_on=['Source ID','Source Time','Taxon'],right_on=['Patient ID','Time','Taxon ID'])
guttrsnp_distances=pd.merge(guttrsnp_distances,coverages,how='left',left_on=['Destination ID','Destination Time','Taxon'],right_on=['Patient ID','Time','Taxon ID'])



guttrsnp_distances = guttrsnp_distances.drop(columns=[
    'Patient ID_x','Patient ID_y',
    'Time_x','Time_y',
    'Taxon ID_x','Taxon ID_y'
])


bins=[0,10,20,50,100,200,500,1000,2000]


guttrsnp_distances['VCG Source'] = pd.cut(guttrsnp_distances['Average Vertical Coverage_x'],bins)
guttrsnp_distances['VCG Destination'] = pd.cut(guttrsnp_distances['Average Vertical Coverage_y'],bins)
guttrsnp_distances = guttrsnp_distances.dropna()
guttrsnp_distances['VCG Source']=guttrsnp_distances['VCG Source'].astype(str)
guttrsnp_distances['VCG Destination']=guttrsnp_distances['VCG Destination'].astype(str)

timepoints = {}

for taxon in guttrsnp_distances['Taxon'].unique():
    for source_id in guttrsnp_distances['Source ID'].unique():
        subtable = guttrsnp_distances[
            (
                guttrsnp_distances['Taxon'] == taxon
            )&(
                guttrsnp_distances['Source ID'] == source_id
                
            )
        ]
        timepoints[(taxon,source_id)] = sorted(subtable['Source Time'].unique().tolist())
        
def sequential_test(row):
    if row['Within Patient']:
        if (row['Taxon'],row['Source ID']) in timepoints:
            sequence = timepoints[(row['Taxon'],row['Source ID'])]
            if sequence.index(row['Destination Time'])-sequence.index(row['Source Time']) == 1:
                return True
    return False

guttrsnp_distances['Sequential'] = guttrsnp_distances.apply(sequential_test,axis=1)

In [None]:
mapping_to_simulation={
    'Bacteroides vulgatus ATCC 8482' : 'PVulgatus',
    'Enterococcus faecium' : 'EFaecium',
    'Lactobacillus gasseri ATCC 33323 = JCM 1131' : 'LGasseri',
    'Shigella sonnei 53G' : 'EColi',
    
}

guttrsnp_distances_simulation = pd.read_csv('Input/GutTrSnp/SubspeciesSimulation/aggregatedDistances.csv',dtype={'Taxon' : str}).dropna()


In [None]:
# Sanity Checks / Generic Analysis


os.makedirs(output_folder,exist_ok=True)

plots = []
curation = []
for taxon in guttrsnp_distances['Taxon Name'].unique():#['Bacteroides vulgatus ATCC 8482','Flavonifractor plautii','Bacteroides uniformis','Parabacteroides distasonis ATCC 8503','Parabacteroides merdae']:
    
    taxontable = guttrsnp_distances[
        (guttrsnp_distances['Taxon Name'] == taxon)&
        (guttrsnp_distances['Sequential'])&
        (guttrsnp_distances['Share Of Overlap'] >= 0.5)
    ]
       
    taxontable['PrePost Pair'] = taxontable['PrePost Pair'].map(
        {
            True : 'Yes',
            False : 'No'
        }
    )
   
    if len(taxontable) < 3:
        continue    
 
    min_dist = taxontable['GutTrSnp Distance'].min()
    max_dist = taxontable['GutTrSnp Distance'].max()
    step = (max_dist-min_dist)/50
    
    max_line = alt.Chart(
        pd.DataFrame(
        [(max_dist)],columns=['Distance']
    )
    ).mark_rule().encode(
        x='Distance',
        size=alt.value(3)
    )
    
    if taxon in mapping_to_simulation:
        
        subdata = guttrsnp_distances_simulation[
            (guttrsnp_distances_simulation['Simulated Taxon'] == mapping_to_simulation[taxon])&
            (guttrsnp_distances_simulation['Taxon'] == taxontable['Taxon'].values[0])
            
        ]
        
        min_dist = min(subdata['GutTrSnp Distance'].min(),min_dist)
        max_dist = min(subdata['GutTrSnp Distance'].max(),max_dist)
        step = (max_dist-min_dist)/50 
             
    
    plot = alt.Chart(taxontable,height=300,title='Sequential Samples Same Patients').mark_bar().encode(
            x=alt.X('GutTrSnp Distance',bin=alt.Bin(extent=[min_dist,max_dist],step=step),title='Distance'),
            y=alt.Y('count()',axis=alt.Axis(tickMinStep=1),stack=True),
            color=alt.Color('PrePost Pair',title='Across Transplantation')
        )
    
        
    plot = (plot |alt.Chart(taxontable,height=300,title='Sequential Samples Same Patients').mark_point().encode(
            y=alt.Y('GutTrSnp Distance',title='Distance'),
            x='Elapsed Time',
            tooltip=['Source','Destination','Share Of Overlap'],
            color=alt.Color('PrePost Pair',title='Across Transplantation')
        )          
    )
    
    curation.append(
        taxontable
    )
                   
                
    prepairs = guttrsnp_distances[
        (guttrsnp_distances['Taxon Name'] == taxon)&
        (guttrsnp_distances['Source ID'] != guttrsnp_distances['Destination ID'])&
        (guttrsnp_distances['Source Time'] < 0)&
        (guttrsnp_distances['Destination Time'] < 0)
    ]
    
        
    plot =  (plot | alt.Chart(
        prepairs,height=300,title='Pre-Samples Different Patients'
    ).mark_bar().encode(
        x=alt.X('GutTrSnp Distance',bin=alt.Bin(extent=[min_dist,max_dist],step=step),title='Distance'),
            y=alt.Y('count()',axis=alt.Axis(tickMinStep=1))
    ) 
    )
    
    all_distances = guttrsnp_distances[
        (guttrsnp_distances['Taxon Name'] == taxon)&
        (guttrsnp_distances['Share Of Overlap'] >= 0.5)
    ]
    
    plot =  (plot | (alt.Chart(
            all_distances,height=300,title='All Distances'
        ).mark_bar().encode(
            x=alt.X('GutTrSnp Distance',bin=alt.Bin(maxbins=50),title='Distance'),
            y=alt.Y('count()',axis=alt.Axis(tickMinStep=1))
        ) +max_line)
    )

    if taxon in mapping_to_simulation:
        
        subdata = guttrsnp_distances_simulation[
            (guttrsnp_distances_simulation['Simulated Taxon'] == mapping_to_simulation[taxon])&
            (guttrsnp_distances_simulation['Taxon'] == taxontable['Taxon'].values[0])
            
        ]
        
        plot = (plot |(alt.Chart(
            subdata[subdata['GutTrSnp Distance'] != 0],height=300,title='Simulated Data'
        ).mark_bar().encode(
            x=alt.X('GutTrSnp Distance',bin=alt.Bin(maxbins=50),title='Distance'),
            y=alt.Y('count()',axis=alt.Axis(tickMinStep=1))
        )+max_line))
        
    plots.append(
        plot.properties(title=taxon)
    )


          
chart = reduce(lambda x,y : x& y,plots)

chart.save(output_folder+'/Overview.html')

chart

## Visualization of manually curated shifts

In [None]:
outcomes = pd.read_excel('Input/Annotations/Patient_Statistics (2).xlsx',dtype={'Pat ID' : str})
outcomes_reduced = outcomes[
    ['Pat ID',
     'Day relative to HSCT',
     'Day relative to HSCT.1',
     'Day relative to HSCT.2',
     'Day relative to HSCT.3',
     'Day relative to HSCT.4',
     'Day relative to HSCT.5',
     'Day relative to HSCT.6',
     'Day relative to HSCT.7',
    ]
         ].rename(
    columns={
        'Day relative to HSCT' : '1st Relapse',
        'Day relative to HSCT.1' : '2nd Relapse',
        'Day relative to HSCT.2' : '2nd HSCT',
        'Day relative to HSCT.3' : 'Acute GvHD Grade 1-2',
        'Day relative to HSCT.4' : 'Acute GvHD Grade 3-4',
        'Day relative to HSCT.5' : 'Moderate cGvHD',
        'Day relative to HSCT.6' : 'Severe cGvHD',
        'Day relative to HSCT.7' : 'Death',
    }
)
outcomes_reduced

outcomes_reduced = outcomes_reduced.melt(id_vars=['Pat ID'])
outcomes_reduced = outcomes_reduced[(outcomes_reduced['value']==outcomes_reduced['value'])]

outcomes_reduced.loc[
    (outcomes_reduced['Pat ID']=='18.2'),'KI ID'
] = '18.2'
outcomes_reduced.loc[(outcomes_reduced['Pat ID']=='18.2'),'value'] -= OFFSET_PAT_18
outcomes_reduced = outcomes_reduced[['Pat ID','variable','value']]
outcomes_reduced = outcomes_reduced[outcomes_reduced['value'] != '?']

outcomes_reduced

manual_curation = pd.read_csv('Input/curation_reduced.csv')
manual_curation['Distance Name'] = manual_curation['Taxon Name']+':'+manual_curation['Source']+'->'+manual_curation['Destination']
manual_curation['Time X'] = manual_curation['Source'].str.split('/',expand=True)[1].astype(int)
manual_curation['Time Y'] = manual_curation['Destination'].str.split('/',expand=True)[1].astype(int)
manual_curation['Pat ID'] = manual_curation['Source'].str.split('/',expand=True)[0]
manual_curation['Any Annotation'] = manual_curation['Clinical Annotation']!='No events tracked'

sample_times = (manual_curation.groupby(['Pat ID','Taxon Name'])['Time X'].apply(list)+manual_curation.groupby(['Pat ID','Taxon Name'])['Time Y'].apply(list)
).reset_index().explode(0).rename(columns={0:'Time'}).drop_duplicates()

charts=[]

missing_samples = pd.DataFrame([(x) for x in PAPER_SAMPLES],columns=['Sample'])
missing_samples[['Patient','Time']] = missing_samples['Sample'].str.split('/',expand=True)
missing_samples

for taxon in manual_curation['Taxon Name'].unique():
    st = manual_curation[manual_curation['Taxon Name'] == taxon]
    orm = outcomes_reduced[outcomes_reduced['Pat ID'].isin(st['Pat ID'].unique())]
    
    lm = missing_samples[
        (missing_samples['Patient'].isin(st['Pat ID'].unique()))&
        (~missing_samples['Sample'].isin(st['Source'].unique()))
    ]
    
    if len(lm) != 0:
        print(taxon,lm)
    
    c = (alt.Chart(st).mark_line().encode(
        x=alt.X('Time X',title=None,scale=alt.Scale(type='symlog')),
        x2=alt.X2('Time Y',title=None),
        y=alt.Y('Pat ID',title='Patient'),
        color=alt.Color('Manual Curation',scale=alt.Scale(
            domain=['Shift','Stable'],range=['Orange','Grey']
        ))
    )+alt.Chart(orm).mark_point(color='black',size=56).encode(
        x=alt.X('value',title='Day'),
        y=alt.Y('Pat ID',title='Patient'),
        shape=alt.Shape('variable',
                       scale=alt.Scale(
    domain=['2nd HSCT','Acute GvHD Grade 1-2','Acute GvHD Grade 3-4','Moderate cGvHD','Severe cGvHD','1st Relapse','2nd Relapse','Death'],
        range=['circle','square','square','diamond','diamond','triangle','triangle','cross'])
                       )
    )+alt.Chart(pd.DataFrame([(0)],columns=['x'])).mark_rule().encode(x='x')+alt.Chart(
    sample_times[sample_times['Taxon Name'] == taxon]

    ).mark_point(color='black',filled=True).encode(
        y='Pat ID',
        x=alt.X('Time',title=None)
    )+alt.Chart(lm).mark_point(color='grey',filled=True).encode(
        y='Patient:N',
        x=alt.X('Time:Q',title=None)
    )).resolve_scale(x='shared')
    charts.append(c.properties(title=taxon,width=800))
    
reduce(lambda x,y : x&y, charts).resolve_scale(x='shared')

In [None]:
print(
    len(manual_curation),
    len(manual_curation['Taxon Name'].unique()),
    len(manual_curation['Pat ID'].unique())

)

In [None]:
PROXIMITY = 0

def decide_proximity(row):

    st = outcomes_reduced[outcomes_reduced['Pat ID'] == row['Pat ID']]
    for idx,orow in st.iterrows():
        if (orow['variable'] in ['1st Relapse','2nd Relapse']):
            if row['Time X']-PROXIMITY <= orow['value'] and row['Time Y']+PROXIMITY >= orow['value']:
                return True

    if row['Across TX'] in ['True','Yes']:
        return True
    return False

manual_curation['Proximity to Relapse/aHSCT'] = manual_curation.apply(
    decide_proximity,axis=1
)

manual_curation.groupby(['Proximity to Relapse/aHSCT','Manual Curation'])['Pat ID'].count().rename('Counts')

In [None]:
manual_curation.to_csv('manual_cur_2.csv')

In [None]:
from scipy.stats import chi2_contingency, fisher_exact

ct = pd.crosstab(manual_curation['Proximity to Relapse/aHSCT'],manual_curation['Manual Curation'])
c, p, dof, expected = chi2_contingency(ct)
print(p)
oddsratio,p = fisher_exact(ct)
print(oddsratio,p)

# Validation Checker

In [None]:
SAMPLE = 'G11_-100'
TAXON = PLANTAE
EXCLUDE = []

subtable = validation_data[validation_data['Sample'] == SAMPLE]
subtable = subtable[subtable['Taxon ID'].apply(lambda x : is_below_or_equal(x,TAXON))]
for ftaxon in EXCLUDE:
    subtable = subtable[~subtable['Taxon ID'].apply(lambda x : is_below_or_equal(x,ftaxon))]
subtable['Validated'] = (subtable['Validation Rate'].astype(float) >= 0.2)
weighted_validation_rate= (
    subtable['Validated']*subtable['Reads']
).sum()/subtable['Reads'].sum()
print('Weighted Validation Rate: {} (>= 80% validated)'.format(weighted_validation_rate))

# Validation Rate Plots

In [None]:
os.makedirs('Output/ValidationRatePlots',exist_ok=True)

GROUPS = [
    (VIRUSES,[]),
    (EUKARYOTA,[CHORDATA,FUNGI,PLANTAE]),
    (FUNGI,[]),
    (BACTERIA,[]),
    (PLANTAE,[]),
    (ARCHAEA,[]),
    (CHORDATA,[]),
    (PLANTAE,[]),
]

for group in GROUPS:
    taxon,excludes = group
    subtable = validation_data[validation_data['Taxon ID'].apply(lambda x : is_below_or_equal(x,taxon))]
    for ftaxon in excludes:
        subtable = subtable[~subtable['Taxon ID'].apply(lambda x : is_below_or_equal(x,ftaxon))]
    topxtaxa = subtable.groupby('Taxon Name')['Reads'].sum().sort_values()[-30:].keys()
    subtable = subtable[subtable['Taxon Name'].isin(topxtaxa)]
    subtable['Validated'] = (subtable['Validation Rate'].astype(float) >= 0.2)
    subtable['Validated Reads'] = subtable['Validated'] * subtable['Reads']

    charts = []
    charts.append(
        alt.Chart(
            subtable
        ).mark_boxplot().encode(
            x=alt.X('Taxon Name',sort=subtable.groupby('Taxon Name')['Reads'].sum().sort_values().keys().tolist()[::-1]),
            y=alt.Y('Validation Rate')
        )
    )
    for genus in topxtaxa[::-1]:
        charts.append(
            alt.Chart(
                subtable[subtable['Taxon Name']==genus],title=genus
            ).mark_point().encode(
                x=alt.X('Fraction Estimate',scale=alt.Scale(type='log')),
                y=alt.Y('Validation Rate',scale=alt.Scale(domain=[0,1])),
                tooltip=['Sample','Validation Rate']
            )
        )

    combined_title = idtonames[taxon]
    if excludes != []:
        combined_title += ' without ' + ','.join(idtonames[x] for x in excludes)
        

    reduce(lambda x,y : x&y, charts).properties(title=combined_title).save('Output/ValidationRatePlots/{}.html'.format(combined_title))

In [None]:
os.makedirs('Output/ValidationRatePlots',exist_ok=True)

GROUPS = [
    (VIRUSES,[]),
    (EUKARYOTA,[CHORDATA,FUNGI,PLANTAE]),
    (FUNGI,[]),
    (BACTERIA,[]),
    (PLANTAE,[]),
    (ARCHAEA,[]),
    (CHORDATA,[]),
    (PLANTAE,[])
]

PAPER_SAMPLES_UNDERSCORE = [x.replace('/','_') for x in PAPER_SAMPLES]

print('Continuous')
for group in GROUPS:
    taxon,excludes = group
    subtable = validation_data[validation_data['Taxon ID'].apply(lambda x : is_below_or_equal(x,taxon))]
    #reduce to paper samples only
    subtable = subtable[subtable['Sample'].isin(PAPER_SAMPLES_UNDERSCORE)]
    for ftaxon in excludes:
        subtable = subtable[~subtable['Taxon ID'].apply(lambda x : is_below_or_equal(x,ftaxon))]
    subtable['Validated Reads'] = subtable['Validation Rate'] * subtable['Reads']

    print(idtonames[group[0]],
        (subtable.groupby('Sample')['Validated Reads'].sum()/subtable.groupby('Sample')['Reads'].sum()).median()
         )
print('Binary') 
for group in GROUPS:
    taxon,excludes = group
    subtable = validation_data[validation_data['Taxon ID'].apply(lambda x : is_below_or_equal(x,taxon))]
    #reduce to paper samples only
    subtable = subtable[subtable['Sample'].isin(PAPER_SAMPLES_UNDERSCORE)]
    for ftaxon in excludes:
        subtable = subtable[~subtable['Taxon ID'].apply(lambda x : is_below_or_equal(x,ftaxon))]
    subtable['Validated Reads'] = (subtable['Validation Rate']>=0.2) * subtable['Reads']

    print(idtonames[group[0]],
        (subtable.groupby('Sample')['Validated Reads'].sum()/subtable.groupby('Sample')['Reads'].sum()).median()
         )    
subtable

## Non-Fungal Eukaryota Info

In [None]:
input_table = get_normalized_abundances(
    kraken_dataframe,
    level='G',
    samples=PAPER_SAMPLES,
    included_taxa_filter=None,
    excluded_taxa_filter=None,
    normalize=False
)
input_table['Read Fraction'] = input_table['readcount'] / input_table.groupby(['time','patientid'])['readcount'].transform('sum')

taxon,excludes = EUKARYOTA,[CHORDATA,FUNGI,PLANTAE]
subtable = validation_data[validation_data['Taxon ID'].apply(lambda x : is_below_or_equal(x,taxon))]
#reduce to paper samples only
subtable = subtable[subtable['Sample'].isin(PAPER_SAMPLES_UNDERSCORE)]
for ftaxon in excludes:
    subtable = subtable[~subtable['Taxon ID'].apply(lambda x : is_below_or_equal(x,ftaxon))]
subtable['Validated Reads'] = (subtable['Validation Rate']>=0.2) * subtable['Reads']
subtable = (subtable[subtable['Validated Reads'] > 0])
subtable = pd.merge(
    subtable[['patientid','time','Validation Rate','Taxon ID']],
    input_table,
    left_on=['patientid','time','Taxon ID'],
    right_on=['patientid','time','taxonid'],
    how='left'
)

In [None]:
subtable_sub20 = subtable[subtable['readcount'] >= 20]
subtable_sub20['sample'].unique()
#subtable_sub20['taxon'].unique()

In [None]:

print(len(subtable['sample'].unique()))
print(subtable['taxon'].unique())

# Validation Plots Per Domain

In [None]:
GROUPS = {
    'Bacteria':(BACTERIA,[]),
    'Fungi':(FUNGI,[]),
    'OtherEukaryota':(EUKARYOTA,[CHORDATA,FUNGI,PLANTAE]),
    'Viruses':(VIRUSES,[]),
    'Archaea':(ARCHAEA,[]),
    'Homo':(HOMO,[]),
    'Plants':(PLANTAE,[]),
}

Gruppen = list(GROUPS.keys())


x=alt.X("domain:O", title=None, axis=alt.Axis(labels=False, ticks=False), scale=alt.Scale(paddingInner=1), sort=Gruppen),    
y=alt.Y("value:Q",title='Fraction [%]',scale=alt.Scale(type='symlog',domain=[0,100]), axis=alt.Axis(grid=False,minExtent=40, values=[0.1,0.5,1,5,10,20,50,100])), 


tables = []
sample_statistics['samplename'] = sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)

for groupname,group in GROUPS.items():
    taxon,excludes = group
    subtable = validation_data['S'][validation_data['S']['Taxon ID'].apply(lambda x : is_below_or_equal(x,taxon))]
    for ftaxon in excludes:
        subtable = subtable[~subtable['Taxon ID'].apply(lambda x : is_below_or_equal(x,ftaxon))]
    

    
    subtable['Validated'] = (subtable['Validation Rate'].astype(float) >= 0.2)
    subtable['Validated Reads Binary'] = subtable['Validated'] * subtable['Reads']
    subtable['Validated Reads Continuous'] = subtable['Validation Rate'] * subtable['Reads']
    
    subtable = subtable.groupby('samplename',as_index=False)[['Validated Reads Binary','Validated Reads Continuous','Reads']].sum()

    subtable['Validated Read Fraction Binary'] = subtable['Validated Reads Binary'] / subtable['Reads']
    subtable['Validated Read Fraction Continuous'] = subtable['Validated Reads Continuous'] / subtable['Reads']

        
    subtable['Domain'] = groupname

    tables.append(subtable)
    
combined = pd.merge(
    pd.concat(tables),
    sample_statistics[['samplename','timephase']],
    on=['samplename'],
    how='left'
)

#filter null group
combined = combined[combined['timephase']==combined['timephase']]


(alt.Chart(combined,height=400,title='Binary').mark_boxplot(ticks=True,median={'color':'black'}).encode(
    x=alt.X('Domain',sort=list(GROUPS.keys())),
    column=alt.Column(
        'timephase:O',
        spacing=0,
        sort=['Healthy', 'Pre TX','Leukozytopenia','Reconstitution'],
        title=None,
    ),    y=alt.Y('Validated Read Fraction Binary',scale=alt.Scale(type='linear',domain=[0,1.1])),
        color=alt.Color(
        "Domain:N",
        sort=Gruppen,
        legend=alt.Legend(title=None,orient='top'),
        scale=alt.Scale(domain=list(GROUPS.keys()),range=[ '#EE6677','#fecc5c','#228833', '#66CCEE', '#AA3377','lightgrey', 'darkgrey'])
    )
)&alt.Chart(combined,height=400,title='Continuous').mark_boxplot(ticks=True,median={'color':'black'}).encode(
    x=alt.X('Domain',sort=list(GROUPS.keys())),
    column=alt.Column(
        'timephase:O',
        spacing=0,
        sort=['Healthy', 'Pre TX','Leukozytopenia','Reconstitution'],
        title=None
    ),    y=alt.Y('Validated Read Fraction Continuous',scale=alt.Scale(type='linear',domain=[0,1.1])),
    color=alt.Color(
        "Domain:N",
        sort=Gruppen,
        legend=alt.Legend(title=None,orient='top'),
        scale=alt.Scale(domain=list(GROUPS.keys()),range=[ '#EE6677','#fecc5c','#228833', '#66CCEE', '#AA3377','lightgrey', 'darkgrey'])
    )
)).save('Output/ValidatedReadsPerDomainOverview.html')

In [None]:
combined[combined['Domain']=='Viruses']

In [None]:
combined.groupby('Domain')['Validated Read Fraction Continuous'].median()

# Top Validated Taxa Per Domain Stratified By PCoA Groups

In [None]:
TOP_X = 20

GROUPS = [
    (VIRUSES,[]),
    (EUKARYOTA,[CHORDATA,FUNGI,PLANTAE]),
    (FUNGI,[]),
    (BACTERIA,[]),
    (PLANTAE,[]),
    (ARCHAEA,[]),
    (CHORDATA,[]),
    (PLANTAE,[]),
    ('1',[CHORDATA,PLANTAE])
]


        
        
os.makedirs('top_validated_taxa',exist_ok=True)

tables = []

for group in GROUPS:
    taxon,excludes = group
    subtable = validation_data[validation_data['Taxon ID'].apply(lambda x : is_not_below(x,taxon))]
    for ftaxon in excludes:
        subtable = subtable[~subtable['Taxon ID'].apply(lambda x : is_below_or_equal(x,ftaxon))]
    subtable['Validated'] = (subtable['Validation Rate'].astype(float) >= 0.2)
    subtable['Validated Reads Binary'] = subtable['Validated'] * subtable['Reads']
    subtable['Validated Reads Continuous'] = subtable['Validation Rate'] * subtable['Reads']

    subtable = subtable.groupby(['Sample','Taxon ID'],as_index=False)[['Validated Reads Binary','Validated Reads Continuous','Reads']].sum()

    subtable['Validated Read Fraction Binary'] = subtable['Validated Reads Binary'] / subtable['Reads']
    subtable['Validated Read Fraction Continuous'] = subtable['Validated Reads Continuous'] / subtable['Reads']
    

    combined_title = idtonames[taxon]
    if excludes != []:
        combined_title += ' without ' + ','.join(idtonames[x] for x in excludes)
        
    subtable['Domain'] = combined_title

    tables.append(subtable)
    
combined = pd.concat(tables)

sample_statistics['Sample'] = sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)
combined = pd.merge(
    combined,
    sample_statistics[['Sample','leukocytephase_cluster_kurz']],
    on=['Sample'],
    how='left'
)

for domain in combined['Domain'].unique():
    subtable = combined[combined['Domain'] == domain]
    subtable['Taxon Name'] = subtable['Taxon ID'].map(idtonames)
    
    taxa_we_look_at = list(subtable.groupby('Taxon Name')['Validated Read Fraction Binary'].sum().sort_values(ascending=False)[:(TOP_X+1)].keys())

    taxatable = subtable[subtable['Taxon Name'].isin(taxa_we_look_at)]
    c1=alt.Chart(taxatable,title='Binary').mark_rect().encode(
        x='Taxon Name',
        y='leukocytephase_cluster_kurz',
        color='mean(Validated Read Fraction Binary)'
    )

    taxa_we_look_at = list(subtable.groupby('Taxon Name')['Validated Read Fraction Continuous'].sum().sort_values(ascending=False)[:(TOP_X+1)].keys())

    taxatable = subtable[subtable['Taxon Name'].isin(taxa_we_look_at)]
    c2=alt.Chart(taxatable,title='Continuous').mark_rect().encode(
        x='Taxon Name',
        y='leukocytephase_cluster_kurz',
        color='mean(Validated Read Fraction Continuous)'
    )
    
    (c1|c2).resolve_scale(color='independent').save('top_validated_taxa/{}_top_{}.png'.format(domain,TOP_X))


# Table Top Genera

In [None]:
GROUPS = [
    (BACTERIA,[]),
    (FUNGI,[]),
    (EUKARYOTA,[CHORDATA,FUNGI,PLANTAE]),
    (VIRUSES,[]),
    (ARCHAEA,[]),
    (PLANTAE,[]),
    (CHORDATA,[]),
    (PLANTAE,[]),
    ('1',[CHORDATA,FUNGI,PLANTAE])
]

PAPER_SAMPLES_UNDERSCORE = [x.replace('/','_') for x in PAPER_SAMPLES]

os.makedirs('top_validated_taxa',exist_ok=True)
sample_statistics['Sample'] = sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)

tables = []

total_reads = validation_data.groupby('Sample')['Reads'].sum()
all_charts = []


for group in GROUPS:
    taxon,excludes = group
    subtable = validation_data[validation_data['Taxon ID'].apply(lambda x : is_below_or_equal(x,taxon))]
    for ftaxon in excludes:
        subtable = subtable[~subtable['Taxon ID'].apply(lambda x : is_below_or_equal(x,ftaxon))]
    subtable['Validated'] = (subtable['Validation Rate'].astype(float) >= 0.2)
    subtable['Validated Reads Binary'] = subtable['Validated'] * subtable['Reads']
    #subtable['Validated Reads Continuous'] = subtable['Validation Rate'] * subtable['Reads']

    subtable = subtable.groupby(['Sample','Taxon ID'],as_index=False)[['Validated Reads Binary',
                                                                       #'Validated Reads Continuous',
                                                                       'Reads','Validated']].sum()

    local_reads = subtable.groupby('Sample')['Reads'].sum()
    
    subtable['Validated Local Read Fraction Binary'] = subtable.apply(
        lambda x:x['Validated Reads Binary']/local_reads[x['Sample']],axis=1
    )
    #subtable['Validated Local Read Fraction Continuous'] = subtable['Validated Reads Continuous'] / subtable['Reads']
    
    subtable['Validated Global Read Fraction Binary'] = subtable.apply(
        lambda x:x['Validated Reads Binary']/total_reads[x['Sample']],axis=1
    )
    #subtable['Validated Global Read Fraction Continuous'] = subtable.apply(
    #    lambda x:x['Validated Reads Continuous']/total_reads[x['Sample']],axis=1
    #)        
    
    
    combined_title = idtonames[taxon]
    if excludes != []:
        combined_title += ' without ' + ','.join(idtonames[x] for x in excludes)
    
    subtable['Validated Above 5% Binary Local'] = subtable['Validated Local Read Fraction Binary'] > 0.05
    subtable['Validated Above 5% Binary Global'] = subtable['Validated Global Read Fraction Binary'] > 0.05
    #subtable['Validated Above 5% Continuous Local'] = subtable['Validated Local Read Fraction Continuous'] > 0.05
    #subtable['Validated Above 5% Continuous Global'] = subtable['Validated Global Read Fraction Continuous'] > 0.05
    subtable = pd.merge(
        subtable,
        sample_statistics[['Sample','leukocytephase_cluster_2_kurz','Startcluster']],
        on=['Sample'],
        how='left'
    )
    subtable['Startcluster'] = subtable['Startcluster'].replace({'Gesund':'Control'})
    subtable = subtable.rename(columns={'Validated' : 'Detected'})
    
    totals =subtable.groupby(['Taxon ID'])[[
    'Detected',
    ]].sum().rename(columns={'Detected' : 'Absolute Detected'})

    short =(subtable.groupby(['Taxon ID'])[[
    'Detected',
    'Validated Above 5% Binary Local',
    'Validated Above 5% Binary Global',
    #'Validated Above 5% Continuous Local',
    #'Validated Above 5% Continuous Global',
    ]].sum()/len(validation_data['Sample'].unique())).add_prefix('Total ').fillna(0)

    short = pd.concat([totals,short],axis=1)

    long = subtable.groupby(['leukocytephase_cluster_2_kurz','Taxon ID'])[[
    'Detected',
    'Validated Above 5% Binary Local',
    'Validated Above 5% Binary Global',
    #'Validated Above 5% Continuous Local',
    #'Validated Above 5% Continuous Global'
    ]].sum().reset_index()

    samples_per_group = sample_statistics.groupby(['leukocytephase_cluster_2_kurz'])[
    'Sample'
    ].count().reset_index()

    long = pd.merge(long,samples_per_group,how='left',on='leukocytephase_cluster_2_kurz')

    long['Detected'] = long['Detected'] / long['Sample']
    long['Validated Above 5% Binary Local'] = long['Validated Above 5% Binary Local'] / long['Sample']
    long['Validated Above 5% Binary Global'] = long['Validated Above 5% Binary Global'] / long['Sample']
    #long['Validated Above 5% Continuous Local'] = long['Validated Above 5% Continuous Local'] / long['Sample']
    #long['Validated Above 5% Continuous Global'] = long['Validated Above 5% Continuous Global'] / long['Sample']
    long = long.drop(columns='Sample')

    long = long.pivot(index='Taxon ID',columns='leukocytephase_cluster_2_kurz').fillna(0).swaplevel(0,1,1)

    
    long_2 = subtable.groupby(['Startcluster','Taxon ID'])[[
    'Detected',
    'Validated Above 5% Binary Local',
    'Validated Above 5% Binary Global',
    #'Validated Above 5% Continuous Local',
    #'Validated Above 5% Continuous Global'
    ]].sum().reset_index()

    samples_per_group = sample_statistics.groupby(['Startcluster'])[
    'Sample'
    ].count().reset_index()

    long_2 = pd.merge(long_2,samples_per_group,how='left',on='Startcluster')

    long_2['Detected'] = long_2['Detected'] / long_2['Sample']
    long_2['Validated Above 5% Binary Local'] = long_2['Validated Above 5% Binary Local'] / long_2['Sample']
    long_2['Validated Above 5% Binary Global'] = long_2['Validated Above 5% Binary Global'] / long_2['Sample']
    #long['Validated Above 5% Continuous Local'] = long['Validated Above 5% Continuous Local'] / long['Sample']
    #long['Validated Above 5% Continuous Global'] = long['Validated Above 5% Continuous Global'] / long['Sample']
    long_2 = long_2.drop(columns='Sample')

    long_2 = long_2.pivot(index='Taxon ID',columns='Startcluster').fillna(0).swaplevel(0,1,1)

    long=pd.merge(long,long_2,left_index=True,right_index=True,how='outer')
        
    combined = pd.merge(
        short,long,left_index=True,right_index=True,how='outer'
    ).sort_values(
        by='Total Detected',ascending=False
    )
    combined = combined[[                    'Absolute Detected',
                                                 'Total Detected',
                          'Total Validated Above 5% Binary Local',
                         'Total Validated Above 5% Binary Global',
                                          ('Healthy', 'Detected'),
                                        ('Healthy', 'Validated Above 5% Binary Local'),
                  ('Healthy', 'Validated Above 5% Binary Global'),
                                           ('Pre TX', 'Detected'),
                    ('Pre TX', 'Validated Above 5% Binary Local'),
                   ('Pre TX', 'Validated Above 5% Binary Global'),

                                   ('Leukozytopenia', 'Detected'),
                                 ('Leukozytopenia', 'Validated Above 5% Binary Local'),
           ('Leukozytopenia', 'Validated Above 5% Binary Global'),

                                   ('Reconstitution', 'Detected'),
            ('Reconstitution', 'Validated Above 5% Binary Local'),
           ('Reconstitution', 'Validated Above 5% Binary Global'),
                                                  (1, 'Detected'),

                           (1, 'Validated Above 5% Binary Local'),
                                               (1, 'Validated Above 5% Binary Global'),

                                                                       (2, 'Detected'),
                                                (2, 'Validated Above 5% Binary Local'),
                          (2, 'Validated Above 5% Binary Global'),

                                                  (3, 'Detected'),
                           (3, 'Validated Above 5% Binary Local'),
                          (3, 'Validated Above 5% Binary Global')]
    ]
    for column in combined.columns:
        if column != 'Absolute Detected':
            combined[column] = combined[column].astype(float).map("{:.2%}".format)
    combined = combined[combined['Absolute Detected'] > 0]
            
    excel = combined[:]
    excel.index = excel.index.map(idtonames)
    excel.to_excel('top_validated_taxa/{}.xlsx'.format(combined_title))
    
    #Chart
    
    marker_data = validation_data
    marker_data = marker_data[marker_data['Sample'].isin(PAPER_SAMPLES_UNDERSCORE)]
    marker_data['Validated'] = marker_data['Validation Rate'] >= 0.2

    os.makedirs('marker_genera_overview',exist_ok=True)


    def SORTFUNCTION(x):
        return (
            x[0] == 'G', #Sort by G or regular patient first
            float(x.split('G')[-1])
        )

    patientlist_sorted = sorted(
        marker_data['patientid'].unique(),key=SORTFUNCTION
    )


    tables = []

    for taxon in combined.index[:10]:
        
        subtable = marker_data[
            marker_data['Taxon ID'] == taxon
        ]


        substitutes = []
        for sample in PAPER_SAMPLES_UNDERSCORE:
            if sample not in subtable['Sample'].unique():
                patientid,time = sample.rsplit('_',1)
                substitutes.append(
                    (patientid,int(time),'?')
                )
        subtable = pd.concat([subtable,pd.DataFrame(substitutes,columns=['patientid','time','Validated'])]) 


        occurences = subtable[subtable['Validated']==True].groupby('patientid')['Validated'].count()
        interesting_patients = occurences[occurences >= 1].keys()
        subtable = subtable[subtable['patientid'].isin(interesting_patients)]

        subtable['Validated'] = subtable['Validated'].map(
            {
                '?' : 'Not Validated',
                False : 'Not Validated',
                True : 'Validated'
            }
        )

        subtable['Genus'] = idtonames[taxon] if taxon in idtonames else str(taxon) 

        #Filter patients without any presence

        tables.append(subtable)
    charttable = pd.concat(tables)
    charts=[]
    for genus in list(combined.index[:10]):
        genustable = charttable[charttable['Taxon ID'] == genus]
        if len(genustable) == 0:
            continue
        chart = (alt.Chart(genustable[~genustable['patientid'].str.startswith('G')],title=idtonames[genus]).mark_point().encode(
            x=alt.X('time',scale=alt.Scale(type='symlog'),axis=alt.Axis(grid=False),title='Day'),
                color=alt.Color(
                    'Validated',scale=alt.Scale(domain=['Not Validated','Validated'],range=['lightgrey','orange'])),
            y=alt.Y('patientid:N',sort=patientlist_sorted,axis=alt.Axis(grid=True),title=None)
            )|

            alt.Chart(genustable[genustable['patientid'].str.startswith('G')]).mark_point().encode(            color=alt.Color(
                    'Validated',scale=alt.Scale(domain=['Not Validated','Validated'],range=['lightgrey','orange'])),
            y=alt.Y('patientid:N',sort=patientlist_sorted,title=None)
            )).resolve_scale(y='independent')   
        
        charts.append(chart)
        if group in [
            (BACTERIA,[]),
            (FUNGI,[]),
            (EUKARYOTA,[CHORDATA,FUNGI,PLANTAE]),
            (VIRUSES,[]),
            (ARCHAEA,[]),
            (PLANTAE,[]),
            (CHORDATA,[]),
            (PLANTAE,[])
        ]:
            all_charts.append((chart,len(genustable['patientid'].unique())))
    reduce(lambda x,y : x&y,charts).save('top_validated_taxa/{}.svg'.format(combined_title))
    reduce(lambda x,y : x&y,charts).save('top_validated_taxa/{}.png'.format(combined_title))


In [None]:
print(len(all_charts),sum(x[1] for x in all_charts))

In [None]:
PATIENT_MAX = 200
FIXED_SPACE_PER_PLOT = 10

cur = 0
idx = 0

column_charts = []
column_chart = []

while(idx < len(all_charts)):
    chart,patients = all_charts[idx]
    
    column_chart.append(chart)
    cur += patients+FIXED_SPACE_PER_PLOT
    idx += 1
    if cur >= PATIENT_MAX:
        cur = 0
        column_charts.append(
            reduce(lambda x,y : x&y ,column_chart)
        )
        column_chart = []
reduce(lambda x,y : x|y, column_charts)

# correlation val rate abundance

In [None]:
input_table = get_normalized_abundances(
    kraken_dataframe,
    level='G',
    samples=PAPER_SAMPLES,
    included_taxa_filter=None,
    excluded_taxa_filter=None,
    normalize=False
)

input_table['Read Fraction (%)'] = input_table['readcount'] / input_table.groupby(['time','patientid'])['readcount'].transform('sum')

input_table['Sample'] = input_table['patientid']+'_'+input_table['time'].astype(str)


DISCARD_CUTOFF = 20

with_validation = pd.merge(
    input_table,
    validation_data[['Sample','Taxon ID','Validation Rate']],
    left_on=['Sample','taxonid'],
    right_on=['Sample','Taxon ID'],
    how='left'
).dropna()

In [None]:
with_validation['Read Fraction Bin'] = pd.cut(with_validation['Read Fraction (%)'],[x*0.05 for x in range(round(1/0.05))]+[1]).astype('str')

In [None]:
alt.data_transformers.disable_max_rows()
alt.Chart(
    with_validation.dropna(),width=500
).mark_boxplot(ticks=True).encode(
    x=alt.X('Read Fraction Bin:O'),
    y=alt.Y('Validation Rate')
)

In [None]:
sample_statistics['Sample'] = sample_statistics['PatID']+'_'+sample_statistics['time'].astype(str)

t3 = pd.merge(
    with_validation[with_validation['taxon'] == 'Candida'],
    sample_statistics[['Sample','leukocytephase_cluster_2_kurz']],
    how='left',
    on='Sample'
)

t3['Validated'] = (t3['Validation Rate']>= 0.2)

t3.groupby(['leukocytephase_cluster_2_kurz'])['Validated'].sum()

In [None]:
t3.groupby(['leukocytephase_cluster_2_kurz'])['Validated'].count()

# Lifelines Comparison Bacteroides Phocaeicola

In [None]:
TOP_X = 30
LEVEL = 'S'
DISCARD_CUTOFF = 5

EXCLUDE = [PLANTAE]
INCLUDE = None

SAMPLES = HEALTHY_SAMPLES+LIFELINES

def SORTFUNCTION(x):
    try:
        return (
            x[0] == 'G', #Sort by G or regular patient first
            float(x.split('G')[-1].split('_')[0]), #Then Patient ID
            int(x.split('G')[-1].split('_')[1]), #Then Time
        )
    except:
        return (True,1,hash(x))
    

#################

os.makedirs('Output/Lifelines',exist_ok=True)

reset_frame = get_normalized_abundances(
    kraken_dataframe,
    level=LEVEL,
    samples=[s.replace('/','_') for s in SAMPLES], #New sample format uses underscore '_' instead of '/'
    included_taxa_filter=INCLUDE,
    excluded_taxa_filter=EXCLUDE,
    normalize=False
)


In [None]:
input_table = reset_frame.copy()

 #Phase 1: Kick out low abundance groups, assign to "Not enough reads"
input_table.loc[input_table['readcount']<DISCARD_CUTOFF,'taxon'] = 'Not enough reads'
input_table.loc[input_table['readcount']<DISCARD_CUTOFF,'taxonid'] = NOT_ENOUGH_READS

#Readjust sum (group multiple "Not enough reads" entries together)
input_table=  input_table.groupby(['taxon','taxonid','samplename'],as_index=False).sum()

#Calculate Read Fractions
input_table['Read Fraction'] = input_table['readcount']/input_table.groupby('samplename')['readcount'].transform('sum')

# determine the top taxa based on means
taxa_we_look_at = list(input_table.groupby('taxon')['Read Fraction'].sum().sort_values(ascending=False)[:(TOP_X+2)].keys())
if 'Not enough reads' not in taxa_we_look_at:
    taxa_we_look_at.append('Not enough reads')
if 'Unassigned at Level' not in taxa_we_look_at:
    taxa_we_look_at.append('Unassigned at Level')

print('Determined the following taxa as relevant:',taxa_we_look_at)

#assign everything else to the "other" group and readjust sum
input_table.loc[~input_table['taxon'].isin(taxa_we_look_at), 'taxon'] = 'Other'
input_table = input_table.groupby(['taxon','samplename'],as_index=False).sum()


input_table['other'] = input_table['taxon'] == 'Other'

input_table = input_table.rename(columns={
    'samplename' : 'Sample ID',
    'taxon' : 'Taxon'
})

colorMap = {}

taxa = taxa_we_look_at+['Other']

palette = cc.glasbey_light
bright = palette[::2]
muted = palette[1::2]
palette = bright+muted

taxa_we_look_at_assigned = taxa_we_look_at

taxa_we_look_at_assigned.remove('Not enough reads')
taxa_we_look_at_assigned.remove('Unassigned at Level')

for tax,col in zip(taxa_we_look_at_assigned,palette):
    colorMap[tax] = col #colors.to_hex(col)
    
altdomain = []
altrange = []

for x in taxa_we_look_at_assigned:

    c = colorMap[x]
    altdomain.append(x)
    color = colors.to_hex(c)
    altrange.append(color)
    
altdomain.append('Other')
altrange.append(colors.to_hex((1,1,1)))
altdomain.append('Not enough reads')
altrange.append(colors.to_hex((0,0,0)))
altdomain.append('Unassigned at Level')
altrange.append(colors.to_hex((0.5,0.5,0.5)))

In [None]:
input_table['Lifelines'] = input_table['Sample ID'].str.startswith('L')

ll_table = input_table[input_table['Lifelines']]

patientlist_sorted = sorted(
    ll_table['Sample ID'].unique().tolist())

lls = alt.Chart(
  ll_table,title='Lifelines'
).transform_calculate(
order=f"-indexof({altdomain}, datum.Taxon)"
).mark_bar(stroke='black',strokeWidth=0.5,strokeOpacity=0.9).encode(
    x=alt.X('Sample ID:N',sort=patientlist_sorted, axis=alt.Axis(labels=False,ticks=False),title=None),
    y=alt.Y('Read Fraction:Q',scale=alt.Scale(
        domain=(0,1)),title=None
           ),
    color=alt.Color('Taxon:N',
                    legend=alt.Legend(columns=2,symbolLimit=0,labelLimit=0),
                    sort=taxa,
                    scale=alt.Scale(domain=altdomain,range=altrange)),
    tooltip=['readcount','Read Fraction','Taxon','Sample ID'],
    order=alt.Order('order:Q')
)

c_table = input_table[~input_table['Lifelines']]

patientlist_sorted = sorted(
    c_table['Sample ID'].unique().tolist(),
    key=lambda x : SORTFUNCTION(x)
)

ccs = alt.Chart(
  c_table,title='Control'
).transform_calculate(
order=f"-indexof({altdomain}, datum.Taxon)"
).mark_bar(stroke='black',strokeWidth=0.5,strokeOpacity=0.9).encode(
    x=alt.X('Sample ID:N',sort=patientlist_sorted, axis=alt.Axis(labels=False,ticks=False),title=None),
    y=alt.Y('Read Fraction:Q',scale=alt.Scale(
        domain=(0,1)),title=['Estimated', 'Abundance']
           ),
    color=alt.Color('Taxon:N',
                    legend=alt.Legend(columns=2,symbolLimit=0,labelLimit=0),
                    sort=taxa,
                    scale=alt.Scale(domain=altdomain,range=altrange)),
    tooltip=['readcount','Read Fraction','Taxon','Sample ID'],
    order=alt.Order('order:Q')
)

bac = input_table[input_table['Taxon'].str.startswith(('Phocaeicola','Bacteroides'))]
bac['Healthy'] = bac['Sample ID'].str.startswith('G').map(
    {True : 'Control',
    False : 'Lifelines'}
)
bac = bac.groupby(['Healthy','Sample ID'])['Read Fraction'].sum().reset_index()
bacc = alt.Chart(bac).mark_boxplot().encode(
    x=alt.X('Healthy',sort=['Control','Lifelines'],title=None),
    y=alt.Y('Read Fraction',title='Estimated Abundance Bacteroides+Phocaeicola')
)

(ccs|lls|bacc).configure_title(orient='bottom')

# Correlation to clinical outcome

In [None]:
DISCARD_CUTOFF = 20

table = get_normalized_abundances(
    kraken_dataframe,
    samples=set(PAPER_SAMPLES)-set(DUPLICATES),
    level='G',
    normalize=False,
    excluded_taxa_filter=[CHORDATA])

table['Sample'] = table['patientid']+'_'+table['time'].astype(str)

table = pd.merge(
    table,
    validation_data[['Sample','Taxon ID','Validation Rate']],
    left_on=['Sample','taxonid'],
    right_on=['Sample','Taxon ID'],
    how='left'
)

#Phase 1: Kick out low abundance groups, assign to "Not enough reads"
table.loc[table['readcount']<DISCARD_CUTOFF,'taxon'] = 'Not enough reads'

#Phase 2: Check for rest if validates, if not assign to "Not validated"
table['Validated'] = (table['Validation Rate'] > 0.2)
table.loc[(table['Validated']!=True)&~(table['taxon'] == 'Not enough reads'), 'taxon'] = 'Not validated'
#Readjust sum (group multiple "Not enough reads" entries together)
table = table.groupby(['taxon','time','patientid'],as_index=False).sum()
table['Read Fraction (%)'] = table['readcount']*100 / table.groupby(['time','patientid'])['readcount'].transform('sum')
table['Sample'] = table['patientid']+'/'+table['time'].astype(str)
table

In [None]:
TAXA_OF_INTEREST=[
    'Bacteroides',
    'Candida',
    'Enterococcus',
    'Saccharomyces',
    'Lactobacillus',
    'Methanosarcina',
    'Pseudomonas',
    'Methanobrevibacter',
    'Blautia'
]

outer_charts = []

for TAXON_OF_INTEREST in TAXA_OF_INTEREST:


    taxon_table = table[table['taxon'] == TAXON_OF_INTEREST]
    for sample in table['Sample'].unique():
        #check if sample does not have the taxon
        if sample not in taxon_table['Sample'].unique():
            #print('Sample {} does not have the taxon, creating a dummy entry ...'.format(sample))
            patientid,time = sample.split('/')
            time = int(time)
            taxon_table = pd.concat([
                taxon_table,
                pd.DataFrame(
                    [(0,patientid,time,TAXON_OF_INTEREST,'???','???',sample,0)],
                    columns=['readcount','patientid','time','taxon','taxonid','level','Sample','Read Fraction (%)']
                )
            ])




    taxon_table = pd.merge(
        taxon_table,
        sample_statistics[['PatID','time','Startcluster','leukocytephase_cluster_2_kurz']],
        how='left',
        left_on=['patientid','time'],
        right_on=['PatID','time']
    ).drop(columns=['PatID'])
    
    taxon_table = taxon_table[
        (taxon_table['leukocytephase_cluster_2_kurz'] == 'Leukozytopenia')|
        (taxon_table['patientid'].str.startswith('G'))
    ]

    outcomes['Any GvHD'] = (
        (outcomes['aGVHD Grade 1-2']==1) |
        (outcomes['aGvHD Grade 3 - 4']==1) | 
        (outcomes['moderate cGVHD']==1) |
        (outcomes['severe cGvHD']==1)
    )

    outcomes['Relapse'] = (
        (outcomes['Replase_1']==1) |
        (outcomes['Replase_2']==1)
        )

    taxon_table = pd.merge(
        taxon_table,
        outcomes[['Pat ID','Outcomes (non (0), adverse event (1))','Any GvHD','Death', 'Relapse']],
        how='left',
        left_on=['patientid'],
        right_on=['Pat ID']
    ).rename(columns={
            'Outcomes (non (0), adverse event (1))' : 'Adverse Event',
        'Any GvHD' : 'GvHD',
        'Death' : 'Death',
        'Relapse' : 'Relapse'
    })

    charts = []
    
    taxon_table.to_csv('{}_MARKER_GENERA_ANNA.csv'.format(TAXON_OF_INTEREST))

    for category in ['Adverse Event','GvHD','Death', 'Relapse']:
        
        

        yes_distrib = taxon_table[(taxon_table[category] == True)]['Read Fraction (%)']
        no_distrib = taxon_table[(taxon_table[category] == False)]['Read Fraction (%)']
        U1, p = mannwhitneyu(yes_distrib, no_distrib, method="exact")    
        '''
        print('Comparing {} ({} samples) to no {} ({} samples), p-value: {:.2}'.format(
            category,len(yes_distrib),
                                                                                    category,
                                                                                    len(no_distrib),
                                                                                    p))
        '''
        chart = alt.Chart(taxon_table[~taxon_table['patientid'].str.startswith('G')],width=80, height=400, title='{} (p-Val: {:.2})'.format(
            category,p
        )).mark_boxplot(ticks=True).encode(
            x=alt.X(category+':N',title=None),
            y=alt.Y('Read Fraction (%)', axis=alt.Axis(format='.2f')),
            tooltip=['Pat ID', 'time', 'Read Fraction (%)'],
        )+alt.Chart(taxon_table[taxon_table['patientid'].str.startswith('G')]).mark_rule(color='red').encode(
            y='mean(Read Fraction (%))'
        )
        charts.append(chart)

    outer_charts.append(
        reduce(lambda x,y : x|y , charts).resolve_scale(y='shared').properties(title=TAXON_OF_INTEREST)
    )
reduce (lambda x,y : x&y, outer_charts).configure_tick(thickness=2)

In [None]:
taxon_table[taxon_table['Adverse Event']!=taxon_table['Adverse Event']]

# Validation Calibration with Zymo Std.

In [None]:
validation_data[validation_data['Sample'].str.startswith('Zymo')]

In [None]:
stats_validation = validation_data

stats_validation['Validated'] = (stats_validation['Validation Rate'].astype(float) >= 0.2)
stats_validation['Validated Reads Binary'] = stats_validation['Validated'] * stats_validation['Reads']
stats_validation['Validated Reads Continuous'] = stats_validation['Validation Rate'] * stats_validation['Reads']

stats_validation = stats_validation.groupby('Sample',as_index=False)[['Validated Reads Binary','Validated Reads Continuous','Reads']].sum()

stats_validation['Validated Read Fraction Binary'] = stats_validation['Validated Reads Binary'] / stats_validation['Reads']
stats_validation['Validated Read Fraction Continuous'] = stats_validation['Validated Reads Continuous'] / stats_validation['Reads']
stats_validation[
    stats_validation['Sample'].isin(
        x.replace('/','_') for x in PAPER_SAMPLES
    )
                ].mean()

# Compare Groups

In [None]:
LEVEL = 'G'

#Filters
INCLUDE = None
EXCLUDE = [CHORDATA,PLANTAE]

SAMPLES = PAPER_SAMPLES


def SORTFUNCTION(x):
    #return x
    return (
        x[0] == 'G', #Sort by G or regular patient first
        float(x.split('G')[-1].split('/')[0]), #Then Patient ID
        int(x.split('G')[-1].split('/')[1]), #Then Time
    )

GROUPING = 'leukocytephase_cluster_kurz'
SORTING = 'Sample ID'
#'leukocytephase_cluster_kurz'#'leukocytephase_cluster_kurz' # None if you don't want any Grouping, otherwise a column to group by

NORMALIZE = False

#################

os.makedirs('Output/Composition',exist_ok=True)

input_table = get_normalized_abundances(
    kraken_dataframe,
    level=LEVEL,
    samples=SAMPLES,
    included_taxa_filter=INCLUDE,
    excluded_taxa_filter=EXCLUDE,
    normalize=NORMALIZE
)

input_table['Read Fraction (%)'] = input_table['readcount']*100 / input_table.groupby(['time','patientid'])['readcount'].transform('sum')

input_table['Sample'] = input_table['patientid']+'_'+input_table['time'].astype(str)


DISCARD_CUTOFF = 20

with_validation = pd.merge(
    input_table,
    validation_data[['Sample','Taxon ID','Validation Rate']],
    left_on=['Sample','taxonid'],
    right_on=['Sample','Taxon ID'],
    how='left'
)


total_reads = with_validation.groupby('sample',as_index=False)['readcount'].sum()
sample_statistics['sample']= sample_statistics['PatID']+'/'+sample_statistics['time'].astype(str)
total_reads = pd.merge(total_reads,sample_statistics[['sample','Classified Reads']],on='sample',how='left')
total_reads['Total Fraction (%)'] = total_reads['readcount']/total_reads['Classified Reads']
total_reads['Sample ID'] = total_reads['sample']

#Phase 1: Kick out low abundance groups, assign to "Not enough reads"
with_validation.loc[with_validation['readcount']<DISCARD_CUTOFF,'taxon'] = 'Not enough reads'

#Phase 2: Check for rest if validates, if not assign to "Not validated"
with_validation['Validated'] = (with_validation['Validation Rate'] > 0.2)
with_validation.loc[(with_validation['Validated']!=True)&~(with_validation['taxon'] == 'Not enough reads'), 'taxonid'] = '-3'
#Readjust sum (group multiple "Not enough reads" entries together)
with_validation = with_validation.groupby(['taxonid','sample'],as_index=False).sum()

In [None]:
with_validation['Ingroup'] = with_validation['taxonid'].apply(lambda x : is_below_or_equal(x,'1224'))

In [None]:
sample_statistics['sample']= sample_statistics['PatID']+'/'+sample_statistics['time'].astype(str)
table = pd.merge(
    with_validation[with_validation['Ingroup']].groupby(['sample'],as_index=False)['Read Fraction (%)'].sum(),
    sample_statistics[['sample','leukocytephase_cluster_kurz']],on='sample',how='left'
)

print(table.groupby('leukocytephase_cluster_kurz')['Read Fraction (%)'].median())
alt.Chart(table).mark_boxplot().encode(
    x='leukocytephase_cluster_kurz',
    y='Read Fraction (%)'
)

In [None]:
sample_statistics.groupby('leukocytephase_cluster_2_kurz')['Total Reads'].median()

# Check combined abundance and validation presence

In [None]:
input_table = get_normalized_abundances(
    kraken_dataframe,
    level='G',
    samples=PAPER_SAMPLES,
    included_taxa_filter=VIRUSES,
    excluded_taxa_filter=None,
    normalize=False
)

input_table['Read Fraction'] = input_table['readcount'] / input_table.groupby(['time','patientid'])['readcount'].transform('sum')
input_table.loc[input_table['taxon']=='Crassphage Pseudo-Genus', 'taxon'] = 'uncultured crAssphage'


In [None]:
taxon = 'uncultured crAssphage'
abundance = 0.1

subtable = validation_data[
    (validation_data['Taxon Name'] == taxon)&
    (validation_data['sample'].isin(PAPER_SAMPLES))
]

subtable = pd.merge(
    subtable,
    sample_statistics,
    on='sample',
    how='left'
)


subtable = pd.merge(
    subtable,
    input_table,
    left_on=['sample','Taxon ID'],
    right_on=['sample','taxonid'],
    how='left'
)

subtable = subtable[subtable['leukocytephase_cluster_kurz'].isin(['pre_1','pre_2','pre_3'])]
subtable = subtable[subtable['readcount'] >= 20]

print(
    len(subtable[subtable['Validation Rate'] >= 0.2]),
    len(subtable[
        (subtable['Validation Rate'] >= 0.2)&
        (subtable['Read Fraction'] >= abundance)
    ])
)
subtable[['sample','readcount','Validation Rate','Read Fraction','leukocytephase_cluster_kurz']]

# Saturation

In [None]:
input_table = get_normalized_abundances(
    kraken_dataframe,
    level='G',
    samples=PAPER_SAMPLES,
    included_taxa_filter=None,
    excluded_taxa_filter=None,
    normalize=False
)

In [None]:
input_table['Sample'] = input_table['patientid']+'_'+input_table['time'].astype(str)

with_validation = pd.merge(
    input_table,
    validation_data[['Sample','Taxon ID','Validation Rate']],
    left_on=['Sample','taxonid'],
    right_on=['Sample','Taxon ID'],
    how='left'
)


with_validation['Validated'] = (with_validation['Validation Rate'] > 0.2)
with_validation.loc[(with_validation['Validated']!=True)&~(with_validation['taxon'] == 'Not enough reads'), 'taxonid'] = '-3'


In [None]:
alt.Chart(pd.concat([
    with_validation.groupby('sample')['taxon'].count(),
    with_validation.groupby('sample')['readcount'].sum()
],axis=1)).mark_point().encode(
    x=alt.X('readcount',scale=alt.Scale(type='symlog')),
    y='taxon'
)

# 