@@ -19,7 +19,6 @@ import (
1919 "math/rand"
2020 "os"
2121 "path/filepath"
22- "strconv"
2322 "strings"
2423 "sync"
2524 "sync/atomic"
@@ -42,7 +41,7 @@ type SingleFileDataset struct {
4241 generatorOffset * int32
4342 generatorOffsetLock * sync.Mutex
4443 cacheDir string
45- categoryOffset int
44+ yProcessor * preprocessor. Processor
4645 columnProcessors []* preprocessor.Processor
4746 lineOffsets []int64
4847 trainPercent float32
@@ -65,7 +64,6 @@ type SingleFileDataset struct {
6564type SingleFileDatasetConfig struct {
6665 FilePath string
6766 CacheDir string
68- CategoryOffset int
6967 TrainPercent float32
7068 ValPercent float32
7169 TestPercent float32
@@ -81,6 +79,7 @@ func NewSingleFileDataset(
8179 logger * cblog.Logger ,
8280 errorHandler * cberrors.ErrorsContainer ,
8381 config SingleFileDatasetConfig ,
82+ yProcessor * preprocessor.Processor ,
8483 columnProcessors ... * preprocessor.Processor ,
8584) (* SingleFileDataset , error ) {
8685
@@ -113,7 +112,7 @@ func NewSingleFileDataset(
113112 skipHeaders : config .SkipHeaders ,
114113 ignoreParseErrors : config .IgnoreParseErrors ,
115114 cacheDir : config .CacheDir ,
116- categoryOffset : config . CategoryOffset ,
115+ yProcessor : yProcessor ,
117116 columnProcessors : columnProcessors ,
118117 trainPercent : config .TrainPercent ,
119118 valPercent : config .ValPercent ,
@@ -150,19 +149,7 @@ func NewSingleFileDataset(
150149 return nil , e
151150 }
152151
153- if _ , e := os .Stat (filepath .Join (config .CacheDir , "category-tokenizer.json" )); e == nil {
154- d .categoryTokenizer = preprocessor .NewTokenizer (
155- errorHandler ,
156- 1 ,
157- - 1 ,
158- preprocessor.TokenizerConfig {IsCategoryTokenizer : true , DisableFiltering : true },
159- )
160- e = d .categoryTokenizer .Load (filepath .Join (config .CacheDir , "category-tokenizer.json" ))
161- if e != nil {
162- errorHandler .Error (e )
163- return nil , e
164- }
165- }
152+ _ = yProcessor .Load ()
166153
167154 e = d .readLineOffsets ()
168155 if e != nil {
@@ -270,11 +257,11 @@ func (d *SingleFileDataset) readLineOffsets() error {
270257 return
271258 }
272259 }
273- if len (line ) < d .categoryOffset {
260+ if len (line ) < d .yProcessor . LineOffset {
274261 if d .ignoreParseErrors {
275262 return
276263 }
277- e = fmt .Errorf ("len(line) (%d) < d.categoryOffset (%d)" , len (line ), d .categoryOffset )
264+ e = fmt .Errorf ("len(line) (%d) < d.yProcessor.LineOffset (%d)" , len (line ), d .yProcessor . LineOffset )
278265 d .errorHandler .Error (e )
279266 errs = append (errs , e )
280267 return
@@ -283,28 +270,36 @@ func (d *SingleFileDataset) readLineOffsets() error {
283270 d .lineOffsets = append (d .lineOffsets , offset )
284271 d .Count ++
285272
286- category := line [d .categoryOffset ]
273+ if d .yProcessor .RequiresFit {
274+ e = d .yProcessor .FitString ([]string {line [d .yProcessor .LineOffset ]})
275+ if e != nil {
276+ if d .ignoreParseErrors {
277+ return
278+ }
279+ d .errorHandler .Error (e )
280+ errs = append (errs , e )
281+ return
282+ }
283+ }
287284
288- categoryInt , e := strconv .Atoi (category )
289- // TODO: this magical behaviour could be nicer
285+ category , e := d .yProcessor .ProcessString ([]string {line [d .yProcessor .LineOffset ]})
290286 if e != nil {
291- if d .categoryTokenizer == nil {
292- d .categoryTokenizer = preprocessor .NewTokenizer (
293- d .errorHandler ,
294- 1 ,
295- - 1 ,
296- preprocessor.TokenizerConfig {IsCategoryTokenizer : true , DisableFiltering : true },
297- )
287+ if d .ignoreParseErrors {
288+ return
298289 }
299- d .categoryTokenizer .Fit (category )
300- categoryInt = int (d .categoryTokenizer .Tokenize (category )[0 ])
290+ d .errorHandler .Error (e )
291+ errs = append (errs , e )
292+ return
301293 }
302294
303- d .classCountsLock .Lock ()
304- count := d .ClassCounts [categoryInt ]
305- count ++
306- d .ClassCounts [categoryInt ] = count
307- d .classCountsLock .Unlock ()
295+ categoryInt , isInt := category .Value ().([][]int32 )
296+ if isInt {
297+ d .classCountsLock .Lock ()
298+ count := d .ClassCounts [int (categoryInt [0 ][0 ])]
299+ count ++
300+ d .ClassCounts [int (categoryInt [0 ][0 ])] = count
301+ d .classCountsLock .Unlock ()
302+ }
308303 }(readBytes , offset )
309304
310305 now := time .Now ().Unix ()
@@ -641,7 +636,7 @@ func (d *SingleFileDataset) Generate(batchSize int) ([]*tf.Tensor, *tf.Tensor, *
641636 var x []* tf.Tensor
642637
643638 xStrings := make ([][]string , len (d .columnProcessors ))
644- var yInts [][] int32
639+ var yRaw []string
645640
646641 for true {
647642 row , e := d .getRow ()
@@ -687,33 +682,18 @@ func (d *SingleFileDataset) Generate(batchSize int) ([]*tf.Tensor, *tf.Tensor, *
687682 return nil , nil , nil , e
688683 }
689684
690- if len (row ) <= d .categoryOffset {
685+ if len (row ) <= d .yProcessor . LineOffset {
691686 if d .ignoreParseErrors {
692687 continue
693688 }
694- e = fmt .Errorf ("row did not contain enough columns for categoryOffset at %d" , d .categoryOffset )
689+ e = fmt .Errorf ("row did not contain enough columns for categoryOffset at %d" , d .yProcessor . LineOffset )
695690 d .errorHandler .Error (e )
696691 return nil , nil , nil , e
697692 }
698693
699- var yInt int
700-
701- if d .categoryTokenizer != nil {
702- yInt = int (d .categoryTokenizer .Tokenize (row [d .categoryOffset ])[0 ])
703- } else {
704- yInt , e = strconv .Atoi (row [d .categoryOffset ])
705- if e != nil {
706- if d .ignoreParseErrors {
707- continue
708- }
709- d .errorHandler .Error (e )
710- return nil , nil , nil , e
711- }
712- }
713-
714- yInts = append (yInts , []int32 {int32 (yInt )})
694+ yRaw = append (yRaw , row [d .yProcessor .LineOffset ])
715695
716- if len (yInts ) >= batchSize {
696+ if len (yRaw ) >= batchSize {
717697 break
718698 }
719699 }
@@ -727,18 +707,24 @@ func (d *SingleFileDataset) Generate(batchSize int) ([]*tf.Tensor, *tf.Tensor, *
727707 x = append (x , process )
728708 }
729709
730- var classWeights []float32
731- for _ , yInt32 := range yInts {
732- classWeights = append (classWeights , d .ClassWeights [int (yInt32 [0 ])])
733- }
734-
735- classWeightsTensor , e := tf .NewTensor (classWeights )
710+ y , e := d .yProcessor .ProcessString (yRaw )
736711 if e != nil {
737- d .errorHandler .Error (e )
738712 return nil , nil , nil , e
739713 }
740714
741- y , e := tf .NewTensor (yInts )
715+ var classWeights []float32
716+ yInts , isInt := y .Value ().([][]int32 )
717+ if isInt {
718+ for _ , yInt32 := range yInts {
719+ classWeights = append (classWeights , d .ClassWeights [int (yInt32 [0 ])])
720+ }
721+ } else {
722+ for range yRaw {
723+ classWeights = append (classWeights , 1 )
724+ }
725+ }
726+
727+ classWeightsTensor , e := tf .NewTensor (classWeights )
742728 if e != nil {
743729 d .errorHandler .Error (e )
744730 return nil , nil , nil , e
@@ -768,5 +754,9 @@ func (d *SingleFileDataset) SaveProcessors(saveDir string) error {
768754 return e
769755 }
770756 }
757+ e = d .yProcessor .Save (saveDir )
758+ if e != nil {
759+ return e
760+ }
771761 return nil
772762}
0 commit comments