#  T-cell vaccine design
Design vaccine(s) to elicit a T-cell response by optimising coverage of potential T-cell epitope (PTEs)

Here the term epitope `e` refers to a potential T-cell epitope (PTE) which is a short subsequence of `k` amino acids and also  represented as a node in the epitope graph `G`.

In [None]:
from Bio import Align, SeqIO
from Bio.Align import substitution_matrices
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from bin.predict_immunogenicity import Prediction # TODO: Add this to the repo
from bin.pca_protein_rank import pca_protein_rank, plot_pca
from itertools import product
from scipy import stats
from sklearn import preprocessing

import igviz as ig
import json
import matplotlib.pyplot as plt
import mhcflurry
import networkx as nx
import numpy as np
import os
import pandas as pd
import random
import seaborn as sns
import warnings

In [None]:
# Change
fasta_path = '../data/test/1_nuc_cds_uniq.fa'
hla_path = '../data/test/hla_ref_set.class_i.txt'
k = 9
m = 1 # number of antigens
equalise_clades = True
aligned = False
decycle = True
n_clusters = 9 # will be automatically computed if False
plot = False
weights = {
    # 'evo_rate': 0,
    'freq': 1,
    'immunogenicity': 1,
    'mhc_binding': 1
}

# Don't change
base_path = '.'.join(fasta_path.split('.')[:-1])
fasta_pr_path = base_path + '_protein.fas'
fasta_pr_designs_path = f'{base_path}_designs.fasta'
msa_pr_path = base_path + '_protein.msa'
msa_pr_designs_path = f'{base_path}_designs_pr_msa.fasta'
pre_msa_path = 'hyphy/pre-msa.bf'
# post_msa_path = 'hyphy/post-msa.bf'
# fasta_nt_path = fasta_path + '_nuc.fas'
# msa_nt_path = fasta_path + '_nuc.msa'
# tree_path = fasta_path + '.treefile'
# fel_path = msa_nt_path + '.FEL.json'
# slac_path = msa_nt_path + '.SLAC.json'

In [None]:
supported_alleles = !mhcflurry-predict --list-supported-alleles
supported_alleles = list(supported_alleles)

alleles = []
with open(hla_path, 'r') as f:
    for line in f:
        allele = line.strip().split(',')[0]
        if allele not in supported_alleles:
            print(f'Allele {allele} not supported by mhcflurry')
        elif allele not in alleles:
            alleles.append(allele)

In [None]:
##################
# Calculate scores
##################

def calc_score(kmers_dict, weights, kmer, equalise_clades=True):
    """
    Returns the overall score for a given k-mer
    :param kmers_dict: Dictionary containing all possible k-mers with their corresponding scores
    :param weights: Dictionary containing keys for each scoring criteria and values corresponding to their weight
    :param kmer: String of k-mer
    :param equalise_clades: Boolean indicating whether to equalise the PTE scores based on the clade weights
    :returns: Float of overall score for a given k-mer
    """
    total_score = sum([weights[weight] * kmers_dict[kmer][weight] for weight in weights]) / sum(list(weights.values()))
    if not equalise_clades:
        return total_score
    else:
        return total_score * kmers_dict[kmer]['clade_weight']


def calc_immune_scores(peptides, alleles, mhc_binding=True, immunogenicity=True, percentile_threshold=2):
    """
    Compute the immunogenicity and MHC binding scores for all peptides and alleles
    :param peptides: list of peptides
    :param alleles: list of alleles
    :param mhc_binding: boolean indicating whether to compute MHC binding scores
    :param immunogenicity: boolean indicating whether to compute immunogenicity scores
    :param percentile_threshold: percentile threshold for MHC binding
    :return: Dictionary with the scores
    """
    # Define vars
    ig_predictor = Prediction()
    mhc_predictor = predictor = mhcflurry.Class1PresentationPredictor.load()
    scaler = preprocessing.MinMaxScaler()
    dfs = []
    agg_dict = {}
    cols = []
    if immunogenicity:
        agg_dict['immunogenicity'] = 'mean'
        cols.append('immunogenicity')
    if mhc_binding:
        !mhcflurry-downloads fetch models_class1_presentation
        agg_dict['mhc_binding'] = 'sum'
        cols.append('mhc_binding')

    # Compute the scores for all peptides and alleles
    for allele in alleles:
        if mhc_binding:
            mhc_df = mhc_predictor.predict(peptides, [allele], include_affinity_percentile=True, verbose=0)
            # Determine if the peptide binds to the MHC (based on the percentile threshold)
            mhc_df['mhc_binding'] = mhc_df['affinity_percentile'].apply(lambda x: 1 if x < percentile_threshold else 0)
            mhc_df = mhc_df.rename(columns={'best_allele': 'allele'})
            df = mhc_df
        if immunogenicity:
            ig = ig_predictor.main(peptides=peptides, allele=allele)
            ig_df = pd.DataFrame(ig[1:], columns=ig[0])
            ig_df = ig_df.rename(columns={'score': 'immunogenicity'})
            df = ig_df
        if mhc_binding and immunogenicity:
            df = mhc_df.merge(ig_df, on='peptide')
        dfs.append(df)

    # Concatenate the dataframes
    df = pd.concat(dfs, ignore_index=True)
    # Set the immunogenicity to the minimum value if the peptide is not predicted to bind to the MHC
    if mhc_binding and immunogenicity:
        min_immunogenicity = df['immunogenicity'].min()
        df['immunogenicity'] = df.apply(lambda x: min_immunogenicity if x['mhc_binding'] == 0 else x['immunogenicity'], axis=1)
    # Aggregate the scores for each peptide across the alleles
    df = df.groupby('peptide').agg(agg_dict).reset_index()
    # Convert MHC binding to the fraction of bound alleles
    if mhc_binding:
        df['mhc_binding'] = df['mhc_binding'] / len(alleles)
    # Scale the scores    
    df[cols] = scaler.fit_transform(df[cols])

    return df[cols + ['peptide']].set_index('peptide').to_dict()


