In [1]:
import os
import dimod
import numpy as np
from braket.aws import AwsDevice
from braket.ocean_plugin import BraketSampler, BraketDWaveSampler
from dwave.system.composites import EmbeddingComposite
from qiskit.algorithms.optimizers import SPSA
from IPython.display import display, clear_output

In [2]:
def get_structures(pks, size):
    subdirectory = './data/'+pks+'/'+size
    bprna = []
    fasta = [f for f in os.listdir(subdirectory) if f.endswith('.fasta.txt')]
    for f in fasta:
        bprna.append(subdirectory+"/"+f.split(".")[0])
    return bprna

In [3]:
def actual_stems(seq_ss, seq_ps):
    
    with open(seq_ss) as file:
        lines = file.readlines()
    
    with open(seq_ps) as file:
        fasta_lines = file.readlines()
    
    rna = fasta_lines[1]
    
    stems_actual = []

    sip = False                       # stem in progress?
    sl = 0                            # stem length
    last_line = [0, 0, 0, 0, 0, 0]    # initiate last line

    for i in range(0, len(lines)):
        line = lines[i].strip().split()
        if (int(line[4]) != 0 and sip == False):
            sip = True
            temp = [int(line[0]), int(line[4])]
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('C' or 'c')) or (rna[i] == ('C' or 'c') and rna[int(line[4])-1] == ('G' or 'g')):
                sl += 3
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('U' or 'u')) or (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('G' or 'g')) or (rna[i] == ('A' or 'a') and rna[int(line[4])-1] == ('U' or 'u')) or (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('A' or 'a')):
                sl += 2
        if (int(line[4]) != 0 and sip == True and (int(last_line[4])-int(line[4]) == 1)):
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('C' or 'c')) or (rna[i] == ('C' or 'c') and rna[int(line[4])-1] == ('G' or 'g')):
                sl += 3
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('U' or 'u')) or (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('G' or 'g')) or (rna[i] == ('A' or 'a') and rna[int(line[4])-1] == ('U' or 'u')) or (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('A' or 'a')):
                sl += 2
        if (int(line[4]) == 0 and sip == True):
            sip = False
            temp.append(sl)
            if temp[1] > temp[0]:
                stems_actual.append(temp)
            sl = 0
        if ((int(last_line[4])-int(line[4]) != 1) and int(last_line[4]) != 0  and sip == True):
            temp.append(sl)
            if temp[1] > temp[0]:
                stems_actual.append(temp)
            temp = [int(line[0]), int(line[4])]
            sl = 0
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('C' or 'c')) or (rna[i] == ('C' or 'c') and rna[int(line[4])-1] == ('G' or 'g')):
                sl = 3
            if (rna[i] == ('G' or 'g') and rna[int(line[4])-1] == ('U' or 'u')) or (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('G' or 'g')) or (rna[i] == ('A' or 'a') and rna[int(line[4])-1] == ('U' or 'u')) or (rna[i] == ('U' or 'u') and rna[int(line[4])-1] == ('A' or 'a')):
                sl = 2
        
        last_line = line
        
    return stems_actual

In [4]:
def potential_stems(seq_ps):
    
    with open(seq_ps) as file:
        lines = file.readlines()
    
    rna = lines[1]
    
    matrix = np.zeros((len(rna),len(rna)))
    for diag in range(0, len(matrix)):
        for row in range(0, len(matrix)-diag):
            col = row + diag
            base1 = rna[row]
            base2 = rna[col]
            if row != col:
                if ((base1 == ("A" or "a")) and (base2 == ("U" or "u"))) or ((base1 == ("U" or "u")) and (base2 == ("A" or "a"))) or ((base1 == ("G" or "g")) and (base2 == ("U" or "u"))) or ((base1 == ("U" or "u")) and (base2 == ("G" or "g"))):
                    matrix[row][col] = 2
                if ((base1 == ("G" or "g")) and (base2 == ("C" or "c"))) or ((base1 == ("C" or "c")) and (base2 == ("G" or "g"))):
                    matrix[row][col] = 3
    
    stems_potential = []
    mu = 0

    for row in range(0, len(matrix)):
        for col in range (row, len(matrix)):
            if row != col:
                if matrix[row][col] != 0:
                    temp_row = row
                    temp_col = col
                    stem = [row+1,col+1,0]
                    length_N = 0
                    length_H = 0
                    while (matrix[temp_row][temp_col] != 0) and (temp_row != temp_col):
                        length_N+=1
                        length_H+=matrix[temp_row][temp_col]
                        temp_row+=1
                        temp_col-=1
                        if length_N >= 3:
                            stem[2] = int(length_H)
                            stems_potential.append(stem.copy())
                    if length_H > mu:
                        mu = length_H
    
    return [stems_potential, mu, rna, len(rna)]

