In [33]:
# import packages 
import os, subprocess, dendropy
from Bio import SeqIO
from utils import *
import pandas as pd
from datetime import datetime

In [34]:
# set subtype and segment
subtype = "H3N2"
segment = "PA"
timeperiod = "0019"

In [35]:
main_dir = f"/Users/annelies/Desktop/flu-evolution/{subtype}_{segment}_gisaid_sequences"
raw_fasta = os.path.join(main_dir, "raw", f"{subtype}_{segment}_gisaid_raw.fasta")
metadata_file = os.path.join(main_dir, "raw", f"{subtype}_{segment}_metadata_gisaid_raw.csv")

In [36]:
#check if raw exists and else create to complete raw from where direct gisaid download is stored
if not os.path.isfile(raw_fasta):
    gisaid_download_dir = os.path.join(main_dir, "..", "gisaid_raw", f"{subtype}_{segment}_{timeperiod}")
    records = []
    for f in os.listdir(gisaid_download_dir):
        if f.endswith(".fasta"):
            for r in SeqIO.parse(os.path.join(gisaid_download_dir, f), "fasta"):
                records.append(r)
        elif f.endswith(".xls"):
            try:
                df = pd.concat([df, pd.read_excel(os.path.join(gisaid_download_dir, f))])
            except:  
                df = pd.read_excel(os.path.join(gisaid_download_dir, f))

    with open(raw_fasta, "w") as fw:
        SeqIO.write(records, fw, "fasta")

    df.to_csv(metadata_file, index=False)

In [37]:
#set files and directory paths
cluster_dir = os.path.join(main_dir,"cluster_seqs")

alignment_dir = os.path.join(main_dir, "alignment")
tree_dir = os.path.join(main_dir, "tree")

for d in [cluster_dir, alignment_dir, tree_dir]:
    if not os.path.isdir(d):
        os.mkdir(d)

clean_fasta = os.path.join(main_dir, "outliers_removed", f"{subtype}_{segment}_gisaid_raw.fasta")

molecular_clock_outliers = f"../data/to_drop/{subtype}_{segment}_gitr.csv"

clean_metadata = os.path.join(main_dir, "outliers_removed", f"{subtype}_{segment}_metadata_gisaid_{timeperiod}.xlsx")

cluster_results = os.path.join(main_dir, "cdhit_output", f"{subtype}_{segment}_clus.clstr")


In [38]:
#remove low quality and max ambig sequences
redo = True

mco = [] #molecular clock outliers
to_remove = []
if os.path.isfile(molecular_clock_outliers):
    with open(molecular_clock_outliers, "r") as fr:
        for l in fr:
            if "reason" in l:
                continue
            mco.append(l.strip("\n").split(","))
            to_remove.append(l.strip("\n").split(",")[0])

if not os.path.isfile(clean_fasta) or redo==True:
    records = []
    for r in SeqIO.parse(raw_fasta, "fasta"):
        if r.id.split("|")[0] not in to_remove:
            if not check_sequence_length(segment, str(r.seq), 0.95): #check length
                mco.append([r.id.split("|")[0], "too short"])
            elif not check_max_ambig(str(r.seq), 0.01):
                mco.append([r.id.split("|")[0], "too many ambiguous nucleotides"])
            else:
                records.append(r)

    with open(clean_fasta, "w") as fw:
        SeqIO.write(records,fw,"fasta")

In [39]:
#filter outliers from metadata
metadata = pd.read_csv(metadata_file)
bad_ids = [l[0] for l in mco] 
metadata = metadata[~metadata["Isolate_Id"].isin(bad_ids)]
metadata.to_excel(clean_metadata, index="False")

  metadata = pd.read_csv(metadata_file)


## Run CD-HIT
Run CD-hit manually on the command line with the following command:  
`cd-hit -i {clean_fasta} -o {cluster_results} -c 0.998 -t {number of threads}`

In [40]:
#parse cd hit file
with open(cluster_results, "r") as f:
    lines = [l.strip("\n") for l in f]

clusters = {}
for l in lines:
    if l.startswith(">"):
        c = l.lstrip(">")
        clusters[c] = []
    else:
        sid = l.split(">")[-1].split("|")[0]
        clusters[c].append(sid)