# def get_lrt(fel_path):
#     """
#     Returns the likelihood-ratio test (LRT) with sign (+/-) to indicate if the codon is under posotive or negative selection for each codon
#     :param fel_path: String path to HyPhy FEL JSON ouput
#     :returns: List of LRT values for each codon with sign to indicate selection type
#     """
#     fel = json.load(open(fel_path))
#     headers = [header[0] for header in fel['MLE']['headers']] # fel['MLE']['headers']
#     content = fel['MLE']['content']['0']
#     df = pd.DataFrame(content, columns=headers)
#     df['omega'] = df['beta'] / df['alpha']
#     df['LRT_selection'] = preprocessing.normalize([np.array(df['LRT'])])[0]
#     df['LRT_selection'] = df.apply(lambda x: x['LRT_selection'] if x['alpha'] < x['beta'] else x['LRT_selection']*-1, axis=1)
#     return df['LRT_selection'].tolist()


def get_evo_rate(slac_path):
    """
    Returns the evolution rate score which is the the absolute number of synonymous (-1) + nonsynonymous (1) sites at each position in the alignment
    :param slac_path: String path to HyPhy SLAC JSON ouput
    :returns: List of evolution rate score for each codon
    """
    slac = json.load(open(slac_path))
    headers = [header[0] for header in slac['MLE']['headers']]
    content = slac['MLE']['content']['0']['by-site']['RESOLVED'] #['AVERAGED']
    df = pd.DataFrame(content, columns=headers)[['S', 'N']]
    # This is a bit counter-intuitive but we set the number of non-synonymous sites (N) to negative even though they're under positive selection
    # This is because we want to select for sites that are under negative (aka purifying) selection
    df['N'] = df['N']*-1
    df['evo_rate'] = df['S'] + df['N']
    scaler = preprocessing.MinMaxScaler()
    df['evo_rate'] = scaler.fit_transform(df['evo_rate'].values.reshape(-1, 1))
    return df['evo_rate'].tolist()


def get_clade_freq_dict(seqs_clusters_dict):
    """
    Calculate the clade frequencies = number of sequences in each clade / total number of sequences
    :param seqs_clusters_dict: dictionary of sequences and their clades
    :return: dictionary of clade frequencies
    """
    clade_freq_dict = {}
    for seq_id in seqs_clusters_dict:
        clade = seqs_clusters_dict[seq_id][1]
        if clade not in clade_freq_dict:
            clade_freq_dict[clade] = 1
        else:
            clade_freq_dict[clade] += 1
    clade_freq_dict = {clade: clade_freq_dict[clade] / len(seqs_clusters_dict) for clade in clade_freq_dict}
    return clade_freq_dict


############################
# Utils to process sequences
############################

def seq_to_kmers(seq, k):
    """
    Returns a list of k-mers of length k for a given string of amino acid sequence
    """
    return [seq[i:i+k] for i in range(len(seq) - k + 1)]


# For aligned sequences
def seqs_dict_to_df(seqs_dict, evo_rate):
    """
    Returns a dataframe of sequences, their names and corresponding evolutionary rate for each codon
    :param seqs_dict: Dictionary where the keys are sequence IDs and the values are aligned amino acid sequence strings
    :param evo_rate: List containing the evolutionary rate for each codon/site in the MSA 
    :returns: Dataframe of sequences, their names and corresponding evolutionary rate for each codon
    """
    seqs = [list(seq) for seq in list(seqs_dict.values())]
    seqs.insert(0, evo_rate)
    names = ['Evo_Rate'] + list(seqs_dict.keys())
    df = pd.DataFrame(seqs)
    df.index = names
    return df


def seqs_df_to_kmers_dict(df, weights, k=9):
    """
    Returns a dictionary of all possible k-mers and their evolutionary rate, position, frequency and overall score
    :param df: Dataframe of amino acid sequences, their names and corresponding evolutionary rate for each codon
    :param k: Integer for substring length
    :param weights: Dictionary containing keys for each scoring criteria and values corresponding to their weight
    :returns: Dictionary containing all possible k-mers with their evolutionary rate, position, frequency and overall score
    """
    kmers_dict = {}
    for i in range(len(df.columns) - k + 1):
        kmers = df.loc[:, i:i+k-1]
        evo_rate = kmers.loc['Evo_Rate'].tolist()
        evo_rate = sum(evo_rate) / len(evo_rate)
        pos = i+1
        for index, row in kmers.iterrows():
            if index != 'Evo_Rate':
                kmer = ''.join(row.tolist())
                if kmer not in kmers_dict:
                    kmers_dict[kmer] = {'evo_rate':evo_rate, 'pos':pos, 'freq':1}
                else:
                    if pos == kmers_dict[kmer]['pos']:
                        kmers_dict[kmer]['freq'] += 1
                    else:
                        warnings.warn(f'k-mer {kmer} found at multiple positions {pos},{kmers_dict[kmer]["pos"]}. This was not expected to happen and may cause issues.')
    for kmer in kmers_dict:
        # Convert sequence counts to frequency
        kmers_dict[kmer]['freq'] = kmers_dict[kmer]['freq'] / len(df.columns)
        # Calculate score using weights
        kmers_dict[kmer]['score'] = calc_score(kmers_dict, weights, kmer)    
    return kmers_dict


