# NGS mapping  
  
This jupyter notebook maps the raw FASTQ sequencing reads of all samples to the given reference fasta (H1N1pdm09_Cali09 full genome).

# Import libraries

In [None]:
from os.path import expanduser
from importlib.machinery import SourceFileLoader

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import pandas as pd 
import numpy as np
import re
import os
import subprocess

# load custom flu and ngs libraries 
laeb_lib = expanduser("../python_lib") # folder where custom libraries are saved 
fc = SourceFileLoader('fc', "%s/flu_common.py"%(laeb_lib)).load_module()
ngs = SourceFileLoader('ngs', "%s/laeb_ngs_pipeline.py"%(laeb_lib)).load_module()

# Inputs

In [None]:
# inputs 
# file path to data folder - fastq files to be analysed must be in {data folder}/raw
data_folder = './data' 
# reference fasta file name (should be placed in input_folder)
ref_fasta_fname = './input/H1N1pdm09_Cali09.fasta' 
# CSV file containing the CDR regions of each gene segment (numbering should be based on that of the given reference sequence)
cds_coords = "./input/CDS_H1N1pdm09_Cali09.csv"
nucpos_shift = "./input/CDS_shift_H1N1pdm09_Cali09.csv"
# file path to metadata file. 
meta_fname = './input/metadata.csv' 

# mapping options
trimmomatic_fpath = expanduser('~/opt/anaconda3/pkgs/trimmomatic-0.39-1/share/trimmomatic-0.39-1/') # file path to trimmomatic
threadnum = 4 # number of CPU threads for parallelization 
base_qual_threshold = 20 # minimum accepted base quality 
max_indel_prop = 0.05 # max tolerable proportion of indels wrt read length 
max_indel_abs = 10 # max tolerable absolute number of indels 

# variant calling options
Query_HAnum_subtype = 'absH1pdm' # query HA numbering subtype (i.e. numbering based on CDR HA protein )
HAnum_subtype = 'H3' # reporting HA numbering subtype
subtype_ant = 'H1ant'  # HA canonical antigenic site of interest 
min_cov = 50 # minimum coverage 
min_var_freq = 0
min_var_prop = 0.02 # minimum variant proportion 
err_tol = 0.01 # threshold to which variant called could result from base calling error 
min_breadth = 0.7 # min breadth of gene segment to be mapped for further analysis 

# Parameters and functions 

This cell perform several initialisation procedures, including: 
 - defining parameters needed by the pipeline (e.g. gene segment length, etc.) and initialise to get CDR regions of each protein.
 - defining dataframe for HA numbering conversion

In [None]:
# presets 
reffasta = ref_fasta_fname

# initialise
ngs = SourceFileLoader('ngs', "%s/laeb_ngs_pipeline.py"%(laeb_lib)).load_module()
gene_to_proteinorf, influenza_gene_len, sorted_refnames, nucpos_shift = ngs.initialisation(cds_coords, reffasta, laeb_lib, nucpos_shift=nucpos_shift)
display (gene_to_proteinorf.head())

ha_numbering_conversion = pd.read_csv(expanduser('%s/HA_numbering_conversion.csv'%(laeb_lib)),
                                      na_values='-')
ha_numbering_conversion = ha_numbering_conversion.set_index(Query_HAnum_subtype)
display (ha_numbering_conversion.head())

all_bases = ['A', 'T', 'G', 'C']

## Read metadata

Sample IDs are parsed from metadata file under the header "sampid".

In [None]:
# metadata 
## metadata must have 'sampid' header which is used as sample identifier 
meta_df = pd.read_csv('{}'.format(meta_fname))
sorted_sampid = sorted(set(meta_df['sampid']))
meta_df['date'] = pd.to_datetime(meta_df['date'])
# round CT to nearest integer 
meta_df['ct'] = np.around(meta_df['CT'], 0)
meta_df = meta_df.sort_values(by=['project', 'sampid']).set_index('sampid')
display (meta_df.head())

In [None]:
# get path to raw FASTQ files sorted by read direction 
dat_df = []

