## Inspeciting sensitivitiy of phyclip, Robustness of dominant strain selection

## 0. General

### 0.1. Libraries

In [44]:
import os, dendropy, math
import pandas as pd
from Bio import SeqIO
from datetime import date, timedelta
from dateutil.relativedelta import relativedelta


### 0.2. General 

In [45]:
passage_labels = ['clinical', 'cell-based MDCK or SIAT', 'cell-based other']


In [46]:
region_hemispheres = {"us":"nh", "europe":"nh", "aunz":"sh"}
hemispheres = ["nh", "sh"]

In [47]:
start_mature_protein = 17
HA1_length_AA = 329 #in mature protein 
protein_length = 567
sequence_length = 1701

In [48]:
pp = 24 #preceding period in months
pdur = 16 #period duration
flu_seasons = {h:{} for h in hemispheres}
vaccine_selection = {h:{} for h in hemispheres}
for y in range(2002, 2024):
    #northern hemisphere flu season
    season = f"{y}-{y+1}"
    if y != 2023:
        flu_seasons["nh"][season] = (date(y, 10, 1), date(y+1, 4, 30))
        #northern hemisphere vaccine strain selection moment
        vaccine_selection["nh"][season] = date(y,2,1)
    
    if y==2002:
        continue
    #southern hemisphere flu season
    flu_seasons["sh"][str(y)] = (date(y, 3, 1), date(y, 9, 30))
    #southern hemisphere vaccine strain selection moment
    vaccine_selection["sh"][str(y)] = date(y-1,9,1)

season_periods = {h:{} for h in hemispheres}
period_dates = {}
for h, sd in flu_seasons.items():
    for season, (fss, fse) in sd.items():
        vsd = vaccine_selection[h][season]
        #something with periods in treason doesn't seem to be correct but don't wanna waste time on that for now
        if h == "nh":
            ps, pe  = vsd + relativedelta(months=-pp), vsd + relativedelta(months=+pdur-1)
        else:
            ps, pe  = vsd + relativedelta(months=-pp), vsd + relativedelta(months=+pdur-1)

        
        period = f"{str(ps.year)[2:]}{'0'+str(ps.month) if len(str(ps.month))==1 else str(ps.month)}-{str(pe.year)[2:]}{'0'+str(pe.month) if len(str(pe.month))==1 else str(pe.month)}"
        season_periods[h][season] = period
        period_dates[period] = (ps, pe)

#cut off for HA1 
early_season_cutoff = {'nh':'2011-2012', 'sh':'2011'}
early_season_cutoff_dates = {h:flu_seasons[h][season][-1] for h, season in early_season_cutoff.items()}
early_seasons = [] #get list
for h, cutoff in early_season_cutoff.items():
    csons = list(flu_seasons[h].keys())
    for i, season in enumerate(csons):
        if i <= csons.index(cutoff):
            early_seasons.append(season)

### 0.3. Functions

In [49]:
def get_leaf_nodes(node):
    leafs = []
    for child in node.postorder_iter():
        if child.is_leaf():
            leafs.append(child)
    return leafs

def get_mrca(nodes, tree):
    for internal_node in tree.postorder_node_iter():
        children = get_leaf_nodes(internal_node)
        if all([node in children for node in nodes]):
            return internal_node

def get_consensus_sequence(sequences):
    seqs = []
    for seq in sequences:
        if type(seq) == list:
            seqs.append(seq)
        else:
            seqs.append(list(seq))
    seqs = pd.DataFrame.from_records(seqs)#.reset_index(drop=)
    consensus = seqs.mode().iloc[0,:].to_list()
    return "".join(consensus)
        
def float_date_to_date(fd):
    return date(math.floor(fd), 1, 1 ) + timedelta(days=(fd-math.floor(fd))*365)

def get_dominant_clade(tree, phyclip, seqdict, fss, fse, n=1):
    
    #get sequences within flu season from   for phyclip data
    clustecss_in_season = {}
    for i, row in phyclip.iterrows():
        if row["TAXA"].split("|")[0].replace(" ", "_") in seqdict.keys():
            try:
                clustecss_in_season[row["CLUSTER"]].append(row["TAXA"].split("|")[0].replace(" ", "_") )
            except:
                clustecss_in_season[row["CLUSTER"]] = [row["TAXA"].split("|")[0].replace(" ", "_") ]
    
    cluster_seqs = {}
    for cluster, leafs in clustecss_in_season.items():
        
        #get nodes of the phyclip cluster sequences in the big tree
        nodes = []
        for node in tree.leaf_node_iter():
            if node.taxon.label.split("|")[0].replace(" ", "_") in leafs:
                nodes.append(node)

        #get most recent common ancestor of the cluster
        mrca  = get_mrca(nodes, tree)
        
        #get sequences in seqs
        if mrca.is_leaf():
            try:
                d = float_date_to_date(float(mrca.annotations.get_value("date")))
            except:
                for k, v in mrca.annotations.values_as_dict().items():
                    if k.split(",")[-1] == "date":
                        d = float_date_to_date(float(v))
            if d >= fss and d<=fse:
                cluster_seqs[cluster] = [mrca.taxon.label.split("|")[0].replace(" ", "_")]
        else:
            leafs = get_leaf_nodes(mrca)
            leafs_in_season = []
            for node in leafs:
                try:
                    d = float_date_to_date(float(node.annotations.get_value("date")))
                except:
                    for k, v in node.annotations.values_as_dict().items():
                        if k.split(",")[-1] == "date":
                            d = float_date_to_date(float(v))
                if d >= fss and d<=fse:
                    leafs_in_season.append(node.taxon.label.split("|")[0].replace(" ", "_"))
            cluster_seqs[cluster] = leafs_in_season

    clustecss_to_return = {}
    while len(clustecss_to_return) < n and len(clustecss_to_return)<len(cluster_seqs):
        #get biggest cluster 
        biggest_cluster = [k for k, l in cluster_seqs.items() if len(l) == max([len(l) for k,l  in cluster_seqs.items() if k not in list(clustecss_to_return.keys())]) and k not in list(clustecss_to_return.keys())][0]
        seqs = [v for k,v in seqdict.items() if k in cluster_seqs[biggest_cluster]]
        #get consensus
        consensus = get_consensus_sequence(seqs)
        clustecss_to_return[biggest_cluster] = consensus
    if n > 1:
        return clustecss_to_return
    else:
        return consensus

