## Inspeciting sensitivitiy of phyclip, Robustness of dominant strain selection

parameters tested: 
- minimum cluster size: 3,5 
- FDR: 0, 1

Keep in mind minumum cluster size must be at least n tips / 2 > and minimum cluster size = 5 was therefore unsuccessfull in europe 2002-2003 season

## 0. General

### 0.1. Libraries

In [1]:
import os, dendropy, math
import pandas as pd
from Bio import SeqIO
from datetime import date, timedelta
from dateutil.relativedelta import relativedelta


### 0.2. General 

In [2]:
passage_labels = ['clinical', 'cell-based MDCK or SIAT', 'cell-based other']


In [3]:
region_hemispheres = {"us":"nh", "europe":"nh", "aunz":"sh"}
hemispheres = ["nh", "sh"]

In [4]:
start_mature_protein = 17
HA1_length_AA = 329 #in mature protein 
protein_length = 567
sequence_length = 1701

In [5]:
pp = 24 #preceding period in months
pdur = 16 #period duration
flu_seasons = {h:{} for h in hemispheres}
vaccine_selection = {h:{} for h in hemispheres}
for y in range(2002, 2024):
    #northern hemisphere flu season
    season = f"{y}-{y+1}"
    if y != 2023:
        flu_seasons["nh"][season] = (date(y, 10, 1), date(y+1, 4, 30))
        #northern hemisphere vaccine strain selection moment
        vaccine_selection["nh"][season] = date(y,2,1)
    
    if y==2002:
        continue
    #southern hemisphere flu season
    flu_seasons["sh"][str(y)] = (date(y, 3, 1), date(y, 9, 30))
    #southern hemisphere vaccine strain selection moment
    vaccine_selection["sh"][str(y)] = date(y-1,9,1)

season_periods = {h:{} for h in hemispheres}
period_dates = {}
for h, sd in flu_seasons.items():
    for season, (fss, fse) in sd.items():
        vsd = vaccine_selection[h][season]
        #something with periods in treason doesn't seem to be correct but don't wanna waste time on that for now
        if h == "nh":
            ps, pe  = vsd + relativedelta(months=-pp), vsd + relativedelta(months=+pdur-1)
        else:
            ps, pe  = vsd + relativedelta(months=-pp), vsd + relativedelta(months=+pdur-1)

        
        period = f"{str(ps.year)[2:]}{'0'+str(ps.month) if len(str(ps.month))==1 else str(ps.month)}-{str(pe.year)[2:]}{'0'+str(pe.month) if len(str(pe.month))==1 else str(pe.month)}"
        season_periods[h][season] = period
        period_dates[period] = (ps, pe)

#cut off for HA1 
early_season_cutoff = {'nh':'2011-2012', 'sh':'2011'}
early_season_cutoff_dates = {h:flu_seasons[h][season][-1] for h, season in early_season_cutoff.items()}
early_seasons = [] #get list
for h, cutoff in early_season_cutoff.items():
    csons = list(flu_seasons[h].keys())
    for i, season in enumerate(csons):
        if i <= csons.index(cutoff):
            early_seasons.append(season)

In [6]:
epitope_sites= {"A":[122,124,126,130,131,132,133,135,137,138,140,142,143,144,145,146,150,152,168], 
                "B":[128,129,155,156,157,158,159,163,165,186,187,188,189,190,192,193,194,196,197,198], 
                "C":[44,45,46,47,48,50,51,53,54,273,275,276,278,279,280,294,297,299,300,304,305,307,308,309,310,311,312], 
                "D":[96,102,103,117,121,167,170,171,172,173,174,175,176,177,179,182,201,203,207,208,209,212,213,214,215,216,217,218,219,226,227,228,229,230,238,240,242,244,246,247,248,], 
                "E":[57,59,62,63,67,75,78,80,81,82,83,86,87,88,91,92,94,109,260,261,262,265]}
