## Reproducible strain selection based on 2-global AA consensus for Northern and Southern Hemispheres

**Assumptions:**
- Northern hemisphere: 
  - flu season starts in October of year X and last through April of subsequent year X+1, the vaccine for this season is picked in February of year X
  - 2011/2012 season is the cutoff from using HA1 sequence, using complete sequences afterwards
- Southern hemisphere: 
  - flu season starts in March of year X and last through September of year X, the vaccine for this season in picked in September of the preceding year X-1
  -  2011 season is the cutoff from using HA1 sequence, using complete sequences afterwards


## 0. General

### 0.1. Load Libraries

In [1]:
import os, dendropy, math, sys
import pandas as pd, numpy as np
from Bio import SeqIO
import calendar

from datetime import date, timedelta
from dateutil.relativedelta import relativedelta

sys.path.insert(1, "../scripts")

from labels import *


### 0.2 general variables

In [2]:
start_mature_protein = 17
start_mature_protein_nuc = 17*3

HA1_length_AA = 329 #in mature protein 
HA1_length_nuc= 329*3

protein_length = 567
sequence_length = 1701

In [3]:
region_hemispheres = {"us":"nh", "europe":"nh", "aunz":"sh"}
hemispheres = ["nh", "sh"]

although vaccine strain selection happens some time late february or september, setting vaccine strain selection to the first of the month for coding convienence > working with months mostly anyway

In [4]:
pp = 24 #preceding period in months
pdur = 16 #period duration
flu_seasons = {h:{} for h in hemispheres}
vaccine_selection = {h:{} for h in hemispheres}
for y in range(2002, 2024):
    #northern hemisphere flu season
    season = f"{y}-{y+1}"
    if y != 2023:
        flu_seasons["nh"][season] = (date(y, 10, 1), date(y+1, 4, 30))
        #northern hemisphere vaccine strain selection moment
        vaccine_selection["nh"][season] = date(y,2,1)  
    
    if y==2002:
        continue
    #southern hemisphere flu season
    flu_seasons["sh"][str(y)] = (date(y, 3, 1), date(y, 9, 30))
    #southern hemisphere vaccine strain selection moment
    vaccine_selection["sh"][str(y)] = date(y-1,9,1)

season_periods = {h:{} for h in hemispheres}
period_dates = {}
for h, sd in flu_seasons.items():
    for season, (fss, fse) in sd.items():
        vsd = vaccine_selection[h][season]
        #something with periods in treason doesn't seem to be correct but don't wanna waste time on that for now
        if h == "nh":
            ps, pe  = vsd + relativedelta(months=-pp), vsd + relativedelta(months=+pdur-1)
        else:
            ps, pe  = vsd + relativedelta(months=-pp), vsd + relativedelta(months=+pdur-1)

        
        period = f"{str(ps.year)[2:]}{'0'+str(ps.month) if len(str(ps.month))==1 else str(ps.month)}-{str(pe.year)[2:]}{'0'+str(pe.month) if len(str(pe.month))==1 else str(pe.month)}"
        season_periods[h][season] = period
        period_dates[period] = (ps, pe)

#cut off for HA1 
early_season_cutoff = {'nh':'2011-2012', 'sh':'2011'}
early_season_cutoff_dates = {h:flu_seasons[h][season][-1] for h, season in early_season_cutoff.items()}
early_seasons = [] #get list
for h, cutoff in early_season_cutoff.items():
    csons = list(flu_seasons[h].keys())
    for i, season in enumerate(csons):
        if i <= csons.index(cutoff):
            early_seasons.append(season)

### 0.3. Functions

In [5]:
def check_sequence_length(segment, sequence, mnlp=0.95, mnsl=None):
    """check if sequence is smaller than {mnlp}% of the reference sequence for the segment"""
    #for reference_1968 
    reference_lengths = {'PB2':2280,'PB1':2274,'PA':2151,'HA':1701,'NP':1497,'NA':1410,'M':890,'NS':800}
    if mnsl is None:
        if len(sequence) < np.round(reference_lengths[segment]*mnlp):
            return False 
        else:
            return True
    else:
        if len(sequence) < mnsl:
            return False
        else:
            return True

