#  T-cell vaccine design
Design vaccine(s) to elicit a T-cell response by optimising coverage of potential T-cell epitope (PTEs)

Here the term epitope `e` refers to a potential T-cell epitope (PTE) which is a short subsequence of `k` amino acids and also  represented as a node in the epitope graph `G`.

In [None]:
from Bio import SeqIO
import igviz as ig
from itertools import product
import matplotlib.pyplot as plt
import networkx as nx
import random

In [None]:
# Change
fasta_path = '../data/nucleoprotein/3_nuc_pro_uniq.fa'
k = 9

In [None]:
##########################################
# Utils to convert the format of sequences
##########################################

def seq_to_kmers(seq, k):
    """
    Returns a list of k-mers of length k for a given string of amino acid sequence
    """
    return [seq[i:i+k] for i in range(len(seq) - k + 1)]


def seqs_to_kmers_dict(seqs, k=9):
    """
    Returns a dictionary of all possible k-mers and their frequencies for a given list of sequences and value of k
    :param seqs: List of amino acid sequences
    :param k: Integer for substring length 
    :returns: Dictionary containing all possible k-mers and their frequencies
    """
    kmers_dict = {}
    N = len(seqs)
    for seq in seqs:
        # Get a unique set of k-mers for each sequence
        kmers = set(seq_to_kmers(seq, k))
        # Count the number of seqs for each k-mer (n)
        for e in kmers:
            if e in kmers_dict:
                kmers_dict[e] += 1
            else:
                kmers_dict[e] = 1
    # Calculate epitope frequency
    for e, n in kmers_dict.items():
        kmers_dict[e] = n/N
    return kmers_dict


def path_to_seq(path):
    """
    Returns an AA string for a list of epitopes (path)
    """
    seq = [path[0]] + [e[-1] for e in path[1:]]
    return ''.join(seq)


############################################
# Utils to get the min/max index from a list
############################################

def argmax(lst):
    """
    Returns the index for the maximum value in a list
    """
    return lst.index(max(lst))


def argmin(lst):
    """
    Returns the index for the minimum value in a list
    """
    return lst.index(min(lst))


#####################################
# Utils to retrieve info from a graph
#####################################

def P(G, e):
    """
    Returns the predecessors for a given graph G and node e
    :param G: Directed Graph containing epitopes
    :param e: String for a given potential T-cell epitope (PTE)
    :returns: List of predecessors
    """
    return list(G.predecessors(e))


def S(G, e):
    """
    Returns the successors for a given graph G and node e
    :param G: Directed Graph containing epitopes
    :param e: String for a given potential T-cell epitope (PTE)
    :returns: List of successors
    """
    return list(G.successors(e))


def f(G, e, f='Frequency'):
    """
    Returns the feature for a given epitope e eg frequency in the population
    :param G: Directed Graph containing epitopes
    :param e: String for a given potential T-cell epitope (PTE)
    :param e: String for the node feature (default = 'Frequency')
    :returns: Float for the epitope frequency
    """
    return G.nodes[e][f]


############################################
# Decycling - remove all cycles from a graph
############################################

def decycle_graph(G):
    """
    Return a Directed Graph with no cycles
    :param G: Directed Graph containing epitopes
    :returns: Directed Graph containing epitopes and no cycles
    """
    # j is a list of all compnents; each component is a list of nodes in G
    components = list(nx.strongly_connected_components(G))
    # Discard all single node components - no cycles there!
    components = [j for j in components if len(j) != 1]
    if len(components) != 0:
        for j in components:
            # Randomly choose two nodes from the selected component
            ea, eb = random.sample(list(j), k=2)
            cycle = cycle_from_two_nodes(G, ea, eb)
            if cycle:
                ea, eb = weak_edge_in_cycle(G, cycle)
                G.remove_edge(ea, eb)
                # Repeat until graph is acyclic
                G = decycle_graph(G)
    return G


def cycle_from_two_nodes(G, ea, eb):
    """
    Returns the cycle (i.e. path that starts and ends in with the same epitope) for two nodes
    :param G: Directed Graph containing epitopes
    :param ea: String for the first given potential T-cell epitope (PTE)
    :param eb: String for the second given potential T-cell epitope (PTE)
    :returns: List of epitope strings on path that is a cycle
    """
    try:
        path_ab = nx.shortest_path(G, source=ea, target=eb)
        path_ba = nx.shortest_path(G, source=eb, target=ea)
        # Merge two paths into a cycle
        cycle = path_ab[:-1] + path_ba
    except nx.NetworkXNoPath:
        cycle = []
    return cycle


def weak_edge_in_cycle(G, cycle):
    """
    Returns the weak edge (edge with the lowest score) in a cycle
    :param G: Directed Graph containing epitopes
    :param cycle: List of epitope strings on path that is a cycle
    :returns: Tuple for the weak edge containing the two epitope strings
    """
    edges = seq_to_kmers(cycle, k=2)
    values = []
    for ea, eb in edges:
        # v is heuristic “value” of edge
        v = f(G, ea) + f(G, eb)
        # Add value if cutting edge would isolate ea
        if len(S(G, ea)) == 1:
            v = v + f(G, ea)
        # Add value if cutting edge would isolate eb
        if len(P(G, eb)) == 1:
            v = v + f(G, eb)
        values.append(v)
    ea, eb = edges[argmin(values)]
    return ea, eb


