In [1]:
using Graphs
using Printf
using DataFrames
using CSV
using LinearAlgebra

include("functions.jl")


fit (generic function with 1 method)

In [2]:
infile = "data/small.csv"

In [5]:
function init_data(fname)
    df = CSV.File(fname) |> DataFrame
    data_mat = Matrix(df);
    column_names = names(df)
    num_instance = [maximum(data_mat[:, i]) - minimum(data_mat[:, i]) + 1 for i in 1:size(data_mat, 2)]
    vars = [Variable(Symbol(column_names[i]), num_instance[i]) for i in 1:length(column_names)]
    G = SimpleDiGraph(length(column_names)) 
    D = data_mat'
    p = bayesian_score(vars, G, data_mat')   # feed transposed data!!!!
    return vars, G, D, p 
end

init_data (generic function with 1 method)

In [6]:
vars, G, D, P = init_data(infile)

(Variable[Variable(:age, 3), Variable(:portembarked, 3), Variable(:fare, 3), Variable(:numparentschildren, 3), Variable(:passengerclass, 3), Variable(:sex, 2), Variable(:numsiblings, 3), Variable(:survived, 2)], SimpleDiGraph{Int64}(0, [Int64[], Int64[], Int64[], Int64[], Int64[], Int64[], Int64[], Int64[]], [Int64[], Int64[], Int64[], Int64[], Int64[], Int64[], Int64[], Int64[]]), [1 2 … 1 2; 1 2 … 2 3; … ; 1 1 … 1 1; 1 2 … 2 1], -4166.225858784901)

In [14]:
struct K2Search 
    ordering::Vector{Int}    # variable ordering 
end


function fit(method::K2Search, vars, D)
    G = SimpleDiGraph(length(vars))
    y = 0
    for (k,i) in enumerate(method.ordering[2:end])
#         println(k,i)
        y = bayesian_score(vars, G, D)
        while true 
            y_best, j_best = -Inf, 0 
            for j in method.ordering[1:k]
                if !has_edge(G, j, i)
                    add_edge!(G, j, i)
                    y_new = bayesian_score(vars, G, D)
                    if y_new > y_best 
                        y_best, j_best = y_new, j 
                    end 
                    rem_edge!(G, j, i)
                end
            end 
            # select and add the best edge 
            if y_best > y 
                y = y_best 
                add_edge!(G, j_best, i)
            else 
                break
            end
        end 
    end
    return G, y
end 


fit (generic function with 1 method)

In [15]:
# Function to generate permutations
function permutations(arr)
    n = length(arr)
    if n == 1
        return [arr]
    else
        perms = []
        for i = 1:n
            first_elem = arr[i]
            rest = [arr[j] for j in 1:n if j != i]
            subperms = permutations(rest)
            for p in subperms
                push!(perms, [first_elem; p])
            end
        end
        return perms
    end
end

permutations (generic function with 1 method)

In [16]:
method = K2Search([1,2,3,4,5,6,7,8])

K2Search([1, 2, 3, 4, 5, 6, 7, 8])

In [17]:
G_, BS_best = fit(method, vars, D)

(SimpleDiGraph{Int64}(10, [[4, 5], [5, 6], [5], [6, 7], [6, 8], [8], Int64[], Int64[]], [Int64[], Int64[], Int64[], [1], [1, 2, 3], [2, 4, 5], [4], [5, 6]]), -3835.67942521279)

In [18]:
using ProgressMeter

elements = 1:8
# permutations = collect(permutations(elements))
perm = permutations(elements)

# Print the permutations
G  = SimpleDiGraph(length(vars))
bs = -Inf

@showprogress dt=1 desc="Computing..." for p in perm
    method = K2Search(p)
    G_new, bs_new = fit(method, vars, D)
    
    if bs_new > bs 
        bs = bs_new 
        G = G_new
    end 
end

[32mComputing... 100%|███████████████████████████████████████| Time: 1:33:42[39m


In [20]:
bs

-3794.855597709798

In [21]:
function write_gph(dag::DiGraph, idx2names, filename)
    open(filename, "w") do io
        for edge in edges(dag)
            @printf(io, "%s,%s\n", idx2names[src(edge)], idx2names[dst(edge)])
        end
    end
end

write_gph (generic function with 1 method)

In [22]:
column_names = names(df)

8-element Vector{String}:
 "age"
 "portembarked"
 "fare"
 "numparentschildren"
 "passengerclass"
 "sex"
 "numsiblings"
 "survived"

In [25]:
idx2names = Dict(
    1 => "age", 2 => "protembarked", 3 => "fare", 4=>"numparentschildren", 5=>"passengerclass",
    6 => "sex", 7 =>"numsiblings", 8=>"survived"
                ) 

Dict{Int64, String} with 8 entries:
  5 => "passengerclass"
  4 => "numparentschildren"
  6 => "sex"
  7 => "numsiblings"
  2 => "protembarked"
  8 => "survived"
  3 => "fare"
  1 => "age"

In [26]:
write_gph(G, idx2names, "opt_graph1.gph")