# Data analysis

This notebook contains code to detect the presence of the GAAGT seed motif in basecalled sequences, compute the mean accuracy of trained models by aligning basecalled reads to the ground truth reference sequence, and calculate the DTW (Dynamic Time Warping) distance for aligning experimental signals to reconstructed reference signals as an indication of read quality.

In [None]:
import csv
from Bio import SeqIO
import pandas as pd
import re
from collections import defaultdict
from Bio import pairwise2
import numpy as np
import matplotlib.pyplot as plt
from fastdtw import fastdtw

## Checking for GAAGT seed within basecalled sequences

In [None]:
fastq_file = "path/to/fastq"  
output_csv = "path/to/output.csv"

seed = "GAAGT"

results = []
total_reads = 0
seed_hits = 0

for record in SeqIO.parse(fastq_file, "fastq"):
    seq = str(record.seq)
    mid_index = len(seq) // 2 # checks in second half of the sequence 
    second_half = seq[mid_index:]
    has_seed = seed in second_half

    results.append({
        "read_id": record.id,
        "sequence": seq,
        "contains_seed_in_second_half": has_seed
    })

    total_reads += 1
    if has_seed:
        seed_hits += 1

with open(output_csv, mode="w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=["read_id", "sequence", "contains_seed_in_second_half"])
    writer.writeheader()
    writer.writerows(results)
    
print(f"Total reads: {total_reads}")
print(f"Reads with '{seed}' in second half: {seed_hits} ({(seed_hits / total_reads) * 100:.2f}%)")


## Calculate accuracy for experimental data 

In [None]:
split_cigar = r'(\d+)([=IDX])'

def compute_biopython_alignment(seq, ref):
    """
    Perform global alignment and convert to pseudo-CIGAR format.
    """
    alignments = pairwise2.align.globalms(seq, ref, 2, -3, -5, -2) # parameter values can be changed here 
    best = alignments[0]
    aligned_seq, aligned_ref = best.seqA, best.seqB

    operations = []
    for s, r in zip(aligned_seq, aligned_ref):
        if s == '-' and r != '-':
            operations.append('D')
        elif r == '-' and s != '-':
            operations.append('I')
        elif s == r:
            operations.append('=')
        else:
            operations.append('X')

    cigar = ''
    count = 1
    for i in range(1, len(operations)):
        if operations[i] == operations[i - 1]:
            count += 1
        else:
            cigar += f"{count}{operations[i - 1]}"
            count = 1
    cigar += f"{count}{operations[-1]}"
    return cigar

def accuracy_ignore_deletions(ref, seq):
    """
    Compute accuracy as: match / (match + mismatch + insertion),
    ignoring deletions.
    """
    cigar = compute_biopython_alignment(seq, ref)
    counts = defaultdict(int)

    for count, op in re.findall(split_cigar, cigar):
        counts[op] += int(count)

    total = counts['='] + counts['X'] + counts['I']
    acc = counts['='] / total * 100 if total > 0 else 0.0
    return acc, counts


fastq_path = "path/to/basecalled_reads.fastq"  # update this path
basecalled_dict = {}

for record in SeqIO.parse(fastq_path, "fastq"):
    basecalled_dict[record.id] = str(record.seq)

reference_sequence = "CCGATGCTGGCTACATCTTAGGCTATCACTCTCACCTGCGATTATATGGTCCGTGCACTCTGAAGTCATT" # change this if needed 

accuracies = []
raw_totals = defaultdict(int)

for read_id, called_seq in basecalled_dict.items():
    acc, counts = accuracy_ignore_deletions(reference_sequence, called_seq)
    accuracies.append(acc)
    for k in counts:
        raw_totals[k] += counts[k]

if accuracies:
    mean_accuracy = np.mean(accuracies)
    print(f"Mean accuracy (ignoring deletions): {mean_accuracy:.2f}%")
    print(f"Evaluated {len(accuracies)} reads.")
    print("Alignment stats:")
    print(f"  Matches (=):     {raw_totals['=']}")
    print(f"  Mismatches (X):  {raw_totals['X']}")
    print(f"  Deletions (D):   {raw_totals['D']}")
    print(f"  Insertions (I):  {raw_totals['I']}")
else:
    print("No reads found to evaluate.")


info_df = pd.read_csv("path/to/mapping.csv")  # update this path --> this comes from .csv to .pod5 conversion script (mapping_df)

length_errors = []
matched_lengths = 0
missing_from_basecalls = 0

for _, row in info_df.iterrows():
    read_id = row['read_id']
    predicted_len = row['predicted length']

    if read_id in basecalled_dict:
        actual_len = len(basecalled_dict[read_id])
        length_errors.append(actual_len - predicted_len)
        matched_lengths += 1
    else:
        missing_from_basecalls += 1

if length_errors:
    mae = np.mean(np.abs(length_errors))
    mse = np.mean(np.square(length_errors))
    mean_error = np.mean(length_errors)

    print("\nBasecalled length vs predicted length:")
    print(f"Matched reads: {matched_lengths}")
    print(f"Missing from basecalls: {missing_from_basecalls}")
    print(f"Mean error: {mean_error:.2f} bases")
    print(f"MAE: {mae:.2f} bases")
    print(f"MSE: {mse:.2f} bases")

    plt.figure(figsize=(12, 6))
    plt.hist(length_errors, bins=50, color='steelblue', edgecolor='black')
    plt.title("Length Prediction Error Distribution")
    plt.xlabel("Error (Basecalled - Predicted)")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()
else:
    print("No matching read IDs found between basecalls and predictions.")


## DTW analysis

In [None]:
def moving_6mer_Substrings(string):
    return [string[i:i+6] for i in range(len(string) - 5)]

# change this to different signal generation if needed
def predict_DNA_6mer_5_3_with_sampling(template, lut, lambda_time, sampling_rate, I_max=180): 
    template = template[::-1]
    kmers = moving_6mer_Substrings(template)
    N = len(kmers)

    valid_kmers = [k for k in kmers if k in lut]
    params = np.array([list(lut[k].values()) for k in valid_kmers]) 

    pre_mean, pre_std, post_mean, post_std = params.T

    step_times = np.ones(len(valid_kmers)) * lambda_time 
    num_samples = (step_times * sampling_rate).astype(int)

    sampled_signals = []
    sampled_times = []
    current_time = 0.0
    
    for i in range(len(valid_kmers)):
        ns = num_samples[i]
        if ns == 0:
            continue
        
        pre = np.random.normal(pre_mean[i] * I_max, pre_std[i] * I_max, ns)
        post = np.random.normal(post_mean[i] * I_max, post_std[i] * I_max, ns)
        
        step_time = step_times[i]
        
        t_pre = np.linspace(current_time, current_time + step_time, ns)
        sampled_signals.extend(pre)
        sampled_times.extend(t_pre)
        current_time += step_time
       
        t_post = np.linspace(current_time, current_time + step_time, ns)
        sampled_signals.extend(post)
        sampled_times.extend(t_post)
        current_time += step_time

    return pd.DataFrame({
        "time": sampled_times,
        "current": sampled_signals
    })
    

In [None]:
reference = 'TTACTGAAGTCTCACGTGCCTGGTATATTAGCGTCCACTCTCACTATCGGATTCTACATCGGTCGTAGCC' # change if needed
reference_genome = 'path/to/reference_genome.fna' # update path 
LUT_6mer = pd.read_csv('path/to/kmer_model.csv', encoding='utf-8') # update path 
lut = LUT_6mer.set_index("kmer_pull_3_5")[["pre_mean", "pre_std", "post_mean", "post_std"]].to_dict("index")

In [None]:
reference_signal = predict_DNA_6mer_5_3_with_sampling(reference, lut, 0.002, 5000)['current'].values # change values if needed 
if reference_signal is None:
    raise ValueError("No reference signal found or generated.")

folder_path = "path/to/resampled_signals"
signal_files = [f for f in os.listdir(folder_path) if f.startswith("JS") and f.endswith(".csv")]

results = []

for file in signal_files:
    file_path = os.path.join(folder_path, file)
    df = pd.read_csv(file_path)

    if "RelativeCurrent" not in df.columns:
        continue

    test_signal = df["RelativeCurrent"].values.astype(float)
    dtw_distance, _ = fastdtw(reference_signal, test_signal, dist=lambda x, y: abs(x - y))
    results.append((file, dtw_distance))

results_df = pd.DataFrame(results, columns=["Filename", "DTW_Distance"])

dtw_threshold = results_df["DTW_Distance"].quantile(0.25)
results_df["Pass"] = results_df["DTW_Distance"] <= dtw_threshold

print(f"DTW distance threshold (25th percentile): {dtw_threshold:.2f}")
print(f"{results_df['Pass'].sum()} / {len(results_df)} reads passed DTW threshold.")

results_df.to_csv("path/to/output/signal_quality_dtw_top25.csv", index=False)

plt.figure(figsize=(10, 6))
plt.hist(results_df["DTW_Distance"], bins=30, color='lightcoral', edgecolor='black')
plt.axvline(dtw_threshold, color='blue', linestyle='--', label=f'Threshold (25th percentile) = {dtw_threshold:.2f}')
plt.xlabel("DTW Distance")
plt.ylabel("Number of Signals")
plt.title("Distribution of DTW Distance Scores")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

best_match = results_df.loc[results_df["DTW_Distance"].idxmin()]
worst_match = results_df.loc[results_df["DTW_Distance"].idxmax()]

print(f"Best match: {best_match['Filename']} | DTW Distance: {best_match['DTW_Distance']:.2f}")
print(f"Worst match: {worst_match['Filename']} | DTW Distance: {worst_match['DTW_Distance']:.2f}")