## Antigenic comparison between the dominant strain and vaccine strain and reprocubily selected strains

## 0. General

### 0.1. Load libraries

In [1]:
import os, dendropy, math, sys, json
import pandas as pd, numpy as np
from Bio import SeqIO
import calendar
from collections import deque
from scipy.spatial.distance import cdist

from datetime import date, timedelta
from dateutil.relativedelta import relativedelta

sys.path.insert(1, "../scripts")

from labels import *

### 0.2. General variables

In [2]:
start_mature_protein = 17
HA1_length_AA = 329 #in mature protein 
protein_length = 567
sequence_length = 1701

In [3]:
region_hemispheres = {"us":"nh", "europe":"nh", "aunz":"sh"}
hemispheres = ["nh", "sh"]

In [4]:
pp = 24 #preceding period in months
pdur = 16 #period duration
flu_seasons = {h:{} for h in hemispheres}
vaccine_selection = {h:{} for h in hemispheres}
for y in range(2002, 2024):
    #northern hemisphere flu season
    season = f"{y}-{y+1}"
    if y != 2023:
        flu_seasons["nh"][season] = (date(y, 10, 1), date(y+1, 4, 30))
        #northern hemisphere vaccine strain selection moment
        vaccine_selection["nh"][season] = date(y,2,1)  
    
    if y==2002:
        continue
    #southern hemisphere flu season
    flu_seasons["sh"][str(y)] = (date(y, 3, 1), date(y, 9, 30))
    #southern hemisphere vaccine strain selection moment
    vaccine_selection["sh"][str(y)] = date(y-1,9,1)

season_periods = {h:{} for h in hemispheres}
period_dates = {}
for h, sd in flu_seasons.items():
    for season, (fss, fse) in sd.items():
        vsd = vaccine_selection[h][season]
        #something with periods in treason doesn't seem to be correct but don't wanna waste time on that for now
        if h == "nh":
            ps, pe  = vsd + relativedelta(months=-pp), vsd + relativedelta(months=+pdur-1)
        else:
            ps, pe  = vsd + relativedelta(months=-pp), vsd + relativedelta(months=+pdur-1)

        
        period = f"{str(ps.year)[2:]}{'0'+str(ps.month) if len(str(ps.month))==1 else str(ps.month)}-{str(pe.year)[2:]}{'0'+str(pe.month) if len(str(pe.month))==1 else str(pe.month)}"
        season_periods[h][season] = period
        period_dates[period] = (ps, pe)

#cut off for HA1 
early_season_cutoff = {'nh':'2011-2012', 'sh':'2011'}
early_season_cutoff_dates = {h:flu_seasons[h][season][-1] for h, season in early_season_cutoff.items()}
early_seasons = [] #get list
for h, cutoff in early_season_cutoff.items():
    csons = list(flu_seasons[h].keys())
    for i, season in enumerate(csons):
        if i <= csons.index(cutoff):
            early_seasons.append(season)

In [5]:
epitope_sites= {"A":[122,124,126,130,131,132,133,135,137,138,140,142,143,144,145,146,150,152,168], 
                "B":[128,129,155,156,157,158,159,163,165,186,187,188,189,190,192,193,194,196,197,198], 
                "C":[44,45,46,47,48,50,51,53,54,273,275,276,278,279,280,294,297,299,300,304,305,307,308,309,310,311,312], 
                "D":[96,102,103,117,121,167,170,171,172,173,174,175,176,177,179,182,201,203,207,208,209,212,213,214,215,216,217,218,219,226,227,228,229,230,238,240,242,244,246,247,248,], 
                "E":[57,59,62,63,67,75,78,80,81,82,83,86,87,88,91,92,94,109,260,261,262,265]}
epitope_positions = [s for sites in epitope_sites.values() for s in sites]

#rbs
rbs_positions = [98, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 144, 145, 146, 153, 154, 155, 156, 157, 158, 159,
                 183, 184, 185, 186, 187, 188, 189, 190, 190, 191, 192, 193, 194, 195, 196, 219, 220, 221, 222, 223, 224, 
                 225, 226, 227, 228,]
#rbs_positions = [131, 132, 133, 134, 135, 136, 137, 138, 140, 183, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228,]
koel_sites = [145, 189, 193, 156, 159, 158, 155]

#amino acid list for coding 
aa_list = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]

### 0.3. functions

In [6]:
def get_most_antigenic_similar_sequences(seq, seqdict):
    """returns list of strain names of closest sequences in seqdict and the number of diffences between them """
    min_diff = len(seq)
    cseq = []
    
    for strain, s in seqdict.items():
        ndiff= len([p for p in epitope_positions if seq[p-1]!=s[p-1]])
        if ndiff < min_diff:
            min_diff = ndiff
            cseq = [strain]
        elif ndiff == min_diff:
            cseq.append(strain)

    return cseq, min_diff

def get_most_similar_sequences(seq, seqdict):
    """returns list of strain names of closest sequences in seqdict and the number of diffences between them """
    min_diff = len(seq)
    cseq = []
    
    for strain, s in seqdict.items():
        ndiff= len([i for i in range(len(seq))if seq[i]!=s[i]])
        if ndiff < min_diff:
            min_diff = ndiff
            cseq = [strain]
        elif ndiff == min_diff:
            cseq.append(strain)

    return cseq, min_diff

def can_merge_tables(tables, merge_dict):
    visited = set()
    queue = deque([tables[0]])

    while queue:
        current = queue.popleft()
        if current in visited:
            continue
        visited.add(current)
        for neighbor in merge_dict.get(current, []):
            if neighbor in tables and neighbor not in visited:
                queue.append(neighbor)

    return visited == set(tables)

def count_items_in_tables(tables, table_lists, table_strains):
    counts = []
    for lst, mapping in zip(table_lists, table_strains):
        items_in_tables = set()
        for table in tables:
            items_in_tables.update(mapping.get(table, []))
        counts.append(len(items_in_tables.intersection(lst)))
    return counts

def write_antigens(fw, mapcode, ag_list, ag_name):
    #there must be a better way to do this but this works
    if len(ag_list) <= 10:
        if len(ag_list) == 1:
            fw.write(f'{ag_name} <- agNames(map{mapcode}) %in% c("{ag_list[0]}")\n')
        else:
            fw.write(f'{ag_name} <- agNames(map{mapcode}) %in% c{tuple(ag_list)}\n')
    else:
        l = ','.join([f'"{ag}"' for ag in ag_list[0:10]])
        fw.write(f"{ag_name} <- agNames(map{mapcode}) %in% c({l},\n")
        if len(range(10, len(ag_list)-10, 10)) > 0:
            for i in range(10, len(ag_list)-10, 10):
                l = ','.join([f'"{ag}"' for ag in ag_list[i:i+10]])
                if i > len(ag_list)-10:
                    fw.write(f"\t{l})\n")
                else:
                    fw.write(f"\t{l},\n")
            try:
                if i < len(ag_list)-10:
                    l = ','.join([f'"{ag}"' for ag in ag_list[i+10:]])
                    fw.write(f"\t{l})\n")
            except:
                l = ','.join([f'"{ag}"' for ag in ag_list[10:]])
                fw.write(f"'t{l})\n")

        else:
            l = ','.join([f'"{ag}"' for ag in ag_list[10:]])
            fw.write(f"\t{l})\n")

def centroid(points):
    x = [p[0] for p in points]
    y= [p[1] for p in points]
    l = len(points)
    centroid_x = sum(x)/l
    centroid_y = sum(y)/l
    return [centroid_x, centroid_y]

def get_mutations(ref, seq, sl="complete"):
    muts = []
    for i, b in enumerate(ref):
        if seq[i] == "X" or b == "X":
            continue 
        if seq[i] != b:
            if sl == "HA1":
                pos = start_mature_protein+i
                mpos = i+1
                protein = "HA1"
            else: # if sl=="complete"
                pos = i+1
                mpos = pos-start_mature_protein+1
                protein = "signal protein" if pos<start_mature_protein else "HA1" if pos<=HA1_length_AA else "HA2"
            ep = True if mpos in epitope_positions else False
            rbs = True if mpos in rbs_positions else False
            if ep:
                for s, pos_list in epitope_sites.items():
                    if mpos in pos_list:
                        site = s
            else:
                site = pd.NA
            muts.append([b, seq[i], pos, mpos, protein, ep, site, rbs])
    muts = pd.DataFrame.from_records(muts, columns=["actual strain AA", "antigen strain AA", "position", "mature position", "protein", "epitope", "epitope site", "RBS"])
    return muts

## 1. Get data

### 1.1. WHO vaccine strains

In [7]:
#vaccine strains per hemisphere per season year
#egg-based strain recommendation as listed on the WHO website
vaccine_strain_names = {'nh': {'2024-2025':'A/Thailand/8/2022', '2023-2024':'A/Darwin/9/2021', '2022-2023':'A/Darwin/9/2021', '2021-2022':'A/Cambodia/e0826360/2020',
                               '2020-2021':'A/Hong Kong/2671/2019', '2019-2020':'A/Kansas/14/2017', '2018-2019':'A/Singapore/INFIMH-16-0019/2016', '2017-2018':'A/Hong Kong/4801/2014',
                               '2016-2017':'A/Hong Kong/4801/2014','2015-2016':'A/Switzerland/9715293/2013','2014-2015':'A/Texas/50/2012','2013-2014':'A/Texas/50/2012',
                               '2012-2013':'A/Victoria/361/2011','2011-2012':'A/Perth/16/2009','2010-2011':'A/Perth/16/2009','2009-2010':'A/Brisbane/10/2007',
                               '2008-2009':'A/Brisbane/10/2007','2007-2008':'A/Wisconsin/67/2005','2006-2007':'A/Wisconsin/67/2005','2005-2006':'A/California/7/2004',
                               '2004-2005':'A/Fujian/411/2002','2003-2004':'A/Moscow/10/99','2002-2003':'A/Moscow/10/99','2001-2002':'A/Moscow/10/99','2000-2001':'A/Moscow/10/99'},
                        'sh': {'2024':'A/Thailand/8/2022','2023':'A/Darwin/9/2021','2022':'A/Darwin/9/2021','2021':'A/Hong Kong/2671/2019','2020':'A/South Australia/34/2019',
                               '2019':'A/Switzerland/8060/2017','2018':'A/Singapore/INFIMH-16-0019/2016','2017':'A/Hong Kong/4801/2014','2016':'A/Hong Kong/4801/2014',
                               '2015':'A/Switzerland/9715293/2013','2014':'A/Texas/50/2012','2013':'A/Victoria/361/2011','2012':'A/Perth/16/2009','2011':'A/Perth/16/2009',
                               '2010':'A/Perth/16/2009','2009':'A/Brisbane/10/2007','2008':'A/Brisbane/10/2007','2007':'A/Wisconsin/67/2005','2006':'A/California/7/2004',
                               '2005':'A/Wellington/1/2004','2004':'A/Fujian/411/2002','2003':'A/Moscow/10/99','2002':'A/Moscow/10/99','2001':'A/Moscow/10/99',
                               '2000':'A/Moscow/10/99'}
                        }

