Skip to content

Commit

Permalink
Added an io pachage with a MultiWriter similar to the one included in…
Browse files Browse the repository at this point in the history
… Go. The difference is this one will let you manage the Writers attached to it. This will be used by the log package.
  • Loading branch information
mattfarina committed Oct 16, 2013
1 parent 9eaf783 commit 0fe08ed
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
57 changes: 57 additions & 0 deletions io/multi.go
@@ -0,0 +1,57 @@
package io

import (
"io"
)

// MultiWriter enables you to have a writer that passes on the writing to one
// of more Writers where the write is duplicated to each Writer. MultiWriter
// is similar to the multiWriter that is part of Go. The difference is
// this MultiWriter allows you to manager the Writers attached to it via CRUD
// operations. To do this you will need to mock the type. For example,
// mw := NewMultiWriter()
// mw.(*MultiWriter).AddWriter("foo", foo)
type MultiWriter struct {
writers map[string]io.Writer
}

func (t *MultiWriter) Write(p []byte) (n int, err error) {
for _, w := range t.writers {
n, err = w.Write(p)
if err != nil {
return
}
if n != len(p) {
err = io.ErrShortWrite
return
}
}
return len(p), nil
}

func (t *MultiWriter) Init() *MultiWriter {
t.writers = make(map[string]io.Writer)
return t
}

func (t *MultiWriter) Writer(name string) (io.Writer, bool) {
value, found := t.writers[name]
return value, found
}

func (t *MultiWriter) Writers() map[string]io.Writer {
return t.writers
}

func (t *MultiWriter) AddWriter(name string, writer io.Writer) {
t.writers[name] = writer
}

func (t *MultiWriter) RemoveWriter(name string) {
delete(t.writers, name)
}

func NewMultiWriter() io.Writer {
w := new(MultiWriter).Init()
return w
}
62 changes: 62 additions & 0 deletions io/multi_test.go
@@ -0,0 +1,62 @@
package io

import (
"bytes"
"crypto/sha1"
"fmt"
"io"
"strings"
"testing"
)

func TestMultiWrite(t *testing.T) {
sha1 := sha1.New()
sink := new(bytes.Buffer)
mw := NewMultiWriter()
mw.(*MultiWriter).AddWriter("sha1", sha1)
mw.(*MultiWriter).AddWriter("sink", sink)

sourceString := "My input text."
source := strings.NewReader(sourceString)
written, err := io.Copy(mw, source)

if written != int64(len(sourceString)) {
t.Errorf("short write of %d, not %d", written, len(sourceString))
}

if err != nil {
t.Errorf("unexpected error: %v", err)
}

sha1hex := fmt.Sprintf("%x", sha1.Sum(nil))
if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" {
t.Error("incorrect sha1 value")
}

if sink.String() != sourceString {
t.Errorf("expected %q; got %q", sourceString, sink.String())
}
}

func TestMultiWriterCRUD(t *testing.T) {
sha1 := sha1.New()
mw := NewMultiWriter()
mw.(*MultiWriter).AddWriter("sha1", sha1)

sha1a, found := mw.(*MultiWriter).Writer("sha1")

if found == false {
t.Error("Did not find sha1 as expected.")
}
if sha1a != sha1 {
t.Error("Expected sha1 returned from MultiWriter to be what was set. They were different.")
}

mw.(*MultiWriter).RemoveWriter("sha1")
_, found = mw.(*MultiWriter).Writer("sha1")

if found == true {
t.Error("Expected sha1 to be removed from MultiWriter but it was not.")
}

}

0 comments on commit 0fe08ed

Please sign in to comment.