Skip to content

Commit aaad2ba

Browse files
committed
refactor: 发现虚存可以从中间开始映射,重做虚存抽象
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 17109a0 commit aaad2ba

File tree

4 files changed

+105
-83
lines changed

4 files changed

+105
-83
lines changed

cuda/src/context.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ fn test_mem_info() {
235235
println!("mem info: {free}/{total}");
236236
// 从池中分配空间
237237
let stream = ctx.stream();
238-
let mem = stream.malloc::<u8>((free.0 >> 30).saturating_sub(1) << 30);
238+
let mem = stream.malloc::<u8>(free.0 * 3 / 4);
239239
let (free, total) = ctx.mem_info();
240240
println!("mem info: {free}/{total}");
241241
// 释放的存储只会回到池中,驱动不可见

cuda/src/graph/memcpy.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -151,29 +151,31 @@ mod test {
151151
let prop = dev.mem_prop();
152152
let minium = prop.granularity_minimum();
153153

154-
let mut dst = VirMem::new(minium, 0).map_on(&dev);
155-
let src = VirMem::new(minium, 0).map_on(&dev);
154+
let mut dst_vir = VirMem::new(minium, 0);
155+
let mut src_vir = VirMem::new(minium, 0);
156+
let dst = dst_vir.map(0, prop.create(minium));
157+
let src = src_vir.map(0, prop.create(minium));
156158

157159
let graph = Graph::new();
158160
// 虚存不能直接传入 memcpy node,当时必须是已映射状态
159-
graph.add_memcpy_d2d(&mut dst, &src, &[]);
161+
graph.add_memcpy_d2d(dst, src, &[]);
160162

161163
let phy0 = prop.create(minium);
162164
let phy1 = prop.create(minium);
163165
let phy2 = prop.create(minium);
164166
let phy3 = prop.create(minium);
165167

166-
let (vir_dst, _) = dst.unmap();
167-
let (vir_src, _) = src.unmap();
168-
let mut dst = vir_dst.map(phy0);
169-
let mut src = vir_src.map(phy1);
170-
test_memcpy_in_graph(&dev, &graph, &mut dst, &mut src, 0..);
171-
172-
let (vir_dst, _phy0) = dst.unmap();
173-
let (vir_src, _phy1) = src.unmap();
174-
let mut dst = vir_dst.map(phy2);
175-
let mut src = vir_src.map(phy3);
176-
test_memcpy_in_graph(&dev, &graph, &mut dst, &mut src, (0..u64::MAX).rev());
168+
let _ = dst_vir.unmap(0);
169+
let _ = src_vir.unmap(0);
170+
let dst = dst_vir.map(0, phy0);
171+
let src = src_vir.map(0, phy1);
172+
test_memcpy_in_graph(&dev, &graph, dst, src, 0..);
173+
174+
let _phy0 = dst_vir.unmap(0);
175+
let _phy1 = src_vir.unmap(0);
176+
let dst = dst_vir.map(0, phy2);
177+
let src = src_vir.map(0, phy3);
178+
test_memcpy_in_graph(&dev, &graph, dst, src, (0..u64::MAX).rev());
177179
}
178180

179181
fn test_memcpy_in_graph(

cuda/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ pub use graph::*;
7171
pub use host_mem::{HostMem, HostMemSpore};
7272
pub use nvrtc::{KernelFn, KernelParamPtrs, KernelParams, Module, ModuleSpore, Ptx, Symbol};
7373
pub use stream::{Stream, StreamSpore};
74-
pub use virtual_mem::{MappedMem, MemProp, PhyMem, VirByte, VirMem};
74+
pub use virtual_mem::{MemProp, PhyMem, VirByte, VirMem};
7575

7676
use std::{
7777
cmp::Ordering,

cuda/src/virtual_mem.rs

Lines changed: 87 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
};
99
use context_spore::AsRaw;
1010
use std::{
11-
mem::ManuallyDrop,
11+
collections::BTreeMap,
1212
ops::{Deref, DerefMut},
1313
ptr::null_mut,
1414
slice::{from_raw_parts, from_raw_parts_mut},
@@ -58,20 +58,19 @@ pub struct VirByte(u8);
5858
pub struct VirMem {
5959
ptr: CUdeviceptr,
6060
len: usize,
61+
/// offset -> phy
62+
map: BTreeMap<usize, PhyRegion>,
6163
}
6264

6365
impl VirMem {
6466
pub fn new(len: usize, min_addr: usize) -> Self {
6567
let mut ptr = 0;
6668
driver!(cuMemAddressReserve(&mut ptr, len, 0, min_addr as _, 0));
67-
Self { ptr, len }
68-
}
69-
}
70-
71-
impl Drop for VirMem {
72-
fn drop(&mut self) {
73-
let &mut Self { ptr, len } = self;
74-
driver!(cuMemAddressFree(ptr, len))
69+
Self {
70+
ptr,
71+
len,
72+
map: [(0, len.into())].into(),
73+
}
7574
}
7675
}
7776

@@ -89,6 +88,19 @@ impl DerefMut for VirMem {
8988
}
9089
}
9190

91+
impl Drop for VirMem {
92+
fn drop(&mut self) {
93+
let Self { ptr, len, map } = self;
94+
let map = std::mem::take(map);
95+
for (offset, region) in map {
96+
if let PhyRegion::Mapped(phy) = region {
97+
driver!(cuMemUnmap(*ptr + offset as CUdeviceptr, phy.len))
98+
}
99+
}
100+
driver!(cuMemAddressFree(*ptr, *len))
101+
}
102+
}
103+
92104
pub struct PhyMem {
93105
location: CUmemLocation,
94106
handle: CUmemGenericAllocationHandle,
@@ -134,68 +146,76 @@ impl PhyMem {
134146
}
135147
}
136148

137-
#[repr(transparent)]
138-
pub struct MappedMem(ManuallyDrop<Internal>);
139-
140-
/// 需要一个内部结构来控制何时自动释放。
141-
///
142-
/// [`MappedMem`] 自动释放时,[`Internal`] 的两个成员递归释放。主动解映射时,[`Internal`] 的成员被取出,不释放。
143-
struct Internal {
144-
vir: VirMem,
145-
phy: Arc<PhyMem>,
149+
enum PhyRegion {
150+
Mapped(Arc<PhyMem>),
151+
Vacant(usize),
146152
}
147153

148-
impl VirMem {
149-
pub fn map(self, phy: Arc<PhyMem>) -> MappedMem {
150-
debug_assert!(
151-
self.len >= phy.len,
152-
"cannot map physical memory to a smaller address region"
153-
);
154-
driver!(cuMemMap(self.ptr, phy.len, 0, phy.handle, 0));
155-
156-
let desc = CUmemAccessDesc {
157-
location: phy.location,
158-
flags: CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READWRITE,
159-
};
160-
driver!(cuMemSetAccess(self.ptr, phy.len, &desc, 1));
161-
162-
MappedMem(ManuallyDrop::new(Internal { vir: self, phy }))
163-
}
164-
165-
pub fn map_on(self, dev: &Device) -> MappedMem {
166-
let len = self.len;
167-
self.map(dev.mem_prop().create(len))
168-
}
169-
}
170-
171-
impl Drop for MappedMem {
172-
fn drop(&mut self) {
173-
driver!(cuMemUnmap(self.0.vir.ptr, self.0.phy.len));
174-
unsafe { ManuallyDrop::drop(&mut self.0) }
154+
impl From<Arc<PhyMem>> for PhyRegion {
155+
fn from(value: Arc<PhyMem>) -> Self {
156+
Self::Mapped(value)
175157
}
176158
}
177159

178-
impl MappedMem {
179-
pub fn unmap(mut self) -> (VirMem, Arc<PhyMem>) {
180-
driver!(cuMemUnmap(self.0.vir.ptr, self.0.phy.len));
181-
let Internal { vir, phy } = unsafe { ManuallyDrop::take(&mut self.0) };
182-
std::mem::forget(self);
183-
(vir, phy)
160+
impl From<usize> for PhyRegion {
161+
fn from(value: usize) -> Self {
162+
Self::Vacant(value)
184163
}
185164
}
186165

187-
impl Deref for MappedMem {
188-
type Target = [DevByte];
189-
#[inline]
190-
fn deref(&self) -> &Self::Target {
191-
unsafe { from_raw_parts(self.0.vir.ptr as _, self.0.phy.len) }
166+
impl VirMem {
167+
pub fn map(&mut self, offset: usize, phy: Arc<PhyMem>) -> &mut [DevByte] {
168+
// 检查范围
169+
assert!(offset <= self.len && offset + phy.len <= self.len);
170+
// 查找所在区间
171+
let (head, region) = self.map.range(..=offset).next_back().unwrap();
172+
// 获取空闲段长度
173+
let len = match *region {
174+
PhyRegion::Mapped(_) => panic!("mem is mapped"),
175+
PhyRegion::Vacant(len) => len,
176+
};
177+
assert!(phy.len <= len);
178+
// 映射
179+
{
180+
let ptr = self.ptr + offset as CUdeviceptr;
181+
driver!(cuMemMap(ptr, phy.len, 0, phy.handle, 0));
182+
let desc = CUmemAccessDesc {
183+
location: phy.location,
184+
flags: CUmemAccess_flags::CU_MEM_ACCESS_FLAGS_PROT_READWRITE,
185+
};
186+
driver!(cuMemSetAccess(ptr, phy.len, &desc, 1));
187+
}
188+
// 移除空闲段
189+
let head = *head;
190+
self.map.remove(&head);
191+
// 插入映射段
192+
let phy_len = phy.len;
193+
self.map.insert(offset, phy.into());
194+
// 插入头尾空闲段
195+
let head_len = offset - head;
196+
let tail_len = len - head_len - phy_len;
197+
if head_len > 0 {
198+
self.map.insert(head, head_len.into());
199+
}
200+
if tail_len > 0 {
201+
let tail = head + head_len + phy_len;
202+
self.map.insert(tail, tail_len.into());
203+
}
204+
unsafe { std::slice::from_raw_parts_mut((self.ptr + offset as CUdeviceptr) as _, phy_len) }
192205
}
193-
}
194206

195-
impl DerefMut for MappedMem {
196-
#[inline]
197-
fn deref_mut(&mut self) -> &mut Self::Target {
198-
unsafe { from_raw_parts_mut(self.0.vir.ptr as _, self.0.phy.len) }
207+
pub fn unmap(&mut self, offset: usize) -> Arc<PhyMem> {
208+
let region = self.map.get_mut(&offset).expect("offset is not a boundary");
209+
let len = match region {
210+
PhyRegion::Mapped(phy_mem) => phy_mem.len,
211+
PhyRegion::Vacant(_) => panic!("offset is not mapped"),
212+
};
213+
let PhyRegion::Mapped(phy) = std::mem::replace(region, len.into()) else {
214+
unreachable!()
215+
};
216+
let ptr = self.ptr + offset as CUdeviceptr;
217+
driver!(cuMemUnmap(ptr, phy.len));
218+
phy
199219
}
200220
}
201221

@@ -213,24 +233,24 @@ fn test_behavior() {
213233
println!("minimun = {minimum}, recommended = {recommended}");
214234

215235
// 分配一个较大的虚地址区域
216-
let virmem = VirMem::new(10 * minimum, 0);
236+
let mut virmem = VirMem::new(10 * minimum, 0);
217237
// 分配一个较小的物理页
218238
let phymem = prop.create(minimum);
219239
// 建立映射
220-
let mut mapped = virmem.map(phymem.clone());
240+
let mapped = virmem.map(minimum, phymem.clone());
221241

222242
// 通过虚地址操作存储空间
223243
let host = (0..minimum / size_of::<usize>()).collect::<Box<_>>();
224244
// 对存储空间的操作仍然需要在上下文中进行
225-
dev.context().apply(|_| memcpy_h2d(&mut mapped, &host));
245+
dev.context().apply(|_| memcpy_h2d(mapped, &host));
226246

227247
// 分配另一个虚地址区域
228-
let virmem = VirMem::new(2 * minimum, 0);
248+
let mut virmem = VirMem::new(2 * minimum, 0);
229249
// 将同一个物理页映射到虚地址区域
230-
let mapped = virmem.map(phymem);
250+
let mapped = virmem.map(minimum, phymem);
231251
// 在另一个上下文中读取存储空间
232252
let mut host_ = vec![0usize; host.len()];
233-
dev.context().apply(|_| memcpy_d2h(&mut host_, &mapped));
253+
dev.context().apply(|_| memcpy_d2h(&mut host_, mapped));
234254

235255
assert_eq!(&*host, &*host_)
236256
}

0 commit comments

Comments
 (0)