# For unaligned sequences
def seqs_to_kmers_dict(seqs_clusters_dict, k=9, weights=weights, alleles=None, aligned=False, equalise_clades=True):
    """
    Returns a dictionary of all possible k-mers and their frequencies for a given list of sequences and value of k
    :param seqs_clusters_dict: Dictionary where the keys are sequence IDs and the values are tuples containing the sequence and its clade
    :param k: Integer for substring length 
    :param alleles: List of MHC alleles to use for MHC binding prediction
    :param aligned: Boolean indicating whether the sequences are aligned or unaligned
    :returns: Dictionary containing all possible k-mers and their frequencies
    """
    # Get a list of k-mers from the sequences, for each k-mer record the frequency and clade
    kmers_dict = {}
    N = len(seqs_clusters_dict)
    for seq_id in seqs_clusters_dict:
        seq = seqs_clusters_dict[seq_id][0]
        clade = seqs_clusters_dict[seq_id][1]
        for i in range(len(seq) - k + 1):
            e = seq[i:i+k]
            if e not in kmers_dict:
                kmers_dict[e] = {'clades': [clade], 'clade_weight': 0, 'evo_rate': 0, 'freq': 1, 'immunogenicity':0, 'mhc_binding': 0, 'pos': 0, 'score': 0}
            else:
                kmers_dict[e]['freq'] += 1
                if clade not in kmers_dict[e]['clades']:
                    kmers_dict[e]['clades'].append(clade)

    clade_freq_dict = get_clade_freq_dict(seqs_clusters_dict)
    n_clades = len(clade_freq_dict)
    immunogenicity = bool(weights['immunogenicity'])
    mhc_binding = bool(weights['mhc_binding'])
    empty_dict = {e:0 for e in kmers_dict}
    peptides = list(kmers_dict.keys())

    if immunogenicity or mhc_binding:
        immune_dict = calc_immune_scores(peptides, alleles, mhc_binding, immunogenicity)
    else:
        immune_dict = {'mhc_binding': empty_dict, 'immunogenicity': empty_dict}
    if immunogenicity and not mhc_binding:
        immune_dict['mhc_binding'] = empty_dict
    if mhc_binding and not immunogenicity:
        immune_dict['immunogenicity'] = empty_dict

    for e, d in kmers_dict.items():
        avg_clade_freq = np.mean([clade_freq_dict[clade] for clade in d['clades']])
        kmers_dict[e]['clade_weight'] = (1 / n_clades) / avg_clade_freq
        kmers_dict[e]['freq'] = d['freq']/N
        kmers_dict[e]['immunogenicity'] = immune_dict['immunogenicity'][e] 
        kmers_dict[e]['mhc_binding'] = immune_dict['mhc_binding'][e] 
        kmers_dict[e]['score'] = calc_score(kmers_dict, weights, e, equalise_clades)
    
    # MinMax scale the scores
    scores_dict = {e: d['score'] for e, d in kmers_dict.items()}
    scores_df = pd.DataFrame.from_dict(scores_dict, orient='index', columns=['score'])
    scaler = preprocessing.MinMaxScaler()
    scores_df['score'] = scaler.fit_transform(scores_df['score'].values.reshape(-1,1))
    for kmer, score in zip(scores_df.index, scores_df['score']):
        kmers_dict[kmer]['score'] = score

    return kmers_dict


def path_to_seq(path):
    """
    Returns an AA string for a list of epitopes (path)
    """
    seq = [path[0]] + [e[-1] for e in path[1:]]
    return ''.join(seq)


def align_seqs(seq1, seq2, matrix='BLOSUM62'):
    """
    Align two sequences using the specified substitution matrix.
    :param seq1: First sequence AA string
    :param seq2: Second sequence AA string
    :param matrix: Substitution matrix name (default: BLOSUM62)
    :return: Alignment
    """
    aligner = Align.PairwiseAligner()
    aligner.substitution_matrix = substitution_matrices.load(matrix)
    return aligner.align(seq1, seq2)


############################################
# Utils to get the min/max index from a list
############################################

def argmax(lst):
    """
    Returns the index for the maximum value in a list
    """
    return lst.index(max(lst))


def argmin(lst):
    """
    Returns the index for the minimum value in a list
    """
    return lst.index(min(lst))


#####################################
# Utils to retrieve info from a graph
#####################################

def P(G, e):
    """
    Returns the predecessors for a given graph G and node e
    :param G: Directed Graph containing epitopes
    :param e: String for a given potential T-cell epitope (PTE)
    :returns: List of predecessors
    """
    return list(G.predecessors(e))


def S(G, e):
    """
    Returns the successors for a given graph G and node e
    :param G: Directed Graph containing epitopes
    :param e: String for a given potential T-cell epitope (PTE)
    :returns: List of successors
    """
    return list(G.successors(e))


def f(G, e, f='Score'):
    """
    Returns the feature for a given epitope e eg frequency or score in the population
    :param G: Directed Graph containing epitopes
    :param e: String for a given potential T-cell epitope (PTE)
    :param e: String for the node feature (default = 'Score')
    :returns: Float for the epitope feature eg score
    """
    return G.nodes[e][f]


############################################
# Decycling - remove all cycles from a graph
############################################

def decycle_graph(G):
    """
    Return a Directed Graph with no cycles
    :param G: Directed Graph containing epitopes
    :returns: Directed Graph containing epitopes and no cycles
    """
    # j is a list of all compnents; each component is a list of nodes in G
    components = list(nx.strongly_connected_components(G))
    # Discard all single node components - no cycles there!
    components = [j for j in components if len(j) != 1]
    if len(components) != 0:
        for j in components:
            # Randomly choose two nodes from the selected component
            ea, eb = random.sample(list(j), k=2)
            cycle = cycle_from_two_nodes(G, ea, eb)
            if cycle:
                ea, eb = weak_edge_in_cycle(G, cycle)
                G.remove_edge(ea, eb)
                # Repeat until graph is acyclic
                G = decycle_graph(G)
    return G


