In [None]:
using NextGenSeqUtils, PyPlot, StatsBase, DPMeansClustering, RobustAmpliconDenoising, TableReader

function demux_gel_plot(demux_dic;keys_to_plot = nothing, labels = nothing, alpha = 0.025, height = 6, grid_lines_every = 1000)
    if keys_to_plot == nothing
        keys_to_plot = keys(demux_dic)
    end
    inds = Int[]
    x_pos = []
    points = []
    label_text = []
    count = 1
    for k in keys_to_plot
        l = 0
        if haskey(demux_dic,k)
            push!(inds,count)
            append!(points,length.([i[1] for i in demux_dic[k]]))
            append!(x_pos,[count + rand()*0.7 for i in demux_dic[k]])
            l = length(demux_dic[k])
        end
        if labels != nothing
            push!(label_text,labels[count]*" $l")
        end
        count += 1
    end
    figure(figsize=(count, height))
    for l in grid_lines_every:grid_lines_every:Int64(round(maximum(points)))
        plot([0.5,count+0.5],[l,l],alpha = 0.05, color = "red")
    end
    plot(x_pos,points,alpha = alpha,".",color = "grey")
    if labels != nothing
        m = maximum(points)
        for i in 1:count-1
            text(i+0.33, m*1.1, label_text[i] , rotation=90)
        end
    end
    xlim(0.5,count+0.5);
    xticks([]);
    tight_layout(w_pad=0, h_pad=0, pad=0);
end


function sizes_to_names(sizes; prefix = "")
  return ["$(prefix)seq_$(i)_$(sizes[i])" for i in 1:length(sizes)]
end

function name_clean(s)
    return replace(s,' '=>'_')
end

function dedup(reads; thresh = 2)
    cm = countmap(reads)
    filt = values(cm) .>= thresh
    println("Keeping: ",sum(filt),". Discarding: ",sum(.!filt))
    pairs = sort(collect(zip(collect(values(cm))[filt],collect(keys(cm))[filt])),rev = true)
    return [p[2] for p in pairs],[p[1] for p in pairs]
end

function barcode_extract(seq)
    return seq[end-29:end-10] #include up to the first discriminating base in the isotype primer.
end

#This function takes a set of reads, and a function that extracts the barcode region.
#Reads are de-duplicated across their entirety, and any barcodes that occur identically,
#but at lower frequency, are removed.
function barcoded_dedup(reads, barcode_func; thresh = 0, randomize = true)
    cm = countmap(reads)
    filt = values(cm) .>= thresh
    println("Keeping: ",sum(filt),". Discarding: ",sum(.!filt))
    vals = collect(values(cm))
    if randomize
        vals =  vals .+ (0.01 .* rand(length(vals)))
        println("Check:", vals[1:min(end,10)])
    end
    pairs = sort(collect(zip(collect(vals)[filt],collect(keys(cm))[filt])),rev = true)
    
    barcode_set = Set{String}()
    keep_vec = zeros(Bool,length(pairs))
    for i in 1:length(pairs)
        bc = barcode_extract(pairs[i][2])
        if !(bc in barcode_set)
            push!(barcode_set,bc)
            keep_vec[i] = true
        end
    end
    
    println(length(barcode_set), " barcodes from ", length(pairs), "deduplicated reads.")
    
    return [p[2] for p in pairs][keep_vec],[Int(round(p[1])) for p in pairs][keep_vec]
end

#paths
primer_csv = "primer_sheet.csv";
fastq_file = "m54340U_200124_230845.Q20.fastq"
min_len = 750
max_len = 1500
err_rate = 0.001
primer_trunc_length = 12

primer_table = readcsv(primer_csv);
IDs = primer_table[!,1]
fwd_p = primer_table[!,2]
rev_p = primer_table[!,3];

f_trunc = [p[1:primer_trunc_length] for p in fwd_p];
r_trunc = [p[1:primer_trunc_length] for p in rev_p];

unique_fwds = union(f_trunc)
unique_revs = union(r_trunc)
ind_to_f_p = Dict(zip(unique_fwds,collect(1:length(unique_fwds))))
ind_to_r_p = Dict(zip(unique_revs,collect(1:length(unique_revs))))
occuring_pairs = collect(zip([ind_to_f_p[p] for p in f_trunc],[ind_to_r_p[p] for p in r_trunc]));

length(union(f_trunc)) - length(union(fwd_p)), length(union(r_trunc)) - length(union(rev_p))

fastq_filter(fastq_file, fastq_file*"filt.fastq",error_rate = err_rate, min_length = min_len, max_length = max_len);

@time seqs,phreds,seq_names = read_fastq(fastq_file*"filt.fastq");

#Matching entire primer.
@time demux_dic = demux_dict(seqs,unique_fwds,unique_revs,verbose=false,phreds = phreds,tol_one_error = true);
demux_gel_plot(demux_dic, keys_to_plot = occuring_pairs, labels = IDs, alpha = 0.01, height = 5, grid_lines_every = 100);
savefig(fastq_file*".gel.png")

#Dereplication
sampleIDs = IDs
seqID = 1
for i in 1:length(sampleIDs)
    pair = occuring_pairs[i]
    println("Trying $(sampleIDs[i])")
    if haskey(demux_dic,pair)

            yikes = [double_primer_trim(s[1], s[2], fwd_p[pair[1]], rev_p[pair[2]][1:15]) for s in demux_dic[pair]]
            trimmed_seqs = [y[1] for y in yikes]
            cons,sizes = barcoded_dedup(trimmed_seqs, barcode_extract, thresh = 2)
            write_fasta("$(sampleIDs[i])_barcode_aware_dereplicated_thresh2.fasta",cons,names = sizes_to_names(sizes,prefix = sampleIDs[i]*"_"))
    else
        println("No data for sample $(sampleIDs[i]).")
    end
end