for sampid in sorted_sampid: 
    for fname in os.listdir('{}/raw'.format(data_folder)): 
        if re.search('{}.gz'.format(sampid), fname): 
            dat_df.append({'sampid':sampid, 'fpath':'{}/raw/{}'.format(data_folder, fname)})#, 'read':read, })

dat_df = pd.DataFrame.from_dict(dat_df)
dat_df = dat_df.set_index(['sampid'])#, 'read'])
dat_df = dat_df.sort_index()
dat_df.head()

# Perform ```FASTQC```

```FastQC``` checks the quality of the raw FASTQ files (i.e. $\ge$90% of reads has above acceptable quality score), determine the crop length for ```trimmomatic``` and ensure that there are negligible amount of adapter sequences present. 

In [None]:
if not os.path.isdir('{}/fastqc'.format(data_folder)): 
    os.mkdir('{}/fastqc'.format(data_folder))

def analyse_fastqc(fname): 
    fhandle = open(fname, 'r').readlines()
    module = ''
    df_dict = {}
    
    for line in fhandle: 
        if re.search('>>Per base sequence quality', line): 
            module = 'pbsq' 
            continue
        if re.search('>>Adapter Content', line): 
            # check whether there is adapter content 
            module = 'ac'
            continue
        if re.search('>>Per sequence quality scores', line):
            module = 'psqs'
            continue
        if module != '' and re.search('>>END_MODULE', line):
            df_dict[module] = pd.DataFrame(df)
            module = ''
            
        if module != '':
            if re.search('^#', line):
                headers = re.sub('#', '', line.strip()).split('\t')
                df = {h:[] for h in headers}
            else: 
                values = line.strip().split('\t')
                for _h, h in enumerate(headers):
                    df[h].append(values[_h])

    # % of reads with acceptable phred score > base_qual_threshold
    df = df_dict['psqs']
    df['Quality'] = pd.to_numeric(df['Quality'])
    df['Count'] = pd.to_numeric(df['Count'])
    percent_abv_qualthres = sum(df[df['Quality']>=base_qual_threshold]['Count'])/sum(df['Count'])
    
    # max position where median phred score > base_qual_threshold
    df = df_dict['pbsq']
    df['Median'] = pd.to_numeric(df['Median'])
    
    try: 
        start_pos = int(re.search('^(\d+)-*', df[df['Median']>=base_qual_threshold]['Base'].iloc[0]).group(1))
    except: 
        start_pos = None 
    try: 
        end_pos = int(re.search('-*(\d+)$', df[df['Median']>=base_qual_threshold]['Base'].iloc[-1]).group(1))
    except: 
        end_pos = None
    
    df = df_dict['ac'].drop('Position', axis=1)
    ac = {}
    for adapter in list(df):
        val = np.sum(pd.to_numeric(df[adapter]))/len(df[adapter])
        if adapter in ac: 
            if val > ac[adapter]: 
                ac[adapter] = val
        else:
            if val > 0.: 
                ac[adapter] = val 
    
    return percent_abv_qualthres, ac, start_pos, end_pos
    
adapters_to_rm = {}
 
for sampid in set(dat_df.index): 
    
    # run fastqc 
    if not os.path.isfile('{}/fastqc/{}_fastqc.html'.format(data_folder, sampid)):
        fpath = dat_df.loc[sampid]['fpath']
        if re.search('Darwin', platform.platform()):
            cmd = ['zcat', '<', fpath, '|', 'fastqc', 'stdin:{}'.format(sampid), '--outdir={}/fastqc'.format(data_folder)]
        else: 
            cmd = ['zcat', fpath, '|', 'fastqc', 'stdin:{}'.format(sampid), '--outdir={}/fastqc'.format(data_folder)]
        subprocess.call(' '.join(cmd), shell=True)

        ## unzip files 
        cmd = ['unzip', '{}/fastqc/{}_fastqc.zip'.format(data_folder, sampid), 
               '-d', '{}/fastqc/'.format(data_folder)]
        subprocess.call(cmd)
        
    # analyse results 
    percent_abv_qualthres, ac, start_pos, end_pos = analyse_fastqc('{}/fastqc/{}_fastqc/fastqc_data.txt'.format(data_folder, sampid))
    if len(ac) > 0: 
        for k, v in ac.items(): 
            if k in adapters_to_rm: 
                if v > adapters_to_rm[k]: 
                    adapters_to_rm[k] = v
            else: 
                if v > 0.: 
                    adapters_to_rm[k] = v
    
    dat_df.at[sampid, 'percent_abv_qualthres'] = percent_abv_qualthres
    dat_df.at[sampid, 'start_pos'] = start_pos
    dat_df.at[sampid, 'end_pos'] = end_pos

