Skip to content

Commit

Permalink
both test pass
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed May 19, 2021
1 parent a4d9508 commit 8dc9af2
Show file tree
Hide file tree
Showing 18 changed files with 549 additions and 56 deletions.
13 changes: 11 additions & 2 deletions apps/microtvm/zephyr/demo_runtime/CMakeLists.txt
Expand Up @@ -10,8 +10,17 @@ find_package(Zephyr HINTS $ENV{ZEPHYR_BASE})
project(microtvm_zephyr_runtime)

set(CMAKE_VERBOSE_MAKEFILE ON)
file(GLOB TVM_SOURCES ${CMAKE_SOURCE_DIR}/__tvm*.c)
target_sources(app PRIVATE src/main.c ${TVM_SOURCES})

if($ENV{ZEPHYR_RUNTIME} STREQUAL "HOST-DRIVEN")
target_sources(app PRIVATE
host_driven/src/main.c
)
elseif($ENV{ZEPHYR_RUNTIME} STREQUAL "ZEPHYR-AOT")
file(GLOB TVM_SOURCES zephyr_aot/src/*.c)
target_sources(app PRIVATE
${TVM_SOURCES}
)
endif()

foreach(tvm_lib ${TVM_LIBS})
string(LENGTH ${tvm_lib} tvm_lib_length)
Expand Down
Expand Up @@ -21,7 +21,7 @@
CONFIG_CMSIS_DSP=y

# Required for Cortex-M33 devices.
CONFIG_MAIN_STACK_SIZE=1536
CONFIG_MAIN_STACK_SIZE=18000

# For random number generation.
CONFIG_ENTROPY_GENERATOR=y
Expand Down
5 changes: 3 additions & 2 deletions apps/microtvm/zephyr/demo_runtime/boards/qemu_x86.conf
Expand Up @@ -21,5 +21,6 @@
CONFIG_TEST_RANDOM_GENERATOR=y
CONFIG_TIMER_RANDOM_GENERATOR=y

# Default stack size is 1k, this is required for debug mode.
CONFIG_MAIN_STACK_SIZE=1536
# Default stack size is 1k, this is required for debug mode and
# for AOT mode.
CONFIG_MAIN_STACK_SIZE=18000
84 changes: 84 additions & 0 deletions apps/microtvm/zephyr/demo_runtime/zephyr_aot/include/zephyr_uart.h
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_APPS_MICROTVM_ZEPHYR_DEMO_RUNTIME_ZEPHYR_AOT_INCLUDE_ZEPHYR_UART_H_
#define TVM_APPS_MICROTVM_ZEPHYR_DEMO_RUNTIME_ZEPHYR_AOT_INCLUDE_ZEPHYR_UART_H_

#include <drivers/uart.h>
#include <sys/ring_buffer.h>

static const struct device* g_utvm_uart;
#define RING_BUF_SIZE_BYTES (TVM_CRT_MAX_PACKET_SIZE_BYTES + 100)

// Ring buffer used to store data read from the UART on rx interrupt.
RING_BUF_DECLARE(uart_rx_rbuf, RING_BUF_SIZE_BYTES);

size_t write_serial(const char* data, size_t size) {
for (size_t i = 0; i < size; i++) {
uart_poll_out(g_utvm_uart, data[i]);
}
return size;
}

static uint8_t uart_data[8];
// UART interrupt callback.
void uart_irq_cb(const struct device* dev, void* user_data) {
while (uart_irq_update(dev) && uart_irq_is_pending(dev)) {
struct ring_buf* rbuf = (struct ring_buf*)user_data;
if (uart_irq_rx_ready(dev) != 0) {
for (;;) {
// Read a small chunk of data from the UART.
int bytes_read = uart_fifo_read(dev, uart_data, sizeof(uart_data));
if (bytes_read < 0) {
TVMPlatformAbort((tvm_crt_error_t)(0xbeef1));
} else if (bytes_read == 0) {
break;
}
// Write it into the ring buffer.
int bytes_written = ring_buf_put(rbuf, uart_data, bytes_read);
if (bytes_read != bytes_written) {
TVMPlatformAbort((tvm_crt_error_t)(0xbeef2));
}
}
}
}
}

// Used to initialize the UART receiver.
void uart_rx_init(struct ring_buf* rbuf, const struct device* dev) {
uart_irq_callback_user_data_set(dev, uart_irq_cb, (void*)rbuf);
uart_irq_rx_enable(dev);
}

// Used to read data from the UART.
int uart_rx_buf_read(struct ring_buf* rbuf, uint8_t* data, size_t data_size_bytes) {
unsigned int key = irq_lock();
int bytes_read = ring_buf_get(rbuf, data, data_size_bytes);
irq_unlock(key);
return bytes_read;
}

// Initialize UART
void TVMPlatformUARTInit() {
// Claim console device.
g_utvm_uart = device_get_binding(DT_LABEL(DT_CHOSEN(zephyr_console)));
uart_rx_init(&uart_rx_rbuf, g_utvm_uart);
}

#endif /* TVM_APPS_MICROTVM_ZEPHYR_DEMO_RUNTIME_ZEPHYR_AOT_INCLUDE_ZEPHYR_UART_H_ */
204 changes: 204 additions & 0 deletions apps/microtvm/zephyr/demo_runtime/zephyr_aot/src/main.c
@@ -0,0 +1,204 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <assert.h>
#include <float.h>
#include <kernel.h>
#include <power/reboot.h>
#include <stdio.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/crt/internal/aot_executor/aot_executor.h>
#include <tvm/runtime/crt/logging.h>
#include <tvm/runtime/crt/stack_allocator.h>
#include <unistd.h>
#include <zephyr.h>

#include "input_data.h"
#include "output_data.h"
#include "zephyr_uart.h"

#ifdef CONFIG_ARCH_POSIX
#include "posix_board_if.h"
#endif

#define WORKSPACE_SIZE (270 * 1024)

static uint8_t g_aot_memory[WORKSPACE_SIZE];
extern tvm_model_t network;
tvm_workspace_t app_workspace;

const unsigned char g_wakeup_sequence[12] = {0xfe, 0xff, 0xfd, 0x03, 0x00, 0x00,
0x00, 0x00, 0x00, 0x02, 0x66, 0x77};

size_t TVMPlatformFormatMessage(char* out_buf, size_t out_buf_size_bytes, const char* fmt,
va_list args) {
return vsnprintk(out_buf, out_buf_size_bytes, fmt, args);
}

void TVMLogf(const char* msg, ...) {
char buffer[256];
int size;
va_list args;
va_start(args, msg);
size = vsprintf(buffer, msg, args);
va_end(args);
write_serial(buffer, (size_t)size);
}

void TVMPlatformAbort(tvm_crt_error_t error) {
TVMLogf("TVMPlatformAbort: %08x\n", error);
sys_reboot(SYS_REBOOT_COLD);
for (;;)
;
}

tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) {
return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr);
}

tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) {
return StackMemoryManager_Free(&app_workspace, ptr);
}

void timer_expiry_function(struct k_timer* timer_id) { return; }

#define MILLIS_TIL_EXPIRY 200
#define TIME_TIL_EXPIRY (K_MSEC(MILLIS_TIL_EXPIRY))
struct k_timer g_utvm_timer;
uint32_t g_utvm_start_time;
int g_utvm_timer_running = 0;

// Called to start system timer.
tvm_crt_error_t TVMPlatformTimerStart() {
if (g_utvm_timer_running) {
TVMLogf("timer already running");
return kTvmErrorPlatformTimerBadState;
}

k_timer_start(&g_utvm_timer, TIME_TIL_EXPIRY, TIME_TIL_EXPIRY);
g_utvm_start_time = k_cycle_get_32();
g_utvm_timer_running = 1;
return kTvmErrorNoError;
}

// Called to stop system timer.
tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
if (!g_utvm_timer_running) {
TVMLogf("timer not running");
return kTvmErrorSystemErrorMask | 2;
}

uint32_t stop_time = k_cycle_get_32();

// compute how long the work took
uint32_t cycles_spent = stop_time - g_utvm_start_time;
if (stop_time < g_utvm_start_time) {
// we rolled over *at least* once, so correct the rollover it was *only*
// once, because we might still use this result
cycles_spent = ~((uint32_t)0) - (g_utvm_start_time - stop_time);
}

