diff --git a/compiler/rustc_expand/src/proc_macro_server.rs b/compiler/rustc_expand/src/proc_macro_server.rs index ff135f60a822a..8404e16fd1eb4 100644 --- a/compiler/rustc_expand/src/proc_macro_server.rs +++ b/compiler/rustc_expand/src/proc_macro_server.rs @@ -1,7 +1,7 @@ use crate::base::{ExtCtxt, ResolverExpand}; use rustc_ast as ast; -use rustc_ast::token::{self, Nonterminal, NtIdent, TokenKind}; +use rustc_ast::token::{self, Nonterminal, NtIdent}; use rustc_ast::tokenstream::{self, CanSynthesizeMissingTokens}; use rustc_ast::tokenstream::{DelimSpan, Spacing::*, TokenStream, TreeAndSpacing}; use rustc_ast_pretty::pprust; @@ -537,30 +537,49 @@ impl server::Ident for Rustc<'_> { impl server::Literal for Rustc<'_> { fn from_str(&mut self, s: &str) -> Result { - let override_span = None; - let stream = parse_stream_from_source_str( - FileName::proc_macro_source_code(s), - s.to_owned(), - self.sess, - override_span, - ); - if stream.len() != 1 { - return Err(()); - } - let tree = stream.into_trees().next().unwrap(); - let token = match tree { - tokenstream::TokenTree::Token(token) => token, - tokenstream::TokenTree::Delimited { .. } => return Err(()), + let name = FileName::proc_macro_source_code(s); + let mut parser = rustc_parse::new_parser_from_source_str(self.sess, name, s.to_owned()); + + let first_span = parser.token.span.data(); + let minus_present = parser.eat(&token::BinOp(token::Minus)); + + let lit_span = parser.token.span.data(); + let mut lit = match parser.token.kind { + token::Literal(lit) => lit, + _ => return Err(()), }; - let span_data = token.span.data(); - if (span_data.hi.0 - span_data.lo.0) as usize != s.len() { - // There is a comment or whitespace adjacent to the literal. + + // Check no comment or whitespace surrounding the (possibly negative) + // literal, or more tokens after it. + if (lit_span.hi.0 - first_span.lo.0) as usize != s.len() { return Err(()); } - let lit = match token.kind { - TokenKind::Literal(lit) => lit, - _ => return Err(()), - }; + + if minus_present { + // If minus is present, check no comment or whitespace in between it + // and the literal token. + if first_span.hi.0 != lit_span.lo.0 { + return Err(()); + } + + // Check literal is a kind we allow to be negated in a proc macro token. + match lit.kind { + token::LitKind::Bool + | token::LitKind::Byte + | token::LitKind::Char + | token::LitKind::Str + | token::LitKind::StrRaw(_) + | token::LitKind::ByteStr + | token::LitKind::ByteStrRaw(_) + | token::LitKind::Err => return Err(()), + token::LitKind::Integer | token::LitKind::Float => {} + } + + // Synthesize a new symbol that includes the minus sign. + let symbol = Symbol::intern(&s[..1 + lit.symbol.len()]); + lit = token::Lit::new(lit.kind, symbol, lit.suffix); + } + Ok(Literal { lit, span: self.call_site }) } fn debug_kind(&mut self, literal: &Self::Literal) -> String { diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index a8f969782b22d..6c2f69c4b671f 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -1588,6 +1588,10 @@ impl Symbol { self.0.as_u32() } + pub fn len(self) -> usize { + with_interner(|interner| interner.get(self).len()) + } + pub fn is_empty(self) -> bool { self == kw::Empty } diff --git a/src/test/ui/proc-macro/auxiliary/api/parse.rs b/src/test/ui/proc-macro/auxiliary/api/parse.rs index 4105236b7f2d3..a304c5e81a4bb 100644 --- a/src/test/ui/proc-macro/auxiliary/api/parse.rs +++ b/src/test/ui/proc-macro/auxiliary/api/parse.rs @@ -1,9 +1,15 @@ use proc_macro::Literal; pub fn test() { + test_display_literal(); test_parse_literal(); } +fn test_display_literal() { + assert_eq!(Literal::isize_unsuffixed(-10).to_string(), "- 10"); + assert_eq!(Literal::isize_suffixed(-10).to_string(), "- 10isize"); +} + fn test_parse_literal() { assert_eq!("1".parse::().unwrap().to_string(), "1"); assert_eq!("1.0".parse::().unwrap().to_string(), "1.0"); @@ -12,7 +18,10 @@ fn test_parse_literal() { assert_eq!("b\"\"".parse::().unwrap().to_string(), "b\"\""); assert_eq!("r##\"\"##".parse::().unwrap().to_string(), "r##\"\"##"); assert_eq!("10ulong".parse::().unwrap().to_string(), "10ulong"); + assert_eq!("-10ulong".parse::().unwrap().to_string(), "- 10ulong"); + assert!("true".parse::().is_err()); + assert!(".8".parse::().is_err()); assert!("0 1".parse::().is_err()); assert!("'a".parse::().is_err()); assert!(" 0".parse::().is_err()); @@ -20,4 +29,6 @@ fn test_parse_literal() { assert!("/* comment */0".parse::().is_err()); assert!("0/* comment */".parse::().is_err()); assert!("0// comment".parse::().is_err()); + assert!("- 10".parse::().is_err()); + assert!("-'x'".parse::().is_err()); }