def cycle_from_two_nodes(G, ea, eb):
    """
    Returns the cycle (i.e. path that starts and ends in with the same epitope) for two nodes
    :param G: Directed Graph containing epitopes
    :param ea: String for the first given potential T-cell epitope (PTE)
    :param eb: String for the second given potential T-cell epitope (PTE)
    :returns: List of epitope strings on path that is a cycle
    """
    try:
        path_ab = nx.shortest_path(G, source=ea, target=eb)
        path_ba = nx.shortest_path(G, source=eb, target=ea)
        # Merge two paths into a cycle
        cycle = path_ab[:-1] + path_ba
    except nx.NetworkXNoPath:
        cycle = []
    return cycle


def weak_edge_in_cycle(G, cycle):
    """
    Returns the weak edge (edge with the lowest score) in a cycle
    :param G: Directed Graph containing epitopes
    :param cycle: List of epitope strings on path that is a cycle
    :returns: Tuple for the weak edge containing the two epitope strings
    """
    edges = seq_to_kmers(cycle, k=2)
    values = []
    for ea, eb in edges:
        # v is heuristic “value” of edge
        v = f(G, ea) + f(G, eb)
        # Add value if cutting edge would isolate ea
        if len(S(G, ea)) == 1:
            v = v + f(G, ea)
        # Add value if cutting edge would isolate eb
        if len(P(G, eb)) == 1:
            v = v + f(G, eb)
        values.append(v)
    ea, eb = edges[argmin(values)]
    return ea, eb


###############################################
# Find optimal path through a graph of epitopes
###############################################

def find_optimal_path(G):
    """
    Returns the optimal path through a graph of epitopes
    :param G: Directed Graph containing epitopes
    :returns: List of epitope strings on the optimal path
    """
    # Forward loop - compute F(e)
    for e in G.nodes:
        F(G, e)
    # Backward loop - build the path that achieves the maximal score
    path = backward(G)
    return path


def F(G, e):
    """
    Returns the maximum total score over all paths that end in e
    :param G: Directed Graph containing epitopes
    :param e: String for a given potential T-cell epitope (PTE)
    :returns: Float for the maximum total epitope score
    """
    # Use precomputed F(e) if it already exists for the epitope
    if 'F(e)' not in G.nodes[e]:
        predecessors = P(G, e)
        if not predecessors:
            # If the set of predecessors P(e) is empty, then F(e) = f(e)
            Fe = f(G, e)
        else:
            # If the set of predecessors P(e) is not empty, then F(e) = f(e) + max(F(P(e)))
            Fe = f(G, e) + max([F(G, pe) for pe in predecessors])
        # Save F(e) to the graph for this epitope
        nx.set_node_attributes(G, {e: Fe}, 'F(e)')
    return f(G, e, f='F(e)')


def backward(G, path=[]):
    """
    Returns the path that achieves the maximal score
    :param G: Directed Graph containing epitopes
    :param path: List of epitope strings to complete (deafult=[])
    :returns: List of epitope strings on path that achieve maximum score
    """
    # Get the precomputed F(e) from the graph for all epitopes
    Fe_dict = nx.get_node_attributes(G, 'F(e)')
    if not path:
        # Get the epitope with the maximum F(e) as the final epitope in our optimal path
        end_nodes = {e: Fe_dict[e] for e in P(G, 'END')}
        path = [max(end_nodes, key=end_nodes.get)]
    # Get the most recently added epitope e and it's predecessors P(e)
    e = path[0]
    predecessors = P(G, e)
    if predecessors[0] != 'BEGIN':
        # Add the best (highest F(e)) predecessor P(e) of epitope e to our path
        i = argmax([Fe_dict[pe] for pe in predecessors])
        path.insert(0, predecessors[i])
        # Repeat until you get to the start
        backward(G, path)
    return path


###########################################################
# Cocktail: Find (and iteratively refine) a set of antigens
###########################################################

def cocktail(G, m, refine=True, score='Score'):
    """
    Returns a list of m antigens
    :param G: Directed Graph containing epitopes
    :param m: Integer for number of antigens
    :param refine: Boolean for if the antigens should be iteratively refined
    :returns: List containing m antigens
    """
    Q = [] # vaccine
    # Save original epitope score so it can be reset later
    score_dict = nx.get_node_attributes(G, score)
    for n in range(0, m):
        # Compute and save next antigen sequence
        q = find_optimal_path(G)
        # Add q to vaccine
        Q.append(q)
        # No credit for including e in subsequent antigens
        for e in q:
            nx.set_node_attributes(G, {e: 0}, score)
        # Remove F(e) so it's recomputed using the updated scores
        for (e,d) in G.nodes(data=True):
            del d['F(e)']
    # Reset to the original scores
    nx.set_node_attributes(G, score_dict, score)
    # Optional: Repeat - iterative refinement
    if refine:
        Q = iterative_refinement(G, Q)     
    return Q


def iterative_refinement(G, Q, score='Score'):
    """
    Returns a list of iteratively refined antigens
    :param G: Directed Graph containing epitopes
    :param Q: List containing antigens
    :returns: List containing iteratively refined antigens
    """
    m = len(Q)
    # Save original epitope score so it can be reset later
    score_dict = nx.get_node_attributes(G, score)
    while True:
        for n in range(0, m):
            prev_Q = Q
            # Remove sequence q from vaccine Q
            q = Q[n]
            Q.remove(q)
            # No credit for including e in existing antigens
            for q in Q:
                for e in q:
                    nx.set_node_attributes(G, {e: 0}, score)
            # Compute replacement for old sequence q
            q = find_optimal_path(G)
            # Add q to vaccine
            Q.insert(n, q)
            # Reset to the original scores
            nx.set_node_attributes(G, score_dict, score)
            # Remove F(e) so it's recomputed using the updated scores
            for (e,d) in G.nodes(data=True):
                del d['F(e)']
            if Q == prev_Q:
                return Q


#######################################################
# Design: Wrapper function to design a vaccine cocktail
#######################################################