In [41]:
#get dates and set season definitions
metadata = pd.read_csv(metadata_file)
metadata["Collection_Date"] = pd.to_datetime(metadata["Collection_Date"],format="%Y-%m-%d" )
#get season definitions > doing this manually
seasons = {"0001":[datetime(2000,5,1), datetime(2001,4,30)], "0102":[datetime(2001,5,1), datetime(2002,4,30)], "0203":[datetime(2002,5,1), datetime(2003,4,30)],
           "0304":[datetime(2003,5,1), datetime(2004,4,30)], "0405":[datetime(2004,5,1), datetime(2005,4,30)], "0506":[datetime(2005,5,1), datetime(2006,4,30)],
           "0607":[datetime(2006,5,1), datetime(2007,4,30)], "0708":[datetime(2007,5,1), datetime(2008,4,30)], "0809":[datetime(2008,5,1), datetime(2009,4,30)],
           "0910":[datetime(2009,5,1), datetime(2010,4,30)], "1011":[datetime(2010,5,1), datetime(2011,4,30)], "1112":[datetime(2011,5,1), datetime(2012,4,30)],
           "1213":[datetime(2012,5,1), datetime(2013,4,30)], "1314":[datetime(2013,5,1), datetime(2014,4,30)], "1415":[datetime(2014,5,1), datetime(2015,4,30)], 
           "1516":[datetime(2015,5,1), datetime(2016,4,30)], "1617":[datetime(2016,5,1), datetime(2017,4,30)], "1718":[datetime(2017,5,1), datetime(2018,4,30)],
           "1819":[datetime(2018,5,1), datetime(2019,4,30)]}

  metadata = pd.read_csv(metadata_file)


In [42]:
#split clusters per season
def determine_season(d):
    """
    determine the flu seasons based on the date
    """
    ss = list(seasons.keys())
    season_starts = [v[0] for v  in seasons.values()]
    season_ends = [v[-1] for v in seasons.values()]

    if d > season_ends[-1] or d < season_starts[0]:
        return None

    for i, se in enumerate(season_ends):
        if d < se:
            return ss[i]

clus_representatives = {s:{} for s in seasons.keys()}#seq to select per season
season_clus_ids= {s:{} for s in seasons.keys()}#seq to select per season

singles = {s:[] for s in seasons.keys()}

for c, ids in clusters.items():

    if len(ids)==1:
        d = str(np.datetime_as_string(metadata[metadata["Isolate_Id"]==ids[0]]["Collection_Date"].values[0]))
        d = datetime(int(d.split("-")[0]), int(d.split("-")[1]), int(d.split("-")[2].split("T")[0]))
        s = determine_season(d) 
        if s:
            singles[s].append(ids[0])
        continue

    dates = [str(np.datetime_as_string(metadata[metadata["Isolate_Id"]==sid]["Collection_Date"].values[0])) for sid in ids]
    dates = [datetime(int(d.split("-")[0]), int(d.split("-")[1]), int(d.split("-")[2].split("T")[0])) for d in dates]
    
    csons = [determine_season(d) for d in dates]
    ids = [sid for i, sid in enumerate(ids) if csons[i]!=None]
    dates = [d for i,d in enumerate(dates) if csons[i]!=None]
    csons = [cson for cson in csons if cson!=None]


    #if all seqs in one season select min and max date
    for cson in set(csons):
        cson_ids = [ids[i] for i, x in enumerate(csons) if x == cson]
        cson_dates = [dates[i] for i, x in enumerate(csons) if x == cson]
            
        eldest = cson_ids[cson_dates.index(min(cson_dates))]
        youngest = cson_ids[cson_dates.index(max(cson_dates))]

        season_clus_ids[cson][c] =  cson_ids

        clus_representatives[cson][c] = []
        if eldest not in clus_representatives[cson][c]:
            clus_representatives[cson][c].append(eldest)
        if youngest not in clus_representatives[cson][c]:
            clus_representatives[cson][c].append(youngest)



In [43]:
#get sequence files per batch
#season_sets = [["0001","0102","0203", "0304", "0405", "0506", "0607", "0708", "0809", "0910"],["1011","1112", "1213", "1314","1415"], ["1516", "1617", "1718", "1819"]] #H3N2
season_sets = [["0910", "1011", "1112", "1213", "1314", "1415"],["1516", "1617", "1718","1819"]] #H1N1pdm

