In [1]:
using Pkg; Pkg.activate(".")
using HypergraphModularity

using Random
using StatsBase

[32m[1m Activating[22m[39m environment at `~/codes/hypergraph_modularities_code/Project.toml`


In [2]:
identity(p::Vector{Int64}) = p

function sum_of_ext_degs(p::Vector{Int64})
    soed = length(p) - 1
    return (sum(p), soed)
end

function all_or_nothing(p::Vector{Int64})
    is_aon = length(p) == 1
    return (sum(p), is_aon)
end

function rainbow(p::Vector{Int64})
    is_rainbow = length(p) == sum(p) && length(p) > 1
    return (sum(p), is_rainbow)
end


function MLE_ll(Hyp, Z, agg)
    Ω̂ = estimateΩEmpirically(Hyp, Z; min_val=0, aggregator=agg)
    return Float64(sum(L(Hyp, Z, Ω̂; α=0, bigInt=true)))    
end

MLE_ll (generic function with 1 method)

In [3]:
function sub_hypergraph(h::hypergraph, labels, in_subset::Vector{Bool})
    # Get new set of edges
    new_edges = []
    for (sz, edges) in h.E
        for (edge, val) in edges
            new_edge = filter(v -> in_subset[v], edge)
            if length(new_edge) > 1
               push!(new_edges, new_edge)
            end
        end
    end
    
    # renumbering
    node_map = Dict{Int64,Int64}()
    for (i, val) in enumerate(in_subset)
        if val
            node_map[i] = length(node_map) + 1
        end
    end
    
    renumber_edge(e) = [node_map[v] for v in e]
    renumbered_new_edges = [renumber_edge(e) for e in new_edges]
    
    subE = Dict{Integer, Dict}()
    for edge in renumbered_new_edges
        sz = length(edge)
        if !haskey(subE, sz)
            subE[sz] = Dict{}()
        end
        subE[sz][edge] = 1
    end
    
    n = length(node_map)
    subD = zeros(Int64, n)
    for (sz, edges) in subE
        for (e, _) in edges
            subD[e] .+= 1
        end
    end
    
    return hypergraph(1:n, subE, subD), labels[in_subset], node_map
end

sub_hypergraph (generic function with 1 method)

In [4]:
# merged labels
function merge_gender_labels()
    merge_names = String[]
    merge_labels = Int64[]
    split_names = [split(name) for name in names]
    for label in labels
        merge_name = split_names[label][1]
        if !(merge_name in merge_names)
           push!(merge_names, merge_name) 
        end
        ind = findall(merge_names .== merge_name)[1]
        push!(merge_labels, ind)
    end
    return merge_names, merge_labels
end

merge_gender_labels (generic function with 1 method)

24-element Array{String,1}:
 "5B M"
 "5B F"
 "5A F"
 "5A M"
 "4A M"
 "4A F"
 "Teachers Unknown"
 "3B M"
 "3B F"
 "4B F"
 "2A F"
 "2A M"
 "4A Unknown"
 "4B M"
 "5A Unknown"
 "1B M"
 "1B F"
 "2B F"
 "2B M"
 "1A F"
 "3A M"
 "3A F"
 "1A M"
 "1A Unknown"

In [8]:
for dataset in ["primary-school", "high-school"]
    H, labels = read_hypergraph_data("contact-$dataset-classes-gender", 10)
    names = read_hypergraph_label_names("contact-$dataset-classes-gender")

    # Throw out data where gender is unknown
    keep = ones(Bool, length(labels))
    for (i, label) in enumerate(labels)
        keep[i] = !occursin("Unknown", names[label])
    end

    merge_names, merge_labels = merge_gender_labels()

    subH, sublabels1, _ = sub_hypergraph(H, labels, keep)
    _, sublabels2, _ = sub_hypergraph(H, merge_labels, keep)
    
    println("dataset = ", dataset)
    println("MLE (with genders): ", MLE_ll(subH, sublabels1, all_or_nothing))
    println("MLE (just classes): ", MLE_ll(subH, sublabels2, all_or_nothing))
    
    println("")
end

dataset = primary-school
MLE (with genders): -470152.1789186721
MLE (just classes): -54676.65251816025

dataset = high-school
MLE (with genders): -92326.53363478063
MLE (just classes): -27550.859421192486



In [None]:
# out of date?

function dyadic_MLE_ll(Hyp, Z, weighted::Bool)
    ω_in, ω_out = 
        computeDyadicResolutionParameter(H, Z; mode="ω", weighted=weighted)
    return Float64(dyadicLogLikelihood(H, Z, ω_in, ω_out, weighted=weighted))
end
dyadic_MLE_ll(subH, sublabels1, true)