Skip to content

Commit eeadb13

Browse files
committed
Add MSE loss and ability for non-categorical model types
Add configurable y processor for single file and values dataset Add helper methods for binary/sparse categorical raw/tokenized y processors Add concurrency to values dataset to speed things up Update examples with new y processor design Fix example output of jobs Remove testing monet example Update transfer learning example with new weights get/set design Add LeakyReLU layer Add model methods to get weights in correct order / specific layer weights Remove hardcoded [][]int32 yTrue and [][]float32 yPred casts to allow for other model types Make categorical tokenizers have unlimited numWords Add early stopping callback for when a metric flatlines
1 parent 887c350 commit eeadb13

File tree

110 files changed

+1018
-987
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+1018
-987
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ examples-vanilla:
109109
examples-vanilla-gpu:
110110
go generate ./...
111111
docker-compose up -d tf-jupyter-golang-gpu
112-
docker-compose exec tf-jupyter-golang-gpu sh -c "cd /go/src/tfkg/examples/vanilla && go run main.go"
112+
docker-compose exec tf-jupyter-golang-gpu sh -c "cd /go/src/tfkg/examples/vanilla && python generate_vanilla_model.py && go run main.go"
113113

114114
examples-vanilla-raw:
115115
go generate ./...
116-
cd examples/vanilla && go run main.go
116+
cd examples/vanilla && python generate_vanilla_model.py && go run main.go
117117

118118
test-python:
119119
docker-compose up -d tf-jupyter-golang

callback/early_stopping.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package callback
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
)
7+
8+
type EarlyStoppingOnMetricMode string
9+
10+
var (
11+
Min EarlyStoppingOnMetricMode = "min"
12+
Max EarlyStoppingOnMetricMode = "max"
13+
)
14+
15+
type EarlyStoppingOnMetric struct {
16+
OnEvent Event
17+
OnMode Mode
18+
MetricName string
19+
Mode EarlyStoppingOnMetricMode
20+
MaxValue float64
21+
MinValue float64
22+
}
23+
24+
func (c *EarlyStoppingOnMetric) Init() error {
25+
if c.OnEvent == "" {
26+
return fmt.Errorf("no OnEvent set for callback")
27+
}
28+
if c.OnMode == "" {
29+
return fmt.Errorf("no OnMode set for callback")
30+
}
31+
if c.Mode == "" {
32+
return fmt.Errorf("no Mode set for callback")
33+
}
34+
if c.MetricName == "" {
35+
return fmt.Errorf("unhandled early stopping on metric mode")
36+
}
37+
38+
return nil
39+
}
40+
41+
func (c *EarlyStoppingOnMetric) Call(event Event, mode Mode, epoch int, batch int, logs []Log) ([]Action, error) {
42+
if event != c.OnEvent || mode != c.OnMode {
43+
return []Action{ActionNop}, nil
44+
}
45+
var metricValue float64
46+
if c.MetricName != "" {
47+
found := false
48+
for _, log := range logs {
49+
if strings.ToLower(log.Name) == strings.ToLower(c.MetricName) {
50+
metricValue = log.Value
51+
found = true
52+
}
53+
}
54+
if !found {
55+
return []Action{ActionNop}, fmt.Errorf("metric %s does not exist for the model", c.MetricName)
56+
}
57+
}
58+
59+
if c.Mode == Min {
60+
if metricValue <= c.MinValue {
61+
return []Action{ActionHalt}, nil
62+
}
63+
} else if c.Mode == Max {
64+
if metricValue >= c.MaxValue {
65+
return []Action{ActionHalt}, nil
66+
}
67+
}
68+
69+
return []Action{ActionNop}, nil
70+
}

data/single_file_dataset.go