###############################################
# Find optimal path through a graph of epitopes
###############################################

def find_optimal_path(G):
    """
    Returns the optimal path through a graph of epitopes
    :param G: Directed Graph containing epitopes
    :returns: List of epitope strings on the optimal path
    """
    # Forward loop - compute F(e)
    for e in G.nodes:
        F(G, e)
    # Backward loop - build the path that achieves the maximal score
    path = backward(G)
    return path


def F(G, e):
    """
    Returns the maximum total frequency over all paths that end in e
    :param G: Directed Graph containing epitopes
    :param e: String for a given potential T-cell epitope (PTE)
    :returns: Float for the maximum total epitope frequency
    """
    # Use precomputed F(e) if it already exists for the epitope
    if 'F(e)' not in G.nodes[e]:
        predecessors = P(G, e)
        if not predecessors:
            # If the set of predecessors P(e) is empty, then F(e) = f(e)
            Fe = f(G, e)
        else:
            # If the set of predecessors P(e) is not empty, then F(e) = f(e) + max(F(P(e)))
            Fe = f(G, e) + max([F(G, pe) for pe in predecessors])
        # Save F(e) to the graph for this epitope
        nx.set_node_attributes(G, {e: Fe}, 'F(e)')
    return f(G, e, f='F(e)')


def backward(G, path=[]):
    """
    Returns the path that achieves the maximal score
    :param G: Directed Graph containing epitopes
    :param path: List of epitope strings to complete (deafult=[])
    :returns: List of epitope strings on path that achieve maximum score
    """
    # Get the precomputed F(e) from the graph for all epitopes
    Fe_dict = nx.get_node_attributes(G, 'F(e)')
    if not path:
        # Get the epitope with the maximum F(e) as the final epitope in our optimal path
        path = [max(Fe_dict, key=Fe_dict.get)]
    # Get the most recently added epitope e and it's predecessors P(e)
    e = path[0]
    predecessors = P(G, e)
    if predecessors[0] != 'BEGIN':
        # Add the best (highest F(e)) predecessor P(e) of epitope e to our path
        i = argmax([Fe_dict[pe] for pe in predecessors])
        path.insert(0, predecessors[i])
        # Repeat until you get to the start
        backward(G, path)
    return path


###########################################################
# Cocktail: Find (and iteratively refine) a set of antigens
###########################################################

def cocktail(G, m):
    """
    Returns a list of m antigens
    :param G: Directed Graph containing epitopes
    :param m: Integer for number of antigens
    :returns: List containing m antigens
    """
    Q = [] # vaccine
    for n in range(0, m):
        # Compute and save next antigen sequence
        q = find_optimal_path(G)
        # Add q to vaccine
        Q.append(q)
        # Save original epitope frequency for iterative refinement
        freq_dict = nx.get_node_attributes(G, 'Frequency')
        nx.set_node_attributes(G, freq_dict, 'Original Frequency')
        # No credit for including e in subsequent antigens
        for e in q:
            nx.set_node_attributes(G, {e: 0}, 'Frequency')
        # Remove F(e) so it's recomputed using the updated frequencies
        for (n,d) in G.nodes(data=True):
            del d['F(e)']
    # TODO: Optional - Repeat for iterative refinement
    return Q


#######################################
# Construct/visualise the epitope graph
#######################################

def construct_graph(kmers_dict, decycle=True, edge_colour='#BFBFBF'):
    """
    Return a Directed Graph with unique k-mers as nodes, where overlapping k-mers are connected by edges
    :param kmers_dict: Dictionary containing k-mers and their counts
    :param decycle: Boolean if the output graph should be decycled (default=True)
    :param edge_colour: String for edge colour (default='black')
    :returns: Directed Graph containing epitopes
    """
    # Create graph
    G = nx.DiGraph()
    # Add nodes - for each unique k-mer
    for e, freq in kmers_dict.items():
        G.add_node(e, Frequency=freq)
    # Add edges - where the last k−1 characters of ea match the first k−1 characters of eb
    for ea, eb in product(G.nodes(), G.nodes()):
        if not G.has_edge(ea, eb) and ea[1:] == eb[:-1]:
            G.add_edge(ea, eb, colour=edge_colour)
    # Decycle graph
    if decycle:
        G = decycle_graph(G)
    # Add begin and end nodes
    begin_nodes = [e for e in list(G.nodes) if not P(G, e)]
    end_nodes = [e for e in list(G.nodes) if not S(G, e)]
    G.add_node('BEGIN', Frequency=0)
    G.add_node('END', Frequency=0)
    for e in begin_nodes:
        G.add_edge('BEGIN', e, colour=edge_colour)
    for e in end_nodes:
        G.add_edge(e, 'END', colour=edge_colour)
    # Add the position node attribute
    for e in G.nodes():
        freq = f(G, e)
        pos = nx.shortest_path_length(G, source='BEGIN', target=e)
        nx.set_node_attributes(G, {e:(pos, freq)}, 'pos')
    # TODO: Make sure END node is at the end
    # end_pos = max([pos[0] for pos in list(nx.get_node_attributes(G, 'pos').values())]) + 1
    return G