vaccine_strain_dir = "../data/vaccine_strains/H3N2"

vaccine_strains = {}
for f in os.listdir(vaccine_strain_dir):
    strain = f.split(".")[0].split("_")
    strain[1] = strain[1].replace("-", " ")
    strain = "/".join(strain)
    
    #get full cds and amino acid sequence 
    cds = [r for r in SeqIO.parse(os.path.join(vaccine_strain_dir, f), "fasta")][0].seq[:1701]
    aa_seq = cds.translate()
    aa_seq = aa_seq[:protein_length]

    #get proteins - signal protein
    signal_prot_nuc = cds[:(start_mature_protein*3)-1] 
    signal_prot_aa = aa_seq[:(start_mature_protein)-1]
    #HA1 
    HA1_nuc = cds[(start_mature_protein*3)-1:((start_mature_protein*3)+(HA1_length_AA*3))-1]
    HA1_aa = aa_seq[start_mature_protein-1:+start_mature_protein+HA1_length_AA-1]
    #HA2
    HA2_nuc = cds[((start_mature_protein*3)-1+(HA1_length_AA*3)):]
    HA2_aa = aa_seq[start_mature_protein-1+HA1_length_AA:]

    vaccine_strains[strain] = {"nuc":{"complete":cds, "signal protein":signal_prot_nuc, "HA1":HA1_nuc, "HA2":HA2_nuc, "mature protein":HA1_nuc+HA2_nuc},
                               "aa":{"complete":aa_seq, "signal protein":signal_prot_aa, "HA1":HA1_aa, "HA2":HA2_aa, "mature protein":HA1_aa+HA2_aa}}


### 1.2. Dominant strain

In [8]:
dominant_strain_file = '../data/dominant_strains.fasta'
dominant_strains = []
for r in SeqIO.parse(dominant_strain_file, "fasta"):
    region, season = r.id.split("_")[0], r.id.split("_")[1]
    dominant_strains.append([region, season, r.seq])

dominant_strains = pd.DataFrame.from_records(dominant_strains, columns=["region", "season", "sequence"]).set_index(["region", "season"]).sort_index()

### 1.3. Reproducible selection strains

In [9]:
reproducible_selection_strain_file = "../data/reproducible_selection_strains.fasta"
reproducible_selection_strains = []
for r in SeqIO.parse(reproducible_selection_strain_file, "fasta"):
    region, season, months, timing = r.id.split("_")
    reproducible_selection_strains.append([region, season, months, timing, r.seq])

reproducible_selection_strains = pd.DataFrame.from_records(reproducible_selection_strains, columns=["region", "season", "months", "timing", "sequence"]).set_index(["region", "season", "timing"]).sort_index()

### 1.4. HI data

In [10]:
#load mill hill data
HI_file = pd.ExcelFile("../data/HI_data/HI_titer_extracted.xlsx")
general_columns = ["date", "passage"]
known_passage = ["Cell", "SIAT", "MDCK", "Egg"]
HI_titers = []
for i, sn in enumerate(HI_file.sheet_names):
    df = HI_file.parse(sn, index_col=0)
    for strain, row in df.iterrows():
        l = [strain]

        for gc in general_columns:
            #add general columns
            if gc in df.columns:
                if pd.isna(row[gc]) or row[gc]=="unknown":
                    l.append("*")
                elif gc == "date":
                    l.append(str(row[gc].date()))
                else:
                    l.append(str(row[gc]))

            elif gc == "date":
                l.append("*")
            else:
                l.append("*")

        #add HI strains 
        for his in df.columns:
            
            titre = row[his]
            if titre == "ND" or titre=="NT" or pd.isna(titre):
                titre = "*"
            if titre == "<":
                titre = "<40"
                
            if his not in general_columns:
                if len(his.split(" ")) == 1: #no additional information (cell type or ferret number)
                    #HI_titers.append(l+[his,pd.NA,pd.NA,titre,year,table])
                    HI_titers.append(["|".join(l)] + [f"{his}|*|*", titre,sn])

                else:
                    his = his.split(" ")
                    #check if there was a space in the strain name
                    if not his[1].startswith("F") and "/" in his[1]:
                        his[0] = f'{his[0]} {his[1]}'
                        his =[his[0]] + his[2:] #reassign list
                    his = [v for v in his if len(v)>0]
                    if len(his) ==1:
                        # HI_titers.append(l+[his[0],pd.NA,pd.NA,titre,year,table])
                        HI_titers.append(["|".join(l)]+[f"{his[0]}|*|*",titre,sn])
                        

                    else:
                        #determine if there is a passage citation
                        if his[1] not in known_passage:
                            # HI_titers.append(l+[his[0],pd.NA," ".join(his[1:]),titre,year,table])
                            HI_titers.append(["|".join(l)]+[f"{his[0]}|*|"+" ".join(his[1:]),titre,sn])
                        else:
                            # HI_titers.append(l+[his[0],his[1],' '.join(his[2:]),titre,year,table])
                            HI_titers.append(["|".join(l)]+[f"{his[0]}|{his[1]}|"+' '.join(his[2:]),titre,sn])

HI_titers = pd.DataFrame.from_records(HI_titers, columns=["antigen", "antiserum",  "HI titer", "table"])


In [11]:
with open("../data/HI_data/HI_table_types.json") as f:
    HI_table_types = json.load(f)

In [12]:
#assign table type, strain, passage type

passage_labels = {"egg":epl, "cell":cbpl+ocbpl, "unknown":[' ', '0_1', '1.Pa(RKI)_1', '1_1', '320', 'HGR NIB']}
for i, row in HI_titers.iterrows():
    HI_titers.loc[i,"type"] = HI_table_types[row["table"]]

    HI_titers.loc[i, "strain"] = row["antigen"].split("|")[0].replace("*", "")

    psl = row["antigen"].split("|")[-1]
    if psl == "*":
        HI_titers.loc[i, "passage"] = "unknown"
        

    for p, pl in passage_labels.items():
        if psl in pl:
            HI_titers.loc[i, "passage"] = p
            break



In [13]:
with open("../data/HI_data/HI_titer_merge.json") as f:
    HI_table_merge = json.load(f)

#### 1.4.1. individual HI tables 
need these to constructed antigenic maps

In [14]:
#to write 
outpath = "../data/HI_data/individual_HI_tables"
for table in HI_titers["table"].unique():
    df = HI_titers[HI_titers["table"]==table][["antigen", "antiserum", "HI titer"]]
    df["antigen"] = [i.replace(",","") for i in df["antigen"]]
    df["antiserum"] = [i.replace(",","") for i in df["antiserum"]]
    
    try:
        df = df.pivot(index="antigen", columns="antiserum", values="HI titer")
        df.index.name = None
        df.to_csv(os.path.join(outpath, f"{table}.csv"))
    except:
       print ("double antigens or antiserums in", table)

double antigens or antiserums in Feb2018_table8-5


### 1.5. Sequence data 
to find closest match with HI antigens 

In [15]:
#get raw unaligned sequences
#get raw data
raw_data_dir = "../data/gisaid_data/raw_downloads/H3N2_HA/"
raw_sequences, raw_metadata= {}, []
for f in os.listdir(raw_data_dir):
    if f.endswith(".fasta"):
        for r in SeqIO.parse(os.path.join(raw_data_dir, f), "fasta"):
            raw_sequences[r.id.split("|")[0]] = r.seq
    elif f.endswith(".xls"):
        try:
            raw_metadata = pd.concat([raw_metadata, pd.read_excel(os.path.join(raw_data_dir, f), usecols=["Isolate_Id","HA Segment_Id", "Isolate_Name", "Passage_History", "Location", "Collection_Date"])])
            raw_metadata = raw_metadata.drop_duplicates().reset_index(drop=True)
        except:
            raw_metadata = pd.read_excel(os.path.join(raw_data_dir, f), usecols=["Isolate_Id","HA Segment_Id", "Isolate_Name", "Passage_History", "Location", "Collection_Date"])

#parse collection date as date
raw_metadata["Collection_Date"] =pd.to_datetime(raw_metadata["Collection_Date"], format="%Y-%m-%d")

#deterime passage history 
for i, row in raw_metadata.iterrows():
    ph = row["Passage_History"]
    if ph in cpl:
        raw_metadata.loc[i, "Passage_History"] = "clinical"
    elif ph in cbpl:
        raw_metadata.loc[i, "Passage_History"] = "cell-based (MDCK or SIAT)"
    elif ph in epl:
        raw_metadata.loc[i, "Passage_History"] = "egg-based"
    elif ph in ocbpl:
        raw_metadata.loc[i, "Passage_History"] = "cell-based (other)"
    else:
        raw_metadata.loc[i, "Passage_History"] = "unknown or unclear"

#determine hemisphere and country
for i, row in raw_metadata.iterrows():
    l = row["Location"]
    #determine country
    try:
        country = l.split(" / ")[1]
    except:
        country = pd.NA
    country = cs[country] if country in cs.keys() and not pd.isna(country) else country
    if not pd.isna(country) and country.upper() == country:
        country = country.lower().capitalize()
    raw_metadata.loc[i, "country"] = country
    
    #determine hemisphere
    if not pd.isna(country):
        if country in nhc:
            raw_metadata.loc[i, "hemisphere"] = "northern"
        elif country in shc:
            raw_metadata.loc[i, "hemisphere"] = "southern"
        else:
            print (country)
            raw_metadata.loc[i, "hemisphere"] = pd.NA
    else:
        raw_metadata.loc[i, "hemisphere"] = pd.NA

