In this script we are trying to simulate future variants of the receptor binding domain of SARS-CoV-2 that could be even more infectious than the present onces. 

At first we will compute the mutation rates at two levels:
1. Single Nucleotide Mutation rates
2. Trinucleotide Level mutation rates

For detecting the mutations we use a phylogeny based approach where we used a preconstructed phylogenetic tree of SARS-CoV-2 to detect for the mutation rates

In [None]:
def root_tree_with_first_sequence(tree, root_sequence):
    """
    Roots the tree file with the provided root sequence.

    Args:
    tree: The tree object to be rooted.
    root_sequence (str): The sequence_id to be used as the root.

    Returns:
    rooted_tree_file (str): The path to the rooted tree file.
    """
    # Root the tree file with the outgroup
    rooted_tree_file = tree.root_with_outgroup(root_sequence)

    return rooted_tree_file



In [None]:
def get_reconstructed_ancestral_sequences(ancestral_state_file):
    """
    Reconstructs ancestral sequences from the ancestral state file.
    
    Parameters:
        ancestral_state_file (str): The path to the ancestral state file.
    
    Returns:
        dict: A dictionary containing the reconstructed ancestral sequences, where the keys are the node names and the values are the sequences.
    """
    # Read the ancestral state file
    ancestral_states = pd.read_csv(ancestral_state_file, sep='\t', skiprows=8)
    # Concatenate the ancestral states at each position in a given node to reconstruct ancestral sequences
    ancestral_sequences = {}
    for node, group in ancestral_states.groupby('Node'):
        sequence = ''.join(group['State'])
        ancestral_sequences[node] = sequence
    
    return ancestral_sequences


In [None]:
#Function to compute the trinucleotide frequency of a sequence to be used in normalization step while computing the trinucleotide mutation rate
def get_kmer_frequency(sequence, k):
    """
    Calculate the frequency of k-mers in a given sequence.

    Args:
        sequence (str): The input RNA sequence.
        k (int): The length of the k-mer.

    Returns:
        kmers_frequency_dict: A dictionary containing all possible k-mers as keys and their frequencies in the input sequence as values.
    """
    # Remove gaps from the sequence
    sequence = sequence.replace("-", "")
    #Total number of kmers in the sequence
    total_kmers = len(sequence) - k + 1

    # Declare and Initialize the dictionary to store the kmer frequency by assigning all possible k-mers with frequency 0
    nucleotides = ["C", "T", "G", "A"]
    kmers_frequency_dict = {''.join(combo): 0 for combo in product(nucleotides, repeat=k)}

    # Iterate through the sequence to find all k-mers
    for i in range(total_kmers):
        kmer = sequence[i:i + k]  # Extract the k-mer
        # Check if the k-mer consists only of A, T, G, or C and not any other characters
        if set(kmer).issubset(set(nucleotides)):
            # Increment the frequency of this k-mer
            kmers_frequency_dict[kmer] += 1  

    return kmers_frequency_dict

In [None]:
def find_substitution_mutations(ancestor_seq, sample_seq):
    """
    To compute the mutations by comparing each sequence against its nearest reconstructed ancestral sequence

    Args:
        ancestor_seq (str): The ancestor sequence.
        sample_seq (str): The sample sequence.

    Returns:
        mutations: A list of mutations, where each mutation is represented as a list containing the mutation, the trinucleotide context, and the trinucleotide normalized mutation rate.
    """
    # Initialize the list to store the mutations
    mutations = []
    # Get the trinucleotide mutation frequencies in the ancestor sequence
    trinucleotide_mutation_frequency = get_kmer_frequency(ancestor_seq, 3)
    # Iterate through each position and base in the two sequences
    for i, (a, b) in enumerate(zip(ancestor_seq, sample_seq)):
        # If the base is different between the ancestor and sample sequences 
        if a != b:
            #position of the mutation in the alignment is index + 1
            position = i + 1  
            # Extract the trinucleotide context of the mutation
            trinucleotide = str(ancestor_seq[position-2:position+1]).upper()  
            # Base in the ancestor sequence
            ref_base = a.upper()
            # Base in the sample sequence
            sample_base = b.upper()
            # We ignore those mutations at the starting or end poisition of the alignment as their trinucleotide context cannot be captured
            # We also ignore those mutations in which the surrounding trinucleotide contains abnormal basses or the mutation itself contains abnormal bases
            if (
                any(base not in ["C", "T", "G", "A"] for base in trinucleotide) or
                any(base not in ["C", "T", "G", "A"] for base in [ref_base, sample_base]) or
                len(trinucleotide) != 3
            ):                
                continue
            else:
                # Record the mutation data which includes the mutation, its trinucleotide context and trinuleotide mutation rate (normalized) 
                mutation = f"{ref_base}{position}{sample_base}".upper()
                mutations.append([mutation, trinucleotide, 1/trinucleotide_mutation_frequency[trinucleotide]])
               
    return mutations


