# Preparing workspace

In [40]:
#r "netstandard"
#r @"bin\Cntk.Core.Managed-2.6.dll"
#load @".paket\load\main.group.fsx"
 
open System
open System.IO

Environment.GetEnvironmentVariable("PATH")
|> fun path -> sprintf "%s%c%s" path (Path.PathSeparator) (Path.GetFullPath("bin"))
|> fun path -> Environment.SetEnvironmentVariable("PATH", path)

open CNTK
DeviceDescriptor.UseDefaultDevice().Type
|> printfn "Congratulations, you are using CNTK for: %A" 

Congratulations, you are using CNTK for: CPU


In [51]:
let device = CNTK.DeviceDescriptor.CPUDevice
let dataType = CNTK.DataType.Float
let initialization = CNTKLib.GlorotUniformInitializer(1.0)
let input_dim, num_output_classes = 2,2

#load "fsx/CntkHelpers.fsx"
open CntkHelpers

let input = Variable.InputVariable(shape [|input_dim|], dataType, "Features")
let label = Variable.InputVariable(shape [|num_output_classes|], dataType, "Labels")
let z = fullyConnectedClassifierNet input [5] num_output_classes CNTKLib.Sigmoid

In [102]:
open System.Collections.Generic

// combine filters to single lambda
let inline (<&>) (filter1 : 'T -> bool) (filter2 : 'T -> bool) =
        fun (t:'T) -> filter1 t && filter2 t

let decomposeFunction (root: Function) = 
        
        let visited = Dictionary<string, Function>()
        
        let inline excludeVisitedNodes (f: Function) = 
            if f = null then false
            else
                let key = f.Uid
                if visited.ContainsKey(key) then false
                else visited.Add(key, f); true
    
        let inline includeNodesWithOwner (v: Variable) = v.Owner |> isNull |> not             

        let rec search (f: Function) = 
            //printfn "%s\t%s" (f.Uid) ((Var f).Uid)
            visited.Add(f.Uid, f)
        
            seq {   yield! f.Inputs
                    yield! f.Outputs
                    yield  Var f.RootFunction }
            //|> Seq.filter (excludeVisitedNodes <&> includeNodesWithOwner)
            |> Seq.map (fun v -> v.Owner)
            |> Seq.filter (fun f -> f <> null && visited.ContainsKey(f.Uid) = false)            
            //|> Seq.filter (excludeVisitedNodes)
            |> Seq.iter search

        search root
    
        visited

In [103]:
let decz = decomposeFunction z

In [105]:
decz
|> Seq.cast<KeyValuePair<string,Function>> 
|> Seq.map(fun pair -> pair.Key) 
|> Array.ofSeq

[|"CompositeFunction1627"; "Plus1626"; "Times1623"; "StableSigmoid1618";
  "Plus1615"; "Times1612"|]

In [93]:
// https://github.com/Microsoft/CNTK/blob/master/bindings/python/cntk/logging/graph.py

let extractGraphVizDotNotation (f: Function) = 
        let varText (v:Variable) = (if String.IsNullOrEmpty v.Name then v.Uid else v.Name) + "\\n" + v.Shape.AsString()
        let funText (f: Function) = if String.IsNullOrEmpty f.Name then f.Uid else f.Name
    
        let varLabel (v: Variable) = sprintf "%s [label=\"%s\"];" v.Uid (varText v)
        let funLabel (f: Function) = sprintf "%s [label=\"%s\"];" f.Uid (funText f)

        let varShape (v: Variable) =
            match v with
            | _ when v.IsInput -> sprintf "%s [shape=invhouse, color=yellow];" v.Uid
            | _ when v.IsOutput -> sprintf "%s [shape=invhouse, color=gray];" v.Uid
            | _ when v.IsPlaceholder -> sprintf "%s [shape=invhouse, color=yellow];" v.Uid
            | _ when v.IsParameter -> sprintf "%s [shape=diamond, color=green];" v.Uid
            | _ when v.IsConstant -> sprintf "%s [shape=rectangle, color=lightblue];" v.Uid
            | _ -> sprintf "%s [shape=circle, color=purple];" v.Uid

        let funShape (f: Function) = 
            match f with 
            | _ when f.IsComposite -> sprintf "%s [shape=ellipse, fontsize=20, penwidth=2, peripheries=2];" f.Uid
            | _ when f.IsPrimitive -> sprintf "%s [shape=ellipse, fontsize=20, penwidth=2, size=0.6];" f.Uid
            | _ -> sprintf "%s [shape=ellipse, fontsize=20, penwidth=4];" f.Uid

        let varEdges (f: Function) (v: Variable) = 
            let inputIndex = f.Inputs |> Seq.map (fun v -> v.Uid) |> Set
            let outputIndex = f.Outputs |> Seq.map (fun v -> v.Uid) |> Set
            match inputIndex.Contains(v.Uid), outputIndex.Contains(v.Uid) with 
            | true, _ when v.IsParameter -> sprintf "%s -> %s [label=\"input param\"];" v.Uid f.Uid |> Some
            | _, true when v.IsParameter -> sprintf "%s -> %s [label=\"output param\"];" f.Uid v.Uid|> Some
            | true, _ -> sprintf "%s -> %s [label=input];" v.Uid f.Uid|> Some
            | _, true -> sprintf "%s -> %s [label=output];" f.Uid v.Uid|> Some
            //| _ when v.IsParameter -> sprintf "%s -> %s [label=param];" f.Uid v.Uid|> Some
            | _ -> None //sprintf "%s -> %s;" f.Uid v.Uid

        let vars = Seq.append f.Inputs f.Outputs
        let funs = seq { 
                yield f
                yield f.RootFunction;
                yield! vars |> Seq.filter (fun v -> v.Owner |> isNull |> not) |> Seq.map (fun v -> v.Owner) 
            } 

        seq {        
            if f.Uid <> f.RootFunction.Uid then yield sprintf "%s -> %s [label=\"root function\"];" f.RootFunction.Uid f.Uid
            yield! vars |> Seq.map varShape
            yield! vars |> Seq.map varLabel
            yield! vars |> Seq.map (varEdges f) |> Seq.choose id
            yield! funs |> Seq.map funLabel 
            yield! funs |> Seq.map funShape
        } |> Seq.distinct


In [110]:
let createGraphVizDiagram (f:Function) =
    f 
    |> decomposeFunction 
    |> Seq.cast<KeyValuePair<string,Function>> 
    |> Seq.map(fun pair -> pair.Value) 
    //|> Seq.distinct
    //|> Seq.where(fun v -> v.Owner)
//     |> Seq.where (function Fun _ -> true | _ -> false)
    |> Seq.where (fun f -> f.IsComposite |> not  )
    |> Seq.collect extractGraphVizDotNotation
    |> Seq.distinct //|> Array.ofSeq |> Array.filter (fun str -> str.Contains("->")) |> Array.sort
    |> Seq.sort
    
    
createGraphVizDiagram z |> Array.ofSeq

[|"Input1608 -> Times1612 [label=input];"; "Input1608 [label="Features\n[2]"];";
  "Input1608 [shape=invhouse, color=yellow];";
  "Parameter1610 -> Times1612 [label="input param"];";
  "Parameter1610 [label="Weights\n[5 x 2]"];";
  "Parameter1610 [shape=diamond, color=green];";
  "Parameter1611 -> Plus1615 [label="input param"];";
  "Parameter1611 [label="Bias\n[5]"];";
  "Parameter1611 [shape=diamond, color=green];";
  "Parameter1621 -> Times1623 [label="input param"];";
  "Parameter1621 [label="Weights\n[2 x 5]"];";
  "Parameter1621 [shape=diamond, color=green];";
  "Parameter1622 -> Plus1626 [label="input param"];";
  "Parameter1622 [label="Bias\n[2]"];";
  "Parameter1622 [shape=diamond, color=green];";
  "Plus1615 -> Plus1615_Output_0 [label=output];"; "Plus1615 [label="Layer"];";
  "Plus1615 [shape=ellipse, fontsize=20, penwidth=2, size=0.6];";
  "Plus1615_Output_0 -> StableSigmoid1618 [label=input];";
  "Plus1615_Output_0 [label="Layer\n[5]"];";
  "Plus1615_Output_0 [shape=invhou

In [53]:
@"<script src='../../d3-jupyter/dist/bundle.js'></script>" |> Util.Html |> Display

In [54]:
let initGraph jqueryPath = 
    sprintf "<script>$(document).trigger('INIT_D3', ['%s']);</script>" jqueryPath
    |> Util.Html
    
let renderDot dotNotation = 
    sprintf "<script>$(document).trigger('RENDER_GRAPH', [`%s`]);</script>" dotNotation
    |> Util.Html
    
let renderSeries (digraphs : string[]) = 
    digraphs
    |> Array.reduce(sprintf "%s','%s")
    |> sprintf "<script>$(document).trigger('RENDER_SERIES', [['%s']]);</script>"
    |> Util.Html    
    
initGraph "#graph"    

In [111]:
createGraphVizDiagram z 
|> Seq.reduce(sprintf "%s\n%s")
|> sprintf "digraph { %s }"
|> renderDot

<div id="graph" style="width: 100%;height:800px; border: solid lightblue 1px"></div>

In [None]:
"""<script>
        var dot = `
  digraph {
    Combine1787 -> CompositeFunction1788 [label="root function"];
Combine1787 [label="Combine1787"];
Combine1787 [shape=ellipse, fontsize=20, penwidth=2, size=0.6];
Combine1789 -> CompositeFunction1790 [label="root function"];
Combine1789 [label="Combine1789"];
Combine1789 [shape=ellipse, fontsize=20, penwidth=2, size=0.6];
Combine1791 -> CompositeFunction1792 [label="root function"];
Combine1791 [label="Combine1791"];
Combine1791 [shape=ellipse, fontsize=20, penwidth=2, size=0.6];
Combine1793 -> CompositeFunction1794 [label="root function"];
Combine1793 [label="Combine1793"];
Combine1793 [shape=ellipse, fontsize=20, penwidth=2, size=0.6];
Combine1795 -> CompositeFunction1796 [label="root function"];
Combine1795 [label="Combine1795"];
Combine1795 [shape=ellipse, fontsize=20, penwidth=2, size=0.6];
Combine1797 -> CompositeFunction1798 [label="root function"];
Combine1797 [label="Combine1797"];
Combine1797 [shape=ellipse, fontsize=20, penwidth=2, size=0.6];
Combine1799 -> CompositeFunction1800 [label="root function"];
Combine1799 [label="Combine1799"];
Combine1799 [shape=ellipse, fontsize=20, penwidth=2, size=0.6];
Combine1801 -> CompositeFunction1802 [label="root function"];
Combine1801 [label="Combine1801"];
Combine1801 [shape=ellipse, fontsize=20, penwidth=2, size=0.6];
Combine1803 -> CompositeFunction1804 [label="root function"];
Combine1803 [label="Combine1803"];
Combine1803 [shape=ellipse, fontsize=20, penwidth=2, size=0.6];
  }`;
        var viz = d3.select("#graph")
            .graphviz()
            //.logEvents(true)
            .on("initEnd", function () {

                //viz.engine("circo");
                viz.renderDot(dot)
                    .on("end", function () {

                        let svg = d3.select("svg");//.attr("width", "1750pt").attr("height","1000pt");

                        let defs = svg.append("defs");

                        let filter = defs
                            .append("filter")
                            .attr("id", "shadow")
                            .attr("x", "-50%")
                            .attr("y", "-50%")
                            .attr("width", "200%")
                            .attr("height", "200%");
                        filter.append("feGaussianBlur")
                            .attr("in", "SourceAlpha")
                            .attr("stdDeviation", 3)
                            .attr("result", "blur");
                        filter
                            .append("feOffset")
                            .attr("in", "blur")
                            .attr("dx", 3)
                            .attr("dy", 3);
                        filter.append("feComponentTransfer")
                            .append("feFuncA").attr("type", "linear")
                            .attr("slope", 0.35);

                        var merge = filter.append("feMerge");
                        merge.append("feMergeNode");
                        merge.append("feMergeNode").attr("in", "SourceGraphic");

                        d3.selectAll(".node ellipse, .node polygon")
                            .style("fill", "white")
                            .on("mouseover", (d, i, n) => d3.select(n[i]).style("filter", "url(#shadow)"))
                            .on("mouseout", (d, i, n) => d3.select(n[i]).style("filter", null));
                    });
            });
    </script>""" |> Util.Html |> Display

In [None]:
"""<script>
        var canvas = d3.select("#graph")
                        .append("svg")
                        .attr("width", 500)
                        .attr("height", 500);
        var circle = canvas.append("circle")
                        .attr("cx",250)
                        .attr("cy", 250)
                        .attr("r", 50)
                        .attr("fill", "red");
    </script>
""" |> Util.Html

In [67]:
z.Uid, (Var z).Uid, (Var z).ToFunction().Uid

("CompositeFunction1627", "Plus1626_Output_0", "CompositeFunction1846")

In [71]:
input.ToFunction().Uid

"CompositeFunction1899"