# Edit distance search
Manipulating inference results, creating training and testing labels

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
from tqdm import tqdm
import os

In [None]:
basepath = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\edit_distance_motif_search\res-loose-chain.txt"

In [None]:
def clean_prediction(edit_motif_search_prediction):

    cleaned_prediction = []
    for i in edit_motif_search_prediction:
        if i == 'fake':
            cleaned_prediction.append([])
        else:
            cleaned_prediction.append([int(i[1])])

    return cleaned_prediction

In [None]:

def get_edit_dataframe(edit_inference_filename):

    with open(edit_inference_filename, 'r') as f:
        lines = f.readlines()

    motif_predictions = []
    orientations = []
    read_ids = []
    ont_barcode = []

    for line in tqdm(lines):
        split_line = line.split()
        read_id = split_line[0][3:]
        orientation = split_line[1]
        prediction = split_line[4][8:].split('->')
        if not (prediction[0].startswith('f') or prediction[0].startswith('m')):
            prediction = prediction[1:]

        cleaned_prediction = clean_prediction(prediction)
        
        motif_predictions.append(cleaned_prediction)
        orientations.append(orientation)
        read_ids.append(read_id)
    
    df = pd.DataFrame({"read_id": read_ids, "orientation": orientations, "motif_seq": motif_predictions})
    df = df.drop_duplicates(subset=['read_id'])
    return df

In [None]:
df = get_edit_dataframe(basepath)