Lines changed: 54 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"math/rand"
2020
"os"
2121
"path/filepath"
22-
"strconv"
2322
"strings"
2423
"sync"
2524
"sync/atomic"
@@ -42,7 +41,7 @@ type SingleFileDataset struct {
4241
generatorOffset *int32
4342
generatorOffsetLock *sync.Mutex
4443
cacheDir string
45-
categoryOffset int
44+
yProcessor *preprocessor.Processor
4645
columnProcessors []*preprocessor.Processor
4746
lineOffsets []int64
4847
trainPercent float32
@@ -65,7 +64,6 @@ type SingleFileDataset struct {
6564
type SingleFileDatasetConfig struct {
6665
FilePath string
6766
CacheDir string
68-
CategoryOffset int
6967
TrainPercent float32
7068
ValPercent float32
7169
TestPercent float32
@@ -81,6 +79,7 @@ func NewSingleFileDataset(
8179
logger *cblog.Logger,
8280
errorHandler *cberrors.ErrorsContainer,
8381
config SingleFileDatasetConfig,
82+
yProcessor *preprocessor.Processor,
8483
columnProcessors ...*preprocessor.Processor,
8584
) (*SingleFileDataset, error) {
8685

@@ -113,7 +112,7 @@ func NewSingleFileDataset(
113112
skipHeaders: config.SkipHeaders,
114113
ignoreParseErrors: config.IgnoreParseErrors,
115114
cacheDir: config.CacheDir,
116-
categoryOffset: config.CategoryOffset,
115+
yProcessor: yProcessor,
117116
columnProcessors: columnProcessors,
118117
trainPercent: config.TrainPercent,
119118
valPercent: config.ValPercent,
@@ -150,19 +149,7 @@ func NewSingleFileDataset(
150149
return nil, e
151150
}
152151

153-
if _, e := os.Stat(filepath.Join(config.CacheDir, "category-tokenizer.json")); e == nil {
154-
d.categoryTokenizer = preprocessor.NewTokenizer(
155-
errorHandler,
156-
1,
157-
-1,
158-
preprocessor.TokenizerConfig{IsCategoryTokenizer: true, DisableFiltering: true},
159-
)
160-
e = d.categoryTokenizer.Load(filepath.Join(config.CacheDir, "category-tokenizer.json"))
161-
if e != nil {
162-
errorHandler.Error(e)
163-
return nil, e
164-
}
165-
}
152+
_ = yProcessor.Load()
166153

167154
e = d.readLineOffsets()
168155
if e != nil {
@@ -270,11 +257,11 @@ func (d *SingleFileDataset) readLineOffsets() error {
270257
return
271258
}
272259
}
273-
if len(line) < d.categoryOffset {
260+
if len(line) < d.yProcessor.LineOffset {
274261
if d.ignoreParseErrors {
275262
return
276263
}
277-
e = fmt.Errorf("len(line) (%d) < d.categoryOffset (%d)", len(line), d.categoryOffset)
264+
e = fmt.Errorf("len(line) (%d) < d.yProcessor.LineOffset (%d)", len(line), d.yProcessor.LineOffset)
278265
d.errorHandler.Error(e)
279266
errs = append(errs, e)
280267
return
@@ -283,28 +270,36 @@ func (d *SingleFileDataset) readLineOffsets() error {
283270
d.lineOffsets = append(d.lineOffsets, offset)
284271
d.Count++
285272

286-
category := line[d.categoryOffset]
273+
if d.yProcessor.RequiresFit {
274+
e = d.yProcessor.FitString([]string{line[d.yProcessor.LineOffset]})
275+
if e != nil {
276+
if d.ignoreParseErrors {
277+
return
278+
}
279+
d.errorHandler.Error(e)
280+
errs = append(errs, e)
281+
return
282+
}
283+
}
287284

288-
categoryInt, e := strconv.Atoi(category)
289-
// TODO: this magical behaviour could be nicer
285+
category, e := d.yProcessor.ProcessString([]string{line[d.yProcessor.LineOffset]})
290286
if e != nil {
291-
if d.categoryTokenizer == nil {
292-
d.categoryTokenizer = preprocessor.NewTokenizer(
293-
d.errorHandler,
294-
1,
295-
-1,
296-
preprocessor.TokenizerConfig{IsCategoryTokenizer: true, DisableFiltering: true},
297-
)
287+
if d.ignoreParseErrors {
288+
return
298289
}
299-
d.categoryTokenizer.Fit(category)
300-
categoryInt = int(d.categoryTokenizer.Tokenize(category)[0])
290+
d.errorHandler.Error(e)
291+
errs = append(errs, e)
292+
return
301293
}
302294

303-
d.classCountsLock.Lock()
304-
count := d.ClassCounts[categoryInt]
305-
count++
306-
d.ClassCounts[categoryInt] = count
307-
d.classCountsLock.Unlock()
295+
categoryInt, isInt := category.Value().([][]int32)
296+
if isInt {
297+
d.classCountsLock.Lock()
298+
count := d.ClassCounts[int(categoryInt[0][0])]
299+
count++
300+
d.ClassCounts[int(categoryInt[0][0])] = count
301+
d.classCountsLock.Unlock()
302+
}
308303
}(readBytes, offset)
309304

310305
now := time.Now().Unix()
@@ -641,7 +636,7 @@ func (d *SingleFileDataset) Generate(batchSize int) ([]*tf.Tensor, *tf.Tensor, *
641636
var x []*tf.Tensor
642637

643638
xStrings := make([][]string, len(d.columnProcessors))
644-
var yInts [][]int32
639+
var yRaw []string
645640

646641
for true {
647642
row, e := d.getRow()
@@ -687,33 +682,18 @@ func (d *SingleFileDataset) Generate(batchSize int) ([]*tf.Tensor, *tf.Tensor, *
687682
return nil, nil, nil, e
688683
}
689684

690-
if len(row) <= d.categoryOffset {
685+
if len(row) <= d.yProcessor.LineOffset {
691686
if d.ignoreParseErrors {
692687
continue
693688
}
694-
e = fmt.Errorf("row did not contain enough columns for categoryOffset at %d", d.categoryOffset)
689+
e = fmt.Errorf("row did not contain enough columns for categoryOffset at %d", d.yProcessor.LineOffset)
695690
d.errorHandler.Error(e)
696691
return nil, nil, nil, e
697692
}
698693

699-
var yInt int
700-
701-
if d.categoryTokenizer != nil {
702-
yInt = int(d.categoryTokenizer.Tokenize(row[d.categoryOffset])[0])
703-
} else {
704-
yInt, e = strconv.Atoi(row[d.categoryOffset])
705-
if e != nil {
706-
if d.ignoreParseErrors {
707-
continue
708-
}
709-
d.errorHandler.Error(e)
710-
return nil, nil, nil, e
711-
}
712-
}
713-
714-
yInts = append(yInts, []int32{int32(yInt)})
694+
yRaw = append(yRaw, row[d.yProcessor.LineOffset])
715695

716-
if len(yInts) >= batchSize {
696+
if len(yRaw) >= batchSize {
717697
break
718698
}
719699
}
@@ -727,18 +707,24 @@ func (d *SingleFileDataset) Generate(batchSize int) ([]*tf.Tensor, *tf.Tensor, *
727707
x = append(x, process)
728708
}
729709

730-
var classWeights []float32
731-
for _, yInt32 := range yInts {
732-
classWeights = append(classWeights, d.ClassWeights[int(yInt32[0])])
733-
}
734-
735-
classWeightsTensor, e := tf.NewTensor(classWeights)
710+
y, e := d.yProcessor.ProcessString(yRaw)
736711
if e != nil {
737-
d.errorHandler.Error(e)
738712
return nil, nil, nil, e
739713
}
740714

741-
y, e := tf.NewTensor(yInts)
715+
var classWeights []float32
716+
yInts, isInt := y.Value().([][]int32)
717+
if isInt {
718+
for _, yInt32 := range yInts {
719+
classWeights = append(classWeights, d.ClassWeights[int(yInt32[0])])
720+
}
721+
} else {
722+
for range yRaw {
723+
classWeights = append(classWeights, 1)
724+
}
725+
}
726+
727+
classWeightsTensor, e := tf.NewTensor(classWeights)
742728
if e != nil {
743729
d.errorHandler.Error(e)
744730
return nil, nil, nil, e
@@ -768,5 +754,9 @@ func (d *SingleFileDataset) SaveProcessors(saveDir string) error {
768754
return e
769755
}
770756
}
757+
e = d.yProcessor.Save(saveDir)
758+
if e != nil {
759+
return e
760+
}
771761
return nil
772762
}

0 commit comments

Comments
 (0)