diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index db41ef782b8..b0dffd1ec53 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -427,24 +427,14 @@ void Fusion::removeStatementsCreatedAfter( int64_t num_vals_before) { auto* c = ir_container(); - NVF_ERROR( - c->exprs_up_.size() == c->exprs_.size(), - "exprs_up_ (size ", - c->exprs_up_.size(), - ") and exprs_ (size ", - c->exprs_.size(), - ") are out of sync."); - NVF_ERROR( - std::ssize(c->exprs_up_) >= num_exprs_before, - "exprs_up_ size (", - std::ssize(c->exprs_up_), - ") is less than num_exprs_before (", - num_exprs_before, - ")."); - // Remove expressions before values because we need to change Val::uses_. - while (std::ssize(c->exprs_up_) > num_exprs_before) { + while (std::ssize(c->exprsOwnedBy(this)) > num_exprs_before) { + // Pop from global deque back — statements created by this Fusion during + // the guard scope are at the tail (LIFO invariant). Expr* e = c->exprs_up_.back().get(); + NVF_ERROR( + c->per_fusion_exprs_[this].count(e) > 0, + "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); for (Val* in : e->inputs()) { in->removeUse(e); } @@ -453,8 +443,12 @@ void Fusion::removeStatementsCreatedAfter( c->exprs_up_.pop_back(); } - while (std::ssize(c->vals_up_) > num_vals_before) { + while (numValsExcludingShortcuts() > num_vals_before) { Val* v = c->vals_up_.back().get(); + NVF_ERROR( + c->per_fusion_vals_[this].count(v) > 0, + "removeStatementsCreatedAfter: tail val belongs to another Fusion"); + // Null out shortcut caches if they point to vals about to be destroyed if (v == zero_val_) { zero_val_ = nullptr; } else if (v == one_val_) { diff --git a/csrc/fusion.h b/csrc/fusion.h index d244517a703..f2b50425ecd 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -546,13 +546,26 @@ class NVF_API Fusion : public PolymorphicBase { return ir_container()->valsOwnedBy(this); } - // Count queries + // Count queries (per-Fusion: only counts statements owned by this Fusion) int64_t numExprs() const noexcept { - return ir_container()->numExprs(); + return std::ssize(ir_container()->exprsOwnedBy(this)); } int64_t numVals() const noexcept { - return ir_container()->numVals(); + return std::ssize(ir_container()->valsOwnedBy(this)); + } + + //! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.). + //! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but + //! since they're singletons that should persist across StatementGuard scopes, + //! this count excludes them so the LIFO pop-back in + //! removeStatementsCreatedAfter correctly skips over them. + int64_t numValsExcludingShortcuts() const noexcept { + int64_t count = std::ssize(ir_container()->valsOwnedBy(this)); + count -= (zero_val_ != nullptr) + (one_val_ != nullptr) + + (true_val_ != nullptr) + (false_val_ != nullptr) + + (magic_zero_val_ != nullptr); + return count; } // Shortcut values (frequently used constants) diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index e7a64d0afea..bfe811cc124 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -223,27 +223,27 @@ void IrContainer::transferStatementOwnership( void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) { auto vals_it = per_fusion_vals_.find(fusion); if (vals_it != per_fusion_vals_.end()) { - for (auto it = vals_up_.begin(); it != vals_up_.end();) { - if (vals_it->second.count(it->get()) > 0) { - vals_.erase(it->get()); - it = vals_up_.erase(it); - } else { - ++it; + const auto& owned = vals_it->second; + std::erase_if(vals_up_, [&](const std::unique_ptr& v) { + if (owned.count(v.get()) > 0) { + vals_.erase(v.get()); + return true; } - } + return false; + }); per_fusion_vals_.erase(vals_it); } auto exprs_it = per_fusion_exprs_.find(fusion); if (exprs_it != per_fusion_exprs_.end()) { - for (auto it = exprs_up_.begin(); it != exprs_up_.end();) { - if (exprs_it->second.count(it->get()) > 0) { - exprs_.erase(it->get()); - it = exprs_up_.erase(it); - } else { - ++it; + const auto& owned = exprs_it->second; + std::erase_if(exprs_up_, [&](const std::unique_ptr& e) { + if (owned.count(e.get()) > 0) { + exprs_.erase(e.get()); + return true; } - } + return false; + }); per_fusion_exprs_.erase(exprs_it); } } diff --git a/csrc/statement_guard.cpp b/csrc/statement_guard.cpp index 4575bb59076..15a3b4159c3 100644 --- a/csrc/statement_guard.cpp +++ b/csrc/statement_guard.cpp @@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion) return fusion; }()), prev_num_exprs_(fusion_->numExprs()), - prev_num_vals_(fusion_->numVals()) {} + prev_num_vals_(fusion_->numValsExcludingShortcuts()) {} StatementGuard::~StatementGuard() { fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_);