Skip to content

Commit

Permalink
Merge pull request #22 from UM-Bridge/main
Browse files Browse the repository at this point in the history
Pull request to update my branch
  • Loading branch information
linusseelinger committed Sep 18, 2023
2 parents 0d4a708 + 9129da6 commit dc80bd6
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 61 deletions.
15 changes: 11 additions & 4 deletions clients/matlab/ttClient.m
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
% Analytic-Banana benchmark. Approximate log-posterior density in TT format
% Check for (and download) TT Toolbox
check_tt;

uri = 'http://localhost:4243';
model = HTTPModel(uri,'posterior');
model = HTTPModel(uri, 'posterior');

% TT demo with umbridge model
tol = 1e-6;
x = linspace(-5,5,33);
TTlogLikelihood = amen_cross(numel(x)*ones(2,1), @(i)model.evaluate(x(i)), tol, 'vec', false)
tol = 1e-5;
d = model.get_input_sizes
x = linspace(-5,5,33); % A uniform grid on [-5,5]^d for analytic-banana
% TTlogLikelihood = amen_cross(numel(x)*ones(d,1), @(i)model.evaluate(x(i)), tol, 'vec', false)
tic;
TTlogLikelihood = greedy2_cross(numel(x)*ones(d,1), @(i)model.evaluate(x(i)), tol, 'vec', false)
toc

% benchmark-analytic-banana: 44.138067 seconds with 'send' engine
% 4.250689 seconds with 'webwrite' engine
8 changes: 8 additions & 0 deletions julia/juliaClient.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using UMBridge

url = "http://localhost:4242"

val = UMBridge.evaluate(url, "forward" ,[[1,3]], Dict())

print(val)

17 changes: 11 additions & 6 deletions matlab/@HTTPModel/HTTPModel.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,26 @@
properties
uri
model_name
send_engine % 'send' or 'webwrite'
end

methods
function model = HTTPModel(uri,model_name)
import matlab.net.*
import matlab.net.http.*

function model = HTTPModel(uri,model_name,send_engine)
model.uri = uri;
model.model_name = model_name;
if (nargin<3) || (isempty(send_engine))
model.send_engine = 'send';
else
model.send_engine = send_engine;
end

% check protocol, make sure it's matching
r = RequestMessage;
uri = URI([uri, '/Info']);
uri = matlab.net.URI([uri, '/Info']);
r = matlab.net.http.RequestMessage; % This must be GET
resp = send(r,uri);
HTTPModel.check_http_status(resp); % check if connection broke down
json = jsondecode(resp.Body.string);

if json.protocolVersion ~= 1
error('the protocol version on the server side does not match the client')
end
Expand All @@ -45,6 +49,7 @@

methods (Access=private)
output_json = get_model_info(self);
output_json = send_data(self, uri, value);
end

methods (Static, Access=public)
Expand Down
12 changes: 3 additions & 9 deletions matlab/@HTTPModel/apply_hessian.m
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@
config = struct;
end

import matlab.net.*
import matlab.net.http.*

% Evaluate model
r = RequestMessage('POST');
if (isa(input, 'cell'))
value.input = input;
elseif (isa(input, 'double'))
Expand All @@ -35,11 +31,9 @@
value.inWrt2 = inWrt2;
value.outWrt = outWrt;
value.config = config;
r.Body = MessageBody(jsonencode(value));
uri = URI([self.uri, '/ApplyHessian']);
resp = send(r,uri);
self.check_http_status(resp);
json = jsondecode(resp.Body.string);
value = jsonencode(value);
uri = matlab.net.URI([self.uri, '/ApplyHessian']);
json = self.send_data(uri, value);
self.check_error(json);
output = json.output;

Expand Down
12 changes: 3 additions & 9 deletions matlab/@HTTPModel/apply_jacobian.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
config = struct;
end

import matlab.net.*
import matlab.net.http.*

