In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd
import time
import datetime
import pickle
from copy import deepcopy
import re
import os

In [None]:
import sys

print(sys.version)

In [None]:
datadir = "/storage/user/GenBank/"
tsvfile = datadir + "metadata_dec01.tsv"

In [None]:
# Load the tsv metadata file into a pandas DataFrame
df = pd.read_csv(tsvfile, sep='\t', header=0)

In [None]:
df # Get a quick overview of the data

In [None]:
# Filtering the data:

#country = "United Kingdom"
country = "USA"
df_USA = df.loc[df['country'] == country]

#country = "World"
#df_USA = df

#country = "NonEurNA" # World except Europe and North America
#df_USA = df.loc[df['region'] != "Europe"] # Filter Europe out
#df_USA = df_USA.loc[df_USA['region'] != "North America"] # Filter NA out

#df_USA = df.loc[df['country'] == "USA"]
#df_USA = df.loc[df['country'] == "United Kingdom"]
#df_USA = df.loc[df['region'] == "Europe"]
#df_USA = df.loc[df['region'] == "Asia"]

# Filter to get only human sequences ...

df_USA = df_USA.loc[df_USA['host'] == "Homo sapiens"]

# Only accept those that have overall quality "good"

df_USA = df_USA.loc[df_USA['QC_overall_status'] == "good"]

print(f"Matched {len(df_USA)} sequences")

In [None]:
df_USA # We call the filtered dataframe df_USA in every case.

In [None]:
# A simple way to measure Hamming distance only from the tsv data:
def dist_from_subs_dels_ins(subs0, subs1, dels0, dels1, ins0, ins1):
    subs0 = str(subs0).split(",")
    subs1 = str(subs1).split(",")
    dels0_prelim = str(dels0).split(",")
    dels1_prelim = str(dels1).split(",")
    
    # Structure of insertions is e.g.
    # 2820:GCT,19230:ACGT
    # if no insertions are present, it is nan.
    
    ins0 = str(ins0).split(",")
    ins0_true = []
    for ip in ins0:
        if ip=='nan':
            pass
        else:
            ip = ip.split(":")
            pos_beg = int(ip[0])
            nucs = ip[1]
            pos = pos_beg
            for nuc in nucs:
                ins0_true.append((pos,nuc))
                pos += 1
    i0 = ins0_true
    
    ins1 = str(ins1).split(",")
    ins1_true = []
    for ip in ins1:
        if ip=='nan':
            pass
        else:
            ip = ip.split(":")
            pos_beg = int(ip[0])
            nucs = ip[1]
            pos = pos_beg
            for nuc in nucs:
                ins1_true.append((pos,nuc))
                pos += 1
    i1 = ins1_true
    
    dels0_true = []
    dels1_true = []
    
    for dp in dels0_prelim:
        if "-" in dp: # It's a range
            dpsplit = dp.split("-")
            dp_beg_end = (int(dpsplit[0]), int(dpsplit[1]))
            dpdiff = dp_beg_end[1]-dp_beg_end[0]
            for i in range(0,dpdiff+1):
                dels0_true.append(deepcopy(dp_beg_end[0]) + i)
        else: # It's a single deletion (or none)
            if dp=='nan':
                pass
            else:
                dels0_true.append(int(deepcopy(dp)))
    
    for dp in dels1_prelim:
        if "-" in dp: # It's a range
            dpsplit = dp.split("-")
            dp_beg_end = (int(dpsplit[0]), int(dpsplit[1]))
            dpdiff = dp_beg_end[1]-dp_beg_end[0]
            for i in range(0,dpdiff+1):
                dels1_true.append(deepcopy(dp_beg_end[0]) + i)
        else: # It's a single deletion (or none)
            if dp=='nan':
                pass
            else:
                dels1_true.append(deepcopy(int(dp)))

    d0 = dels0_true
    d1 = dels1_true
            
    dist = 0

    for sub in subs0:
        if sub in subs1 or sub=='nan':
            pass
        else:
            dist += 1

    for sub in subs1:
        if sub in subs0 or sub=='nan':
            pass
        else:
            dist += 1

    for d in d0:
        if d in d1:
            pass
        else:
            dist += 1

    for d in d1:
        if d in d0:
            pass
        else:
            dist += 1
    
    for i in i0:
        if i in i1:
            pass
        else:
            dist += 1
    
    for i in i1:
        if i in i0:
            pass
        else:
            dist += 1
            
    return dist

