In [1]:
!pip install pod5 pysam

Collecting pod5
  Downloading pod5-0.3.35-py3-none-any.whl.metadata (21 kB)
Collecting pysam
  Downloading pysam-0.23.3-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (1.7 kB)
Collecting deprecated~=1.2.18 (from pod5)
  Downloading Deprecated-1.2.18-py2.py3-none-any.whl.metadata (5.7 kB)
Collecting lib_pod5==0.3.35 (from pod5)
  Downloading lib_pod5-0.3.35-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.4 kB)
Collecting iso8601 (from pod5)
  Downloading iso8601-2.1.0-py3-none-any.whl.metadata (3.7 kB)
Collecting pyarrow~=22.0.0 (from pod5)
  Downloading pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.2 kB)
Collecting vbz_h5py_plugin (from pod5)
  Downloading vbz_h5py_plugin-1.0.1-py3-none-any.whl.metadata (1.7 kB)
Collecting wrapt<2,>=1.10 (from deprecated~=1.2.18->pod5)
  Downloading wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl.metadata (6.4 kB)
Downloading pod5-0.3.35-py3-none-any.whl (68 kB)
[2K   

In [3]:
# Imports & Global Setup
import os
import glob
import subprocess
from google.colab import drive
import pod5
import torch
INPUT_DIR = "/content/drive/MyDrive/Untitled folder/pod5"   # Here is your pod5 path
POD5_GLOB = "/content/drive/MyDrive/Untitled folder/pod5/*.pod5"   # Here is your pod5 path

OUTPUT_DIR = "/content/deepsignal_paper"
MODEL_NAME = "dna_r10.4.1_e8.2_400bps_hac@v5.0.0"
MODS = "6mA"
MODEL_CACHE = "/content/.dorado/models"
REF_PATH = os.path.join(OUTPUT_DIR, "listeria_ref.fa")
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MIN_QSCORE = 0
BATCHSIZE = 32 if "cuda" in DEVICE else 2
MAX_READS = 1000
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(MODEL_CACHE, exist_ok=True)

def run_cmd(cmd, capture_output=True, shell=False):
    env = os.environ.copy()
    env['PATH'] = f"/content/dorado/bin:{env.get('PATH', '')}"
    if isinstance(cmd, str):
        if shell:
            proc = subprocess.run(cmd, shell=True, capture_output=capture_output, text=True, env=env)
        else:
            cmd = cmd.split()
            proc = subprocess.run(cmd, capture_output=capture_output, text=True, env=env)
    else:
        proc = subprocess.run(cmd, capture_output=capture_output, text=True, env=env)
    if proc.returncode != 0:
        stderr_snip = proc.stderr[:300] if proc.stderr is not None else ''
        print(f"RC={proc.returncode}: {stderr_snip}")
    return proc.returncode, proc.stdout, proc.stderr

# Install Dependencies
print("Installing dependencies...")
run_cmd(["apt", "update", "-qq"])
run_cmd(["apt", "install", "-y", "samtools"])
run_cmd(["pip", "install", "pod5", "pysam"])

# GPU check
try:
    run_cmd(["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv"], capture_output=False)
except FileNotFoundError:
    print("No GPU - using CPU")

# Download Reference Genome
print("Downloading reference genome...")
ref_url = "https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/196/035/GCF_000196035.1_ASM19603v1/GCF_000196035.1_ASM19603v1_genomic.fna.gz"
if not os.path.exists(REF_PATH):
    run_cmd(["wget", "-q", "-O", f"{REF_PATH}.gz", ref_url])
    run_cmd(["gunzip", "-f", f"{REF_PATH}.gz"])
    run_cmd(["samtools", "faidx", REF_PATH])
print(f"Ref ready: {REF_PATH}")

# Dorado Install
print("Installing Dorado...")
dorado_bin = "/content/dorado/bin/dorado"
if not os.path.exists(dorado_bin):
    run_cmd(
        "wget -qO dorado.tar.gz "
        "https://cdn.oxfordnanoportal.com/software/analysis/dorado-1.2.0-linux-x64.tar.gz",
        shell=True
    )
    run_cmd("tar -xzf dorado.tar.gz", shell=True)
    folders = glob.glob("/content/dorado-1.2.0*")
    if len(folders) == 0:
        raise RuntimeError("Dorado extraction failed")
    extracted = folders[0]
    os.rename(extracted, "/content/dorado")
    print(f"Moved {extracted} -> /content/dorado")
rc, out, err = run_cmd("dorado --version")
print("Dorado version:", (out or err).strip())

# Model Download
print("Downloading model...")
base_model = "dna_r10.4.1_e8.2_400bps_hac@v5.0.0"
base_path = os.path.join(MODEL_CACHE, base_model)
if not os.path.exists(base_path):
    print(f"Downloading base {base_model}...")
    rc, _, _ = run_cmd(f"dorado download --model {base_model} --models-directory {MODEL_CACHE}")
    if rc != 0:
        tar_file = f"{base_model}.tar.gz"
        cdn_url = f"https://resources.nanoporetech.com/models/{base_model}.tar.gz"
        run_cmd(f"wget -qO {tar_file} {cdn_url}", shell=True)
        run_cmd(f"tar -xzf {tar_file} -C {MODEL_CACHE}", shell=True)
        os.remove(tar_file)
print(f"Base model ready: {base_path}")

# Process POD5 Files
print("Processing POD5 files...")
pod5_files = sorted(glob.glob(os.path.join(INPUT_DIR, "*.pod5")))
if not pod5_files:
    raise FileNotFoundError("No POD5s found")
print(f"Found {len(pod5_files)} POD5s")

bams = []
for i, pod5_file in enumerate(pod5_files):
    base = os.path.basename(pod5_file)
    print(f"Processing {base}")
    bam_raw = os.path.join(OUTPUT_DIR, f"batch_{i}.bam")
    bam_sorted = os.path.join(OUTPUT_DIR, f"batch_{i}.sorted.bam")

    if os.path.exists(bam_sorted) and os.path.getsize(bam_sorted) > 5000:
        print("Skipping - sorted BAM exists")
        bams.append(bam_sorted)
        continue

    cmd = [
        "/content/dorado/bin/dorado", "basecaller",
        MODEL_NAME,
        pod5_file,
        "--reference", REF_PATH,
        "--modified-bases", MODS,
        "--device", DEVICE,
        "--min-qscore", str(MIN_QSCORE),
        "--batchsize", str(BATCHSIZE),
        "--max-reads", str(MAX_READS),
        "--emit-moves",
        "--disable-read-splitting"
    ]
    print("Running Dorado...")
    with open(bam_raw, "wb") as fh:
        proc = subprocess.run(cmd, stdout=fh, stderr=subprocess.PIPE, text=True)
    if proc.returncode != 0:
        print(f"Dorado error:\n{proc.stderr[:500]}")
        continue
    print(f"Raw BAM created: {bam_raw}")

    print("Sorting BAM...")
    rc, out, err = run_cmd(["samtools", "sort", "-o", bam_sorted, bam_raw])
    if rc != 0:
        print("Failed to sort BAM")
        continue

    print("Indexing BAM...")
    rc, out, err = run_cmd(["samtools", "index", bam_sorted])
    if rc != 0:
        print("Failed to index BAM")

    try:
        os.remove(bam_raw)
    except:
        pass

    rc, out, _ = run_cmd(["samtools", "view", "-c", bam_sorted])
    print(f"Final sorted BAM: {bam_sorted} ({out.strip()} alignments)")
    bams.append(bam_sorted)

merged_bam = os.path.join(OUTPUT_DIR, "all_pod5s_merged.bam")
if len(bams) > 1:
    rc, out, err = run_cmd(["samtools", "merge", "-f", merged_bam] + bams)
    if rc == 0:
        print(f"Merged BAM created: {merged_bam}")
        rc, out, err = run_cmd(["samtools", "index", merged_bam])

print("Processing complete!")

Installing dependencies...
Downloading reference genome...
Ref ready: /content/deepsignal_paper/listeria_ref.fa
Installing Dorado...
Moved /content/dorado-1.2.0-linux-x64 -> /content/dorado
Dorado version: [2025-11-21 23:48:36.950] [info] Running: "--version"
1.2.0+f9443bb8
Downloading model...
Downloading base dna_r10.4.1_e8.2_400bps_hac@v5.0.0...
Base model ready: /content/.dorado/models/dna_r10.4.1_e8.2_400bps_hac@v5.0.0
Processing POD5 files...
Found 3 POD5s
Processing ATCC_19119__202309_batch1.pod5
Running Dorado...
Raw BAM created: /content/deepsignal_paper/batch_0.bam
Sorting BAM...
Indexing BAM...
Final sorted BAM: /content/deepsignal_paper/batch_0.sorted.bam (1039 alignments)
Processing ATCC_19119__202309_batch86.pod5
Running Dorado...
Raw BAM created: /content/deepsignal_paper/batch_1.bam
Sorting BAM...
Indexing BAM...
Final sorted BAM: /content/deepsignal_paper/batch_1.sorted.bam (1143 alignments)
Processing ATCC_19119__202309_batch90.pod5
Running Dorado...
Raw BAM created: 

In [4]:
import os
import glob
import re
import numpy as np
import pysam
from collections import defaultdict
from pod5 import Reader
import h5py

# ===================== CONFIG (EDIT BEFORE RUN) =====================
MERGED_BAM_PATH = "/content/deepsignal_paper/all_pod5s_merged.bam"  # output of PART A
OUTPUT_DATASET = "/content/deepsignal_paper/deepsignal_dataset_colab.h5"
WINDOW_SIZE = 17
SIGNAL_LENGTH = 360
MIN_COVERAGE = 5
MAX_READS = 500       # <<-- safe test; increase or set None to run all reads (may OOM)
MIN_MAPQ = 20           # mapping quality filter for reads
ML_HIGH = 0.7           # Dorado prob threshold considered methylated
ML_LOW = 0.3            # Dorado prob threshold considered unmethylated

# ================ Utility helpers ================

def safe_ml_probs(ml_tag):
    """Convert ML tag (bytes/list/int) to list of floats [0,1]."""
    if ml_tag is None:
        return None
    # bytes: treat as uint8 array
    if isinstance(ml_tag, (bytes, bytearray)):
        arr = np.frombuffer(ml_tag, dtype=np.uint8)
        return (arr.astype(np.float32) / 255.0).tolist()
    # list/tuple of ints
    try:
        return (np.array(list(ml_tag), dtype=np.float32) / 255.0).tolist()
    except Exception:
        try:
            # single int
            return [float(ml_tag) / 255.0]
        except Exception:
            return None


def parse_mm_tag_positions(mm_tag):
    """Robustly parse MM tag and return list of (base, positions[] as 0-based offsets in read).
    We don't try to be perfectly spec-conforming for every odd variant; we aim to extract numeric positions.
    """
    out = []
    if not mm_tag:
        return out
    # mm_tag may be like: 'A+m,0,5;'
    blocks = mm_tag.split(';')
    for block in blocks:
        if not block.strip():
            continue
        # split at first comma to get head like 'A+m' and remaining numeric tokens
        parts = block.split(',')
        head = parts[0]
        # base info could be like 'A+m' or 'C+m' or '6mA' variants; keep head
        base = head.split(':')[0] if ':' in head else head
        # collect digits from the rest
        pos_tokens = []
        for tok in parts[1:]:
            tok = tok.strip().lstrip('+')
            m = re.match(r'-?(\d+)', tok)
            if m:
                pos_tokens.append(int(m.group(1)))
        # Dorado/Tombo spec: these are deltas (cumulative)
        positions = []
        offset = 0
        for v in pos_tokens:
            offset += v
            positions.append(offset)
        out.append((base, positions))
    return out


# ================ POD5 selective loader ================


def load_pod5_signals_for_reads(pod5_glob, read_id_set):
    """Load only required reads' raw signals into dict {str(read_id): norm_signal}.
    Normalizes per-read using MAD (paper).
    """
    raw_signals = {}
    pod5_list = sorted(glob.glob(pod5_glob))
    if not pod5_list:
        raise FileNotFoundError(f"No POD5 files found at: {pod5_glob}")

    # normalize read_id_set to strings
    wanted = set(str(x) for x in read_id_set)

    for p5 in pod5_list:
        with Reader(p5) as r:
            for read in r.reads():
                rid = str(read.read_id)  # <-- CAST TO STRING
                if rid in wanted and rid not in raw_signals:
                    sig = np.array(read.signal, dtype=np.float32)
                    median = np.median(sig)
                    mad = np.median(np.abs(sig - median))
                    if mad > 0:
                        norm = (sig - median) / mad
                    else:
                        norm = sig - median
                    raw_signals[rid] = norm

    return raw_signals



# ================ MM/ML parsing to candidates (first pass) ================

def parse_dorado_modifications_to_candidates(read, require_motif='GATC'):
    """
    Return list of candidate dicts from a pysam read using MM/ML tags.
    Candidate fields:
      - read_id: str
      - read_pos: 0-based position in read
      - ref_pos: genome position
      - probability: float (0..1)
      - strand: '+' or '-'
      - motif: sequence context around candidate
    """
    candidates = []
    if not read.has_tag('MM'):
        return candidates

    mm_tag = read.get_tag('MM')
    ml_tag = read.get_tag('ML') if read.has_tag('ML') else None
    ml_probs = safe_ml_probs(ml_tag)
    seq = (read.query_sequence or "").upper()
    ref_start = read.reference_start if read.reference_start is not None else 0

    parsed = parse_mm_tag_positions(mm_tag)
    ml_idx = 0

    for base_head, positions in parsed:
        # Only accept adenine-type mods (A or variants containing 'A' or '6mA')
        if not re.search(r'A', base_head, re.I) and '6m' not in base_head.lower():
            ml_idx += len(positions)
            continue

        for pos in positions:
            if pos < 0 or pos >= len(seq):
                ml_idx += 1
                continue

            # ================== FIX: Motif window ==================
            ctx_len = 2  # number of bases before/after the candidate
            ctx_start = max(0, pos - ctx_len)
            ctx_end = min(len(seq), pos + ctx_len + 1)  # +1 because end index is exclusive
            motif = seq[ctx_start:ctx_end]
            # =======================================================

            prob = ml_probs[ml_idx] if (ml_probs and ml_idx < len(ml_probs)) else 0.5
            candidates.append({
                'read_id': str(read.query_name),
                'read_pos': pos,
                'ref_pos': ref_start + pos,
                'probability': float(prob),
                'strand': '-' if read.is_reverse else '+',
                'motif': motif
            })

            ml_idx += 1

    return candidates


# ================ Labeling & coverage helpers ================

def check_your_coverage(bam_file, max_sites=200000):
    coverage_per_site = defaultdict(int)
    with pysam.AlignmentFile(bam_file, 'rb') as bam:
        for read in bam.fetch():
            if read.is_unmapped:
                continue
            # limit for speed
            for pos in range(read.reference_start, read.reference_end):
                coverage_per_site[pos] += 1
            if len(coverage_per_site) > max_sites:
                break
    coverages = list(coverage_per_site.values())
    if not coverages:
        return []
    return coverages


def determine_labeling_strategy(bam_file):
    coverages = check_your_coverage(bam_file)
    if not coverages:
        return 'dorado_only'
    avg_coverage = np.mean(coverages)
    high_cov_sites = sum(c >= MIN_COVERAGE for c in coverages)
    total_sites = len(coverages)
    if avg_coverage >= 10 and high_cov_sites > total_sites * 0.6:
        return 'consensus_plus_dorado'
    elif avg_coverage >= 5 and high_cov_sites > total_sites * 0.3:
        return 'dorado_primary'
    else:
        return 'dorado_only'


def simple_effective_labeling(candidates, strategy='dorado_only'):
    site_probs = defaultdict(list)
    for c in candidates:
        site_probs[c['ref_pos']].append(c['probability'])
    high_conf = []
    # dorado thresholds
    for c in candidates:
        p = c['probability']
        if p >= ML_HIGH:
            c['label'] = 1; c['method'] = 'dorado_high'; high_conf.append(c)
        elif p <= ML_LOW:
            c['label'] = 0; c['method'] = 'dorado_low'; high_conf.append(c)
        else:
            c['label'] = None; c['method'] = 'ambiguous'
    # consensus
    if strategy in ('consensus_plus_dorado', 'dorado_primary'):
        for c in candidates:
            probs = site_probs[c['ref_pos']]
            if len(probs) >= MIN_COVERAGE:
                freq = np.mean(probs)
                if freq == 1.0:
                    c['label'] = 1; c['method'] = 'consensus_meth'
                elif freq == 0.0:
                    c['label'] = 0; c['method'] = 'consensus_unmeth'
                if c not in high_conf:
                    high_conf.append(c)
    return high_conf

# ================ Feature extraction (two-pass) ================

def resample_to_length(arr, length):
    """Simple 1D resample (linear interpolation) to `length` values."""
    if len(arr) == 0:
        return np.zeros(length, dtype=np.float32)
    if len(arr) == length:
        return arr.astype(np.float32)
    x_old = np.linspace(0, 1, num=len(arr))
    x_new = np.linspace(0, 1, num=length)
    return np.interp(x_new, x_old, arr).astype(np.float32)


def extract_paper_features_colab(bam_path=MERGED_BAM_PATH, pod5_glob=POD5_GLOB, max_reads=MAX_READS):
    """Memory-safe, paper-accurate feature extraction for Colab.
    Returns: seq_features (n,17,7) where 7 = one-hot(4)+mean+std+count,
             sig_features (n,360), labels (n,), metadata list
    """
    # 2) first pass to collect candidate read_ids (and candidates) up to max_reads
    bam = pysam.AlignmentFile(bam_path, 'rb')
    candidate_list = []
    read_ids_needed = set()
    for i, read in enumerate(bam.fetch(until_eof=True)):
        if max_reads and i >= max_reads:
            break
        if read.is_unmapped or (read.mapping_quality is not None and read.mapping_quality < MIN_MAPQ):
            continue
        cands = parse_dorado_modifications_to_candidates(read)
        if cands:
           candidate_list.extend(cands)
           read_ids_needed.update(str(c['read_id']) for c in cands)


    bam.close()

    if not candidate_list:
        return np.empty((0, WINDOW_SIZE, 7), dtype=np.float32), np.empty((0, SIGNAL_LENGTH), dtype=np.float32), np.empty(0, dtype=int), []

    # 3) load only required pod5 signals
    raw_signals = load_pod5_signals_for_reads(pod5_glob, read_ids_needed)
    print(f"Loaded raw signals for {len(raw_signals)} reads (needed {len(read_ids_needed)}).")

    # 4) second pass: compute feature matrices
    bam = pysam.AlignmentFile(bam_path, 'rb')
    seq_features = []
    sig_features = []
    labels = []
    metadata = []

    # Build index: read_id -> list of candidate dicts (to avoid re-parsing MM again)
    per_read_candidates = defaultdict(list)
    for c in candidate_list:
        per_read_candidates[str(c['read_id'])].append(c)

    for i, read in enumerate(bam.fetch(until_eof=True)):
        if max_reads and i >= max_reads:
            break
        rid = read.query_name
        if rid not in per_read_candidates:
            continue
        seq = (read.query_sequence or "").upper()
        n_bases = len(seq)
        # get raw signal
        if rid in raw_signals:
            raw_sig = raw_signals[rid]
            if read.is_reverse:
                raw_sig = raw_sig[::-1]
        else:
            raw_sig = None
        # event tags (if present)
        ev = None
        if read.has_tag('EV') and read.has_tag('ES') and read.has_tag('ED'):
            try:
                ev = np.array(read.get_tag('EV'), dtype=np.float32)
                es = np.array(read.get_tag('ES'), dtype=np.float32)
                ed = np.array(read.get_tag('ED'), dtype=np.float32)
                # if reverse, flip to query-order
                if read.is_reverse:
                    ev = ev[::-1]; es = es[::-1]; ed = ed[::-1]
            except Exception:
                ev = None

        # for each candidate in this read make features
        for cand in per_read_candidates[rid]:
            read_pos = cand['read_pos']
            # ========== SEQUENCE FEATURES (17 x 7) ==========
            start_seq = max(0, read_pos - (WINDOW_SIZE//2))
            end_seq = min(n_bases, read_pos + (WINDOW_SIZE//2) + 1)
            window_seq = seq[start_seq:end_seq]
            # pad both sides to CENTER candidate at index WINDOW_SIZE//2
            left_pad = (WINDOW_SIZE//2) - (read_pos - start_seq)
            right_pad = WINDOW_SIZE - len(window_seq) - max(0, left_pad)
            if left_pad > 0:
                window_seq = ('N'*left_pad) + window_seq
            if right_pad > 0:
                window_seq = window_seq + ('N'*right_pad)

            # Use event stats if present else compute from raw per-base slices
            base_stats = []  # [(mean,std,count) per base]
            if ev is not None and len(ev) >= n_bases:
                # map base->event index (simple scaling)
                ratio = len(ev) / max(1, n_bases)
                for b_idx in range(WINDOW_SIZE):
                    base_idx = start_seq + b_idx - (left_pad if left_pad>0 else 0)
                    ev_idx = int(round(base_idx * ratio)) if base_idx>=0 and base_idx<n_bases else None
                    if ev_idx is not None and 0 <= ev_idx < len(ev):
                        m = float(ev[ev_idx])
                        s = float(es[ev_idx]) if 'es' in locals() else 0.0
                        d = float(ed[ev_idx]) if 'ed' in locals() else 0.0
                    else:
                        m = s = d = 0.0
                    base_stats.append((m, s, d))
            elif raw_sig is not None and n_bases > 0:
                samples_per_base = max(1, int(round(len(raw_sig) / n_bases)))
                for b_idx in range(WINDOW_SIZE):
                    base_idx = start_seq + b_idx - (left_pad if left_pad>0 else 0)
                    if base_idx < 0 or base_idx >= n_bases:
                        base_stats.append((0.0, 0.0, 0.0))
                    else:
                        sstart = int(base_idx * samples_per_base)
                        send = min(len(raw_sig), sstart + samples_per_base)
                        slice_ = raw_sig[sstart:send] if send> sstart else np.array([], dtype=np.float32)
                        if slice_.size == 0:
                            base_stats.append((0.0,0.0,0.0))
                        else:
                            base_stats.append((float(np.mean(slice_)), float(np.std(slice_)), int(slice_.size)))
            else:
                # fallback zeros
                base_stats = [(0.0,0.0,0.0)] * WINDOW_SIZE

            # build one-hot + stats
            seq_matrix = np.zeros((WINDOW_SIZE, 7), dtype=np.float32)
            for b in range(WINDOW_SIZE):
                base = window_seq[b] if b < len(window_seq) else 'N'
                one_hot = [0.0, 0.0, 0.0, 0.0]
                if base == 'A': one_hot[0] = 1.0
                elif base == 'C': one_hot[1] = 1.0
                elif base == 'G': one_hot[2] = 1.0
                elif base == 'T': one_hot[3] = 1.0
                m,s,d = base_stats[b]
                seq_matrix[b, :4] = one_hot
                seq_matrix[b, 4] = m
                seq_matrix[b, 5] = s
                seq_matrix[b, 6] = d

            # ========== SIGNAL FEATURES (360) ==========
            # Prefer EV if present (resample to 360), else resample raw_sig around center sample
            if ev is not None and len(ev) > 0:
                # center event index
                ratio = len(ev) / max(1, n_bases)
                center_event = int(round(read_pos * ratio))
                start_e = max(0, center_event - (SIGNAL_LENGTH//2))
                end_e = min(len(ev), start_e + SIGNAL_LENGTH)
                window_ev = ev[start_e:end_e]
                if len(window_ev) < SIGNAL_LENGTH:
                    window_ev = np.pad(window_ev, (0, SIGNAL_LENGTH - len(window_ev)), 'constant')
                sig_vec = resample_to_length(window_ev, SIGNAL_LENGTH)
            elif raw_sig is not None:
                # approximate center in samples
                samples_per_base = max(1, int(round(len(raw_sig) / max(1, n_bases))))
                center_sample = int(read_pos * samples_per_base)
                start_s = max(0, center_sample - (SIGNAL_LENGTH//2))
                end_s = min(len(raw_sig), start_s + SIGNAL_LENGTH)
                window_sig = raw_sig[start_s:end_s]
                if len(window_sig) < SIGNAL_LENGTH:
                    window_sig = np.pad(window_sig, (0, SIGNAL_LENGTH - len(window_sig)), 'constant')
                sig_vec = resample_to_length(window_sig, SIGNAL_LENGTH)
            else:
                sig_vec = np.zeros((SIGNAL_LENGTH,), dtype=np.float32)

            # MAD normalization per-signal vector (guarded)
            med = np.median(sig_vec)
            mad = np.median(np.abs(sig_vec - med))
            if mad > 0:
                sig_vec = (sig_vec - med) / mad
            else:
                sig_vec = sig_vec - med

            # collect
            seq_features.append(seq_matrix)
            sig_features.append(sig_vec)
            labels.append(1 if cand.get('probability', 0.5) >= ML_HIGH else 0 if cand.get('probability',0.5) <= ML_LOW else -1)
            metadata.append(cand)

    bam.close()

    seq_features = np.array(seq_features, dtype=np.float32) if seq_features else np.empty((0, WINDOW_SIZE, 7), dtype=np.float32)
    sig_features = np.array(sig_features, dtype=np.float32) if sig_features else np.empty((0, SIGNAL_LENGTH), dtype=np.float32)
    labels = np.array(labels, dtype=np.int32) if labels else np.empty((0,), dtype=np.int32)

    print(f"Extraction done: seq={seq_features.shape}, sig={sig_features.shape}, labels={labels.shape}")
    return seq_features, sig_features, labels, metadata

# ================ Save utility ================

def save_dataset(seq_data, sig_data, y_true, metadata, output_path=OUTPUT_DATASET):
    with h5py.File(output_path, 'w') as f:
        f.create_dataset('sequence_features', data=seq_data, compression='gzip')
        f.create_dataset('signal_features', data=sig_data, compression='gzip')
        f.create_dataset('labels', data=y_true, compression='gzip')
        mg = f.create_group('metadata')
        for i, meta in enumerate(metadata):
            g = mg.create_group(f'sample_{i}')
            for k, v in meta.items():
                if isinstance(v, (str, int, float)):
                    g.attrs[k] = v
                elif isinstance(v, np.ndarray):
                    g.create_dataset(k, data=v)
        f.attrs['paper_reference'] = 'DeepSignal (paper-accurate features, adapted)'
        f.attrs['window_size'] = WINDOW_SIZE
        f.attrs['signal_length'] = SIGNAL_LENGTH
        f.attrs['n_samples'] = seq_data.shape[0]
    print(f"Saved dataset: {output_path}")

# ================ Runner ================

def run_part_b_colab(bam_path=MERGED_BAM_PATH, pod5_glob=POD5_GLOB, max_reads=MAX_READS, save_path=OUTPUT_DATASET):
    seq, sig, labels, meta = extract_paper_features_colab(bam_path, pod5_glob, max_reads=max_reads)
    if seq.shape[0] == 0:
        print('No extracted samples — check MM tags / POD5 availability / motif filters')
        return
    save_dataset(seq, sig, labels, meta, save_path)
    print('PART B (Colab-optimized) complete.')


# If called as script, run with defaults (safe test)
if __name__ == '__main__':
    run_part_b_colab()

Loaded raw signals for 340 reads (needed 340).
Extraction done: seq=(1422974, 17, 7), sig=(1422974, 360), labels=(1422974,)
Saved dataset: /content/deepsignal_paper/deepsignal_dataset_colab.h5
PART B (Colab-optimized) complete.


In [5]:
import os
import h5py
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from collections import defaultdict

DATA_PATH = "/content/deepsignal_paper/deepsignal_dataset_colab.h5"
SAVE_PATH = "/content/deepsignal_paper/deepsignal_balanced.h5"
BATCH_SIZE = 256
NUM_EPOCHS = 10
RANDOM_SEED = 42

class DeepSignalDataset(Dataset):
    def __init__(self, h5_path, read_ids=None):
        self.h5_path = h5_path
        with h5py.File(h5_path, 'r') as f:
            self.sequence_features = f['sequence_features'][:]
            self.signal_features = f['signal_features'][:]
            self.labels = f['labels'][:]
            n = self.labels.shape[0]
            self.read_ids = [f[f'metadata/sample_{i}'].attrs['read_id'] for i in range(n)]
        if read_ids is not None:
            id_set = set(read_ids)
            self.indices = [i for i, rid in enumerate(self.read_ids) if rid in id_set]
        else:
            self.indices = list(range(len(self.labels)))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        seq = torch.tensor(self.sequence_features[actual_idx], dtype=torch.float32)
        sig = torch.tensor(self.signal_features[actual_idx], dtype=torch.float32)
        label = torch.tensor(self.labels[actual_idx], dtype=torch.float32)
        return seq, sig, label

def split_data_by_read_ids(h5_path, test_size=0.2, val_size=0.1):
    with h5py.File(h5_path, 'r') as f:
        n_samples = f['labels'].shape[0]
        read_ids = []
        labels_list = []
        for i in range(n_samples):
            sample_meta = f[f'metadata/sample_{i}']
            read_ids.append(sample_meta.attrs['read_id'])
            labels_list.append(f['labels'][i])
    read_id_to_samples = {}
    for i, read_id in enumerate(read_ids):
        if read_id not in read_id_to_samples:
            read_id_to_samples[read_id] = []
        read_id_to_samples[read_id].append({
            'sample_index': i,
            'label': labels_list[i]
        })
    unique_read_ids = list(read_id_to_samples.keys())
    train_reads, test_reads = train_test_split(
        unique_read_ids,
        test_size=test_size,
        random_state=RANDOM_SEED,
        shuffle=True
    )
    train_reads, val_reads = train_test_split(
        train_reads,
        test_size=val_size/(1-test_size),
        random_state=RANDOM_SEED,
        shuffle=True
    )
    train_samples = sum(len(read_id_to_samples[rid]) for rid in train_reads)
    val_samples = sum(len(read_id_to_samples[rid]) for rid in val_reads)
    test_samples = sum(len(read_id_to_samples[rid]) for rid in test_reads)
    train_set = set(train_reads)
    val_set = set(val_reads)
    test_set = set(test_reads)
    assert len(train_set & val_set) == 0
    assert len(train_set & test_set) == 0
    assert len(val_set & test_set) == 0
    return train_reads, val_reads, test_reads

def create_balanced_dataset(h5_path, output_path, samples_per_class=35000):
    with h5py.File(h5_path, 'r') as f:
        labels = f['labels'][:]
        read_ids = [f[f'metadata/sample_{i}'].attrs['read_id'] for i in range(len(labels))]
        valid_mask = (labels == 0) | (labels == 1)
        valid_indices = np.where(valid_mask)[0]
        methylated_idx = valid_indices[labels[valid_indices] == 1]
        unmethylated_idx = valid_indices[labels[valid_indices] == 0]
        meth_size = min(samples_per_class, len(methylated_idx))
        unmeth_size = min(samples_per_class, len(unmethylated_idx))
        methylated_sample = np.random.choice(methylated_idx, meth_size, replace=False)
        unmethylated_sample = np.random.choice(unmethylated_idx, unmeth_size, replace=False)
        balanced_indices = np.concatenate([methylated_sample, unmethylated_sample])
        sorted_indices = np.sort(balanced_indices)
        seq_data = f['sequence_features'][sorted_indices]
        sig_data = f['signal_features'][sorted_indices]
        labels_data = f['labels'][sorted_indices]
        metadata_list = []
        for idx in sorted_indices:
            meta_group = f[f'metadata/sample_{idx}']
            metadata_list.append({
                'read_id': meta_group.attrs['read_id'],
                'motif': meta_group.attrs.get('motif', ''),
                'probability': meta_group.attrs.get('probability', 0.5),
                'read_pos': meta_group.attrs.get('read_pos', 0),
                'ref_pos': meta_group.attrs.get('ref_pos', 0),
                'strand': meta_group.attrs.get('strand', '+')
            })
    with h5py.File(output_path, 'w') as f_out:
        shuffle_idx = np.random.permutation(len(seq_data))
        f_out.create_dataset('sequence_features', data=seq_data[shuffle_idx], compression='gzip')
        f_out.create_dataset('signal_features', data=sig_data[shuffle_idx], compression='gzip')
        f_out.create_dataset('labels', data=labels_data[shuffle_idx], compression='gzip')
        meta_group = f_out.create_group('metadata')
        for new_idx, old_idx in enumerate(shuffle_idx):
            new_meta = meta_group.create_group(f'sample_{new_idx}')
            for key, value in metadata_list[old_idx].items():
                new_meta.attrs[key] = value
    return output_path

class DeepSignalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.signal_cnn_head = nn.Sequential(
            nn.Conv1d(1, 64, 7, stride=2, padding=3),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(3, stride=2, padding=1)
        )
        blocks = []
        in_ch = 64
        out_ch = 32
        for _ in range(3):
            blocks.append(InceptionBlock(in_ch, out_ch))
            in_ch = out_ch * 4
        self.signal_inception = nn.Sequential(*blocks)
        self.signal_tail = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(in_ch, 128),
            nn.ReLU()
        )
        self.sequence_brnn = nn.Sequential(
            nn.LSTM(7, 32, batch_first=True, bidirectional=True, num_layers=2),
            nn.Linear(64, 128),
            nn.ReLU()
        )
        self.classifier = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, seq, sig):
        sig = self.signal_cnn_head(sig.unsqueeze(1))
        sig = self.signal_inception(sig)
        sig = self.signal_tail(sig)
        lstm_out, _ = self.sequence_brnn[0](seq)
        seq_feat = self.sequence_brnn[1](lstm_out[:, -1, :])
        combined = torch.cat([sig, seq_feat], dim=1)
        return self.classifier(combined).squeeze()

class InceptionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.branch1 = nn.Conv1d(in_channels, out_channels, 1)
        self.branch2 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, 1),
            nn.Conv1d(out_channels, out_channels, 3, padding=1)
        )
        self.branch3 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, 1),
            nn.Conv1d(out_channels, out_channels, 5, padding=2)
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool1d(3, stride=1, padding=1),
            nn.Conv1d(in_channels, out_channels, 1)
        )

    def forward(self, x):
        return torch.cat([
            self.branch1(x),
            self.branch2(x),
            self.branch3(x),
            self.branch4(x)
        ], dim=1)

def train_model():
    balanced_path = create_balanced_dataset(DATA_PATH, SAVE_PATH, samples_per_class=35000)
    train_reads, val_reads, test_reads = split_data_by_read_ids(balanced_path)
    train_dataset = DeepSignalDataset(balanced_path, train_reads)
    val_dataset = DeepSignalDataset(balanced_path, val_reads)
    test_dataset = DeepSignalDataset(balanced_path, test_reads)
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = DeepSignalModel().to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(NUM_EPOCHS):
        model.train()
        train_loss = 0
        for batch_idx, (seq, sig, labels) in enumerate(train_loader):
            seq, sig, labels = seq.to(device), sig.to(device), labels.float().to(device)
            optimizer.zero_grad()
            outputs = model(seq, sig)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for seq, sig, labels in val_loader:
                seq, sig, labels = seq.to(device), sig.to(device), labels.float().to(device)
                outputs = model(seq, sig)
                val_loss += criterion(outputs, labels).item()
        print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}')
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for seq, sig, labels in test_loader:
            seq, sig, labels = seq.to(device), sig.to(device), labels.float().to(device)
            outputs = model(seq, sig)
            preds = (outputs > 0.5).float()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    print(f"Test Results:")
    print(f"Accuracy:  {accuracy:.3f}")
    print(f"Precision: {precision:.3f}")
    print(f"Recall:    {recall:.3f}")
    print(f"F1-Score:  {f1:.3f}")
    print(f"Confusion Matrix:\n{cm}")
    return model

if __name__ == "__main__":
    torch.manual_seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)
    model = train_model()



Epoch [1/10], Train Loss: 0.6929, Val Loss: 0.6938
Epoch [2/10], Train Loss: 0.6908, Val Loss: 0.6950
Epoch [3/10], Train Loss: 0.6886, Val Loss: 0.7015
Epoch [4/10], Train Loss: 0.6873, Val Loss: 0.7029
Epoch [5/10], Train Loss: 0.6865, Val Loss: 0.7146
Epoch [6/10], Train Loss: 0.6847, Val Loss: 0.7005
Epoch [7/10], Train Loss: 0.6832, Val Loss: 0.7117
Epoch [8/10], Train Loss: 0.6819, Val Loss: 0.6946
Epoch [9/10], Train Loss: 0.6805, Val Loss: 0.7002
Epoch [10/10], Train Loss: 0.6796, Val Loss: 0.6978
Test Results:
Accuracy:  0.507
Precision: 0.494
Recall:    0.247
F1-Score:  0.329
Confusion Matrix:
[[5573 1787]
 [5328 1748]]