epitope_positions = [s for sites in epitope_sites.values() for s in sites]

#rbs
# rbs_positions = [98, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 144, 145, 146, 153, 154, 155, 156, 157, 158, 159,
#                  183, 184, 185, 186, 187, 188, 189, 190, 190, 191, 192, 193, 194, 195, 196, 219, 220, 221, 222, 223, 224, 
#                  225, 226, 227, 228,]
rbs_positions = [131, 132, 133, 134, 135, 136, 137, 138, 140, 183, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228,]
koel_sites = [145, 189, 193, 156, 159, 158, 155]

#amino acid list for coding 
aa_list = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]

### 0.3. Functions

In [7]:
def get_leaf_nodes(node):
    leafs = []
    for child in node.postorder_iter():
        if child.is_leaf():
            leafs.append(child)
    return leafs

def get_mrca(nodes, tree):
    for internal_node in tree.postorder_node_iter():
        children = get_leaf_nodes(internal_node)
        if all([node in children for node in nodes]):
            return internal_node

def get_consensus_sequence(sequences):
    seqs = []
    for seq in sequences:
        if type(seq) == list:
            seqs.append(seq)
        else:
            seqs.append(list(seq))
    seqs = pd.DataFrame.from_records(seqs)#.reset_index(drop=)
    consensus = seqs.mode().iloc[0,:].to_list()
    return "".join(consensus)
        
def float_date_to_date(fd):
    return date(math.floor(fd), 1, 1 ) + timedelta(days=(fd-math.floor(fd))*365)

def get_dominant_clade(tree, phyclip, seqdict, fss, fse, n=1):
    
    #get sequences within flu season from   for phyclip data
    clustecss_in_season = {}
    for i, row in phyclip.iterrows():
        if row["TAXA"].split("|")[0].replace(" ", "_") in seqdict.keys():
            try:
                clustecss_in_season[row["CLUSTER"]].append(row["TAXA"].split("|")[0].replace(" ", "_") )
            except:
                clustecss_in_season[row["CLUSTER"]] = [row["TAXA"].split("|")[0].replace(" ", "_") ]
    
    cluster_seqs = {}
    for cluster, leafs in clustecss_in_season.items():
        
        #get nodes of the phyclip cluster sequences in the big tree
        nodes = []
        for node in tree.leaf_node_iter():
            if node.taxon.label.split("|")[0].replace(" ", "_") in leafs:
                nodes.append(node)

        #get most recent common ancestor of the cluster
        mrca  = get_mrca(nodes, tree)
        
        #get sequences in seqs
        if mrca.is_leaf():
            try:
                d = float_date_to_date(float(mrca.annotations.get_value("date")))
            except:
                for k, v in mrca.annotations.values_as_dict().items():
                    if k.split(",")[-1] == "date":
                        d = float_date_to_date(float(v))
            if d >= fss and d<=fse:
                cluster_seqs[cluster] = [mrca.taxon.label.split("|")[0].replace(" ", "_")]
        else:
            leafs = get_leaf_nodes(mrca)
            leafs_in_season = []
            for node in leafs:
                try:
                    d = float_date_to_date(float(node.annotations.get_value("date")))
                except:
                    for k, v in node.annotations.values_as_dict().items():
                        if k.split(",")[-1] == "date":
                            d = float_date_to_date(float(v))
                if d >= fss and d<=fse:
                    leafs_in_season.append(node.taxon.label.split("|")[0].replace(" ", "_"))
            cluster_seqs[cluster] = leafs_in_season

    clustecss_to_return = {}
    while len(clustecss_to_return) < n and len(clustecss_to_return)<len(cluster_seqs):
        #get biggest cluster 
        biggest_cluster = [k for k, l in cluster_seqs.items() if len(l) == max([len(l) for k,l  in cluster_seqs.items() if k not in list(clustecss_to_return.keys())]) and k not in list(clustecss_to_return.keys())][0]
        seqs = [v for k,v in seqdict.items() if k in cluster_seqs[biggest_cluster]]
        #get consensus
        consensus = get_consensus_sequence(seqs)
        clustecss_to_return[biggest_cluster] = consensus
    if n > 1:
        return clustecss_to_return
    else:
        return consensus
    