#sort by collection date
raw_metadata = raw_metadata.sort_values(["Collection_Date"]).reset_index(drop=True)

#remove sequences with unknown country 
raw_metadata = raw_metadata.dropna().reset_index(drop=True)     

# check for valid nucleotides in sequences
valid_nucs = ['A','T','C','G','R','Y','B','D','K','M','H','V','S','W','N']
weird_nucs = []
for sid, seq in raw_sequences.items():
    if not all(i in valid_nucs for i in seq.upper()):
        weird_nucs.append(sid.split("|")[0])
# metadata = metadata[~metadata["Isolate_Id"].isin(found_outliers + weird_nucs)]
raw_metadata = raw_metadata[~raw_metadata["Isolate_Id"].isin(weird_nucs)]



## 2. Sequences for HI antigens

### 2.1 match antigens to GISAID strain names 

In [16]:
#get ids of HI strains
known_expections = {"NYMC X-327(A/Kansas/14/17)":"A/Kansas/14/2017","NYMC X-327 (A/Kansas/14/17)":"A/Kansas/14/2017"}
tested_antigens = HI_titers["strain"].unique().tolist()

tested_antigen_ids = {}
antigens_without_gisaid_seq = []
for strain in tested_antigens:
    if strain in tested_antigen_ids.keys():
        continue
   
    try:
        li = [strain.split(" ").index(i) for i in strain.split(" ") if "/" in i][-1]
    except:
        break
    s  = " ".join([i for i in strain.split(" ") if strain.split(" ").index(i) <= li])

    if s in raw_metadata["Isolate_Name"].to_list():
        tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==s]["Isolate_Id"].values[0]
        continue

    if s in known_expections.keys():
        ns = known_expections[s]
        if ns in raw_metadata["Isolate_Name"].to_list():
            tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
            continue

    #replace spaces
    ns = s.replace(" ", "")
    if ns in raw_metadata["Isolate_Name"].to_list():
        tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
        continue

    ns = s.replace(" ", "_")
    if ns in raw_metadata["Isolate_Name"].to_list():
        tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
        continue

    ns = s.replace("_", " ")
    if ns in raw_metadata["Isolate_Name"].to_list():
        tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
        continue

    ns = s.replace("&", "and")
    if ns in raw_metadata["Isolate_Name"].to_list():
        tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
        continue
    ns = s.replace("_", " ")
    if ns in raw_metadata["Isolate_Name"].to_list():
        tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
        continue

    ns = s.replace("_", " ")
    if ns in raw_metadata["Isolate_Name"].to_list():
        tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
        continue

    #try full year
    if len(s.split("/")[-1])== 2:
        ns = "/".join(s.split("/")[:-1] + ["20"+s.split("/")[-1]])
        if ns in raw_metadata["Isolate_Name"].to_list():
            tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
            continue

    if len(s.split("/")[-1])== 2:
        ns = "/".join(s.split("/")[:-1] + ["19"+s.split("/")[-1]])
        if ns in raw_metadata["Isolate_Name"].to_list():
            tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
            continue

    ns = s.replace("-", " ")
    if ns in raw_metadata["Isolate_Name"].to_list():
        tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
        continue

    ns = s.replace("-", "")
    if ns in raw_metadata["Isolate_Name"].to_list():
        tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
        continue
        
    ns = s.split(" ")[0]
    if ns in raw_metadata["Isolate_Name"].to_list():
        tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
        continue

    ns = s.split(" (")[-1].rstrip(")")
    ns = ns.split("_(")[-1].rstrip(")")
    ns = ns.lstrip("hy ")
    if ns in raw_metadata["Isolate_Name"].to_list():
        tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
        continue

        
    #try shortened year
    if len(s.split("/")[-1])>2:
        ns = "/".join(s.split("/")[:-1] + [s.split("/")[-1][-2:]])
        if ns in raw_metadata["Isolate_Name"].to_list():
            tested_antigen_ids[strain] = raw_metadata[raw_metadata["Isolate_Name"]==ns]["Isolate_Id"].values[0]
            continue
    
    antigens_without_gisaid_seq.append(strain)
    

In [17]:
HI_selection_dir = "../analysis/antigenic_comparison/HI_data"
if not os.path.isdir(HI_selection_dir):
    os.mkdir(HI_selection_dir)

#write output file for alignment 
HI_sequence_file = os.path.join(HI_selection_dir, "HI_antigen_sequences.fasta")
with open(HI_sequence_file, "w") as fw:
    for strain, iid in tested_antigen_ids.items():
        fw.write(f">{strain}\n{raw_sequences[iid]}\n")
        

### 2.2. Align sequences 

In [18]:
threads =12
ref_ha = "../data/references/H3N2_HA.fasta"


HI_alignment_file = os.path.join(HI_selection_dir, "HI_antigen_seq_alignment.fasta")
if not os.path.isfile(HI_alignment_file): 
    cmd = ['mafft', '--auto', '--thread', str(threads), '--keeplength', '--addfragments', HI_sequence_file , ref_ha,'>', HI_alignment_file]
    os.system(" ".join(cmd))


HI_antigen_seqs = {} #complete sequences
for r in SeqIO.parse(HI_alignment_file, "fasta"):
    if r.id == "KX879573.1-ref_1968-A/Alchi/2/1968|4-HA": #skip reference
        continue
    strain = r.description
    AA_seq = r.seq[:sequence_length].replace("-", "N").translate()
    HI_antigen_seqs[strain] = AA_seq

HI_antigen_seqs = {strain:seq for strain, seq in HI_antigen_seqs.items() if strain in HI_titers["strain"].unique().tolist()}
HA1_HI_antigen_seqs = {strain:seq[start_mature_protein-1:HA1_length_AA+start_mature_protein-1] for strain, seq in HI_antigen_seqs.items()}


## 3. Get most similar HI strains 

In [19]:
closest_HI_strains =[]
diff_HI_selected = []

for region, h in region_hemispheres.items():
    for season, (fss, fse) in flu_seasons[h].items():

        ds = dominant_strains.loc[(region, season), "sequence"]
        vs = vaccine_strains[vaccine_strain_names[h][season]]["aa"]["HA1" if season in early_seasons else "complete"]
        css = reproducible_selection_strains.loc[(region, season, "WHO-timing"), "sequence"]
        lss = reproducible_selection_strains.loc[(region, season, "delayed-timing"), "sequence"]

        for strain, seq in {"dominant strain": ds, "vaccine strain": vs, "reproducible selection at WHO-timing":css,
                            "reproducible selection at delayed timing":lss}.items():
            
            #for early season GETTING THE MOST ANTIGENICALLY SIMILAR SEQUENCES
            if season  in early_seasons:
            
                acseq, adiff = get_most_antigenic_similar_sequences(seq, HA1_HI_antigen_seqs)

                #prioritise passage types
                pt = [HI_titers[HI_titers["strain"]==s]["passage"].values[0] for s in acseq]
                if "cell" in pt:
                    acseq = [acseq[i] for i, p in enumerate(pt) if p == "cell"]
                elif "egg" in pt:
                    acseq = [acseq[i] for i, p in enumerate(pt) if p == "egg"]

                
                closest_HI_strains.append([region, season, strain, ", ".join(acseq)])
                diff_HI_selected.append([region, season, strain, adiff])

            #for later season GETTING THE GENETICALLY MOST SIMILAR SEQUENCES   
            else:
                #prioritizing smallest difference in the antigenic sites
                acseq, adiff = get_most_antigenic_similar_sequences(seq, HI_antigen_seqs)
                cseq, diff = get_most_similar_sequences(seq, {strain:seq for strain, seq in HI_antigen_seqs.items() if strain in acseq})

                pt = [HI_titers[HI_titers["strain"]==s]["passage"].values[0] for s in cseq]
                if "cell" in pt:
                    cseq = [cseq[i] for i, p in enumerate(pt) if p == "cell"]
                elif "egg" in pt:
                    cseq = [cseq[i] for i, p in enumerate(pt) if p == "egg"]

                closest_HI_strains.append([region, season, strain, ", ".join(cseq)])
                diff_HI_selected.append([region, season, strain, diff])

closest_HI_strains = pd.DataFrame.from_records(closest_HI_strains, columns=["region", "season", "strain", "HI antigens"]).set_index(["region", 'season']).sort_index()
diff_HI_selected = pd.DataFrame.from_records(diff_HI_selected, columns=["region", "season", "strain", "number of differences"])


## 4. HI tables to create antigenic maps per regional season

### 4.1. manually find tables to merge to create antigenic maps
quickest way to do this was manually.....

