Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LogGraph #32

Merged
merged 19 commits into from Jun 28, 2019
3 changes: 2 additions & 1 deletion Project.toml
Expand Up @@ -8,6 +8,7 @@ CRC32c = "8bf52ea8-c179-5cab-976a-9e18b702a9bc"
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

Expand All @@ -24,4 +25,4 @@ TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
WAV = "8149f6b0-98f6-5db9-b78f-408fbbb8ef88"

[targets]
test = ["Test", "Flux", "TestImages", "ImageMagick", "WAV"]
test = ["Test", "TestImages", "ImageMagick", "WAV", "Flux"]
83 changes: 83 additions & 0 deletions src/Loggers/LogGraph.jl
@@ -0,0 +1,83 @@
"""
log_graph
"""
function log_graph(logger::TBLogger, g::AbstractGraph; step = nothing, nodelabel::Vector{String} = map(string, vertices(g)), nodeop::Vector{String} = map(string, vertices(g)), nodedevice::Vector{String} = fill("cpu", nv(g)), nodevalue::Vector{Any} = fill(nothing, nv(g)))
@assert nv(g) == length(nodelabel) "length of nodelable must be same as number of vertices"
@assert nv(g) == length(nodeop) "length of nodeop must be same as number of vertices"
@assert nv(g) == length(nodedevice) "length of nodedevice must be same as number of vertices"
@assert nv(g) == length(nodevalue) "length of nodevalue must be same as number of vertices"
shashikdm marked this conversation as resolved.
Show resolved Hide resolved
summ = SummaryCollection(graph_summary(g, nodelabel, nodeop, nodedevice, nodevalue))
write_event(logger.file, make_event(logger, summ, step=step))
end

function graph_summary(g, nodelabel, nodeop, nodedevice, nodevalue)
function getdtype(dtype)
shashikdm marked this conversation as resolved.
Show resolved Hide resolved
nodetype =
dtype == nothing ? _DataType.DT_INVALID :
dtype == UInt8 ? _DataType.DT_UINT8 :
dtype == UInt16 ? _DataType.DT_UINT16 :
dtype == UInt32 ? _DataType.DT_UINT32 :
dtype == UInt64 ? _DataType.DT_UINT64 :
dtype == Int8 ? _DataType.DT_INT8 :
dtype == Int16 ? _DataType.DT_INT16 :
dtype == Int32 ? _DataType.DT_INT32 :
dtype == Int64 ? _DataType.DT_INT64 :
dtype == Float16 ? _DataType.DT_BFLOAT16 :
dtype == Float32 ? _DataType.DT_FLOAT :
dtype == Float64 ? _DataType.DT_DOUBLE :
dtype <: AbstractString ? _DataType.DT_STRING :
dtype == Bool ? _DataType.DT_BOOL :
dtype ∈ [Complex{Float32},
Complex{Float16},
Complex{UInt8},
Complex{UInt16},
Complex{UInt32},
Complex{Int8},
Complex{Int16},
Complex{Int32}] ? _DataType.DT_COMPLEX64 :
dtype ∈ [Complex{Float64},
Complex{UInt64},
Complex{Int64}] ? _DataType.DT_COMPLEX128 :
@error "Unknown Datatype" dtype
shashikdm marked this conversation as resolved.
Show resolved Hide resolved
end
nodes = Vector{NodeDef}()
for v in vertices(g)
name = nodelabel[v]
op = nodeop[v]
input = [nodelabel[x] for x in inneighbors(g, v)]
device = nodedevice[v]
attr = Dict{String, AttrValue}()
x = nodevalue[v]
if isa(x, AbstractString)
attr["value"] = AttrValue(s = Vector{UInt8}(x))
attr["dtype"] = AttrValue(_type = getdtype(typeof(x)))
elseif isa(x, Integer)
attr["value"] = AttrValue(i = Int64(x))
attr["dtype"] = AttrValue(_type = getdtype(typeof(x)))
elseif isa(x, Real)
attr["value"] = AttrValue(f = Float32(x))
attr["dtype"] = AttrValue(_type = getdtype(typeof(x)))
elseif isa(x, Bool)
attr["value"] = AttrValue(b = x)
attr["dtype"] = AttrValue(_type = getdtype(typeof(x)))
elseif isa(x, AbstractArray)
shape = TensorShapeProto(dim = [TensorShapeProto_Dim(size = d) for d in (collect(size(x)))])
t = TensorProto(dtype = getdtype(eltype(x)), tensor_shape = shape, tensor_content = serialize_proto(string(x)))
attr["value"] = AttrValue(tensor = t)
attr["_output_shapes"] = AttrValue(list = AttrValue_ListValue(shape = [shape]))
elseif isa(x, Tuple)
attr["value"] = AttrValue(list = AttrValue_ListValue(s = [Vector{UInt8}(repr(y)) for y in x]))
shape = TensorShapeProto(dim = [TensorShapeProto_Dim(size = length(x))])
attr["_output_shapes"] = AttrValue(list = AttrValue_ListValue(shape = [shape]))
elseif isa(x, Function)
attr["value"] = AttrValue(func = NameAttrList(name = repr(x)))
elseif x == nothing
#donothing
else
@error "unhandled nodevalue type $(typeof(x))"
end
node = NodeDef(name = name, op = op, input = input, device = device, attr = attr)
push!(nodes, node)
end
GraphDef(node = nodes)
end
4 changes: 3 additions & 1 deletion src/TensorBoardLogger.jl
Expand Up @@ -5,13 +5,14 @@ using ImageCore
using ColorTypes
using FileIO
using FileIO: @format_str
using LightGraphs
using StatsBase #TODO: remove this. Only needed to compute histogram bins.
using Base.CoreLogging: global_logger, LogLevel, Info
import Base.CoreLogging:
AbstractLogger, handle_message, shouldlog, min_enabled_level,
catch_exceptions