def design_vaccines(fasta_pr_path, msa_pr_path, n_clusters, k=9, m=1, weights=None, alleles=None, aligned=False, equalise_clades=True, decycle=False, plot=False):
    """
    Design a vaccine cocktail from a FASTA file of protein sequences.
    :param fasta_pr_path: Path to the FASTA file of protein sequences.
    :param msa_pr_path: Path to the MSA file of protein sequences.
    :param n_clusters: Number of clusters to use for clustering the sequences.
    :param k: Length of the k-mers to use.
    :param m: Number of epitopes to include in the vaccine cocktail.
    :param weights: Dictionary of weights to use for the scoring functions.
    :param alleles: List of alleles to use for the scoring functions.
    :param aligned: Boolean indicating whether the sequences are already aligned.
    :param equalise_clades: Boolean indicating whether to equalise the clades of the sequences.
    :param decycle: Boolean indicating whether to decycle the graph.
    :param plot: Boolean indicating whether to generate all plots.
    :return: Sequences dictionary and list of the optimal vaccine cocktails.
    """
    # Load the FASTA seqs into a dictionary
    fasta_seqs = SeqIO.parse(fasta_pr_path,'fasta')
    seqs_dict = {seq.id:str(seq.seq) for seq in fasta_seqs}
    # Assign sequences to clusters
    if not os.path.exists(msa_pr_path):
        !mafft --auto {fasta_pr_path} > {msa_pr_path}
    clusters_dict, comp_df = pca_protein_rank(msa_pr_path, n_clusters=n_clusters, plot=plot)
    seqs_clusters_dict = {seq_id: [seqs_dict[seq_id], clusters_dict[seq_id]] for seq_id in seqs_dict}
    # Convert sequences to kmers
    kmers_dict = seqs_to_kmers_dict(seqs_clusters_dict, k=k, weights=weights, alleles=alleles, aligned=aligned, equalise_clades=equalise_clades)
    # Construct the graph
    G = construct_graph(kmers_dict, decycle=decycle, aligned=aligned)
    # Find the optimal path(s) through the graph of epitopes
    Q = cocktail(G, m)
    return seqs_dict, Q


##################################################
# Construct/visualise the epitope graph and scores
##################################################

def construct_graph(kmers_dict, decycle=True, aligned=False, edge_colour='#BFBFBF'):
    """
    Return a Directed Graph with unique k-mers as nodes, where overlapping k-mers are connected by edges
    :param kmers_dict: Dictionary containing k-mers and their counts
    :param decycle: Boolean if the output graph should be decycled (default=True)
    :param edge_colour: String for edge colour (default='black')
    :returns: Directed Graph containing epitopes
    """
    # Create graph
    G = nx.DiGraph()
    # Add nodes - for each unique k-mer
    for e, d in kmers_dict.items():
        clades, clade_weight, freq, immunogenicity, mhc_binding, evo_rate, pos, score = d['clades'], d['clade_weight'], d['freq'], d['immunogenicity'], d['mhc_binding'], d['evo_rate'], d['pos'], d['score']
        G.add_node(e, Clades=clades, Clade_Weight=clade_weight, Frequency=freq, Immunogenicity=immunogenicity, MHC_Binding=mhc_binding, Evo_Rate=evo_rate, pos=(pos, score), Score=score)
    # Add edges - where the last k−1 characters of ea match the first k−1 characters of eb
    for ea, eb in product(G.nodes(), G.nodes()):
        if not G.has_edge(ea, eb) and ea[1:] == eb[:-1]:
            G.add_edge(ea, eb, colour=edge_colour)
    # Decycle graph
    if decycle:
        G = decycle_graph(G)
    # Add begin and end nodes
    begin_nodes = [e for e in list(G.nodes) if not P(G, e)]
    end_nodes = [e for e in list(G.nodes) if not S(G, e)]
    G.add_node('BEGIN', Clades=[], Clade_Weight=0, Frequency=0, Immunogenicity=0, MHC_Binding=0, Evo_Rate=0, pos=(0,0), Score=0)
    G.add_node('END', Clades=[], Clade_Weight=0, Frequency=0, Immunogenicity=0, MHC_Binding=0, Evo_Rate=0, pos=(0,0), Score=0)
    for e in begin_nodes:
        G.add_edge('BEGIN', e, colour=edge_colour)
    for e in end_nodes:
        G.add_edge(e, 'END', colour=edge_colour)
    # Add the position node attribute (for unaligned sequences)
    if not aligned:
        for e in G.nodes():
            if e != 'BEGIN' and e != 'END':
                pos = nx.shortest_path_length(G, source='BEGIN', target=e)
                score = f(G, e)
                nx.set_node_attributes(G, {e: (pos, score)}, 'pos')
    # Ensure the 'END' node is always at the end of the graph
    end_pos = max([pos[0] for pos in list(nx.get_node_attributes(G, 'pos').values())]) + 1
    nx.set_node_attributes(G, {'END': (end_pos, 0)}, 'pos')
    return G