if len(adapters_to_rm) > 0: 
    print ('- Presence of adapter sequence (max. proportion of reads) -')
    for k, v in adapters_to_rm.items(): 
        print ("{}: {:.2f}%".format(k, v))

dat_df.head()

# Adapted trimmomatic scirpt for 454-sequencing data

In [None]:
trim_min_len = 30
maxinfo_target_len = 80
maxinfo_strictness = 0.4
adapter_seed_mismatch = 2
adapter_fdat = fc.parsefasta('./input/454-emPCR.fa')

if not os.path.isdir('{}/trimmed'.format(data_folder)): 
    os.mkdir('{}/trimmed'.format(data_folder))
    
def LTScore(l_array):
    return 1/(1+np.exp(maxinfo_target_len-l_array))

# maxinfo function of trimmomatic 
def maxinfo(length_array, quality_array, 
            target_len=maxinfo_target_len, strictness=maxinfo_strictness): 
    
    LTscore = 1/(1+np.exp(target_len-length_array))
    Covscore = length_array**(1-strictness)
    
    correctness_array = 1-(10**(-quality_array/10))
    prod_array = np.array([np.prod(correctness_array[:l+1]) for l in range(len((length_array)))])
    Errscore = prod_array**strictness
    
    return np.argmax(LTscore*Covscore*Errscore)+1