uint32_t ns_spent = (uint32_t)k_cyc_to_ns_floor64(cycles_spent);
double hw_clock_res_us = ns_spent / 1000.0;

// need to grab time remaining *before* stopping. when stopped, this function
// always returns 0.
int32_t time_remaining_ms = k_timer_remaining_get(&g_utvm_timer);
k_timer_stop(&g_utvm_timer);
// check *after* stopping to prevent extra expiries on the happy path
if (time_remaining_ms < 0) {
TVMLogf("negative time remaining");
return kTvmErrorSystemErrorMask | 3;
}
uint32_t num_expiries = k_timer_status_get(&g_utvm_timer);
uint32_t timer_res_ms = ((num_expiries * MILLIS_TIL_EXPIRY) + time_remaining_ms);
double approx_num_cycles =
(double)k_ticks_to_cyc_floor32(1) * (double)k_ms_to_ticks_ceil32(timer_res_ms);
// if we approach the limits of the HW clock datatype (uint32_t), use the
// coarse-grained timer result instead
if (approx_num_cycles > (0.5 * (~((uint32_t)0)))) {
*elapsed_time_seconds = timer_res_ms / 1000.0;
} else {
*elapsed_time_seconds = hw_clock_res_us / 1e6;
}

g_utvm_timer_running = 0;
return kTvmErrorNoError;
}

void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint,
int dtype_bits_hint) {
tvm_crt_error_t err = kTvmErrorNoError;
void* ptr = 0;
DLDevice dev = {device_type, device_id};
assert(nbytes > 0);
err = TVMPlatformMemoryAllocate(nbytes, dev, &ptr);
CHECK_EQ(err, kTvmErrorNoError,
"TVMBackendAllocWorkspace(%d, %d, %" PRIu64 ", %d, %d) -> %" PRId32, device_type,
device_id, nbytes, dtype_code_hint, dtype_bits_hint, err);
return ptr;
}

int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
tvm_crt_error_t err = kTvmErrorNoError;
DLDevice dev = {device_type, device_id};
err = TVMPlatformMemoryFree(ptr, dev);
return err;
}

void main(void) {
TVMPlatformUARTInit();
k_timer_init(&g_utvm_timer, NULL, NULL);
// Wake up host side.
write_serial(g_wakeup_sequence, 12);
TVMLogf("Zephyr AOT Runtime\n");

void* inputs[1] = {
input_data,
};
void* outputs[1] = {
output_data,
};

StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE);

double elapsed_time = 0;
TVMPlatformTimerStart();
int ret_val = tvm_runtime_run(&network, inputs, outputs);
TVMPlatformTimerStop(&elapsed_time);

if (ret_val != 0) {
TVMLogf("Error: %d\n", ret_val);
TVMPlatformAbort(kTvmErrorPlatformCheckFailure);
}

// int8_t* results = (int8_t*)(&output_data);
size_t max_ind = -1;
float max_val = -FLT_MAX;
for (size_t i = 0; i < output_data_len; i++) {
if (output_data[i] >= max_val) {
max_ind = i;
max_val = output_data[i];
}
}
TVMLogf("result:%d\n", max_ind);
#ifdef CONFIG_ARCH_POSIX
posix_exit(0);
#endif
}
3 changes: 1 addition & 2 deletions docker/bash.sh
Expand Up @@ -98,7 +98,7 @@ else
fi

if [[ "${DOCKER_IMAGE_NAME}" == *"ci"* ]]; then
CI_ADDON_ENV="-e PYTHONPATH=/workspace/python:/workspace/.local/lib/python3.6/site-packages"
CI_ADDON_ENV="-e PYTHONPATH=/workspace/python"
else
CI_ADDON_ENV=""
fi
Expand Down Expand Up @@ -167,7 +167,6 @@ ${DOCKER_BINARY} run --rm --pid=host\
${CI_ADDON_ENV} \
${CUDA_ENV} \
"${CI_DOCKER_EXTRA_PARAMS[@]}" \
--mount type=bind,source=/home/mhessar/tinymlperf,target=/tinymlperf \
${DOCKER_IMAGE_NAME} \
bash --login /docker/with_the_same_user \
"${COMMAND[@]}"

0 comments on commit 8dc9af2

Please sign in to comment.