In [6]:
import os
import dimod
import numpy as np
from braket.aws import AwsDevice
from braket.ocean_plugin import BraketSampler, BraketDWaveSampler
from dwave.system.composites import EmbeddingComposite
from qiskit.algorithms.optimizers import SPSA
from IPython.display import display, clear_output

In [7]:
def get_structures(pks, size):
    subdirectory = './data/'+pks+'/'+size
    bprna = []
    fasta = [f for f in os.listdir(subdirectory) if f.endswith('.fasta.txt')]
    for f in fasta:
        bprna.append(subdirectory+"/"+f.split(".")[0])
    return bprna

In [8]:
# function to return the stem energy based on nearest-neighbor interactions:

def stem_energy(sp):
    se = 0
    if len(sp) > 1:
        for i in range(1, len(sp)):
            if sp[i] == "AU":
                if sp[i-1] == "AU": 
                    se += 0.9
                if sp[i-1] == "CG":
                    se += 2.2
                if sp[i-1] == "GC":
                    se += 2.1
                if sp[i-1] == "UA":
                    se += 1.1
                if sp[i-1] == "GU":
                    se += 0.6
                if sp[i-1] == "UG":
                    se += 1.4
            if sp[i] == "CG":
                if sp[i-1] == "AU": 
                    se += 2.1
                if sp[i-1] == "CG":
                    se += 3.3
                if sp[i-1] == "GC":
                    se += 2.4
                if sp[i-1] == "UA":
                    se += 2.1
                if sp[i-1] == "GU":
                    se += 1.4
                if sp[i-1] == "UG":
                    se += 2.1
            if sp[i] == "GC":
                if sp[i-1] == "AU": 
                    se += 2.4
                if sp[i-1] == "CG":
                    se += 3.4
                if sp[i-1] == "GC":
                    se += 3.3
                if sp[i-1] == "UA":
                    se += 2.2
                if sp[i-1] == "GU":
                    se += 1.5
                if sp[i-1] == "UG":
                    se += 2.5
            if sp[i] == "UA":
                if sp[i-1] == "AU": 
                    se += 1.3
                if sp[i-1] == "CG":
                    se += 2.4
                if sp[i-1] == "GC":
                    se += 2.1
                if sp[i-1] == "UA":
                    se += 0.9
                if sp[i-1] == "GU":
                    se += 1.0
                if sp[i-1] == "UG":
                    se += 1.3
            if sp[i] == "GU":
                if sp[i-1] == "AU": 
                    se += 1.3
                if sp[i-1] == "CG":
                    se += 2.5
                if sp[i-1] == "GC":
                    se += 2.1
                if sp[i-1] == "UA":
                    se += 1.4
                if sp[i-1] == "GU":
                    se += 0.5
                if sp[i-1] == "UG":
                    se += -1.3
            if sp[i] == "UG":
                if sp[i-1] == "AU": 
                    se += 1.0
                if sp[i-1] == "CG":
                    se += 1.5
                if sp[i-1] == "GC":
                    se += 1.4
                if sp[i-1] == "UA":
                    se += 0.6
                if sp[i-1] == "GU":
                    se += -0.3
                if sp[i-1] == "UG":
                    se += 0.5
    return se

