Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 59 additions & 43 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,30 @@ jobs:
id: benchmark
run: |
echo "🚀 Running single file benchmark..."
# Run benchmark with ES2004a file and save results to JSON
swift run fluidaudio benchmark --auto-download --single-file ES2004a --output benchmark_results.json
swift run fluidaudio benchmark --auto-download --single-file ES2004a --output benchmark_results.json | tee benchmark.log

# Extract total time from CLI output
if grep -q "Total benchmark execution time:" benchmark.log; then
BENCHMARK_TIME=$(grep "Total benchmark execution time:" benchmark.log | grep -o '[0-9.]*')
echo "BENCHMARK_TIME=${BENCHMARK_TIME}" >> $GITHUB_OUTPUT
else
echo "BENCHMARK_TIME=NA" >> $GITHUB_OUTPUT
fi

# Extract key metrics from JSON output
if [ -f benchmark_results.json ]; then
# Parse JSON results (using basic tools available in GitHub runners)
AVERAGE_DER=$(cat benchmark_results.json | grep -o '"averageDER":[0-9]*\.?[0-9]*' | cut -d':' -f2)
AVERAGE_JER=$(cat benchmark_results.json | grep -o '"averageJER":[0-9]*\.?[0-9]*' | cut -d':' -f2)
AVERAGE_JER=$(cat benchmark_results.json | grep -o '"averageJER":[0-9]*\.?[0-9]*' | cut -d':' -f2)
PROCESSED_FILES=$(cat benchmark_results.json | grep -o '"processedFiles":[0-9]*' | cut -d':' -f2)

# Get first result details
RTF=$(cat benchmark_results.json | grep -o '"realTimeFactor":[0-9]*\.?[0-9]*' | head -1 | cut -d':' -f2)
DURATION=$(cat benchmark_results.json | grep -o '"durationSeconds":[0-9]*\.?[0-9]*' | head -1 | cut -d':' -f2)
SPEAKER_COUNT=$(cat benchmark_results.json | grep -o '"speakerCount":[0-9]*' | head -1 | cut -d':' -f2)

echo "DER=${AVERAGE_DER}" >> $GITHUB_OUTPUT
echo "JER=${AVERAGE_JER}" >> $GITHUB_OUTPUT
echo "JER=${AVERAGE_JER}" >> $GITHUB_OUTPUT
echo "RTF=${RTF}" >> $GITHUB_OUTPUT
echo "DURATION=${DURATION}" >> $GITHUB_OUTPUT
echo "SPEAKER_COUNT=${SPEAKER_COUNT}" >> $GITHUB_OUTPUT
Expand All @@ -61,54 +68,63 @@ jobs:
fi
timeout-minutes: 25

- name: Show benchmark_results.json
if: always()
run: |
echo "--- benchmark_results.json ---"
cat benchmark_results.json || echo "benchmark_results.json not found"
echo "-----------------------------"

- name: Extract benchmark metrics with jq
id: extract
run: |
DER=$(jq '.averageDER' benchmark_results.json)
JER=$(jq '.averageJER' benchmark_results.json)
RTF=$(jq '.results[0].realTimeFactor' benchmark_results.json)
DURATION=$(jq '.results[0].durationSeconds' benchmark_results.json)
SPEAKER_COUNT=$(jq '.results[0].speakerCount' benchmark_results.json)
echo "DER=${DER}" >> $GITHUB_OUTPUT
echo "JER=${JER}" >> $GITHUB_OUTPUT
echo "RTF=${RTF}" >> $GITHUB_OUTPUT
echo "DURATION=${DURATION}" >> $GITHUB_OUTPUT
echo "SPEAKER_COUNT=${SPEAKER_COUNT}" >> $GITHUB_OUTPUT

- name: Comment PR with Benchmark Results
if: always()
uses: actions/github-script@v7
with:
script: |
const success = '${{ steps.benchmark.outputs.SUCCESS }}' === 'true';
const der = parseFloat('${{ steps.extract.outputs.DER }}');
const jer = parseFloat('${{ steps.extract.outputs.JER }}');
const rtf = parseFloat('${{ steps.extract.outputs.RTF }}');
const duration = parseFloat('${{ steps.extract.outputs.DURATION }}').toFixed(1);
const speakerCount = '${{ steps.extract.outputs.SPEAKER_COUNT }}';
const benchmarkTime = '${{ steps.benchmark.outputs.BENCHMARK_TIME }}';

let comment = '## 🎯 Single File Benchmark Results\n\n';
comment += `**Test File:** ES2004a (${duration}s audio)\n\n`;
comment += '| Metric | Value | Target | Status |\n';
comment += '|--------|-------|--------|---------|\n';
comment += `| **DER** (Diarization Error Rate) | ${der.toFixed(1)}% | < 30% | ${der < 30 ? '✅' : '❌'} |\n`;
comment += `| **JER** (Jaccard Error Rate) | ${jer.toFixed(1)}% | < 25% | ${jer < 25 ? '✅' : '❌'} |\n`;
comment += `| **RTF** (Real-Time Factor) | ${rtf.toFixed(2)}x | < 1.0x | ${rtf < 1.0 ? '✅' : '❌'} |\n`;
comment += `| **Speakers Detected** | ${speakerCount} | - | ℹ️ |\n`;
comment += `| **Benchmark Runtime** | ${benchmarkTime}s | - | ℹ️ |\n\n`;