## 1. Prep phyclip run

Should be performed already in dominant strain script

In [50]:
analysis_dir = "../analysis"
for d in os.listdir(analysis_dir):
    if os.path.isdir(os.path.join(analysis_dir,d)) and d.split("_")[0] in list(region_hemispheres.keys()):
        region = d.split("_")[0]

        phyclip_input_dir = os.path.join(analysis_dir, d, "phyclip_input_trees")
        if not os.path.isdir(phyclip_input_dir):
            os.mkdir(phyclip_input_dir)

        treetime_dir = os.path.join(analysis_dir, d, "treetime")
        if not os.path.isdir(treetime_dir):
            print (f"can not find treetime output for {region}")

        for f in os.listdir(os.path.join(treetime_dir)):
            if f.endswith(".nexus") and "divergence" in f:
                output_tree_file = os.path.join(os.path.join(phyclip_input_dir, f.replace("H3N2_HA", f"H3N2_HA_{region}").replace(".nexus", ".tree")))
                if not os.path.isfile(output_tree_file):
                    tree = dendropy.Tree.get(path=os.path.join(treetime_dir, f), schema="nexus", preserve_underscores=True)
                    tree.write(path=output_tree_file, schema="newick", suppress_internal_node_labels=True,)
                    print (output_tree_file) 
        
        
       

### 1.1. Run Phyclip

running run_phyclip_sensitivity.py manually

## 2. Results

### 2.1. get output data

In [51]:
#get phyclip output
phyclip_output = "../analysis/phyclip_sensitivity"

phyclip_files = {region:{s:{} for s in season_periods[h].keys()}  for region, h in region_hemispheres.items()}
for f in os.listdir(phyclip_output):
    if not f.startswith("cluster"):
        continue
    
    region = f.split("_")[-3]
    p = f.split("_")[-2]
    if p not in season_periods[region_hemispheres[region]].values():
         continue
            
    season = [s for s,pe in season_periods[region_hemispheres[region]].items() if pe ==p][0]
    
    cs = int(f.split("_")[3][2:])
    fdr = float(f.split("_")[4][3:])

    phyclip_files[region][season][(cs, fdr)] = os.path.join(phyclip_output, f)

    

In [52]:
#get sequence output 
sequence_files = {region:{s:{} for s in season_periods[h].keys()}  for region, h in region_hemispheres.items()}
for d in os.listdir(analysis_dir):
    if os.path.isdir(os.path.join(analysis_dir,d)) and d.split("_")[0] in list(region_hemispheres.keys()):
        region = d.split("_")[0]

        #getting from alignment dir to sequences are trimmed to CDS
        alignment_dir = os.path.join(analysis_dir,d, "alignment")
        for f in os.listdir(alignment_dir):
            if f.startswith("."):
                continue #.DS store file #livelovemac
            #determine season of interest from  file 
            p = f.split("_")[-2]
            if p not in season_periods[region_hemispheres[region]].values():
                continue
            season = [s for s,pe in season_periods[region_hemispheres[region]].items() if pe ==p][0]
            

            sequence_files[region][season] = os.path.join(alignment_dir, f)

        

In [53]:
#get metadata
metadata_files = {region:{s:{} for s in season_periods[h].keys()}  for region, h in region_hemispheres.items()}
for d in os.listdir(analysis_dir):
    if os.path.isdir(os.path.join(analysis_dir,d)) and d.split("_")[0] in list(region_hemispheres.keys()):
        region = d.split("_")[0]

        #getting from alignment dir to sequences are trimmed to CDS
        sequence_dir = os.path.join(analysis_dir,d, "sequences")
        for f in os.listdir(sequence_dir):
            if not  f.endswith(".csv"):
                continue #.DS store file #livelovemac
            #determine season of interest from  file 
            p = f.split("_")[-2]
            if p not in season_periods[region_hemispheres[region]].values():
                continue
            season = [s for s,pe in season_periods[region_hemispheres[region]].items() if pe ==p][0]
            

            metadata_files[region][season] = os.path.join(sequence_dir, f)

        