def trim_reads_fn(sampid): 
    trimmed_output = open('{}/trimmed/{}_trimmed.fastq'.format(data_folder, sampid), 'w')
    total_reads = 0 
    discarded_reads = 0
    original_reads_len = []
    remaining_reads_len = []
    
    with gzip.open(dat_df.loc[sampid]['fpath'], "rt") as fhandle:
        for rec in SeqIO.parse(fhandle, 'fastq'):
            
            sequence_list = list(rec.seq)
            qarray = np.array(rec.letter_annotations["phred_quality"])
            
            total_reads += 1
            original_reads_len.append(len(sequence_list))
            
            # sequence fall below min_len
            if len(sequence_list) < trim_min_len:
                discarded_reads += 1
                continue 
            
            # cut adapter 
            adapter_header_to_match = []
            for adapter_header, adapter_seq in adapter_fdat.items():
                for idx in range(len(sequence_list)):
                    # get segment of rec seq to align 
                    rec_seq_to_align = rec.seq[idx:idx+len(adapter_seq)]
                    match = sum(c1==c2 for c1,c2 in zip(adapter_seq, rec_seq_to_align))
                    # continue if no match scores 
                    if match == 0: 
                        continue 
                    # make sure mismatch is below threshold 
                    if len(rec_seq_to_align) - match <= adapter_seed_mismatch:
                        adapter_header_to_match.append({'adapter_header':adapter_header, 
                                                        'idx':idx, 
                                                        'match':match, 
                                                        'seqlen':len(rec_seq_to_align)})
            if len(adapter_header_to_match) > 0: 
                adapter_header_to_match = pd.DataFrame.from_dict(adapter_header_to_match)
                likely_adapter_match = adapter_header_to_match.iloc[adapter_header_to_match['match'].idxmax()]
                
                sequence_list = sequence_list[:likely_adapter_match.idx]
                qarray = qarray[:likely_adapter_match.idx]
                              
            # cut leading and trailing 
            leadtrail_idx = np.argwhere(qarray<3).T[0]
            # start with trailing  
            if len(leadtrail_idx) > 0:
                if leadtrail_idx[-1]+1 == len(sequence_list): 
                    for _i, idx in enumerate(leadtrail_idx[::-1]): 
                        if _i == 0: 
                            prev_idx = idx 
                        else: 
                            if prev_idx-idx > 1:  
                                sequence_list = sequence_list[:prev_idx]
                                qarray = qarray[:prev_idx]
                                break 
                            else: 
                                prev_idx = idx 
                # then with leading 
                if leadtrail_idx[0] == 0: 
                    for _i, idx in enumerate(leadtrail_idx): 
                        if _i == 0: 
                            prev_idx = idx 
                        else: 
                            if idx-prev_idx  > 1:  
                                sequence_list = sequence_list[prev_idx+1:]
                                qarray = qarray[prev_idx+1:]
                                break 
                            else: 
                                prev_idx = idx 
            
            # sequence fall below min_len after all the trimming 
            if len(sequence_list) < trim_min_len:
                discarded_reads += 1
                continue 
            
            # cut to maxinfo crop  
            maxinfo_crop = maxinfo(np.linspace(1, len(sequence_list), len(sequence_list)), qarray)
            sequence_list = sequence_list[:maxinfo_crop]
            qarray = qarray[:maxinfo_crop]
            
            # cut to crop length 
            trim_crop = int(dat_df.loc[sampid]['end_pos'])
            sequence_list = sequence_list[:trim_crop]
            qarray = qarray[:trim_crop]
            
            # sequence fall below min_len after all the trimming 
            if len(sequence_list) < trim_min_len:
                discarded_reads += 1
                continue 
            
            # write to file 
            seq_line = "".join(sequence_list)
            qual_line = "".join(list(map(chr, list(qarray+33))))
            trimmed_output.write('@{}\n{}\n+\n{}\n'.format(rec.id, seq_line, qual_line))
            remaining_reads_len.append(len(seq_line))
    
    trimmed_output.close()
    
    # gzip trimmed file 
    cmd = ['gzip', '{}/trimmed/{}_trimmed.fastq'.format(data_folder, sampid)]
    subprocess.call(cmd)
    
    # update trimmed_stats 
    trimmed_stats.append({'sampid':sampid, 'total_reads':total_reads, 'discarded':discarded_reads, 
                          'initial_mean_len':np.mean(original_reads_len), 'initial_sd_len':np.std(original_reads_len, ddof=1), 
                          'trimmed_mean_len':np.mean(remaining_reads_len), 'trimmed_sd_len':np.std(remaining_reads_len, ddof=1)})

if os.path.isfile("{}/trim_stats.csv".format(data_folder)):
    trimmed_stats = pd.read_csv("{}/trim_stats.csv".format(data_folder))
else: 
    # create shared list 
    manager = mp.Manager()
    trimmed_stats = manager.list()

    # parallelise across all sampids 
    pool = mp.Pool(processes=threadnum)
    results = [pool.apply_async(trim_reads_fn, args=[sampid]) for sampid in sorted_sampid]
    pool.close()
    pool.join()

    trimmed_stats = pd.DataFrame.from_dict(trimmed_stats)
    trimmed_stats['retained'] = trimmed_stats['total_reads']-trimmed_stats['discarded']
    trimmed_stats['retained_prop'] = trimmed_stats['retained']/trimmed_stats['total_reads']
    trimmed_stats = trimmed_stats[["sampid", "total_reads", "discarded", "retained", "retained_prop",  "initial_mean_len", "initial_sd_len", "trimmed_mean_len", "trimmed_sd_len"]]
    trimmed_stats.to_csv("{}/trim_stats.csv".format(data_folder), index=False)

trimmed_stats = trimmed_stats.set_index('sampid')
trimmed_stats.head()

# Read mapping

```bowtie2``` to align the trimmed, merged reads to the reference sequence. 

