Skip to content
This repository has been archived by the owner on Apr 4, 2024. It is now read-only.

Commit

Permalink
fixed memory leak
Browse files Browse the repository at this point in the history
  • Loading branch information
Benny-Nottonson committed Feb 8, 2024
1 parent 19adb5f commit 4ec0889
Showing 1 changed file with 47 additions and 58 deletions.
105 changes: 47 additions & 58 deletions voodoo/graph.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -177,70 +177,59 @@ struct Graph:
return node

fn get_free_data[unique: Bool = False](inout self, node: Node) raises:
let data_id = node.get_data_id()
if data_id != -1:
if node.get_data_id() != -1:
return

let parents = node.get_parents()
let parents_len = len(parents)
let node_cap = node.get_cap()
let node_is_static = node.get_is_static()
let node_checkpoint = node.get_checkpoint()
let node_is_single = node.get_is_single()

var parent_found = False
if (
not node_is_static
and not node_checkpoint
and not unique
and not node_is_single
):
for i in range(parents_len):
let parent = self._nodes[parents[i]]
if (
self.load_ceiled_cap(parent.get_cap())
== self.load_ceiled_cap(node_cap)
and parent.get_dependencies() == 1
and not parent.get_is_static()
and not parent.get_checkpoint()
and not parent.get_is_single()
):
node.set_data_id(parent.get_data_id())
node.set_data(self._memory_pool[node.get_data_id()])
parent_found = True
break

for i in range(parents_len):
if not parent_found:
let parent = self._nodes[parents[i]]
let dependencies = parent.get_dependencies()
parent.set_dependencies(dependencies - 1)
var idx = -1
for i in range(len(node.get_parents())):
let ind = node.get_parents()[i]
let parent = self._nodes[node.get_parents()[i]]
if (
self.load_ceiled_cap(parent.get_cap())
== self.load_ceiled_cap(node.get_cap())
and parent.get_dependencies() == 1
and not parent.get_is_static()
and not node.get_is_static()
and not parent.get_checkpoint()
and not node.get_checkpoint()
and not unique
and not parent.get_is_single()
and not node.get_is_single()
):
node.set_data_id(parent.get_data_id())
node.set_data(self._memory_pool[node.get_data_id()])
idx = i
break

if not parent_found:
self.handle_no_matching_parent(node)
for i in range(len(node.get_parents())):
if i == idx:
continue
else:
let parent = self._nodes[node.get_parents()[i]]
parent.set_dependencies(parent.get_dependencies() - 1)

fn handle_no_matching_parent(inout self, node: Node) raises:
let index = self.get_index(node.get_cap())
var mem_pool = self._memory_pool_manager[index]
if len(mem_pool) > 0:
let data_id = mem_pool.pop_back()
node.set_data_id(data_id)
let ceiled_cap = self.load_ceiled_cap(node.get_cap())
if idx == -1:
let index = self.get_index(node.get_cap())
var mem_pool = self._memory_pool_manager[index]
if len(mem_pool) > 0:
let data_id = mem_pool.pop_back()
node.set_data_id(data_id)
let ceiled_cap = self.load_ceiled_cap(node.get_cap())

node.set_data(self._memory_pool[node.get_data_id()])
memset_zero(node.get_data(), ceiled_cap)
else:
let data_id = self.get_free_data_id()
node.set_data_id(data_id)
let ceiled_cap = self.load_ceiled_cap(node.get_cap() + 1)
let new_data_ptr = DTypePointer[DType.float32].alloc(ceiled_cap)
if data_id == len(self._memory_pool):
self._memory_pool.push_back(new_data_ptr)
node.set_data(self._memory_pool[node.get_data_id()])
memset_zero(node.get_data(), ceiled_cap)
else:
self._memory_pool[data_id] = new_data_ptr

node.set_data(self._memory_pool[node.get_data_id()])
memset_zero(node.get_data(), ceiled_cap)
let data_id = self.get_free_data_id()
node.set_data_id(data_id)
let ceiled_cap = self.load_ceiled_cap(node.get_cap() + 1)
let new_data_ptr = DTypePointer[DType.float32].alloc(ceiled_cap)
if data_id == len(self._memory_pool):
self._memory_pool.push_back(new_data_ptr)
else:
self._memory_pool[data_id] = new_data_ptr

node.set_data(self._memory_pool[node.get_data_id()])
memset_zero(node.get_data(), ceiled_cap)

fn get_free_grad(inout self, node: Node) raises:
if node.get_grad_id() != -1:
Expand Down

0 comments on commit 4ec0889

Please sign in to comment.