Skip to content

Commit

Permalink
Solver: Keep track of a solution's score as we're computing it.
Browse files Browse the repository at this point in the history
No functionality change here; just staging for some future optimizations.


Swift SVN r11028
  • Loading branch information
DougGregor committed Dec 9, 2013
1 parent ac7ee4b commit 79f8175
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 83 deletions.
51 changes: 0 additions & 51 deletions lib/Sema/CSApply.cpp
Expand Up @@ -3353,54 +3353,3 @@ Solution::convertToArrayBound(Expr *expr, ConstraintLocator *locator) const {

return result;
}

int Solution::getFixedScore() const {
if (fixedScore)
return *fixedScore;

int score = 0;

// Consider overload choices.
for (auto overload : overloadChoices) {
auto choice = overload.second.choice;
if (choice.getKind() != OverloadChoiceKind::Decl)
continue;

// -3 penalty for each user-defined conversion.
if (choice.getDecl()->getAttrs().isConversion())
score -= 3;
}

// Consider type bindings.
auto &tc = getConstraintSystem().getTypeChecker();
for (auto binding : typeBindings) {
// Look for type variables corresponding directly to an expression.
auto typeVar = binding.first;
auto locator = typeVar->getImpl().getLocator();
if (!locator || !locator->getAnchor() || !locator->getPath().empty())
continue;

// Check whether there is a literal protocol corresponding to the
// anchor expression.
auto literalProtocol
= tc.getLiteralProtocol(locator->getAnchor());
if (!literalProtocol)
continue;

// Retrieve the default type for this literal protocol, if there is one.
auto defaultType = tc.getDefaultType(literalProtocol,
getConstraintSystem().DC);
if (!defaultType)
continue;

// +1 if the bound type matches the default type for this literal protocol.
// Literal types are always nominal, so we simply check the nominal
// declaration. This covers e.g., Slice vs. Slice<T>.
if (defaultType->getAnyNominal() == binding.second->getAnyNominal())
++score;
}

// Save the fixed score.
fixedScore = score;
return score;
}
25 changes: 23 additions & 2 deletions lib/Sema/CSRanking.cpp
Expand Up @@ -27,6 +27,20 @@ using namespace constraints;
#define DEBUG_TYPE "Constraint solver overall"
STATISTIC(NumDiscardedSolutions, "# of solutions discarded");

void ConstraintSystem::increaseScore(ScoreKind kind) {
unsigned index = static_cast<unsigned>(kind);
++CurrentScore.Data[index];
}

llvm::raw_ostream &constraints::operator<<(llvm::raw_ostream &out,
const Score &score) {
for (unsigned i = 0; i != NumScoreKinds; ++i) {
if (i) out << ' ';
out << score.Data[i];
}
return out;
}