In [9]:
def actual_stems(seq_ss, seq_ps):
    
    with open(seq_ss) as file:
        ss_lines = file.readlines()
    
    with open(seq_ps) as file:
        ps_lines = file.readlines()
    
    rna = ps_lines[1]
    
    stems_actual = []

    sip = False                       # stem in progress?
    sl = 0                            # stem length
    sp = []
    last_line = [0, 0, 0, 0, 0, 0]    # initiate last line

    for i in range(0, len(ss_lines)):
        line = ss_lines[i].strip().split()
        
        if (int(line[4]) != 0 and sip == False):
            sip = True
            temp = [int(line[0]), int(line[4])]
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('C' or 'c')):
                sp.append("GC")
            if (rna[i] == ('C' or 'c') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("CG")
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("GU")
            if (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("UG")
            if (rna[i] == ('A' or 'a') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("AU")
            if (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('A' or 'a')):
                sp.append("UA")
            sl += 1
            
        if (int(line[4]) != 0 and sip == True and (int(last_line[4])-int(line[4]) == 1)):
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('C' or 'c')):
                sp.append("GC")
            if (rna[i] == ('C' or 'c') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("CG")
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("GU")
            if (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("UG")
            if (rna[i] == ('A' or 'a') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("AU")
            if (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('A' or 'a')):
                sp.append("UA")
            sl += 1
            
        if (int(line[4]) == 0 and sip == True):
            sip = False
            temp.append(sl)
            temp.append(int(temp[1]-temp[0]-2*sl))
            temp.append(stem_energy(sp))
            if temp[1] > temp[0]:
                stems_actual.append(temp)
            sl = 0
            sp = []
            
        if ((int(last_line[4])-int(line[4]) != 1) and int(last_line[4]) != 0  and sip == True):
            temp.append(sl)
            temp.append(int(temp[1]-temp[0]-2*sl))
            temp.append(stem_energy(sp))
            if temp[1] > temp[0]:
                stems_actual.append(temp)
            temp = [int(line[0]), int(line[4])]
            sl = 0
            sp = []
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('C' or 'c')):
                sp.append("GC")
            if (rna[i] == ('C' or 'c') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("CG")
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("GU")
            if (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('G' or 'g')):
                sp.append("UG")
            if (rna[i] == ('A' or 'a') and rna[int(line[4])-1] == ('U' or 'u')):
                sp.append("AU")
            if (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('A' or 'a')):
                sp.append("UA")
            sl += 1
        
        last_line = line
        
    return stems_actual

In [10]:
def potential_stems(seq_ps):
    
    with open(seq_ps) as file:
        lines = file.readlines()
    
    rna = lines[1]
    
    matrix = np.zeros((len(rna),len(rna)))
    for diag in range(0, len(matrix)):
        for row in range(0, len(matrix)-diag):
            col = row + diag
            base1 = rna[row]
            base2 = rna[col]
            if row != col:
                if ((base1 == ("A" or "a")) and (base2 == ("U" or "u"))) or ((base1 == ("U" or "u")) and (base2 == ("A" or "a"))) or ((base1 == ("G" or "g")) and (base2 == ("U" or "u"))) or ((base1 == ("U" or "u")) and (base2 == ("G" or "g"))) or ((base1 == ("G" or "g")) and (base2 == ("C" or "c"))) or ((base1 == ("C" or "c")) and (base2 == ("G" or "g"))):
                    matrix[row][col] = 1
                    
    stems_potential = []
    mu = 0

    for row in range(0, len(matrix)):
        for col in range (row, len(matrix)):
            if row != col:
                if matrix[row][col] != 0:
                    sp = []                 # stem pairs
                    temp_row = row
                    temp_col = col
                    stem = [row+1,col+1,0,0,0]
                    length = 0
                    while (matrix[temp_row][temp_col] != 0) and (temp_row != temp_col):
                        base1 = rna[temp_row]
                        base2 = rna[temp_col]
                        if (base1 == ('G' or 'g') and base2 == ('C' or 'c')):
                            sp.append("GC")
                        if (base1 == ('C' or 'c') and base2 == ('G' or 'g')):
                            sp.append("CG")
                        if (base1 == ('G' or 'g') and base2 == ('U' or 'u')):
                            sp.append("GU")
                        if (base1 == ('U' or 'u') and base2 == ('G' or 'g')):
                            sp.append("UG")
                        if (base1 == ('A' or 'a') and base2 == ('U' or 'u')):
                            sp.append("AU")
                        if (base1 == ('U' or 'u') and base2 == ('A' or 'a')):
                            sp.append("UA")
                        length += 1
                        temp_row += 1
                        temp_col -= 1
                        if length >= 3:
                            stem[2] = int(length)
                            stem[3] = int(col-row-2*length)
                            stem[4] = stem_energy(sp)
                            stems_potential.append(stem.copy())
                    if stem_energy(sp) > mu:
                        mu = stem_energy(sp)
    
    return [stems_potential, mu, rna, len(rna)]

In [11]:
# function to generate energy per in-line pseudoknotted helix of length n:

def pseudoknot_sub_penalty(length):
    return np.exp(0.572992*length+0.219677)

In [12]:
def potential_pseudoknots(stems_potential, gamma):

    pseudoknots_potential = []

    for i in range(len(stems_potential)):
        for j in range(i + 1, len(stems_potential)):
            
            stem1 = stems_potential[i]
            stem2 = stems_potential[j]
    
            i_a = stem1[0]
            j_a = stem1[1]
            i_b = stem2[0]
            j_b = stem2[1]
    
            pseudoknot = [i,j,1]
            if (i_a < i_b and i_b < j_a and j_a < j_b) or (i_b < i_a and i_a < j_b and j_b < j_a):
                
                pseudoknot[2] = gamma*np.log(stem1[2]*pseudoknot_sub_penalty(stem1[2])**2+stem2[2]*pseudoknot_sub_penalty(stem2[2])**2)
    
            pseudoknots_potential.append(pseudoknot)
            
    return pseudoknots_potential

In [13]:
def potential_overlaps(stems_potential):
    
    overlaps_potential = []
    overlap_penalty = 1e6

    for i in range(len(stems_potential)):
        for j in range(i+1, len(stems_potential)):
    
            stem1 = stems_potential[i]
            stem2 = stems_potential[j]
    
            overlap = [i, j, 0]
    
            stem1_cspan1 = set(range(stem1[1]-int(stem1[2])+1, stem1[1]+1))
            stem2_cspan1 = set(range(stem2[1]-int(stem2[2])+1, stem2[1]+1))
            
            stem1_cspan2 = set(range(stem1[0], stem1[0]+int(stem1[2])))
            stem2_cspan2 = set(range(stem2[0], stem2[0]+int(stem2[2])))
    
            if (len(stem1_cspan1 & stem2_cspan1) != 0) or (len(stem1_cspan2 & stem2_cspan2) != 0)  or (len(stem1_cspan1 & stem2_cspan2) != 0) or (len(stem1_cspan2 & stem2_cspan1) != 0):
        
                overlap[2] = overlap_penalty
        
            overlaps_potential.append(overlap)
            
    return overlaps_potential

In [14]:
def loop_penalty(ll):
    lp = 0
    if ll == 0 or ll == 1 or ll == 2:
        lp = 1000
    if ll == 3:
        lp = 7.4
    if ll == 4:
        lp = 5.9
    if ll == 5:
        lp = 4.4
    if ll == 6:
        lp = 4.3
    if ll >= 7:
        lp = 4.1
    return lp

In [15]:
def model(stems_potential, pseudoknots_potential, overlaps_potential, mu, alpha, beta):
    
    L = {}
    Q = {}
    k = 0

    for i in range(0, len(stems_potential)):
        k_i = stems_potential[i][4]
        lp  = stems_potential[i][3]
        L[str(i)] = alpha*((k_i-mu)**2)-beta*(k_i-loop_penalty(stems_potential[i][2]))
        for j in range(i+1, len(stems_potential)):
            Q[(str(i), str(j))] = pseudoknots_potential[k][2]+overlaps_potential[k][2]
            k += 1
    
    return L, Q

In [16]:
def energy(stems_actual, gamma, alpha, beta):
    k = 0
    pseudoknots_actual = potential_pseudoknots(stems_actual, gamma)
    cost = 0
    mu = max(list(map(list, zip(*stems_actual)))[2])
        
    for i in range(0, len(stems_actual)):
        k_i = stems_actual[i][4]
        cost += alpha*((k_i-mu)**2)-beta*(k_i-loop_penalty(stems_actual[i][2]))
        for j in range(i+1, len(stems_actual)):
            cost += pseudoknots_actual[k][2]
            k += 1
    
    return cost

In [17]:
# function to compare actual and predicted structure based on comparison of base-pairs:

def evaluation_1(stems_actual, stems_potential):
    
    bp_actual = []
    bp_predicted = []

    for i in range(0, len(stems_actual)):
        for j in range(0, stems_actual[i][2]):
            bp_actual.append((stems_actual[i][0]+j, stems_actual[i][1]-j))
        
    for i in range(0, len(stems_potential)):
        for j in range(0, stems_potential[i][2]):
            bp_predicted.append((stems_potential[i][0]+j, stems_potential[i][1]-j))
            
    C = 0    # number of correctly identified base pairs
    M = 0    # number of the predicted base pairs missing from the known structure
    I = 0    # number of non-predicted base pairs present in the known structure

    for i in range(0, len(bp_predicted)):
        if bp_predicted[i] in bp_actual:
            C += 1
        else:
            M += 1

    for i in range(0, len(bp_actual)):
        if bp_actual[i] not in bp_predicted:
            I += 1
    ppv = 0
    sensitivity = 0
    if C != 0 or M != 0:
        ppv = C/(C+M)
    if C != 0 or I !=0:
        sensitivity = C/(C+I)
    
    return [ppv, sensitivity]

In [18]:
# function to compare actual and predicted structure based on comparison of bases involved in pairing:

def evaluation_2(stems_actual, stems_predicted):
    
    pb_actual = []
    pb_predicted = []

    for i in range(0, len(stems_actual)):
        for j in range(0, stems_actual[i][2]):
            pb_actual.append(stems_actual[i][0]+j)
            pb_actual.append(stems_actual[i][1]-j)
        
    for i in range(0, len(stems_predicted)):
        for j in range(0, stems_predicted[i][2]):
            pb_predicted.append(stems_predicted[i][0]+j)
            pb_predicted.append(stems_predicted[i][1]-j)
            
    C = 0    # number of correctly identified bases that are paired
    M = 0    # number of the predicted paired bases missing from the known structure
    I = 0    # number of non-predicted paired bases present in the known structure

    for i in range(0, len(pb_predicted)):
        if pb_predicted[i] in pb_actual:
            C += 1
        else:
            M += 1

    for i in range(0, len(pb_actual)):
        if pb_actual[i] not in pb_predicted:
            I += 1
            
    ppv = 0
    sensitivity = 0
    if C != 0 or M !=0:
        ppv = C/(C+M)
    if C != 0 or I !=0:
        sensitivity = C/(C+I)
    
    return [ppv, sensitivity]

In [19]:
def spsa_optimizer_callback(nb_fct_eval, params, fct_value, stepsize, step_accepted, train_history):
    print("In callback")
    train_history.append((nb_fct_eval,params,fct_value))
    clear_output(wait=True)
    display(f'evaluations : {nb_fct_eval} loss: {fct_value:0.4f}')

In [20]:
def calculate_cost_function(expectation_values, target_values):
    product_zt = expectation_values*target_values
    all_costs = ((1-product_zt)/2)**2
    return all_costs

In [21]:
num_reads = 100

In [22]:
def harmonic_mean_metric(first_metric, second_metric):
    metric_1 = []
    metric_2 = []
    for t in first_metric:
        sen = t[0]
        ppv = t[0]
        if sen == 0 and ppv == 0:
            metric_1.append(0)
        else:
            metric_1.append((2*sen*ppv)/(sen+ppv))
    for t in second_metric:
        sen = t[0]
        ppv = t[0]
        if sen == 0 and ppv == 0:
            metric_2.append(0)
        else:
            metric_2.append((2*sen*ppv)/(sen+ppv))
    total_metric = []
    for i in range(len(metric_1)):
        m_1 = metric_1[i]
        m_2 = metric_2[i]
        if m_1 == 0 or m_2 == 0:
            total_metric.append(0)
        else:
            total_metric.append(2/((1/m_1)+(1/m_2)))
    return np.array(total_metric)

In [23]:
#sampler = BraketDWaveSampler(device_arn='arn:aws:braket:::device/qpu/d-wave/Advantage_system4')
#sampler = EmbeddingComposite(sampler)

sampler = dimod.SimulatedAnnealingSampler()

def optimize_params(optimizer, hyper_params, inital_point):
    target_value = 1
    def cost_function(hyper_params):
        print(hyper_params)
        first_metric = []
        second_metric = []
        full_bprna = []
        pks = ["wPks", "woutPKs"]
        sizes = ["s"]
        a_stems = {}
        p_stems = {}
        a_energies = {}
        problems = {}
        stems_f = {}
        for pk in pks:
            for size in sizes:    
                bprna = get_structures(pks, size)
                full_bprna = full_bprna + bprna
        for bprna in full_bprna:
            bprna_id = bprna.split("/")[4]
            fasta_file = bprna + ".fasta.txt"
            ct_file = bprna + ".ct.txt" 
            a_stems[bprna_id] = actual_stems(ct_file, fasta_file)
            a_energies[bprna_id] = energy(a_stems[bprna_id], hyper_params[0], hyper_params[1], hyper_params[2])
            p_stems[bprna_id] = potential_stems(fasta_file)
            p_psudoknots = potential_pseudoknots(p_stems[bprna_id][0], hyper_params[0])
            p_overlaps = potential_overlaps(p_stems[bprna_id][0])
            md = model(p_stems[bprna_id][0], p_psudoknots, p_overlaps, p_stems[bprna_id][1],hyper_params[1], hyper_params[2])
            problems[bprna_id] = dimod.BinaryQuadraticModel(md[0], md[1], vartype = 'BINARY', offset = 0.0)
            
        print("Finished creating the models")
        
        for key, value in problems.items():
            sampleset = sampler.sample(value, num_reads=num_reads)
            #print("Ready:", key)
            for datum in sampleset.data(['sample', 'energy', 'num_occurrences']):
                results_hybrid = datum.sample
                predicted_energy = datum.energy
    
            f_stems = []

            for j in range(0, len(results_hybrid)):
                if results_hybrid[str(j)] == 1:
                    f_stems.append(p_stems[key][0][j])
                
            stems_f[key] = ([f_stems, predicted_energy])
            first_metric.append(evaluation_1(a_stems[key], stems_f[key][0]))
            second_metric.append(evaluation_2(a_stems[key], stems_f[key][0]))
        print("Finished running the models")
        all_metrics = harmonic_mean_metric(first_metric, second_metric)
        all_costs = calculate_cost_function(all_metrics, target_value)
        #print(all_costs)
        return np.sum(all_costs)/len(all_costs)
    
    model_values, loss, nfev = optimizer.optimize(len(hyper_params), cost_function, initial_point=inital_point)
    return model_values, loss, nfev

initial_point = [0, 1, 1]
model_params = [0, 1, 1]
train_history = []
optimizer = SPSA(maxiter=50, callback=lambda n, p, v, ss, sa: spsa_optimizer_callback(n, p, v, ss, sa, train_history))
model_values, loss, nfev = optimize_params(optimizer, model_params, initial_point)
print("Final values:",model_values)

'evaluations : 150 loss: 0.0707'

[-5.47154165  4.1117629   3.97340818]
Finished creating the models
Finished running the models
Final values: [-5.47154165  4.1117629   3.97340818]


In [24]:
full_bprna = []
pks = ["wPks"]
sizes = ["s"]
penalties = [-5.4715]
cl = 4.1117
cb = 3.973
a_stems = {}
p_stems = {}
a_energies = {}
problems = {}
stems_f = {}
for pk in pks:
    for size in sizes:    
        bprna = get_structures("wPKs", "s")
        full_bprna = full_bprna + bprna
for penalty in penalties:
    for bprna in full_bprna:
        bprna_id = bprna.split("/")[4]
        fasta_file = bprna + ".fasta.txt"
        ct_file = bprna + ".ct.txt" 
        a_stems[bprna_id] = actual_stems(ct_file, fasta_file)
        a_energies[bprna_id] = energy(a_stems[bprna_id], penalty, cl, cb)
        p_stems[bprna_id] = potential_stems(fasta_file)
        p_psudoknots = potential_pseudoknots(p_stems[bprna_id][0], penalty)
        p_overlaps = potential_overlaps(p_stems[bprna_id][0])
        md = model(p_stems[bprna_id][0], p_psudoknots, p_overlaps, p_stems[bprna_id][1], cl, cb)
        problems[bprna_id] = dimod.BinaryQuadraticModel(md[0], md[1], vartype = 'BINARY', offset = 0.0)

In [25]:
metric_1 = {}
metric_2 = {}
sampler = BraketDWaveSampler(device_arn='arn:aws:braket:::device/qpu/d-wave/Advantage_system4')
sampler = EmbeddingComposite(sampler)
for key, value in problems.items():
    sampleset = sampler.sample(value, num_reads=num_reads)
    print("Ready:", key)
    for datum in sampleset.data(['sample', 'energy', 'num_occurrences']):
        results_hybrid = datum.sample
        predicted_energy = datum.energy
    
    f_stems = []

    for j in range(0, len(results_hybrid)):
        if results_hybrid[str(j)] == 1:
            f_stems.append(p_stems[key][0][j])
        
    stems_f[key] = ([f_stems, predicted_energy])
    metric_1[key] = evaluation_1(a_stems[key], stems_f[key][0])
    metric_2[key] = evaluation_2(a_stems[key], stems_f[key][0])

Ready: bpRNA_RFAM_23352
Ready: bpRNA_RFAM_23366
Ready: bpRNA_RFAM_23381
Ready: bpRNA_RFAM_23457
Ready: bpRNA_RFAM_23495


In [26]:
print(metric_1)

{'bpRNA_RFAM_23352': [0.10204081632653061, 0.5555555555555556], 'bpRNA_RFAM_23366': [0.25, 0.375], 'bpRNA_RFAM_23381': [0.0, 0.0], 'bpRNA_RFAM_23457': [0.575, 1.0], 'bpRNA_RFAM_23495': [0.0, 0.0]}


In [27]:
print(metric_2)

{'bpRNA_RFAM_23352': [0.47959183673469385, 0.94], 'bpRNA_RFAM_23366': [0.5416666666666666, 0.6842105263157895], 'bpRNA_RFAM_23381': [0.20833333333333334, 0.3125], 'bpRNA_RFAM_23457': [0.75, 1.0], 'bpRNA_RFAM_23495': [0.5, 0.5454545454545454]}