In [20]:
#manual selection dict with list of selected HI tables
manual_HI_table_selection = {('aunz', '2003'):[["2003_table5", "2003_table6"], ["2003_table5"], ["2003_table6"]],
                             ('aunz', '2004'):[[]], 
                             ('aunz', '2005'):[[]],
                             ('aunz', '2006'):[["Mar2007_table4", "Mar2007_table5", "Mar2006_table4", "Sep2006_table4", 'Sep2006_table6'],["Mar2007_table4"], ["Mar2007_table5"]],
                             ('aunz', '2007'):[[]],
                             ('aunz', '2008'):[["Sep2008_table4A", "Sep2008_table4B", "Sep2007_table8"], ["Sep2008_table4B", "Sep2008_table4A", "Sep2007_table7", "Sep2007_table6"], 
                                               ["Feb2009_table4"],  ["Sep2008_table4A", "Mar2008_table5"],],
                             ('aunz', '2009'):[["Sep2009_table6"]],
                             ('aunz', '2010'):[["Sep2010_table7B", "Feb2010_table13"], ["Sep2009_table5", "Sep2010_table7B", "Feb2010_table13"]],
                             ('aunz', '2011'):[["Sep2011_table18"], ["Sep2012_table9"],["Feb2012_table12", "Feb2012_table13"], ["Sep2012_table15"],["Feb2012_table10"]],
                             ('aunz', '2012'):[["Feb2012_table13"], ["Feb2012_table14"], [ "Sep2012_table18", "Sep2012_table17"], ["Sep2012_table17", "Sep2012_table12"],
                                               ["Feb2012_table9", "Sep2012_table10"]],
                             ('aunz', '2013'):[["Sep2013_table7-7","Sep2013_table7-10","Feb2012_table13"],["Feb2013_table18", "Feb2012_table13"], ["Feb2012_table13", "Sep2013_table7-16", "Sep2013_table7-11"], 
                                               ["Sep2012_table20", "Sep2013_table7-12", "Sep2013_table7-9"], ["Sep2012_table20", "Feb2013_table18", "Feb2013_table20"]],
                             ('aunz', '2014'):[["Sep2012_table9", "Feb2014_table9-8", 'Sep2013_table7-8'], ["Sep2012_table9", "Sep2013_table7-7", "Feb2014_table9-2", "Sep2014_table8-10", "Sep2014_table8-14"], 
                                               ["Sep2012_table15","Sep2013_table7-12", "Feb2014_table9-7", 'Sep2014_table8-14'], ["Sep2012_table15","Sep2013_table7-12", "Feb2014_table9-2", 'Feb2014_table9-8'],
                                               ["Sep2012_table15","Sep2013_table7-7", "Feb2014_table9-7", 'Feb2014_table9-8', "Sep2014_table8-14"]],
                             ('aunz', '2015'):[["Feb2015_table9-10", "Feb2015_table9-4", "Sep2014_table8-18"],["Sep2015_table8-9", "Feb2015_table9-10", "Sep2014_table8-19"], 
                                               ["Sep2014_table8-19","Sep2014_table8-18","Sep2014_table8-10"]],
                             ('aunz', '2016'):[["Sep2017_table8-13", "Sep2017_table8-12", "Sep2016_table8-6","Sep2016_table8-2"], ["Sep2017_table8-13","Feb2016_table9-3","Feb2017_table8-5"],
                                               ["Sep2017_table8-16", "Sep2017_table8-2", "Sep2016_table8-2"], ["Sep2017_table8-13", "Sep2017_table8-2", "Sep2016_table8-2"],
                                               ["Sep2017_table8-13", "Sep2017_table8-2","Feb2016_table9-3"]],
                             ('aunz', '2017'):[["Sep2016_table8-7", "Sep2017_table8-13", "Sep2017_table8-16", "Feb2018_table8-3"], ["Sep2016_table8-7", "Sep2017_table8-13", "Sep2017_table8-16"], 
                                               ["Sep2016_table8-7", "Sep2017_table8-13", "Sep2017_table8-12", "Feb2018_table8-3"], ["Sep2016_table8-7", "Sep2017_table8-13", "Feb2018_table8-3"]],
                             ('aunz', '2018'):[["Sep2018_table8-5","Sep2017_table8-16", "Sep2017_table8-20"], ["Sep2018_table8-5","Feb2018_table8-9","Sep2017_table8-16"]],
                             ('aunz', '2019'):[["Feb2020_table7-11", "Sep2019_table8-3", "Feb2019_table8-9", "Feb2018_table8-9", "Feb2017_table8-5", "Sep2017_table21","Sep2017_table8-16"]],
                             ('aunz', '2020'):[["Feb2021_table7-2","Sep2019_table8-1","Sep2020_table7-2", "Sep2019_table8-18"],["Feb2021_table7-2", "Sep2019_table8-14", "Sep2020_table7-1","Feb2019_table8-11"],
                                               ["Sep2020_table7-7","Sep2019_table8-19","Sep2019_table8-18", "Sep2019_table8-16"]],
                             ('aunz', '2021'):[['Feb2021_table7-2', 'Sep2021_table7-4', 'Feb2022_table9-1'], ['Feb2021_table7-2', 'Sep2021_table7-7', 'Feb2022_table9-1'], 
                                               ['Feb2021_table7-2', 'Sep2021_table7-7', 'Sep2022_table10-6', 'Sep2022_table10-20'],
                                               ['Feb2021_table7-2', 'Sep2021_table7-7', 'Sep2021_table7-1', 'Sep2022_table10-6', 'Sep2022_table10-20', 'Sep2019_table8-16']],
                             ('aunz', '2022'):[["Sep2021_table7-6", "Feb2022_table9-1", "Feb2022_table9-10", "Sep2022_table10-18"], ["Sep2021_table7-6", "Feb2022_table9-1", "Feb2022_table9-2","Sep2022_table10-18"],
                                               ["Sep2021_table7-6", "Sep2022_table10-20", "Feb2022_table9-1","Sep2022_table10-18"]],
                             ('aunz', '2023'):[['Feb2023_tableH3-10', 'Sep2023_tableH3-12', "Sep2022_table10-20", 'Sep2022_table10-12'],
                                               ['Feb2023_tableH3-10', 'Sep2023_tableH3-12', 'Sep2023_tableH3-1', 'Sep2022_table10-25', 'Sep2021_table7-7'],
                                               ['Feb2023_tableH3-10', 'Sep2023_tableH3-12', 'Sep2022_table10-12', 'Feb2022_table9-1'],
                                               ['Feb2023_tableH3-10', 'Sep2023_tableH3-12', 'Sep2022_table10-12', 'Sep2022_table10-10', 'Feb2022_table9-1']],
                            ('europe', '2002-2003'):[["2003_table5", "2003_table6"], ["2003_table5"], ["2003_table6"]],
                            ('europe', '2003-2004'):[["2003_table5", "2003_table6"], ["2003_table5"], ["2003_table6"], ["2004_table5", "2003_table6"], ["2004_table5"]],
                            ('europe', '2004-2005'):[[]],
                            ('europe', '2005-2006'):[[]],
                            ('europe', '2006-2007'):[["Sep2006_table3", "Sep2006_table6"], ["Sep2006_table5", "Sep2006_table3"]],
                            ('europe', '2007-2008'):[["Mar2007_table4", "Sep2006_table6"], ["Mar2007_table4", "Sep2006_table6","Sep2007_table6"], ["Sep2007_table5", "Sep2007_table6", "Sep2007_table8"],
                                                     ["Sep2007_table5", "Sep2007_table6", "Sep2007_table8","Mar2007_table4", "Sep2006_table6"]],
                            ('europe', '2008-2009'):[["Sep2008_table4A", "Sep2007_table8", "Sep2008_table4B"], ["Sep2008_table4A", "Sep2007_table8"], 
                                                     ["Feb2009_table4"], ["Feb2009_table5"], ["Sep2009_table6"]],
                            ('europe', '2009-2010'):[[]],
                            ('europe', '2010-2011'):[[]],
                            ('europe', '2011-2012'):[["Sep2013_table7-12", "Sep2012_table9"],["Sep2013_table7-12", "Sep2012_table10"], ["Sep2013_table7-12", "Feb2012_table12"], 
                                                     ["Sep2013_table7-12", "Feb2013_table13"],["Sep2013_table7-12", "Sep2012_table15"]],
                            ('europe', '2012-2013'):[['Feb2012_table13', "Feb2013_table18", "Feb2013_table13"], ['Feb2012_table8', "Feb2013_table18", "Sep2013_table7-16"],
                                                      ["Sep2012_table18", "Sep2013_table7-7"], ["Sep2013_table7-8", "Sep2012_table17"], ["Sep2013_table7-3", "Sep2012_table9", "Sep2013_table7-16"]],
                            ('europe', '2013-2014'):[["Sep2012_table9","Sep2013_table7-9", "Sep2013_table7-7"],["Sep2012_table9", "Sep2013_table7-12"],["Sep2012_table15","Feb2013_table18"],
                                                     ["Sep2012_table9", "Feb2014_table9-7", "Sep2013_table7-10"], ["Sep2012_table15", 'Sep2013_table7-10', "Feb2014_table9-7"]], 
                            ('europe', '2014-2015'):[["Sep2012_table15", "Sep2013_table7-7", "Sep2014_table8-19", "Feb2014_table9-7", "Sep2014_table8-11"], 
                                                      ["Sep2015_table8-9", "Sep2014_table8-16", "Sep2014_table8-10", 'Feb2014_table9-2', "Feb2015_table9-4", "Sep2012_table12", "Sep2012_table15"]],
                            ('europe', '2015-2016'):[["Sep2016_table8-6","Feb2016_table9-2"],["Sep2015_table8-4","Sep2015_table8-13","Sep2015_table8-6"],
                                                     ["Sep2015_table8-8","Sep2015_table8-9"],["Sep2015_table8-13","Sep2015_table8-9"],["Feb2016_table9-3"]] ,
                            ('europe', '2016-2017'):[["Sep2016_table8-6", "Sep2016_table8-7"], ["Sep2016_table8-7", "Feb2016_table9-3"], ["Sep2016_table8-7", "Feb2016_table9-3", "Sep2016_table8-2"],
                                                      ["Sep2016_table8-7", "Sep2016_table8-2"]],
                            ('europe', '2017-2018'):[["Sep2016_table8-7", "Sep2017_table21"], ["Sep2016_table8-7", "Feb2017_table8-2"], ["Sep2016_table8-7", "Sep2016_table8-9", "Sep2017_table8-2"],
                                                      ["Sep2016_table8-7", "Sep2016_table8-9", "Feb2017_table8-2"], ["Sep2016_table8-7", "Feb2017_table8-2", "Sep2017_table21"]],
                            ('europe', '2018-2019'):[["Feb2019_table8-5","Feb2018_table8-4","Sep2017_table8-16"],["Sep2018_table8-6", "Sep2018_table8-5","Feb2018_table8-6"], ["Sep2018_table8-6","Feb2018_table8-6","Sep2017_table21"],
                                                     ["Feb2018_table8-6","Sep2017_table21", "Sep2018_table8-7"], ["Sep2018_table8-7", "Sep2018_table8-2", "Sep2018_table8-5"]],
                            ('europe', '2019-2020'):[["Feb2020_table7-12", "Sep2019_table8-15", "Sep2019_table8-6"], ["Sep2019_table8-7", "Sep2019_table8-17", "Feb2019_table8-10"],
                                                      ["Sep2019_table8-13", "Sep2019_table8-14", "Sep2019_table8-15"], ["Sep2019_table8-17", "Sep2019_table8-18"],
                                                      ["Sep2019_table8-6", "Sep2019_table8-3", "Sep2019_table8-5"]],                                                                                                                                             
                            ('europe', '2020-2021'):[["Feb2020_table7-5", "Feb2021_table7-3","Feb2020_table7-12"], ["Feb2020_table7-5", "Sep2019_table8-16", "Sep2021_table7-4"], 
                                                      ["Feb2020_table7-5","Sep2020_table7-9","Feb2022_table9-1","Sep2021_table7-4"], ["Feb2020_table7-5", "Sep2020_table7-1", "Sep2021_table7-5"]],
                            ('europe', '2021-2022'):[["Feb2020_table7-5", "Feb2021_table7-2","Sep2021_table7-10","Sep2022_table10-17"],["Feb2020_table7-5", "Feb2021_table7-2","Sep2021_table7-10","Feb2022_table9-9"],
                                                     ["Feb2020_table7-5", "Feb2021_table7-2","Sep2021_table7-10","Feb2022_table9-7"],["Feb2020_table7-5", "Feb2021_table7-2","Feb2021_table7-3","Sep2021_table7-10"]],
                            ('europe', '2022-2023'):[["Sep2021_table7-6","Feb2022_table9-11","Sep2022_table10-8"],["Sep2021_table7-6","Sep2022_table10-8","Feb2023_tableH3-10","Feb2023_tableH3-1"],
                                                     ["Sep2021_table7-6","Sep2022_table10-19"],["Sep2021_table7-6","Sep2022_table10-18"],["Sep2021_table7-6","Sep2022_table10-16","Feb2023_tableH3-1","Feb2023_tableH3-10"]],
                            ('us', '2002-2003'):[["2003_table5", "2003_table6"], ["2003_table5"], ["2003_table6"]],
                            ('us', '2003-2004'):[["2003_table5", "2003_table6"], ["2003_table5"], ["2003_table6"], ["2004_table5", "2003_table6"], ["2004_table5"]],
                            ('us', '2004-2005'):[[]],
                            ('us', '2005-2006'):[[]],
                            ('us', '2006-2007'):[[]],                          
                            ('us', '2007-2008'):[["Mar2006_table4", "Sep2006_table4",'Sep2006_table6','Sep2007_table6',"Sep2008_table4A","Sep2007_table7"]],  
                            ('us', '2008-2009'):[["Feb2009_table5"], ["Sep2009_table6"], ["Feb2009_table4"],["Sep2008_table4A", "Sep2007_table8", 'Sep2008_table4B'], 
                                                 ["Sep2008_table4A", "Sep2007_table8"]],
                            ('us', '2009-2010'):[['Sep2009_table6',"Feb2010_table12"],["Sep2009_table6"]],
                            ('us', '2010-2011'):[[]],
                            ('us', '2011-2012'):[["Sep2013_table7-12","Feb2013_table13"],["Sep2013_table7-12","Feb2012_table14"],["Sep2013_table7-12","Feb2012_table12"],
                                                 ["Sep2013_table7-12","Sep2012_table15"],["Sep2013_table7-12","Sep2012_table9"]],
                            ('us', '2012-2013'):[["Sep2013_table7-10", "Feb2012_table12"], ["Sep2013_table7-10", "Feb2012_table13"], ["Sep2013_table7-10", "Feb2012_table14"],
                                                  ["Sep2013_table7-10", "Feb2012_table11"], ["Sep2013_table7-10", "Sep2012_table12"]],
                            ('us', '2013-2014'):[['Sep2012_table15', 'Sep2013_table7-10', 'Feb2014_table9-2', 'Sep2014_table8-10', "Sep2014_table8-19"],
                                                  ['Sep2012_table9', 'Sep2013_table7-10', 'Feb2014_table9-7', 'Sep2014_table8-13', 'Feb2015_table9-3', 'Feb2015_table9-4']],
                            ('us', '2014-2015'):[["Sep2012_table9", "Sep2013_table7-7", 'Feb2014_table9-7', 'Sep2014_table8-10', "Sep2014_table8-16", 'Feb2015_table9-4',  "Sep2015_table8-9"],
                                                  ["Sep2013_table7-8", "Sep2012_table15", "Sep2014_table8-19", "Sep2014_table8-10", "Feb2014_table9-6"],
                                                  ["Sep2013_table7-8", "Feb2015_table9-4", "Sep2012_table9", "Sep2014_table8-19", "Sep2014_table8-10", "Feb2014_table9-6"]],
                            ('us', '2015-2016'):[["Feb2016_table9-3"], ["Sep2014_table8-19","Feb2015_table9-2"], ["Feb2015_table9-4","Feb2015_table9-2"],
                                                 ["Sep2015_table8-9","Sep2015_table8-8"], ["Sep2016_table8-6","Sep2016_table8-2"]],
                            ('us', '2016-2017'):[['Sep2016_table8-7', 'Sep2016_table8-6'], ['Sep2016_table8-7', 'Sep2016_table8-2'], 
                                                  ['Sep2016_table8-7','Feb2016_table9-3'], ['Sep2016_table8-7',  'Sep2016_table8-2', 'Feb2016_table9-3']],
                            ('us', '2017-2018'):[['Sep2016_table8-7', 'Sep2017_table8-16', 'Sep2017_table8-2'], ['Sep2016_table8-7', 'Sep2017_table8-2'],
                                                  ['Sep2016_table8-7', 'Feb2017_table8-2']],
                            ('us', '2018-2019'):[['Sep2017_table8-16','Feb2018_table8-6','Feb2019_table8-7'], ['Feb2019_table8-9', 'Feb2018_table8-9','Sep2017_table8-2'], 
                                                 ['Feb2018_table8-9', 'Feb2019_table8-7', 'Sep2017_table21']],
                            ('us', '2019-2020'):[["Sep2019_table8-1","Sep2019_table8-11","Sep2019_table8-14"],["Sep2019_table8-6","Sep2019_table8-3","Sep2019_table8-1"],
                                                 ["Feb2019_table8-11","Sep2019_table8-14","Sep2019_table8-15"],["Sep2019_table8-1","Sep2019_table8-18"]],
                            ('us', '2020-2021'):[["Sep2021_table7-9","Sep2020_table7-5"],["Sep2020_table7-9", "Sep2021_table7-4"], 
                                                 ["Feb2021_table7-3","Feb2020_table7-12"],["Feb2022_table9-6","Feb2022_table9-8"],["Sep2021_table7-4","Sep2020_table7-5"]],
                            ('us', '2021-2022'):[['Feb2021_table7-2','Feb2020_table7-5','Sep2022_table10-18', 'Sep2021_table7-4']],                                                                                                       
                            ('us', '2022-2023'):[["Sep2021_table7-6","Sep2022_table10-19"],["Sep2021_table7-6","Sep2022_table10-17","Sep2022_table10-16"],
                                                 ["Sep2021_table7-6","Feb2023_tableH3-10","Feb2023_tableH3-1"]],
                             
                             }

