Skip to content

Commit

Permalink
[WGSL] Implement shift expression parsing
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=251669
rdar://problem/104997375

Reviewed by Tadeu Zagallo.

Handle the parsing of expressions involving << & >> following the WGSL spec.

Canonical link: https://commits.webkit.org/260166@main
  • Loading branch information
djg committed Feb 12, 2023
1 parent 153929d commit 9bc491c
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 95 deletions.
14 changes: 10 additions & 4 deletions Source/WebGPU/WGSL/Lexer.cpp
Expand Up @@ -76,10 +76,18 @@ Token Lexer<T>::lex()
return makeToken(TokenType::Equal);
case '>':
shift();
return makeToken(TokenType::GT);
if (m_current == '>') {
shift();
return makeToken(TokenType::GtGt);
}
return makeToken(TokenType::Gt);
case '<':
shift();
return makeToken(TokenType::LT);
if (m_current == '<') {
shift();
return makeToken(TokenType::LtLt);
}
return makeToken(TokenType::Lt);
case '@':
shift();
return makeToken(TokenType::Attribute);
Expand Down Expand Up @@ -116,15 +124,13 @@ Token Lexer<T>::lex()
return makeToken(TokenType::MinusMinus);
}
return makeToken(TokenType::Minus);
break;
case '+':
shift();
if (m_current == '+') {
shift();
return makeToken(TokenType::PlusPlus);
}
return makeToken(TokenType::Plus);
break;
case '0': {
shift();
double literalValue = 0;
Expand Down
143 changes: 90 additions & 53 deletions Source/WebGPU/WGSL/Parser.cpp
Expand Up @@ -111,6 +111,29 @@ namespace WGSL {
} \
} while (false)

static bool canContinueMultiplicativeExpression(const Token& token)
{
switch (token.m_type) {
case TokenType::Modulo:
case TokenType::Slash:
case TokenType::Star:
return true;
default:
return false;
}
}

static bool canContinueAdditiveExpression(const Token& token)
{
switch (token.m_type) {
case TokenType::Minus:
case TokenType::Plus:
return true;
default:
return canContinueMultiplicativeExpression(token);
}
}

