In [None]:
using Serialization, StatsPlots, CSV, DataFrames, CircularArrays
f1_fiveprime_3x3_8_6 = deserialize("f1_fiveprime_3x3_8_6.predictions")
f1_threeprime_3x3_8_6 = deserialize("f1_threeprime_3x3_8_6.predictions")
f1_3x3_8_6_fwd = CircularVector(max.(first.(first(f1_fiveprime_3x3_8_6)), first.(first(f1_threeprime_3x3_8_6))))
f1_3x3_8_6_rev = CircularVector(max.(first.(last(f1_fiveprime_3x3_8_6)), first.(last(f1_threeprime_3x3_8_6))))

# --- Load known footprint data ---
known = CSV.File("Known PPR Binding Sites AP000423.gff"; comment = "#",
    header = ["accession", "software", "feature", "start", "stop", "score", "strand", "phase", "attributes", "present"]
) |> DataFrame
filter!(x -> occursin("present", x.present), known)

# --- Reverse complement coordinate conversion ---
for footprint in eachrow(known)
    if footprint.strand == "-"
        rcstart = 154478 - footprint.stop + 1
        rcstop = 154478 - footprint.start + 1
        footprint.start = rcstart
        footprint.stop = rcstop
    end
end

# --- Utility functions ---
function includes(r1::UnitRange, r2::UnitRange)
    length(intersect(r1, r2)) == length(r2)
end

struct ConfusionMatrix
    tp::Vector{Int}
    fp::Vector{Int}
    tn::Vector{Int}
    fn::Vector{Int}
end

function addcounts!(ref::DataFrame, predictions::CircularVector{Float32}, counts::ConfusionMatrix)
    for (j, p) in enumerate(predictions)
        includesfootprint = false
        for footprint in eachrow(ref)
            if includes(j:j+49, footprint.start:footprint.stop)
                includesfootprint = true
                break
            end
        end
        for (i, t) in enumerate(range(1.0, 0.0; step = -0.001))
            if includesfootprint
                if p >= t
                    counts.tp[i] += 1
                else
                    counts.fn[i] += 1
                end
            else
                if p >= t
                    counts.fp[i] += 1
                else
                    counts.tn[i] += 1
                end
            end
        end
    end
    return counts
end

function ROC(fwd_predictions::CircularVector{Float32}, rev_predictions::CircularVector{Float32})
    counts = ConfusionMatrix(zeros(Int, 1001), zeros(Int, 1001), zeros(Int, 1001), zeros(Int, 1001))
    ref = filter(x -> x.strand == "+", known)
    counts = addcounts!(ref, fwd_predictions, counts)
    ref = filter(x -> x.strand == "-", known)
    counts = addcounts!(ref, rev_predictions, counts)
    fpr = 1 .- counts.tn ./ (counts.tn .+ counts.fp)
    tpr = counts.tp ./ (counts.tp .+ counts.fn)
    return fpr, tpr
end

roc = ROC(f1_3x3_8_6_fwd, f1_3x3_8_6_rev)

([0.0005837408709413738, 0.0029316763740611096, 0.0032981359208188454, 0.0035186602498410435, 0.0037326985691862324, 0.003878633786921659, 0.004034298019172655, 0.004180233236907971, 0.004332654464320407, 0.004371570522383239  …  0.08082216658667252, 0.08826810569601373, 0.09777983888751962, 0.1096524796014996, 0.12670419904266494, 0.14898688528843285, 0.21568576580316257, 0.25085615327738064, 0.30983019626665287, 1.0], [0.02666666666666667, 0.16166666666666665, 0.17333333333333334, 0.18666666666666668, 0.19, 0.19333333333333333, 0.19333333333333333, 0.195, 0.195, 0.19666666666666666  …  0.52, 0.52, 0.5233333333333333, 0.5283333333333333, 0.535, 0.5366666666666666, 0.575, 0.58, 0.615, 1.0])