In [None]:
def get_avg_dists(df_match, downsample=False):
    if downsample:
        comparisons = 50
    else:
        comparisons = 5000
    comps_done = 0
    dists = []
    while comps_done < comparisons:
        i = random.randint(0,len(df_match)-1)
        j = random.randint(0,len(df_match)-1)
        while i==j:
            j = random.randint(0,len(df_match)-1)
        dist_loc = dist_from_subs_dels_ins(df_match.iloc[i].substitutions,df_match.iloc[j].substitutions,df_match.iloc[i].deletions,df_match.iloc[j].deletions,df_match.iloc[i].insertions,df_match.iloc[j].insertions)
        
        dists.append(dist_loc)
        comps_done += 1
    return dists

def get_avg_dists_two_sets(df1, df2, downsample=False):
    if downsample:
        comparisons = 50
    else:
        comparisons = 5000
    comps_done = 0
    dists = []
    while comps_done < comparisons:
        i = random.randint(0,len(df1)-1)
        j = random.randint(0,len(df2)-1)
        dist_loc = dist_from_subs_dels_ins(df1.iloc[i].substitutions,df2.iloc[j].substitutions,df1.iloc[i].deletions,df2.iloc[j].deletions,df1.iloc[i].insertions,df2.iloc[j].insertions)
        dists.append(dist_loc)
        comps_done += 1
    return dists

In [None]:
# Parallelized Hamming distribution computations:

from multiprocessing import Pool
from matplotlib.ticker import FuncFormatter

def generate_heatmaps_parready(paramdict):
    date_beg_str = paramdict["date_beg_str"]
    date_end_str = paramdict["date_end_str"]
    df_ref = paramdict["df_ref"]
    i = paramdict["i"]
    df_USA = paramdict["df"]
    print("i:", i)
    
    savefigs = False
    downsample = False

    df_match = df_USA.loc[df_USA['date'] >= date_beg_str]
    df_match = df_match.loc[df_match['date'] < date_end_str]
    if len(df_match) < 10:
        dists = [0,0,0,0,0,0,0,0,0]
    else:
        dists = get_avg_dists(df_match, downsample=downsample)
        #dists = get_avg_dists_two_sets(df_match, df_ref, downsample=downsample)
        if savefigs:
            plt.ioff()
            fontsize=10
            font = {'family' : 'sans',
                    'weight' : 'normal',
                    'size'   : fontsize}
            mystyle = 'seaborn'
            plt.style.use(mystyle)
            plt.rc('font', **font)

            __, __, __ = plt.hist(dists, range=[0, 150], bins=151, density=True)
            plt.xlim([0,150])
            plt.ylim([0,0.1])
            plt.xlabel("Hamming distance")
            plt.ylabel("Frequency")
            plt.title(date_beg_str)
            plt.savefig(datadir + str(i) + ".png", dpi=175, facecolor='white', transparent=False, bbox_inches='tight', pad_inches=0)
            plt.clf()
    return dists

def generate_heatmaps_parready_preread(paramdict):
    date_beg_str = paramdict["date_beg_str"]
    date_end_str = paramdict["date_end_str"]
    df_ref = paramdict["df_ref"]
    df_match = paramdict["df_match"]
    i = paramdict["i"]
    print("i:", i)
    
    savefigs = False
    downsample = False

    
    if len(df_match) < 10:
        dists = [0,0,0,0,0,0,0,0,0]
    else:
        dists = get_avg_dists(df_match, downsample=downsample)
        #dists = get_avg_dists_two_sets(df_match, df_ref, downsample=downsample)
        if savefigs:
            plt.ioff()
            fontsize=10
            font = {'family' : 'sans',
                    'weight' : 'normal',
                    'size'   : fontsize}
            mystyle = 'seaborn'
            plt.style.use(mystyle)
            plt.rc('font', **font)

            __, __, __ = plt.hist(dists, range=[0, 150], bins=151, density=True)
            plt.xlim([0,150])
            plt.ylim([0,0.1])
            plt.xlabel("Hamming distance")
            plt.ylabel("Frequency")
            plt.title(date_beg_str)
            plt.savefig(datadir + str(i) + ".png", dpi=175, facecolor='white', transparent=False, bbox_inches='tight', pad_inches=0)
            plt.clf()
    return dists

datadir = "/storage/user/GenBank/frames/"
datadir = datadir + country + "_dec01/"

timespan = 1005 # days

