Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/libexpr-tests/primops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ TEST_F(PrimOpTest, derivation)
ASSERT_EQ(v.type(), nFunction);
ASSERT_TRUE(v.isLambda());
ASSERT_NE(v.lambda().fun, nullptr);
ASSERT_TRUE(v.lambda().fun->hasFormals());
ASSERT_TRUE(v.lambda().fun->getFormals());
}

TEST_F(PrimOpTest, currentTime)
Expand Down
6 changes: 2 additions & 4 deletions src/libexpr-tests/value/print.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,8 @@ TEST_F(ValuePrintingTests, vLambda)
PosTable::Origin origin = state.positions.addOrigin(std::monostate(), 1);
auto posIdx = state.positions.add(origin, 0);
auto body = ExprInt(0);
auto formals = Formals{};

ExprLambda eLambda(posIdx, createSymbol("a"), &formals, &body);
ExprLambda eLambda(posIdx, createSymbol("a"), &body);

Value vLambda;
vLambda.mkLambda(&env, &eLambda);
Expand Down Expand Up @@ -500,9 +499,8 @@ TEST_F(ValuePrintingTests, ansiColorsLambda)
PosTable::Origin origin = state.positions.addOrigin(std::monostate(), 1);
auto posIdx = state.positions.add(origin, 0);
auto body = ExprInt(0);
auto formals = Formals{};

ExprLambda eLambda(posIdx, createSymbol("a"), &formals, &body);
ExprLambda eLambda(posIdx, createSymbol("a"), &body);

Value vLambda;
vLambda.mkLambda(&env, &eLambda);
Expand Down
25 changes: 13 additions & 12 deletions src/libexpr/eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1496,15 +1496,13 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,

ExprLambda & lambda(*vCur.lambda().fun);

auto size = (!lambda.arg ? 0 : 1) + (lambda.hasFormals() ? lambda.formals->formals.size() : 0);
auto size = (!lambda.arg ? 0 : 1) + (lambda.getFormals() ? lambda.getFormals()->formals.size() : 0);
Env & env2(mem.allocEnv(size));
env2.up = vCur.lambda().env;

Displacement displ = 0;