def plot_graph(G, paths=None, node_size=150, with_labels=False, ylim=([0,1]), interactive=False, colour_by='Score'):
    """
    Plot the epitope graph
    :param G: Directed Graph containing epitopes
    :param paths: List of lists of epitope strings (deafult=None)
    :param node_size: Integer for size of nodes in non-interactive plot (default=150)
    :param with_labels: Boolean for if epitope labels should be displayed in the non-interactive plot (default=False)
    :param ylim: List for y-axis limits (default=[0,1])
    :param interactive: Boolean for if the plot should be interactive (default=False)
    :returns: None
    """
    if interactive:
        fig = ig.plot(G, color_method=colour_by, node_text=['Score', 'Clades', 'Clade_Weight', 'Frequency', 'Immunogenicity', 'MHC_Binding', 'pos']) # layout='spectral','spiral','spring'
        return fig.show()
    else:
        # Define vars
        node_color = list(nx.get_node_attributes(G,colour_by).values())
        pos = nx.get_node_attributes(G, 'pos')
        if paths:
            for path in paths:
                path = ['BEGIN'] + path + ['END']
                for i in range(0,len(path)-1):
                    G.edges[path[i], path[i+1]]['colour'] = 'red'
        edge_colours = nx.get_edge_attributes(G, 'colour')
        # Plot
        fig, ax = plt.subplots(1,figsize=(16,8))
        nx.draw(G, node_color=node_color, pos=pos, node_size=node_size, with_labels=with_labels, edge_color=edge_colours.values(), width=2, font_color='white', ax=ax)
        limits=plt.axis('on')
        max_pos = max([p[0] for p in pos.values()]) + 0.5
        ax.set_xlim([-0.5, max_pos])
        ax.set_ylim(ylim)
        ax.spines.right.set_visible(False)
        ax.spines.top.set_visible(False)
        ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
        ax.tick_params(axis='both', which='major', labelsize=14)
        plt.ylabel('Epitope Score, f(e)', fontsize=18)


def format_title(title):
    return str.replace(title, '_', ' ')


def plot_score(G, score='Frequency'):
    """
    Plot the distribution of a given score for the nodes in the graph.
    :param G: NetworkX graph
    :param score: String for the score to be plotted (default='Frequency')
    :returns: Figure
    """
    scores = list(nx.get_node_attributes(G,score).values())
    # Plot
    fig, ax = plt.subplots(1,figsize=(16,8))
    ax.hist(scores, bins=100)
    ax.set_xlabel(score, fontsize=18)
    ax.set_ylabel('Number of PTEs', fontsize=18)
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)
    ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
    return fig


def plot_scores(G, paths=None):
    """
    Plot the scores of the nodes in the graph
    :param G: The graph of epitopes
    :param paths: The optimal paths through the graph, if specified the scores of the nodes in the path will be plotted
    :return: A figure of the scores of the nodes in the graph
    """
    if paths:
        # Get the scores of each node in the path
        pos = [G.nodes[node]['pos'][0] for node in paths[0]]
        freq = [G.nodes[node]['Frequency'] for node in paths[0]]
        mhc_binding = [G.nodes[node]['MHC_Binding'] for node in paths[0]]
        immunogenicity = [G.nodes[node]['Immunogenicity'] for node in paths[0]]
    else:
        # Get the (mean) scores of all node positions in the graph
        pos = list(set([G.nodes[node]['pos'][0] for node in G.nodes]))
        freq = [np.mean([G.nodes[node]['Frequency'] for node in G.nodes if G.nodes[node]['pos'][0] == p]) for p in pos]
        mhc_binding = [np.mean([G.nodes[node]['MHC_Binding'] for node in G.nodes if G.nodes[node]['pos'][0] == p]) for p in pos]
        immunogenicity = [np.mean([G.nodes[node]['Immunogenicity'] for node in G.nodes if G.nodes[node]['pos'][0] == p]) for p in pos]
    # Plot stacked area
    fig, ax = plt.subplots(1,figsize=(16,8))
    ax.stackplot(pos, freq, mhc_binding, immunogenicity, labels=['Frequency', 'MHC Binding', 'Immunogenicity'])
    ax.set_xlabel('Position', fontsize=18)
    ax.set_ylabel('Score', fontsize=18)
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)
    ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
    ax.legend(loc='upper left', fontsize=14)
    ax.set_xlim([0, max(pos)])
    return fig


def plot_corr(G, x='Frequency', y='MHC_Binding'):
    df = pd.DataFrame({x: [G.nodes[node][x] for node in G.nodes],
                       y: [G.nodes[node][y] for node in G.nodes]})
    corr_plot = sns.jointplot(x=x, y=y, data=df, kind='reg', height=10) # 
    r, p = stats.pearsonr(df[x], df[y])
    corr_plot.ax_joint.annotate(f'$R_\\rho = {r:.2f}$', xy=(0.3, 0.3), xycoords=corr_plot.ax_joint.transAxes, fontsize=16)
    corr_plot.set_axis_labels(format_title(x), format_title(y), fontsize=16)
    return corr_plot


def plot_vaccine_design_pca(seqs_dict, Q, fasta_pr_designs_path, msa_pr_designs_path, n_clusters, plot_type='2D', plot=False):
    """
    Plot the PCA of the vaccine design.
    :param seqs_dict: Dictionary of sequences.
    :param Q: List of the optimal vaccine cocktail(s).
    :param fasta_pr_designs_path: Path to the FASTA file of protein sequences.
    :param msa_pr_designs_path: Path to the MSA file of protein sequences.
    :param n_clusters: Number of clusters to use for clustering the sequences.
    :param plot: Boolean indicating whether to generate all plots.
    :param plot_type: Type of plot to generate.
    :return: Plot of the PCA of the vaccine design.
    """
    # Add design to sequences dictionary
    seqs_dict['vaccine_design'] = path_to_seq(Q[0])
    # Write the sequences to a fasta file
    SeqIO.write([SeqRecord(Seq(seqs_dict[seq_id]), id=seq_id, description='') for seq_id in seqs_dict], fasta_pr_designs_path, 'fasta')
    # Perform MSA
    !mafft --auto {fasta_pr_designs_path} > {msa_pr_designs_path}
    # Perform PCA on the MSA
    clusters_dict, comp_df = pca_protein_rank(msa_pr_designs_path, n_clusters=n_clusters, plot=plot)
    # Plot the PCA results
    pca_plot = plot_pca(comp_df, plot_type=plot_type)
    return pca_plot, comp_df

## Simple example 