def plot_graph(G, path=None, node_size=150, with_labels=False, ylim=([0,1]), interactive=False):
    """
    Plot the epitope graph
    :param G: Directed Graph containing epitopes
    :param path: List of epitope strings (deafult=None)
    :param node_size: Integer for size of nodes in non-interactive plot (default=150)
    :param with_labels: Boolean for if epitope labels should be displayed in the non-interactive plot (default=False)
    :param ylim: List for y-axis limits (default=[0,1])
    :param interactive: Boolean for if the plot should be interactive (default=False)
    :returns: None
    """
    if interactive:
        fig = ig.plot(G, color_method='Frequency', node_text=['Frequency', 'F(e)']) # layout='spectral','spiral','spring'
        return fig.show()
    else:
        # Define vars
        freq = list(nx.get_node_attributes(G,'Frequency').values())
        pos = nx.get_node_attributes(G, 'pos')
        if path:
            path = ['BEGIN'] + path + ['END']
            for i in range(0,len(path)-1):
                G.edges[path[i], path[i+1]]['colour'] = 'red'
        edge_colours = nx.get_edge_attributes(G, 'colour')
        # Plot
        fig, ax = plt.subplots(1,figsize=(16,8))
        nx.draw(G, node_color=freq, pos=pos, node_size=node_size, with_labels=with_labels, edge_color=edge_colours.values(), width=2, font_color='white', ax=ax)
        limits=plt.axis('on')
        max_pos = max([p[0] for p in pos.values()]) + 0.5
        ax.set_xlim([-0.5, max_pos])
        ax.set_ylim(ylim)
        ax.spines.right.set_visible(False)
        ax.spines.top.set_visible(False)
        ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
        ax.tick_params(axis='both', which='major', labelsize=14)
        plt.ylabel('Epitope Frequency, f(e)', fontsize=18)

## Simple example 

In [None]:
# Define the input epitopes
# Cyclic example: kmers_dict = {'MSA': 0.6, 'SAM': 0.2, 'AMS': 0.4}
kmers_dict = {
    'MSA': 0.6,
    'SAM': 0.2,
    'AMQ': 0.2,
    'MQL': 0.2,
    'SAR': 0.4,
    'MGA': 0.3,
    'GAR': 0.7,
    'ARQ': 0.4,
    'RQL': 0.4,
}

# Construct the graph
G = construct_graph(kmers_dict)

# Find the optimal path through the graph of epitopes
path = find_optimal_path(G)
print(path_to_seq(path))

# Plot the results
plot_graph(G, path=path, node_size=2000, ylim=[-0.1, 1], with_labels=True, interactive=False)

## Load the FASTA sequences

In [None]:
fasta_seqs = SeqIO.parse(open(fasta_path),'fasta')
seqs = [str(seq.seq) for seq in fasta_seqs]

## Split into k-mers
Compute all possible k-mers of length `k` for the given target sequences

In [None]:
kmers_dict = seqs_to_kmers_dict(seqs, k)

## Construct the epitope graph
Create a Directed Graph (`DiGraph`) using the `networkx` package, where each epitope `e` is a node and edges connect nodes where the last `k−1` characters of `ea` match the first `k−1` characters of `eb`. For computational convenience, two extra nodes `BEGIN` and `END` are added. The `BEGIN` node connects to all the nodes that lack predecessors (`P(e)`) (corresponding to epitopes that are the first `k` characters in a sequence). Nodes that lack successors (`S(e)`) (because they are the last `k` characters in a sequence) are connected to the `END`. For plotting convenience, the length shortest path to the `BEGIN` node is added as a node attribute

In [None]:
G = construct_graph(kmers_dict)

## Assembly 
Take a path through the graph to optimise epitope frequency.

The forward loop computes the function `F(e)` (the largest sum achievable for any path that terminates with the epitope `e`) for all the nodes in a stepwise manner. The backward loop chooses the node with maximum value as the final epitope in our optimal string and works backwards to build the path that achieves the maximal score

In [None]:
# TODO: Add additional measures to scoring function eg binding affinity prediction

In [None]:
# Find the optimal path through the graph of epitopes
path = find_optimal_path(G)
path_to_seq(path)

## Plot the epitope graph
The nodes are the epitopes `e` and the edges connect epitopes whose sequences overlap by `k − 1` amino acids. The x-axis shows the shortest path length to the `BEGIN` node, the y-axis indicates the epitope frequency `f(e)` in this target sequence set. The optimal path is shown in red which corresponds to the protein sequence that maximizes epitope coverage of the population

In [None]:
plot_graph(G, path=path, node_size=25, with_labels=False, ylim=[0, 1], interactive=False)