template<typename Lexer>
std::optional<Error> parse(ShaderModule& shaderModule)
{
Expand Down Expand Up @@ -361,9 +384,9 @@ template<typename Lexer>
Expected<AST::TypeName::Ref, Error> Parser<Lexer>::parseTypeNameAfterIdentifier(AST::Identifier&& name, SourcePosition _startOfElementPosition)
{
if (auto kind = AST::ParameterizedTypeName::stringViewToKind(name.id())) {
CONSUME_TYPE(LT);
CONSUME_TYPE(Lt);
PARSE(elementType, TypeName);
CONSUME_TYPE(GT);
CONSUME_TYPE(Gt);
RETURN_NODE_REF(ParameterizedTypeName, *kind, WTFMove(elementType));
}
RETURN_NODE_REF(NamedTypeName, WTFMove(name));
Expand All @@ -379,7 +402,7 @@ Expected<AST::TypeName::Ref, Error> Parser<Lexer>::parseArrayType()
AST::TypeName::Ptr maybeElementType;
AST::Expression::Ptr maybeElementCount;

if (current().m_type == TokenType::LT) {
if (current().m_type == TokenType::Lt) {
// We differ from the WGSL grammar here by allowing the type to be optional,
// which allows us to use `parseArrayType` in `parseCallExpression`.
consume();
Expand All @@ -395,10 +418,10 @@ Expected<AST::TypeName::Ref, Error> Parser<Lexer>::parseArrayType()
// The WGSL grammar doesn't specify expression operator precedence so
// until then just parse AdditiveExpression.
PARSE(elementCountLHS, UnaryExpression);
PARSE(elementCount, AdditiveExpression, WTFMove(elementCountLHS));
PARSE(elementCount, AdditiveExpressionPostUnary, WTFMove(elementCountLHS));
maybeElementCount = elementCount.moveToUniquePtr();
}
CONSUME_TYPE(GT);
CONSUME_TYPE(Gt);
}

RETURN_NODE_REF(ArrayTypeName, WTFMove(maybeElementType), WTFMove(maybeElementCount));
Expand All @@ -418,7 +441,7 @@ Expected<AST::Variable::Ref, Error> Parser<Lexer>::parseVariableWithAttributes(A
CONSUME_TYPE(KeywordVar);

std::unique_ptr<AST::VariableQualifier> maybeQualifier = nullptr;
if (current().m_type == TokenType::LT) {
if (current().m_type == TokenType::Lt) {
PARSE(variableQualifier, VariableQualifier);
maybeQualifier = WTF::makeUnique<AST::VariableQualifier>(WTFMove(variableQualifier));
}
Expand Down Expand Up @@ -447,7 +470,7 @@ Expected<AST::VariableQualifier, Error> Parser<Lexer>::parseVariableQualifier()
{
START_PARSE();

CONSUME_TYPE(LT);
CONSUME_TYPE(Lt);
PARSE(storageClass, StorageClass);

// FIXME: verify that Read is the correct default in all cases.
Expand All @@ -458,7 +481,7 @@ Expected<AST::VariableQualifier, Error> Parser<Lexer>::parseVariableQualifier()
accessMode = actualAccessMode;
}

CONSUME_TYPE(GT);
CONSUME_TYPE(Gt);

RETURN_NODE(VariableQualifier, storageClass, accessMode);
}
Expand Down Expand Up @@ -631,79 +654,93 @@ Expected<AST::ReturnStatement, Error> Parser<Lexer>::parseReturnStatement()
}

template<typename Lexer>
Expected<UniqueRef<AST::Expression>, Error> Parser<Lexer>::parseRelationalExpression(AST::Expression::Ref&& lhs)
Expected<UniqueRef<AST::Expression>, Error> Parser<Lexer>::parseRelationalExpressionPostUnary(AST::Expression::Ref&& lhs)
{
// FIXME: fill in
return parseShiftExpression(WTFMove(lhs));
return parseShiftExpressionPostUnary(WTFMove(lhs));
}

template<typename Lexer>
Expected<UniqueRef<AST::Expression>, Error> Parser<Lexer>::parseShiftExpression(AST::Expression::Ref&& lhs)
Expected<UniqueRef<AST::Expression>, Error> Parser<Lexer>::parseShiftExpressionPostUnary(AST::Expression::Ref&& lhs)
{
// FIXME: fill in
return parseAdditiveExpression(WTFMove(lhs));
}
if (canContinueAdditiveExpression(current()))
return parseAdditiveExpressionPostUnary(WTFMove(lhs));

template<typename Lexer>
Expected<AST::BinaryOperation, Error> Parser<Lexer>::parseAdditiveOperator()
{
START_PARSE();

switch (current().m_type) {
case TokenType::Minus:
case TokenType::GtGt: {
consume();
return AST::BinaryOperation::Subtract;
case TokenType::Plus:
PARSE(rhs, UnaryExpression);
RETURN_NODE_UNIQUE_REF(BinaryExpression, WTFMove(lhs), WTFMove(rhs), AST::BinaryOperation::RightShift);
}

case TokenType::LtLt: {
consume();
return AST::BinaryOperation::Add;
PARSE(rhs, UnaryExpression);
RETURN_NODE_UNIQUE_REF(BinaryExpression, WTFMove(lhs), WTFMove(rhs), AST::BinaryOperation::LeftShift);
}

default:
FAIL("Expected one of + or -"_s);
return WTFMove(lhs);
}
}

template<typename Lexer>
Expected<UniqueRef<AST::Expression>, Error> Parser<Lexer>::parseAdditiveExpression(AST::Expression::Ref&& lhs)
Expected<UniqueRef<AST::Expression>, Error> Parser<Lexer>::parseAdditiveExpressionPostUnary(AST::Expression::Ref&& lhs)
{
START_PARSE();
PARSE_MOVE(lhs, MultiplicativeExpression, WTFMove(lhs));
PARSE_MOVE(lhs, MultiplicativeExpressionPostUnary, WTFMove(lhs));

while (canContinueAdditiveExpression(current())) {
auto op = AST::BinaryOperation::Add;
switch (current().m_type) {
case TokenType::Minus:
op = AST::BinaryOperation::Subtract;
break;

case TokenType::Plus:
op = AST::BinaryOperation::Add;
break;

default:
// parseMultiplicativeExpression handles multiplicative operators so
// token should be PLUS or MINUS.
RELEASE_ASSERT_NOT_REACHED_WITH_MESSAGE("Expected + or -");
}

while (current().m_type == TokenType::Plus || current().m_type == TokenType::Minus) {
PARSE(op, AdditiveOperator);
consume();
PARSE(unary, UnaryExpression);
PARSE(rhs, MultiplicativeExpression, WTFMove(unary));
PARSE(rhs, MultiplicativeExpressionPostUnary, WTFMove(unary));
lhs = MAKE_NODE_UNIQUE_REF(BinaryExpression, WTFMove(lhs), WTFMove(rhs), op);
}

return WTFMove(lhs);
}

template<typename Lexer>
Expected<AST::BinaryOperation, Error> Parser<Lexer>::parseMultiplicativeOperator()
Expected<UniqueRef<AST::Expression>, Error> Parser<Lexer>::parseMultiplicativeExpressionPostUnary(AST::Expression::Ref&& lhs)
{
START_PARSE();
switch (current().m_type) {
case TokenType::Modulo:
consume();
return AST::BinaryOperation::Modulo;
case TokenType::Slash:
consume();
return AST::BinaryOperation::Divide;
case TokenType::Star:
consume();
return AST::BinaryOperation::Multiply;
default:
FAIL("Expected one of %, / or *"_s);
}
}
while (canContinueMultiplicativeExpression(current())) {
auto op = AST::BinaryOperation::Multiply;
switch (current().m_type) {
case TokenType::Modulo:
op = AST::BinaryOperation::Modulo;
break;

template<typename Lexer>
Expected<UniqueRef<AST::Expression>, Error> Parser<Lexer>::parseMultiplicativeExpression(AST::Expression::Ref&& lhs)
{
START_PARSE();
while (current().m_type == TokenType::Modulo
|| current().m_type == TokenType::Slash
|| current().m_type == TokenType::Star) {
PARSE(op, MultiplicativeOperator)
case TokenType::Slash:
op = AST::BinaryOperation::Divide;
break;

case TokenType::Star:
op = AST::BinaryOperation::Multiply;
break;

default:
RELEASE_ASSERT_NOT_REACHED();
}

consume();
PARSE(rhs, UnaryExpression);
lhs = MAKE_NODE_UNIQUE_REF(BinaryExpression, WTFMove(lhs), WTFMove(rhs), op);
}
Expand Down Expand Up @@ -790,7 +827,7 @@ Expected<UniqueRef<AST::Expression>, Error> Parser<Lexer>::parsePrimaryExpressio
}
case TokenType::Identifier: {
PARSE(ident, Identifier);
if (current().m_type == TokenType::LT || current().m_type == TokenType::ParenLeft) {
if (current().m_type == TokenType::Lt || current().m_type == TokenType::ParenLeft) {
PARSE(type, TypeNameAfterIdentifier, WTFMove(ident), _startOfElementPosition);
PARSE(arguments, ArgumentExpressionList);
RETURN_NODE_UNIQUE_REF(CallExpression, WTFMove(type), WTFMove(arguments));
Expand Down Expand Up @@ -844,7 +881,7 @@ Expected<UniqueRef<AST::Expression>, Error> Parser<Lexer>::parseExpression()
{
// FIXME: Fill in
PARSE(lhs, UnaryExpression);
return parseRelationalExpression(WTFMove(lhs));
return parseRelationalExpressionPostUnary(WTFMove(lhs));
}

template<typename Lexer>
Expand Down
10 changes: 4 additions & 6 deletions Source/WebGPU/WGSL/ParserPrivate.h
Expand Up @@ -73,12 +73,10 @@ class Parser {
Expected<AST::CompoundStatement, Error> parseCompoundStatement();
Expected<AST::ReturnStatement, Error> parseReturnStatement();
Expected<AST::Expression::Ref, Error> parseShortCircuitOrExpression();
Expected<AST::Expression::Ref, Error> parseRelationalExpression(AST::Expression::Ref&& lhs);
Expected<AST::Expression::Ref, Error> parseShiftExpression(AST::Expression::Ref&& lhs);
Expected<AST::Expression::Ref, Error> parseAdditiveExpression(AST::Expression::Ref&& lhs);
Expected<AST::BinaryOperation, Error> parseAdditiveOperator();
Expected<AST::Expression::Ref, Error> parseMultiplicativeExpression(AST::Expression::Ref&& lhs);
Expected<AST::BinaryOperation, Error> parseMultiplicativeOperator();
Expected<AST::Expression::Ref, Error> parseRelationalExpressionPostUnary(AST::Expression::Ref&& lhs);
Expected<AST::Expression::Ref, Error> parseShiftExpressionPostUnary(AST::Expression::Ref&& lhs);
Expected<AST::Expression::Ref, Error> parseAdditiveExpressionPostUnary(AST::Expression::Ref&& lhs);
Expected<AST::Expression::Ref, Error> parseMultiplicativeExpressionPostUnary(AST::Expression::Ref&& lhs);
Expected<AST::Expression::Ref, Error> parseUnaryExpression();
Expected<AST::Expression::Ref, Error> parseSingularExpression();
Expected<AST::Expression::Ref, Error> parsePostfixExpression(AST::Expression::Ref&& base, SourcePosition startPosition);
Expand Down
8 changes: 6 additions & 2 deletions Source/WebGPU/WGSL/Token.cpp
Expand Up @@ -105,10 +105,14 @@ String toString(TokenType type)
return ","_s;
case TokenType::Equal:
return "="_s;
case TokenType::GT:
case TokenType::Gt:
return ">"_s;
case TokenType::LT:
case TokenType::GtGt:
return ">>"_s;
case TokenType::Lt:
return "<"_s;
case TokenType::LtLt:
return "<<"_s;
case TokenType::Minus:
return "-"_s;
case TokenType::MinusMinus:
Expand Down
6 changes: 4 additions & 2 deletions Source/WebGPU/WGSL/Token.h
Expand Up @@ -79,8 +79,10 @@ enum class TokenType: uint32_t {
Colon,
Comma,
Equal,
GT,
LT,
Gt,
GtGt,
Lt,
LtLt,
Minus,
MinusMinus,
Modulo,
Expand Down

0 comments on commit 9bc491c

Please sign in to comment.