In [None]:
# Define the input epitopes
# Cyclic example: kmers_dict = {'MSA': 0.6, 'SAM': 0.2, 'AMS': 0.4}
kmers_dict = {
    'MSA': {'evo_rate':0, 'clades':[], 'clade_weight':0, 'immunogenicity':0, 'mhc_binding':0, 'pos':1, 'freq':0, 'score':0.6},
    'SAM': {'evo_rate':0, 'clades':[], 'clade_weight':0, 'immunogenicity':0, 'mhc_binding':0, 'pos':2, 'freq':0, 'score':0.2},
    'AMQ': {'evo_rate':0, 'clades':[], 'clade_weight':0, 'immunogenicity':0, 'mhc_binding':0, 'pos':3, 'freq':0, 'score':0.2},
    'MQL': {'evo_rate':0, 'clades':[], 'clade_weight':0, 'immunogenicity':0, 'mhc_binding':0, 'pos':4, 'freq':0, 'score':0.2},
    'SAR': {'evo_rate':0, 'clades':[], 'clade_weight':0, 'immunogenicity':0, 'mhc_binding':0, 'pos':2, 'freq':0, 'score':0.4},
    'MGA': {'evo_rate':0, 'clades':[], 'clade_weight':0, 'immunogenicity':0, 'mhc_binding':0, 'pos':1, 'freq':0, 'score':0.3},
    'GAR': {'evo_rate':0, 'clades':[], 'clade_weight':0, 'immunogenicity':0, 'mhc_binding':0, 'pos':2, 'freq':0, 'score':0.7},
    'ARQ': {'evo_rate':0, 'clades':[], 'clade_weight':0, 'immunogenicity':0, 'mhc_binding':0, 'pos':3, 'freq':0, 'score':0.4},
    'RQL': {'evo_rate':0, 'clades':[], 'clade_weight':0, 'immunogenicity':0, 'mhc_binding':0, 'pos':4, 'freq':0, 'score':0.4},
}

# Construct the graph
G = construct_graph(kmers_dict, aligned=True)

# Find the optimal path(s) through the graph of epitopes
Q = cocktail(G, m)
print([path_to_seq(path) for path in Q])

# Plot the results
plot_graph(G, paths=Q, node_size=2000, ylim=[-0.1, 1], with_labels=True, interactive=False)

## Perform a Multiple Sequence Alignment (MSA)
The input nucleotide sequences are preprocessed by performing a MSA. This is required to calculate omega (dN/dS) and is done in the following steps:
1. **Pre-MSA:** Run these sequences through `pre-msa.bf` in order to correct frame-shift mutations and translate the resulting sequences to proteins
2. **MSA:** Take the output of step 1 and run in through MAFFT to generate a **protein** MSA
3. **Post-MSA:** Run the protein MSA (from step 2) and the frameshift corrected nucleotide sequences (from step 1) through `post-msa.bf` to obtain a nucleotide MSA. This step will also, compress all identical sequences, i.e. replace them with a single representative sequence.

In [None]:
# Step 1) Pre-MSA
if not os.path.exists(fasta_pr_path):
    !hyphy {pre_msa_path} --input {fasta_path} --protein {fasta_pr_path}

In [None]:
# Step 2) MSA
if not os.path.exists(msa_pr_path):
    !mafft --auto {fasta_pr_path} > {msa_pr_path}

In [None]:
# Step 3) Post-MSA
# !hyphy {post_msa_path} --protein-msa {msa_pr_path} --nucleotide-sequences {fasta_nt_path} --output {msa_nt_path}

## Compare nonsynoymous (dN) and synonymous (dS) substitution rates

> Note: This section is deprecated.

Use HyPhy FEL to infer the number of nonsynoymous (N) and synonymous (S) substitutions on a per-site basis for a given coding alignment and corresponding phylogeny. Report N and S, and combine them to ascertain if N is greater than S:
1. **Tree consutrction:** Consutruct a phylogenetic tree in NEWICK format using IQ-TREE
2. **Compare dN and dS:** using the nucleotide MSA (from MSA step 3) and the phylogenetic tree (from step 1) as input to HyPhy FEL infer the nonsynoymous (dN), synonymous (dS) and estimate the number of synonymous/nonsynonymous sites for each codon/site in the MSA
3. **Get Evolutionary rate:** load the JSON output from HyPhy SLAC (step 2) and adjust the sign (+/-) of the number of synonymous (-1) vs nonsynonymous (1) sites to indicate if the codon is under posotive/negative selection

Note:
- Diversifying or positive selection is where dS/α < dN/β
- Purifying or negative selection is where dS/α > dN/β

We want to select for codons under purifying selection because they're more conserved

In [None]:
# Step 1) Tree construction
# !iqtree -s {msa_nt_path} -pre {fasta_path} --redo --fast

In [None]:
# Step 2) Compare dN and dS
# !hyphy slac --alignment {msa_nt_path} --tree {tree_path} --output {slac_path}

In [None]:
# Step 3) Get Evolutionary Rate
# evo_rate = get_evo_rate(slac_path)

## Load the FASTA sequences

In [None]:
if aligned:
    fasta_path = msa_pr_path
else:
    fasta_path = fasta_pr_path

fasta_seqs = SeqIO.parse(fasta_path,'fasta')
seqs_dict = {seq.id:str(seq.seq) for seq in fasta_seqs}

## Assign each sequence to a clade

In [None]:
clusters_dict, comp_df = pca_protein_rank(msa_pr_path, n_clusters=n_clusters, plot=plot)
# pca_2d_plot = plot_pca(comp_df, type='2D')
# pca_3d_plot = plot_pca(comp_df, type='3D')
# pca_interactive_plot = plot_pca(comp_df, interactive=True)

# pca_2d_plot.show()
# pca_3d_plot.show()
# pca_interactive_plot.show()

## Split into k-mers
Compute all possible k-mers of length `k` for the given target sequences and score each k-mer

In [None]:
# TODO: Add logic to use the input weights to check if all required inputs are present and only calculate desired scores
if aligned:
    df = seqs_dict_to_df(seqs_dict, evo_rate)
    kmers_dict = seqs_df_to_kmers_dict(df, weights, k)
