This notebook implements the RNA folding QUBO of Lewis et al., 2021 (QFold: A new modelling paradigm for the RNA folding problem).

Basically, the QUBO takes on this form, where $q_i$ and $q_j$ are each "stacked quartets" (SQs):

$$H = -\sum_i N_iq_i - \sum_{i>j} M_{ij}q_iq_j$$

Where $M_{ij}$ is $M^+$ > 0 if SQs $i$ and $j$ are nested (i.e., form a stacked octet), $M^-$ (where $M^+>M^->0$) if they are pseudoknotted, and $M^p \ll 0$ if they overlap.

In [None]:
# import packages:

import numpy as np
import pandas as pd
import math
import os
import glob

In [None]:
# function to return the SQ energy based on nearest-neighbor interactions:

def SQ_energy(sq):
    sqe = 0
    if len(sq) > 1:
        for i in range(1, len(sq)):
            if sq[i] == "AU":
                if sq[i-1] == "AU": 
                    sqe += 0.9
                if sq[i-1] == "CG":
                    sqe += 2.2
                if sq[i-1] == "GC":
                    sqe += 2.1
                if sq[i-1] == "UA":
                    sqe += 1.1
                if sq[i-1] == "GU":
                    sqe += 0.6
                if sq[i-1] == "UG":
                    sqe += 1.4
            if sq[i] == "CG":
                if sq[i-1] == "AU": 
                    sqe += 2.1
                if sq[i-1] == "CG":
                    sqe += 3.3
                if sq[i-1] == "GC":
                    sqe += 2.4
                if sq[i-1] == "UA":
                    sqe += 2.1
                if sq[i-1] == "GU":
                    sqe += 1.4
                if sq[i-1] == "UG":
                    sqe += 2.1
            if sq[i] == "GC":
                if sq[i-1] == "AU": 
                    sqe += 2.4
                if sq[i-1] == "CG":
                    sqe += 3.4
                if sq[i-1] == "GC":
                    sqe += 3.3
                if sq[i-1] == "UA":
                    sqe += 2.2
                if sq[i-1] == "GU":
                    sqe += 1.5
                if sq[i-1] == "UG":
                    sqe += 2.5
            if sq[i] == "UA":
                if sq[i-1] == "AU": 
                    sqe += 1.3
                if sq[i-1] == "CG":
                    sqe += 2.4
                if sq[i-1] == "GC":
                    sqe += 2.1
                if sq[i-1] == "UA":
                    sqe += 0.9
                if sq[i-1] == "GU":
                    sqe += 1.0
                if sq[i-1] == "UG":
                    sqe += 1.3
            if sq[i] == "GU":
                if sq[i-1] == "AU": 
                    sqe += 1.3
                if sq[i-1] == "CG":
                    sqe += 2.5
                if sq[i-1] == "GC":
                    sqe += 2.1
                if sq[i-1] == "UA":
                    sqe += 1.4
                if sq[i-1] == "GU":
                    sqe += 0.5
                if sq[i-1] == "UG":
                    sqe += -1.3
            if sq[i] == "UG":
                if sq[i-1] == "AU": 
                    sqe += 1.0
                if sq[i-1] == "CG":
                    sqe += 1.5
                if sq[i-1] == "GC":
                    sqe += 1.4
                if sq[i-1] == "UA":
                    sqe += 0.6
                if sq[i-1] == "GU":
                    sqe += -0.3
                if sq[i-1] == "UG":
                    sqe += 0.5
    return sqe

In [None]:
# function to read in .ct file and give a list of known structure SQs:

def actual_SQs(seq_ss, seq_ps):    # seq_ss: secondary structure, seq_ps: primary structure (sequence)
    
    with open(subdirectory+"/"+seq_ss) as file:
        ss_lines = file.readlines()
    
    with open(subdirectory+"/"+seq_ps) as file:
        ps_lines = file.readlines()
    
    rna = ps_lines[1]
    
    SQs_actual = []

    sip = False                       # SQ in progress?
    sl = 0                            # SQ length
    sp = []                           # SQ pairs
    last_line = [0, 0, 0, 0, 0, 0]    # initiate last line

    for i in range(0, len(ss_lines)):
        line = ss_lines[i].strip().split()
        
        if (int(line[4]) != 0 and sip == False):
            sip = True
            temp = [int(line[0]), int(line[4])]
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('C' or 'c')):
                sp.append("GC")
            elif (rna[i] == ('C' or 'c') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("CG")
            elif (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("GU")
            elif (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("UG")
            elif (rna[i] == ('A' or 'a') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("AU")
            elif (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('A' or 'a')):
                sp.append("UA")
            else: 
                sp.append("noncanonical")
            sl += 1
            
        elif (int(line[4]) != 0 and sip == True and (int(last_line[4])-int(line[4]) == 1)):
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('C' or 'c')):
                sp.append("GC")
            elif (rna[i] == ('C' or 'c') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("CG")
            elif (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("GU")
            elif (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("UG")
            elif (rna[i] == ('A' or 'a') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("AU")
            elif (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('A' or 'a')):
                sp.append("UA")
            else: 
                sp.append("noncanonical")
            sl += 1
            if sl == 2:
                if temp[1] > temp[0]:
                    temp.append(SQ_energy(sp[-2:]))
                    SQs_actual.append(temp)
                temp = [int(line[0]), int(line[4])]
                sl = 1
            
        elif (int(line[4]) == 0 and sip == True):
            sip = False
            if sl == 2:
                if temp[1] > temp[0]:
                    temp.append(SQ_energy(sp[-2:]))
                    SQs_actual.append(temp)
            elif sl == 1:
                sp.pop()
            sl = 0
            
        elif ((int(last_line[4])-int(line[4]) != 1) and int(last_line[4]) != 0  and sip == True):

            if sl == 2:
                if temp[1] > temp[0]:
                    temp.append(SQ_energy(sp[-2:]))
                    SQs_actual.append(temp)
            elif sl == 1:
                sp.pop()
            temp = [int(line[0]), int(line[4])]
            sl = 0
            
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('C' or 'c')):
                sp.append("GC")
            elif (rna[i] == ('C' or 'c') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("CG")
            elif (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("GU")
            elif (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("UG")
            elif (rna[i] == ('A' or 'a') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("AU")
            elif (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('A' or 'a')):
                sp.append("UA")
            else: 
                sp.append("noncanonical")
            sl += 1
        
        last_line = line
        
    return SQs_actual

In [None]:
# function to read in .fasta file and generate list of potential stems at least 3 base-pairs long:

def potential_SQs(seq_ps):
    
    with open(subdirectory+"/"+seq_ps) as file:
        lines = file.readlines()
    
    rna = lines[1]
    
    matrix = np.zeros((len(rna),len(rna)))
    for diag in range(0, len(matrix)):
        for row in range(0, len(matrix)-diag):
            col = row + diag
            base1 = rna[row]
            base2 = rna[col]
            if row != col:
                if ((base1 == ("A" or "a")) and (base2 == ("U" or "u"))) or ((base1 == ("U" or "u")) and (base2 == ("A" or "a"))) or ((base1 == ("G" or "g")) and (base2 == ("U" or "u"))) or ((base1 == ("U" or "u")) and (base2 == ("G" or "g"))) or ((base1 == ("G" or "g")) and (base2 == ("C" or "c"))) or ((base1 == ("C" or "c")) and (base2 == ("G" or "g"))):
                    matrix[row][col] = 1
    
    SQs_potential = []

    for row in range(0, len(matrix)):
        for col in range (row, len(matrix)):
            if row != col:
                if matrix[row][col] != 0:
                    SQp = []                 # stacked quartet pairs
                    temp_row = row
                    temp_col = col
                    SQ = [row+1, col+1, 0] # [SQ start, SQ end, SQ energy]
                    length = 0
                    while (matrix[temp_row][temp_col] != 0) and (temp_row != temp_col):
                        base1 = rna[temp_row]
                        base2 = rna[temp_col]
                        if (base1 == ('G' or 'g') and base2 == ('C' or 'c')):
                            SQp.append("GC")
                        if (base1 == ('C' or 'c') and base2 == ('G' or 'g')):
                            SQp.append("CG")
                        if (base1 == ('G' or 'g') and base2 == ('U' or 'u')):
                            SQp.append("GU")
                        if (base1 == ('U' or 'u') and base2 == ('G' or 'g')):
                            SQp.append("UG")
                        if (base1 == ('A' or 'a') and base2 == ('U' or 'u')):
                            SQp.append("AU")
                        if (base1 == ('U' or 'u') and base2 == ('A' or 'a')):
                            SQp.append("UA")
                        length += 1
                        temp_row += 1
                        temp_col -= 1
                        if length == 2 and col-row-2*length >= 3:
                            SQ[2] = SQ_energy(SQp)
                            SQs_potential.append(SQ.copy())
                            break
    
    return [SQs_potential, rna, len(rna)]

In [None]:
# function to generate lists of SQ pairs that overlap or nest or pseudoknot:

def potential_couplings(SQs_potential):
    
    overlaps_potential = []
    nestings_potential = []
    pseudoknots_potential = []
    
    overlap_penalty = 1e6

    for i in range(len(SQs_potential)):
        for j in range(i+1, len(SQs_potential)):
    
            SQ1 = SQs_potential[i]
            SQ2 = SQs_potential[j]
            
            i_a = SQ1[0]
            j_a = SQ1[1]
            i_b = SQ2[0]
            j_b = SQ2[1]
    
            overlaps = [i, j, 0]    
            nestings = [i, j, 0]
            pseudoknots = [i, j, 0]
    
            SQ1_cspan1 = set(range(SQ1[1]-2+1, SQ1[1]+1))
            SQ2_cspan1 = set(range(SQ2[1]-2+1, SQ2[1]+1))
            
            SQ1_cspan2 = set(range(SQ1[0], SQ1[0]+2))
            SQ2_cspan2 = set(range(SQ2[0], SQ2[0]+2))
    
            if (len(SQ1_cspan1 & SQ2_cspan1) != 0) or (len(SQ1_cspan2 & SQ2_cspan2) != 0)  or (len(SQ1_cspan1 & SQ2_cspan2) != 0) or (len(SQ1_cspan2 & SQ2_cspan1) != 0):
                
                if (SQ1[0] == SQ2[0]+1 and SQ1[1] == SQ2[1]-1) or (SQ2[0] == SQ1[0]+1 and SQ2[1] == SQ1[1]-1):
                    nestings[2] = 1
                else:
                    overlaps[2] = overlap_penalty
            elif (i_a < i_b and i_b < j_a and j_a < j_b) or (i_b < i_a and i_a < j_b and j_b < j_a):
                pseudoknots[2] = 1
            
            overlaps_potential.append(overlaps)
            nestings_potential.append(nestings)
            pseudoknots_potential.append(pseudoknots)
            
    return (overlaps_potential, nestings_potential, pseudoknots_potential)

In [None]:
# function to generate the Hamiltonian of a given RNA structure from potential SQs, overlaps, and pseudoknots:

def model(SQs_potential, overlaps_potential, nestings_potential, pseudoknots_potential, mplus, mminus):
    
    L = {}
    Q = {}
    k = 0
    
    for i in range(0, len(SQs_potential)):
        L[str(i)] = -SQs_potential[i][2]
        for j in range(i+1, len(SQs_potential)):
            Q[(str(i), str(j))] = overlaps_potential[k][2] - (mplus*nestings_potential[k][2] + mminus*pseudoknots_potential[k][2])
            k += 1
            
    return L, Q

In [None]:
# function to evaluate the energy of the known structure under the model Hamiltonian:

def energy(SQs_actual, mplus, mminus):
    
    k = 0
    couplings_actual = potential_couplings(SQs_actual)
    nestings_actual = couplings_actual[1]
    pseudoknots_actual = couplings_actual[2]
    cost = 0
        
    for i in range(0, len(SQs_actual)):
        cost -= SQs_actual[i][2]
        for j in range(i+1, len(SQs_actual)):
            cost -= mplus*nestings_actual[k][2] + mminus*pseudoknots_actual[k][2]
            k += 1
    
    return cost

In [None]:
# function to compare actual and predicted structure based on comparison of base-pairs:

def evaluation_1(SQs_actual, SQs_potential):
    
    bp_actual = []
    bp_predicted = []

    for i in range(0, len(SQs_actual)):
        for j in range(0, 2):
            bp_actual.append((SQs_actual[i][0]+j, SQs_actual[i][1]-j))
        
    for i in range(0, len(SQs_potential)):
        for j in range(0, 2):
            bp_predicted.append((SQs_potential[i][0]+j, SQs_potential[i][1]-j))
            
    C = 0    # number of correctly identified base pairs
    M = 0    # number of the predicted base pairs missing from the known structure
    I = 0    # number of non-predicted base pairs present in the known structure

    for i in range(0, len(bp_predicted)):
        if bp_predicted[i] in bp_actual:
            C += 1
        else:
            M += 1

    for i in range(0, len(bp_actual)):
        if bp_actual[i] not in bp_predicted:
            I += 1
            
    if C+M != 0:
        ppv = C/(C+M)
    else:
        ppv = 0
    if C+I != 0:
        sensitivity = C/(C+I)
    else:
        sensitivity = 0
    
    return [ppv, sensitivity]

In [None]:
# function to compare actual and predicted structure based on comparison of bases involved in pairing:

def evaluation_2(SQs_actual, SQs_predicted):
    
    pb_actual = []
    pb_predicted = []

    for i in range(0, len(SQs_actual)):
        for j in range(0, 2):
            pb_actual.append(SQs_actual[i][0]+j)
            pb_actual.append(SQs_actual[i][1]-j)
        
    for i in range(0, len(SQs_predicted)):
        for j in range(0, 2):
            pb_predicted.append(SQs_predicted[i][0]+j)
            pb_predicted.append(SQs_predicted[i][1]-j)
            
    C = 0    # number of correctly identified bases that are paired
    M = 0    # number of the predicted paired bases missing from the known structure
    I = 0    # number of non-predicted paired bases present in the known structure

    for i in range(0, len(pb_predicted)):
        if pb_predicted[i] in pb_actual:
            C += 1
        else:
            M += 1

    for i in range(0, len(pb_actual)):
        if pb_actual[i] not in pb_predicted:
            I += 1
            
    if C+M != 0:
        ppv = C/(C+M)
    else:
        ppv = 0
    if C+I != 0:
        sensitivity = C/(C+I)
    else:
        sensitivity = 0
    
    return [ppv, sensitivity]

In [None]:
def connectivity_table(bpRNA_id, sequence, stems, t):
    print(len(sequence), bpRNA_id, file=open("./results/cts/"+bpRNA_id+"_"+t+"_model2.ct", "w"))
    for i in range(0, len(sequence)):
        pair = 0
        for j in stems:
            for k in range(j[0], j[0]+2):
                if i+1 == k:
                    pair = j[1]+j[0]-k
            for k in range(j[1]-1, j[1]+1):
                if i+1 == k:
                    pair = j[0]+j[1]-k
        print(i+1, sequence[i], i, i+2, pair, i+1, file=open("./results/cts/"+bpRNA_id+"_"+t+"_model2.ct", "a"))

In [None]:
# connecting with D-Wave:

from dwave.cloud import Client

client = Client.from_config(token="DEV-6b38e4697eaa586b361595c629788f595b810a14")
client.get_solvers()

from dwave.system.samplers import DWaveSampler
from dwave.system.samplers import LeapHybridSampler
from dwave.system.composites import EmbeddingComposite

import dimod

sampler_q = EmbeddingComposite(DWaveSampler(token="DEV-6b38e4697eaa586b361595c629788f595b810a14", solver={'topology__type': 'pegasus'}))
sampler_h = LeapHybridSampler(token="DEV-6b38e4697eaa586b361595c629788f595b810a14")

The following cell runs all structures of a given folder (`subdirectory`), where:

- `mminus` is the pseudoknot penalty, and is to be tuned
- `mplus` is the nesting penalty, and is to be tuned

1. First, the actual SQs are found, and then their energies. 
2. Second, the potential SQs are found, as well as their potential overlaps, nestings, and pseudoknots. 3. Third, the QUBO model for is built. 
4. Fourth, the model is run on DWave.
5. Fifth, connectivity table files are built for the actual and predicted structures.
6. Sixth, the predicted structure is evaluated against the actual structure using BP and PB Sensitivity and PPV. 

We will need to take this cell and modify it for both the training and testing steps to our protocol.

In [None]:
mplus = 1
mminus = 1

subdirectory = "./data/woutPKs/s"

ct = [f for f in os.listdir(subdirectory) if f.endswith('.ct.txt')]
fasta = [f for f in os.listdir(subdirectory) if f.endswith('.fasta.txt')]

bprna_id    = [] # IDs

a_SQs       = [] # actual structure stacked quartets
a_energies  = [] # actual structure energies

p_SQs       = [] # potential structure stacked quartets
p_couplings = [] # potential overlaps, nestings, pseudoknots

models      = [] # models

problem = []     # intiate list of problems
prediction = []  # intiate list of predictions
evaluation = []  # initiate list of evaluations
min_time = []    # intiate list of time to solution

for i in range(0, len(ct)):

    try:                                                       # try/except here b/c some wrong CT files
        print("building model for:", ct[i].split('.')[0])
        bprna_id.append(ct[i].split('.')[0])                   # append ID of structure
    
        a_SQs.append(actual_SQs(ct[i], fasta[i]))              # find actual SQs of structure
        a_energies.append(energy(a_SQs[i], mplus, mminus))     # compute energy of actual structure
    
        p_SQs.append(potential_SQs(fasta[i]))                  # find potential SQs of structure
        p_couplings.append(potential_couplings(p_SQs[i][0]))   # find potential couplings (overlaps, nestings, pseudoknots) of structure
    
        models.append(model(p_SQs[i][0], p_couplings[i][0], p_couplings[i][1], p_couplings[i][2], mplus, mminus))
        
    except:
        print("error in preprocessing, skipping...")
        
    try:
        print("running model for:", ct[i].split('.')[0])
        problem.append(dimod.BinaryQuadraticModel(models[i][0], models[i][1], vartype = 'BINARY', offset = 0.0))    

        #sampleset = sampler_q.sample(problem[i], num_reads=1000)
        #min_time.append("placeholder")
        sampleset = sampler_h.sample(problem[i])              # hybrid
        min_time.append(sampler_h.min_time_limit(problem[i])) # hybrid

        for datum in sampleset.data(['sample', 'energy', 'num_occurrences']):
            results = datum.sample
            predicted_energy = datum.energy
    
        SQs_found = []           # initiate list of predicted SQs

        for j in range(0, len(results)):
            if results[str(j)] == 1:
                SQs_found.append(p_SQs[i][0][j])
            
        prediction.append([SQs_found, predicted_energy]) # record predicted stems and structure energy
                
        connectivity_table(bprna_id[i], p_SQs[i][1], SQs_found, "predicted") # write predicted CT file
        connectivity_table(bprna_id[i], p_SQs[i][1], a_SQs[i], "actual")     # write actual CT file
        
    except:
        print("no embedding found, skipping...")

    try:
        print("evaluating model for:", ct[i].split('.')[0])
        metrics_1 = []
        metrics_2 = []
        metrics_1.append(evaluation_1(a_SQs[i], SQs_found))                  # compute BP metrics
        metrics_2.append(evaluation_2(a_SQs[i], SQs_found))                  # compute PB metrics
        evaluation.append((metrics_1, metrics_2))
    
    except:
        print("no structure found, skipping...")