pretrain affine Q network

In [3]:
using Pkg
Pkg.activate("/home/jiaxingl/project/verify_julia_env")
Pkg.status()

[32m[1m  Activating[22m[39m project at `~/project/verify_julia_env`


[32m[1mStatus[22m[39m `~/project/verify_julia_env/Project.toml`
  [90m[d8c2afa5] [39mCersyve v1.0.0-DEV `~/project/Cersyve.jl`
[33m⌅[39m [90m[587475ba] [39mFlux v0.13.17
  [90m[f67ccb44] [39mHDF5 v0.17.2
  [90m[7073ff75] [39mIJulia v1.26.0
[33m⌅[39m [90m[033835bb] [39mJLD2 v0.4.53
  [90m[6d061d49] [39mModelVerification v0.1.0 `/home/jiaxingl/project/ModelVerification.jl#cersyve`
  [90m[85610aed] [39mNaiveNASflux v2.0.8 `~/project/ModelVerification.jl/onnx_parser/NaiveNASflux`
[32m⌃[39m [90m[bd45eb3e] [39mNaiveNASlib v2.0.11
[32m⌃[39m [90m[d0dd6a25] [39mONNX v0.2.0
  [90m[2e935253] [39mONNXNaiveNASflux v0.2.7 `~/project/ModelVerification.jl/onnx_parser/ONNXNaiveNASflux`
  [90m[438e738f] [39mPyCall v1.96.4
  [90m[9a3f8284] [39mRandom
[36m[1mInfo[22m[39m Packages marked with [32m⌃[39m and [33m⌅[39m have new versions available, but those with [33m⌅[39m are restricted by compatibility constraints from upgrading. To see why use `status --outdated

In [5]:
using Cersyve
using Flux
using JLD2
using Random

struct FilterX
    W::Matrix  # Weight matrix
end

struct FilterU
    W::Matrix  # Weight matrix
end

function (layer::FilterX)(input::Matrix{Float32})
    return layer.W * input
end

Flux.@functor FilterX  # Make the layer compatible with Flux
function Flux.params(layer::FilterX)
    return Flux.Params([])  # Exclude weights from being trainable
end

# Define a filtering layer for extracting u (indices 9 to 14)


function (layer::FilterU)(input::Matrix{Float32})
    return layer.W * input
end

Flux.@functor FilterU  # Make the layer compatible with Flux
function Flux.params(layer::FilterU)
    return Flux.Params([])  # Exclude weights from being trainable
end

# Initialize the fixed weight matrices for filtering
function create_filter_matrix(start_idx, end_idx, total_len)
    W = zeros(end_idx - start_idx + 1, total_len)
    for i in start_idx:end_idx
        W[i - start_idx + 1, i] = 1.0
    end
    return W
end

function create_cat_affine_Q(x_dim, u_dim)
    x_filter = FilterX(create_filter_matrix(1, x_dim, x_dim+u_dim))  # Extract x
    u_filter = FilterU(create_filter_matrix(x_dim, x_dim+u_dim-1, x_dim+u_dim))  # Extract u

    # Define the branch1 network (process x)
    branch1 = Chain(
        Dense(x_dim, 32, relu),  # First hidden layer (32 neurons, input size is 8 for x)
        Dense(32, 32, relu)  # Second hidden layer (32 neurons)
    )

    # Define the final output layer (scalar output)
    final_layer = Chain(Dense(32 + u_dim, 1))  # Concatenation of x (32) and u (6)

    # Complete model
    model = Chain(
        x -> (x_filter(x), u_filter(x)),  # Apply the filters to extract x and u
        x -> (branch1(x[1]), x[2]),       # Process x through branch1, keep u unchanged
        x -> vcat(x[1], x[2]),            # Concatenate outputs of branch1 and u
        final_layer                       # Compute scalar output
    )
    return model
end

function create_add_affine_Q(x_dim, u_dim)
    # Assume the input has 13 elements: x (0–7), u (8–13)
    x_filter = FilterX(create_filter_matrix(1, x_dim, x_dim+u_dim))  # Extract x
    u_filter = FilterU(create_filter_matrix(x_dim, x_dim+u_dim-1, x_dim+u_dim))  # Extract u

    b1_l1 = Dense(x_dim, 32, relu)
    b1_l2 = Dense(32, u_dim, relu)

    # Define the final output layer (scalar output)
    final_layer = Chain(Dense(u_dim, 1))  # Concatenation of x (32) and u (6)

    # Complete model
    model = Chain(
        x -> (x_filter(x), u_filter(x)),  # Apply the filters to extract x and u
        x1 -> (branch1(x[1])),       # Process x through branch1, keep u unchanged
        x -> (x1 + x[2]),               # add outputs of branch1 and u
        final_layer                       # Compute scalar output
    )
    return model
end

function create_parallel_affine_Q(x_dim, u_dim)
    # Assume the input has 13 elements: x (0–7), u (8–13)
    x_w = create_filter_matrix(1, x_dim, x_dim+u_dim)
    x_b = zeros(x_dim)
    filter_x = Dense(x_w, x_b)

    u_w = create_filter_matrix(x_dim, x_dim+u_dim-1, x_dim+u_dim)
    u_b = zeros(u_dim)
    filter_u = Dense(u_w, u_b)
    
    #Branch1
    b1 = Chain(
        filter_x,  # First hidden layer (32 neurons, input size is 8 for x)
        Dense(x_dim, 32, relu),  # First hidden layer (32 neurons, input size is 8 for x)
        Dense(32, 32, relu)  # Second hidden layer (32 neurons)
    )

    # Define the final output layer (scalar output)
    final_layer = Chain(Dense(32 + u_dim, 1))  # Concatenation of x (32) and u (6)

    # Complete model
    model = Chain(
        Parallel(
            vcat, 
            b1,
            filter_u
        ),
        final_layer                       # Compute scalar output
    )
    return model
end

task = Unicycle
value_hidden_sizes = [32, 32]
dynamics_hidden_sizes = [32, 32]
constraint_hidden_sizes = [16]
data_path = joinpath(@__DIR__, "../data/unicycle_data.jld2")
model_dir = joinpath(@__DIR__, "../model/unicycle/")
log_dir = joinpath(@__DIR__, "../log/unicycle/")
seed = 1

Random.seed!(seed)

# V_model = Cersyve.create_mlp(task.x_dim, 1, value_hidden_sizes)
# Q_model = Cersyve.create_mlp(task.x_dim + task.u_dim, 1, value_hidden_sizes)


data = JLD2.load(data_path)["data"]
f_model = Cersyve.create_mlp(task.x_dim + task.u_dim, task.x_dim, dynamics_hidden_sizes)
Flux.loadmodel!(f_model, JLD2.load(joinpath(model_dir, "f.jld2"), "state"))
f_pi_model = Cersyve.create_closed_loop_dynamics_model(
    f_model, task.pi_model, data, task.x_low, task.x_high, task.u_dim)

h_model = Cersyve.create_mlp(task.x_dim, 1, constraint_hidden_sizes)
Flux.loadmodel!(h_model, JLD2.load(joinpath(model_dir, "h.jld2"), "state"))

x_a_low =  [task.x_low; task.u_low]
x_a_high = [task.x_high; task.u_high]


affine_Q = create_parallel_affine_Q(task.x_dim, task.u_dim)

# trainable parameters
# println(affine_Q[1][1][2])
# println(affine_Q[1][1][3])
# println(affine_Q[2])

# pretrain_Q(
#     affine_Q,
#     f_pi_model,
#     task.pi_model,
#     h_model,
#     task.x_low,
#     task.x_high;
#     penalty="APA",
#     space_size=x_a_high - x_a_low,
#     apa_coef=1e-4,
#     log_dir=log_dir,
# )




Chain(
  Parallel(
    vcat,
    Chain(
      Dense(5 => 3),                    [90m# 18 parameters[39m
      Dense(3 => 32, relu),             [90m# 128 parameters[39m
      Dense(32 => 32, relu),            [90m# 1_056 parameters[39m
    ),
    Dense(5 => 2),                      [90m# 12 parameters[39m
  ),
  Chain(
    Dense(34 => 1),                     [90m# 35 parameters[39m
  ),
) [90m                  # Total: 10 arrays, [39m1_249 parameters, 5.871 KiB.

In [7]:
function find_min_at_vertices(Q_model, x::Vector{Float32}, u_low::Vector{Float32}, u_high::Vector{Float32}, x_dim::Int)
    # Step 1: Compute vertices
    vertices = compute_vertices(u_low, u_high)  # Each column is a vertex

    # Step 2: Evaluate Q_model at each vertex
    min_value = Inf
    min_vertex = nothing
    for i in 1:size(vertices, 2)
        u = vertices[:, i]             # Extract the i-th vertex (u part)
        xu = copy(x)                   # Clone x
        xu[x_dim+1:end] .= u           # Replace u portion in x
        Q_val = Q_model(xu)[1]         # Evaluate Q_model (assume scalar output)
        # Step 3: Check if this is the minimum
        if Q_val < min_value
            min_value = Q_val
            min_vertex = u
        end
    end

    return min_value, min_vertex
end


function compute_vertices(u_low::Vector{Float32}, u_high::Vector{Float32})
    n = length(u_low)  # Dimension of the space
    vertices = []      # Initialize an empty array to store vertices

    # Iterate over all 2^n combinations
    for i in 0:(2^n - 1)
        vertex = Float32[]  # Initialize a vertex
        for j in 1:n
            # Check the j-th bit of i to decide low or high
            if (i >> (j - 1)) & 1 == 0
                push!(vertex, u_low[j])
            else
                push!(vertex, u_high[j])
            end
        end
        push!(vertices, vertex)  # Add the vertex to the list
    end

    return reduce(hcat, vertices)  # Return vertices as a matrix (each column is a vertex)
end

x = Float32.(rand(task.x_dim+task.u_dim))
min_value, min_vertex = find_min_at_vertices(affine_Q, x, task.u_low, task.u_high, task.x_dim)
println(min_value, min_vertex)


[33m[1m│ [22m[39m  The input will be converted, but any earlier layers may be very slow.
[33m[1m│ [22m[39m  layer = Dense(3 => 32, relu)  [90m# 128 parameters[39m
[33m[1m│ [22m[39m  summary(x) = "3-element Vector{Float64}"
[33m[1m└ [22m[39m[90m@ Flux ~/.julia/packages/Flux/n3cOc/src/layers/stateless.jl:60[39m


-0.20661746Float32[-1.0, -1.0]