if (success) {
const der = parseFloat('${{ steps.benchmark.outputs.DER }}').toFixed(1);
const jer = parseFloat('${{ steps.benchmark.outputs.JER }}').toFixed(1);
const rtf = parseFloat('${{ steps.benchmark.outputs.RTF }}').toFixed(2);
const duration = parseFloat('${{ steps.benchmark.outputs.DURATION }}').toFixed(1);
const speakerCount = '${{ steps.benchmark.outputs.SPEAKER_COUNT }}';

comment += `**Test File:** ES2004a (${duration}s audio)\n\n`;
comment += '| Metric | Value | Target | Status |\n';
comment += '|--------|-------|--------|---------|\n';
comment += `| **DER** (Diarization Error Rate) | ${der}% | < 30% | ${der < 30 ? '✅' : '❌'} |\n`;
comment += `| **JER** (Jaccard Error Rate) | ${jer}% | < 25% | ${jer < 25 ? '✅' : '❌'} |\n`;
comment += `| **RTF** (Real-Time Factor) | ${rtf}x | < 1.0x | ${rtf < 1.0 ? '✅' : '❌'} |\n`;
comment += `| **Speakers Detected** | ${speakerCount} | - | ℹ️ |\n\n`;

// Performance assessment
if (der < 20) {
comment += '🎉 **Excellent Performance!** - Competitive with state-of-the-art research\n';
} else if (der < 30) {
comment += '✅ **Good Performance** - Meeting target benchmarks\n';
} else {
comment += '⚠️ **Performance Below Target** - Consider parameter optimization\n';
}

comment += '\n📊 **Research Comparison:**\n';
comment += '- Powerset BCE (2023): 18.5% DER\n';
comment += '- EEND (2019): 25.3% DER\n';
comment += '- x-vector clustering: 28.7% DER\n';

// Performance assessment
if (der < 20) {
comment += '🎉 **Excellent Performance!** - Competitive with state-of-the-art research\n';
} else if (der < 30) {
comment += '✅ **Good Performance** - Meeting target benchmarks\n';
} else {
comment += '❌ **Benchmark Failed**\n\n';
comment += 'The single file benchmark could not complete successfully. ';
comment += 'This may be due to:\n';
comment += '- Network issues downloading test data\n';
comment += '- Model initialization problems\n';
comment += '- Audio processing errors\n\n';
comment += 'Please check the workflow logs for detailed error information.';
comment += '⚠️ **Performance Below Target** - Consider parameter optimization\n';
}

comment += '\n📊 **Research Comparison:**\n';
comment += '- Powerset BCE (2023): 18.5% DER\n';
comment += '- EEND (2019): 25.3% DER\n';
comment += '- x-vector clustering: 28.7% DER\n';

comment += '\n\n---\n*Automated benchmark using AMI corpus ES2004a test file*';

github.rest.issues.createComment({
Expand Down
102 changes: 85 additions & 17 deletions Sources/DiarizationCLI/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct DiarizationCLI {
--debug Enable debug mode
--output <file> Output results to JSON file
--auto-download Automatically download dataset if not found

NOTE: Benchmark now uses real AMI manual annotations from Tests/ami_public_1.6.2/
If annotations are not found, falls back to simplified placeholder.

Expand All @@ -72,22 +72,24 @@ struct DiarizationCLI {
EXAMPLES:
# Download AMI datasets
swift run fluidaudio download --dataset ami-sdm

# Run AMI SDM benchmark with auto-download
swift run fluidaudio benchmark --auto-download

# Run benchmark with custom threshold and save results
swift run fluidaudio benchmark --threshold 0.8 --output results.json

# Process a single audio file
swift run fluidaudio process meeting.wav

# Process file with custom settings
swift run fluidaudio process meeting.wav --threshold 0.6 --output output.json
""")
}

static func runBenchmark(arguments: [String]) async {
let benchmarkStartTime = Date()

var dataset = "ami-sdm"
var threshold: Float = 0.7
var minDurationOn: Float = 1.0
Expand Down Expand Up @@ -189,6 +191,9 @@ struct DiarizationCLI {
print("💡 Supported datasets: ami-sdm, ami-ihm")
exit(1)
}

let benchmarkElapsed = Date().timeIntervalSince(benchmarkStartTime)
print("\n⏱️ Total benchmark execution time: \(String(format: "%.1f", benchmarkElapsed)) seconds")
}

static func downloadDataset(arguments: [String]) async {
Expand Down Expand Up @@ -781,11 +786,74 @@ struct DiarizationCLI {
static func calculateJaccardErrorRate(
predicted: [TimedSpeakerSegment], groundTruth: [TimedSpeakerSegment]
) -> Float {
let totalGTDuration = groundTruth.reduce(0) { $0 + $1.durationSeconds }
let totalPredDuration = predicted.reduce(0) { $0 + $1.durationSeconds }
// If no segments in either prediction or ground truth, return 100% error
if predicted.isEmpty && groundTruth.isEmpty {
return 0.0 // Perfect match - both empty
} else if predicted.isEmpty || groundTruth.isEmpty {
return 100.0 // Complete mismatch - one empty, one not
}

// Use the same frame size as DER calculation for consistency
let frameSize: Float = 0.01
let totalDuration = max(
predicted.map { $0.endTimeSeconds }.max() ?? 0,
groundTruth.map { $0.endTimeSeconds }.max() ?? 0
)
let totalFrames = Int(totalDuration / frameSize)

// Get optimal speaker mapping using existing Hungarian algorithm
let speakerMapping = findOptimalSpeakerMapping(
predicted: predicted,
groundTruth: groundTruth,
totalDuration: totalDuration
)

var intersectionFrames = 0
var unionFrames = 0

// Calculate frame-by-frame Jaccard
for frame in 0..<totalFrames {
let frameTime = Float(frame) * frameSize

let gtSpeaker = findSpeakerAtTime(frameTime, in: groundTruth)
let predSpeaker = findSpeakerAtTime(frameTime, in: predicted)

let durationDiff = abs(totalGTDuration - totalPredDuration)
return (durationDiff / max(totalGTDuration, totalPredDuration)) * 100
// Map predicted speaker to ground truth speaker using optimal mapping
let mappedPredSpeaker = predSpeaker.flatMap { speakerMapping[$0] }

switch (gtSpeaker, mappedPredSpeaker) {
case (nil, nil):
// Both silent - no contribution to intersection or union
continue
case (nil, _):
// Ground truth silent, prediction has speaker
unionFrames += 1
case (_, nil):
// Ground truth has speaker, prediction silent
unionFrames += 1
case let (gt?, pred?):
// Both have speakers
unionFrames += 1
if gt == pred {
// Same speaker - contributes to intersection
intersectionFrames += 1
}
// Different speakers - only contributes to union
}
}

// Calculate Jaccard Index
let jaccardIndex = unionFrames > 0 ? Float(intersectionFrames) / Float(unionFrames) : 0.0

// Convert to error rate: JER = 1 - Jaccard Index
let jer = (1.0 - jaccardIndex) * 100.0

// Debug logging for first few calculations
if predicted.count > 0 && groundTruth.count > 0 {
print("🔍 JER DEBUG: Intersection: \(intersectionFrames), Union: \(unionFrames), Jaccard Index: \(String(format: "%.3f", jaccardIndex)), JER: \(String(format: "%.1f", jer))%")
}

return jer
}

static func findSpeakerAtTime(_ time: Float, in segments: [TimedSpeakerSegment]) -> String? {
Expand Down Expand Up @@ -833,7 +901,7 @@ struct DiarizationCLI {
// Find optimal assignment using Hungarian Algorithm for globally optimal solution
let predSpeakerArray = Array(predSpeakers).sorted() // Consistent ordering
let gtSpeakerArray = Array(gtSpeakers).sorted() // Consistent ordering

// Build numerical overlap matrix for Hungarian algorithm
var numericalOverlapMatrix: [[Int]] = []
for predSpeaker in predSpeakerArray {
Expand All @@ -843,35 +911,35 @@ struct DiarizationCLI {
}
numericalOverlapMatrix.append(row)
}

// Convert overlap matrix to cost matrix (higher overlap = lower cost)
let costMatrix = HungarianAlgorithm.overlapToCostMatrix(numericalOverlapMatrix)

// Solve optimal assignment
let assignments = HungarianAlgorithm.minimumCostAssignment(costs: costMatrix)

// Create speaker mapping from Hungarian result
var mapping: [String: String] = [:]
var totalAssignmentCost: Float = 0
var totalOverlap = 0

for (predIndex, gtIndex) in assignments.assignments.enumerated() {
if gtIndex != -1 && predIndex < predSpeakerArray.count && gtIndex < gtSpeakerArray.count {
let predSpeaker = predSpeakerArray[predIndex]
let gtSpeaker = gtSpeakerArray[gtIndex]
let overlap = overlapMatrix[predSpeaker]![gtSpeaker]!

if overlap > 0 { // Only assign if there's actual overlap
mapping[predSpeaker] = gtSpeaker
totalOverlap += overlap
print("🔍 HUNGARIAN MAPPING: '\(predSpeaker)' → '\(gtSpeaker)' (overlap: \(overlap) frames)")
}
}
}

totalAssignmentCost = assignments.totalCost
print("🔍 HUNGARIAN RESULT: Total assignment cost: \(String(format: "%.1f", totalAssignmentCost)), Total overlap: \(totalOverlap) frames")

// Handle unassigned predicted speakers
for predSpeaker in predSpeakerArray {
if mapping[predSpeaker] == nil {
Expand Down