diff --git a/src/librustc_mir/transform/deaggregator.rs b/src/librustc_mir/transform/deaggregator.rs index fcdeae6d6c080..e13c8e02137a0 100644 --- a/src/librustc_mir/transform/deaggregator.rs +++ b/src/librustc_mir/transform/deaggregator.rs @@ -36,71 +36,70 @@ impl<'tcx> MirPass<'tcx> for Deaggregator { // In fact, we might not want to trigger in other cases. // Ex: when we could use SROA. See issue #35259 - let mut curr: usize = 0; for bb in mir.basic_blocks_mut() { - let idx = match get_aggregate_statement_index(curr, &bb.statements) { - Some(idx) => idx, - None => continue, - }; - // do the replacement - debug!("removing statement {:?}", idx); - let src_info = bb.statements[idx].source_info; - let suffix_stmts = bb.statements.split_off(idx+1); - let orig_stmt = bb.statements.pop().unwrap(); - let (lhs, rhs) = match orig_stmt.kind { - StatementKind::Assign(ref lhs, ref rhs) => (lhs, rhs), - _ => span_bug!(src_info.span, "expected assign, not {:?}", orig_stmt), - }; - let (agg_kind, operands) = match rhs { - &Rvalue::Aggregate(ref agg_kind, ref operands) => (agg_kind, operands), - _ => span_bug!(src_info.span, "expected aggregate, not {:?}", rhs), - }; - let (adt_def, variant, substs) = match agg_kind { - &AggregateKind::Adt(adt_def, variant, substs, None) => (adt_def, variant, substs), - _ => span_bug!(src_info.span, "expected struct, not {:?}", rhs), - }; - let n = bb.statements.len(); - bb.statements.reserve(n + operands.len() + suffix_stmts.len()); - for (i, op) in operands.iter().enumerate() { - let ref variant_def = adt_def.variants[variant]; - let ty = variant_def.fields[i].ty(tcx, substs); - let rhs = Rvalue::Use(op.clone()); - - let lhs_cast = if adt_def.variants.len() > 1 { - Lvalue::Projection(Box::new(LvalueProjection { - base: lhs.clone(), - elem: ProjectionElem::Downcast(adt_def, variant), - })) - } else { - lhs.clone() + let mut curr: usize = 0; + while let Some(idx) = get_aggregate_statement_index(curr, &bb.statements) { + // do the replacement + debug!("removing statement {:?}", idx); + let src_info = bb.statements[idx].source_info; + let suffix_stmts = bb.statements.split_off(idx+1); + let orig_stmt = bb.statements.pop().unwrap(); + let (lhs, rhs) = match orig_stmt.kind { + StatementKind::Assign(ref lhs, ref rhs) => (lhs, rhs), + _ => span_bug!(src_info.span, "expected assign, not {:?}", orig_stmt), }; - - let lhs_proj = Lvalue::Projection(Box::new(LvalueProjection { - base: lhs_cast, - elem: ProjectionElem::Field(Field::new(i), ty), - })); - let new_statement = Statement { - source_info: src_info, - kind: StatementKind::Assign(lhs_proj, rhs), + let (agg_kind, operands) = match rhs { + &Rvalue::Aggregate(ref agg_kind, ref operands) => (agg_kind, operands), + _ => span_bug!(src_info.span, "expected aggregate, not {:?}", rhs), }; - debug!("inserting: {:?} @ {:?}", new_statement, idx + i); - bb.statements.push(new_statement); - } + let (adt_def, variant, substs) = match agg_kind { + &AggregateKind::Adt(adt_def, variant, substs, None) + => (adt_def, variant, substs), + _ => span_bug!(src_info.span, "expected struct, not {:?}", rhs), + }; + let n = bb.statements.len(); + bb.statements.reserve(n + operands.len() + suffix_stmts.len()); + for (i, op) in operands.iter().enumerate() { + let ref variant_def = adt_def.variants[variant]; + let ty = variant_def.fields[i].ty(tcx, substs); + let rhs = Rvalue::Use(op.clone()); - // if the aggregate was an enum, we need to set the discriminant - if adt_def.variants.len() > 1 { - let set_discriminant = Statement { - kind: StatementKind::SetDiscriminant { - lvalue: lhs.clone(), - variant_index: variant, - }, - source_info: src_info, + let lhs_cast = if adt_def.variants.len() > 1 { + Lvalue::Projection(Box::new(LvalueProjection { + base: lhs.clone(), + elem: ProjectionElem::Downcast(adt_def, variant), + })) + } else { + lhs.clone() + }; + + let lhs_proj = Lvalue::Projection(Box::new(LvalueProjection { + base: lhs_cast, + elem: ProjectionElem::Field(Field::new(i), ty), + })); + let new_statement = Statement { + source_info: src_info, + kind: StatementKind::Assign(lhs_proj, rhs), + }; + debug!("inserting: {:?} @ {:?}", new_statement, idx + i); + bb.statements.push(new_statement); + } + + // if the aggregate was an enum, we need to set the discriminant + if adt_def.variants.len() > 1 { + let set_discriminant = Statement { + kind: StatementKind::SetDiscriminant { + lvalue: lhs.clone(), + variant_index: variant, + }, + source_info: src_info, + }; + bb.statements.push(set_discriminant); }; - bb.statements.push(set_discriminant); - }; - curr = bb.statements.len(); - bb.statements.extend(suffix_stmts); + curr = bb.statements.len(); + bb.statements.extend(suffix_stmts); + } } } } diff --git a/src/test/mir-opt/deaggregator_test_enum_2.rs b/src/test/mir-opt/deaggregator_test_enum_2.rs new file mode 100644 index 0000000000000..02d496b2901e6 --- /dev/null +++ b/src/test/mir-opt/deaggregator_test_enum_2.rs @@ -0,0 +1,57 @@ +// Copyright 2016 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Test that deaggregate fires in more than one basic block + +enum Foo { + A(i32), + B(i32), +} + +fn test1(x: bool, y: i32) -> Foo { + if x { + Foo::A(y) + } else { + Foo::B(y) + } +} + +fn main() {} + +// END RUST SOURCE +// START rustc.node12.Deaggregator.before.mir +// bb1: { +// _6 = _4; +// _0 = Foo::A(_6,); +// goto -> bb3; +// } +// +// bb2: { +// _7 = _4; +// _0 = Foo::B(_7,); +// goto -> bb3; +// } +// END rustc.node12.Deaggregator.before.mir +// START rustc.node12.Deaggregator.after.mir +// bb1: { +// _6 = _4; +// ((_0 as A).0: i32) = _6; +// discriminant(_0) = 0; +// goto -> bb3; +// } +// +// bb2: { +// _7 = _4; +// ((_0 as B).0: i32) = _7; +// discriminant(_0) = 1; +// goto -> bb3; +// } +// END rustc.node12.Deaggregator.after.mir +// diff --git a/src/test/mir-opt/deaggregator_test_multiple.rs b/src/test/mir-opt/deaggregator_test_multiple.rs new file mode 100644 index 0000000000000..a180a69be55af --- /dev/null +++ b/src/test/mir-opt/deaggregator_test_multiple.rs @@ -0,0 +1,48 @@ +// Copyright 2016 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// Test that deaggregate fires more than once per block + +enum Foo { + A(i32), + B, +} + +fn test(x: i32) -> [Foo; 2] { + [Foo::A(x), Foo::A(x)] +} + +fn main() { } + +// END RUST SOURCE +// START rustc.node10.Deaggregator.before.mir +// bb0: { +// _2 = _1; +// _4 = _2; +// _3 = Foo::A(_4,); +// _6 = _2; +// _5 = Foo::A(_6,); +// _0 = [_3, _5]; +// return; +// } +// END rustc.node10.Deaggregator.before.mir +// START rustc.node10.Deaggregator.after.mir +// bb0: { +// _2 = _1; +// _4 = _2; +// ((_3 as A).0: i32) = _4; +// discriminant(_3) = 0; +// _6 = _2; +// ((_5 as A).0: i32) = _6; +// discriminant(_5) = 0; +// _0 = [_3, _5]; +// return; +// } +// END rustc.node10.Deaggregator.after.mir