manual_HI_table_selection = [list(k)+[", ".join(l)] for k,v in manual_HI_table_selection.items() for l in v]
manual_HI_table_selection = pd.DataFrame.from_records(manual_HI_table_selection, columns=["region", "season", "table list"])



In [21]:
for (region, season) in closest_HI_strains.index.unique():
    if (region, season) in manual_HI_table_selection.set_index(["region", "season"]).sort_index().index.to_list():
        continue
    print (f"('{region}', '{season}')")

    sdf = closest_HI_strains.loc[(region, season), :].set_index("strain")
    HI_antigen_strains = [s for sl in sdf["HI antigens"].tolist() for s in sl.split(", ")]
    HI_antigen_tables = HI_titers[HI_titers["strain"].isin(HI_antigen_strains)][["table", "type", "strain"]].drop_duplicates(ignore_index=True)

    for t in HI_antigen_tables["type"].unique():
        tdf = HI_antigen_tables[HI_antigen_tables["type"]==t]

        #determine antigens for each strain
        ds = tdf[tdf["strain"].isin(sdf.loc["dominant strain", "HI antigens"].split(", "))]["strain"].unique().tolist()
        
        vs = tdf[tdf["strain"].isin(sdf.loc["vaccine strain", "HI antigens"].split(", "))]["strain"].unique().tolist()
        css = tdf[tdf["strain"].isin(sdf.loc["reproducible selection at WHO-timing", "HI antigens"].split(", "))]["strain"].unique().tolist()
        lss = tdf[tdf["strain"].isin(sdf.loc["reproducible selection at delayed timing", "HI antigens"].split(", "))]["strain"].unique().tolist()

        #if no antigens for any of the strains continue
        if len(ds)==0 or len(vs)==0 or len(css)==0 or len(lss)==0:
            continue 

        #for each strain determine in which table they are
        ds_tables, vs_tables, css_tables, lss_tables = {}, {}, {},{}
        for tb in tdf["table"].unique():
            ts = tdf[tdf["table"]==tb]["strain"].unique().tolist()
        
            ds_tables[tb] = [s for s in ts if s in  ds]
            vs_tables[tb] = [s for s in ts if s in  vs]
            css_tables[tb] = [s for s in ts if s in css]
            lss_tables[tb] = [s for s in ts if s in lss]

        #remove empty tables
        ds_tables = {k:v for k, v in ds_tables.items() if len(v)>0}
        vs_tables = {k:v for k, v in vs_tables.items() if len(v)>0}
        css_tables = {k:v for k, v in css_tables.items() if len(v)>0}
        lss_tables = {k:v for k, v in lss_tables.items() if len(v)>0}

        #make an overview of the tables and the number of strain is each table
        table_overview = []
        for i, tb in enumerate(set(list(ds_tables.keys()) + list(vs_tables.keys()) + list(css_tables.keys())+list(lss_tables.keys())), start=1):
            l = []
            l.append(len(ds_tables[tb]) if tb in ds_tables else 0)
            l.append(len(vs_tables[tb]) if tb in vs_tables else 0)
            l.append(len(css_tables[tb]) if tb in css_tables else 0)
            l.append(len(lss_tables[tb]) if tb in lss_tables else 0)
            table_overview.append([i, tb]+l)

        table_overview = pd.DataFrame.from_records(table_overview, columns=["tn", "table", "ds", "vs", "css", "lss"])

        print (f"ds: {len(ds)}, vs: {len(vs)}, css: {len(css)}, lss: {len(lss)} ")
        table_overview = table_overview.set_index("table").sort_index()
        table_overview.index.name = None
        print(table_overview[["ds", "vs", "css", "lss"]].to_string())

                
        ds_, vs_, css_, lss_= ds, vs, css,lss
        
        
    
    
    break


