Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat procedure call #30

Merged
merged 11 commits into from
Apr 28, 2024
117 changes: 82 additions & 35 deletions codegen/jasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,61 +6,104 @@ import (
)

type JASM struct {
code *Code
procedureDefinitionName string
labelID int
pst *StackType
endIfLabel string
elseLabel string
repeatLabel string
whileTestLabel string
nextStatementLabel string
forTestLabel string
forVariable string
forStep string
st *SymbolTable
code *Code
pst *StackType
st *SymbolTable

className string
procedureDeclarationName string
procedureStatementName string
labelID int
endIfLabel string
elseLabel string
repeatLabel string
whileTestLabel string
nextStatementLabel string
forTestLabel string
forVariable string
forStep string
}

func NewJASM() *JASM {
return &JASM{
code: NewCode(),
pst: NewStackType(),
st: NewSymbolTable(),
code: NewCode(),
pst: NewStackType(),
st: NewSymbolTable(),
procedureDeclarationName: "main",
}
}

func (j *JASM) StartMainClass(name string) {
j.className = name
j.addLine("// Code generated by POJ 0.1")
j.addLine(fmt.Sprintf("public class %s {", name))
j.incTab()
}

func (j *JASM) FinishMainClass() {
j.finishMain()
j.decTab()
j.addLine("}")
}

func (j *JASM) StartProcedureDeclaration(name string, paramTypes []string) {
j.procedureDeclarationName = name
j.addLine(fmt.Sprintf("static %s(%s)V {", name, j.genSignature(paramTypes)))
j.incTab()
}

func (j *JASM) genSignature(paramTypes []string) string {
javaParams := make([]string, len(paramTypes))
for i, p := range paramTypes {
if p == "string" {
javaParams[i] = "java/lang/String"
} else if p == "integer" {
javaParams[i] = "I"
} else {
javaParams[i] = "UndefinedType"
}
}

return strings.Join(javaParams, ", ")
}

func (j *JASM) FinishProcedureDeclaration() {
j.addLine("return")
j.decTab()
j.addLine("}")
j.procedureDeclarationName = "main"
}

func (j *JASM) StartProcedureStatement(name string) {
j.procedureDefinitionName = name
j.procedureStatementName = name
}