% Evaluate model
r = RequestMessage('POST');
if (isa(input, 'cell'))
value.input = input;
elseif (isa(input, 'double'))
Expand All @@ -30,11 +26,9 @@
value.inWrt = inWrt;
value.outWrt = outWrt;
value.config = config;
r.Body = MessageBody(jsonencode(value));
uri = URI([self.uri, '/ApplyJacobian']);
resp = send(r,uri);
self.check_http_status(resp);
json = jsondecode(resp.Body.string);
value = jsonencode(value);
uri = matlab.net.URI([self.uri, '/ApplyJacobian']);
json = self.send_data(uri, value);
self.check_error(json);
output = json.output;

Expand Down
15 changes: 5 additions & 10 deletions matlab/@HTTPModel/evaluate.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
config = struct;
end

import matlab.net.*
import matlab.net.http.*

% Evaluate model
r = RequestMessage('POST');
% Parse inputs
if (isa(input, 'cell'))
value.input = input;
elseif (isa(input, 'double'))
Expand All @@ -21,11 +17,10 @@
end
value.name = self.model_name;
value.config = config;
r.Body = MessageBody(jsonencode(value));
uri = URI([self.uri, '/Evaluate']);
resp = send(r,uri);
self.check_http_status(resp);
json = jsondecode(resp.Body.string);
value = jsonencode(value);
uri = matlab.net.URI([self.uri, '/Evaluate']);
% Evaluate model
json = self.send_data(uri, value);
self.check_error(json);
output = json.output;

Expand Down
11 changes: 2 additions & 9 deletions matlab/@HTTPModel/get_model_info.m
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
function [output_json] = get_model_info(self)

import matlab.net.*
import matlab.net.http.*

r = RequestMessage('POST');
value.name = self.model_name;
r.Body = MessageBody(jsonencode(value));
uri = URI([self.uri,'/ModelInfo']);
resp = send(r,uri);
self.check_http_status(resp); % check if connection broke down
json = jsondecode(resp.Body.string);
uri = matlab.net.URI([self.uri,'/ModelInfo']);
json = self.send_data(uri, jsonencode(value));
self.check_error(json); % check if message came in but contains an error
output_json = json.support;

Expand Down
12 changes: 3 additions & 9 deletions matlab/@HTTPModel/gradient.m
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
config = struct;
end

import matlab.net.*
import matlab.net.http.*

% Evaluate model
r = RequestMessage('POST');
if (isa(input, 'cell'))
value.input = input;
elseif (isa(input, 'double'))
Expand All @@ -30,11 +26,9 @@
value.inWrt = inWrt;
value.outWrt = outWrt;
value.config = config;
r.Body = MessageBody(jsonencode(value));
uri = URI([self.uri, '/Gradient']);
resp = send(r,uri);
self.check_http_status(resp);
json = jsondecode(resp.Body.string);
value = jsonencode(value);
uri = matlab.net.URI([self.uri, '/Gradient']);
json = self.send_data(uri, value);
self.check_error(json);
output = json.output;

Expand Down
16 changes: 16 additions & 0 deletions matlab/@HTTPModel/send_data.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
function output_json = send_data(self, uri, value)
if strcmpi(self.send_engine, 'send')
r = matlab.net.http.RequestMessage('POST');
r.Body = matlab.net.http.MessageBody(value);
resp = send(r, uri);
self.check_http_status(resp);
output_json = jsondecode(resp.Body.string);
else
opts = weboptions('Timeout', inf, 'RequestMethod', 'POST');
opts.MediaType = 'application/json';
output_json = webwrite(uri, value, opts);
if ~isa(output_json, 'struct')
output_json = jsondecode(output_json);
end
end
end
9 changes: 4 additions & 5 deletions matlab/umbridge_supported_models.m
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
function [model_names] = umbridge_supported_models(uri)
import matlab.net.*
import matlab.net.http.*

% Get model info
r = RequestMessage;
uri = URI([uri, '/Info']);
uri = matlab.net.URI([uri, '/Info']);

r = matlab.net.http.RequestMessage; % This needs to be GET
resp = send(r,uri);
HTTPModel.check_http_status(resp); % check if connection broke down
json = jsondecode(resp.Body.string);

HTTPModel.check_error(json); % check if message came in but contains an error
model_names = json.models;

Expand Down

0 comments on commit dc80bd6

Please sign in to comment.