In [22]:
csl = [[]]

for cs in csl:
    for i in cs:
        print (HI_table_merge[i])
    try:     
        print (can_merge_tables(cs, HI_table_merge), len(cs))
        print (count_items_in_tables(cs, [ds_, vs_, css_, lss_], [ds_tables, vs_tables, css_tables, lss_tables]))
    except:
        pass
    print()




In [23]:
#assign map code to keep track of the maps
mc_count = 0
map_codes = {}
for tl in manual_HI_table_selection["table list"].unique():
    if len(tl)== 0 or pd.isna(tl):
        continue
    mc_count += 1
    map_codes[tl] = mc_count

#annotated in df
for i, row in manual_HI_table_selection.iterrows():
    if len(row["table list"]) > 0 and not pd.isna(row["table list"]): 
        manual_HI_table_selection.loc[i, "map code"] = map_codes[row["table list"]]
     
manual_HI_table_selection = manual_HI_table_selection.set_index(["region", "season"]).sort_index()

### 4.2. Determine HI antigens for each region, season

In [24]:
closest_HI_strains

Unnamed: 0_level_0,Unnamed: 1_level_0,strain,HI antigens
region,season,Unnamed: 2_level_1,Unnamed: 3_level_1
aunz,2003,dominant strain,"A/Fujian/411/02, A/Finland/170/03"
aunz,2003,vaccine strain,"A/Panama/2007/99, A/Panama/2007/1999"
aunz,2003,reproducible selection at WHO-timing,"A/New York/55/01, A/Latvia/1506/03"
aunz,2003,reproducible selection at delayed timing,"A/New York/55/01, A/Latvia/1506/03"
aunz,2004,dominant strain,"A/Stockholm/15/2004, A/Norway/70/2005"
...,...,...,...
us,2021-2022,reproducible selection at delayed timing,A/Qatar/16-VI-19-0049409/2019
us,2022-2023,dominant strain,"A/Netherlands/10205/2021, A/Netherlands/10884/..."
us,2022-2023,vaccine strain,A/Darwin/9/2021
us,2022-2023,reproducible selection at WHO-timing,A/Darwin/6/2021


In [25]:
selected_HI_antigens = []
for region, h in region_hemispheres.items():
    for season, (fss, fse) in flu_seasons[h].items():

        df = closest_HI_strains.loc[(region, season),:]
        tdf = manual_HI_table_selection.loc[(region, season),:]
        
        for i, row in tdf.iterrows():
            tables = row["table list"].split(", ")
            tables_antigen_stains = HI_titers[HI_titers["table"].isin(tables)]["strain"].unique().tolist()
            
            for j, row2 in df.iterrows():
                strain = row2["strain"]
                possible_antigens = row2["HI antigens"].split(", ")

                selected_ag = [ag for ag in possible_antigens if ag in tables_antigen_stains]

                selected_HI_antigens.append([region, season, row["map code"], row["table list"], strain, ", ".join(selected_ag)])

selected_HI_antigens = pd.DataFrame.from_records(selected_HI_antigens, columns=["region", "season", "map code", "table list", "strain", "selected antigens"])

In [26]:
selected_HI_antigens_overview = pd.pivot(selected_HI_antigens, index=["region", "season", "map code"], columns=["strain"], values=["selected antigens"])
selected_HI_antigens_overview.columns = selected_HI_antigens_overview.columns.droplevel()
selected_HI_antigens_overview.columns.name = None
selected_HI_antigens_overview = selected_HI_antigens_overview[["dominant strain", "vaccine strain", "reproducible selection at WHO-timing", "reproducible selection at delayed timing"]]


In [27]:
selected_HI_antigens_overview.reset_index().to_csv("../analysis/antigenic_comparison/selected_antigens.csv", index=False)

#### 4.2.1 get actual amino acid difference of between the selected antigens and the actual strains

also gettig the GISAID IDs of the selected HI antigens

In [28]:
try:
    selected_HI_antigens = selected_HI_antigens.set_index(["region", "season", "map code"]).sort_index()
except:
    selected_HI_antigens = selected_HI_antigens.reset_index().set_index(["region", "season", "map code"]).sort_index()

In [29]:
antigen_strain_ids = {}
strain_HI_antigen_mutations = []

seen = []
for (region, season, mapcode), row in selected_HI_antigens.iterrows():
    if pd.isna(row["selected antigens"]) or row["selected antigens"] == "":
        continue
    
    strain= row["strain"]
    if strain == "dominant":
        seq = dominant_strains.loc[(region, season), "sequence"]
    elif strain == "vaccine strain":
        seq = vaccine_strains[vaccine_strain_names[region_hemispheres[region]][season]]["aa"]["HA1" if season in early_seasons else "complete"]
    elif strain == "reproducible selection at WHO-timing":
        seq = reproducible_selection_strains.loc[(region, season, "WHO-timing"),"sequence"]
    else:
        seq = reproducible_selection_strains.loc[(region, season, "delayed-timing"),"sequence"]

    antigens = row["selected antigens"].split(", ")
    for antigen in antigens:
        if antigen not in antigen_strain_ids.keys():
            antigen_strain_ids[antigen] = tested_antigen_ids[antigen]
        
        if (region, season, strain, antigen) not in seen:
            if season in early_seasons:
                antigen_seq = HA1_HI_antigen_seqs[antigen]
            else:
                antigen_seq = HI_antigen_seqs[antigen]

            muts = get_mutations(seq, antigen_seq, sl="HA1" if season in early_seasons else "complete")
            muts[["region", "season", "strain", "antigen"]] = [region, season, strain, antigen]
            try:
                strain_HI_antigen_mutations = pd.concat([strain_HI_antigen_mutations, muts], ignore_index=True)
            except:
                strain_HI_antigen_mutations = muts
            seen.append((region, season, strain, antigen))

antigen_strain_ids = pd.DataFrame.from_dict(antigen_strain_ids, orient="index", columns=["GISAID ID"]).reset_index().rename(columns={"index":"HI antigen strain"})


In [30]:
strain_HI_antigen_mutations.to_csv("../analysis/antigenic_comparison/mutations_between_selected_HI_antigens_and_strain.csv", index=False)
antigen_strain_ids.to_csv("../analysis/antigenic_comparison/GISAID_ids_of_HI_antigens.csv", index=False)

### 4.3. Prep R-script for antigenic cartography

In [31]:
#determine all unique table maps to be loaded 
individual_tables = []
for tl in map_codes.keys():
    for t in tl.split(", "):
        if t not in individual_tables:
            individual_tables.append(t)

#### 4.3.1. antigenic maps construction

In [32]:
#script to write the antigenic maps
with open("../scripts/construct_antigenic_maps.R", "w") as fw:
    fw.write("library(Racmacs)\n")

    fw.write("options(RacOptimizer.num_cores = 10)\n")
    fw.write('dir <- "./data/HI_data/individual_HI_tables"\n')
    fw.write('acmap_dir <- "./analysis/antigenic_comparison/antigenic_maps/complete_maps"\n')
    fw.write('coords_dir <- "./analysis/antigenic_comparison/antigenic_maps/coords"\n\n')


    for i in individual_tables:
        fw.write (f"t{i.replace('table', '').replace('-', '_')} <- read.titerTable(file.path(dir, '{i}.csv'))\n")

    fw.write ("\n#load individual maps\n")
    for i in individual_tables:
        i = i.replace('table', '').replace('-', '_')
        fw.write(f"map{i} <- acmap( titer_table=t{i}, sr_names=colnames(t{i}), ag_names=rownames(t{i}))\n\n")
    

    for tl, n in map_codes.items():
        tl = tl.split(", ")

        maplist = ", ".join([f"map{i.replace('table', '').replace('-', '_')}" for i in tl])

        fw.write(f"map{n} <- mergeMaps({maplist}, method='table', number_of_dimensions = 2)\n" )
        fw.write(f"map{n} <- optimizeMap(map = map{n}, number_of_dimensions = 2, number_of_optimizations = {(round(1000*len(tl)*3))}, options = list(ignore_disconnected = TRUE))\n")
        #print(f"map{n} <- keepBestOptimization(map{n})")
        fw.write(f"save.acmap(map{n}, file.path(acmap_dir, 'map{n}.ace'))\n")
        fw.write(f"save.coords(map{n}, file.path(coords_dir, 'map{n}.csv'))\n\n")
        


#### 4.3.2. antigenic map visualization

In [33]:
with open("../scripts/summarize_antigenic_maps.R", "w") as fw:
    
    fw.write("library(Racmacs)\n")
    fw.write("library(ggplot2)\n")
    fw.write("options(RacOptimizer.num_cores = 10)\n")
    
    fw.write('dir <- "./data/HI_data/individual_HI_tables/"\n')
    fw.write('acmap_dir <- "./analysis/antigenic_comparison/antigenic_maps/complete_maps"\n')
    fw.write('acmap_simple_dir <- "./analysis/antigenic_comparison/antigenic_maps/maps"\n')
    fw.write('coords_dir <- "./analysis/antigenic_comparison/antigenic_maps/coords"\n')

    fw.write("#load individual maps\n")
    for mapcode in map_codes.values():
        fw.write(f"map{mapcode} <- keepBestOptimization(read.acmap(file.path(acmap_dir, 'map{mapcode}.ace')))\n")
        fw.write(f"save.acmap(map{mapcode}, file.path(acmap_simple_dir, 'map{mapcode}.ace'))\n")




In [34]:
table_colors = ["#ea5545","#ef9b20","#ede15b","#87bc45","#27aeef","#b33dc6", "#4e00ff"]
ag_colors = {"shared":"#000000", #black
             "ds_vs_css":"#6e6c02", #mustard
             "ds_vs_lss":"#093303", #dark green
             "ds_css_lss":"#00c1f3", #cyan
             "vs_css_lss":"#d399d8", #light pink
             "ds_vs":"#331d03", #brown
             "ds_css":"#584e82", #mauva
             "ds_lss":"#9fd600", #green
             "vs_css":"#cf0030", #red
             "vs_lss":"#7eb4f2", #baby blue
             "css_lss":"#6d21b8", #purple
             "ds":"#fcd74e", #yelllow
             "vs":"#ffa600", #orange
             "css":"#d900ba", #pink
             "lss":"#0051ff"} #darkblue

