In [None]:
from Bio import SeqIO
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd
import time
import datetime
from copy import deepcopy
import os
#import Levenshtein

In [None]:
import sys

print(sys.version)

In [None]:
datadir = "/storage/user/GenBank/NS_full/"
fastafile = datadir + "genbank_full_aligned.fasta"
tsvfile = datadir + "genbank_full.tsv"

In [None]:
# Load the tsv metadata file into a pandas DataFrame
df = pd.read_csv(tsvfile, sep='\t', header=0)

In [None]:
uniq_ss = df[df["host"]=="Avian"]["serotype"].unique()
print(uniq_ss)
for stype in uniq_ss:
    print(f'There are {np.sum(df[df["host"]=="Avian"]["serotype"]==stype)} avian sequences of type {stype}')

In [None]:
def naive_ham(ss1, ss2):
    d = 0
    for i in range(len(ss1)):
        if (ss1[i] != ss2[i]) and (ss1[i] != "-") and (ss2[i] != "-"):
            d += 1
    return d

In [None]:
seqs = []
ids = []
dates = []
hosts = []
serotypes = []
georegion = []

AmericasCountries = ['USA', 'Brazil', 'Chile', 'Peru', 'Argentina', 'Colombia', 'Venezuela', 'El Salvador',
                     'Ecuador', 'Guam', 'Nicaragua', 'Mexico', 'Uruguay',
                     'Panama', 'Puerto Rico', 'Bolivia', 'Dominican Republic','Honduras']

count = 0

for record in SeqIO.parse(fastafile, "fasta"):
    accession = record.id
    if "|" in accession:
        accession = accession.split("|")[1]
    df_acc = df.loc[df['accession'] == accession]
    host = df_acc["host"].values[0]
    serotype = df_acc["serotype"].values[0]
    country = df_acc["country"].values[0]
    skip = False
    try:
        date = int(df_acc["date"].values[0].split("/")[0])
    except:
        skip = True
    if skip:
        print("Skipping due to missing date!!")
    else:
        dates.append(date)
        seqs.append(record.seq)
        ids.append(accession)
        hosts.append(host)
        serotypes.append(serotype)
        if country in AmericasCountries:
            georegion.append("Americas")
        else:
            georegion.append("Other")
    print(count, ":", accession, f"\t ({record.id}), date: {date}")
    count += 1

In [None]:
ids_arr = np.array(ids)
np.where(ids_arr=="AF333238") # AF333238, H1N1, 1918, Brevig

In [None]:
# Prepare alpha values so as to represent years with different sampling density
# in a legible manner.
def compute_alphas(years_out):
    years_consec = []
    seqs_per_year = []
    for year in range(np.min(years_out), np.max(years_out)):
        seqs_yr = np.sum(years_out==year)
        #print(f"Found {seqs_yr} sequences in {year}.")
        years_consec.append(year)
        seqs_per_year.append(seqs_yr)

    years_consec = np.array(years_consec)
    seqs_per_year = np.array(seqs_per_year)

    alphas = []
    for i in range(len(years_out)):
        n_yr = seqs_per_year[years_consec==years_out[i]]
        if len(n_yr)>0:
            n_yr = n_yr[0]
            #print(n_yr)
            alph = 1/max(n_yr,1)**0.7
            alph = max(0.1, alph)
            alphas.append(alph)
        else:
            alphas.append(1)

    alphas = np.array(alphas)
    return alphas

In [None]:
def get_hams_to_ref(i0):
    years_out = []
    dists_out = []
    hosts_out = []
    serotypes_out = []
    georegions_out = []

    print(f"Rooting in {ids[i0]}, {serotypes[i0]}, {dates[i0]}")

    for i in range(len(seqs)):
        years_out.append(dates[i])
        dists_out.append(naive_ham(str(seqs[i0]),str(seqs[i])))
        hosts_out.append(hosts[i])
        serotypes_out.append(serotypes[i])
        georegions_out.append(georegion[i])
        if i % 10000 == 0:
            print(i)
    print(f"Done (ref: {ids[i0]})")
    return years_out, dists_out, hosts_out, serotypes_out, georegions_out

In [None]:
ys = np.array(years_out)
ds = np.array(dists_out)
ss = np.array(serotypes_out)
gs = np.array(georegions_out)
hs = np.array(hosts_out)

# Different masks can be used to search for sequences
# matching certain criteria:

#mask = (ys == 1918) * (ss == "H1N1") * (hs == "Human")
#mask = (ys == 2019) * (ss == "H1N1") * (hs == "Human") * (ds < 30) * (ds > 25)
#mask = (ys == 1950) * (ss == "H1N1") * (hs == "Human") * (ds < 50) * (ds > 25)
#mask = (ys == 1965) * (ss == "H2N2") * (hs == "Human") * (ds < 75) * (ds > 50)
#mask = (ys == 1973) * (ss == "H3N2") * (hs == "Human") * (ds < 80) * (ds > 60)
#mask = (ys == 1977) * (ss == "H1N1") * (hs == "Human") * (ds < 50) * (ds > 25)
#mask = (ys == 2005) * (ss == "H3N2") * (hs == "Human") * (ds < 100) * (ds > 75)
#mask = (ys == 2019) * (ss == "H3N2") * (hs == "Human") * (ds < 110) * (ds > 90)
#mask = (ys == 2009) * (ss == "H1N1") * (hs == "Human") * (ds < 85) * (ds > 70) # non-pdm
#mask = (ys == 2009) * (ss == "H1N1") * (hs == "Human") * (ds < 140) * (ds > 115) # pdm

#mask = (ys == 2002) * (ss == "H1N1") * (hs == "Human") * (ds < 15) * (ds >= 0) # pdm

mask = (ys == 1918)

for i in range(len(np.where(mask)[0])):
    print(f"Rooting in idx={np.where(mask)[0][i]}, {ids[np.where(mask)[0][i]]}, {serotypes[np.where(mask)[0][i]]}, {dates[np.where(mask)[0][i]]}")

    
print("All of them:", np.where(mask)[0])    

selected = np.where(mask)[0]

In [None]:
this_host = "Avian"
uniq_ss = np.unique(ss[hs==this_host])
print(uniq_ss)
for stype in uniq_ss:
    n_loc = np.sum(ss[hs==this_host]==stype)
    if n_loc > 1000:
        print(f'--- There are {n_loc} ' + this_host + f' sequences of type {stype}')
    else:
        print(f'Only {n_loc} ' + this_host + f' sequences of type {stype}')

In [None]:
pd.set_option('display.max_rows', 500)
df[df["serotype"]=="H5N1"]

## Variable scatter plot alpha values:

In [None]:
with_avian = True
with_other = True

print("Reminder: Did you wish to recompute Hamming distances?") 

# De-comment this line to recompute Hamming distances,
# rooted in a sequence with index i0:
#years_out, dists_out, hosts_out, serotypes_out, georegions_out = get_hams_to_ref(i0)

ys = np.array(years_out)
ds = np.array(dists_out)
ss = np.array(serotypes_out)
gs = np.array(georegions_out)
hs = np.array(hosts_out)

years_out = np.array(years_out)
dists_out = np.array(dists_out)
hosts_out = np.array(hosts_out)
serotypes_out = np.array(serotypes_out)

fontsize=10

font = {'family' : 'sans',
        'weight' : 'normal',
        'size'   : fontsize}


#mystyle = 'seaborn-notebook'
mystyle = 'seaborn'

# bmh is quite good somehow, but maybe too heavy
plt.style.use(mystyle)
plt.rc('font', **font)
#plt.style.use("ggplot")

#plt.figure(figsize=(5, 4), dpi=175)
fig, ((ax1)) = plt.subplots(1, 1, dpi=170, figsize=(6,3.8))


#plt.scatter(years_out[hosts_out=="Human"], dists_out[hosts_out=="Human"], s=4, color="red", alpha=0.2)
#plt.scatter(years_out[hosts_out=="Swine"], dists_out[hosts_out=="Swine"], s=4, color="blue", alpha=0.3)

H1N1syns = ["H1N1", "H1", "N1"]
H3N2syns = ["H3N2", "H3", "N2"]
H5N1syns = ["H5N1"]
H7N9syns = ["H7N9"]

#knownsum = H1N1syns+H3N2syns+H5N1syns+H7N9syns
knownsum = H1N1syns+H3N2syns+H5N1syns+["H2N2"]

# Avian:
if with_avian:
    mask = ((hosts_out=="Avian")).astype("bool")
    alphas = compute_alphas(years_out[mask])
    plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=0.5*alphas, label="Avian", color="blue")

## Swine

