88} ;
99use context_spore:: AsRaw ;
1010use std:: {
11- mem :: ManuallyDrop ,
11+ collections :: BTreeMap ,
1212 ops:: { Deref , DerefMut } ,
1313 ptr:: null_mut,
1414 slice:: { from_raw_parts, from_raw_parts_mut} ,
@@ -58,20 +58,19 @@ pub struct VirByte(u8);
5858pub struct VirMem {
5959 ptr : CUdeviceptr ,
6060 len : usize ,
61+ /// offset -> phy
62+ map : BTreeMap < usize , PhyRegion > ,
6163}
6264
6365impl VirMem {
6466 pub fn new ( len : usize , min_addr : usize ) -> Self {
6567 let mut ptr = 0 ;
6668 driver ! ( cuMemAddressReserve( & mut ptr, len, 0 , min_addr as _, 0 ) ) ;
67- Self { ptr, len }
68- }
69- }
70-
71- impl Drop for VirMem {
72- fn drop ( & mut self ) {
73- let & mut Self { ptr, len } = self ;
74- driver ! ( cuMemAddressFree( ptr, len) )
69+ Self {
70+ ptr,
71+ len,
72+ map : [ ( 0 , len. into ( ) ) ] . into ( ) ,
73+ }
7574 }
7675}
7776
@@ -89,6 +88,19 @@ impl DerefMut for VirMem {
8988 }
9089}
9190
91+ impl Drop for VirMem {
92+ fn drop ( & mut self ) {
93+ let Self { ptr, len, map } = self ;
94+ let map = std:: mem:: take ( map) ;
95+ for ( offset, region) in map {
96+ if let PhyRegion :: Mapped ( phy) = region {
97+ driver ! ( cuMemUnmap( * ptr + offset as CUdeviceptr , phy. len) )
98+ }
99+ }
100+ driver ! ( cuMemAddressFree( * ptr, * len) )
101+ }
102+ }
103+
92104pub struct PhyMem {
93105 location : CUmemLocation ,
94106 handle : CUmemGenericAllocationHandle ,
@@ -134,68 +146,76 @@ impl PhyMem {
134146 }
135147}
136148
137- #[ repr( transparent) ]
138- pub struct MappedMem ( ManuallyDrop < Internal > ) ;
139-
140- /// 需要一个内部结构来控制何时自动释放。
141- ///
142- /// [`MappedMem`] 自动释放时,[`Internal`] 的两个成员递归释放。主动解映射时,[`Internal`] 的成员被取出,不释放。
143- struct Internal {
144- vir : VirMem ,
145- phy : Arc < PhyMem > ,
149+ enum PhyRegion {
150+ Mapped ( Arc < PhyMem > ) ,
151+ Vacant ( usize ) ,
146152}
147153
148- impl VirMem {
149- pub fn map ( self , phy : Arc < PhyMem > ) -> MappedMem {
150- debug_assert ! (
151- self . len >= phy. len,
152- "cannot map physical memory to a smaller address region"
153- ) ;
154- driver ! ( cuMemMap( self . ptr, phy. len, 0 , phy. handle, 0 ) ) ;
155-
156- let desc = CUmemAccessDesc {
157- location : phy. location ,
158- flags : CUmemAccess_flags :: CU_MEM_ACCESS_FLAGS_PROT_READWRITE ,
159- } ;
160- driver ! ( cuMemSetAccess( self . ptr, phy. len, & desc, 1 ) ) ;
161-
162- MappedMem ( ManuallyDrop :: new ( Internal { vir : self , phy } ) )
163- }
164-
165- pub fn map_on ( self , dev : & Device ) -> MappedMem {
166- let len = self . len ;
167- self . map ( dev. mem_prop ( ) . create ( len) )
168- }
169- }
170-
171- impl Drop for MappedMem {
172- fn drop ( & mut self ) {
173- driver ! ( cuMemUnmap( self . 0 . vir. ptr, self . 0 . phy. len) ) ;
174- unsafe { ManuallyDrop :: drop ( & mut self . 0 ) }
154+ impl From < Arc < PhyMem > > for PhyRegion {
155+ fn from ( value : Arc < PhyMem > ) -> Self {
156+ Self :: Mapped ( value)
175157 }
176158}
177159
178- impl MappedMem {
179- pub fn unmap ( mut self ) -> ( VirMem , Arc < PhyMem > ) {
180- driver ! ( cuMemUnmap( self . 0 . vir. ptr, self . 0 . phy. len) ) ;
181- let Internal { vir, phy } = unsafe { ManuallyDrop :: take ( & mut self . 0 ) } ;
182- std:: mem:: forget ( self ) ;
183- ( vir, phy)
160+ impl From < usize > for PhyRegion {
161+ fn from ( value : usize ) -> Self {
162+ Self :: Vacant ( value)
184163 }
185164}
186165
187- impl Deref for MappedMem {
188- type Target = [ DevByte ] ;
189- #[ inline]
190- fn deref ( & self ) -> & Self :: Target {
191- unsafe { from_raw_parts ( self . 0 . vir . ptr as _ , self . 0 . phy . len ) }
166+ impl VirMem {
167+ pub fn map ( & mut self , offset : usize , phy : Arc < PhyMem > ) -> & mut [ DevByte ] {
168+ // 检查范围
169+ assert ! ( offset <= self . len && offset + phy. len <= self . len) ;
170+ // 查找所在区间
171+ let ( head, region) = self . map . range ( ..=offset) . next_back ( ) . unwrap ( ) ;
172+ // 获取空闲段长度
173+ let len = match * region {
174+ PhyRegion :: Mapped ( _) => panic ! ( "mem is mapped" ) ,
175+ PhyRegion :: Vacant ( len) => len,
176+ } ;
177+ assert ! ( phy. len <= len) ;
178+ // 映射
179+ {
180+ let ptr = self . ptr + offset as CUdeviceptr ;
181+ driver ! ( cuMemMap( ptr, phy. len, 0 , phy. handle, 0 ) ) ;
182+ let desc = CUmemAccessDesc {
183+ location : phy. location ,
184+ flags : CUmemAccess_flags :: CU_MEM_ACCESS_FLAGS_PROT_READWRITE ,
185+ } ;
186+ driver ! ( cuMemSetAccess( ptr, phy. len, & desc, 1 ) ) ;
187+ }
188+ // 移除空闲段
189+ let head = * head;
190+ self . map . remove ( & head) ;
191+ // 插入映射段
192+ let phy_len = phy. len ;
193+ self . map . insert ( offset, phy. into ( ) ) ;
194+ // 插入头尾空闲段
195+ let head_len = offset - head;
196+ let tail_len = len - head_len - phy_len;
197+ if head_len > 0 {
198+ self . map . insert ( head, head_len. into ( ) ) ;
199+ }
200+ if tail_len > 0 {
201+ let tail = head + head_len + phy_len;
202+ self . map . insert ( tail, tail_len. into ( ) ) ;
203+ }
204+ unsafe { std:: slice:: from_raw_parts_mut ( ( self . ptr + offset as CUdeviceptr ) as _ , phy_len) }
192205 }
193- }
194206
195- impl DerefMut for MappedMem {
196- #[ inline]
197- fn deref_mut ( & mut self ) -> & mut Self :: Target {
198- unsafe { from_raw_parts_mut ( self . 0 . vir . ptr as _ , self . 0 . phy . len ) }
207+ pub fn unmap ( & mut self , offset : usize ) -> Arc < PhyMem > {
208+ let region = self . map . get_mut ( & offset) . expect ( "offset is not a boundary" ) ;
209+ let len = match region {
210+ PhyRegion :: Mapped ( phy_mem) => phy_mem. len ,
211+ PhyRegion :: Vacant ( _) => panic ! ( "offset is not mapped" ) ,
212+ } ;
213+ let PhyRegion :: Mapped ( phy) = std:: mem:: replace ( region, len. into ( ) ) else {
214+ unreachable ! ( )
215+ } ;
216+ let ptr = self . ptr + offset as CUdeviceptr ;
217+ driver ! ( cuMemUnmap( ptr, phy. len) ) ;
218+ phy
199219 }
200220}
201221
@@ -213,24 +233,24 @@ fn test_behavior() {
213233 println ! ( "minimun = {minimum}, recommended = {recommended}" ) ;
214234
215235 // 分配一个较大的虚地址区域
216- let virmem = VirMem :: new ( 10 * minimum, 0 ) ;
236+ let mut virmem = VirMem :: new ( 10 * minimum, 0 ) ;
217237 // 分配一个较小的物理页
218238 let phymem = prop. create ( minimum) ;
219239 // 建立映射
220- let mut mapped = virmem. map ( phymem. clone ( ) ) ;
240+ let mapped = virmem. map ( minimum , phymem. clone ( ) ) ;
221241
222242 // 通过虚地址操作存储空间
223243 let host = ( 0 ..minimum / size_of :: < usize > ( ) ) . collect :: < Box < _ > > ( ) ;
224244 // 对存储空间的操作仍然需要在上下文中进行
225- dev. context ( ) . apply ( |_| memcpy_h2d ( & mut mapped, & host) ) ;
245+ dev. context ( ) . apply ( |_| memcpy_h2d ( mapped, & host) ) ;
226246
227247 // 分配另一个虚地址区域
228- let virmem = VirMem :: new ( 2 * minimum, 0 ) ;
248+ let mut virmem = VirMem :: new ( 2 * minimum, 0 ) ;
229249 // 将同一个物理页映射到虚地址区域
230- let mapped = virmem. map ( phymem) ;
250+ let mapped = virmem. map ( minimum , phymem) ;
231251 // 在另一个上下文中读取存储空间
232252 let mut host_ = vec ! [ 0usize ; host. len( ) ] ;
233- dev. context ( ) . apply ( |_| memcpy_d2h ( & mut host_, & mapped) ) ;
253+ dev. context ( ) . apply ( |_| memcpy_d2h ( & mut host_, mapped) ) ;
234254
235255 assert_eq ! ( & * host, & * host_)
236256}
0 commit comments