# Inverse graphics using graphics-based generative models and MCMC

### Installation:
Don't run this code (I'm not even sure if you can; there may be permission issues), 
but I am including it to show how we installed Mitsuba3 in our environment
which is by using a Conda environment

```julia
using Pkg
Pkg.activate("myenv")
Pkg.add("Conda")
using Conda
Conda.pip("install", "mitsuba")

```

In [1]:
using Pkg
Pkg.activate("myenv")
using Distributions
using ProgressMeter
using Gen, Plots
using Parameters
using PyCall
numpy = pyimport("numpy")
mi = pyimport("mitsuba")
mi.set_variant("scalar_rgb")
@pyinclude("./cbox-generic.py")

[32m[1m  Activating[22m[39m project at `~/Algorithms-of-the-Mind/labs/lab-05/myenv`


# Brief Mitsuba3 Introduction

Mitsuba3 is a physically-based graphics rendered. This means that instead of taking some of the common shortcuts (rasterization), it more faithfully captures the physical process by which light bounces of surfaces and creates images. This process is called ray-tracing, in which rays are simulated from the light sources in the scene, which bouncess off of the surfaces and gets sensed by a sensor (e.g., eye, camera) placed somewhere in the scene.

For our purposes Mitsuba3 serves as 

In [2]:
#use the teapot and transform it in various ways

@pyinclude("./cbox-generic.py")
scene_d = py"initialize_scene()"o
# load a scene with a teapot
scene = mi.load_dict(scene_d)

# read a data structure of what is in this scene 
params = @pycall mi.traverse(scene)::PyObject

modelparams = ModelParams(scene = scene)
mitsuba_render_2(modelparams)

# Review this output and the xml file itself, which 

LoadError: UndefVarError: `ModelParams` not defined

# Inverse graphics

We start by writing our generative model -- which in this case wraps a graphics engine 

In [2]:
@with_kw struct ModelParams
    path::String = "./cbox_generic.xml"
    scene::PyObject = @pycall mi.load_file(path)::PyObject
    #params::PyObject = @pycall mi.traverse(scene)::PyObject
    spp::Int32 = 16
end

function mitsuba_transform(modelparams, object, scale, translation)
    t = mi.Transform4f.translate(translation).scale(scale)
    set!(modelparams.params, object, t)
    #println(get(modelparams.params, object))
end

function mitsuba_transform_2(scene_d, object, scale, translation)
    t = mi.Transform4f.translate(translation).scale(scale)
    lb = PyObject(get(scene_d, object))
    set!(lb, "to_world", t)
    set!(scene_d, object, lb)
end

function mitsuba_render(modelparams)
    image = @pycall mi.render(modelparams.scene, spp=modelparams.spp)::PyObject
    bitmap = @pycall mi.Bitmap(image).convert(srgb_gamma=true)::PyObject
    mu = @pycall numpy.array(bitmap)::Array{Float32, 3}
end

function mitsuba_render_2(modelparams)
    image = @pycall mi.render(modelparams.scene, spp=modelparams.spp)::PyObject
    bitmap = @pycall mi.Bitmap(image).convert(srgb_gamma=true)::PyObject
end

mitsuba_render_2 (generic function with 1 method)

In [3]:
#@gen function room(modelparams::ModelParams)
@gen function room()

    scene_d = py"initialize_scene()"o::PyObject
    
    # prior over the scale of the left sphere
    s_left ~ uniform(0.1, 1.0)
    # we assume we know a priori where this sphere will appear in the scene
    t_left = [-0.3, -0.5, 0.2]
    mitsuba_transform_2(scene_d, "left-object", s_left, t_left)
    #mitsuba_transform(modelparams, "left-object.to_world", s_left, t_left)
    
    # prior over the scale of the right sphere
    s_right ~ uniform(0.1, 1.0)
    # we assume we know a priori where this sphere will appear in the scene
    t_right = [0.5, -0.75, -0.2] 
    mitsuba_transform_2(scene_d, "right-object", s_right, t_right)
    #mitsuba_transform(modelparams, "right-object.to_world", s_right, t_right)

    scene = mi.load_dict(scene_d)
    modelparams = ModelParams(scene=scene)

    #py"$modelparams.params.update()"o

    # render the scene and get the output in an array
    #bitmap = mitsuba_render_2(modelparams)
    #mu = @pycall numpy.array(bitmap)::Array{Float32, 3}
    mu = mitsuba_render(modelparams)

    pred ~ broadcasted_normal(mu, 1)

end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], var"##room#292", Bool[], false)

In [5]:
scene_d = py"initialize_scene()"o
scene = mi.load_dict(scene_d)
modelparams = ModelParams(scene=scene)

ModelParams
  path: String "./cbox_generic.xml"
  scene: PyObject
  spp: Int32 16


In [6]:
trace = Gen.simulate(room, ())
pred = Gen.get_retval(trace)

128×128×3 Array{Float64, 3}:
[:, :, 1] =
 -0.507486    -2.35315     0.225231  …   1.52657     0.286634   -0.740246
  0.600153    -0.361887   -0.676484      0.563702    1.31995     0.47113
 -1.44814     -1.1177     -0.204935      2.01847    -0.326344   -2.15873
 -0.349315     0.59885    -0.1675        0.816439   -1.01983     0.848047
  0.189625     0.237773   -0.898043     -0.284799    0.109829    0.631023
 -0.658513     1.2538     -0.294827  …  -1.52137    -1.50899    -2.53637
  0.248365     0.223267    0.230716      2.42842     0.293488   -1.01267
 -0.845496    -0.172073    1.46944      -0.0764017   0.164691   -0.198467
 -1.66382     -0.576817   -0.351042      0.236869   -1.95389     0.615337
 -0.0445027    0.437693   -0.388322     -0.513755   -0.0780845   0.854136
 -0.789243     0.137362    1.0328    …   0.437443    0.77958    -0.526408
  0.743617     0.933641   -0.386164      0.630846   -0.289547    1.16369
  2.09444      0.0633071   0.048906     -1.05585     0.313983   -0.825128
  

In [7]:
get_choices(trace)

│
├── :s_left : 0.5226638946376776
│
├── :pred : [-0.5074856129595317 -2.353152331783289 0.22523059161994655 1.4336897526315096 0.47081704593653373 0.5706051399214431 0.2874271218173733 -1.5947563878465667 -0.6818109467734447 0.8919884969694771 2.051640258182933 1.39267625650542 -1.1044746888189774 -1.1428022066902352 -1.1544362676006494 0.6286689917144408 -0.5975377566408502 0.3898063019989889 0.6144240914334153 -0.3086350890694263 0.5618853301369868 0.12117416799932429 -1.4120214637815 1.381328993495927 -1.3931442972654149 -3.645129540922589 0.30906269537549125 0.05215264805819646 1.5235994182187342 0.014052931983433006 0.24317833431904326 1.1900325608959164 -0.3854700751869542 0.9050157696398045 0.7563104681540379 0.05077839167851651 1.9762278682924075 -1.1153389893770995 1.8294423341382637 -0.5965731047786381 -0.3705930060940722 -0.8354472486538713 -1.6878689481827325 -1.0296943137003292 -2.0106142981974164 1.7735762285436651 -1.5705040128011563 -1.459281055171917 -0.28905456161254

### Now let's make an observed image using an entry in the Mitsuba gallery

In [4]:
obs_scene = ModelParams(path = "./scenes/cbox.xml", spp=64)
obs_bitmap = mitsuba_render_2(obs_scene)
obs_image = Gen.choicemap()
obs_image[:pred] = @pycall numpy.array(obs_bitmap)::Array{Float32, 3}
# to view the image on your notebook
obs_bitmap

### Inference: Random walk MH

In [5]:
# include the truncated norm distribution
include("truncatednorm.jl")

In [6]:
# proposal distribution for the scale variables
@gen function scale_proposal(current_trace)
    # trunc_norm(mean, std, lower_bound, upper_bound)
    # why do we need a truncated norm, instead of a regular normal distribution?
    s_left ~ trunc_norm(current_trace[:s_left], 0.2, 0.1, 1.)
    s_right ~ trunc_norm(current_trace[:s_right], 0.2, 0.1, 1.)
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##scale_proposal#293", Bool[0], false)

In [7]:
function random_walk_mh(tr)
    
    # make a random-walk update on scale variables
    (tr, accepted) = mh(tr, scale_proposal, ())

    # return the updated trace
    tr
end

random_walk_mh (generic function with 1 method)

In [16]:
function do_inference()
    K = 200
    t, = generate(room, (), obs_image)
    scores = Vector{Float64}(undef, K)
    @showprogress for i in 1:K
        t = random_walk_mh(t)
        scores[i] = get_score(t)
    end;
    println(scores)
    return scores, t
end

do_inference (generic function with 1 method)

In [None]:
(scores, t) = do_inference()

[32mProgress:   4%|█▋                                       |  ETA: 0:00:59[39m

In [15]:
scene_d = py"initialize_scene()"o::PyObject
t_left = [-0.3, -0.5, 0.2]
mitsuba_transform_2(scene_d, "left-object", t[:s_left], t_left)
    
t_right = [0.5, -0.75, -0.2] 
mitsuba_transform_2(scene_d, "right-object", t[:s_right], t_right)
scene = mi.load_dict(scene_d)
modelparams = ModelParams(scene=scene)
bitmap = mitsuba_render_2(modelparams)