if (!lambda.hasFormals())
env2.values[displ++] = args[0];
else {
if (auto formals = lambda.getFormals()) {
try {
forceAttrs(*args[0], lambda.pos, "while evaluating the value passed for the lambda argument");
} catch (Error & e) {
Expand All @@ -1520,7 +1518,7 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
there is no matching actual argument but the formal
argument has a default, use the default. */
size_t attrsUsed = 0;
for (auto & i : lambda.formals->formals) {
for (auto & i : formals->formals) {
auto j = args[0]->attrs()->get(i.name);
if (!j) {
if (!i.def) {
Expand All @@ -1542,13 +1540,13 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,

/* Check that each actual argument is listed as a formal
argument (unless the attribute match specifies a `...'). */
if (!lambda.formals->ellipsis && attrsUsed != args[0]->attrs()->size()) {
if (!formals->ellipsis && attrsUsed != args[0]->attrs()->size()) {
/* Nope, so show the first unexpected argument to the
user. */
for (auto & i : *args[0]->attrs())
if (!lambda.formals->has(i.name)) {
if (!formals->has(i.name)) {
StringSet formalNames;
for (auto & formal : lambda.formals->formals)
for (auto & formal : formals->formals)
formalNames.insert(std::string(symbols[formal.name]));
auto suggestions = Suggestions::bestMatches(formalNames, symbols[i.name]);
error<TypeError>(
Expand All @@ -1563,6 +1561,8 @@ void EvalState::callFunction(Value & fun, std::span<Value *> args, Value & vRes,
}
unreachable();
}
} else {
env2.values[displ++] = args[0];
}

nrFunctionCalls++;
Expand Down Expand Up @@ -1747,22 +1747,23 @@ void EvalState::autoCallFunction(const Bindings & args, Value & fun, Value & res
}
}

if (!fun.isLambda() || !fun.lambda().fun->hasFormals()) {
if (!fun.isLambda() || !fun.lambda().fun->getFormals()) {
res = fun;
return;
}
auto formals = fun.lambda().fun->getFormals();

auto attrs = buildBindings(std::max(static_cast<uint32_t>(fun.lambda().fun->formals->formals.size()), args.size()));
auto attrs = buildBindings(std::max(static_cast<uint32_t>(formals->formals.size()), args.size()));

if (fun.lambda().fun->formals->ellipsis) {
if (formals->ellipsis) {
// If the formals have an ellipsis (eg the function accepts extra args) pass
// all available automatic arguments (which includes arguments specified on
// the command line via --arg/--argstr)
for (auto & v : args)
attrs.insert(v);
} else {
// Otherwise, only pass the arguments that the function accepts
for (auto & i : fun.lambda().fun->formals->formals) {
for (auto & i : formals->formals) {
auto j = args.get(i.name);
if (j) {
attrs.insert(*j);
Expand Down
73 changes: 58 additions & 15 deletions src/libexpr/include/nix/expr/nixexpr.hh
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ struct Formal
Expr * def;
};

struct Formals
struct FormalsBuilder
{
typedef std::vector<Formal> Formals_;
/**
Expand All @@ -481,6 +481,23 @@ struct Formals
formals.begin(), formals.end(), arg, [](const Formal & f, const Symbol & sym) { return f.name < sym; });
return it != formals.end() && it->name == arg;
}
};

struct Formals
{
std::span<Formal> formals;
bool ellipsis;

Formals(std::span<Formal> formals, bool ellipsis)
: formals(formals)
, ellipsis(ellipsis) {};

bool has(Symbol arg) const
{
auto it = std::lower_bound(
formals.begin(), formals.end(), arg, [](const Formal & f, const Symbol & sym) { return f.name < sym; });
return it != formals.end() && it->name == arg;
}

std::vector<Formal> lexicographicOrder(const SymbolTable & symbols) const
{
Expand All @@ -498,31 +515,57 @@ struct ExprLambda : Expr
PosIdx pos;
Symbol name;
Symbol arg;
Formals * formals;

private:
bool hasFormals;
bool ellipsis;
uint16_t nFormals;
Formal * formalsStart;
public:

std::optional<Formals> getFormals() const
{
if (hasFormals)
return Formals{{formalsStart, nFormals}, ellipsis};
else
return std::nullopt;
}

Expr * body;
DocComment docComment;

ExprLambda(PosIdx pos, Symbol arg, Formals * formals, Expr * body)
ExprLambda(
std::pmr::polymorphic_allocator<char> & alloc,
PosIdx pos,
Symbol arg,
const FormalsBuilder & formals,
Expr * body)
: pos(pos)
, arg(arg)
, formals(formals)
, body(body) {};

ExprLambda(PosIdx pos, Formals * formals, Expr * body)
: pos(pos)
, formals(formals)
, hasFormals(true)
, ellipsis(formals.ellipsis)
, nFormals(formals.formals.size())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can do a silent truncation from size_t -> uint16_t.

, formalsStart(alloc.allocate_object<Formal>(nFormals))
, body(body)
{
}
std::ranges::copy(formals.formals, formalsStart);
Comment on lines -516 to +551
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have a check in case there a more than 65k formals. We must fail gracefully.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would overflow the heap buffer otherwise. Let's please not do this.

};

ExprLambda(PosIdx pos, Symbol arg, Expr * body)
: pos(pos)
, arg(arg)
, hasFormals(false)
, ellipsis(false)
, nFormals(0)
, formalsStart(nullptr)
, body(body) {};

ExprLambda(std::pmr::polymorphic_allocator<char> & alloc, PosIdx pos, FormalsBuilder formals, Expr * body)
: ExprLambda(alloc, pos, Symbol(), formals, body) {};

void setName(Symbol name) override;
std::string showNamePos(const EvalState & state) const;

inline bool hasFormals() const
{
return formals != nullptr;
}

PosIdx getPos() const override
{
return pos;
Expand Down
16 changes: 7 additions & 9 deletions src/libexpr/include/nix/expr/parser-state.hh
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ struct ParserState
void addAttr(
ExprAttrs * attrs, AttrPath && attrPath, const ParserLocation & loc, Expr * e, const ParserLocation & exprLoc);
void addAttr(ExprAttrs * attrs, AttrPath & attrPath, const Symbol & symbol, ExprAttrs::AttrDef && def);
Formals * validateFormals(Formals * formals, PosIdx pos = noPos, Symbol arg = {});
void validateFormals(FormalsBuilder & formals, PosIdx pos = noPos, Symbol arg = {});
Expr * stripIndentation(const PosIdx pos, std::vector<std::pair<PosIdx, std::variant<Expr *, StringToken>>> && es);
PosIdx at(const ParserLocation & loc);
};
Expand Down Expand Up @@ -213,29 +213,27 @@ ParserState::addAttr(ExprAttrs * attrs, AttrPath & attrPath, const Symbol & symb
}
}

inline Formals * ParserState::validateFormals(Formals * formals, PosIdx pos, Symbol arg)
inline void ParserState::validateFormals(FormalsBuilder & formals, PosIdx pos, Symbol arg)
{
std::sort(formals->formals.begin(), formals->formals.end(), [](const auto & a, const auto & b) {
std::sort(formals.formals.begin(), formals.formals.end(), [](const auto & a, const auto & b) {
return std::tie(a.name, a.pos) < std::tie(b.name, b.pos);
});

std::optional<std::pair<Symbol, PosIdx>> duplicate;
for (size_t i = 0; i + 1 < formals->formals.size(); i++) {
if (formals->formals[i].name != formals->formals[i + 1].name)
for (size_t i = 0; i + 1 < formals.formals.size(); i++) {
if (formals.formals[i].name != formals.formals[i + 1].name)
continue;
std::pair thisDup{formals->formals[i].name, formals->formals[i + 1].pos};
std::pair thisDup{formals.formals[i].name, formals.formals[i + 1].pos};
duplicate = std::min(thisDup, duplicate.value_or(thisDup));
}
if (duplicate)
throw ParseError(
{.msg = HintFmt("duplicate formal function argument '%1%'", symbols[duplicate->first]),
.pos = positions[duplicate->second]});

if (arg && formals->has(arg))
if (arg && formals.has(arg))
throw ParseError(
{.msg = HintFmt("duplicate formal function argument '%1%'", symbols[arg]), .pos = positions[pos]});

return formals;
}

inline Expr *
Expand Down
8 changes: 4 additions & 4 deletions src/libexpr/nixexpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ void ExprList::show(const SymbolTable & symbols, std::ostream & str) const
void ExprLambda::show(const SymbolTable & symbols, std::ostream & str) const
{
str << "(";
if (hasFormals()) {
if (auto formals = getFormals()) {
str << "{ ";
bool first = true;
// the natural Symbol ordering is by creation time, which can lead to the
Expand All @@ -171,7 +171,7 @@ void ExprLambda::show(const SymbolTable & symbols, std::ostream & str) const
i.def->show(symbols, str);
}
}
if (formals->ellipsis) {
if (ellipsis) {
if (!first)
str << ", ";
str << "...";
Expand Down Expand Up @@ -452,14 +452,14 @@ void ExprLambda::bindVars(EvalState & es, const std::shared_ptr<const StaticEnv>
es.exprEnvs.insert(std::make_pair(this, env));

auto newEnv =
std::make_shared<StaticEnv>(nullptr, env, (hasFormals() ? formals->formals.size() : 0) + (!arg ? 0 : 1));
std::make_shared<StaticEnv>(nullptr, env, (getFormals() ? getFormals()->formals.size() : 0) + (!arg ? 0 : 1));

Displacement displ = 0;

if (arg)
newEnv->vars.emplace_back(arg, displ++);

if (hasFormals()) {
if (auto formals = getFormals()) {
for (auto & i : formals->formals)
newEnv->vars.emplace_back(i.name, displ++);

Expand Down
28 changes: 16 additions & 12 deletions src/libexpr/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ static Expr * makeCall(PosIdx pos, Expr * fn, Expr * arg) {
%type <nix::Expr *> expr_pipe_from expr_pipe_into
%type <std::vector<Expr *>> list
%type <nix::ExprAttrs *> binds binds1
%type <nix::Formals *> formals formal_set
%type <nix::FormalsBuilder> formals formal_set
%type <nix::Formal> formal
%type <std::vector<nix::AttrName>> attrpath
%type <std::vector<std::pair<nix::AttrName, nix::PosIdx>>> attrs
Expand Down Expand Up @@ -179,26 +179,30 @@ expr: expr_function;

expr_function
: ID ':' expr_function
{ auto me = new ExprLambda(CUR_POS, state->symbols.create($1), 0, $3);
{ auto me = new ExprLambda(CUR_POS, state->symbols.create($1), $3);
$$ = me;
SET_DOC_POS(me, @1);
}
| formal_set ':' expr_function[body]
{ auto me = new ExprLambda(CUR_POS, state->validateFormals($formal_set), $body);
{
state->validateFormals($formal_set);
auto me = new ExprLambda(state->alloc, CUR_POS, std::move($formal_set), $body);
$$ = me;
SET_DOC_POS(me, @1);
}
| formal_set '@' ID ':' expr_function[body]
{
auto arg = state->symbols.create($ID);
auto me = new ExprLambda(CUR_POS, arg, state->validateFormals($formal_set, CUR_POS, arg), $body);
state->validateFormals($formal_set, CUR_POS, arg);
auto me = new ExprLambda(state->alloc, CUR_POS, arg, std::move($formal_set), $body);
$$ = me;
SET_DOC_POS(me, @1);
}
| ID '@' formal_set ':' expr_function[body]
{
auto arg = state->symbols.create($ID);
auto me = new ExprLambda(CUR_POS, arg, state->validateFormals($formal_set, CUR_POS, arg), $body);
state->validateFormals($formal_set, CUR_POS, arg);
auto me = new ExprLambda(state->alloc, CUR_POS, arg, std::move($formal_set), $body);
$$ = me;
SET_DOC_POS(me, @1);
}
Expand Down Expand Up @@ -490,18 +494,18 @@ list
;

formal_set
: '{' formals ',' ELLIPSIS '}' { $$ = $formals; $$->ellipsis = true; }
| '{' ELLIPSIS '}' { $$ = new Formals; $$->ellipsis = true; }
| '{' formals ',' '}' { $$ = $formals; $$->ellipsis = false; }
| '{' formals '}' { $$ = $formals; $$->ellipsis = false; }
| '{' '}' { $$ = new Formals; $$->ellipsis = false; }
: '{' formals ',' ELLIPSIS '}' { $$ = std::move($formals); $$.ellipsis = true; }
| '{' ELLIPSIS '}' { $$.ellipsis = true; }
| '{' formals ',' '}' { $$ = std::move($formals); $$.ellipsis = false; }
| '{' formals '}' { $$ = std::move($formals); $$.ellipsis = false; }
| '{' '}' { $$.ellipsis = false; }
;

formals
: formals[accum] ',' formal
{ $$ = $accum; $$->formals.emplace_back(std::move($formal)); }
{ $$ = std::move($accum); $$.formals.emplace_back(std::move($formal)); }
| formal
{ $$ = new Formals; $$->formals.emplace_back(std::move($formal)); }
{ $$.formals.emplace_back(std::move($formal)); }
;

formal
Expand Down
23 changes: 11 additions & 12 deletions src/libexpr/primops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3363,21 +3363,20 @@ static void prim_functionArgs(EvalState & state, const PosIdx pos, Value ** args
if (!args[0]->isLambda())
state.error<TypeError>("'functionArgs' requires a function").atPos(pos).debugThrow();

if (!args[0]->lambda().fun->hasFormals()) {
if (const auto & formals = args[0]->lambda().fun->getFormals()) {
auto attrs = state.buildBindings(formals->formals.size());
for (auto & i : formals->formals)
attrs.insert(i.name, state.getBool(i.def), i.pos);
/* Optimization: avoid sorting bindings. `formals` must already be sorted according to
(std::tie(a.name, a.pos) < std::tie(b.name, b.pos)) predicate, so the following assertion
always holds:
assert(std::is_sorted(attrs.alreadySorted()->begin(), attrs.alreadySorted()->end()));
.*/
v.mkAttrs(attrs.alreadySorted());
} else {
v.mkAttrs(&Bindings::emptyBindings);
return;
}

const auto & formals = args[0]->lambda().fun->formals->formals;
auto attrs = state.buildBindings(formals.size());
for (auto & i : formals)
attrs.insert(i.name, state.getBool(i.def), i.pos);
/* Optimization: avoid sorting bindings. `formals` must already be sorted according to
(std::tie(a.name, a.pos) < std::tie(b.name, b.pos)) predicate, so the following assertion
always holds:
assert(std::is_sorted(attrs.alreadySorted()->begin(), attrs.alreadySorted()->end()));
.*/
v.mkAttrs(attrs.alreadySorted());
}

static RegisterPrimOp primop_functionArgs({
Expand Down
Loading
Loading