Skip to content

Commit

Permalink
[red-knot] resolve class members (#11256)
Browse files Browse the repository at this point in the history
  • Loading branch information
carljm committed May 3, 2024
1 parent 6a1e555 commit 82dd5e6
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 27 deletions.
45 changes: 31 additions & 14 deletions crates/red_knot/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub(crate) struct Scope {
name: Name,
kind: ScopeKind,
child_scopes: Vec<ScopeId>,
// symbol IDs, hashed by symbol name
/// symbol IDs, hashed by symbol name
symbols_by_name: Map<SymbolId, ()>,
}

Expand Down Expand Up @@ -107,6 +107,7 @@ bitflags! {
pub(crate) struct Symbol {
name: Name,
flags: SymbolFlags,
scope_id: ScopeId,
// kind: Kind,
}

Expand Down Expand Up @@ -141,7 +142,7 @@ pub(crate) enum Definition {
// the small amount of information we need from the AST.
Import(ImportDefinition),
ImportFrom(ImportFromDefinition),
ClassDef(TypedNodeKey<ast::StmtClassDef>),
ClassDef(ClassDefinition),
FunctionDef(TypedNodeKey<ast::StmtFunctionDef>),
Assignment(TypedNodeKey<ast::StmtAssign>),
AnnotatedAssignment(TypedNodeKey<ast::StmtAnnAssign>),
Expand Down Expand Up @@ -174,6 +175,12 @@ impl ImportFromDefinition {
}
}

#[derive(Clone, Debug)]
pub(crate) struct ClassDefinition {
pub(crate) node_key: TypedNodeKey<ast::StmtClassDef>,
pub(crate) scope_id: ScopeId,
}

#[derive(Debug, Clone)]
pub enum Dependency {
Module(ModuleName),
Expand Down Expand Up @@ -332,7 +339,11 @@ impl SymbolTable {
*entry.key()
}
RawEntryMut::Vacant(entry) => {
let id = self.symbols_by_id.push(Symbol { name, flags });
let id = self.symbols_by_id.push(Symbol {
name,
flags,
scope_id,
});
entry.insert_with_hasher(hash, id, (), |_| hash);
id
}
Expand Down Expand Up @@ -459,8 +470,8 @@ impl SymbolTableBuilder {
symbol_id
}

fn push_scope(&mut self, child_of: ScopeId, name: &str, kind: ScopeKind) -> ScopeId {
let scope_id = self.table.add_child_scope(child_of, name, kind);
fn push_scope(&mut self, name: &str, kind: ScopeKind) -> ScopeId {
let scope_id = self.table.add_child_scope(self.cur_scope(), name, kind);
self.scopes.push(scope_id);
scope_id
}
Expand All @@ -482,10 +493,10 @@ impl SymbolTableBuilder {
&mut self,
name: &str,
params: &Option<Box<ast::TypeParams>>,
nested: impl FnOnce(&mut Self),
) {
nested: impl FnOnce(&mut Self) -> ScopeId,
) -> ScopeId {
if let Some(type_params) = params {
self.push_scope(self.cur_scope(), name, ScopeKind::Annotation);
self.push_scope(name, ScopeKind::Annotation);
for type_param in &type_params.type_params {
let name = match type_param {
ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name,
Expand All @@ -495,10 +506,11 @@ impl SymbolTableBuilder {
self.add_or_update_symbol(name, SymbolFlags::IS_DEFINED);
}
}
nested(self);
let scope_id = nested(self);
if params.is_some() {
self.pop_scope();
}
scope_id
}
}

Expand All @@ -525,21 +537,26 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
// TODO need to capture more definition statements here
match stmt {
ast::Stmt::ClassDef(node) => {
let def = Definition::ClassDef(TypedNodeKey::from_node(node));
self.add_or_update_symbol_with_def(&node.name, def);
self.with_type_params(&node.name, &node.type_params, |builder| {
builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Class);
let scope_id = self.with_type_params(&node.name, &node.type_params, |builder| {
let scope_id = builder.push_scope(&node.name, ScopeKind::Class);
ast::visitor::preorder::walk_stmt(builder, stmt);
builder.pop_scope();
scope_id
});
let def = Definition::ClassDef(ClassDefinition {
node_key: TypedNodeKey::from_node(node),
scope_id,
});
self.add_or_update_symbol_with_def(&node.name, def);
}
ast::Stmt::FunctionDef(node) => {
let def = Definition::FunctionDef(TypedNodeKey::from_node(node));
self.add_or_update_symbol_with_def(&node.name, def);
self.with_type_params(&node.name, &node.type_params, |builder| {
builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Function);
let scope_id = builder.push_scope(&node.name, ScopeKind::Function);
ast::visitor::preorder::walk_stmt(builder, stmt);
builder.pop_scope();
scope_id
});
}
ast::Stmt::Import(ast::StmtImport { names, .. }) => {
Expand Down
52 changes: 42 additions & 10 deletions crates/red_knot/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#![allow(dead_code)]
use crate::ast_ids::NodeKey;
use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar};
use crate::files::FileId;
use crate::symbols::SymbolId;
use crate::symbols::{ScopeId, SymbolId};
use crate::{FxDashMap, FxIndexSet, Name};
use ruff_index::{newtype_index, IndexVec};
use rustc_hash::FxHashMap;
Expand Down Expand Up @@ -124,8 +125,15 @@ impl TypeStore {
.add_function(name, decorators)
}

fn add_class(&self, file_id: FileId, name: &str, bases: Vec<Type>) -> ClassTypeId {
self.add_or_get_module(file_id).add_class(name, bases)
fn add_class(
&self,
file_id: FileId,
name: &str,
scope_id: ScopeId,
bases: Vec<Type>,
) -> ClassTypeId {
self.add_or_get_module(file_id)
.add_class(name, scope_id, bases)
}

fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
Expand Down Expand Up @@ -253,6 +261,24 @@ pub struct ClassTypeId {
class_id: ModuleClassTypeId,
}

impl ClassTypeId {
fn get_own_class_member<Db>(self, db: &Db, name: &Name) -> QueryResult<Option<Type>>
where
Db: SemanticDb + HasJar<SemanticJar>,
{
// TODO: this should distinguish instance-only members (e.g. `x: int`) and not return them
let ClassType { scope_id, .. } = *db.jar()?.type_store.get_class(self);
let table = db.symbol_table(self.file_id)?;
if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) {
Ok(Some(db.infer_symbol_type(self.file_id, symbol_id)?))
} else {
Ok(None)
}
}

// TODO: get_own_instance_member, get_class_member, get_instance_member
}

#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub struct UnionTypeId {
file_id: FileId,
Expand Down Expand Up @@ -318,9 +344,10 @@ impl ModuleTypeStore {
}
}

fn add_class(&mut self, name: &str, bases: Vec<Type>) -> ClassTypeId {
fn add_class(&mut self, name: &str, scope_id: ScopeId, bases: Vec<Type>) -> ClassTypeId {
let class_id = self.classes.push(ClassType {
name: Name::new(name),
scope_id,
// TODO: if no bases are given, that should imply [object]
bases,
});
Expand Down Expand Up @@ -405,7 +432,11 @@ impl std::fmt::Display for DisplayType<'_> {

#[derive(Debug)]
pub(crate) struct ClassType {
/// Name of the class at definition
name: Name,
/// `ScopeId` of the class body
pub(crate) scope_id: ScopeId,
/// Types of all class bases
bases: Vec<Type>,
}

Expand Down Expand Up @@ -496,6 +527,7 @@ impl IntersectionType {
#[cfg(test)]
mod tests {
use crate::files::Files;
use crate::symbols::SymbolTable;
use crate::types::{Type, TypeStore};
use crate::FxIndexSet;
use std::path::Path;
Expand All @@ -505,7 +537,7 @@ mod tests {
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let id = store.add_class(file_id, "C", Vec::new());
let id = store.add_class(file_id, "C", SymbolTable::root_scope_id(), Vec::new());
assert_eq!(store.get_class(id).name(), "C");
let inst = Type::Instance(id);
assert_eq!(format!("{}", inst.display(&store)), "C");
Expand All @@ -528,8 +560,8 @@ mod tests {
let mut store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1", Vec::new());
let c2 = store.add_class(file_id, "C2", Vec::new());
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
let c2 = store.add_class(file_id, "C2", SymbolTable::root_scope_id(), Vec::new());
let elems = vec![Type::Instance(c1), Type::Instance(c2)];
let id = store.add_union(file_id, &elems);
assert_eq!(
Expand All @@ -545,9 +577,9 @@ mod tests {
let mut store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1", Vec::new());
let c2 = store.add_class(file_id, "C2", Vec::new());
let c3 = store.add_class(file_id, "C3", Vec::new());
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
let c2 = store.add_class(file_id, "C2", SymbolTable::root_scope_id(), Vec::new());
let c3 = store.add_class(file_id, "C3", SymbolTable::root_scope_id(), Vec::new());
let pos = vec![Type::Instance(c1), Type::Instance(c2)];
let neg = vec![Type::Instance(c3)];
let id = store.add_intersection(file_id, &pos, &neg);
Expand Down
46 changes: 43 additions & 3 deletions crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use ruff_python_ast::AstNode;

use crate::db::{HasJar, QueryResult, SemanticDb, SemanticJar};
use crate::module::ModuleName;
use crate::symbols::{Definition, ImportFromDefinition, SymbolId};
use crate::symbols::{ClassDefinition, Definition, ImportFromDefinition, SymbolId};
use crate::types::Type;
use crate::FileId;
use ruff_python_ast as ast;
Expand Down Expand Up @@ -51,7 +51,7 @@ where
Type::Unknown
}
}
Definition::ClassDef(node_key) => {
Definition::ClassDef(ClassDefinition { node_key, scope_id }) => {
if let Some(ty) = type_store.get_cached_node_type(file_id, node_key.erased()) {
ty
} else {
Expand All @@ -65,7 +65,8 @@ where
bases.push(infer_expr_type(db, file_id, base)?);
}

let ty = Type::Class(type_store.add_class(file_id, &node.name.id, bases));
let ty =
Type::Class(type_store.add_class(file_id, &node.name.id, *scope_id, bases));
type_store.cache_node_type(file_id, *node_key.erased(), ty);
ty
}
Expand Down Expand Up @@ -133,6 +134,7 @@ mod tests {
use crate::db::{HasJar, SemanticDb, SemanticJar};
use crate::module::{ModuleName, ModuleSearchPath, ModuleSearchPathKind};
use crate::types::Type;
use crate::Name;

// TODO with virtual filesystem we shouldn't have to write files to disk for these
// tests
Expand Down Expand Up @@ -222,4 +224,42 @@ mod tests {

Ok(())
}

#[test]
fn resolve_method() -> anyhow::Result<()> {
let case = create_test()?;
let db = &case.db;

let path = case.src.path().join("mod.py");
std::fs::write(path, "class C:\n def f(self): pass")?;
let file = db
.resolve_module(ModuleName::new("mod"))?
.expect("module should be found")
.path(db)?
.file();
let syms = db.symbol_table(file)?;
let sym = syms
.root_symbol_id_by_name("C")
.expect("C symbol should be found");

let ty = db.infer_symbol_type(file, sym)?;

let Type::Class(class_id) = ty else {
panic!("C is not a Class");
};

let member_ty = class_id
.get_own_class_member(db, &Name::new("f"))
.expect("C.f to resolve");

let Some(Type::Function(func_id)) = member_ty else {
panic!("C.f is not a Function");
};

let jar = HasJar::<SemanticJar>::jar(db)?;
let function = jar.type_store.get_function(func_id);
assert_eq!(function.name(), "f");

Ok(())
}
}

0 comments on commit 82dd5e6

Please sign in to comment.