else:
    seqs_clusters_dict = {seq_id: [seqs_dict[seq_id], clusters_dict[seq_id]] for seq_id in seqs_dict}
    kmers_dict = seqs_to_kmers_dict(seqs_clusters_dict, k=k, weights=weights, alleles=alleles, aligned=aligned, equalise_clades=equalise_clades)

## Construct the epitope graph
Create a Directed Graph (`DiGraph`) using the `networkx` package, where each epitope `e` is a node and edges connect nodes where the last `k−1` characters of `ea` match the first `k−1` characters of `eb`. For computational convenience, two extra nodes `BEGIN` and `END` are added. The `BEGIN` node connects to all the nodes that lack predecessors (`P(e)`) (corresponding to epitopes that are the first `k` characters in a sequence). Nodes that lack successors (`S(e)`) (because they are the last `k` characters in a sequence) are connected to the `END`. For plotting convenience, the length shortest path to the `BEGIN` node is added as a node attribute

In [None]:
G = construct_graph(kmers_dict, decycle=decycle, aligned=aligned)

## Assembly 
Take a path through the graph to optimise epitope frequency.

The forward loop computes the function `F(e)` (the largest sum achievable for any path that terminates with the epitope `e`) for all the nodes in a stepwise manner. The backward loop chooses the node with maximum value as the final epitope in our optimal string and works backwards to build the path that achieves the maximal score

In [None]:
# Find the optimal path(s) through the graph of epitopes
Q = cocktail(G, m)
[path_to_seq(path) for path in Q]

## Plot the epitope graph
The nodes are the epitopes `e` and the edges connect epitopes whose sequences overlap by `k − 1` amino acids. The x-axis shows the shortest path length to the `BEGIN` node, the y-axis indicates the epitope frequency `f(e)` in this target sequence set. The optimal path is shown in red which corresponds to the protein sequence that maximizes epitope coverage of the population

In [None]:
# F(e)
plot_graph(G, paths=Q, node_size=25, with_labels=False, ylim=[0, 5], interactive=False, colour_by='Frequency')

## Evaluate the implemented scores

Compare the generated vaccine design to a design optimised only for epitope frequency to determine if the additional scores change the vaccine design

> The additional immune scores (MHC binding and immunogenicity) do not change the vaccine design

In [None]:
seq_freq = 'MSDNGPQNQRSAPRITFGGPTDSTDNNQDGGRSGARPKQRRPQGLPNNTASWFTALTQHGKEELRFPRGQGVPINTNSGKDDQIGYYRRATRRVRGGDGKMKELSPRWYFYYLGTGPEASLPYGANKEGIVWVATEGALNTPKDHIGTRNPNNNAAIVLQLPQGTTLPKGFYAEGSRGGSQASSRSSSRSRGNSRNSTPGSSRGNSPARMASGGGETALALLLLDRLNQLESKVSGKGQQQQGQTVTKKSAAEASKKPRQKRTATKQYNVTQAFGRRGPEQTQGNFGDQELIRQGTDYKHWPQIAQFAPSASAFFGMSRIGMEVTPSGTWLTYHGAIKLDDKDPQFKDNVILLNKHIDAYKTFPPTEPKKDKKKKTDEAQPLPQRQKKQPTVTLLPAADMDDFSRQLQNSMSGASADSTQA'
seq_immune = path_to_seq(Q[0])
align = align_seqs(seq_freq, seq_immune)
similar = 'the same' if seq_freq == seq_immune else 'different'
print(f'The sequences are {similar}')
print(next(align))

Determine how each of the individual scores contribute to the overall score
> Here, we can see that the epitope frequency is contributing to all PTEs in our vaccine design, whereas the immune scores (MHC binding and immunogenicity) are more spikey and only contribute to some of the PTEs. It seems like the epitope frequency is having a disproportionately large affect on the overall score (despite the scores being equally weighted). Interestingly, it looks like there is a correlation between the Frequency and immune scores. If they are correlated, then this would explain why the additional scores do not change the vaccine design

In [None]:
scores_plot = plot_scores(G, paths=Q)
scores_plot.show()

Compare different scores to see if they are correlated to determine if they are redundant and if they're is a relationship between the different scores
> Here, we can see that the frequency and immune score (MHC binding) are NOT correlated. This means that we have not identified why the epitope frequency is having a disproportionately large affect on the overall score. It also means that the immune scores are not redundant and are providing additional information that is not captured by the epitope frequency

In [None]:
corr_plot = plot_corr(G)
corr_plot

View the distribution of the scores for all the epitopes in the graph
> Looking at the distribution of the epitope frequencies we can see that there are many infrequent PTEs and a few frequent PTEs. My current hypothesis is that the epitope frequency is having a disproportionately large affect on the overall score because of the **"zero-sum"** nature of epitope frequency. At a particular position, for one epitope to have a high frequency the other epitopes at that position must have a low frequency. Crucially, as this is for each particular position the epitope frequency has a disproportionately large affect on the overall score because it affects the choice of epitope at each decision point.

In [None]:
# hist_plot = plot_score(G, score='Frequency')
hist_plot = plot_score(G, score='MHC_Binding')
# hist_plot = plot_score(G, score='Immunogenicity')
hist_plot.show()

### See where the vaccine design(s) are in sequence space
For different numbers of clusters, plot the input target sequences and the vaccine design(s) on a PCA plot

In [None]:
n_clusters = [1, 5, 10, 20]

for n in n_clusters:
    seqs_dict, Q = design_vaccines(fasta_pr_path, msa_pr_path, n, k, m, weights, alleles, aligned, equalise_clades, decycle)
    pca_plot, comp_df = plot_vaccine_design_pca(seqs_dict, Q, fasta_pr_designs_path, msa_pr_designs_path, n, plot_type='2D')
    pca_plot.show()