In [None]:
def find_ancestral_node(tree, node_name):
    """
    To find the nearest ancestor to a particular node in the tree

    Args:
        tree: The phylogenetic tree to search in.
        node_name: The name of the node for which we wish to find the neareast ancestor.

    Returns:
        ancestral_node: The ancestral node if found, None otherwise.
    """
    # Search the target node in the tree
    target_node = next((clade for clade in tree.find_clades() if str(clade.name).split("/")[0] == node_name), None)
    if target_node is None:
        print(f"Node '{node_name}' not found.")
        return None

    # Traverse the tree to find the ancestral node to the target node
    ancestral_node = None
    for clade in tree.find_clades():
        if target_node in clade:
            ancestral_node = clade
            break

    return ancestral_node



In [None]:

def get_mutation_data_for_internal_nodes(ancestral_sequences, alignment_file, tree):
    """ 
    To compute the mutations for the internal nodes of the tree

    Args:
        ancestral_sequences (dict): Dictionary of reconstructed ancestral sequences.
        alignment_file (str): The path to the mutiple sequence alignment file.
        tree: The phylogenetic tree.

    Returns:
        - denovo_mutations: A dictionary mapping each internal node to its denovo mutations.
         - ancestral_mutations: A dictionary mapping each internal node to its ancestral mutations.
         - ancestral_node_data: A dictionary mapping each internal node to its nearest ancestral node.

    """
    # Initialize the dictionaries to store the mutation data
    denovo_mutations = {}
    ancestral_mutations = {}
    ancestral_node_data = {}
    # Sort the ancestral nodes based on their node number
    ancestral_nodes = list(ancestral_sequences.keys())
    ancestral_nodes = sorted(ancestral_nodes, key=lambda node: int(re.search(r'\d+', node).group()))

    for node in ancestral_nodes:
        #If it is the initial node, then the reference sequence is used as the ancestral sequence
        if node == "Node1":
            reference_record = next(SeqIO.parse(alignment_file, "fasta"), None)
            reference_id = reference_record.id
            reference_seq = str(reference_record.seq).upper()
            denovo_mutations[node] = find_substitution_mutations(reference_seq, ancestral_sequences[node])
            ancestral_mutations[node] = []
            ancestral_node_data["Node1"] = reference_id
        #If it is not the initial node, then we find the nearest ancestor and compute the mutations against it
        else:
            ancestral_node = find_ancestral_node(tree, node)
            denovo_mutations[node] = find_substitution_mutations(
                ancestral_sequences[str(ancestral_node).split("/")[0]], ancestral_sequences[node]
            )
            ancestral_mutations[node] = (
                denovo_mutations[str(ancestral_node).split("/")[0]]
                + ancestral_mutations[str(ancestral_node).split("/")[0]]
            )
            ancestral_node_data[node] = ancestral_node

    return denovo_mutations, ancestral_mutations, ancestral_node_data


In [None]:

