Skip to content

Commit

Permalink
Improve performance of DB tests
Browse files Browse the repository at this point in the history
  • Loading branch information
arp242 committed Jul 2, 2020
1 parent d1b78e3 commit 3ca1f43
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 108 deletions.
245 changes: 137 additions & 108 deletions gctest/gctest.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ import (
"zgo.at/goatcounter/cron"
"zgo.at/zdb"
"zgo.at/zhttp"
)

var (
schema string
migrations [][]string
"zgo.at/zlog"
"zgo.at/zstd/zstring"
)

type tester interface {
Expand All @@ -35,51 +32,130 @@ type tester interface {
Logf(string, ...interface{})
}

var (
dbname = "goatcounter_test_" + zhttp.Secret()[:25]
db *sqlx.DB
tables []string
)

// DB starts a new database test.
func DB(t tester) (context.Context, func()) {
t.Helper()

clean := func() {}
defer func() {
r := recover()
if r != nil {
clean()
panic(r)
l := zlog.Module("gctest")
//l = l.SetDebug("gctest")

if db == nil {
var err error
if cfg.PgSQL {
{
out, err := exec.Command("createdb", dbname).CombinedOutput()
if err != nil {
t.Fatalf("%s → %s", err, out)
}
l = l.Since("createdb")
}

db, err = sqlx.Connect("postgres", "dbname="+dbname+" sslmode=disable password=x")
} else {
db, err = sqlx.Connect("sqlite3", "file::memory:?cache=shared")
}
if err != nil {
t.Fatalf("connect to DB: %s", err)
}
l = l.Since("connect")

setupDB(t)

l = l.Since("setupDB")

if cfg.PgSQL {
err = db.Select(&tables, `select c.relname as name
from pg_catalog.pg_class c
left join pg_catalog.pg_namespace n on n.oid = c.relnamespace
where
c.relkind = 'r' and
n.nspname <> 'pg_catalog' and
n.nspname <> 'information_schema' and
n.nspname !~ '^pg_toast' and
pg_catalog.pg_table_is_visible(c.oid);`)
} else {
err = db.Select(&tables, `select name from sqlite_master where type='table'`)
}
if err != nil {
t.Fatal(err)
}
}()

dbname := "goatcounter_test_" + zhttp.Secret()[:25]
exclude := []string{"iso_3166_1", "version"}
tables = zstring.Filter(tables, func(t string) bool { return !zstring.Contains(exclude, t) })

if cfg.PgSQL {
// TODO: avoid using shell commands if possible; it's quite slow!
out, err := exec.Command("createdb", dbname).CombinedOutput()
l = l.Since("list tables")
} else {
q := `delete from %s`
if cfg.PgSQL {
// TODO: takes about 450ms, which is rather long. See if we can
// speed this up.
q = `truncate %s restart identity cascade`
}
for _, t := range tables {
db.MustExec(fmt.Sprintf(q, t))
}
if !cfg.PgSQL {
db.MustExec(`delete from sqlite_sequence`)
}

l = l.Since("truncate")
}
ctx := zdb.With(context.Background(), db)

{
_, err := db.ExecContext(ctx, `insert into sites
(code, plan, settings, created_at) values ('test', 'personal', '{}', $1)`,
goatcounter.Now().Format(zdb.Date))
if err != nil {
panic(fmt.Sprintf("%s → %s", err, out))
t.Fatalf("create site: %s", err)
}
l = l.Since("create site")

clean = func() {
go func() {
out, err := exec.Command("dropdb", dbname).CombinedOutput()
if err != nil {
t.Logf("dropdb: %s → %s", err, out)
}
}()
var site goatcounter.Site
err = site.ByID(ctx, 1)
if err != nil {
t.Fatalf("get site: %s", err)
}
ctx = goatcounter.WithSite(ctx, &site)
l = l.Since("get site")
}

var (
db *sqlx.DB
err error
)
if cfg.PgSQL {
db, err = sqlx.Connect("postgres", "dbname="+dbname+" sslmode=disable password=x")
} else {
db, err = sqlx.Connect("sqlite3", "file::memory:?cache=shared")
{
_, err := db.ExecContext(ctx, `insert into users
(site, email, password, created_at) values (1, 'test@example.com', 'xx', $1)`,
goatcounter.Now().Format(zdb.Date))
if err != nil {
t.Fatalf("create site: %s", err)
}
l = l.Since("create user")

var user goatcounter.User
err = user.BySite(ctx, 1)
if err != nil {
t.Fatalf("get user: %s", err)
}
ctx = goatcounter.WithUser(ctx, &user)
l = l.Since("get user")
}
if err != nil {
t.Fatalf("connect to DB: %s", err)

return ctx, func() {
goatcounter.Salts.Clear()

// TODO: run after all tests are done.
// out, err := exec.Command("dropdb", dbname).CombinedOutput()
// if err != nil {
// t.Logf("dropdb: %s → %s", err, out)
// }
}
}

func setupDB(t tester) {
top, err := os.Getwd()
if err != nil {
t.Fatalf(fmt.Sprintf("cannot get cwd: %s", err))
Expand All @@ -95,52 +171,45 @@ func DB(t tester) (context.Context, func()) {
break
}
}

schemapath := top + "/db/schema.sql"
migratepath := top + "/db/migrate/sqlite"
if cfg.PgSQL {
schemapath = top + "/db/schema.pgsql"
migratepath = top + "/db/migrate/pgsql"
}

if schema == "" {
s, err := ioutil.ReadFile(schemapath)
if err != nil {
t.Fatalf("read schema: %v", err)
}
schema = string(s)
_, err = db.ExecContext(context.Background(), schema)
if err != nil {
t.Fatalf("run schema %q: %v", schemapath, err)
}

migs, err := ioutil.ReadDir(migratepath)
if err != nil {
t.Fatalf("read migration directory: %s", err)
}
s, err := ioutil.ReadFile(schemapath)
if err != nil {
t.Fatalf("read schema: %v", err)
}
schema := string(s)
_, err = db.ExecContext(context.Background(), schema)
if err != nil {
t.Fatalf("run schema %q: %v", schemapath, err)
}

for _, m := range migs {
if !strings.HasSuffix(m.Name(), ".sql") {
continue
}
var ran bool
db.Get(&ran, `select 1 from version where name=$1`, m.Name()[:len(m.Name())-4])
if ran {
continue
}
migs, err := ioutil.ReadDir(migratepath)
if err != nil {
t.Fatalf("read migration directory: %s", err)
}

mp := fmt.Sprintf("%s/%s", migratepath, m.Name())
mb, err := ioutil.ReadFile(mp)
if err != nil {
t.Fatalf("read migration: %s", err)
}
migrations = append(migrations, []string{mp, string(mb)})
var migrations [][]string
for _, m := range migs {
if !strings.HasSuffix(m.Name(), ".sql") {
continue
}
} else {
_, err = db.ExecContext(context.Background(), schema)
var ran bool
db.Get(&ran, `select 1 from version where name=$1`, m.Name()[:len(m.Name())-4])
if ran {
continue
}

mp := fmt.Sprintf("%s/%s", migratepath, m.Name())
mb, err := ioutil.ReadFile(mp)
if err != nil {
t.Fatalf("create schema: %s", err)
t.Fatalf("read migration: %s", err)
}
migrations = append(migrations, []string{mp, string(mb)})
}

for _, m := range migrations {
Expand All @@ -149,46 +218,6 @@ func DB(t tester) (context.Context, func()) {
t.Fatalf("run migration %q: %s", m[0], err)
}
}

now := `datetime()`
if cfg.PgSQL {
now = `now()`
}

_, err = db.ExecContext(context.Background(), fmt.Sprintf(
`insert into sites (code, plan, settings, created_at) values
('test', 'personal', '{}', %s);`, now))
if err != nil {
t.Fatalf("create site: %s", err)
}

ctx := zdb.With(context.Background(), db)

var site goatcounter.Site
err = site.ByID(ctx, 1)
if err != nil {
t.Fatalf("get site: %s", err)
}
ctx = goatcounter.WithSite(ctx, &site)

var user goatcounter.User
err = user.BySite(ctx, site.ID)
if err != nil {
user.Site = 1
user.Email = "test@example.com"
user.Password = []byte("coconuts")
err = user.Insert(ctx)
}
if err != nil {
t.Fatalf("get/create user: %s", err)
}
ctx = goatcounter.WithUser(ctx, &user)

return ctx, func() {
db.Close()
goatcounter.Salts.Clear()
clean()
}
}

// StoreHits is a convenient helper to store hits in the DB via Memstore and
Expand Down
35 changes: 35 additions & 0 deletions gctest/gctest_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright © 2019 Martin Tournoij <martin@arp242.net>
// This file is part of GoatCounter and published under the terms of the EUPL
// v1.2, which can be found in the LICENSE file or at http://eupl12.zgo.at

package gctest

import (
"fmt"
"testing"

"zgo.at/zlog"
)

func TestDB(t *testing.T) {
zlog.SetDebug("gctest")
fmt.Println("Run 1")
_, clean := DB(t)
clean()

fmt.Println("\nRun 2")
_, clean = DB(t)
clean()

fmt.Println("\nRun 3")
_, clean = DB(t)
clean()
}

// func BenchmarkTestDBDB(b *testing.B) {
// b.ReportAllocs()
// for n := 0; n < b.N; n++ {
// _, clean := DB(b)
// clean()
// }
// }

0 comments on commit 3ca1f43

Please sign in to comment.