/// \brief Remove the initializers from any tuple types within the
/// given type.
static Type stripInitializers(TypeChecker &tc, Type origType) {
Expand Down Expand Up @@ -446,6 +460,13 @@ Comparison TypeChecker::compareDeclarations(DeclContext *dc,
return decl1Better? Comparison::Better : Comparison::Worse;
}

/// Simplify a score into a single integer.
/// FIXME: Temporary hack.
static int simplifyScore(const Score &score) {
return (int)score.Data[SK_UserConversion] * -3
+ (int)score.Data[SK_NonDefaultLiteral] * -1;
}

SolutionCompareResult ConstraintSystem::compareSolutions(
ConstraintSystem &cs,
ArrayRef<Solution> solutions,
Expand All @@ -458,8 +479,8 @@ SolutionCompareResult ConstraintSystem::compareSolutions(
// Solution comparison uses a scoring system to determine whether one
// solution is better than the other. Retrieve the fixed scores for each of
// the solutions, which we'll modify with relative scoring.
int score1 = solutions[idx1].getFixedScore();
int score2 = solutions[idx2].getFixedScore();
int score1 = simplifyScore(solutions[idx1].getFixedScore());
int score2 = simplifyScore(solutions[idx2].getFixedScore());

// Compare overload sets.
for (auto &overload : diff.overloads) {
Expand Down
3 changes: 3 additions & 0 deletions lib/Sema/CSSimplify.cpp
Expand Up @@ -526,6 +526,9 @@ tryUserConversion(ConstraintSystem &cs, Type type, ConstraintKind kind,
cs.addConstraint(kind, outputTV, otherType, resultLocator);
}

// We're adding a user-defined conversion.
cs.increaseScore(SK_UserConversion);

return ConstraintSystem::SolutionKind::Solved;
}

Expand Down
22 changes: 18 additions & 4 deletions lib/Sema/CSSolver.cpp
Expand Up @@ -73,7 +73,7 @@ static Optional<Type> checkTypeOfBinding(ConstraintSystem &cs,
Solution ConstraintSystem::finalize(
FreeTypeVariableBinding allowFreeTypeVariables) {
// Create the solution.
Solution solution(*this);
Solution solution(*this, CurrentScore);

// For any of the type variables that has no associated fixed type, assign a
// fresh generic type parameters.
Expand Down Expand Up @@ -123,6 +123,9 @@ Solution ConstraintSystem::finalize(
}

void ConstraintSystem::applySolution(const Solution &solution) {
// Update the score.
CurrentScore += solution.getFixedScore();

// Assign fixed types to the type variables solved by this solution.
llvm::SmallPtrSet<TypeVariableType *, 4>
knownTypeVariables(TypeVariables.begin(), TypeVariables.end());
Expand All @@ -134,7 +137,7 @@ void ConstraintSystem::applySolution(const Solution &solution) {
// If we don't already have a fixed type for this type variable,
// assign the fixed type from the solution.
if (!getFixedType(binding.first) && !binding.second->hasTypeVariable())
assignFixedType(binding.first, binding.second);
assignFixedType(binding.first, binding.second, /*updateScore=*/false);
}

// Register overload choices.
Expand Down Expand Up @@ -583,6 +586,7 @@ ConstraintSystem::SolverScope::SolverScope(ConstraintSystem &cs)
numConstraintRestrictions = cs.solverState->constraintRestrictions.size();
oldGeneratedConstraints = cs.solverState->generatedConstraints;
cs.solverState->generatedConstraints = &generatedConstraints;
PreviousScore = cs.CurrentScore;

++cs.solverState->NumStatesExplored;

Expand Down Expand Up @@ -618,6 +622,9 @@ ConstraintSystem::SolverScope::~SolverScope() {
// Reset the prior generated-constraints pointer.
cs.solverState->generatedConstraints = oldGeneratedConstraints;

// Reset the previous score.
cs.CurrentScore = PreviousScore;

// Clear out other "failed" state.
cs.failedConstraint = nullptr;
}
Expand Down Expand Up @@ -932,7 +939,8 @@ bool ConstraintSystem::solve(SmallVectorImpl<Solution> &solutions,
auto solution = finalize(allowFreeTypeVariables);
if (TC.getLangOpts().DebugConstraintSolver) {
auto &log = getASTContext().TypeCheckerDebug->getStream();
log.indent(solverState->depth * 2) << "(found solution)\n";
log.indent(solverState->depth * 2)
<< "(found solution " << CurrentScore << ")\n";
}

solutions.push_back(std::move(solution));
Expand Down Expand Up @@ -1069,6 +1077,11 @@ bool ConstraintSystem::solve(SmallVectorImpl<Solution> &solutions,
// Move the type variables back, clear out constraints; we're
// ready for the next component.
TypeVariables = std::move(allTypeVariables);

// For each of the partial solutions, substract off the current score.
// It doesn't contribute.
for (auto &solution : partialSolutions[component])
solution.getFixedScore() -= CurrentScore;
}

// Move the constraints back. The system is back in a normal state.
Expand Down Expand Up @@ -1103,7 +1116,8 @@ bool ConstraintSystem::solve(SmallVectorImpl<Solution> &solutions,
auto solution = finalize(allowFreeTypeVariables);
if (TC.getLangOpts().DebugConstraintSolver) {
auto &log = getASTContext().TypeCheckerDebug->getStream();
log.indent(solverState->depth * 2) << "(composed solution)\n";
log.indent(solverState->depth * 2)
<< "(composed solution " << CurrentScore << ")\n";
}

// Save this solution.
Expand Down
46 changes: 44 additions & 2 deletions lib/Sema/ConstraintSystem.cpp
Expand Up @@ -69,9 +69,51 @@ void ConstraintSystem::mergeEquivalenceClasses(TypeVariableType *typeVar1,
}
}

void ConstraintSystem::assignFixedType(TypeVariableType *typeVar, Type type) {
void ConstraintSystem::assignFixedType(TypeVariableType *typeVar, Type type,
bool updateScore) {
typeVar->getImpl().assignFixedType(type, getSavedBindings());


if (updateScore && !type->is<TypeVariableType>()) {
// If this type variable represents a literal, check whether we picked the
// default literal type. First, find the corresponding protocol.
ProtocolDecl *literalProtocol = nullptr;
if (CG) {
// If we have the constraint graph, we can check all type variables in
// the equivalence class. This is the More Correct path.
// FIXME: Eliminate the less-correct path.
auto typeVarRep = getRepresentative(typeVar);
for (auto tv : (*CG)[typeVarRep].getEquivalenceClass()) {
auto locator = tv->getImpl().getLocator();
if (!locator || !locator->getPath().empty())
continue;

auto anchor = locator->getAnchor();
if (!anchor)
continue;

literalProtocol = TC.getLiteralProtocol(anchor);
if (literalProtocol)
break;
}
} else {
// FIXME: This is the less-correct path.
auto locator = typeVar->getImpl().getLocator();
if (locator && locator->getPath().empty() && locator->getAnchor()) {
literalProtocol = TC.getLiteralProtocol(locator->getAnchor());
}
}

// If the protocol has a default type, check it.
if (literalProtocol) {
if (auto defaultType = TC.getDefaultType(literalProtocol, DC)) {
// Check whether the nominal types match. This makes sure that we
// properly handle Slice vs. Slice<T>.
if (defaultType->getAnyNominal() != type->getAnyNominal())
increaseScore(SK_NonDefaultLiteral);
}
}
}

// Notify the constraint graph.
if (CG) {
CG->bindTypeVariable(typeVar, type);
Expand Down

0 comments on commit 79f8175

Please sign in to comment.