def check_max_ambig(sequence, mxa=0.01):
    """check if sequence has ambiguous nucleotide % greater than the mxa"""
    #also checking if there aren't any illegal characters present
    valid_nucs = ['A','T','C','G','R','Y','B','D','K','M','H','V','S','W','N']
    if not all(i in valid_nucs for i in sequence.upper()):
        return False

    N_ambig = sequence.upper().count("N")
    
    if round(N_ambig/len(sequence),2) > mxa:
        return False
    else:
        return True

def get_consensus_sequence(sequences):
    seqs = []
    for seq in sequences:
        if type(seq) == list:
            seqs.append(seq)
        else:
            seqs.append(list(seq))
    seqs = pd.DataFrame.from_records(seqs)#.reset_index(drop=)
    consensus = seqs.mode().iloc[0,:].to_list()
    return "".join(consensus)

## 1. Data prep 
prepare raw GISAID data 


### 1.1. Get data

In [6]:
gisaid_raw_data_folder = "../data/gisaid_data/raw_downloads/H3N2_HA"
outliers = pd.read_csv("../data/outliers/H3N2_HA_gitr.csv")["Isolate_Id"].to_list()


In [7]:
#get raw data
sequences, metadata= {}, []
for f in os.listdir(gisaid_raw_data_folder):
    if "010199" in f:
        continue #dont need data from 99
    if f.endswith(".fasta"):
        for r in SeqIO.parse(os.path.join(gisaid_raw_data_folder, f), "fasta"):
            sequences[r.id] = r.seq
    elif f.endswith(".xls"):
        try:
            metadata = pd.concat([metadata, pd.read_excel(os.path.join(gisaid_raw_data_folder, f), usecols=metcols)])
            metadata = metadata.drop_duplicates().reset_index(drop=True)
        except:
            metadata = pd.read_excel(os.path.join(gisaid_raw_data_folder, f), usecols=metcols)

#parse collection date as date
metadata["Collection_Date"] =pd.to_datetime(metadata["Collection_Date"], format="%Y-%m-%d")

#deterime passage history 
for i, row in metadata.iterrows():
    ph = row["Passage_History"]
    if ph in cpl:
        metadata.loc[i, "Passage_History"] = "clinical"
    elif ph in cbpl:
        metadata.loc[i, "Passage_History"] = "cell-based MDCK or SIAT"
    elif ph in epl:
        metadata.loc[i, "Passage_History"] = "egg-based"
    elif ph in ocbpl:
        metadata.loc[i, "Passage_History"] = "cell-based other"
    else:
        metadata.loc[i, "Passage_History"] = "unknown or unclear"

#determine hemisphere and country
for i, row in metadata.iterrows():
    l = row["Location"]
    #determine country
    try:
        country = l.split(" / ")[1]
    except:
        if l == "Sudan, South":
            country == "sudan"
        else:
            country = pd.NA
    country = cs[country] if country in cs.keys() and not pd.isna(country) else country
    if not pd.isna(country) and country.upper() == country:
        country = country.lower().capitalize()
    metadata.loc[i, "country"] = country
    
    #determine hemisphere
    if not pd.isna(country):
        if country in nhc:
            metadata.loc[i, "hemisphere"] = "northern"
        elif country in shc:
            metadata.loc[i, "hemisphere"] = "southern"
        else:
            print (country)
            metadata.loc[i, "hemisphere"] = pd.NA
    else:
        metadata.loc[i, "hemisphere"] = pd.NA

#sort by collection date
metadata = metadata.sort_values(["Collection_Date"]).reset_index(drop=True)

#remove sequences with unknown country 
print (f"total number of sequences downloaded from GISAID: {len(metadata)}")
print (f"number of sequence with unknown country of origin {len(metadata) - len(metadata.dropna(subset='country'))}.\t (Sequences were removed)")
metadata = metadata.dropna(subset='country').reset_index(drop=True)     

found_outliers = []
for sid in metadata["Isolate_Id"].unique():
    if sid in outliers:
        found_outliers.append(sid)
        continue
print (f"number of sequences that are known outliers: {len(found_outliers)}.\t (Sequences were removed)")
# check for valid nucleotides in sequences
valid_nucs = ['A','T','C','G','R','Y','B','D','K','M','H','V','S','W','N']
weird_nucs = []
for sid, seq in sequences.items():
    if not all(i in valid_nucs for i in seq.upper()):
        weird_nucs.append(sid.split("|")[0])
