Skip to content

Commit c96680a

Browse files
authored
Merge pull request RustPython#1050 from palaviv/relative-import
Support relative import
2 parents 6d5f381 + da4c0bc commit c96680a

File tree

6 files changed

+23
-6
lines changed

6 files changed

+23
-6
lines changed

tests/snippets/dir_module/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .relative import value
2+
from .dir_module_inner import value2
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ..relative import value
2+
3+
value2 = value + 2

tests/snippets/dir_module/relative.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
value = 5

tests/snippets/import_module.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import dir_module
2+
assert dir_module.value == 5
3+
assert dir_module.value2 == 7

vm/src/frame.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,9 @@ impl Frame {
912912
.iter()
913913
.map(|symbol| vm.ctx.new_str(symbol.to_string()))
914914
.collect();
915-
let module = vm.import(module, &vm.ctx.new_tuple(from_list))?;
915+
let level = module.chars().take_while(|char| *char == '.').count();
916+
let module_name = &module[level..];
917+
let module = vm.import(module_name, &vm.ctx.new_tuple(from_list), level)?;
916918

917919
if symbols.is_empty() {
918920
self.push_value(module);
@@ -928,7 +930,8 @@ impl Frame {
928930
}
929931

930932
fn import_star(&self, vm: &VirtualMachine, module: &str) -> FrameResult {
931-
let module = vm.import(module, &vm.ctx.new_tuple(vec![]))?;
933+
let level = module.chars().take_while(|char| *char == '.').count();
934+
let module = vm.import(module, &vm.ctx.new_tuple(vec![]), level)?;
932935

933936
// Grab all the names from the module and put them in the context
934937
if let Some(dict) = &module.dict {

vm/src/vm.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,15 @@ impl VirtualMachine {
136136

137137
pub fn try_class(&self, module: &str, class: &str) -> PyResult<PyClassRef> {
138138
let class = self
139-
.get_attribute(self.import(module, &self.ctx.new_tuple(vec![]))?, class)?
139+
.get_attribute(self.import(module, &self.ctx.new_tuple(vec![]), 0)?, class)?
140140
.downcast()
141141
.expect("not a class");
142142
Ok(class)
143143
}
144144

145145
pub fn class(&self, module: &str, class: &str) -> PyClassRef {
146146
let module = self
147-
.import(module, &self.ctx.new_tuple(vec![]))
147+
.import(module, &self.ctx.new_tuple(vec![]), 0)
148148
.unwrap_or_else(|_| panic!("unable to import {}", module));
149149
let class = self
150150
.get_attribute(module.clone(), class)
@@ -302,7 +302,7 @@ impl VirtualMachine {
302302
TryFromObject::try_from_object(self, repr)
303303
}
304304

305-
pub fn import(&self, module: &str, from_list: &PyObjectRef) -> PyResult {
305+
pub fn import(&self, module: &str, from_list: &PyObjectRef, level: usize) -> PyResult {
306306
let sys_modules = self
307307
.get_attribute(self.sys_module.clone(), "modules")
308308
.unwrap();
@@ -314,9 +314,14 @@ impl VirtualMachine {
314314
func,
315315
vec![
316316
self.ctx.new_str(module.to_string()),
317-
self.get_none(),
317+
if self.current_frame().is_some() {
318+
self.get_locals().into_object()
319+
} else {
320+
self.get_none()
321+
},
318322
self.get_none(),
319323
from_list.clone(),
324+
self.ctx.new_int(level),
320325
],
321326
),
322327
Err(_) => Err(self.new_exception(

0 commit comments

Comments
 (0)