In [5]:
def potential_pseudoknots(stems_potential, pkp):

    pseudoknots_potential = []
    pseudoknot_penalty = pkp

    for i in range(len(stems_potential)):
        for j in range(i + 1, len(stems_potential)):
            
            stem1 = stems_potential[i]
            stem2 = stems_potential[j]
    
            i_a = stem1[0]
            j_a = stem1[1]
            i_b = stem2[0]
            j_b = stem2[1]
    
            pseudoknot = [i,j,1]
    
            if (i_a < i_b and i_b < j_a and j_a < j_b) or (i_b < i_a and i_a < j_b and j_b < j_a):
        
                pseudoknot[2] = pseudoknot_penalty
    
            pseudoknots_potential.append(pseudoknot)
            
    return pseudoknots_potential

In [6]:
def potential_overlaps(stems_potential):
    
    overlaps_potential = []
    overlap_penalty = 1e6

    for i in range(len(stems_potential)):
        for j in range(i+1, len(stems_potential)):
    
            stem1 = stems_potential[i]
            stem2 = stems_potential[j]
    
            overlap = [i, j, 0]
    
            stem1_cspan1 = set(range(stem1[1]-int(stem1[2])+1, stem1[1]+1))
            stem2_cspan1 = set(range(stem2[1]-int(stem2[2])+1, stem2[1]+1))
            
            stem1_cspan2 = set(range(stem1[0], stem1[0]+int(stem1[2])))
            stem2_cspan2 = set(range(stem2[0], stem2[0]+int(stem2[2])))
    
            if (len(stem1_cspan1 & stem2_cspan1) != 0) or (len(stem1_cspan2 & stem2_cspan2) != 0)  or (len(stem1_cspan1 & stem2_cspan2) != 0) or (len(stem1_cspan2 & stem2_cspan1) != 0):
        
                overlap[2] = overlap_penalty
        
            overlaps_potential.append(overlap)
            
    return overlaps_potential

In [7]:
def model(stems_potential, pseudoknots_potential, overlaps_potential, mu, cl, cb):
    
    L = {}
    Q = {}
    k = 0

    for i in range(0, len(stems_potential)):
        L[str(i)] = cl*((stems_potential[i][2]**2)-2*mu*stems_potential[i][2]+mu**2)-cb*(stems_potential[i][2]**2)
        for j in range(i+1, len(stems_potential)):
            Q[(str(i), str(j))] = -2*cb*stems_potential[i][2]*stems_potential[j][2]*pseudoknots_potential[k][2]+overlaps_potential[k][2]
            k += 1
    
    return L, Q

In [8]:
def energy(stems_actual, pkp, cl, cb):
    k = 0
    
    pseudoknots_actual = potential_pseudoknots(stems_actual, pkp)
    cost = 0
    mu = max(list(map(list, zip(*stems_actual)))[2])
    
    for i in range(0, len(stems_actual)):
        cost += cl*((stems_actual[i][2]**2)-2*mu*stems_actual[i][2]+mu**2)-cb*(stems_actual[i][2]**2)
        for j in range(i+1, len(stems_actual)):
            cost -= 2*cb*stems_actual[i][2]*stems_actual[j][2]*pseudoknots_actual[k][2]
            k += 1
    
    return cost

In [9]:
def evaluation_1(stems_actual, stems_potential):
    
    bp_actual = []
    bp_predicted = []

    for i in range(0, len(stems_actual)):
        for j in range(0, stems_actual[i][2]):
            bp_actual.append((stems_actual[i][0]+j, stems_actual[i][1]-j))
        
    for i in range(0, len(stems_potential)):
        for j in range(0, stems_potential[i][2]):
            bp_predicted.append((stems_potential[i][0]+j, stems_potential[i][1]-j))
            
    C = 0    # number of correctly identified base pairs
    M = 0    # number of the predicted base pairs missing from the known structure
    I = 0    # number of non-predicted base pairs present in the known structure

    for i in range(0, len(bp_predicted)):
        if bp_predicted[i] in bp_actual:
            C += 1
        else:
            M += 1

    for i in range(0, len(bp_actual)):
        if bp_actual[i] not in bp_predicted:
            I += 1
            
    ppv = C/(C+M)
    sensitivity = C/(C+I)
    
    return [sensitivity, ppv]