Flags used for ```bowtie2```: 
```
-x <refid> : Reference sequence to align by 
-X <int>   : If -X 100, a two 20-bp alignment + 60-bp gap would be valid but not if there is 61-bp gap 
-k <int>   : Searches for at most <int> of valid, distinct alignment for each read 
--local    : Does not require that the entire read align from one end to the other. Rather, some characters may be omitted ("soft clipped") from the ends in order to achieve the greatest possible alignment score. 
--very-sensitive : Same as -D 20 -R 3 -N 0 -L 20 -i S,1,0.50
```

In [None]:
print ('Index reference sequence...')
ref_key = re.sub('(^.+/|\.[^\.+]$)', '', reffasta)
cmd = ['bowtie2-build', reffasta, ref_key] # map to barcode ref fasta
subprocess.call(cmd)

# align sequences 
print ('Mapping reads with bowtie2...')
if not os.path.isdir("{}/align".format(data_folder)):
    os.mkdir("{}/align".format(data_folder))

for sampid in sorted_sampid:  
    if not os.path.isfile('{}/align/{}.sam.gz'.format(data_folder, sampid)):
        # mapping with bowtie
        with open('./data/bt_aln.log', 'a') as output: 
            output.write('{}\n'.format(sampid))
            cmd = ['bowtie2', 
                   '-x', ref_key, 
                   '-X', str(max(influenza_gene_len.values())),
                   '-k', '2', 
                   '--very-sensitive-local',
                   '-p', str(threadnum), 
                   '-U', '{}/trimmed/{}_trimmed.fastq.gz'.format(data_folder, sampid),
                   '-S', '{}/align/{}.sam'.format(data_folder, sampid)]
            subprocess.call(cmd, stderr=subprocess.STDOUT, stdout=output)
            output.write('\n')
        
        # gzip sam file 
        cmd = ['gzip', '{}/align/{}.sam'.format(data_folder, sampid)]
        subprocess.call(cmd)
            
print ('...done.')

# Parse SAM files 

Quality filters:  
- excluded all unmapped and non-primary read alignments
- accept only bases with Q-score $\ge$ ```base_qual_threshold``` 

In [None]:
mapping_stats = ngs.parse_sam(sorted_sampid, sorted_refnames, data_folder, 
                              base_qual_threshold, max_indel_abs, max_indel_prop, 
                              nucpos_shift=nucpos_shift, threadnum=1, plt_show=0)
display (mapping_stats.head())

# Tally base and codon counts

In [None]:
ngs = SourceFileLoader('ngs', "%s/laeb_ngs_pipeline.py"%(laeb_lib)).load_module()
ngs.tally_bases(sorted_sampid, threadnum=threadnum, reanalyze_bool=1)

# Variant calling

Other than the 2% frequency threshold, as per  (Illingworth, Bioinformatics, 2016), we compute a statistical threshold for a variant to be called. Suppose $q$ is  the minimum required base quality score, error rate will then be $p_e = 10^{-q/10}$. As such, if $n$ out of $N$ bases are called to a site, the probability that this event resulted from errors is modelled as: 

$p_{Err} = \sum_{i=n}^{N}{\begin{pmatrix}
N \\
i 
\end{pmatrix}p_{e}^{i}(1-p_e)^{N-i}}$

Variant is only called if $p_{Err}<0.01$.

In [None]:
variant_call_df = ngs.variant_calling(sorted_sampid, sorted_refnames, base_qual_threshold,
                                      min_cov, min_var_prop, gene_to_proteinorf, err_tol, 
                                      ha_numbering_conversion=ha_numbering_conversion, 
                                      HAnum_subtype=HAnum_subtype, threadnum=threadnum)
display (variant_call_df.head())

## Plot nucleotide coverage plots for each subject

Here, we plot the average coverage of all gene segments per sample in bins of 50bp. We also compute the breadth of coverage for each gene segment.


In [None]:
# standardise maximum y-value for plots 
ymax = -1
for sampid in sorted_sampid: 
    try: 
        map_nuc_results = pd.read_csv('./results/map_nuc_results_{}.csv'.format(sampid))
    except: 
        continue
    if map_nuc_results['Coverage'].max() > ymax: 
        ymax = map_nuc_results['Coverage'].max()