mask = ((hosts_out=="Swine") * np.isin(georegions_out, ["Americas"], invert=False)).astype("bool")
alphas = compute_alphas(years_out[mask])
plt.scatter(years_out[mask], dists_out[mask], marker="_", s=4, alpha=alphas, label="Swine (Americas)", color="lawngreen")
mask = ((hosts_out=="Swine") * np.isin(georegions_out, ["Americas"], invert=True)).astype("bool")
alphas = compute_alphas(years_out[mask])
plt.scatter(years_out[mask], dists_out[mask], marker="_", s=4, alpha=alphas, label="Swine (Other)", color="green")

#mask = ((hosts_out=="Swine")).astype("bool")
#alphas = compute_alphas(years_out[mask])
#plt.scatter(years_out[mask], dists_out[mask], marker="_", s=4, alpha=alphas, label="Swine", color="lawngreen")

## Humans

mask = ((hosts_out=="Human") * np.isin(serotypes_out, H3N2syns)).astype("bool")
alphas = compute_alphas(years_out[mask])
plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas, label="Human H3N2", color="orange")

#mask = ((hosts_out=="Human") * np.isin(serotypes_out, H5N1syns)).astype("bool")
#alphas = compute_alphas(years_out[mask])
#plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas[mask], label="Human H5N1", color="purple")

#mask = ((hosts_out=="Human") * np.isin(serotypes_out, H7N9syns)).astype("bool")
#plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas[mask], label="Human H7N9", color="skyblue")

mask = ((hosts_out=="Human") * np.isin(serotypes_out, ["H2N2"], invert=False)).astype("bool")
alphas = compute_alphas(years_out[mask])
plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas, label="Human H2N2", color="yellow")

mask = ((hosts_out=="Human") * np.isin(serotypes_out, H1N1syns)).astype("bool")
alphas = compute_alphas(years_out[mask])
plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas, label="Human H1N1", color="red", zorder=2)

mask = ((hosts_out=="Human") * np.isin(serotypes_out, knownsum, invert=True)).astype("bool")
alphas = compute_alphas(years_out[mask])
plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas, label="Human other", color="cyan")

#mask = ((hosts_out=="Swine") * np.isin(serotypes_out, H1N1syns)).astype("bool")
#plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=0.2, label="Swine H1N1")
#mask = ((hosts_out=="Swine") * np.isin(serotypes_out, H1N1syns,  invert=True)).astype("bool")
#plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=0.2, label="Swine other")

if with_other:
    if with_avian:
        mask = (np.isin(hosts_out, ["Human", "Swine", "Avian"], invert=True)).astype("bool")
        alphas = compute_alphas(years_out[mask])
        plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=0.3*alphas, label="Other hosts", color="black")
    else:
        mask = (np.isin(hosts_out, ["Human", "Swine"], invert=True)).astype("bool")
        alphas = compute_alphas(years_out[mask])
        plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=0.3*alphas, label="Other hosts", color="black", zorder=0.9)

plt.ylim([-5, 200])
plt.ylabel("Nucleotide distance")
plt.xlabel("Year")
plt.title(f"Ref: {ids[i0]}, {serotypes_out[i0]}, {years_out[i0]}, {hosts_out[i0]}") 

leg = plt.legend(markerscale=3)
#plt.legend(markerscale=6)
for lh in leg.legendHandles: 
    lh.set_alpha(1)

dt_string = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
outfile = "ref_" + f"{ids[i0]}_{years_out[i0]}_{serotypes_out[i0]}_{hosts_out[i0]}" + "_" + dt_string
if with_avian:
    outfile += "_incl_avian"
outfile += ".png"
print(outfile)
outdir = "./"
plt.savefig(outdir + outfile)


In [None]:
mask = (np.isin(hosts_out, ["Human", "Swine", "Avian"], invert=True)).astype("bool")
100*(1-np.sum(mask)/len(hosts_out))

## Data for multi-panel plot:

In [None]:
## Data for multi-panel plot:
i0s = [2998, 105, 6856]
datas = []
for i0 in i0s:
    datas.append(get_hams_to_ref(i0))

#i0 = 105 # Brevig sequence, 
#i0 = 6856 # FJ966086, H1N1, 2009, pdm
#i0 = 2998 # CY015056, Avian, H7N7, 1902, dist=37
# The three 1902s are:
# CY015056, Avian, H7N7, 1902, dist=37
# idx=8728, GU186781, Avian, H7N7, 1902, dist=38
# L37798, Avian, H7N7, 1902, dist=43

#years_out, dists_out, hosts_out, serotypes_out, georegions_out = get_hams_to_ref(i0)