In [19]:
def evaluation_2(stems_actual, stems_predicted):
    
    b_actual = []
    b_predicted = []

    for i in range(0, len(stems_actual)):
        for j in range(0, stems_actual[i][2]):
            b_actual.append(stems_actual[i][0]+j)
            b_actual.append(stems_actual[i][1]-j)
        
    for i in range(0, len(stems_predicted)):
        for j in range(0, stems_predicted[i][2]):
            b_predicted.append(stems_predicted[i][0]+j)
            b_predicted.append(stems_predicted[i][1]-j)
            
    C = 0    # number of correctly identified bases that are paired
    M = 0    # number of the predicted paired bases missing from the known structure
    I = 0    # number of non-predicted paired bases present in the known structure

    for i in range(0, len(b_predicted)):
        if b_predicted[i] in b_actual:
            C += 1
        else:
            M += 1

    for i in range(0, len(b_actual)):
        if b_actual[i] not in b_predicted:
            I += 1
    
    PPV = 0
    sensitivity = 0
    if C != 0 or M != 0:
        PPV = C/(C+M)
    if C != 0 or I != 0:
        sensitivity = C/(C+I)
    
    return [sensitivity, PPV]

In [12]:
num_reads = 100

In [13]:
def spsa_optimizer_callback(nb_fct_eval, params, fct_value, stepsize, step_accepted, train_history):
    print("In callback")
    train_history.append((nb_fct_eval,params,fct_value))
    clear_output(wait=True)
    display(f'evaluations : {nb_fct_eval} loss: {fct_value:0.4f}')

In [14]:
def calculate_cost_function(expectation_values, target_values):
    product_zt = expectation_values*target_values
    all_costs = ((1-product_zt)/2)**2
    return all_costs

In [25]:
#sampler = BraketDWaveSampler(device_arn='arn:aws:braket:::device/qpu/d-wave/Advantage_system4')
#sampler = EmbeddingComposite(sampler)

sampler = dimod.SimulatedAnnealingSampler()

def optimize_params(optimizer, hyper_params, inital_point):
    target_value = 1
    def cost_function(hyper_params):
        print(hyper_params)
        all_metrics = []
        full_bprna = []
        pks = ["wPks"]
        sizes = ["s"]
        penalties = [0.5]
        a_stems = {}
        p_stems = {}
        a_energies = {}
        problems = {}
        stems_f = {}
        for pk in pks:
            for size in sizes:    
                bprna = get_structures("wPKs", "s")
                full_bprna = full_bprna + bprna
        for bprna in full_bprna:
            bprna_id = bprna.split("/")[4]
            fasta_file = bprna + ".fasta.txt"
            ct_file = bprna + ".ct.txt" 
            a_stems[bprna_id] = actual_stems(ct_file, fasta_file)
            a_energies[bprna_id] = energy(a_stems[bprna_id], hyper_params[0], hyper_params[1], hyper_params[2])
            p_stems[bprna_id] = potential_stems(fasta_file)
            p_psudoknots = potential_pseudoknots(p_stems[bprna_id][0], hyper_params[0])
            p_overlaps = potential_overlaps(p_stems[bprna_id][0])
            md = model(p_stems[bprna_id][0], p_psudoknots, p_overlaps, p_stems[bprna_id][1],hyper_params[1], hyper_params[2])
            problems[bprna_id] = dimod.BinaryQuadraticModel(md[0], md[1], vartype = 'BINARY', offset = 0.0)
            
        #print("Finished creating the models")
        
        for key, value in problems.items():
            sampleset = sampler.sample(value, num_reads=num_reads)
            #print("Ready:", key)
            for datum in sampleset.data(['sample', 'energy', 'num_occurrences']):
                results_hybrid = datum.sample
                predicted_energy = datum.energy
    
            f_stems = []

            for j in range(0, len(results_hybrid)):
                if results_hybrid[str(j)] == 1:
                    f_stems.append(p_stems[key][0][j])
                
            stems_f[key] = ([f_stems, predicted_energy])
            all_metrics.append(evaluation_2(a_stems[key], stems_f[key][0]))
        all_metrics = np.array(list(map(lambda x: (2*x[0]*x[1])/(x[0]+x[1]), all_metrics)))
        all_costs = calculate_cost_function(all_metrics, target_value)
        #print(all_costs)
        return np.sum(all_costs)/len(all_costs)
    
    model_values, loss, nfev = optimizer.optimize(len(hyper_params), cost_function, initial_point=inital_point)
    return model_values, loss, nfev