def find_mismatch_HA(s1, s2):
    
    signal_peptide=  []
    epitope = []
    non_epitope = []

    
    if len(s1) != len(s2):
        print ("unequal length")
    if len(s1) > 329:
        
        #starting with signal peptide
        signal_peptide=  []
        for i, aa in enumerate(s1[:start_mature_protein-1]):
            if aa != s2[i]:
                signal_peptide.append(f"{aa}{i+1}{s2[i]}")
            
         #mature protein
        for i, aa in enumerate(s1[start_mature_protein-1:]):
            if s2[start_mature_protein-1:][i]!=aa:
                if i in epitope_positions:
                    epitope.append(f"{aa}{i+1}{s2[i]}")
                else:
                    non_epitope.append(f"{aa}{i+1}{s2[i]}")
    else:
        #just HA1
        for i, aa in enumerate(s1[:start_mature_protein-1]):
            if aa != s2[i]:
                if i in epitope_positions:
                    epitope.append(f"{aa}{i+1}{s2[i]}")
                else:
                    non_epitope.append(f"{aa}{i+1}{s2[i]}")
       
    sum = []             
    if len(signal_peptide) > 0:
        sum.append("sig pep: "+", ".join(signal_peptide))
    if len(epitope) >0:
        sum.append("epitope: "+", ".join(epitope))
    if len(non_epitope) > 0:
        sum.append("non-epitope: "+", ".join(non_epitope))
        
    return "; ".join(sum)
        

## 1. Prep phyclip run

Should be performed already in dominant strain script

In [8]:
analysis_dir = "../analysis"
for d in os.listdir(analysis_dir):
    if os.path.isdir(os.path.join(analysis_dir,d)) and d.split("_")[0] in list(region_hemispheres.keys()):
        region = d.split("_")[0]

        phyclip_input_dir = os.path.join(analysis_dir, d, "phyclip_input_trees")
        if not os.path.isdir(phyclip_input_dir):
            os.mkdir(phyclip_input_dir)

        treetime_dir = os.path.join(analysis_dir, d, "treetime")
        if not os.path.isdir(treetime_dir):
            print (f"can not find treetime output for {region}")

        for f in os.listdir(os.path.join(treetime_dir)):
            if f.endswith(".nexus") and "divergence" in f:
                output_tree_file = os.path.join(os.path.join(phyclip_input_dir, f.replace("H3N2_HA", f"H3N2_HA_{region}").replace(".nexus", ".tree")))
                if not os.path.isfile(output_tree_file):
                    tree = dendropy.Tree.get(path=os.path.join(treetime_dir, f), schema="nexus", preserve_underscores=True)
                    tree.write(path=output_tree_file, schema="newick", suppress_internal_node_labels=True,)
                    print (output_tree_file) 
        
        
       

### 1.1. Run Phyclip

running run_phyclip_sensitivity.py manually

## 2. Results

### 2.1. get output data

In [9]:
#get phyclip output
phyclip_output = "../analysis/phyclip_sensitivity"

phyclip_files = {region:{s:{} for s in season_periods[h].keys()}  for region, h in region_hemispheres.items()}
for f in os.listdir(phyclip_output):
    if not f.startswith("cluster"):
        continue
    
    region = f.split("_")[-3]
    p = f.split("_")[-2]
    if p not in season_periods[region_hemispheres[region]].values():
         continue
            
    season = [s for s,pe in season_periods[region_hemispheres[region]].items() if pe ==p][0]
    
    cs = int(f.split("_")[3][2:])
    fdr = float(f.split("_")[4][3:])

    phyclip_files[region][season][(cs, fdr)] = os.path.join(phyclip_output, f)
  