# Set a reference sequence (only used if utilizing get_avg_dists_two_sets() in generate_heatmaps_parready_preread() )
# Alpha:
#df_ref = df_USA.loc[df_USA['date'] >= "2021-03-15"]
#df_ref = df_ref.loc[df_ref['date'] < "2021-03-22"]
# Delta:
#df_ref = df_USA.loc[df_USA['date'] >= "2021-09-20"]
#df_ref = df_ref.loc[df_ref['date'] < "2021-09-27"]
# Omicron (BA1):
df_ref = df_USA.loc[df_USA['date'] >= "2022-01-01"]  # UK
df_ref = df_ref.loc[df_ref['date'] < "2022-01-08"] # UK
#df_ref = df_USA.loc[df_USA['date'] >= "2022-01-17"]  # USA
#df_ref = df_ref.loc[df_ref['date'] < "2022-01-24"] # USA
# Omicron BA2:
#df_ref = df_USA.loc[df_USA['date'] >= "2022-04-09"]
#df_ref = df_ref.loc[df_ref['date'] < "2022-04-16"]

# Starting date:
date_beg = datetime.datetime(2020, 3, 1)

distributions = []
dateranges = []


# First create list of parameter dictionaries:
# Generate all the starting dates:
date_beg_strs = []
date_end_strs = []
for i in range(timespan):
    date_beg_str = date_beg.strftime('%Y-%m-%d')
    date_end = date_beg
    date_end += datetime.timedelta(days=7)
    date_end_str = date_end.strftime('%Y-%m-%d')
    date_beg_strs.append(date_beg_str)
    date_end_strs.append(date_end_str)
    date_beg += datetime.timedelta(days=1)


already_prepared = False

if already_prepared:
    print("Data already initialized, proceeding directly to Hamming computations.")
    print("If this is not correct, abort computation immediately!")
    time.sleep(10)
    print("Commencing!")
else:
    paramdicts = []
    for i in range(len(date_beg_strs)):
        print(f"Preparing data for day {i+1} out of {len(date_beg_strs)}.")
        paramdicts.append({"date_beg_str":date_beg_strs[i], "date_end_str":date_end_strs[i], "df_ref": deepcopy(df_ref), "i": i, "df_match": deepcopy(df_USA.loc[(df_USA['date'] >= date_beg_strs[i]) & (df_USA['date'] < date_end_strs[i])])}  )

if __name__ == '__main__':
        pool = Pool(processes=35) # How many concurrent processes? 
        outarr = pool.map(generate_heatmaps_parready_preread, paramdicts)


In [None]:
os.system("mkdir -p " + '"' + datadir + '"')

## Populate datadict on the basis of parallell runs:
# First, generate the begdates list of 2-tuples:

savedata = True

dateranges = []
distributions = []

for i in range(len(date_beg_strs)):
    dateranges.append((date_beg_strs[i],date_end_strs[i]))
    distributions.append(outarr[i])
    

datadict = dict()
datadict["hammingdistributions"] = distributions
datadict["dateranges"] = dateranges

if savedata:
    rand_ID = str(random.randint(100000000,999999999))
    filename = "data_" + rand_ID
    pklname = filename + ".pkl"
    f = open(datadir + pklname, "wb")
    pickle.dump(datadict,f)
    f.close()
    print("Saved to:", datadir + pklname)

In [None]:
# Preparing timeseries data for plotting
# Load existing data like this:
#load_from = "/storage/user/GenBank/frames/United Kingdom_dec01/data_920481519.pkl"
#datadict = pickle.load( open(load_from, "rb" ) )


ddh = datadict["hammingdistributions"]
distributions = datadict["hammingdistributions"] # Synonym
ddd = datadict["dateranges"]
means = []
variances = []
deciles = []
begdates = []
enddates = []

print("len(ddh) =", len(ddh))

for i in range(1,10):
    deciles.append([])
i_s = []
for i in range(len(ddh)):
    means.append(np.mean(ddh[i]))
    variances.append(np.var(ddh[i]))
    for j in range(1,10):
        pctile = np.percentile(ddh[i], j*10)
        deciles[j-1].append(pctile)
    i_s.append(i)
    begdates.append(ddd[i][0])
    enddates.append(ddd[i][1])

print("len(begdates)=", len(begdates))

In [None]:
# Save the figures (assumes data analysis is already done):
for i in range(len(ddh)):
    date_beg_str = ddd[i][0]
    date_end = ddd[i][1]
    dists = ddh[i]
    __, __, __ = plt.hist(dists, range=[0, 160], bins=161, density=True)
    plt.xlim([0,160])
    plt.ylim([0,0.1])
    plt.xlabel("Hamming distance")
    plt.ylabel("Frequency")
    plt.title(date_beg_str)
    plt.savefig(datadir + str(i) + ".png", dpi=175, facecolor='white', transparent=False, bbox_inches='tight', pad_inches=0)
    plt.clf()