color_seasons = True

with open("../scripts/visualize_antigenic_maps.R", "w") as fw:
    
    fw.write("library(Racmacs)\n")
    fw.write("library(ggplot2)\n")
    fw.write("options(RacOptimizer.num_cores = 10)\n\n")
    
    fw.write("setwd('~/Desktop/later-strain-selection')\n\n")

    fw.write('dir <- "./data/HI_data/individual_HI_tables/"\n')
    fw.write('acmap_dir <- "./analysis/antigenic_comparison/antigenic_maps/maps"\n')
    fw.write('coords_dir <- "./analysis/antigenic_comparison/antigenic_maps/coords"\n')
    fw.write('table_fig_dir <- "./figures/antigenic_maps/colored_by_table"\n')
    fw.write('strain_fig_dir <- "./figures/antigenic_maps/colored_by_strain"\n\n')

    fw.write("#load individual maps\n")
    for mapcode in map_codes.values():
        fw.write(f"map{mapcode} <- read.acmap(file.path(acmap_dir, 'map{mapcode}.ace'))\n")
        
    #check quality of maps
    fw.write ("\n# check if maps are good")
    for mapcode in map_codes.values():
        fw.write(f"\ncheckHemisphering(map{mapcode})")
    
    for tl, mapcode in map_codes.items():

        fw.write(f"\n#################### MAP {mapcode} ####################\n")
        fw.write("#### colered by tables\n")
        #start with all antigens grey
        fw.write(f"agSize(map{mapcode}) <- 5\n")
        fw.write(f"agFill(map{mapcode}) <- 'grey50'\n")

        #assign antigens
        tl = tl.split(", ")
        if len(tl) >1:
            for i, table in enumerate(tl):
                table_antigens = HI_titers[HI_titers["table"]==table]["antigen"].unique().tolist()
                non_table_antigens = HI_titers[HI_titers["table"].isin([t for t in tl if t!=table])]["antigen"].unique().tolist()
                table = table.replace("-","_")

                unique_table_antigens = [ta for ta in table_antigens if ta not in non_table_antigens]
                write_antigens(fw, mapcode, unique_table_antigens, f"ag_{table}")
            
            #color antigens
            fw.write("\n")
            for i, table in enumerate(tl):
                table = table.replace("-","_")

                fw.write(f"agFill(map{mapcode})[ag_{table}] <- '{table_colors[i]}'\n") 

        fw.write(f"p_map{mapcode} <- ggplot(map{mapcode}) + ggtitle('map {mapcode} ({', '.join(tl)})')\n")
        fw.write(f"p_map{mapcode}_png <- file.path(table_fig_dir, 'map{mapcode}.png')\n")
        fw.write(f"png(file=p_map{mapcode}_png)\n")
        fw.write(f"p_map{mapcode}\n")
        fw.write("dev.off()\n\n")

        if color_seasons:

            fw.write("#### colered by region and season\n")

            df = selected_HI_antigens.reset_index()
            df = df[df["map code"]==mapcode].set_index(["region", "season"]).sort_index()

            for (region, season) in df.index.unique():
                sdf = df.loc[(region, season), :].set_index(["strain"]).sort_index()
                fw.write(f"\n####### {region}, {season} #########\n")
                #start with all antigens grey
                fw.write(f"agSize(map{mapcode}) <- 5\n")
                fw.write(f"agFill(map{mapcode}) <- 'grey50'\n")

                #get antigens of the strains
                ds_ag = sdf.loc["dominant strain", "selected antigens"].split(", ")
                ds_ag = HI_titers[(HI_titers["strain"].isin(ds_ag))&(HI_titers["table"].isin(tl))]["antigen"].unique().tolist()
                vs_ag = sdf.loc["vaccine strain", "selected antigens"].split(", ")
                vs_ag = HI_titers[(HI_titers["strain"].isin(vs_ag))&(HI_titers["table"].isin(tl))]["antigen"].unique().tolist()
                vs_ag = [ag.replace(",", "") for ag in vs_ag]
                css_ag = sdf.loc["reproducible selection at WHO-timing", "selected antigens"].split(", ")
                css_ag = HI_titers[(HI_titers["strain"].isin(css_ag))&(HI_titers["table"].isin(tl))]["antigen"].unique().tolist()
                lss_ag = sdf.loc["reproducible selection at delayed timing", "selected antigens"].split(", ")
                lss_ag = HI_titers[(HI_titers["strain"].isin(lss_ag))&(HI_titers["table"].isin(tl))]["antigen"].unique().tolist()


                #get possible combinations
                shared_ag = [ag for ag in ds_ag if ag in vs_ag and ag in css_ag and ag in lss_ag]
                ds_vs_css_ag = [ag for ag in ds_ag if ag in vs_ag and ag in css_ag and ag not in lss_ag]
                ds_vs_lss_ag= [ag for ag in ds_ag if ag in vs_ag and ag not in css_ag and ag in lss_ag]
                ds_css_lss_ag = [ag for ag in ds_ag if ag not in vs_ag and ag in css_ag and ag in lss_ag]
                vs_css_lss_ag = [ag for ag in vs_ag if ag not in ds_ag and ag in css_ag and ag in lss_ag]
                ds_vs_ag = [ag for ag in ds_ag if ag in vs_ag and ag not in css_ag and ag not in lss_ag]
                ds_css_ag = [ag for ag in ds_ag if ag not in vs_ag and ag in css_ag and ag not in lss_ag]
                ds_lss_ag = [ag for ag in ds_ag if ag not in vs_ag and ag not in css_ag and ag in lss_ag]
                vs_css_ag = [ag for ag in vs_ag if ag not in ds_ag and ag in css_ag and ag not in lss_ag]
                vs_lss_ag = [ag for ag in vs_ag if ag not in ds_ag and ag not in css_ag and ag in lss_ag]
                css_lss_ag = [ag for ag in css_ag if ag not in ds_ag and ag not in vs_ag and ag in lss_ag]

                ds_u_ag = [ag for ag in ds_ag if ag not in vs_ag and ag not in css_ag and ag not in lss_ag]
                vs_u_ag = [ag for ag in vs_ag if ag not in ds_ag and ag not in css_ag and ag not in lss_ag]
                css_u_ag = [ag for ag in css_ag if ag not in ds_ag and ag not in vs_ag and ag not in lss_ag]
                lss_u_ag = [ag for ag in lss_ag if ag not in ds_ag and ag not in vs_ag and ag not in css_ag]

                ag_dict = {"shared":shared_ag, "ds_vs_css":ds_vs_css_ag, "ds_vs_lss":ds_vs_lss_ag, "ds_css_lss":ds_css_lss_ag, "vs_css_lss":vs_css_lss_ag,
                        "ds_vs":ds_vs_ag, "ds_css":ds_css_ag, "ds_lss":ds_lss_ag, "vs_css":vs_css_ag, "vs_lss":vs_lss_ag,
                        "css_lss":css_lss_ag, "ds":ds_u_ag, "vs":vs_u_ag, "css":css_u_ag, "lss":lss_u_ag}

                for ag_name, ag_list in ag_dict.items():
                    if len(ag_list) == 0:
                        continue
                    write_antigens(fw, mapcode, ag_list, ag_name)

                fw.write("\n")
                for ag_name, ag_list in ag_dict.items():
                    if len(ag_list) == 0:
                        continue
                    fw.write(f"agSize(map{mapcode})[{ag_name}] <- 8\n")
                    fw.write(f"agFill(map{mapcode})[{ag_name}] <- '{ag_colors[ag_name]}'\n") 

                li = list(set(ds_ag+vs_ag+lss_ag+css_ag))
                fw.write(f"ptDrawingOrder(map{mapcode}) <- c(seq_len(numSera(map{mapcode})) + numAntigens(map{mapcode}),\n")
                if len(li) <= 10:
                    fw.write(f"    which(!agNames(map{mapcode}) %in% c{tuple(li)}),\n")
                    fw.write(f"    which(agNames(map{mapcode}) %in% c{tuple(li)})\n")
                
                else:
                    
                    l = ','.join([f"'{ag}'" for ag in li[0:10]])
                    fw.write(f"\twhich(!agNames(map{mapcode}) %in% c({l},\n")
                    if len(range(10, len(li)-10, 10)) > 0:
                        for i in range(10, len(li)-10, 10):
                            l = ','.join([f'"{ag}"' for ag in li[i:i+10]])
                            if i > len(li)-10:
                                fw.write(f"\t{l})),\n")
                            else:
                                fw.write(f"\t{l},\n")
                        if i < len(li)-10:
                            l = ','.join([f'"{ag}"' for ag in li[i+10:]])
                            fw.write(f"\t{l})),\n")
                    else:
                        l = ','.join([f'"{ag}"' for ag in li[10:]])
                        fw.write(f"\t{l})),\n")


                    l = ','.join([f"'{ag}'" for ag in li[0:10]])
                    fw.write(f"\twhich(agNames(map{mapcode}) %in% c({l},")
                    if len(range(10, len(li)-10, 10)) > 0:
                        for i in range(10, len(li)-10, 10):
                            l = ','.join([f'"{ag}"' for ag in li[i:i+10]])
                            if i > len(li)-10:
                                fw.write(f"\t{l}))\n")
                            else:
                                fw.write(f"\t{l},")
                        if i < len(li)-10:
                            l = ','.join([f'"{ag}"' for ag in li[i+10:]])
                            fw.write(f"\t{l}))\n")
                    else:
                        l = ','.join([f'"{ag}"' for ag in li[10:]])
                        fw.write(f"\t{l}))\n")
                fw.write(")\n")

                fw.write(f"p_map{mapcode}_{region}_{season.replace('-','_')} <- ggplot(map{mapcode}) + ggtitle('{region} {season}  (map {mapcode})')\n")
                fw.write(f"p_map{mapcode}_{region}_{season.replace('-','_')}\n\n")
                fw.write(f"p_map{mapcode}_{region}_{season.replace('-','_')}_png <- file.path(strain_fig_dir, 'map{mapcode}_{region}_{season.replace('-','_')}.png')\n")
                fw.write(f"png(file=p_map{mapcode}_{region}_{season.replace('-','_')}_png)\n")
                fw.write(f"p_map{mapcode}_{region}_{season.replace('-','_')}\n")
                fw.write("dev.off()\n\n")