In [10]:
#get sequence output 
sequence_files = {region:{s:{} for s in season_periods[h].keys()}  for region, h in region_hemispheres.items()}
for d in os.listdir(analysis_dir):
    if os.path.isdir(os.path.join(analysis_dir,d)) and d.split("_")[0] in list(region_hemispheres.keys()):
        region = d.split("_")[0]

        #getting from alignment dir to sequences are trimmed to CDS
        alignment_dir = os.path.join(analysis_dir,d, "alignment")
        for f in os.listdir(alignment_dir):
            if f.startswith("."):
                continue #.DS store file #livelovemac
            #determine season of interest from  file 
            p = f.split("_")[-2]
            if p not in season_periods[region_hemispheres[region]].values():
                continue
            season = [s for s,pe in season_periods[region_hemispheres[region]].items() if pe ==p][0]
            

            sequence_files[region][season] = os.path.join(alignment_dir, f)
     

In [11]:
#get metadata
metadata_files = {region:{s:{} for s in season_periods[h].keys()}  for region, h in region_hemispheres.items()}
for d in os.listdir(analysis_dir):
    if os.path.isdir(os.path.join(analysis_dir,d)) and d.split("_")[0] in list(region_hemispheres.keys()):
        region = d.split("_")[0]

        #getting from alignment dir to sequences are trimmed to CDS
        sequence_dir = os.path.join(analysis_dir,d, "sequences")
        for f in os.listdir(sequence_dir):
            if not  f.endswith(".csv"):
                continue #.DS store file #livelovemac
            #determine season of interest from  file 
            p = f.split("_")[-2]
            if p not in season_periods[region_hemispheres[region]].values():
                continue
            season = [s for s,pe in season_periods[region_hemispheres[region]].items() if pe ==p][0]
            

            metadata_files[region][season] = os.path.join(sequence_dir, f)       

In [12]:
#get tree files
tree_files = {region:{p:{} for p in season_periods[h].values()}  for region, h in region_hemispheres.items()}
for d in os.listdir(analysis_dir):
    if os.path.isdir(os.path.join(analysis_dir,d)) and d.split("_")[0] in list(region_hemispheres.keys()):
        region = d.split("_")[0]

        #getting from alignment dir to sequences are trimmed to CDS
        treetime_dir = os.path.join(analysis_dir,d, "treetime")
        for f in os.listdir(treetime_dir):
            if not f.endswith("timetree.nexus") or "dominant_clade" in f:
                continue #.DS store file #livelovemac
            #determine season of interest from  file 
            p = f.split("_")[-2]
            if p not in season_periods[region_hemispheres[region]].values():
                continue
            season = [s for s,pe in season_periods[region_hemispheres[region]].items() if pe ==p][0]

            tree_files[region][season] = os.path.join(treetime_dir, f)
     

In [13]:
#get og dominant strains
dominant_strain_file = '../data/dominant_strains.fasta'

dominant_strains = {region:{p:{} for p in flu_seasons[h].keys()}  for region, h in region_hemispheres.items()}
for r in SeqIO.parse(dominant_strain_file, "fasta"):
    region, season = r.id.split("_")[0:2]
    dominant_strains[region][season] = r.seq
    

### 2.2. Determine dominant strains

In [14]:
dominant_strain_file = '../data/sensitivity_analysis_dominant_strains.fasta'

