In [4]:
# using Pkg
# Pkg.instantiate()

# Import PPSeq
import PPSeq
const seq = PPSeq

# Other Imports
import PyPlot: plt
import DelimitedFiles: readdlm
import Random
import StatsBase: quantile

# Network metadata
num_neurons = 100
max_time = 3.5

# Randomly permute neuron labels.
# (This hides the sequences, to make things interesting.)
_p = [1:num_neurons;] # Random.randperm(num_neurons)

# These two parameters control what datafiles are being read
numSeq = 10
instList = [100]

for x in 1:length(instList)
    # Load spikes.
    spikes = seq.Spike[]
    filename = "data/Regular/png_" * string(numSeq) * "s" * string(instList[x]) * "i.txt"
    spikeCt = 0
    for (n, t) in eachrow(readdlm(filename, '\t', Float64, '\n'))
        push!(spikes, seq.Spike(_p[Int(n)], t))
        spikeCt += 1
    end

    numInst = instList[x]
    # i = 1
    # while i <= length(filename)
    #     if isdigit(filename[i])
    #         j = i
    #         while isdigit(filename[j])
    #             j += 1
    #         end
    #         if numSeq == 0
    #             numSeq = parse(Int64, filename[i:j - 1])  # Read no. of sequences from file name
    #         else
    #             numInst = parse(Int64, filename[i:j - 1])  # Read no. of instances from file name
    #         end  
    #         i = j
    #     end
    #     i += 1
    # end

##################### Must tweak config according to data ##########################################
config = Dict(
    # Model hyperparameters
    :num_sequence_types =>  numSeq,  # Types of sequences in data. Overestimating is better
    :seq_type_conc_param => 1.0,  # Assumes every sequence is equally likely (fair assumption)
    :seq_event_rate => (numSeq * numInst) / max_time,  # No. of sequences/sec.

    :mean_event_amplitude => 10.0,  # Avg. number of spikes per sequence (10 for our purposes)
    :var_event_amplitude => 1.5,  # Small because it's always 10 for us, may increase for other purposes
    
    :neuron_response_conc_param => 0.1,  # Idk, I thought a spike's a spike
    :neuron_offset_pseudo_obs => 0.01,  # Average delay between neurons in sequence (10 ms)
    :neuron_width_pseudo_obs => 1.0, # Causes it to train for much longer if increased
    :neuron_width_prior => 0.0001,  #  Tolerance (in sec.) for noisy delays in a sequence
    
    :num_warp_values => 1,
    :max_warp => 1.0,
    :warp_variance => 1.0,

    :mean_bkgd_spike_rate => max(1.0, (spikeCt - 10 * numSeq * numInst) / max_time),  # No. of non-sequence spikes / sec.
    :var_bkgd_spike_rate => 10.0,  # Variance is tight cuz we know the sequences
    :bkgd_spikes_conc_param => 0.5,  # Longer training if decreased
    :max_sequence_length => 12.0,  # Self-explanatory. All sequences rn are 10 spikes
    
    # MCMC Sampling parameters.
    :num_anneals => 10,
    :samples_per_anneal => 100,
    :max_temperature => 40.0,
    :save_every_during_anneal => 10,
    :samples_after_anneal => 2000,
    :save_every_after_anneal => 10,
    :split_merge_moves_during_anneal => 15,  # Increasing generally helps reduce mislabeled dupe sequences
    :split_merge_moves_after_anneal => 20,  # Same as above
    :split_merge_window => 1.0,

);

# Initialize all spikes to background process.
init_assignments = fill(-1, length(spikes))

# Construct model struct (PPSeq instance).
model = seq.construct_model(config, max_time, num_neurons)

# Run Gibbs sampling with an initial annealing period.
results = seq.easy_sample!(model, spikes, init_assignments, config);

# Grab the final MCMC sample
final_globals = results[:globals_hist][end]
final_events = results[:latent_event_hist][end]
final_assignments = results[:assignment_hist][:, end]

# Helpful utility function that sorts the neurons to reveal sequences.
neuron_ordering = [1:num_neurons;] 
# seq.sortperm_neurons(final_globals)

# Plot model-annotated raster.
fig = seq.plot_raster(
    spikes,
    final_events,
    final_assignments,
    neuron_ordering;
    color_cycle=["red", "blue", "yellow", "green", "orange", "purple", "pink", "cyan", "magenta", "gray"] # colors for each sequence type can be modified.
)
fig.set_size_inches([10, 4]);
fig.savefig("graphs/" * string(numSeq) * "s" * string(numInst) * "i.png")
end

(warp_values, warp_log_proportions) = ([1.0], [0.0])
TEMP:  39.99999999999999
10