In [None]:
dil = pd.read_pickle(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\full_datasets\diluted_EIC04.pkl")

In [None]:
merged = pd.merge(dil, df, on='read_id')

In [None]:
len(merged)

In [None]:
merged

In [None]:
merged['ONT_Barcode'].unique()

In [None]:
master_df = pd.DataFrame()

for file in tqdm(os.listdir(basepath)):
    df = get_edit_dataframe(os.path.join(basepath, file))
    master_df = pd.concat([master_df, df])

In [None]:
master_df.to_pickle(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\01-04run\edit_medium.pkl")

## Balancing edit-train df

In [None]:
from data_functions import get_cleaned_encoded_file

In [None]:
encoded_df = pd.read_csv(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\EIC01-01-1280-T1_encoded.tsv", sep='\t')

In [None]:
t = get_cleaned_encoded_file(encoded_df)

In [None]:
t = t[['ONT_Barcode', 'HW_Address', 'payload']]

### Adding edit labels to test dataset

In [None]:
edit_df = pd.read_pickle(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\edit_distance_motif_search\edit_distance_motif_search.pkl")

In [None]:
edit_df

## Extracting a fastq with all the subset reads

In [None]:
import pandas as pd
from tqdm import tqdm
import os

In [None]:
encoded_df = pd.read_pickle(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\full_datasets\diluted_EIC04.pkl")

In [None]:
master_read_ids = set(encoded_df['read_id'].tolist())

In [None]:
filtered_barcodes = ['barcode_external02_internal02', 'barcode_external08_internal01',
       'barcode_external05_internal01', 'barcode_external02_internal03',
       'barcode_external02_internal05']

In [None]:
filtered_df = encoded_csv.loc[encoded_csv['HW_Address'].isin(filtered_barcodes)].sample(10000)

In [None]:
basepath = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\fastq\EIC01-01-1280\FASTQ(pass_fail_logs)\fail"

In [None]:
master_read_ids = filtered_df['read_id'].tolist()

In [None]:
from Bio import SeqIO, SeqRecord

def parse_biopython(input_fastq):
    for record in SeqIO.parse(input_fastq, 'fastq'):
        yield record

def get_fastq_records(fastq_filepath, master_read_ids):
    records = []
    for i, record in enumerate(parse_biopython(fastq_filepath)):
        if record.id in master_read_ids:
            records.append(record)
    return records

In [None]:
master_records = []

In [None]:
master_read_ids = set(master_read_ids)

In [None]:
for file in tqdm(os.listdir(basepath)):
    if file.endswith('.fastq'):
        records = get_fastq_records(os.path.join(basepath, file), master_read_ids)
        master_records.extend(records)

In [None]:
record_read_ids = [str(i.id) for i in master_records]

In [None]:
len(merged['read_id'])

In [None]:
len(set(record_read_ids).intersection(set(merged['read_id'])))

In [None]:
t = [i for i in master_records if i.id in master_read_ids]

In [None]:
def create_fasta_file(ids, strands, output_filepath):
    with open(output_filepath, 'w') as f:
        for i, strand in enumerate(strands):
            f.write(f">{ids[i]}\n")
            f.write(strand + '\n\n')

In [None]:
strands = [str(i.seq) for i in master_records]
ids = [str(i.id) for i in master_records]

In [None]:
output_filepath = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\fastq\EIC01-01-1280"

In [None]:
with open(os.path.join(output_filepath, f"fq.fastq"), "w") as output_handle:
        SeqIO.write(master_records, output_handle, "fastq")

In [None]:

# Splitting by 10000 reads - since MCED can handle it and outputting files

for ind, i in enumerate(range(0, len(master_records), 10000)):
    with open(os.path.join(basepath, f"fq{ind}.fastq"), "w") as output_handle:
        SeqIO.write(master_records[i: i + 10000], output_handle, "fastq")

# This worked! Let's see how long it takes - failed halfway through huh

In [None]:
from data_functions import sort_library_motif_transcript

t = dil['library_motif'].apply(lambda x: sort_library_motif_transcript(x))

In [None]:
df

In [None]:
merged_df = pd.mebrge(encoded_df, df, on='read_id')

In [None]:
import pandas as pd

In [None]:
merged_df = pd.read_pickle(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\full_datasets\master_test.pkl")

In [None]:
from utils import evaluate_prediction

In [None]:
merged_df['motif_seq'] = merged_df['edit_search_seq']

In [None]:
merged_df['orientation_x'] = merged_df['orientation_x'].apply(lambda x: 1 if x.startswith('+') else 0)

In [None]:

found_arr = []
err_arr = []
orientation_arr = []
total = 0
for ind, row in merged_df.iterrows():
    edit_seq = row['motif_seq']
    payload = row['payload_seq']
    orientation = row['orientation_x']

    found_, err_ = evaluate_prediction(edit_seq[2:], payload)
    found_arr.append(found_)
    err_arr.append(err_)
    orientation_arr.append(orientation)
    total += 1
    

In [None]:
t = pd.DataFrame({
    "mf_edit": found_arr,
    "me_edit": err_arr,
    "orientation": orientation_arr
})

In [None]:
filtered_df = t.loc[t['mf_edit'] > 1]

In [None]:
filtered_df = t

In [None]:

mean_values = filtered_df.mean(numeric_only=True)

print("Mean metrics combined:")
print(mean_values.to_frame(name='Mean').T)
print()

mean_values = filtered_df.loc[filtered_df['orientation'] == 1].mean(numeric_only=True)
print("Mean metrics forward")
print(mean_values.to_frame(name='Mean').T)
print()

mean_values = filtered_df.loc[filtered_df['orientation'] == 0].mean(numeric_only=True)
print("Mean metrics reverse")
print(mean_values.to_frame(name='Mean').T)
print()


print(f"{len(filtered_df)/ len(merged_df) * 100}% of pool")


### Convergence

In [None]:
import numpy as np

In [None]:
def update_tallies(tallies, prediction):

    for ind, i in enumerate(prediction):
        for j in i:
            tallies[ind][j-1] += 1

    return tallies

def evaluate_motif_tallies(motif_tallies, payload_cycles):
    correct = 0
    errs = 0
    for tallies, cycle in zip(motif_tallies, payload_cycles):
        sorted_tallies = np.argsort(tallies)[::-1]
        top_4 = [i+1 for i in sorted_tallies[:4]]
        correct += len(set(top_4).intersection(set(cycle)))
        errs += len(set(top_4) - set(cycle))

    return correct / 32

In [None]:
test_df = dil

In [None]:
dil["motif_seq"] =  t

In [None]:
mf = t['mf_edit']

In [None]:
test_df['mf'] = mf

In [None]:
t1_barcodes = np.arange(2, 81, 4)

In [None]:
test_df['ONT_Barcode'].unique()

In [None]:
test_df_ = merged.loc[merged['ONT_Barcode'].isin(t1_barcodes)]

In [None]:
from tqdm import tqdm

In [None]:
# For barcoded runs

edit_predictions_arr = []
payloads_arr = []
batch_size = 8
quality_threshold = 12
prob_threshold = 0.85

# Testing convergence - add quality filtering..
for barcode in tqdm(test_df_['ONT_Barcode'].unique()):
    u_cycle = 0
    for cycle in tqdm(test_df_['HW_Address'].unique()):
        selected_df = test_df_.loc[
            (test_df_['ONT_Barcode'] == barcode) &
            (test_df_['HW_Address'] == cycle)
        ]
        read_ids = selected_df['read_id'].tolist()
        if len(read_ids) == 0:
            continue
        payload = selected_df['payload'].tolist()[0]
        decoded_predictions = selected_df['motif_seq']
        edit_predictions_arr.append([i[2:] for i in decoded_predictions])
        
        payloads_arr.append(payload)
        u_cycle += 1
        


In [None]:
## For the non barcoded runs

search_predictions_arr = []
payloads_arr = []
batch_size = 8
quality_threshold = 12
prob_threshold = 0.85


for cycle in tqdm(test_df_['HW_Address'].unique()):
    selected_df = test_df_.loc[
        (test_df['HW_Address'] == cycle) #&
        #(test_df['strand'].str.startswith('+')) 
    ]
    read_ids = selected_df['read_id'].tolist()
    payload = selected_df['payload'].tolist()[0]
    #search_predictions = selected_df['motif_seq'].tolist()
    #search_predictions_sorted = [
    #    sort_transcript(i) for i in selected_df['motif_seq'].tolist()]
    decoded_predictions = [i[2:] for i in selected_df['motif_seq'].tolist()][:150]
    
    #search_predictions_arr.append(search_predictions_sorted[:len(decoded_predictions)])
    search_predictions_arr.append(decoded_predictions)
    
    payloads_arr.append(payload)
    


In [None]:
import heapq
from utils import evaluate_prediction

def get_top_4_motifs(tallies: dict):
    prediction = []
    for i in range(len(tallies)):
        elems = heapq.nlargest(4, tallies[i].values())
        inds = heapq.nlargest(4, tallies[i].keys(), key = lambda x: tallies[i][x])
        prediction.append([k for j, k in zip(elems, inds) if j > 0 ])

    return prediction

#payloads_arr = [eval(i) for i in payloads_arr]


search_accs_arr = []
caller_accs_arr = []
edit_accs_arr = []


### Comment out when non-barcoded
caller_predictions_arr = edit_predictions_arr
search_predictions_arr = edit_predictions_arr

for search_prediction, caller_prediction, edit_prediction, payload in zip(
    search_predictions_arr, caller_predictions_arr, edit_predictions_arr, payloads_arr):

    n_reads = 0
    motif_tallies_caller = [{i: 0 for i in range(1, 9)} for i in range(8)]
    motif_tallies_search = [{i: 0 for i in range(1, 9)} for i in range(8)]
    motif_tallies_edit = [{i: 0 for i in range(1, 9)} for i in range(8)]
    search_accs, caller_accs, edit_accs = [], [], []
    for i, j, k in zip(search_prediction, caller_prediction, edit_prediction):  # for each prediction
        cycle_num = 0
        for search_cycle, caller_cycle, edit_cycle in zip(i, j, k):  # loop over cycles and update tallies
            for motif in search_cycle:
                motif_tallies_search[cycle_num][motif] += 1
            for motif in caller_cycle:
                motif_tallies_caller[cycle_num][motif] += 1
            for motif in edit_cycle:
                motif_tallies_edit[cycle_num][motif] += 1
            cycle_num += 1

        #  Get caller and search predictions
        caller_prediction_ = get_top_4_motifs(motif_tallies_caller)
        search_prediction_ = get_top_4_motifs(motif_tallies_search)
        edit_prediction_ = get_top_4_motifs(motif_tallies_edit)
        
        n_reads += 1

        #print(payload)
        #payload = eval(payload)  # comment out for other datasets
        #  Evaluate predictions
        search_acc = evaluate_prediction(search_prediction_, payload)[0]/32
        caller_acc  = evaluate_prediction(caller_prediction_, payload)[0]/32
        edit_acc = evaluate_prediction(edit_prediction_, payload)[0]/32
        

        search_accs.append(search_acc)
        caller_accs.append(caller_acc)
        edit_accs.append(edit_acc)
        
    
    search_accs_arr.append(search_accs)
    caller_accs_arr.append(caller_accs)
    edit_accs_arr.append(edit_accs)
    



In [None]:
np.mean([i[-1] for i in edit_accs_arr])

In [None]:
avg_search = []
avg_caller = []
avg_edit = []

for i in range(100):

    search_sum = 0
    search_samples = 0
    for k in search_accs_arr:
        if i < len(k):
            search_sum += k[i]
            search_samples += 1
    
    caller_sum = 0
    caller_samples = 0
    for k in caller_accs_arr:
        if i < len(k):
            caller_sum += k[i]
            caller_samples += 1

    edit_sum = 0
    edit_samples = 0
    for k in edit_accs_arr:
        if i < len(k):
            edit_sum += k[i]
            edit_samples += 1

    avg_search.append(search_sum / search_samples)
    #avg_caller.append(caller_sum / caller_samples)
    #avg_edit.append(edit_sum / edit_samples)
    

In [None]:
import matplotlib.pyplot as plt
import numpy as np

plt.plot(avg_search, label='zero-error search')
plt.plot(avg_caller, label='motif caller')
plt.plot(avg_edit, label='edit search')
plt.grid()
plt.yticks(np.arange(0, 1.01, 0.05))
#plt.xticks(np.arange(0, 200, 2))
plt.xlim(0, 70)
plt.xlabel("Number of reads")
plt.ylabel("Recovery percentage")
plt.title("Convergence of motif-inferring methods")
plt.legend()

### Evaluating zero-error search

In [None]:

def sort_library_motif_transcript(library_prediction):
    """Sorts library motif transcripts while fixing for typos, by looking at the payloads to get the best orientation matches"""

    sorted_prediction = [[] for i in range(10)]

    split_library_prediction = library_prediction.split('|')

    #print(split_library_prediction)
    for i in split_library_prediction:

        if len(i) < 8:
            continue
        # searching for the 10
        motif_found = int(i[-1])
        cycle_address = None  # starts from 1 there and from 0 here
        if i[-4] == '1':
            if i[-3] == '0':
                cycle_address = 10
                sorted_prediction[cycle_address - 1].append(motif_found)
        elif i[-4] == '9':
            continue
        else:
            cycle_address = int(i[-3])
            sorted_prediction[cycle_address - 1].append(motif_found)

    return sorted_prediction
    

In [None]:
import pandas as pd

In [None]:
ms_df = pd.read_pickle(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\sequencing_runs\01-13\01-13-test.pkl")

In [None]:
lib_motifs = ms_df['library_motif'].tolist()
read_id_arr = ms_df['read_id'].tolist()

In [None]:
motif_seq_arr = []
read_ids_arr = []

for i, j in zip(lib_motifs, read_id_arr):
    try:
        motif_seq = sort_library_motif_transcript(i)
        motif_seq_arr.append(motif_seq)
        read_ids_arr.append(j)
    except:
        print(f"Exception at {i}")


In [None]:
merged_df = pd.merge(ms_df, pd.DataFrame({
    "read_id": read_ids_arr,
    "motif_seq": motif_seq_arr
}))

In [None]:
merged_df

### Checking average quality

In [None]:
basepath = r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\fastq\EIC01-01-1280\FASTQ(pass_fail_logs)\fail"

In [None]:
from Bio import SeqIO
import numpy as np
import os
from tqdm import tqdm

In [None]:

from Bio import SeqIO

avg_qs = []
for file in tqdm(os.listdir(basepath)):
    for record in SeqIO.parse(os.path.join(basepath, file), "fastq"):
        score=record.letter_annotations["phred_quality"]
        avg_qs.append(np.mean(score))
    

In [None]:
np.mean(avg_qs)

In [None]:
pass_sizes = []
fail_sizes = []

for file in tqdm(os.listdir(basepath)):
    pass_sizes.append(os.path.getsize(os.path.join(basepath, file)))

In [None]:

fail_sizes = []
for file in tqdm(os.listdir(basepath)[:544]):
    fail_sizes.append(os.path.getsize(os.path.join(basepath, file)))

In [None]:
np.mean(pass_sizes) / (np.mean(pass_sizes) + np.mean(fail_sizes))

In [None]:
np.mean(qualities)

In [None]:
import pandas as pd


motif_search_coded = pd.read_csv(r"C:\Users\Parv\Doc\HelixWorks\Basecalling\code\motifcaller\data\empirical\motif_search_barcoded.csv")