Skip to content

Commit

Permalink
Merge e986edc into 2009390
Browse files Browse the repository at this point in the history
  • Loading branch information
asticode committed Mar 19, 2022
2 parents 2009390 + e986edc commit 21cf918
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 17 deletions.
31 changes: 26 additions & 5 deletions muxer.go
Expand Up @@ -30,7 +30,9 @@ type Muxer struct {
tablesRetransmitPeriod int // period in PES packets

pm *programMap // pid -> programNumber
pmUpdated bool
pmt PMTData
pmtUpdated bool
nextPID uint16
patVersion wrappingCounter
pmtVersion wrappingCounter
Expand Down Expand Up @@ -96,6 +98,7 @@ func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer {

// TODO multiple programs support
m.pm.set(pmtStartPID, programNumberStart)
m.pmUpdated = true

for _, opt := range opts {
opt(m)
Expand Down Expand Up @@ -125,6 +128,7 @@ func (m *Muxer) AddElementaryStream(es PMTElementaryStream) error {
m.esContexts[es.ElementaryPID] = newEsContext(&es)
// invalidate pmt cache
m.pmtBytes.Reset()
m.pmtUpdated = true
return nil
}

Expand All @@ -144,12 +148,14 @@ func (m *Muxer) RemoveElementaryStream(pid uint16) error {
m.pmt.ElementaryStreams = append(m.pmt.ElementaryStreams[:foundIdx], m.pmt.ElementaryStreams[foundIdx+1:]...)
delete(m.esContexts, pid)
m.pmtBytes.Reset()
m.pmtUpdated = true
return nil
}

// SetPCRPID marks pid as one to look PCRs in
func (m *Muxer) SetPCRPID(pid uint16) {
m.pmt.PCRPID = pid
m.pmtUpdated = true
}

// WriteData writes MuxerData to TS stream
Expand Down Expand Up @@ -181,7 +187,7 @@ func (m *Muxer) WriteData(d *MuxerData) (int, error) {
pktLen := 1 + mpegTsPacketHeaderSize // sync byte + header
pkt := Packet{
Header: &PacketHeader{
ContinuityCounter: uint8(ctx.cc.get()),
ContinuityCounter: uint8(ctx.cc.inc()),
HasAdaptationField: writeAf,
HasPayload: false,
PayloadUnitStartIndicator: false,
Expand Down Expand Up @@ -315,6 +321,12 @@ func (m *Muxer) WriteTables() (int, error) {

func (m *Muxer) generatePAT() error {
d := m.pm.toPATData()

versionNumber := m.patVersion.get()
if m.pmUpdated {
versionNumber = m.patVersion.inc()
}

syntax := &PSISectionSyntax{
Data: &PSISectionSyntaxData{PAT: d},
Header: &PSISectionSyntaxHeader{
Expand All @@ -323,7 +335,7 @@ func (m *Muxer) generatePAT() error {
//LastSectionNumber: 0,
//SectionNumber: 0,
TableIDExtension: d.TransportStreamID,
VersionNumber: uint8(m.patVersion.get()),
VersionNumber: uint8(versionNumber),
},
}
section := PSISection{
Expand Down Expand Up @@ -352,7 +364,7 @@ func (m *Muxer) generatePAT() error {
HasPayload: true,
PayloadUnitStartIndicator: true,
PID: PIDPAT,
ContinuityCounter: uint8(m.patCC.get()),
ContinuityCounter: uint8(m.patCC.inc()),
},
Payload: m.buf.Bytes(),
}
Expand All @@ -361,6 +373,8 @@ func (m *Muxer) generatePAT() error {
return err
}

m.pmUpdated = false

return nil
}

Expand All @@ -376,6 +390,11 @@ func (m *Muxer) generatePMT() error {
return ErrPCRPIDInvalid
}

versionNumber := m.pmtVersion.get()
if m.pmtUpdated {
versionNumber = m.pmtVersion.inc()
}

syntax := &PSISectionSyntax{
Data: &PSISectionSyntaxData{PMT: &m.pmt},
Header: &PSISectionSyntaxHeader{
Expand All @@ -384,7 +403,7 @@ func (m *Muxer) generatePMT() error {
//LastSectionNumber: 0,
//SectionNumber: 0,
TableIDExtension: m.pmt.ProgramNumber,
VersionNumber: uint8(m.pmtVersion.get()),
VersionNumber: uint8(versionNumber),
},
}
section := PSISection{
Expand Down Expand Up @@ -413,7 +432,7 @@ func (m *Muxer) generatePMT() error {
HasPayload: true,
PayloadUnitStartIndicator: true,
PID: pmtStartPID, // FIXME multiple programs support
ContinuityCounter: uint8(m.pmtCC.get()),
ContinuityCounter: uint8(m.pmtCC.inc()),
},
Payload: m.buf.Bytes(),
}
Expand All @@ -422,5 +441,7 @@ func (m *Muxer) generatePMT() error {
return err
}

m.pmtUpdated = false

return nil
}
31 changes: 23 additions & 8 deletions muxer_test.go
Expand Up @@ -3,9 +3,10 @@ package astits
import (
"bytes"
"context"
"testing"

"github.com/asticode/go-astikit"
"github.com/stretchr/testify/assert"
"testing"
)

func patExpectedBytes(versionNumber uint8, cc uint8) []byte {
Expand Down Expand Up @@ -52,21 +53,28 @@ func TestMuxer_generatePAT(t *testing.T) {
assert.Equal(t, MpegTsPacketSize, muxer.patBytes.Len())
assert.Equal(t, patExpectedBytes(0, 0), muxer.patBytes.Bytes())

// to check version number increment
// Version number shouldn't change
err = muxer.generatePAT()
assert.NoError(t, err)
assert.Equal(t, MpegTsPacketSize, muxer.patBytes.Len())
assert.Equal(t, patExpectedBytes(0, 1), muxer.patBytes.Bytes())

// Version number should change
muxer.pmUpdated = true
err = muxer.generatePAT()
assert.NoError(t, err)
assert.Equal(t, MpegTsPacketSize, muxer.patBytes.Len())
assert.Equal(t, patExpectedBytes(1, 1), muxer.patBytes.Bytes())
assert.Equal(t, patExpectedBytes(1, 2), muxer.patBytes.Bytes())
}

func pmtExpectedBytesVideoOnly(versionNumber uint8) []byte {
func pmtExpectedBytesVideoOnly(versionNumber, cc uint8) []byte {
buf := bytes.Buffer{}
w := astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &buf})
w.Write(uint8(syncByte))
w.Write("010") // no transport error, payload start, no priority
w.WriteN(pmtStartPID, 13)
w.Write("0001") // no scrambling, no AF, payload present
w.Write("0000") // CC
w.WriteN(cc, 4)

w.Write(uint16(PSITableIDPMT)) // Table ID
w.Write("1011") // Syntax section indicator, private bit, reserved
Expand Down Expand Up @@ -161,18 +169,25 @@ func TestMuxer_generatePMT(t *testing.T) {
err = muxer.generatePMT()
assert.NoError(t, err)
assert.Equal(t, MpegTsPacketSize, muxer.pmtBytes.Len())
assert.Equal(t, pmtExpectedBytesVideoOnly(0), muxer.pmtBytes.Bytes())
assert.Equal(t, pmtExpectedBytesVideoOnly(0, 0), muxer.pmtBytes.Bytes())

// Version number shouldn't change
err = muxer.generatePMT()
assert.NoError(t, err)
assert.Equal(t, MpegTsPacketSize, muxer.pmtBytes.Len())
assert.Equal(t, pmtExpectedBytesVideoOnly(0, 1), muxer.pmtBytes.Bytes())

err = muxer.AddElementaryStream(PMTElementaryStream{
ElementaryPID: 0x0234,
StreamType: StreamTypeAACAudio,
})
assert.NoError(t, err)

// Version number should change
err = muxer.generatePMT()
assert.NoError(t, err)
assert.Equal(t, MpegTsPacketSize, muxer.pmtBytes.Len())
assert.Equal(t, pmtExpectedBytesVideoAndAudio(1, 1), muxer.pmtBytes.Bytes())
assert.Equal(t, pmtExpectedBytesVideoAndAudio(1, 2), muxer.pmtBytes.Bytes())
}

func TestMuxer_WriteTables(t *testing.T) {
Expand All @@ -190,7 +205,7 @@ func TestMuxer_WriteTables(t *testing.T) {
assert.Equal(t, 2*MpegTsPacketSize, n)
assert.Equal(t, n, buf.Len())

expectedBytes := append(patExpectedBytes(0, 0), pmtExpectedBytesVideoOnly(0)...)
expectedBytes := append(patExpectedBytes(0, 0), pmtExpectedBytesVideoOnly(0, 0)...)
assert.Equal(t, expectedBytes, buf.Bytes())
}

Expand Down
11 changes: 7 additions & 4 deletions wrapping_counter.go
@@ -1,22 +1,25 @@
package astits

type wrappingCounter struct {
wrapAt int
value int
wrapAt int
}

func newWrappingCounter(wrapAt int) wrappingCounter {
return wrappingCounter{
value: wrapAt + 1,
wrapAt: wrapAt,
}
}

// returns current counter state and increments internal value
func (c *wrappingCounter) get() int {
ret := c.value
return c.value
}

func (c *wrappingCounter) inc() int {
c.value++
if c.value > c.wrapAt {
c.value = 0
}
return ret
return c.value
}

0 comments on commit 21cf918

Please sign in to comment.