### 4.4. Remove low quality maps

In [35]:
low_quality_maps = [7,8,9,10,33,59,71,72,86]

In [36]:
manual_HI_table_selection.reset_index()

Unnamed: 0,region,season,table list,map code
0,aunz,2003,"2003_table5, 2003_table6",1.0
1,aunz,2003,2003_table5,2.0
2,aunz,2003,2003_table6,3.0
3,aunz,2004,,
4,aunz,2005,,
...,...,...,...,...
200,us,2020-2021,"Sep2021_table7-4, Sep2020_table7-5",160.0
201,us,2021-2022,"Feb2021_table7-2, Feb2020_table7-5, Sep2022_ta...",161.0
202,us,2022-2023,"Sep2021_table7-6, Sep2022_table10-19",125.0
203,us,2022-2023,"Sep2021_table7-6, Sep2022_table10-17, Sep2022_...",162.0


In [37]:
for (region, season) in closest_HI_strains.index.unique():
    if (region, season) in manual_HI_table_selection.sort_index().index.to_list():
        continue
    print (f"('{region}', '{season}')")

    sdf = closest_HI_strains.loc[(region, season), :].set_index("strain")
    HI_antigen_strains = [s for sl in sdf["HI antigens"].tolist() for s in sl.split(", ")]
    HI_antigen_tables = HI_titers[HI_titers["strain"].isin(HI_antigen_strains)][["table", "type", "strain"]].drop_duplicates(ignore_index=True)

    for t in HI_antigen_tables["type"].unique():
        tdf = HI_antigen_tables[HI_antigen_tables["type"]==t]

        #determine antigens for each strain
        ds = tdf[tdf["strain"].isin(sdf.loc["dominant strain", "HI antigens"].split(", "))]["strain"].unique().tolist()
        
        vs = tdf[tdf["strain"].isin(sdf.loc["vaccine strain", "HI antigens"].split(", "))]["strain"].unique().tolist()
        css = tdf[tdf["strain"].isin(sdf.loc["reproducible selection at WHO-timing", "HI antigens"].split(", "))]["strain"].unique().tolist()
        lss = tdf[tdf["strain"].isin(sdf.loc["reproducible selection at delayed timing", "HI antigens"].split(", "))]["strain"].unique().tolist()

        #if no antigens for any of the strains continue
        if len(ds)==0 or len(vs)==0 or len(css)==0 or len(lss)==0:
            continue 

        #for each strain determine in which table they are
        ds_tables, vs_tables, css_tables, lss_tables = {}, {}, {},{}
        for tb in tdf["table"].unique():
            ts = tdf[tdf["table"]==tb]["strain"].unique().tolist()
        
            ds_tables[tb] = [s for s in ts if s in  ds]
            vs_tables[tb] = [s for s in ts if s in  vs]
            css_tables[tb] = [s for s in ts if s in css]
            lss_tables[tb] = [s for s in ts if s in lss]

        #remove empty tables
        ds_tables = {k:v for k, v in ds_tables.items() if len(v)>0}
        vs_tables = {k:v for k, v in vs_tables.items() if len(v)>0}
        css_tables = {k:v for k, v in css_tables.items() if len(v)>0}
        lss_tables = {k:v for k, v in lss_tables.items() if len(v)>0}

        #make an overview of the tables and the number of strain is each table
        table_overview = []
        for i, tb in enumerate(set(list(ds_tables.keys()) + list(vs_tables.keys()) + list(css_tables.keys())+list(lss_tables.keys())), start=1):
            l = []
            l.append(len(ds_tables[tb]) if tb in ds_tables else 0)
            l.append(len(vs_tables[tb]) if tb in vs_tables else 0)
            l.append(len(css_tables[tb]) if tb in css_tables else 0)
            l.append(len(lss_tables[tb]) if tb in lss_tables else 0)
            table_overview.append([i, tb]+l)

        table_overview = pd.DataFrame.from_records(table_overview, columns=["tn", "table", "ds", "vs", "css", "lss"])

        print (f"ds: {len(ds)}, vs: {len(vs)}, css: {len(css)}, lss: {len(lss)} ")
        table_overview = table_overview.set_index("table").sort_index()
        table_overview.index.name = None
        print(table_overview[["ds", "vs", "css", "lss"]].to_string())

                
        ds_, vs_, css_, lss_= ds, vs, css,lss
        
        
    
    
    break


## 5. Get centroid distances

In [38]:
try:
    selected_HI_antigens = selected_HI_antigens.set_index(["region", "season", "map code"]).sort_index()
except:
    selected_HI_antigens = selected_HI_antigens.reset_index().set_index(["region", "season", "map code"]).sort_index()

In [39]:
centroid_distances = []
coord_dir = "../analysis/antigenic_comparison/antigenic_maps/coords"

for (region, season, mapcode) in selected_HI_antigens.index.unique():
    if mapcode in low_quality_maps or pd.isna(mapcode):
        continue
    df = selected_HI_antigens.loc[(region, season, mapcode),].set_index("strain").sort_index()
    tl = df["table list"].values[0].split(", ")

    #get antigens of the strains
    ds = df.loc["dominant strain", "selected antigens"].split(", ")
    ds_ag = HI_titers[(HI_titers["table"].isin(tl))&(HI_titers["strain"].isin(ds))]["antigen"].unique().tolist()
    vs= df.loc["vaccine strain", "selected antigens"].split(", ")
    vs_ag = HI_titers[(HI_titers["table"].isin(tl))&(HI_titers["strain"].isin(vs))]["antigen"].unique().tolist()
    vs_ag = [ag.replace(",", "") for ag in vs_ag]
    css= df.loc["reproducible selection at WHO-timing", "selected antigens"].split(", ")
    css_ag = HI_titers[(HI_titers["table"].isin(tl))&(HI_titers["strain"].isin(css))]["antigen"].unique().tolist()
    lss= df.loc["reproducible selection at delayed timing", "selected antigens"].split(", ")
    lss_ag = HI_titers[(HI_titers["table"].isin(tl))&(HI_titers["strain"].isin(lss))]["antigen"].unique().tolist()

    
    #get coordinates
    coords = pd.read_csv(os.path.join(coord_dir, f"map{int(mapcode)}.csv"), names=["type", "antigen", "X", "Y"], header=0)
    #filter antigens
    coords = coords[coords['type']=="antigen"].set_index("antigen").sort_index().drop(columns="type")

    #get centriod distances
    ds_coords = coords[coords.index.isin(ds_ag)].values.tolist()
    ds_centroid = centroid(ds_coords)

    vs_coords = coords[coords.index.isin(vs_ag)].values.tolist()
    vs_centroid = centroid(vs_coords)

    css_coords = coords[coords.index.isin(css_ag)].values.tolist()
    css_centroid = centroid(css_coords)

    lss_coords = coords[coords.index.isin(lss_ag)].values.tolist()
    lss_centroid = centroid(lss_coords)

    #get distances  
    ds2vs = cdist([ds_centroid], [vs_centroid])[0][0]
    ds2css = cdist([ds_centroid], [css_centroid])[0][0]
    ds2lss = cdist([ds_centroid], [lss_centroid])[0][0]

    centroid_distances.append([region, season, mapcode, ds2vs, ds2css, ds2lss, len(ds_coords), len(vs_coords), len(css_coords), len(lss_coords)])

centroid_distances = pd.DataFrame.from_records(centroid_distances, columns=["region", "season", "map code", "DS to VS", "DS to CSS", "DS to LSS", "DS cluster size", 
                                                                            "VS cluster size", "CSS cluster size", "LSS cluster size"]).set_index(["region", "season"]).sort_index()

for region, h in region_hemispheres.items():
    for season, (fss, fse) in flu_seasons[h].items():
        if (region, season) not in centroid_distances.index.tolist():
            centroid_distances.loc[(region, season), :] = [pd.NA]*len(centroid_distances.columns)
centroid_distances = centroid_distances.sort_index()

  centroid_distances.loc[(region, season), :] = [pd.NA]*len(centroid_distances.columns)
  centroid_distances.loc[(region, season), :] = [pd.NA]*len(centroid_distances.columns)
  centroid_distances.loc[(region, season), :] = [pd.NA]*len(centroid_distances.columns)
  centroid_distances.loc[(region, season), :] = [pd.NA]*len(centroid_distances.columns)
  centroid_distances.loc[(region, season), :] = [pd.NA]*len(centroid_distances.columns)
  centroid_distances.loc[(region, season), :] = [pd.NA]*len(centroid_distances.columns)


In [40]:
centroid_distances.reset_index().to_csv("../analysis/antigenic_comparison/centroid_distances.csv", index=False)

In [41]:
median_centroid_distances = []
summary_centroid_distances = []
for (region, season) in centroid_distances.index.unique():
    df = centroid_distances.loc[(region, season)].dropna()
    if len(df) == 0:
        summary_centroid_distances.append([region, season,pd.NA,pd.NA,pd.NA])
        for c in  ["DS to VS", "DS to CSS", "DS to LSS"]:
            median_centroid_distances.append([region, season,c, pd.NA,pd.NA,pd.NA])

        continue
    
    sd = []
    for c in  ["DS to VS", "DS to CSS", "DS to LSS"]:
        median = round(np.median(df[c]),3)
        l_iqr = round(np.quantile(df[c], 0.25),3)
        u_iqr = round(np.quantile(df[c], 0.75),3)

        sd.append(median)
        median_centroid_distances.append([region, season, c, median, l_iqr, u_iqr])

    summary_centroid_distances.append([region, season]+sd)

median_centroid_distances = pd.DataFrame.from_records(median_centroid_distances, columns=["region", "season", "comparison", "median", "lower IQR", "upper IQR"]).set_index(["region", "season"]).sort_index()
summary_centroid_distances = pd.DataFrame.from_records(summary_centroid_distances, columns=["region", "season", "DS to VS", "DS to CSS", "DS to LSS"]).set_index(["region", "season"]).sort_index()