In [1]:
using CSV
using DataFrames
using DataStructures
using Distributions
using GZip
using JSON
using LinearAlgebra
using Logging
using NPZ
using ProgressMeter
using PyPlot
using PyCall

using Distributions
using Random

import DataStructures.compare

include("julia/src/enums.jl")
include("julia/src/events.jl")
include("julia/src/event_execution.jl")
include("julia/src/utils.jl")
include("julia/src/data_loading.jl")

struct Earlier end
compare(c::Earlier, x::AbstractEvent, y::AbstractEvent) = time(x) < time(y)

compare (generic function with 4 methods)

In [2]:
struct Progression
    severity::Severity
    incubation_time::Float32
    symptom_onset_time::Float32
    hospitalization_time::Float32
end

In [3]:
function sample_progression(rng, dist_severity, dist_incubation, dist_symptom_onset, dist_hospitalization)
    severity = rand(rng, dist_severity) |> Severity
    
    incubation_time = rand(rng, dist_incubation)
    symptom_onset_time = incubation_time + rand(rng, dist_symptom_onset)
    hospitalization_time = NaN    
    
    if (severity==Severe) || (severity==Critical)
        hospitalization_time = incubation_time + rand(dist_hospitalization)
        if hospitalization_time < symptom_onset_time
            symptom_onset_time = NaN
        end
    end
    
    Progression(
        severity,
        incubation_time,
        symptom_onset_time,
        hospitalization_time
    )
end

sample_progression (generic function with 1 method)

In [4]:
# the constant part of the simulation
struct SimParams 
    households::Vector{Int32}  # to be decided what it is - (list of pointers to the first member?)
    
    progressions::Vector{Progression}
    
    constant_kernel_param::Float64
end

# the mutable part of the simulation
mutable struct SimState
    time::Float64
    queue::BinaryHeap{AbstractEvent, Earlier} # TODO change to union once all events are implemented
    health_states::Vector{HealthState}
    detected::BitVector
    
    infections::SortedMultiDict{Int32,AbstractInfectionEvent} # TODO change to union 
    
    num_dead::Int
    num_affected::Int
    num_detected::Int
    
    SimState(num_individuals::Integer) = num_individuals<=0 ? error("number of individuals must be positive") : 
        new(
            0.0,
            BinaryHeap{AbstractEvent, Earlier}(),
            fill(Healthy, num_individuals),
            falses(num_individuals),
            SortedMultiDict{Int32,AbstractInfectionEvent}(),
    
            0,
            0,
            0
        ) 
end

In [5]:
function initialize(rng=MersenneTwister();
        population_path::AbstractString,
        incubation_time_samples_path::AbstractString,
        t0_to_t1_samples_path::AbstractString,
        t0_to_t2_samples_path::AbstractString)
    
    individuals_df = load_individuals(population_path)
    num_individuals = individuals_df |> nrow

    dist_severity = Categorical([4/10, 3/10, 2/10, 1/10])
    dist_incubation_time = load_dist_from_samples(incubation_time_samples_path)
    dist_symptom_onset_time = load_dist_from_samples(t0_to_t1_samples_path)
    dist_hospitalization_time = load_dist_from_samples(t0_to_t2_samples_path)

    progression = 1:num_individuals .|> _ -> sample_progression(rng, 
        dist_severity, 
        dist_incubation_time, 
        dist_symptom_onset_time, 
        dist_hospitalization_time)
    
    households = 1:num_individuals |> collect #TODO
    
    params = SimParams(
        households,
        progression,        
        1.0
    )

    params
end



params = initialize(
    population_path="data/simulations/wroclaw-population-orig.csv.gz", 
    incubation_time_samples_path="test/models/assets/incubation_period_distribution.npy", 
    t0_to_t1_samples_path="test/models/assets/t1_distribution.npy",
    t0_to_t2_samples_path="test/models/assets/t1_t2_distribution.npy")