def get_mutation_data_for_leaf_nodes(alignment_file, ancestral_sequences, denovo_mutations, ancestral_mutations, tree, ancestral_node_data):
    """
    To compute the mutations for the leaf nodes of the tree
    
    Parameters:
    - alignment_file (str): Path to the multiple sequence alignment file .
    - ancestral_sequences (dict): Dictionary of reconstructed ancestral sequences.
    - denovo_mutations (dict): Dictionary to store denovo mutations. 
    - ancestral_mutations (dict): Dictionary to store ancestral mutations. 
    - tree: the Phylogenetic tree object.
    - ancestral_node_data (dict): Dictionary to store ancestral node data.

    Returns:
    - denovo_mutations (dict): Updated dictionary of denovo mutations.
    - ancestral_mutations (dict): Updated dictionary of ancestral mutations.
    - ancestral_node_data (dict): Updated dictionary of ancestral node data for each leaf node.
    """
    # Get the leaf node names from the alignment file    
    leaf_nodes = [record.id for record in SeqIO.parse(alignment_file, "fasta")]
    #As the first leaf node is the reference sequence, we need to handle it separately
    denovo_mutations[leaf_nodes[0]] = []
    ancestral_mutations[leaf_nodes[0]] = []
    #Remove the reference sequence from the leaf nodes
    leaf_nodes = leaf_nodes[1:]
    #Compute the mutations for each leaf node
    for leaf_node in leaf_nodes:
        #find its nearest ancestral node data
        ancestral_node = find_ancestral_node(tree, leaf_node)
        #store the ancestral node data for the leaf node
        ancestral_node_data[leaf_node] = ancestral_node
        #extract the sequence of the leaf node
        sequence = next((str(record.seq).upper() for record in SeqIO.parse(alignment_file, "fasta") if record.id == leaf_node), None)
        #compute the mutations against the nearest ancestral sequence
        denovo_mutations[leaf_node] = find_substitution_mutations(
            ancestral_sequences[str(ancestral_node).split("/")[0]], 
            sequence)
        #bootstrapping during tree construction can modify the names of the nodes in the tree by including the bootstrap results, hence we need to extract only the node name from the tree
        #and use it to save mutation data for that leafnode
        ancestral_mutations[leaf_node] = denovo_mutations[str(ancestral_node).split("/")[0]] + ancestral_mutations[str(ancestral_node).split("/")[0]]
    return denovo_mutations, ancestral_mutations, ancestral_node_data



In [None]:
#path to the MSA file
alignment_file = 'SARS-CoV-2_18000+1_msa.fasta'

#we load the phylognetic tree 
tree = Phylo.read(f"Fast IQTREE/fast.treefile", "newick")

#we now root it by assigning the outgroup/reference sequence as the root of the tree
outgroup_id = "NC_045512.2"  
root_tree_with_first_sequence(tree, outgroup_id)

#we now reconstruct the ancestral sequences of the tree from the ancestral state file
ancestral_sequences=get_reconstructed_ancestral_sequences(f"Fast IQTREE/fast.state")

#we now compute the mutations for the internal nodes of the tree
denovo_mutations, ancestral_mutations,ancestral_node_data=get_mutation_data_for_internal_nodes(ancestral_sequences,alignment_file, tree)

#we now compute the mutations for the leaf nodes of the tree   
denovo_mutations, ancestral_mutations,ancestral_node_data=get_mutation_data_for_leaf_nodes(alignment_file,ancestral_sequences, denovo_mutations, ancestral_mutations,tree,ancestral_node_data)


In [None]:
from modeller import *
from modeller.automodel import *

#from modeller import soap_protein_od

env = Environ()

env.io.atom_files_directory = ['basic-example/']  # Ensure the directory contains your PDB files


a = AutoModel(env, alnfile='basic-example/TvLDH-1bdmA.ali',
              knowns='1bdmA', sequence='TvLDH',
              assess_methods=(assess.DOPE,
                              #soap_protein_od.Scorer(),
                              assess.GA341))
a.starting_model = 1
a.ending_model = 5
a.make()





In [5]:
!python -c 'import pyrosetta_installer; pyrosetta_installer.install_pyrosetta()'


Installing PyRosetta:
 os: ubuntu
 type: Release
 Rosetta C++ extras: 
 mirror: https://west.rosettacommons.org/pyrosetta/release/release
 extra packages: numpy

PyRosetta wheel url: https://:@west.rosettacommons.org/pyrosetta/release/release/PyRosetta4.Release.python312.ubuntu.wheel/pyrosetta-2024.19+release.a34b73c40f-cp312-cp312-linux_x86_64.whl
Collecting pyrosetta==2024.19+release.a34b73c40f
  Downloading https://:****@west.rosettacommons.org/pyrosetta/release/release/PyRosetta4.Release.python312.ubuntu.wheel/pyrosetta-2024.19+release.a34b73c40f-cp312-cp312-linux_x86_64.whl (1668.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 GB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:04[0mm
[?25hInstalling collected packages: pyrosetta
Successfully installed pyrosetta-2024.19+release.a34b73c40f


In [6]:
!pip install pyrosetta-distributed

Collecting pyrosetta-distributed
  Downloading pyrosetta_distributed-0.0.3-py3-none-any.whl.metadata (906 bytes)
Collecting billiard (from pyrosetta-distributed)
  Downloading billiard-4.2.0-py3-none-any.whl.metadata (4.4 kB)
Collecting dask-jobqueue (from pyrosetta-distributed)
  Downloading dask_jobqueue-0.9.0-py2.py3-none-any.whl.metadata (1.3 kB)
Collecting blosc (from pyrosetta-distributed)
  Downloading blosc-1.11.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)
