Skip to content

Commit

Permalink
[External] [stdlib] Add method unsafe_get to List (#40553)
Browse files Browse the repository at this point in the history
[External] [stdlib] Add method `unsafe_get` to `List`

See modularml#2677 (comment)
for the background about this method.

We are currently missing methods to access elements in collections
without any bounds checks or wraparound, for maximum performance. I
suggest that we introduce `unsafe_get` and `unsafe_set` to our List-like
collections. This is equivalent to
* https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_unchecked
*
https://doc.rust-lang.org/std/vec/struct.Vec.html#method.get_unchecked_mut
in Rust.

This should prove useful to makers of high performance libraries like
Max :p

We can then make `__getitem__` and `__setitem__` as safe as we want
without impacting the power users.

Co-authored-by: Gabriel de Marmiesse <gabriel.demarmiesse@datadoghq.com>
Closes modularml#2800
MODULAR_ORIG_COMMIT_REV_ID: 5321c191d83262f240c5d8d5ac77afccaa00a7ba

Signed-off-by: Avinag <udayagiriavinag@gmail.com>
  • Loading branch information
gabrieldemarmiesse authored and Av1nag committed May 27, 2024
1 parent 0bcdf3c commit 87fb3bb
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,11 @@ what we publish.
- Added `os.getsize` function, which gives the size in bytes of a path.
([PR 2626](https://github.com/modularml/mojo/pull/2626) by [@artemiogr97](https://github.com/artemiogr97))

- `List` now has a method `unsafe_get` to get the reference to an
element without bounds check or wraparound for negative indices.
Note that this method is unsafe. Use with caution.
([PR #2800](https://github.com/modularml/mojo/pull/2800) by [@gabrieldemarmiesse](https://github.com/gabrieldemarmiesse))

- Added `fromkeys` method to `Dict` to return a `Dict` with the specified keys
and value.
([PR 2622](https://github.com/modularml/mojo/pull/2622) by [@artemiogr97](https://github.com/artemiogr97))
Expand Down
38 changes: 37 additions & 1 deletion stdlib/src/collections/list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,43 @@ struct List[T: CollectionElement](CollectionElement, Sized, Boolable):
if i < 0:
normalized_idx += self[].size

return (self[].data + normalized_idx)[]
return self[].unsafe_get(normalized_idx)

@always_inline
fn unsafe_get[
IndexerType: Indexer,
](self: Reference[Self, _, _], idx: IndexerType) -> Reference[
Self.T, self.is_mutable, self.lifetime
]:
"""Get a reference to an element of self without checking index bounds.
Users should consider using `__getitem__` instead of this method as it is unsafe.
If an index is out of bounds, this method will not abort, it will be considered
undefined behavior.
Note that there is no wraparound for negative indices, caution is advised.
Using negative indices is considered undefined behavior.
Never use `my_list.unsafe_get(-1)` to get the last element of the list. It will not work.
Instead, do `my_list.unsafe_get(len(my_list) - 1)`.
Parameters:
IndexerType: The type of the argument used as index.
Args:
idx: The index of the element to get.
Returns:
A reference to the element at the given index.
"""
var idx_as_int = index(idx)
debug_assert(
0 <= idx_as_int < len(self[]),
(
"The index provided must be within the range [0, len(List) -1]"
" when using List.unsafe_get()"
),
)
return (self[].data + idx_as_int)[]

fn count[T: ComparableCollectionElement](self: List[T], value: T) -> Int:
"""Counts the number of occurrences of a value in the list.
Expand Down
22 changes: 22 additions & 0 deletions stdlib/test/collections/test_list.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ def test_list():
assert_equal(7, list[-1])


def test_list_unsafe_get():
var list = List[Int]()

for i in range(5):
list.append(i)

assert_equal(5, len(list))
assert_equal(0, list.unsafe_get(0)[])
assert_equal(1, list.unsafe_get(1)[])
assert_equal(2, list.unsafe_get(2)[])
assert_equal(3, list.unsafe_get(3)[])
assert_equal(4, list.unsafe_get(4)[])

list[2] = -2
assert_equal(-2, list.unsafe_get(2)[])

list.clear()
list.append(2)
assert_equal(2, list.unsafe_get(0)[])


def test_list_clear():
var list = List[Int](1, 2, 3)
assert_equal(len(list), 3)
Expand Down Expand Up @@ -789,6 +810,7 @@ def test_indexing():
def main():
test_mojo_issue_698()
test_list()
test_list_unsafe_get()
test_list_clear()
test_list_to_bool_conversion()
test_list_pop()
Expand Down

0 comments on commit 87fb3bb

Please sign in to comment.