In [15]:
redo = False
if not os.path.isfile(dominant_strain_file) or redo:
    sensitivity_results = {region:{p:{} for p in flu_seasons[h].keys()}  for region, h in region_hemispheres.items()}
    for region, sd in phyclip_files.items():
        for season, phyclipd in sd.items():
            
            for (cs, fdr), f in phyclipd.items():
                
                #determine dominant strain first 
                sequences = {r.id.split("|")[0]:r.seq[:1701].replace("-", "N").translate() for r in SeqIO.parse(sequence_files[region][season], "fasta")}
                metadata = pd.read_csv(metadata_files[region][season])
                metadata["Collection_Date"] = [d.date() for d in pd.to_datetime(metadata["Collection_Date"], errors='coerce')] 
                metadata = metadata[metadata["passage_category"].isin(passage_labels)]

                #get sequences within influenza season
                seq_ids = metadata[(metadata["Collection_Date"]>=fss)&(metadata["Collection_Date"]<=fse)]["Isolate_Id"].to_list()
                tree = dendropy.Tree.get(path=tree_files[region][season], schema="nexus")
                

                phyclip = pd.read_csv(f, sep="\t")
                fss, fse = flu_seasons[region_hemispheres[region]][season]
                dominant_clade_consensus = get_dominant_clade(tree, phyclip, sequences, fss, fse)
                if season in early_seasons and len(dominant_clade_consensus)>HA1_length_AA:
                    dominant_clade_consensus = dominant_clade_consensus[start_mature_protein-1:HA1_length_AA+start_mature_protein-1]
                    
                sensitivity_results[region][season][(cs, fdr)] = dominant_clade_consensus
        
    with open(dominant_strain_file, "w") as fw:
        for region, sd in sensitivity_results.items():
            for season, dsd in sd.items():
                for (cs, fdr), strain in dsd.items():
                    if season in early_seasons:
                        fw.write(f">{region}_{season}_cs{cs}_fdr{fdr}_HA1\n{strain}\n") 
                    else:
                        fw.write(f">{region}_{season}_cs{cs}_fdr{fdr}\n{strain}\n") 
                                
else:
    sensitivity_results = {region:{p:{} for p in flu_seasons[h].keys()}  for region, h in region_hemispheres.items()}
    for r in SeqIO.parse(dominant_strain_file, "fasta"):
        region, season = r.id.split("_")[0:2]
        cs = int(r.id.split("_")[2][2:])
        fdr = float(r.id.split("_")[3][3:])
        
        sensitivity_results[region][season][(cs, fdr)] = r.seq
    
                
        

### 2.3. compare dominant strain with the original dominant strain file

In [16]:
#get an overview of the match
sensitivity_results_overview = []
for region, sd in sensitivity_results.items():
    for season, dsd in sd.items():
        dominant_strain = dominant_strains[region][season]
        for (cs, fdr), strain in dsd.items():
            
            if strain == dominant_strain:
                sensitivity_results_overview.append([region, season, cs, fdr, "match"])
            else:
                sensitivity_results_overview.append([region, season, cs, fdr, find_mismatch_HA(dominant_strain, strain)])
                
                print (f"no match for {region} {season} cs{cs} fdr{fdr}")

sensitivity_results_overview = pd.DataFrame.from_records(sensitivity_results_overview, columns=["region", "season", "minimum cluster size", "fdr", "difference between identified dominant strain and original dominant strain (cs3, fdr0.2)"])

no match for europe 2022-2023 cs5 fdr0.1
no match for europe 2022-2023 cs5 fdr0.2
no match for aunz 2018 cs3 fdr0.1
no match for aunz 2018 cs5 fdr0.2
no match for aunz 2018 cs3 fdr0.2
no match for aunz 2018 cs5 fdr0.1


### 2.4. investigation robustness of mismatches

In [17]:
mismatch_df = sensitivity_results_overview[sensitivity_results_overview["difference between identified dominant strain and original dominant strain (cs3, fdr0.2)"]!="match"]
mismatch_df = mismatch_df.set_index(["region", "season"]).sort_index()
mismatch_df

Unnamed: 0_level_0,Unnamed: 1_level_0,minimum cluster size,fdr,"difference between identified dominant strain and original dominant strain (cs3, fdr0.2)"
region,season,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
aunz,2018,3,0.1,epitope: N171T
aunz,2018,5,0.2,epitope: N171T
aunz,2018,3,0.2,epitope: N171T
aunz,2018,5,0.1,epitope: N171T
europe,2022-2023,5,0.1,epitope: S262N; non-epitope: R33H
europe,2022-2023,5,0.2,epitope: S262N; non-epitope: R33H