state = SimState(params.households |> length)

SimState(0.0, BinaryHeap{AbstractEvent,Earlier}(Earlier(), AbstractEvent[]), HealthState[Healthy, Healthy, Healthy, Healthy, Healthy, Healthy, Healthy, Healthy, Healthy, Healthy  …  Healthy, Healthy, Healthy, Healthy, Healthy, Healthy, Healthy, Healthy, Healthy, Healthy], Bool[0, 0, 0, 0, 0, 0, 0, 0, 0, 0  …  0, 0, 0, 0, 0, 0, 0, 0, 0, 0], SortedMultiDict(Base.Order.ForwardOrdering(),), 0, 0, 0)

In [6]:
population_path = "data/simulations/wroclaw-population-orig.csv.gz"
incubation_time_samples_path = "test/models/assets/incubation_period_distribution.npy"
t0_to_t1_samples_path = "test/models/assets/t1_distribution.npy"
t0_to_t2_samples_path = "test/models/assets/t1_t2_distribution.npy"


"test/models/assets/t1_t2_distribution.npy"

In [7]:
#function enqueue_transmissions!(state::SimState, ::Any, params::SimParams, source_id::Integer) = nothing
    
    
function enqueue_transmissions!(state::SimState, ::Type{Val{ConstantKernelContact}}, source_id::Integer, params::SimParams)
    
    t0 = incubation_times[source_id]
    t1 = symptom_onset_times[source_id]
    t2 = hospitalization_times[source_id]
    
    start_time = t0
    end_time = isnan(t1) ? t2 : t1
        
    time_dist = Uniform(start_time, end_time)
    
    total_infection_rate = (end_time - start_time) * params.constant_kernel_param
    num_infections = random(Poisson(total_infection_rate))
    
    if num_infections == 0
        return
    end
    
    num_individuals = size(params.individual_df, 1)
    selected_individuals = sample(num_individuals-1, num_infecitons) # exclude the source itself
    
    for subject_id in selected_individuals
        if subject_id >= source_id # restore the indexing
            subject_id +=1 
        end
        
        if Healthy == subjecthealt(subject_id) 
            infection_time = rand(time_dist)
            push!(state.queue, TransmissionEvent(infection_time, source_id, subject_id, ConstantKernelContact))
        end
    end
end

enqueue_transmissions! (generic function with 1 method)

In [8]:
#isactive(state::HealthState) = (state == Infectious) || (state == StayingHome)

In [9]:
health(state::SimState, person_id::Integer) = state.infection_status[person_id]

subjecthealth(state::SimState, event::AbstractEvent) = health(state, subject(event))
sourcehealth(state::SimState, event::TransmissionEvent) = health(state, source(event))

sourcehealth (generic function with 1 method)

In [17]:
function simulate!(state, params)
    while true
        if isempty(queue)
            @info "Empty queue"
            break
        end
        
        event = pop!(state.event_queue)
        if state.affected_people >= params.stop_simulation_threshold
            @info "The outbreak reached a high number $(params.stop_simulation_threshold)"
            break
        else
            event.time >= params.max_time
            @info "Max time reached"
            break
        end
        
        state.global_time = time(event)
        
        execute!(state, params, event)
    end
end

simulate! (generic function with 1 method)

In [18]:
#h = DataStructures.BinaryHeap{AbstractEvent, DataStructures.LessThan}()
h = DataStructures.BinaryHeap{AbstractEvent, Earlier}()
push!(h, OutsideInfectionEvent(1.2, 34))
push!(h, OutsideInfectionEvent(1.3, 45))
push!(h, TransmissionEvent(1.25, 100, 56, ConstantKernelContact))

pop!(h) |> display
pop!(h) |> display
pop!(h) |> display

OutsideInfectionEvent(1.2f0, 34)

TransmissionEvent(1.25f0, 100, 56, ConstantKernelContact)

OutsideInfectionEvent(1.3f0, 45)