print (f"number of sequences with illegal characters in their sequence: {len(weird_nucs)}.\t (Sequences were removed)")
# metadata = metadata[~metadata["Isolate_Id"].isin(found_outliers + weird_nucs)]
metadata = metadata[~metadata["Isolate_Id"].isin(weird_nucs)]


total number of sequences downloaded from GISAID: 151892
number of sequence with unknown country of origin 20.	 (Sequences were removed)
number of sequences that are known outliers: 3014.	 (Sequences were removed)
number of sequences with illegal characters in their sequence: 123.	 (Sequences were removed)


### 1.2. Align sequences
check how many sequences there are with complete (HA1) nucleotide sequence 

In [8]:
#make alignment 
alignment = "../data/gisaid_data/gisaid_2000_2023_H3N2_HA_aligment.fasta"
threads = 12
ref_ha = "../data/references/H3N2_HA.fasta"

if not os.path.isfile(alignment):
    seq_file = "../data/gisaid_data/gisaid_2000_2023_H3N2_HA_filtered.fasta"
    with open(seq_file, "w") as fw:
        for sid, seq in sequences.items():
            fw.write(f">{sid}\n{seq}\n")

    cmd = ['mafft', '--auto', '--thread', str(threads), '--keeplength', '--addfragments', seq_file , ref_ha,'>', alignment]
    os.system(" ".join(cmd))

aligned_seqs = {r.id.split("|")[0]:r.seq for r in list(SeqIO.parse(alignment, "fasta"))[1:]} 
aligned_ref = list(SeqIO.parse(alignment, "fasta"))[0].seq

In [9]:
#get sequences that cover HA1
HA1_seqids = []
for sid, seq in aligned_seqs.items():
    HA1_seq = seq[start_mature_protein_nuc -1:start_mature_protein_nuc +HA1_length_nuc-1]
    if check_max_ambig(HA1_seq, mxa=0.01):
        HA1_seqids.append(sid)

In [10]:
#get sequence with complete (>=95%) sequences
complete_seqids = []
for sid, seq in aligned_seqs.items():
    if check_max_ambig(seq[:1701], mxa=0.01) and check_sequence_length("HA", seq[:1701]):
        complete_seqids.append(sid)

In [11]:
#print results
print (f"number of sequences with complete HA1 nucleotide sequence (len=948) and % ambiguous nucleotides < 1% : {len(HA1_seqids)}")
print (f"\t- clinical samples: {len(metadata[(metadata['Passage_History']=='clinical')&(metadata['Isolate_Id'].isin(HA1_seqids))])}")
print (f"\t- cell-based (MDCK or SIAT) samples: {len(metadata[(metadata['Passage_History']=='cell-based MDCK or SIAT')&(metadata['Isolate_Id'].isin(HA1_seqids))])}")
print (f"\t- cell-based (other) samples: {len(metadata[(metadata['Passage_History']=='cell-based other')&(metadata['Isolate_Id'].isin(HA1_seqids))])}")
print (f"\t- egg-based samples: {len(metadata[(metadata['Passage_History']=='egg-based')&(metadata['Isolate_Id'].isin(HA1_seqids))])}")
print (f"\t- unknown or unclear samples: {len(metadata[(metadata['Passage_History']=='unknown or unclear')&(metadata['Isolate_Id'].isin(HA1_seqids))])}")

print (f"number of sequences with complete nucleotide sequence (len=1701) and % ambiguous nucleotides < 1% : {len(complete_seqids)}")
print (f"\t- clinical samples: {len(metadata[(metadata['Passage_History']=='clinical')&(metadata['Isolate_Id'].isin(complete_seqids))])}")
print (f"\t- cell-based (MDCK or SIAT) samples: {len(metadata[(metadata['Passage_History']=='cell-based MDCK or SIAT')&(metadata['Isolate_Id'].isin(complete_seqids))])}")
print (f"\t- cell-based (other) samples: {len(metadata[(metadata['Passage_History']=='cell-based other')&(metadata['Isolate_Id'].isin(complete_seqids))])}")
print (f"\t- egg-based samples: {len(metadata[(metadata['Passage_History']=='egg-based')&(metadata['Isolate_Id'].isin(complete_seqids))])}")
print (f"\t- unknown or unclear samples: {len(metadata[(metadata['Passage_History']=='unknown or unclear')&(metadata['Isolate_Id'].isin(complete_seqids))])}")