initial_point = [0.5, 1, 1]
model_params = [0.5, 1, 1]
train_history = []
optimizer = SPSA(maxiter=50, callback=lambda n, p, v, ss, sa: spsa_optimizer_callback(n, p, v, ss, sa, train_history))
model_values, loss, nfev = optimize_params(optimizer, model_params, initial_point)
print("Final values:",model_values)

'evaluations : 30 loss: 0.0425'

[1.25323124 3.15961042 2.67520782]
Final values: [1.25323124 3.15961042 2.67520782]


In [26]:
full_bprna = []
pks = ["wPks"]
sizes = ["s"]
penalties = [1.253]
cl = 3.160
cb = 2.675
a_stems = {}
p_stems = {}
a_energies = {}
problems = {}
stems_f = {}
for pk in pks:
    for size in sizes:    
        bprna = get_structures("wPKs", "s")
        full_bprna = full_bprna + bprna
for penalty in penalties:
    for bprna in full_bprna:
        bprna_id = bprna.split("/")[4]
        fasta_file = bprna + ".fasta.txt"
        ct_file = bprna + ".ct.txt" 
        a_stems[bprna_id] = actual_stems(ct_file, fasta_file)
        a_energies[bprna_id] = energy(a_stems[bprna_id], penalty, cl, cb)
        p_stems[bprna_id] = potential_stems(fasta_file)
        p_psudoknots = potential_pseudoknots(p_stems[bprna_id][0], penalty)
        p_overlaps = potential_overlaps(p_stems[bprna_id][0])
        md = model(p_stems[bprna_id][0], p_psudoknots, p_overlaps, p_stems[bprna_id][1], cl, cb)
        problems[bprna_id] = dimod.BinaryQuadraticModel(md[0], md[1], vartype = 'BINARY', offset = 0.0)

In [27]:
metric_1 = {}
metric_2 = {}
sampler = BraketDWaveSampler(device_arn='arn:aws:braket:::device/qpu/d-wave/Advantage_system4')
sampler = EmbeddingComposite(sampler)
for key, value in problems.items():
    sampleset = sampler.sample(value, num_reads=num_reads)
    print("Ready:", key)
    for datum in sampleset.data(['sample', 'energy', 'num_occurrences']):
        results_hybrid = datum.sample
        predicted_energy = datum.energy
    
    f_stems = []

    for j in range(0, len(results_hybrid)):
        if results_hybrid[str(j)] == 1:
            f_stems.append(p_stems[key][0][j])
        
    stems_f[key] = ([f_stems, predicted_energy])
    metric_1[key] = evaluation_1(a_stems[key], stems_f[key][0])
    metric_2[key] = evaluation_2(a_stems[key], stems_f[key][0])

Ready: bpRNA_RFAM_23352
Ready: bpRNA_RFAM_23366
Ready: bpRNA_RFAM_23381
Ready: bpRNA_RFAM_23457
Ready: bpRNA_RFAM_23495


In [28]:
print(metric_1)

{'bpRNA_RFAM_23352': [0.36363636363636365, 0.16], 'bpRNA_RFAM_23366': [0.42857142857142855, 0.34615384615384615], 'bpRNA_RFAM_23381': [0.0, 0.0], 'bpRNA_RFAM_23457': [0.0, 0.0], 'bpRNA_RFAM_23495': [0.3076923076923077, 0.4]}


In [29]:
print(metric_2)

{'bpRNA_RFAM_23352': [0.9893617021276596, 0.93], 'bpRNA_RFAM_23366': [0.8363636363636363, 0.8846153846153846], 'bpRNA_RFAM_23381': [1.0, 0.5138888888888888], 'bpRNA_RFAM_23457': [0.8823529411764706, 0.7758620689655172], 'bpRNA_RFAM_23495': [0.8, 1.0]}
