/
sqlserver.go
143 lines (121 loc) · 3.08 KB
/
sqlserver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package sqlserver
import (
"bufio"
"bytes"
"context"
"database/sql"
"encoding/gob"
"fmt"
"io"
"net"
"os"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
)
type SqlResult struct {
Columns []string
Data [][]string
Error string
}
func OpenDatabase(name string) (*sql.DB, error) {
// Open the SQLite3 database file
db, err := sql.Open("sqlite3", name)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(20)
db.SetMaxIdleConns(5)
return db, err
}
func ExecQuery(db *sql.DB, conn net.Conn, query string) (*SqlResult, error) {
// Execute the query and send the results back to the client
ctx := context.Background()
ctx, cancelFunc := context.WithDeadline(ctx, time.Now().Add(time.Second*5))
defer cancelFunc()
rows, err := db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
// error ignored since rows are not closed
columns, _ := rows.Columns()
columnCount := len(columns)
values := make([]interface{}, columnCount)
valuePtrs := make([]interface{}, columnCount)
for i := range values {
valuePtrs[i] = &values[i]
}
// Send rows to the client
data := [][]string{}
for rows.Next() {
err = rows.Scan(valuePtrs...)
if err != nil {
fmt.Fprintf(os.Stderr, "[SQLSERVER]: error scanning row: %v\n", err)
return nil, err
}
rowData := []string{}
for _, value := range values {
if value == nil {
rowData = append(rowData, "NULL")
} else {
rowData = append(rowData, fmt.Sprintf("%v", value))
}
}
data = append(data, rowData)
}
return &SqlResult{Columns: columns, Data: data}, nil
}
// Serialize and send the result as raw bytes
func SendQueryResult(conn net.Conn, data *SqlResult) {
var buffer bytes.Buffer
encoder := gob.NewEncoder(&buffer)
encoder.Encode(data)
_, err := io.Copy(conn, &buffer)
if err != nil {
fmt.Printf("[SQLSERVER]: error sending results: %v\n", err)
return
}
}
func HandleConnection(conn net.Conn, db *sql.DB) {
fmt.Println("Connection from", conn.RemoteAddr())
// We close each client connection when the communication loop ends.
defer conn.Close()
// Create a buffered IO Reader from the net.Conn
reader := bufio.NewReader(conn)
// Indifinite communication loop between server and client.
for {
var buf bytes.Buffer
for {
// We read the query until we encounter a semicolon.
// if no semi-colon exists, it blocks.
chunk, err := reader.ReadString(';')
if err != nil && err != io.EOF {
// give the client another life-line.
break
}
// The client has disconnected. We return and it's connection will be closed.
if err == io.EOF {
return
}
// Happy path: Write this chunck to the buffer
buf.WriteString(chunk)
// We are at end of query
if strings.Contains(chunk, ";") {
break
}
}
// Now we have the query. Hopefully safe!! :)
query := strings.TrimSpace(buf.String())
buf.Reset() // reset the buffer
// Execute the query
result, err := ExecQuery(db, conn, query)
if err != nil {
result = &SqlResult{
Error: err.Error(),
}
}
// Send the query results onto the connection.
SendQueryResult(conn, result)
}
}