ymax = 10**int(np.ceil(np.log10(ymax)))

# average coverage is based on a sliding window of 50 bp with stepsize of 25 bp
sliding_window=50
stepsize=25
label_size = 12
color_scheme = ["#2f4f4f","#228b22","#7f0000","#000080","#ff8c00","#ffff00","#00ff00","#00ffff","#ff00ff","#1e90ff","#ffe4b5","#ff69b4"]

# array of sorted segment length 
sorted_gene_len = np.array([influenza_gene_len[refname] for refname in sorted(influenza_gene_len.keys())])

# reindex meta_df based on subject_id and enrollment day 
meta_df = meta_df.reset_index().set_index(['subject_id', 'enrolD', 'sampid']).sort_index() 

# dataframe to plot overall distribution across all patients
overall_gene_coverage_distribution = [] 

for subject_id in sorted(set(meta_df.index.get_level_values(0))): 
    print (subject_id)
    
    subject_meta_df = meta_df.loc[subject_id]
    
    # initialise coverage plot figure for each subject 
    with plt.style.context("default"):
        fig = plt.figure(figsize=(11.7, 4.1))#, constrained_layout=True)
        spec = gridspec.GridSpec(1, 8, figure=fig, wspace=0.2, 
                                 width_ratios=sorted_gene_len/np.sum(sorted_gene_len))

        axes = [] # list of subplots (by segments)
        first_sample_bool = 1

        for enrolD in sorted(set(subject_meta_df.index.get_level_values(0))): 
            enrolD_subject_meta_df = subject_meta_df.loc[enrolD]
            #display (enrolD_subject_meta_df)
            
            for sampid in enrolD_subject_meta_df.index:
                sampid_enrolD_subject_meta_df = enrolD_subject_meta_df.loc[sampid]
                
                try: 
                    timepoint_label = "D%i (%s)"%(int(sampid_enrolD_subject_meta_df.timepoint), sampid)
                except: 
                    timepoint_label = "T%i (%s)"%(enrolD, sampid) 

                #sampid = enrolD_subject_meta_df['sampid']
                sample_type = sampid_enrolD_subject_meta_df['SampleType']

                # read map_nuc_results 
                if os.path.isfile('./results/map_nuc_results_{}.csv'.format(sampid)): 
                    # parse coverage/site quality results    
                    map_nuc_results = pd.read_csv('./results/map_nuc_results_{}.csv'.format(sampid), keep_default_na=False)
                else: 
                    print ('No mapped reads found for %s'%sampid)
                    continue 

                for _r, refname in enumerate(sorted(influenza_gene_len.keys())):

                    rdf = map_nuc_results[map_nuc_results['Gene']==refname]

                    refseq_len = influenza_gene_len[refname]
                    gene_start_pos = 1
                    gene_end_pos = gene_start_pos+refseq_len

                    # step size 
                    x_values = np.arange(gene_start_pos, gene_end_pos, stepsize)
                    if x_values[-1] != gene_end_pos: 
                        x_values = np.append(x_values, gene_end_pos)

                    # compute mean coverage over sliding_window sized bins 
                    y_values = np.zeros(len(x_values)-1)
                    mapdf = rdf[['Position', 'Coverage']].set_index('Position').sort_index()

                    # compute breadth of coverage 
                    breadth = []
                    for idx, x_val in enumerate(x_values): 
                        if idx == 0:
                            continue 

                        pos_range = range(int(np.max([0., x_val-sliding_window])), x_val)

                        # compute mean coverage over 200bp bins 
                        try:
                            mean_coverage = np.mean(mapdf.loc[pos_range])['Coverage']
                        except: 
                            mean_coverage = np.zeros(len(pos_range))
                            for _p, p in enumerate(pos_range): 
                                try: 
                                    mean_coverage[_p] = mapdf.loc[p]['Coverage']
                                except: 
                                    continue 
                            mean_coverage = np.mean(mean_coverage)

                        y_values[idx-1] = mean_coverage

                        # breadth of coverage 
                        if mean_coverage >= min_cov: 
                            breadth += range(x_values[idx-1], x_val)

                        # store computed mean coverage 
                        overall_gene_coverage_distribution.append({'gene':refname, 'pos':x_val, 'coverage':mean_coverage})


                    # compute breadth of coverage 
                    meta_df.loc[(subject_id, enrolD, sampid), refname] = len(breadth)/refseq_len

                    """
                    num_overlapping_polymorphic_sites = len(set(all_polymorphic_sites.loc[refname, 'nucpos'])&set(rdf[rdf['Coverage']>=min_cov]['Position']))
                    meta_df.loc[(subject_id, enrolD), refname] = num_overlapping_polymorphic_sites/len(all_polymorphic_sites.loc[refname, 'nucpos']) 
                    """

                    if first_sample_bool == 1:
                        # add subplot for 1st sample 
                        ax = fig.add_subplot(spec[0,_r])
                        ax.set_title(refname, fontsize=label_size) # title 

                        # plot min_cov line
                        ax.plot(x_values[1:], 
                                np.zeros(len(x_values)-1)+min_cov, 
                                color='k', linestyle='--')
                        axes.append(ax)
                    else: 
                        ax = axes[_r]

                    # plot coverage 
                    # patch samples
                    if re.search('_P$', sampid):
                        label = '{}-{} (patch)'.format(timepoint_label, sample_type)
                        ax.plot(x_values[1:],
                                y_values, '--',
                                color=color_scheme[enrolD-1],
                                label=label)
                    else: 
                        label = '{}{}'.format(timepoint_label, "" if pd.isna(sample_type) else "-%s"%(sample_type))
                        ax.plot(x_values[1:],
                                y_values, 
                                color=color_scheme[enrolD-1],
                                label=label)

                if first_sample_bool == 1:
                    first_sample_bool = 0

        for _ax, ax in enumerate(axes):
            if _ax == 0: 
                ax.set_ylabel('Coverage')
                ax.yaxis.label.set_fontsize(label_size)
            else: 
                # remove y-axis label (sharey)
                ax.tick_params(labelleft=False)

            # gray facecolor for odd panels 
            if (_ax%2 != 0): 
                ax.set_facecolor(color='#d1d1d1')

            # remove left and right spines 
            ax.spines['left'].set_visible(False)
            ax.spines['right'].set_visible(False)

            # set xlim and xtick labels
            refname = sorted(influenza_gene_len.keys())[_ax]
            refseq_len = influenza_gene_len[refname]
            gene_start_pos = 1
            gene_end_pos = gene_start_pos+refseq_len

            if refseq_len > 2000: 
                ax.set_xticks(np.linspace(gene_start_pos,  gene_end_pos-1, 4))
                ax.set_xticklabels(map(int, np.linspace(gene_start_pos,  gene_end_pos-1, 4)))
            elif refseq_len > 1000: 
                ax.set_xticks(np.linspace(gene_start_pos,  gene_end_pos-1, 3))
                ax.set_xticklabels(map(int, np.linspace(gene_start_pos,  gene_end_pos-1, 3)))
            else: 
                ax.set_xticks(np.linspace(gene_start_pos, gene_end_pos-1, 2))
                ax.set_xticklabels(map(int, np.linspace(gene_start_pos,  gene_end_pos-1, 2)))

            # set ylim and yscale 
            ax.set_ylim((1, ymax))
            ax.set_yscale('symlog')

            # change tick size 
            ax.tick_params(axis='both', which='major', labelsize=label_size*0.8)

            # change axis size 
            ax.xaxis.label.set_fontsize(label_size)

        # x-axis label 
        fig.text(0.5, 0.01, 'Position', ha='center', fontsize=label_size)
        plt.legend(loc='center left',  bbox_to_anchor=(1, 0.5))
        #plt.tight_layout()
        plt.savefig('./results/figures/coverage_plots_{}.pdf'.format(subject_id.replace("/", "-")), 
                    bbox_inches='tight', pad_inches=0.)
        plt.show()

