Skip to content

Commit

Permalink
Filter NALs globally (#1403)
Browse files Browse the repository at this point in the history
* Linux: Remove filter_NAL and avoid copy

* Rework NAL filtering

* Free packet before getting new one (@nowrep suggestion)

* Remove unnecessary condition from `extractHeaders()`

* Fix decoder init and properly offset the pointers

* Use AMFBufferPtr instead of composing FramePacket

* Add boundary checks
  • Loading branch information
deiteris committed Jan 28, 2023
1 parent 0734d61 commit 58093df
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 131 deletions.
130 changes: 82 additions & 48 deletions alvr/server/cpp/alvr_server/ClientConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,65 +8,99 @@
#include "Utils.h"
#include "Settings.h"

static const uint8_t NAL_TYPE_SPS = 7;
static const char NAL_HEADER[] = {0x00, 0x00, 0x00, 0x01};

static const uint8_t H264_NAL_TYPE_SPS = 7;
static const uint8_t H265_NAL_TYPE_VPS = 32;

ClientConnection::ClientConnection() {
m_Statistics = std::make_shared<Statistics>();
static const uint8_t H264_NAL_TYPE_AUD = 9;
static const uint8_t H265_NAL_TYPE_AUD = 35;

ClientConnection::ClientConnection() {
m_Statistics = std::make_shared<Statistics>();
}

int findVPSSPS(const uint8_t *frameBuffer, int frameByteSize) {
int zeroes = 0;
int foundNals = 0;
for (int i = 0; i < frameByteSize; i++) {
if (frameBuffer[i] == 0) {
zeroes++;
} else if (frameBuffer[i] == 1) {
if (zeroes >= 2) {
foundNals++;
if (Settings::Instance().m_codec == ALVR_CODEC_H264 && foundNals >= 3) {
// Find end of SPS+PPS on H.264.
return i - 3;
} else if (Settings::Instance().m_codec == ALVR_CODEC_H265 && foundNals >= 4) {
// Find end of VPS+SPS+PPS on H.264.
return i - 3;
}
}
zeroes = 0;
} else {
zeroes = 0;
}
}
return -1;
/*
Sends the (VPS + )SPS + PPS video configuration headers from H.264 or H.265 stream as a sequence of NALs.
(VPS + )SPS + PPS have short size (8bytes + 28bytes in some environment), so we can
assume SPS + PPS is contained in first fragment.
*/
void sendHeaders(uint8_t **buf, int *len, int nalNum) {
uint8_t *b = *buf;
uint8_t *end = b + *len;

int headersLen = 0;
int foundHeaders = -1; // Offset by 1 header to find the length until the next header
while (b != end) {
if (b + sizeof(NAL_HEADER) <= end && memcmp(b, NAL_HEADER, sizeof(NAL_HEADER)) == 0) {
foundHeaders++;
if (foundHeaders == nalNum) {
break;
}
b += sizeof(NAL_HEADER);
headersLen += sizeof(NAL_HEADER);
}

b++;
headersLen++;
}
if (foundHeaders != nalNum) {
return;
}
InitializeDecoder((const unsigned char *)*buf, headersLen);

// move the cursor forward excluding config NALs
*buf = b;
*len -= headersLen;
}

void processH264Nals(uint8_t **buf, int *len) {
uint8_t *b = *buf;
int l = *len;
uint8_t nalType = b[4] & 0x1F;

if (nalType == H264_NAL_TYPE_AUD && l > sizeof(NAL_HEADER) * 2 + 2) {
b += sizeof(NAL_HEADER) + 2;
l -= sizeof(NAL_HEADER) + 2;
nalType = b[4] & 0x1F;
}
if (nalType == H264_NAL_TYPE_SPS) {
sendHeaders(&b, &l, 2); // 2 headers SPS and PPS
}
*buf = b;
*len = l;
}

void processH265Nals(uint8_t **buf, int *len) {
uint8_t *b = *buf;
int l = *len;
uint8_t nalType = (b[4] >> 1) & 0x3F;

if (nalType == H265_NAL_TYPE_AUD && l > sizeof(NAL_HEADER) * 2 + 3) {
b += sizeof(NAL_HEADER) + 3;
l -= sizeof(NAL_HEADER) + 3;
nalType = (b[4] >> 1) & 0x3F;
}
if (nalType == H265_NAL_TYPE_VPS) {
sendHeaders(&b, &l, 3); // 3 headers VPS, SPS and PPS
}
*buf = b;
*len = l;
}

void ClientConnection::SendVideo(uint8_t *buf, int len, uint64_t targetTimestampNs) {
// Report before the frame is packetized
ReportEncoded(targetTimestampNs);

uint8_t NALType;
if (Settings::Instance().m_codec == ALVR_CODEC_H264)
NALType = buf[4] & 0x1F;
else
NALType = (buf[4] >> 1) & 0x3F;

if ((Settings::Instance().m_codec == ALVR_CODEC_H264 && NALType == NAL_TYPE_SPS) ||
(Settings::Instance().m_codec == ALVR_CODEC_H265 && NALType == H265_NAL_TYPE_VPS)) {
// This frame contains (VPS + )SPS + PPS + IDR on NVENC H.264 (H.265) stream.
// (VPS + )SPS + PPS has short size (8bytes + 28bytes in some environment), so we can
// assume SPS + PPS is contained in first fragment.

int end = findVPSSPS(buf, len);
if (end == -1) {
// Invalid frame.
return;
}

InitializeDecoder((const unsigned char *)buf, end);
if (len < sizeof(NAL_HEADER)) {
return;
}

// move the cursor forward excluding config NALs
buf = &buf[end];
len = len - end;
int codec = Settings::Instance().m_codec;
if (codec == ALVR_CODEC_H264) {
processH264Nals(&buf, &len);
} else if (codec == ALVR_CODEC_H265) {
processH265Nals(&buf, &len);
}

VideoSend(targetTimestampNs, buf, len);
Expand Down
9 changes: 3 additions & 6 deletions alvr/server/cpp/platform/linux/CEncoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ void CEncoder::Run() {

fprintf(stderr, "CEncoder starting to read present packets");
present_packet frame_info;
std::vector<uint8_t> encoded_data;
while (not m_exiting) {
read_latest(client, (char *)&frame_info, sizeof(frame_info), m_exiting);

Expand All @@ -250,9 +249,8 @@ void CEncoder::Run() {

static_assert(sizeof(frame_info.pose) == sizeof(vr::HmdMatrix34_t&));

encoded_data.clear();
uint64_t pts;
if (!encode_pipeline->GetEncoded(encoded_data, &pts)) {
alvr::FramePacket packet;
if (!encode_pipeline->GetEncoded(packet)) {
Error("Failed to get encoded data!");
continue;
}
Expand All @@ -279,10 +277,9 @@ void CEncoder::Run() {
ReportPresent(pose->targetTimestampNs, present_offset);
ReportComposed(pose->targetTimestampNs, composed_offset);

m_listener->SendVideo(encoded_data.data(), encoded_data.size(), pts);
m_listener->SendVideo(packet.data, packet.size, packet.pts);

m_listener->GetStatistics()->EncodeOutput();

}
}
catch (std::exception &e) {
Expand Down
74 changes: 12 additions & 62 deletions alvr/server/cpp/platform/linux/EncodePipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,6 @@ extern "C" {
#include <libavcodec/avcodec.h>
}

namespace {

bool should_keep_nal_h264(const uint8_t * header_start)
{
uint8_t nal_type = (header_start[2] == 0 ? header_start[4] : header_start[3]) & 0x1F;
switch (nal_type)
{
case 6: // supplemental enhancement information
case 9: // access unit delimiter
return false;
default:
return true;
}
}

bool should_keep_nal_h265(const uint8_t * header_start)
{
uint8_t nal_type = ((header_start[2] == 0 ? header_start[4] : header_start[3]) >> 1) & 0x3F;
switch (nal_type)
{
case 35: // access unit delimiter
case 39: // supplemental enhancement information
return false;
default:
return true;
}
}

void filter_NAL(const uint8_t* input, size_t input_size, std::vector<uint8_t> &out)
{
if (input_size < 4)
return;
auto codec = Settings::Instance().m_codec;
std::array<uint8_t, 3> header = {{0, 0, 1}};
auto end = input + input_size;
auto header_start = input;
while (header_start != end)
{
auto next_header = std::search(header_start + 3, end, header.begin(), header.end());
if (next_header != end and next_header[-1] == 0)
{
next_header--;
}
if (codec == ALVR_CODEC_H264 and should_keep_nal_h264(header_start))
out.insert(out.end(), header_start, next_header);
if (codec == ALVR_CODEC_H265 and should_keep_nal_h265(header_start))
out.insert(out.end(), header_start, next_header);
header_start = next_header;
}
}

}

void alvr::EncodePipeline::SetBitrate(int64_t bitrate) {
encoder_ctx->bit_rate = bitrate;
encoder_ctx->rc_buffer_size = bitrate / Settings::Instance().m_refreshRate * 1.1;
Expand Down Expand Up @@ -112,17 +59,20 @@ alvr::EncodePipeline::~EncodePipeline()
avcodec_free_context(&encoder_ctx);
}

bool alvr::EncodePipeline::GetEncoded(std::vector<uint8_t> &out, uint64_t *pts)
bool alvr::EncodePipeline::GetEncoded(FramePacket &packet)
{
AVPacket * enc_pkt = av_packet_alloc();
int err = avcodec_receive_packet(encoder_ctx, enc_pkt);
if (err == AVERROR(EAGAIN)) {
return false;
} else if (err) {
av_packet_free(&encoder_packet);
encoder_packet = av_packet_alloc();
int err = avcodec_receive_packet(encoder_ctx, encoder_packet);
if (err != 0) {
av_packet_free(&encoder_packet);
if (err == AVERROR(EAGAIN)) {
return false;
}
throw alvr::AvException("failed to encode", err);
}
filter_NAL(enc_pkt->data, enc_pkt->size, out);
*pts = enc_pkt->pts;
av_packet_free(&enc_pkt);
packet.data = encoder_packet->data;
packet.size = encoder_packet->size;
packet.pts = encoder_packet->pts;
return true;
}
10 changes: 9 additions & 1 deletion alvr/server/cpp/platform/linux/EncodePipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <vector>

extern "C" struct AVCodecContext;
extern "C" struct AVPacket;

class Renderer;

Expand All @@ -14,6 +15,12 @@ class VkFrame;
class VkFrameCtx;
class VkContext;

struct FramePacket {
uint8_t *data;
int size;
uint64_t pts;
};

class EncodePipeline
{
public:
Expand All @@ -25,13 +32,14 @@ class EncodePipeline
virtual ~EncodePipeline();

virtual void PushFrame(uint64_t targetTimestampNs, bool idr) = 0;
virtual bool GetEncoded(std::vector<uint8_t> & out, uint64_t *pts);
virtual bool GetEncoded(FramePacket &data);
virtual Timestamp GetTimestamp() { return timestamp; }

virtual void SetBitrate(int64_t bitrate);
static std::unique_ptr<EncodePipeline> Create(Renderer *render, VkContext &vk_ctx, VkFrame &input_frame, VkFrameCtx &vk_frame_ctx, uint32_t width, uint32_t height);
protected:
AVCodecContext *encoder_ctx = nullptr; //shall be initialized by child class
AVPacket *encoder_packet = NULL;
Timestamp timestamp = {};
};

Expand Down
20 changes: 8 additions & 12 deletions alvr/server/cpp/platform/linux/EncodePipelineAMF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,26 +456,27 @@ void EncodePipelineAMF::PushFrame(uint64_t targetTimestampNs, bool idr)
m_amfComponents.front()->SubmitInput(surface);
}

bool EncodePipelineAMF::GetEncoded(std::vector<uint8_t> &out, uint64_t *pts)
bool EncodePipelineAMF::GetEncoded(FramePacket &packet)
{
m_frameBuffer = NULL;
if (m_hasQueryTimeout) {
m_pipeline->Run();
} else {
uint32_t timeout = 4 * 1000; // 1 second
while (m_outBuffer.empty() && --timeout != 0) {
while (m_frameBuffer == NULL && --timeout != 0) {
std::this_thread::sleep_for(std::chrono::microseconds(250));
m_pipeline->Run();
}
}

if (m_outBuffer.empty()) {
if (m_frameBuffer == NULL) {
Error("Timed out waiting for encoder data");
return false;
}

out = m_outBuffer;
*pts = m_targetTimestampNs;
m_outBuffer.clear();
packet.data = reinterpret_cast<uint8_t *>(m_frameBuffer->GetNative());
packet.size = static_cast<int>(m_frameBuffer->GetSize());
packet.pts = m_targetTimestampNs;

uint64_t query;
VK_CHECK(vkGetQueryPoolResults(m_render->m_dev, m_queryPool, 0, 1, sizeof(uint64_t), &query, sizeof(uint64_t), VK_QUERY_RESULT_64_BIT));
Expand All @@ -499,12 +500,7 @@ void EncodePipelineAMF::SetBitrate(int64_t bitrate)

void EncodePipelineAMF::Receive(amf::AMFDataPtr data)
{
amf::AMFBufferPtr buffer(data); // query for buffer interface

char *p = reinterpret_cast<char*>(buffer->GetNative());
int length = static_cast<int>(buffer->GetSize());

m_outBuffer = std::vector<uint8_t>(p, p + length);
m_frameBuffer = amf::AMFBufferPtr(data); // query for buffer interface
}

void EncodePipelineAMF::ApplyFrameProperties(const amf::AMFSurfacePtr &surface, bool insertIDR)
Expand Down
4 changes: 2 additions & 2 deletions alvr/server/cpp/platform/linux/EncodePipelineAMF.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class EncodePipelineAMF : public EncodePipeline
~EncodePipelineAMF();

void PushFrame(uint64_t targetTimestampNs, bool idr) override;
bool GetEncoded(std::vector<uint8_t> &out, uint64_t *pts) override;
bool GetEncoded(FramePacket &packet) override;
void SetBitrate(int64_t bitrate) override;

private:
Expand Down Expand Up @@ -96,7 +96,7 @@ class EncodePipelineAMF : public EncodePipeline
int m_bitrateInMBits;

bool m_hasQueryTimeout = false;
std::vector<uint8_t> m_outBuffer;
amf::AMFBufferPtr m_frameBuffer;
uint64_t m_targetTimestampNs;
};

Expand Down

0 comments on commit 58093df

Please sign in to comment.