number of sequences with complete HA1 nucleotide sequence (len=948) and % ambiguous nucleotides < 1% : 137599
	- clinical samples: 62706
	- cell-based (MDCK or SIAT) samples: 32642
	- cell-based (other) samples: 10763
	- egg-based samples: 1104
	- unknown or unclear samples: 30384
number of sequences with complete nucleotide sequence (len=1701) and % ambiguous nucleotides < 1% : 121410
	- clinical samples: 59143
	- cell-based (MDCK or SIAT) samples: 30258
	- cell-based (other) samples: 10024
	- egg-based samples: 996
	- unknown or unclear samples: 20989


In [12]:
# add to metadata
for i, row in metadata.iterrows():
    seqid = row["Isolate_Id"]
    if seqid in complete_seqids:
        metadata.loc[i, "sequence_length"] = "complete"
    elif seqid in HA1_seqids:
        metadata.loc[i, "sequence_length"] = "HA1"
    else:
        metadata.loc[i, "sequence_length"] = "incomplete"

metadata = metadata[metadata["sequence_length"]!="incomplete"]

### 2. Reproducible strain selection based on global AA consensus

In [13]:
#filter metadata on passage history
passage_labels = ['clinical', 'cell-based MDCK or SIAT', 'cell-based other']
metadata = metadata[metadata["Passage_History"].isin(passage_labels)]

In [14]:
later_selection_delay = 3 #months

global_consensus= []
b = False
for region, h in region_hemispheres.items():
    for season, (fss, fse) in flu_seasons[h].items():

        vaccine_strain_selection = vaccine_selection[h][season]

        #reproducible selection at current timing
        cts = vaccine_strain_selection-relativedelta(months=2)
        cte= vaccine_strain_selection-relativedelta(days=1)

        lab = f"{calendar.month_abbr[cts.month]}{str(cts.year)[-2:]}-{calendar.month_abbr[cte.month]}{str(cte.year)[-2:]}"

        seq_ids = metadata[(metadata["Collection_Date"]>=pd.to_datetime(cts))&(metadata["Collection_Date"]<=pd.to_datetime(cte))]["Isolate_Id"].tolist()
        seqs = [v[:1701].replace("-", "n").translate() for k, v in aligned_seqs.items() if k.split("|")[0] in seq_ids]

        if season in early_seasons:
            seqs = [s[start_mature_protein-1:HA1_length_AA+start_mature_protein-1] for s in seqs]

        consensus = get_consensus_sequence(seqs)
        observed = False
        for s in seqs:
            if str(s)==consensus:
                observed=True
                break


        global_consensus.append([region, season, lab, "WHO-timing",consensus, observed])

        #reproducible selection at delayed timing 
        dts = cts + relativedelta(months=later_selection_delay)
        dte= cte + relativedelta(months=later_selection_delay)

        lab = f"{calendar.month_abbr[dts.month]}{str(dts.year)[-2:]}-{calendar.month_abbr[dte.month]}{str(dte.year)[-2:]}"

        seq_ids = metadata[(metadata["Collection_Date"]>=pd.to_datetime(dts))&(metadata["Collection_Date"]<=pd.to_datetime(dte))]["Isolate_Id"].tolist()
        seqs = [v[:1701].replace("-", "n").translate() for k, v in aligned_seqs.items() if k.split("|")[0] in seq_ids]

        if season in early_seasons:
            seqs = [s[start_mature_protein-1:HA1_length_AA+start_mature_protein-1] for s in seqs]

        consensus = get_consensus_sequence(seqs)
        observed = False
        for s in seqs:
            if str(s)==consensus:
                observed=True
                break

        global_consensus.append([region, season, lab, "delayed-timing",consensus, observed])



global_consensus = pd.DataFrame.from_records(global_consensus, columns=["region", "season", "months", "timing", "sequence", "observed in global data"])      


In [15]:
global_consensus_file = "../data/reproducible_selection_strains.fasta"

with open(global_consensus_file, "w") as fw:
    for i, row in global_consensus.iterrows():
        fw.write(f">{row['region']}_{row['season']}_{row['months']}_{row['timing']}\n{row['sequence']}\n")