# convert overall_gene_coverage_distribution to dataframe 
overall_gene_coverage_distribution = pd.DataFrame.from_dict(overall_gene_coverage_distribution)
overall_gene_coverage_distribution = overall_gene_coverage_distribution.set_index(['gene', 'pos']).sort_index()

# Overall coverage across all patients

In [None]:
with plt.style.context("default"):
    # initialise coverage plot figure 
    fig = plt.figure(figsize=(11.7, 4.1))#, constrained_layout=True)
    spec = gridspec.GridSpec(1, 8, figure=fig, wspace=0.2, 
                             width_ratios=sorted_gene_len/np.sum(sorted_gene_len))

    axes = [] # list of subplots (by segments)

    # add subplot for 1st sample 
    for _r, refname in enumerate(sorted_refnames): 
        ax = fig.add_subplot(spec[0,_r])
        ax.set_title(refname, fontsize=label_size) # title 

        # plot min_cov line
        Y_array = []
        X_array = np.array(sorted(set(overall_gene_coverage_distribution.loc[refname].index)))
        for x_val in X_array:
            Y_array.append(np.array(overall_gene_coverage_distribution.loc[(refname, x_val), 'coverage']))
        Y_array = np.array(Y_array)

        ax.plot(X_array, [50]*len(X_array), "--", color='#fcae91')

        mu = np.median(Y_array, axis=1)
        ax.plot(X_array, mu, color='#000000')
        ax.fill_between(X_array, np.quantile(Y_array, 0.25, axis=1), 
                        np.quantile(Y_array, 0.75, axis=1), facecolor='#ef3b2c', alpha=0.5)
        ax.fill_between(X_array, np.min(Y_array, axis=1), 
                        np.max(Y_array, axis=1), facecolor='#fcbba1', alpha=0.2)

        axes.append(ax)

    for _ax, ax in enumerate(axes):
        if _ax == 0: 
            ax.set_ylabel('Coverage')
            ax.yaxis.label.set_fontsize(label_size)
        else: 
            # remove y-axis label (sharey)
            ax.tick_params(labelleft=False)

        # gray facecolor for odd panels 
        if (_ax%2 != 0): 
            ax.set_facecolor(color='#d1d1d1')

        # remove left and right spines 
        ax.spines['left'].set_visible(False)
        ax.spines['right'].set_visible(False)

        # set xlim and xtick labels
        refname = sorted(influenza_gene_len.keys())[_ax]
        refseq_len = influenza_gene_len[refname]
        gene_start_pos = 1
        gene_end_pos = gene_start_pos+refseq_len

        if refseq_len > 2000: 
            ax.set_xticks(np.linspace(gene_start_pos,  gene_end_pos-1, 4))
            ax.set_xticklabels(map(int, np.linspace(gene_start_pos,  gene_end_pos-1, 4)))
        elif refseq_len > 1000: 
            ax.set_xticks(np.linspace(gene_start_pos,  gene_end_pos-1, 3))
            ax.set_xticklabels(map(int, np.linspace(gene_start_pos,  gene_end_pos-1, 3)))
        else: 
            ax.set_xticks(np.linspace(gene_start_pos, gene_end_pos-1, 2))
            ax.set_xticklabels(map(int, np.linspace(gene_start_pos,  gene_end_pos-1, 2)))

        # set ylim and yscale 
        ax.set_ylim((1, ymax))
        ax.set_yscale('symlog')

        # change tick size 
        ax.tick_params(axis='both', which='major', labelsize=label_size*0.8)

        # change axis size 
        ax.xaxis.label.set_fontsize(label_size)

    # x-axis label 
    fig.text(0.5, 0.01, 'Position', ha='center', fontsize=label_size)
    #plt.tight_layout()
    plt.savefig('./results/figures/coverage_plots_overall.pdf', 
                bbox_inches='tight', pad_inches=0.)
    plt.show()

# save meta_df with coverage breadth to results 
meta_df.to_csv('./results/metadata_w_covbreadth.csv')