Skip to content

Commit aa963dc

Browse files
Merge pull request #67 from RustPython/mro
Compute Method Resolution Order (MRO).
2 parents c664fab + a9c3349 commit aa963dc

File tree

12 files changed

+245
-58
lines changed

12 files changed

+245
-58
lines changed

parser/src/ast.rs

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ pub enum Statement {
8181
ClassDef {
8282
name: String,
8383
body: Vec<Statement>,
84+
args: Vec<String>,
8485
// TODO: docstring: String,
8586
},
8687
FunctionDef {

parser/src/parser.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,12 @@ mod tests {
175175

176176
#[test]
177177
fn test_parse_class() {
178-
let source = String::from("class Foo:\n def __init__(self):\n pass\n");
178+
let source = String::from("class Foo(A, B):\n def __init__(self):\n pass\n");
179179
assert_eq!(
180180
parse_statement(&source),
181181
Ok(ast::Statement::ClassDef {
182182
name: String::from("Foo"),
183+
args: vec![String::from("A"), String::from("B")],
183184
body: vec![ast::Statement::FunctionDef {
184185
name: String::from("__init__"),
185186
args: vec![String::from("self")],

parser/src/python.lalrpop

+2-1
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,9 @@ TypedArgsList: Vec<String> = {
208208
};
209209

210210
ClassDef: ast::Statement = {
211-
"class" <n:Identifier> <_a:("(" ")")?> ":" <s:Suite> => ast::Statement::ClassDef {
211+
"class" <n:Identifier> <a:Parameters?> ":" <s:Suite> => ast::Statement::ClassDef {
212212
name: n,
213+
args: a.unwrap_or(vec![]),
213214
body: s},
214215
};
215216

tests/snippets/mro.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
class X():
2+
pass
3+
4+
class Y():
5+
pass
6+
7+
class A(X, Y):
8+
pass
9+
10+
print(A.__mro__)
11+
12+
class B(X, Y):
13+
pass
14+
15+
print(B.__mro__)
16+
17+
class C(A, B):
18+
pass
19+
20+
print(C.__mro__)

vm/src/builtins.rs

+10-13
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ fn builtin_range(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
270270
PyObjectKind::Integer { ref value } => {
271271
let range_elements: Vec<PyObjectRef> =
272272
(0..*value).map(|num| vm.context().new_int(num)).collect();
273-
Ok(vm.context().new_list(Some(range_elements)))
273+
Ok(vm.context().new_list(range_elements))
274274
}
275275
_ => panic!("first argument to range must be an integer"),
276276
}
@@ -372,13 +372,15 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
372372
obj
373373
}
374374

375-
pub fn builtin_build_class_(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
376-
let function = args.args[0].clone();
377-
let a1 = &*args.args[1].borrow();
378-
let name = match &a1.kind {
379-
PyObjectKind::String { value: name } => name,
380-
_ => panic!("Class name must be a string."),
375+
pub fn builtin_build_class_(vm: &mut VirtualMachine, mut args: PyFuncArgs) -> PyResult {
376+
let function = args.shift();
377+
let name_arg = args.shift();
378+
let name = match name_arg.borrow().kind {
379+
PyObjectKind::String { ref value } => value.to_string(),
380+
_ => panic!("Class name must by a string!"),
381381
};
382+
let mut bases = args.args.clone();
383+
bases.push(vm.context().object.clone());
382384
let metaclass = vm.get_type();
383385
let namespace = vm.new_dict();
384386
&vm.invoke(
@@ -387,10 +389,5 @@ pub fn builtin_build_class_(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResu
387389
args: vec![namespace.clone()],
388390
},
389391
);
390-
objtype::new(
391-
metaclass,
392-
name.to_string(),
393-
vec![vm.get_object()],
394-
namespace,
395-
)
392+
objtype::new(metaclass, name, bases, namespace)
396393
}

vm/src/compile.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ impl Compiler {
257257
name: name.to_string(),
258258
});
259259
}
260-
ast::Statement::ClassDef { name, body } => {
260+
ast::Statement::ClassDef { name, body, args } => {
261261
self.emit(Instruction::LoadBuildClass);
262262
self.code_object_stack
263263
.push(CodeObject::new(vec![String::from("__locals__")]));
@@ -288,7 +288,13 @@ impl Compiler {
288288
value: name.clone(),
289289
},
290290
});
291-
self.emit(Instruction::CallFunction { count: 2 });
291+
292+
for base in args {
293+
self.emit(Instruction::LoadName { name: base.clone() });
294+
}
295+
self.emit(Instruction::CallFunction {
296+
count: 2 + args.len(),
297+
});
292298

293299
self.emit(Instruction::StoreName {
294300
name: name.to_string(),

vm/src/objclass.rs

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ pub fn new_instance(vm: &mut VirtualMachine, mut args: PyFuncArgs) -> PyResult {
88
let type_ref = args.shift();
99
let dict = vm.new_dict();
1010
let obj = PyObject::new(PyObjectKind::Instance { dict: dict }, type_ref.clone());
11-
// TODO Raise TypeError if init returns not None.
1211
Ok(obj)
1312
}
1413

vm/src/objfunction.rs

+30-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use super::pyobject::{PyFuncArgs, PyObject, PyObjectKind, PyObjectRef, PyResult};
1+
use super::objtype;
2+
use super::pyobject::{
3+
AttributeProtocol, PyFuncArgs, PyObject, PyObjectKind, PyObjectRef, PyResult,
4+
};
25
use super::vm::VirtualMachine;
36
use std::collections::HashMap;
47

@@ -40,3 +43,29 @@ pub fn create_bound_method_type(type_type: PyObjectRef) -> PyObjectRef {
4043
fn bind_method(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
4144
Ok(vm.new_bound_method(args.args[0].clone(), args.args[1].clone()))
4245
}
46+
47+
pub fn create_member_descriptor_type(type_type: PyObjectRef, object: PyObjectRef) -> PyResult {
48+
let mut dict = HashMap::new();
49+
50+
dict.insert(
51+
String::from("__get__"),
52+
PyObject::new(
53+
PyObjectKind::RustFunction {
54+
function: member_get,
55+
},
56+
type_type.clone(),
57+
),
58+
);
59+
60+
objtype::new(
61+
type_type.clone(),
62+
String::from("member_descriptor"),
63+
vec![object],
64+
PyObject::new(PyObjectKind::Dict { elements: dict }, type_type.clone()),
65+
)
66+
}
67+
68+
fn member_get(vm: &mut VirtualMachine, mut args: PyFuncArgs) -> PyResult {
69+
let function = args.shift().get_attr(&String::from("function"));
70+
vm.invoke(function, args)
71+
}

vm/src/objtype.rs

+145-19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::pyobject::{
2-
AttributeProtocol, PyFuncArgs, PyObject, PyObjectKind, PyObjectRef, PyResult, ToRust,
3-
TypeProtocol,
2+
AttributeProtocol, IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectKind, PyObjectRef,
3+
PyResult, ToRust, TypeProtocol,
44
};
55
use super::vm::VirtualMachine;
66
use std::collections::HashMap;
@@ -30,33 +30,62 @@ pub fn create_type() -> PyObjectRef {
3030
typ
3131
}
3232

33-
pub fn type_type_add_methods(type_type: PyObjectRef, function_type: PyObjectRef) {
34-
type_type.set_attr(
35-
&String::from("__call__"),
36-
PyObject::new(
37-
PyObjectKind::RustFunction {
38-
function: type_call,
39-
},
40-
function_type.clone(),
41-
),
33+
pub fn init(context: &mut PyContext) {
34+
context
35+
.type_type
36+
.set_attr(&String::from("__call__"), context.new_rustfunc(type_call));
37+
context
38+
.type_type
39+
.set_attr(&String::from("__new__"), context.new_rustfunc(type_new));
40+
41+
context.type_type.set_attr(
42+
&String::from("__mro__"),
43+
context.new_member_descriptor(type_mro),
44+
);
45+
context.type_type.set_attr(
46+
&String::from("__class__"),
47+
context.new_member_descriptor(type_new),
4248
);
43-
type_type.set_attr(
44-
&String::from("__new__"),
45-
PyObject::new(
46-
PyObjectKind::RustFunction { function: type_new },
47-
function_type.clone(),
48-
),
49+
context.type_type.set_attr(
50+
&String::from("__dict__"),
51+
context.new_member_descriptor(type_dict),
4952
);
5053
}
5154

55+
fn type_mro(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
56+
match _mro(args.args[0].clone()) {
57+
Some(mro) => Ok(vm.context().new_tuple(mro)),
58+
None => Err(vm.new_exception("Only classes have an MRO.".to_string())),
59+
}
60+
}
61+
62+
fn _mro(cls: PyObjectRef) -> Option<Vec<PyObjectRef>> {
63+
match cls.borrow().kind {
64+
PyObjectKind::Class { ref mro, .. } => {
65+
let mut mro = mro.clone();
66+
mro.insert(0, cls.clone());
67+
Some(mro)
68+
}
69+
_ => None,
70+
}
71+
}
72+
73+
fn type_dict(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
74+
match args.args[0].borrow().kind {
75+
PyObjectKind::Class { ref dict, .. } => Ok(dict.clone()),
76+
_ => Err(vm.new_exception("type_dict must be called on a class.".to_string())),
77+
}
78+
}
79+
5280
pub fn type_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
5381
debug!("type.__new__{:?}", args);
5482
if args.args.len() == 2 {
5583
Ok(args.args[1].typ())
5684
} else if args.args.len() == 4 {
5785
let typ = args.args[0].clone();
5886
let name = args.args[1].to_str().unwrap();
59-
let bases = args.args[2].to_vec().unwrap();
87+
let mut bases = args.args[2].to_vec().unwrap();
88+
bases.push(vm.context().object.clone());
6089
let dict = args.args[3].clone();
6190
new(typ, name, bases, dict)
6291
} else {
@@ -108,12 +137,61 @@ pub fn get_attribute(vm: &mut VirtualMachine, obj: PyObjectRef, name: &String) -
108137
}
109138
}
110139

140+
fn take_next_base(
141+
mut bases: Vec<Vec<PyObjectRef>>,
142+
) -> Option<(PyObjectRef, Vec<Vec<PyObjectRef>>)> {
143+
let mut next = None;
144+
145+
bases = bases.into_iter().filter(|x| !x.is_empty()).collect();
146+
147+
for base in &bases {
148+
let head = base[0].clone();
149+
if !(&bases)
150+
.into_iter()
151+
.any(|x| x[1..].into_iter().any(|x| x.get_id() == head.get_id()))
152+
{
153+
next = Some(head);
154+
break;
155+
}
156+
}
157+
158+
if let Some(head) = next {
159+
for ref mut item in &mut bases {
160+
if item[0].get_id() == head.get_id() {
161+
item.remove(0);
162+
}
163+
}
164+
return Some((head, bases));
165+
}
166+
None
167+
}
168+
169+
fn linearise_mro(mut bases: Vec<Vec<PyObjectRef>>) -> Option<Vec<PyObjectRef>> {
170+
debug!("Linearising MRO: {:?}", bases);
171+
let mut result = vec![];
172+
loop {
173+
if (&bases).into_iter().all(|x| x.is_empty()) {
174+
break;
175+
}
176+
match take_next_base(bases) {
177+
Some((head, new_bases)) => {
178+
result.push(head);
179+
bases = new_bases;
180+
}
181+
None => return None,
182+
}
183+
}
184+
Some(result)
185+
}
186+
111187
pub fn new(typ: PyObjectRef, name: String, bases: Vec<PyObjectRef>, dict: PyObjectRef) -> PyResult {
188+
let mros = bases.into_iter().map(|x| _mro(x).unwrap()).collect();
189+
let mro = linearise_mro(mros).unwrap();
112190
Ok(PyObject::new(
113191
PyObjectKind::Class {
114192
name: name,
115193
dict: dict,
116-
mro: bases,
194+
mro: mro,
117195
},
118196
typ,
119197
))
@@ -123,3 +201,51 @@ pub fn call(vm: &mut VirtualMachine, typ: PyObjectRef, args: PyFuncArgs) -> PyRe
123201
let function = get_attribute(vm, typ, &String::from("__call__"))?;
124202
vm.invoke(function, args)
125203
}
204+
205+
#[cfg(test)]
206+
mod tests {
207+
use super::{create_type, linearise_mro, new};
208+
use super::{IdProtocol, PyContext, PyObjectRef};
209+
210+
fn map_ids(obj: Option<Vec<PyObjectRef>>) -> Option<Vec<usize>> {
211+
match obj {
212+
Some(vec) => Some(vec.into_iter().map(|x| x.get_id()).collect()),
213+
None => None,
214+
}
215+
}
216+
217+
#[test]
218+
fn test_linearise() {
219+
let context = PyContext::new();
220+
let object = context.object;
221+
let type_type = create_type();
222+
223+
let a = new(
224+
type_type.clone(),
225+
String::from("A"),
226+
vec![object.clone()],
227+
type_type.clone(),
228+
).unwrap();
229+
let b = new(
230+
type_type.clone(),
231+
String::from("B"),
232+
vec![object.clone()],
233+
type_type.clone(),
234+
).unwrap();
235+
236+
assert_eq!(
237+
map_ids(linearise_mro(vec![
238+
vec![object.clone()],
239+
vec![object.clone()]
240+
])),
241+
map_ids(Some(vec![object.clone()]))
242+
);
243+
assert_eq!(
244+
map_ids(linearise_mro(vec![
245+
vec![a.clone(), object.clone()],
246+
vec![b.clone(), object.clone()],
247+
])),
248+
map_ids(Some(vec![a.clone(), b.clone(), object.clone()]))
249+
);
250+
}
251+
}

0 commit comments

Comments
 (0)