Skip to content

Commit

Permalink
Demuxer: emit PAT/PMT when seen (#24)
Browse files Browse the repository at this point in the history
* packetPool: flush old packets whenever new payload starts

* packetPool: split out a packetAccumulator so the pool only has to keep one per-pid map

* packetPool: accept demuxer pointer

* Demuxer: helper wrapper for parseData

* packetAccumulator: flush PAT/PMT as soon as parsed successfully

* packetAccumulator: cleanups

* test to ensure demuxer returns PAT/PMT when seen

* decouple packetPool from Demuxer

* strip all whitespace from hex before parsing

* move comment

* decouple packetPool and packetAccumulator
  • Loading branch information
tmm1 committed Apr 17, 2021
1 parent c44c340 commit 7a46f0b
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 42 deletions.
4 changes: 2 additions & 2 deletions data.go
Expand Up @@ -35,7 +35,7 @@ type MuxerData struct {
}

// parseData parses a payload spanning over multiple packets and returns a set of data
func parseData(ps []*Packet, prs PacketsParser, pm programMap) (ds []*DemuxerData, err error) {
func parseData(ps []*Packet, prs PacketsParser, pm *programMap) (ds []*DemuxerData, err error) {
// Use custom parser first
if prs != nil {
var skip bool
Expand Down Expand Up @@ -99,7 +99,7 @@ func parseData(ps []*Packet, prs PacketsParser, pm programMap) (ds []*DemuxerDat
}

// isPSIPayload checks whether the payload is a PSI one
func isPSIPayload(pid uint16, pm programMap) bool {
func isPSIPayload(pid uint16, pm *programMap) bool {
return pid == PIDPAT || // PAT
pm.exists(pid) || // PMT
((pid >= 0x10 && pid <= 0x14) || (pid >= 0x1e && pid <= 0x1f)) //DVB
Expand Down
6 changes: 3 additions & 3 deletions demuxer.go
Expand Up @@ -27,7 +27,7 @@ type Demuxer struct {
optPacketsParser PacketsParser
packetBuffer *packetBuffer
packetPool *packetPool
programMap programMap
programMap *programMap
r io.Reader
}

Expand All @@ -40,10 +40,10 @@ func NewDemuxer(ctx context.Context, r io.Reader, opts ...func(*Demuxer)) (d *De
// Init
d = &Demuxer{
ctx: ctx,
packetPool: newPacketPool(),
programMap: newProgramMap(),
r: r,
}
d.packetPool = newPacketPool(d.optPacketsParser, d.programMap)

// Apply options
for _, opt := range opts {
Expand Down Expand Up @@ -180,7 +180,7 @@ func (dmx *Demuxer) updateData(ds []*DemuxerData) (d *DemuxerData) {
func (dmx *Demuxer) Rewind() (n int64, err error) {
dmx.dataBuffer = []*DemuxerData{}
dmx.packetBuffer = nil
dmx.packetPool = newPacketPool()
dmx.packetPool = newPacketPool(dmx.optPacketsParser, dmx.programMap)
if n, err = rewind(dmx.r); err != nil {
err = fmt.Errorf("astits: rewinding reader failed: %w", err)
return
Expand Down
48 changes: 48 additions & 0 deletions demuxer_test.go
Expand Up @@ -3,14 +3,31 @@ package astits
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"io"
"strings"
"testing"
"unicode"

"github.com/asticode/go-astikit"
"github.com/stretchr/testify/assert"
)

func hexToBytes(in string) []byte {
cin := strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return -1
}
return r
}, in)
o, err := hex.DecodeString(cin)
if err != nil {
panic(err)
}
return o
}

func TestDemuxerNew(t *testing.T) {
ps := 1
pp := func(ps []*Packet) (ds []*DemuxerData, skip bool, err error) { return }
Expand Down Expand Up @@ -84,6 +101,37 @@ func TestDemuxerNextData(t *testing.T) {
assert.EqualError(t, err, ErrNoMorePackets.Error())
}

func TestDemuxerNextDataPATPMT(t *testing.T) {
pat := hexToBytes(`474000100000b00d0001c100000001f0002ab104b2ffffffffffffffff
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
ffffffffffffffffff`)
pmt := hexToBytes(`475000100002b0170001c10000e100f0001be100f0000fe101f0002f44
b99bffffffffffffffffffffffffffffffffffffffffffffffffffffffff
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff
ffffffffffffffffff`)
r := bytes.NewReader(append(pat, pmt...))
dmx := NewDemuxer(context.Background(), r, DemuxerOptPacketSize(188))
assert.Equal(t, 188*2, r.Len())

d, err := dmx.NextData()
assert.NoError(t, err)
assert.Equal(t, uint16(0), d.FirstPacket.Header.PID)
assert.NotNil(t, d.PAT)
assert.Equal(t, 188, r.Len())

d, err = dmx.NextData()
assert.NoError(t, err)
assert.Equal(t, uint16(0x1000), d.FirstPacket.Header.PID)
assert.NotNil(t, d.PMT)
}

func TestDemuxerRewind(t *testing.T) {
r := bytes.NewReader([]byte("content"))
dmx := NewDemuxer(context.Background(), r)
Expand Down
5 changes: 3 additions & 2 deletions muxer.go
Expand Up @@ -4,8 +4,9 @@ import (
"bytes"
"context"
"errors"
"github.com/asticode/go-astikit"
"io"

"github.com/asticode/go-astikit"
)

const (
Expand All @@ -28,7 +29,7 @@ type Muxer struct {
packetSize int
tablesRetransmitPeriod int // period in PES packets

pm programMap // pid -> programNumber
pm *programMap // pid -> programNumber
pmt PMTData
nextPID uint16
patVersion wrappingCounter
Expand Down
102 changes: 70 additions & 32 deletions packet_pool.go
Expand Up @@ -5,17 +5,76 @@ import (
"sync"
)

// packetPool represents a pool of packets
// packetAccumulator keeps track of packets for a single PID and decides when to flush them
type packetAccumulator struct {
parser PacketsParser
pid uint16
programMap *programMap
q []*Packet
}

// newPacketAccumulator creates a new packet queue for a single PID
func newPacketAccumulator(pid uint16, parser PacketsParser, programMap *programMap) *packetAccumulator {
return &packetAccumulator{
parser: parser,
pid: pid,
programMap: programMap,
}
}

// add adds a new packet for this PID to the queue
func (b *packetAccumulator) add(p *Packet) (ps []*Packet) {
mps := b.q

// Empty buffer if we detect a discontinuity
if hasDiscontinuity(mps, p) {
mps = []*Packet{}
}

// Throw away packet if it's the same as the previous one
if isSameAsPrevious(mps, p) {
return
}

// Flush buffer if new payload starts here
if p.Header.PayloadUnitStartIndicator {
ps = mps
mps = []*Packet{p}
} else {
mps = append(mps, p)
}

// Check if PSI payload is complete
if b.programMap != nil &&
(b.pid == PIDPAT || b.programMap.exists(b.pid)) {
// TODO Use partial data parsing instead
if _, err := parseData(mps, b.parser, b.programMap); err == nil {
ps = mps
mps = nil
}
}

b.q = mps
return
}

// packetPool represents a queue of packets for each PID in the stream
type packetPool struct {
b map[uint16][]*Packet // Indexed by PID
b map[uint16]*packetAccumulator // Indexed by PID
m *sync.Mutex

parser PacketsParser
programMap *programMap
}

// newPacketPool creates a new packet pool
func newPacketPool() *packetPool {
// newPacketPool creates a new packet pool with an optional parser and programMap
func newPacketPool(parser PacketsParser, programMap *programMap) *packetPool {
return &packetPool{
b: make(map[uint16][]*Packet),
b: make(map[uint16]*packetAccumulator),
m: &sync.Mutex{},

parser: parser,
programMap: programMap,
}
}

Expand All @@ -36,34 +95,13 @@ func (b *packetPool) add(p *Packet) (ps []*Packet) {
b.m.Lock()
defer b.m.Unlock()

// Init buffer
var mps []*Packet
var ok bool
if mps, ok = b.b[p.Header.PID]; !ok {
mps = []*Packet{}
}

// Empty buffer if we detect a discontinuity
if hasDiscontinuity(mps, p) {
mps = []*Packet{}
}

// Throw away packet if it's the same as the previous one
if isSameAsPrevious(mps, p) {
return
// Make sure accumulator exists
if _, ok := b.b[p.Header.PID]; !ok {
b.b[p.Header.PID] = newPacketAccumulator(p.Header.PID, b.parser, b.programMap)
}

// Flush buffer if new payload starts here
if p.Header.PayloadUnitStartIndicator {
ps = mps
mps = []*Packet{p}
} else {
mps = append(mps, p)
}

// Assign
b.b[p.Header.PID] = mps
return
// Add to the accumulator
return b.b[p.Header.PID].add(p)
}

// dump dumps the packet pool by looking for the first item with packets inside
Expand All @@ -76,7 +114,7 @@ func (b *packetPool) dump() (ps []*Packet) {
}
sort.Ints(keys)
for _, k := range keys {
ps = b.b[uint16(k)]
ps = b.b[uint16(k)].q
delete(b.b, uint16(k))
if len(ps) > 0 {
return
Expand Down
2 changes: 1 addition & 1 deletion packet_pool_test.go
Expand Up @@ -21,7 +21,7 @@ func TestIsSameAsPrevious(t *testing.T) {
}

func TestPacketPool(t *testing.T) {
b := newPacketPool()
b := newPacketPool(nil, nil)
ps := b.add(&Packet{Header: &PacketHeader{ContinuityCounter: 0, HasPayload: true, PID: 1}})
assert.Len(t, ps, 0)
ps = b.add(&Packet{Header: &PacketHeader{ContinuityCounter: 1, HasPayload: true, PayloadUnitStartIndicator: true, PID: 1}})
Expand Down
4 changes: 2 additions & 2 deletions program_map.go
Expand Up @@ -9,8 +9,8 @@ type programMap struct {
}

// newProgramMap creates a new program ids map
func newProgramMap() programMap {
return programMap{
func newProgramMap() *programMap {
return &programMap{
m: &sync.Mutex{},
p: make(map[uint16]uint16),
}
Expand Down

0 comments on commit 7a46f0b

Please sign in to comment.