In [21]:
#bestt = range(1.0, 0.0; step = -0.001)[argmax(last(roc) .+ 1 .- first(roc))]
bestt = 0.99

0.99

## distillation ##

In [None]:
using CSV, DataFrames, CircularArrays

# --- Load GFF to get structural RNA coordinates ---
Atgff = CSV.File("AP000423.gff";
    comment = "#",
    header = ["accession", "software", "feature", "start", "stop", "score", "strand", "phase", "attributes"]
) |> DataFrame

genome_length = first(Atgff[Atgff.feature .== "region", :stop])
rc(x::Real) = genome_length - x + 1

# --- Identify structural features (tRNA, rRNA, trn-containing introns) ---
structuralRNA = filter(row -> begin
    row.feature == "rRNA" ||
    row.feature == "tRNA" ||
    occursin(r"(?i)trn", row.attributes)   # include introns like trnK-UUU
end, Atgff)

sort!(structuralRNA, :start)

# --- Build boolean masks for forward and reverse strands ---
mask_fwd = falses(genome_length)
mask_rev = falses(genome_length)
for row in eachrow(structuralRNA)
    mask_fwd[row.start:row.stop] .= true
    mask_rev[rc(row.stop):rc(row.start)] .= true
end

# --- Distillation with structural exclusion ---
stranded_predictions = Tuple{Bool, Int, Float32}[]
for (i, p) in enumerate(f1_3x3_8_6_fwd)
    push!(stranded_predictions, (true, i, p))
end
for (i, p) in enumerate(f1_3x3_8_6_rev)
    push!(stranded_predictions, (false, i, p))
end

sort!(stranded_predictions; by = last, rev = true)

fwdmask = CircularVector(falses(length(f1_3x3_8_6_fwd)))
revmask = CircularVector(falses(length(f1_3x3_8_6_rev)))

distilled_predictions = Tuple{Bool, Int, Float32}[]

for p in stranded_predictions
    last(p) < bestt && break
    is_fwd, pos, score = p

    # skip predictions overlapping rRNA/tRNA/trn regions
    if is_fwd
        if any(mask_fwd[pos:clamp(pos + 49, 1, genome_length)])
            continue  # skip this prediction
        end
        if !any(fwdmask[pos:clamp(pos + 49, 1, genome_length)])
            push!(distilled_predictions, p)
        end
        fwdmask[pos:clamp(pos + 49, 1, genome_length)] .= true
    else
        if any(mask_rev[pos:clamp(pos + 49, 1, genome_length)])
            continue
        end
        if !any(revmask[pos:clamp(pos + 49, 1, genome_length)])
            push!(distilled_predictions, p)
        end
        revmask[pos:clamp(pos + 49, 1, genome_length)] .= true
    end
end

distilled_predictions

99-element Vector{Tuple{Bool, Int64, Float32}}:
 (1, 32511, 1.0)
 (1, 33700, 1.0)
 (1, 48412, 1.0)
 (1, 54874, 1.0)
 (1, 57002, 1.0)
 (1, 65648, 1.0)
 (1, 67136, 1.0)
 (1, 93556, 1.0)
 (1, 115563, 1.0)
 (1, 141463, 1.0)
 ⋮
 (1, 33966, 0.99601704)
 (0, 93306, 0.99531484)
 (1, 87288, 0.9948232)
 (0, 38897, 0.99469805)
 (1, 141029, 0.9927418)
 (1, 8350, 0.9926473)
 (1, 67767, 0.9923679)
 (1, 95741, 0.99234056)
 (1, 152040, 0.991689)

In [23]:
using CSV, DataFrames
# Prepare DataFrame from distilled_predictions
df = DataFrame(start=Int[], stop=Int[], score=Float32[], strand=String[])
for (strand, start, score) in distilled_predictions
    stop = start + 49
    strand_label = strand ? "+" : "-"
    push!(df, (start, stop, score, strand_label))
end
# Write to CSV
CSV.write("distilled_predictions_filtered_masking.csv", df)

"distilled_predictions_filtered_masking.csv"