## Tri-panel plot (1902, 1918, 2009) for supplement:

In [None]:
datas_here = datas
i0s_here = i0s
fig, axes = plt.subplots(1, len(datas_here), dpi=170, figsize=(6*len(datas_here),3.8))

print("axes.shape:", axes.shape)

for i in range(len(datas_here)):
    years_out, dists_out, hosts_out, serotypes_out, georegions_out = datas_here[i]

    ys = np.array(years_out)
    ds = np.array(dists_out)
    ss = np.array(serotypes_out)
    gs = np.array(georegions_out)
    hs = np.array(hosts_out)

    years_out = np.array(years_out)
    dists_out = np.array(dists_out)
    hosts_out = np.array(hosts_out)
    serotypes_out = np.array(serotypes_out)
    
    H1N1syns = ["H1N1", "H1", "N1"]
    H3N2syns = ["H3N2", "H3", "N2"]
    H5N1syns = ["H5N1"]
    H7N9syns = ["H7N9"]

    #knownsum = H1N1syns+H3N2syns+H5N1syns+H7N9syns
    knownsum = H1N1syns+H3N2syns+H5N1syns+["H2N2"]

    # Avian:
    if with_avian:
        mask = ((hosts_out=="Avian")).astype("bool")
        alphas = compute_alphas(years_out[mask])
        axes[i].scatter(years_out[mask], dists_out[mask], s=4, alpha=0.5*alphas, label="Avian", color="blue")

    ## Swine

    mask = ((hosts_out=="Swine") * np.isin(georegions_out, ["Americas"], invert=False)).astype("bool")
    alphas = compute_alphas(years_out[mask])
    axes[i].scatter(years_out[mask], dists_out[mask], marker="_", s=4, alpha=alphas, label="Swine (Americas)", color="lawngreen")
    mask = ((hosts_out=="Swine") * np.isin(georegions_out, ["Americas"], invert=True)).astype("bool")
    alphas = compute_alphas(years_out[mask])
    axes[i].scatter(years_out[mask], dists_out[mask], marker="_", s=4, alpha=alphas, label="Swine (Other)", color="green")

    #mask = ((hosts_out=="Swine")).astype("bool")
    #alphas = compute_alphas(years_out[mask])
    #plt.scatter(years_out[mask], dists_out[mask], marker="_", s=4, alpha=alphas, label="Swine", color="lawngreen")

    ## Humans

    mask = ((hosts_out=="Human") * np.isin(serotypes_out, H3N2syns)).astype("bool")
    alphas = compute_alphas(years_out[mask])
    axes[i].scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas, label="Human H3N2", color="orange")

    #mask = ((hosts_out=="Human") * np.isin(serotypes_out, H5N1syns)).astype("bool")
    #alphas = compute_alphas(years_out[mask])
    #plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas[mask], label="Human H5N1", color="purple")

    #mask = ((hosts_out=="Human") * np.isin(serotypes_out, H7N9syns)).astype("bool")
    #plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas[mask], label="Human H7N9", color="skyblue")

    mask = ((hosts_out=="Human") * np.isin(serotypes_out, ["H2N2"], invert=False)).astype("bool")
    alphas = compute_alphas(years_out[mask])
    axes[i].scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas, label="Human H2N2", color="yellow")

    mask = ((hosts_out=="Human") * np.isin(serotypes_out, H1N1syns)).astype("bool")
    alphas = compute_alphas(years_out[mask])
    axes[i].scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas, label="Human H1N1", color="red", zorder=2)

    mask = ((hosts_out=="Human") * np.isin(serotypes_out, knownsum, invert=True)).astype("bool")
    alphas = compute_alphas(years_out[mask])
    axes[i].scatter(years_out[mask], dists_out[mask], s=4, alpha=alphas, label="Human other", color="cyan")

    #mask = ((hosts_out=="Swine") * np.isin(serotypes_out, H1N1syns)).astype("bool")
    #plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=0.2, label="Swine H1N1")
    #mask = ((hosts_out=="Swine") * np.isin(serotypes_out, H1N1syns,  invert=True)).astype("bool")
    #plt.scatter(years_out[mask], dists_out[mask], s=4, alpha=0.2, label="Swine other")

    if with_other:
        if with_avian:
            mask = (np.isin(hosts_out, ["Human", "Swine", "Avian"], invert=True)).astype("bool")
            alphas = compute_alphas(years_out[mask])
            axes[i].scatter(years_out[mask], dists_out[mask], s=4, alpha=0.3*alphas, label="Other hosts", color="black")
        else:
            mask = (np.isin(hosts_out, ["Human", "Swine"], invert=True)).astype("bool")
            alphas = compute_alphas(years_out[mask])
            axes[i].scatter(years_out[mask], dists_out[mask], s=4, alpha=0.3*alphas, label="Other hosts", color="black", zorder=0.9)
    axes[i].set_ylim([-5, 260])
    if i>0:
        axes[i].yaxis.set_ticklabels([])
    axes[i].set_title(f"Ref: {ids[i0s_here[i]]}, {serotypes_out[i0s_here[i]]}, {years_out[i0s_here[i]]}, {hosts_out[i0s_here[i]]}")