visual inspection shows no Koel site mutations, so only loading epitope mutations

In [18]:
#get original epitope mutation difference results 
#dominant strain vs WHO vacine strain (VS)
ds_vs = pd.read_csv("../analysis/genetic_comparison_results/mutations_dominant_strain_vs_vaccine_strain.csv")
ds_vs["comparison"] = "DS vs VS"

#dominant strain vs reproducible selection strain at WHO timing (CSS)
ds_css = pd.read_csv("../analysis/genetic_comparison_results/mutations_dominant_strain_vs_reproducible_selection_at_WHO-timing.csv")
ds_css["comparison"] = "DS vs CSS"

#dominant strain vs later selection strain (reproducible) (LSS)
ds_lss = pd.read_csv("../analysis/genetic_comparison_results/mutations_dominant_strain_vs_reproducible_selection_at_delayed_timing.csv")
ds_lss["comparison"] = "DS vs LSS"

genetic_comparison = pd.concat([ds_vs, ds_css, ds_lss], ignore_index=True).set_index(["region", "season", "comparison"]).sort_index()

In [19]:
def get_performance_results(performance_dict):
    performance = []
    if performance_dict["DS vs CSS"] == performance_dict["DS vs VS"]:
        performance.append('reproducible selection performs the same as WHO vaccine strain')
        
        if performance_dict["DS vs LSS"] < performance_dict['DS vs CSS']:
            performance.append('added benefit from delayed selection')
    
    elif performance_dict["DS vs CSS"] < performance_dict["DS vs VS"]:
        performance.append('reproducible selection performs beter than WHO vaccine strain')
        
        if performance_dict["DS vs LSS"] < performance_dict['DS vs CSS']:
            performance.append('added benefit from delayed selection')
        

    return "; ".join(performance)
    

comps = ["DS vs VS", "DS vs CSS", "DS vs LSS"]
for (region, season) in mismatch_df.index.unique():
    gen = genetic_comparison.loc[(region, season)]
    #care about epitope mutations 
    gen = gen[gen["epitope"]==True]
    #get counts 
    original_counts = gen.reset_index().groupby("comparison")["H3 mutation"].count().to_dict()
    for comp in comps:
        if comp not in original_counts.keys():
            original_counts[comp] = 0
    
    df = mismatch_df.loc[(region, season)]
    
    for l in df["difference between identified dominant strain and original dominant strain (cs3, fdr0.2)"].unique():
        if not 'epitope' in l:
            continue
        
        ep_muts = [i for i in l.split(";") if "epitope" in i][0].lstrip("epitope: ").split(",")
        
        updated_counts = original_counts.copy()
      
        for ep_mut in ep_muts:
            h3_pos = int(ep_mut[1:-1])
            
            for comp in comps:
                try:
                    ml = gen.loc[comp]["mature position"].tolist()
                    if h3_pos not in ml:
                        updated_counts[comp] += 1
                except:
                    #comparison not in gen > count zero so new mutations is a count
                    updated_counts[comp] += 1

        original_performance = get_performance_results(original_counts)
        updated_performance = get_performance_results(original_counts)
        
        if original_performance != updated_performance:
            print ("time to investigate")
            
        else:
            combos =df[df["difference between identified dominant strain and original dominant strain (cs3, fdr0.2)"]==l].set_index(["minimum cluster size", "fdr"]).index.tolist() 
            print (f"similar performance to original performance for {region} {season}")
            print ("combonations tested: ", combos)
            print ()
        
        
        
                
                


similar performance to original performance for aunz 2018
combonations tested:  [(3, 0.1), (5, 0.2), (3, 0.2), (5, 0.1)]

similar performance to original performance for europe 2022-2023
combonations tested:  [(5, 0.1), (5, 0.2)]

