In [1]:
# Full imports
import gym
import cv2

# Partial imports 
from tqdm.notebook import tqdm, trange
from telesketch.envs.discrete_telesketch import DiscreteTelesketchEnv
from typing import Any, List, Sequence, Tuple
from numpy import linalg as LA
""
# Aliased imports
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_probability as tfp
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

2023-02-20 14:23:42.619895: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-20 14:23:42.710031: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-02-20 14:23:42.729446: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-02-20 14:23:43.050839: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: li

In [2]:
# Ref canvas
IMG_SIZE = 64
PATCH_SIZE = 8
ref_canvas = np.full((IMG_SIZE, IMG_SIZE, 3), 255, dtype=np.uint8)
ref_canvas = cv2.line(ref_canvas, (0, 0), (IMG_SIZE, IMG_SIZE), (0, 0, 0), 2)

# Sim func
def rmse_sim(x, y):
    return np.sum((x - y) ** 2) / x.size

env = DiscreteTelesketchEnv(ref_canvas, rmse_sim, 4, 2, patch_size=(PATCH_SIZE, PATCH_SIZE), render_mode="image")

In [10]:
global_input = env.observation_space["cnv"].shape + (1,)
local_input = env.observation_space["cnv_patch"].shape + (1,)

dmap_block = keras.Sequential([
    keras.Input(shape=global_input),
    keras.layers.Conv2D(32, (8, 8), (4, 4)),
    keras.layers.LeakyReLU(alpha=0.01),
    keras.layers.Conv2D(64, (4, 4), (2, 2)),
    keras.layers.LeakyReLU(alpha=0.01),
    keras.layers.Conv2D(64, (3, 3)),
    keras.layers.Flatten()
])

global_ref_block = keras.Sequential([
    keras.Input(shape=global_input),
    keras.layers.Conv2D(32, (8, 8), (4, 4)),
    keras.layers.LeakyReLU(alpha=0.01),
    keras.layers.Conv2D(64, (4, 4), (2, 2)),
    keras.layers.LeakyReLU(alpha=0.01),
    keras.layers.Conv2D(64, (3, 3)),
    keras.layers.Flatten()
])

global_cnv_block = keras.Sequential([
    keras.Input(shape=global_input),
    keras.layers.Conv2D(32, (8, 8), (4, 4)),
    keras.layers.LeakyReLU(alpha=0.01),
    keras.layers.Conv2D(64, (4, 4), (2, 2)),
    keras.layers.LeakyReLU(alpha=0.01),
    keras.layers.Conv2D(64, (3, 3)),
    keras.layers.Flatten()
])

local_ref_input = keras.Sequential([
    keras.Input(shape=local_input),
    keras.layers.Conv2D(64, (8, 8)),
    keras.layers.LeakyReLU(alpha=0.01),
    keras.layers.Flatten()
])

local_cnv_input = keras.Sequential([
    keras.Input(shape=local_input),
    keras.layers.Conv2D(64, (8, 8)),
    keras.layers.LeakyReLU(alpha=0.01),
    keras.layers.Flatten()
])

block_concat = keras.layers.Concatenate()([
    dmap_block.output,
    global_ref_block.output,
    global_cnv_block.output,
    local_ref_input.output,
    local_cnv_input.output
])

actor_out = keras.layers.Dense(8)(block_concat)
critic_out = keras.layers.Dense(1)(block_concat)

model = keras.Model(
    [
        dmap_block.input,
        global_ref_block.input, 
        global_cnv_block.input, 
        local_ref_input.input, 
        local_cnv_input.input
    ], 
    [
        actor_out, 
        critic_out
    ]
)

model.summary()

Model: "model_6"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_26 (InputLayer)          [(None, 64, 64, 1)]  0           []                               
                                                                                                  
 input_27 (InputLayer)          [(None, 64, 64, 1)]  0           []                               
                                                                                                  
 input_28 (InputLayer)          [(None, 64, 64, 1)]  0           []                               
                                                                                                  
 conv2d_37 (Conv2D)             (None, 15, 15, 32)   2080        ['input_26[0][0]']               
                                                                                            