plt.subplots_adjust(wspace=0.02, hspace=0)
leg = axes[0].legend(markerscale=3)
#plt.legend(markerscale=6)
for lh in leg.legendHandles: 
    lh.set_alpha(1)

dt_string = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
outfile = "multiplot" + "_" + dt_string
if with_avian:
    outfile += "_incl_avian"
outfile += ".png"
print(outfile)
outdir = "./"
plt.savefig(outdir + outfile, facecolor='white', transparent=False, bbox_inches='tight', pad_inches=0)

## Unrooted Hamming Distribution -- "Diversity plots"

### Get the data (set accession number):

In [None]:
## Get the data:
# Set the Accession number of the root sequence:
acc = "GQ166657" # Human 2009pdm
#acc = "OP458646" # Swine 2020

i0 = np.where(np.array(ids)==acc)[0][0]
years_out, dists_out, hosts_out, serotypes_out, georegions_out = get_hams_to_ref(i0)

ys = np.array(years_out)
ds = np.array(dists_out)
ss = np.array(serotypes_out)
gs = np.array(georegions_out)
hs = np.array(hosts_out)

### Filter sequences and select year-by-year

In [None]:
# Filter sequences and select year-by-year
nx = 220

startyear = 2009
endyear = 2023
nyears = endyear-startyear+1
plotmat = np.zeros((nyears,nx))
k=0
n_comps_max = 5000
for yr in range(startyear, endyear+1):
    ds_crit = 25+abs((yr-2009))*(40-25)/14
    mask = (ys == yr) * (hs == "Swine") * (ds <= ds_crit)
    
    selected = np.where(mask)[0]
    print(yr, "selected", len(selected), "sequences at ds_crit =", ds_crit)
    if len(selected) < 5:
        pass
    else:
        ncomps = min(n_comps_max, len(selected)**2)
        Dvec = np.zeros((ncomps))
        for c in range(ncomps):
            idx_loc = np.random.choice(selected, size=2, replace=False)
            Dvec[c] = naive_ham(str(seqs[idx_loc[0]]),str(seqs[idx_loc[1]]))
        hist_n, hist_bins = np.histogram(Dvec, density=True, range=[0, nx-1], bins=nx)
        plotmat[k,:] = deepcopy(hist_n)
    k+=1

### Plot the resulting Hamming heatmap

In [None]:
# Plot the resulting Hamming heatmap
scalefactor = 0.75
fig, ((ax)) = plt.subplots(1, 1, dpi=275, figsize=(scalefactor*2,scalefactor*8))
  
# make bar
showmat = np.log(0.02 + plotmat[:,:])
cax = ax.imshow(showmat, interpolation='nearest', aspect='auto', cmap=plt.get_cmap('inferno'))

cbar = fig.colorbar(cax)

ticklist = []
ticklabellist = []
for i in range(startyear, endyear+1):
    ticklist.append(i-startyear)
    ticklabellist.append(str(i))
    #print("Adding tick", ticklist[-1])
        
__ = plt.yticks(ticklist, labels=ticklabellist)

cbar.ax.set_yticks([np.min(showmat), np.max(showmat) ])  # vertically oriented colorbar
cbar.ax.set_yticklabels(['Low', 'High'])  # vertically oriented colorbar
#cbar.set_label('Frequency of pairwise distance (logarithmic)', rotation=270)

#plt.xlabel("Hamming distance (nucleotides)")

ax.grid(False)

dt_string = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
outfile = "hamdist" + "_" + dt_string
outfile += ".png"
print(outfile)
outdir = "./"
#plt.savefig(outdir + outfile, facecolor='white', transparent=False, bbox_inches='tight', pad_inches=0)