In [54]:
#get tree files
tree_files = {region:{p:{} for p in season_periods[h].values()}  for region, h in region_hemispheres.items()}
for d in os.listdir(analysis_dir):
    if os.path.isdir(os.path.join(analysis_dir,d)) and d.split("_")[0] in list(region_hemispheres.keys()):
        region = d.split("_")[0]

        #getting from alignment dir to sequences are trimmed to CDS
        treetime_dir = os.path.join(analysis_dir,d, "treetime")
        for f in os.listdir(treetime_dir):
            if not f.endswith("timetree.nexus") or "dominant_clade" in f:
                continue #.DS store file #livelovemac
            #determine season of interest from  file 
            p = f.split("_")[-2]
            if p not in season_periods[region_hemispheres[region]].values():
                continue
            season = [s for s,pe in season_periods[region_hemispheres[region]].items() if pe ==p][0]

            tree_files[region][season] = os.path.join(treetime_dir, f)


        

In [55]:
#get og dominant strains
dominant_strain_file = '../data/dominant_strains.fasta'

dominant_strains = {region:{p:{} for p in flu_seasons[h].keys()}  for region, h in region_hemispheres.items()}
for r in SeqIO.parse(dominant_strain_file, "fasta"):
    region, season = r.id.split("_")[0:2]
    dominant_strains[region][season] = r.seq
    

## 2.2. Determine dominant strains

In [56]:
dominant_strain_file = '../data/sensitivity_analysis_dominant_strains.fasta'

In [57]:
redo = False
if not os.path.isfile(dominant_strain_file) or redo:
    sensitivity_results = {region:{p:{} for p in flu_seasons[h].keys()}  for region, h in region_hemispheres.items()}
    for region, sd in phyclip_files.items():
        for season, phyclipd in sd.items():
            
            for (cs, fdr), f in phyclipd.items():
                
                #determine dominant strain first 
                sequences = {r.id.split("|")[0]:r.seq[:1701].replace("-", "N").translate() for r in SeqIO.parse(sequence_files[region][season], "fasta")}
                metadata = pd.read_csv(metadata_files[region][season])
                metadata["Collection_Date"] = [d.date() for d in pd.to_datetime(metadata["Collection_Date"], errors='coerce')] 
                metadata = metadata[metadata["passage_category"].isin(passage_labels)]

                #get sequences within influenza season
                seq_ids = metadata[(metadata["Collection_Date"]>=fss)&(metadata["Collection_Date"]<=fse)]["Isolate_Id"].to_list()
                tree = dendropy.Tree.get(path=tree_files[region][season], schema="nexus")
                

                phyclip = pd.read_csv(f, sep="\t")
                fss, fse = flu_seasons[region_hemispheres[region]][season]
                dominant_clade_consensus = get_dominant_clade(tree, phyclip, sequences, fss, fse)
                if season in early_seasons and len(dominant_clade_consensus)>HA1_length_AA:
                    dominant_clade_consensus = dominant_clade_consensus[start_mature_protein-1:HA1_length_AA+start_mature_protein-1]
                    
                sensitivity_results[region][season][(cs, fdr)] = dominant_clade_consensus
        
    with open(dominant_strain_file, "w") as fw:
        for region, sd in sensitivity_results.items():
            for season, dsd in sd.items():
                for (cs, fdr), strain in dsd.items():
                    if season in early_seasons:
                        fw.write(f">{region}_{season}_cs{cs}_fdr{fdr}_HA1\n{strain}\n") 
                    else:
                        fw.write(f">{region}_{season}_cs{cs}_fdr{fdr}\n{strain}\n") 
                                
else:
    sensitivity_results_ = {region:{p:{} for p in flu_seasons[h].keys()}  for region, h in region_hemispheres.items()}
    for r in SeqIO.parse(dominant_strain_file, "fasta"):
        region, season = r.id.split("_")[0:2]
        cs = int(r.id.split("_")[2][2:])
        fdr = float(r.id.split("_")[3][3:])
        
        sensitivity_results_[region][season][(cs, fdr)] = r.seq
                
        

In [58]:
sensitivity_results_overview = []
for region, sd in sensitivity_results.items():
    for season, dsd in sd.items():
        dominant_strain = dominant_strains[region][season]
        for (cs, fdr), strain in dsd.items():
            
            if strain == dominant_strain:
                sensitivity_results_overview.append([region, season, cs, fdr, "match"])
            else:
                print (f"no match for {region} {season} cs{cs} fdr{fdr}, write code to summarize differences")

sensitivity_results_overview = pd.DataFrame.from_records(sensitivity_results_overview, columns=["region", "season", "minimum cluster size", "fdr", "difference between identified dominant strain and original dominant strain (cs3, fdr0.2)"])

In [59]:
sensitivity_results_overview.to_csv("../data/sensitivity_analysis_dominant_strains_summary.csv", index=False)