func (j *JASM) FinishProcedureStatement() {
if j.procedureDefinitionName == "writeln" {
func (j *JASM) FinishProcedureStatement() error {
if j.procedureStatementName == "writeln" {
j.addStaticPrintStream()
j.addInvokeVirtual("java/io/PrintStream.println()V")
} else if j.procedureStatementName != "write" {
proc, ok := j.st.Get(j.procedureStatementName)
if !ok {
return fmt.Errorf("procedure %s not found", j.procedureStatementName)
}

j.addInvokeStatic(j.procedureStatementName, j.genSignature(proc.ParamTypes))
}

j.procedureDefinitionName = ""
j.procedureStatementName = ""

return nil
}

func (j *JASM) StartParameter() {
if j.procedureDefinitionName == "write" || j.procedureDefinitionName == "writeln" {
if j.procedureStatementName == "write" || j.procedureStatementName == "writeln" {
j.addStaticPrintStream()
}
}

func (j *JASM) FinishParameter() error {
if j.procedureDefinitionName == "write" || j.procedureDefinitionName == "writeln" {
if j.procedureStatementName == "write" || j.procedureStatementName == "writeln" {
if err := j.addInvokeVirtualPrintWithType(); err != nil {
return err
}
Expand All @@ -69,22 +112,14 @@ func (j *JASM) FinishParameter() error {
return nil
}

func (j *JASM) StartBlock() {
if j.procedureDefinitionName == "" {
func (j *JASM) StartMainBlock() {
if j.procedureDeclarationName == "main" {
// Main block.
j.startMain()
j.procedureDeclarationName = ""
}
}

func (j *JASM) FinishBlock() {
if j.procedureDefinitionName == "" {
// Main block.
j.finishMain()
}

j.procedureDefinitionName = ""
}

func (j *JASM) NewConstantString(constant string) {
j.addLdcStringOpcode(constant)
}
Expand Down Expand Up @@ -331,7 +366,7 @@ func (j *JASM) NewVariable(name, pst string) error {
}

func (j *JASM) FinishAssignmentStatement(varName string) error {
ok, symbol := j.st.Get(varName)
symbol, ok := j.st.Get(varName)
if !ok {
return fmt.Errorf("variable %s not found", varName)
}
Expand All @@ -351,7 +386,7 @@ func (j *JASM) FinishAssignmentStatement(varName string) error {
}

func (j *JASM) LoadVarContent(varName string) error {
ok, symbol := j.st.Get(varName)
symbol, ok := j.st.Get(varName)
if !ok {
return fmt.Errorf("variable %s not found", varName)
}
Expand All @@ -370,6 +405,14 @@ func (j *JASM) LoadVarContent(varName string) error {
return nil
}

func (j *JASM) NewProcedure(name string, paramTypes []string) error {
if err := j.st.AddProcedure(name, paramTypes); err != nil {
return err
}

return nil
}

func (j *JASM) Code() string {
return j.code.Code()
}
Expand Down Expand Up @@ -445,6 +488,10 @@ func (j *JASM) addInvokeVirtual(method string) {
j.addOpcode("invokevirtual", method)
}

func (j *JASM) addInvokeStatic(method, signature string) {
j.addOpcode(fmt.Sprintf("invokestatic %s.%s(%s)V", j.className, method, signature))
}

func (j *JASM) addInvokeVirtualPrintWithType() error {
pt := j.pst.Pop()

Expand Down
10 changes: 3 additions & 7 deletions codegen/jasm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,17 @@ func TestJASM_CompleteBasicTest(t *testing.T) {
expected := "// Code generated by POJ 0.1\npublic class XPTO {\n"
assert.Equal(expected, j.Code())

j.StartBlock()
j.StartMainBlock()
expected += "\tpublic static main([java/lang/String)V {\n"
assert.Equal(expected, j.Code())

j.NewConstantString("\"param1\"")
expected += "\t\tldc \"param1\"\n"
assert.Equal(expected, j.Code())

j.FinishBlock()
j.FinishMainClass()
expected += "\t\treturn\n"
expected += "\t}\n"
assert.Equal(expected, j.Code())

j.FinishMainClass()
expected += "}\n"
assert.Equal(expected, j.Code())
}
Expand All @@ -40,13 +37,12 @@ func TestJASM_HelloWorld(t *testing.T) {

j := codegen.NewJASM()
j.StartMainClass("HelloWorld")
j.StartBlock()
j.StartMainBlock()
j.StartProcedureStatement("writeln")
j.StartParameter()
j.NewConstantString("\"Hello, World\"")
j.FinishParameter()
j.FinishProcedureStatement()
j.FinishBlock()
j.FinishMainClass()
expected := `// Code generated by POJ 0.1
public class HelloWorld {
Expand Down
30 changes: 25 additions & 5 deletions codegen/symbol_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ type SymbolType uint8
const (
UndefinedSymbolType SymbolType = iota
Variable
Procedure
)

type Symbol struct {
SymbolType SymbolType
PascalType PascalType
Index int
ParamTypes []string
}

type SymbolTable struct {
Expand All @@ -36,26 +38,44 @@ func (st *SymbolTable) AddVariable(name string, ptype PascalType) error {
return fmt.Errorf("variable %s already declared", name)
}

st.count++
st.symbols[name] = Symbol{
SymbolType: Variable,
PascalType: ptype,
Index: st.count,
}

st.count++

return nil
}

func (st *SymbolTable) AddProcedure(name string, paramTypes []string) error {
name = strings.ToUpper(name)
if _, ok := st.symbols[name]; ok {
return fmt.Errorf("procedure %s already declared", name)
}

st.symbols[name] = Symbol{ // REFACTOR: Symbol is only to variables?
SymbolType: Procedure,
ParamTypes: paramTypes[:],
// PascalType: ptype,
// Index: st.count,
}

return nil
}

func (st *SymbolTable) Get(name string) (bool, Symbol) {
func (st *SymbolTable) Get(name string) (Symbol, bool) {
name = strings.ToUpper(name)
symbol, ok := st.symbols[name]
if !ok {
return false, Symbol{
return Symbol{
SymbolType: UndefinedSymbolType,
PascalType: Undefined,
Index: -1,
}
ParamTypes: nil,
}, false
}

return true, symbol
return symbol, true
}
53 changes: 48 additions & 5 deletions codegen/symbol_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

func TestSymbolTable_AddVariable(t *testing.T) {
st := codegen.NewSymbolTable()
ok, symbol := st.Get("xpto")
symbol, ok := st.Get("xpto")
assert.Equal(t, false, ok)
assert.Equal(t, codegen.UndefinedSymbolType, symbol.SymbolType)
assert.Equal(t, codegen.Undefined, symbol.PascalType)
Expand All @@ -18,21 +18,64 @@ func TestSymbolTable_AddVariable(t *testing.T) {
err := st.AddVariable("myvar", codegen.Integer)
assert.Nil(t, err)

ok, symbol = st.Get("myvar")
symbol, ok = st.Get("myvar")
assert.Equal(t, true, ok)
assert.Equal(t, codegen.Variable, symbol.SymbolType)
assert.Equal(t, codegen.Integer, symbol.PascalType)
assert.Equal(t, 1, symbol.Index)
assert.Equal(t, 0, symbol.Index)

ok, symbol = st.Get("MyVar")
symbol, ok = st.Get("MyVar")
assert.Equal(t, true, ok)
assert.Equal(t, codegen.Variable, symbol.SymbolType)
assert.Equal(t, codegen.Integer, symbol.PascalType)
assert.Equal(t, 1, symbol.Index)
assert.Equal(t, 0, symbol.Index)

err = st.AddVariable("myvar", codegen.Integer)
assert.NotNil(t, err)

err = st.AddVariable("MyVar", codegen.Integer)
assert.NotNil(t, err)
}

func TestSymbolTable_AddProcedure(t *testing.T) {
st := codegen.NewSymbolTable()
symbol, ok := st.Get("xpto")
assert.Equal(t, false, ok)
assert.Equal(t, codegen.UndefinedSymbolType, symbol.SymbolType)
assert.Equal(t, codegen.Undefined, symbol.PascalType)
assert.Equal(t, -1, symbol.Index)

err := st.AddProcedure("myproc", []string{})
assert.Nil(t, err)

symbol, ok = st.Get("myproc")
assert.Equal(t, true, ok)
assert.Equal(t, codegen.Procedure, symbol.SymbolType)
assert.Equal(t, codegen.Undefined, symbol.PascalType)
assert.Equal(t, 0, symbol.Index)

symbol, ok = st.Get("MyProc")
assert.Equal(t, true, ok)
assert.Equal(t, codegen.Procedure, symbol.SymbolType)
assert.Equal(t, codegen.Undefined, symbol.PascalType)
assert.Equal(t, 0, symbol.Index)

err = st.AddProcedure("myproc", []string{})
assert.NotNil(t, err)

err = st.AddProcedure("MyProc", []string{})
assert.NotNil(t, err)

err = st.AddVariable("MyProc", codegen.Integer)
assert.NotNil(t, err)

err = st.AddProcedure("myproc2", []string{"integer", "string"})
assert.Nil(t, err)

symbol, ok = st.Get("myproc2")
assert.Equal(t, true, ok)
assert.Equal(t, codegen.Procedure, symbol.SymbolType)
assert.Equal(t, []string{"integer", "string"}, symbol.ParamTypes)
assert.Equal(t, codegen.Undefined, symbol.PascalType)
assert.Equal(t, 0, symbol.Index)
}
Loading