Skip to content

Commit 5151891

Browse files
committed
refactor: update watcher tests to use public interface of fsnotify to prevent race condition in tests
1 parent bc38b05 commit 5151891

File tree

4 files changed

+43
-33
lines changed

4 files changed

+43
-33
lines changed

.version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.3.927
1+
0.3.931

cmd/templ/generatecmd/cmd.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,12 +306,17 @@ func (cmd *Generate) walkAndWatch(ctx context.Context, events chan fsnotify.Even
306306
return
307307
}
308308
cmd.Log.Info("Watching files")
309-
rw, err := watcher.Recursive(ctx, cmd.Args.Path, cmd.Args.WatchPattern, events, errs)
309+
rw, err := watcher.Recursive(ctx, cmd.Args.WatchPattern, events, errs)
310310
if err != nil {
311311
cmd.Log.Error("Recursive watcher setup failed, exiting", slog.Any("error", err))
312312
errs <- FatalError{Err: fmt.Errorf("failed to setup recursive watcher: %w", err)}
313313
return
314314
}
315+
if err = rw.Add(cmd.Args.Path); err != nil {
316+
cmd.Log.Error("Failed to add path to watcher", slog.Any("error", err))
317+
errs <- FatalError{Err: fmt.Errorf("failed to add path to watcher: %w", err)}
318+
return
319+
}
315320
defer func() {
316321
if err := rw.Close(); err != nil {
317322
cmd.Log.Error("Failed to close watcher", slog.Any("error", err))

cmd/templ/generatecmd/watcher/watch.go

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515

1616
func Recursive(
1717
ctx context.Context,
18-
path string,
1918
watchPattern *regexp.Regexp,
2019
out chan fsnotify.Event,
2120
errors chan error,
@@ -24,20 +23,21 @@ func Recursive(
2423
if err != nil {
2524
return nil, err
2625
}
27-
w = NewRecursiveWatcher(ctx, fsnw, watchPattern, out, errors)
28-
go w.loop()
29-
return w, w.Add(path)
30-
}
31-
32-
func NewRecursiveWatcher(ctx context.Context, w *fsnotify.Watcher, watchPattern *regexp.Regexp, events chan fsnotify.Event, errors chan error) *RecursiveWatcher {
33-
return &RecursiveWatcher{
26+
w = &RecursiveWatcher{
3427
ctx: ctx,
35-
w: w,
28+
w: fsnw,
3629
WatchPattern: watchPattern,
37-
Events: events,
30+
Events: out,
3831
Errors: errors,
3932
timers: make(map[timerKey]*time.Timer),
33+
loopComplete: sync.WaitGroup{},
4034
}
35+
w.loopComplete.Add(1)
36+
go func() {
37+
defer w.loopComplete.Done()
38+
w.loop()
39+
}()
40+
return w, nil
4141
}
4242

4343
// WalkFiles walks the file tree rooted at path, sending a Create event for each
@@ -73,6 +73,7 @@ type RecursiveWatcher struct {
7373
Errors chan error
7474
timerMu sync.Mutex
7575
timers map[timerKey]*time.Timer
76+
loopComplete sync.WaitGroup
7677
}
7778

7879
type timerKey struct {
@@ -88,6 +89,10 @@ func timerKeyFromEvent(event fsnotify.Event) timerKey {
8889
}
8990

9091
func (w *RecursiveWatcher) Close() error {
92+
w.loopComplete.Wait()
93+
for _, timer := range w.timers {
94+
timer.Stop()
95+
}
9196
return w.w.Close()
9297
}
9398

@@ -115,6 +120,9 @@ func (w *RecursiveWatcher) loop() {
115120
w.timerMu.Unlock()
116121
if !ok {
117122
t = time.AfterFunc(100*time.Millisecond, func() {
123+
if w.ctx.Err() != nil {
124+
return
125+
}
118126
w.Events <- event
119127
})
120128
w.timerMu.Lock()

cmd/templ/generatecmd/watcher/watch_test.go

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,20 @@ import (
1212

1313
func TestWatchDebouncesDuplicates(t *testing.T) {
1414
ctx, cancel := context.WithCancel(context.Background())
15-
w := &fsnotify.Watcher{
16-
Events: make(chan fsnotify.Event),
17-
}
1815
events := make(chan fsnotify.Event, 2)
1916
errors := make(chan error)
2017
watchPattern, err := regexp.Compile(".*")
2118
if err != nil {
2219
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
2320
}
24-
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
21+
rw, err := Recursive(ctx, watchPattern, events, errors)
22+
if err != nil {
23+
t.Fatal(fmt.Errorf("failed to create recursive watcher: %w", err))
24+
}
2525
go func() {
2626
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
2727
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
28-
cancel()
29-
close(rw.w.Events)
3028
}()
31-
rw.loop()
3229
count := 0
3330
exp := time.After(300 * time.Millisecond)
3431
for {
@@ -39,6 +36,8 @@ func TestWatchDebouncesDuplicates(t *testing.T) {
3936
if count != 1 {
4037
t.Errorf("expected 1 event, got %d", count)
4138
}
39+
cancel()
40+
rw.Close()
4241
return
4342
}
4443
}
@@ -64,23 +63,20 @@ func TestWatchDoesNotDebounceDifferentEvents(t *testing.T) {
6463
}
6564
for _, test := range tests {
6665
ctx, cancel := context.WithCancel(context.Background())
67-
w := &fsnotify.Watcher{
68-
Events: make(chan fsnotify.Event),
69-
}
7066
events := make(chan fsnotify.Event, 2)
7167
errors := make(chan error)
7268
watchPattern, err := regexp.Compile(".*")
7369
if err != nil {
7470
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
7571
}
76-
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
72+
rw, err := Recursive(ctx, watchPattern, events, errors)
73+
if err != nil {
74+
t.Fatal(fmt.Errorf("failed to create recursive watcher: %w", err))
75+
}
7776
go func() {
7877
rw.w.Events <- test.event1
7978
rw.w.Events <- test.event2
80-
cancel()
81-
close(rw.w.Events)
8279
}()
83-
rw.loop()
8480
count := 0
8581
exp := time.After(300 * time.Millisecond)
8682
for {
@@ -91,6 +87,8 @@ func TestWatchDoesNotDebounceDifferentEvents(t *testing.T) {
9187
if count != 2 {
9288
t.Errorf("expected 2 event, got %d", count)
9389
}
90+
cancel()
91+
rw.Close()
9492
return
9593
}
9694
}
@@ -99,24 +97,21 @@ func TestWatchDoesNotDebounceDifferentEvents(t *testing.T) {
9997

10098
func TestWatchDoesNotDebounceSeparateEvents(t *testing.T) {
10199
ctx, cancel := context.WithCancel(context.Background())
102-
w := &fsnotify.Watcher{
103-
Events: make(chan fsnotify.Event),
104-
}
105100
events := make(chan fsnotify.Event, 2)
106101
errors := make(chan error)
107102
watchPattern, err := regexp.Compile(".*")
108103
if err != nil {
109104
t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err))
110105
}
111-
rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors)
106+
rw, err := Recursive(ctx, watchPattern, events, errors)
107+
if err != nil {
108+
t.Fatal(fmt.Errorf("failed to create recursive watcher: %w", err))
109+
}
112110
go func() {
113111
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
114112
<-time.After(200 * time.Millisecond)
115113
rw.w.Events <- fsnotify.Event{Name: "test.templ"}
116-
cancel()
117-
close(rw.w.Events)
118114
}()
119-
rw.loop()
120115
count := 0
121116
exp := time.After(500 * time.Millisecond)
122117
for {
@@ -127,6 +122,8 @@ func TestWatchDoesNotDebounceSeparateEvents(t *testing.T) {
127122
if count != 2 {
128123
t.Errorf("expected 2 event, got %d", count)
129124
}
125+
cancel()
126+
rw.Close()
130127
return
131128
}
132129
}

0 commit comments

Comments
 (0)