In [None]:
x_values = [datetime.datetime.strptime(d,"%Y-%m-%d").date() for d in begdates]
dec_idx=4

In [None]:
# Plot mean and median Hamming distances (as timeseries)

%matplotlib inline

from matplotlib.ticker import FuncFormatter

fontsize=10

font = {'family' : 'sans',
        'weight' : 'normal',
        'size'   : fontsize}

mystyle = 'seaborn'

plt.style.use(mystyle)
plt.rc('font', **font)


fig, ((ax1)) = plt.subplots(1, 1, dpi=275, figsize=(6,3.8))


cm = plt.cm.get_cmap('inferno')

dcol = 0.25
col0 = 0 # Color offset

# Set colours:
# Normal order
col_multipliers = [1, 2, 2.8]
# Reversed order:
col_reversed = True
if col_reversed:
    col_multipliers = list(reversed(col_multipliers))
col1 = cm(col0 + col_multipliers[0]*dcol)
col2 = cm(col0 + col_multipliers[1]*dcol)
col3 = cm(col0 + col_multipliers[2]*dcol)

ax = plt.gca()



plt.plot_date(begdates, means, fmt='', label="Mean", color=col1)
plt.plot_date(begdates, deciles[dec_idx], fmt='--', label=f"{(dec_idx+1)*10}% percentile", color=col2)

#plt.xlabel("Time (generations)")
plt.ylabel("Hamming distance")

ticklist = []
for i in range(len(begdates)):
    if begdates[i][-2:]=="01":
        if datetime.datetime.strptime(begdates[i],"%Y-%m-%d").month % 2==0:
            pass # Skip even months
        else:
            ticklist.append(begdates[i])
            print("Adding tick", ticklist[-1])
plt.xticks(ticklist, rotation=90)

print("First day:", begdates[0])
print("Last day:", begdates[-1])

plt.legend()

plt.xlim([-1,978])
print(begdates[978])
plt.ylim([0,145])

In [None]:
# Convert to matrix:

ddh = distributions

x_max = 200
ny = len(ddh) # time
nx = x_max
plotmat = np.zeros((ny,nx))
for i in range(ny):
    hist_n, hist_bins = np.histogram(ddh[i], density=True, range=[0, x_max], bins=nx)
    plotmat[i,:] = deepcopy(hist_n)
    print(f"t={i+1} out of {ny}")

In [None]:
# Plot Heatmap

%matplotlib inline

from matplotlib.ticker import FuncFormatter

fontsize=10

font = {'family' : 'sans',
        'weight' : 'normal',
        'size'   : fontsize}


#mystyle = 'seaborn-notebook'
mystyle = 'seaborn'

# bmh is quite good somehow, but maybe too heavy
plt.style.use(mystyle)
plt.rc('font', **font)
#plt.style.use("ggplot")

fig, ((ax)) = plt.subplots(1, 1, dpi=175, figsize=(6,8))

#last_idx = 793
last_idx = 978

#showmat = np.log(0.02+plotmat[:936,:175]) 
showmat = np.log(0.02+plotmat[:last_idx,:175]) 

cax = ax.imshow(showmat, interpolation='nearest', aspect='auto', cmap=plt.get_cmap('inferno'))

cbar = fig.colorbar(cax, ticks=[np.min(showmat[:,:]), np.max(showmat[:,:])])
cbar.ax.set_yticklabels(['0', '1'])  # vertically oriented colorbar
cbar.set_label('Frequency (logarithmic) within generation', rotation=270)

ax = plt.gca()

ax.grid(False)

#country_name = "UK"
#country_name = "USA"
country_name = country
#country_name = "Europe"

startdate = "2020-03-01"

#plt.xlabel("Hamming distance (nucleotides)")
plt.xlabel("Hamming distance")
#plt.ylabel(f"Day (since {startdate})")
plt.title(f"Hamming heatmap ({country_name})")

#ax.set_yticks(np.arange(0,936))
#ax.set_yticklabels(begdates)

ticklist = []
ticklabellist = []
for i in range(len(begdates)):
    if begdates[i][-2:]=="01":
        if datetime.datetime.strptime(begdates[i],"%Y-%m-%d").month % 2==0:
            pass # Skip even months
        else:
            ticklist.append(i)
            ticklabellist.append(begdates[i])
            print("Adding tick", ticklist[-1])
        
__ = plt.yticks(ticklist, labels=ticklabellist)

print("Last plotted range is", begdates[last_idx], "to", enddates[last_idx])