for s_set in season_sets:
    ids = []
    for s in s_set:
        for l in clus_representatives[s].values():
            ids.extend(l)

        ids.extend(singles[s])

    records = []
    for r in SeqIO.parse(clean_fasta, "fasta"):
        if r.id.split("|")[0] in ids:
            for i in [":", ",", "(", ")", "'"]:
                if i in r.id:
                    r.id = r.id.replace(i, "")
                    r.name = r.name.replace(i, "")
                    r.description = r.description.replace(i, "")
            records.append(r)

    out_fasta = f"{subtype}_{segment}_gisaid_{s_set[0]}_{s_set[-1]}.fasta"
    with open(os.path.join(cluster_dir, out_fasta), "w") as fw:
       SeqIO.write(records,fw,"fasta")

## Construct MSA
MSA via MAFFt with following command `mafft --auto --thread 3 --keeplength --addfragments {cluster_file} {reference} > {alignment_file}`
removing reference afterwards > as reference used in from 1968 and will f*** with molecular clock

In [44]:
#remove reference from aligment
for f in os.listdir(alignment_dir):
    if not f.startswith("."):
        ff = os.path.join(alignment_dir,f)

        records = list(SeqIO.parse(ff, "fasta"))
        if "-ref" in records[0].id:
            with open(ff,"w")as fw:
                SeqIO.write(records[1:],fw,"fasta")

## Construct tree
Don't need exact precision and want trees as fast as possible, so therefore using fasttree  
command used `fasttree -gtr -nt {alignment} > {tree}`

## Tempest 
getting dates for tempest > takes dates in order of tree  
load annotated trees in tempest > best-fitting root > root-to-tip and residuals to find the outliers

In [45]:
#annotate tree files with dates
for f in os.listdir(tree_dir):
    labels = []
    if f.endswith(".tree") and "annotated" not in f:
        tree = dendropy.Tree.get(path=os.path.join(tree_dir,f), schema="newick")
        for l in tree.leaf_node_iter():
            labels.append(l.taxon.label.replace(" ", "_"))

        label_dates = {}
        for label in labels:
            rid = label.split("|")[0]
            d = metadata[metadata["Isolate_Id"]==rid]["Collection_Date"].values[0].astype(str).split("T")[0]
            label_dates[label] = f"{label}|{d}"

        with open(os.path.join(tree_dir,f), "r") as fr:
            tree_line = fr.readline()

        for label, new_label in label_dates.items():
            tree_line  = tree_line.replace(label, new_label)

        with open(os.path.join(tree_dir,f.replace(".tree", "_annotated.tree")), "w") as fw:
            fw.write(tree_line)

## Removing problematic clusters

In [46]:
tempest_outliers = ["EPI_ISL_286128", "EPI_ISL_498990", "EPI_ISL_309766", "EPI_ISL_498987","EPI_ISL_12995401"]

In [47]:
#make reverse list of clusters
reversed_clusters = {c:k for k,v in clusters.items() for c in v}

#clus_representatives_reversed = {i:f"{c}-{s}" for s, d in clus_representatives.items() for c, v in d.items() for i in  v}
season_clus_ids_reversed = {i:f"{c}-{s}" for s, d in season_clus_ids.items() for c, v in d.items() for i in  v}

for outlier in tempest_outliers:
    
    if outlier in season_clus_ids_reversed.keys(): #else outlier is already removed
        clus, cson = season_clus_ids_reversed[outlier].split("-")
        cluster_season_ids = season_clus_ids[cson][clus]
        
        for cid in cluster_season_ids:
            l = [cid, "molecular clock outlier"]
            if not any(i==l for i in mco):
                mco.append(l)     


    else:
        if outlier not in to_remove:
            for c, l in clusters.items():
                if outlier in l:
                    if not any(i==[outlier, "molecular clock outlier"] for i in mco):
                        mco.append([outlier, "molecular clock outlier"])
        else:
            if not any(i==[outlier, "molecular clock outlier"] for i in mco):
                mco.append([outlier, "molecular clock outlier"])         
    

In [48]:
#update outlier file
mco = pd.DataFrame.from_records(mco, columns=["Isolate_Id","reason"])
mco.to_csv(molecular_clock_outliers, index=False)