Collecting python-xz (from pyrosetta-distributed)
  Downloading python_xz-0.5.0-py3-none-any.whl.metadata (8.5 kB)
Collecting fqdn (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook->jupyter->pyrosetta-distributed)
  Downloading fqdn-1.5.1-py3-none-any.whl.metadata (1.4 kB)
Collecting isoduration (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook->jupyter->pyrosetta-distributed)
  Downloading isodur

In [16]:

import pyrosetta

# Initialize PyRosetta
# Initialize PyRosetta with multithreading support
pyrosetta.init()

# Load the docked complex into PyRosetta
pose = pyrosetta.pose_from_pdb('6m0j.pdb')

# Create a MoveMap object to specify which parts of the structure can move
movemap = pyrosetta.rosetta.core.kinematics.MoveMap()
movemap.set_bb(True)  # Allow backbone minimization
movemap.set_chi(True)  # Allow side-chain minimization

# Create and configure the MinMover
scorefxn = pyrosetta.get_fa_scorefxn()  # Get the full-atom score function
min_mover = pyrosetta.rosetta.protocols.minimization_packing.MinMover()
min_mover.movemap(movemap)  # Apply the MoveMap to the MinMover
min_mover.score_function(scorefxn)  # Set the score function
min_mover.max_iter(5000)  # Set the number of minimization iterations

# Apply the minimization
min_mover.apply(pose)


# Calculate the total energy of the minimized complex
total_energy = scorefxn(pose)
print(f"Total Energy after minimization: {total_energy} REU")

┌──────────────────────────────────────────────────────────────────────────────┐
│                                 PyRosetta-4                                  │
│              Created in JHU by Sergey Lyskov and PyRosetta Team              │
│              (C) Copyright Rosetta Commons Member Institutions               │
│                                                                              │
│ NOTE: USE OF PyRosetta FOR COMMERCIAL PURPOSES REQUIRE PURCHASE OF A LICENSE │
│         See LICENSE.PyRosetta.md or email license@uw.edu for details         │
└──────────────────────────────────────────────────────────────────────────────┘
PyRosetta-4 2024 [Rosetta PyRosetta4.Release.python312.ubuntu 2024.19+release.a34b73c40fe9c61558d566d6a63f803cfb15a4fc 2024-05-02T16:22:03] retrieved from: http://www.pyrosetta.org
core.init: Checking for fconfig files in pwd and ./rosetta/flags
core.init: Rosetta version: PyRosetta4.Release.python312.ubuntu r381 2024.19+release.a34b73c40f a34b73c40f

In [None]:
1736.1101167092763

In [13]:
# Load the complex into PyRosetta
pose_complex = pyrosetta.pose_from_pdb('6m0j.pdb')

# Separate the complex into receptor and ligand (you will need to specify the chain IDs)
# Assuming chain A is the receptor and chain B is the ligand
pose_receptor = pose_complex.split_by_chain(1)  # Receptor
pose_ligand = pose_complex.split_by_chain(2)    # Ligand

# Get the score function
scorefxn = pyrosetta.get_fa_scorefxn()

# Calculate energies
energy_complex = scorefxn(pose_complex)
energy_receptor = scorefxn(pose_receptor)
energy_ligand = scorefxn(pose_ligand)

# Calculate the binding affinity (ΔG)
binding_affinity = energy_complex - (energy_receptor + energy_ligand)
print(f"Binding Affinity (ΔG): {binding_affinity} REU")

core.import_pose.import_pose: File '6m0j.pdb' automatically determined to be of type PDB
core.conformation.Conformation: Found disulfide between residues 115 123
core.conformation.Conformation: Found disulfide between residues 326 343
core.conformation.Conformation: Found disulfide between residues 512 524
core.conformation.Conformation: Found disulfide between residues 607 632
core.conformation.Conformation: Found disulfide between residues 650 703
core.conformation.Conformation: Found disulfide between residues 662 796
core.conformation.Conformation: Found disulfide between residues 751 759
core.pack.pack_missing_sidechains: packing residue number 600 because of missing atom number 10 atom name O1
core.pack.pack_missing_sidechains: packing residue number 601 because of missing atom number 10 atom name O1
core.pack.pack_missing_sidechains: packing residue number 602 because of missing atom number 10 atom name O1
core.pack.pack_missing_sidechains: packing residue number 603 because of 