export log_histogram, log_value, log_vector, log_text, log_image, log_images, log_audio, log_audios
export log_histogram, log_value, log_vector, log_text, log_image, log_images, log_audio, log_audios, log_graph
export scalar_summary, histogram_summary, text_summary, make_event
export TBLogger
export reset!, set_step!, increment_step!
Expand Down Expand Up @@ -43,6 +44,7 @@ include("Loggers/LogText.jl")
include("Loggers/LogHistograms.jl")
include("Loggers/LogImage.jl")
include("Loggers/LogAudio.jl")
include("Loggers/LogGraph.jl")

include("logger_dispatch.jl")
include("logger_dispatch_overrides.jl")
Expand Down
7 changes: 3 additions & 4 deletions src/event.jl
Expand Up @@ -3,11 +3,10 @@
SummaryCollection(;kwargs...) = Summary(value=Vector{Summary_Value}(); kwargs...)
SummaryCollection(summaries::Vector{Summary_Value}; kwargs...) = Summary(value=summaries; kwargs...)
SummaryCollection(summary::Summary_Value; kwargs...) = Summary(value=[summary]; kwargs...)
SummaryCollection(summary::GraphDef; kwargs...) = summary

function make_event(logger::TBLogger, summary::Summary;
step::Int=TensorBoardLogger.step(logger))
return Event(wall_time=time(), summary=summary, step=step)
end
make_event(logger::TBLogger, summary::Summary; step::Int=TensorBoardLogger.step(logger)) = Event(wall_time=time(), summary=summary, step=step)
make_event(logger::TBLogger, summary::GraphDef; step::Int=TensorBoardLogger.step(logger)) = Event(wall_time=time(), graph_def=serialize_proto(summary), step=step)

function write_event(file::IOStream, event::Event)
data = PipeBuffer();
Expand Down
8 changes: 7 additions & 1 deletion src/utils.jl
Expand Up @@ -6,8 +6,14 @@ function masked_crc32c(data)
return UInt32(((x >> 15) | UInt32(x << 17)) + 0xa282ead8)
end

function serialize_proto(data::Any)
function serialize_proto(data::Union{ProtoType, ProtoEnum})
pb = PipeBuffer()
writeproto(pb, data)
pb.data
end

function serialize_proto(data::Any)
pb = PipeBuffer()
write(pb, data)
pb.data
end
19 changes: 17 additions & 2 deletions test/runtests.jl
@@ -1,12 +1,12 @@
using TensorBoardLogger, Logging
using TensorBoardLogger: preprocess, summary_impl
using Test
using Flux.Data.MNIST
using Flux, Flux.Data.MNIST
using TestImages
using ImageCore
using ColorTypes
using FileIO

using LightGraphs

@testset "TBLogger" begin
include("test_TBLogger.jl")
Expand Down Expand Up @@ -217,6 +217,21 @@ end
log_audios(logger, "audiosample", samples, fs, step = step)
end

@testset "Graph Logger" begin
logger = TBLogger("test_logs/t", tb_overwrite)
step = 1
ss = TensorBoardLogger.graph_summary(DiGraph(1), ["1"], ["1"], ["cpu"], [nothing])
@test isa(ss, TensorBoardLogger.GraphDef)
g = DiGraph(7)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
add_edge!(g, 3, 6)
add_edge!(g, 4, 6)
add_edge!(g, 5, 6)
add_edge!(g, 5, 7)
log_graph(logger, g, step = step, nodedevice = ["cpu", "cpu", "gpu", "gpu", "gpu", "gpu", "cpu"], nodevalue = [1, "tf", 3.14, [1.0 2.0; 3.0 4.0], true, +, (10, "julia", 12.4)])
end

@testset "Logger dispatch overrides" begin
include("test_logger_dispatch_overrides.jl")
end