In [1]:
import sys
sys.path.append('../')
from functions.functions import count_SIR, STFT1024, predict_5class_pilot, calculate_SIR_for_frames, create_ideal_pilot_from_SIR, get_SIR_of_iFICA_result
from classes.hybrid_system import HybridSystem_IPD17, HybridSystem_IPD35, HybridSystem_IPD14_IPD17, HybridSystem_IPD14_IPD17_IPD35, HybridSystem_ILD17, HybridSystem_ILD35, HybridSystem_ILD14_ILD17, HybridSystem_ILD14_ILD17_ILD35, HybridSystem_IPD17_ILD17

import numpy as np
import soundfile as sf
import pandas as pd
import plotly.graph_objects as go


In [2]:
number_of_samples = 600

## Orientační hodnoty ostatních pilotů pro testovací data

In [None]:
def get_SIR_for_basic_pilots(number_of_samples):
    SIR_original = np.zeros((number_of_samples, 1))
    SIR_SOI_pilot = np.zeros((number_of_samples, 1))
    SIR_impovement_SOI_pilot = np.zeros((number_of_samples, 1))
    SIR_label_pilot = np.zeros((number_of_samples, 1))
    SIR_impovement_label_pilot = np.zeros((number_of_samples, 1))
    SIR_ones_pilot = np.zeros((number_of_samples, 1))
    SIR_improvement_no_pilot = np.zeros((number_of_samples, 1))
    SIR_X_pilot = np.zeros((number_of_samples, 1))
    SIR_improvement_X_pilot = np.zeros((number_of_samples, 1))

    for i in range(number_of_samples):
        s, fs_s = sf.read(f'../data/test/s{i}.wav')
        y, fs_y = sf.read(f'../data/test/y{i}.wav')

        nr_of_mics = 9
        
        mix_signal = s + y
        
        SIR_original[i], _ = count_SIR(s, y)
        
        X = STFT1024(mix_signal.T).transpose(0,2,1)
        SOI = STFT1024(s.T).transpose(0,2,1)
        INT = STFT1024(y.T).transpose(0,2,1)
        
        d, N, K = X.shape

        # PILOT IDEALNI
        pilot = np.sum(np.abs(SOI[0,:,:])**2, axis=1)
        SIR_SOI_pilot[i] = get_SIR_of_iFICA_result(nr_of_mics, K, X, SOI, INT, pilot)
        SIR_impovement_SOI_pilot[i] = SIR_SOI_pilot[i] - SIR_original[i]

        # PILOT LABELOVY
        sir_values = calculate_SIR_for_frames(SOI.transpose(0,2,1), INT.transpose(0,2,1))
        pilot_label = create_ideal_pilot_from_SIR(sir_values)
        SIR_label_pilot[i] = get_SIR_of_iFICA_result(nr_of_mics, K, X, SOI, INT, pilot_label)
        SIR_impovement_label_pilot[i] = SIR_label_pilot[i] - SIR_original[i]

        # PILOT Z JEDNICEK
        pilot_ones = np.ones_like(pilot)
        SIR_ones_pilot[i] = get_SIR_of_iFICA_result(nr_of_mics, K, X, SOI, INT, pilot_ones)
        SIR_improvement_no_pilot[i] = SIR_ones_pilot[i] - SIR_original[i]

        # PILOT Z MIXOVANEHO SIGNALU
        pilot_X = np.sum(np.abs(X[0,:,:])**2, axis=1)
        SIR_X_pilot[i] = get_SIR_of_iFICA_result(nr_of_mics, K, X, SOI, INT, pilot_X)
        SIR_improvement_X_pilot[i] = SIR_X_pilot[i] - SIR_original[i]

    SIR_df = pd.DataFrame({
        'SIR_original': SIR_original.flatten(),
        'SIR_SOI_pilot': SIR_SOI_pilot.flatten(),
        'imp_SIR_SOI_pilot': SIR_impovement_SOI_pilot.flatten(),
        'SIR_label_pilot': SIR_label_pilot.flatten(),
        'imp_SIR_label_pilot': SIR_impovement_label_pilot.flatten(),
        'SIR_ones_pilot': SIR_ones_pilot.flatten(),
        'imp_SIR_ones_pilot': SIR_improvement_no_pilot.flatten(),
        'SIR_X_pilot': SIR_X_pilot.flatten(),
        'imp_SIR_X_pilot': SIR_improvement_X_pilot.flatten()
    })

    return SIR_df


In [None]:
SIR_df = get_SIR_for_basic_pilots(number_of_samples)

In [6]:
SIR_df.to_csv('../results/SIR_test.csv', index=True)

### Načtení již spočítaných hodnot

In [3]:
SIR_df = pd.read_csv('../results/SIR_test.csv', index_col=0)

## Hybridní systémy

In [4]:
def get_improvement_for_hybrid_system(hybrid_system, number_of_samples):
    SIR_model_pilot = np.zeros((number_of_samples, 1))
    SIR_improvement_model_pilot = np.zeros((number_of_samples, 1))

    for i in range(number_of_samples):
        s, _ = sf.read(f'../data/test/s{i}.wav')
        y, _ = sf.read(f'../data/test/y{i}.wav')

        mix_signal = s + y
        _, SIR_model_pilot[i], SIR_improvement_model_pilot[i] = hybrid_system.evaluate_extraction_of_main_speaker(mixture=mix_signal, soi=s, interference=y)

            
    return SIR_improvement_model_pilot, SIR_model_pilot


In [5]:
hybrid_system = HybridSystem_IPD17(model_path="../models/ipd17_model.pt")
SIR_imp_model_pilot, SIR_model_pilot = get_improvement_for_hybrid_system(hybrid_system, number_of_samples)

SIR_df['SIR_ipd17_model_pilot'] = SIR_model_pilot
SIR_df['imp_SIR_ipd17_model_pilot'] = SIR_imp_model_pilot

In [None]:
hybrid_system = HybridSystem_IPD35(model_path="./models/ipd35_model.pt")
SIR_imp_model_pilot, SIR_model_pilot = get_improvement_for_hybrid_system(hybrid_system, number_of_samples)

SIR_df['SIR_ipd35_model_pilot'] = SIR_model_pilot
SIR_df['imp_SIR_ipd35_model_pilot'] = SIR_imp_model_pilot

In [None]:
hybrid_system = HybridSystem_IPD14_IPD17(model_path="./models/ipd14_ipd17_model.pt")
SIR_imp_model_pilot, SIR_model_pilot = get_improvement_for_hybrid_system(hybrid_system, number_of_samples)

SIR_df['SIR_ipd14_ipd17_model_pilot'] = SIR_model_pilot
SIR_df['imp_SIR_ipd14_ipd17_model_pilot'] = SIR_imp_model_pilot

In [None]:
hybrid_system = HybridSystem_IPD17_ILD17(model_path="./models/ipd17_ild17_model.pt")
SIR_imp_model_pilot, SIR_model_pilot = get_improvement_for_hybrid_system(hybrid_system, number_of_samples)

SIR_df['SIR_ipd17_ild17_model_pilot'] = SIR_model_pilot
SIR_df['imp_SIR_ipd17_ild17_model_pilot'] = SIR_imp_model_pilot

In [None]:
hybrid_system = HybridSystem_IPD14_IPD17_IPD35(model_path="./models/ipd14_ipd17_ipd35_model.pt")
SIR_imp_model_pilot, SIR_model_pilot = get_improvement_for_hybrid_system(hybrid_system, number_of_samples)

SIR_df['SIR_ipd14_ipd17_ipd35_model_pilot'] = SIR_model_pilot
SIR_df['imp_SIR_ipd14_ipd17_ipd35_model_pilot'] = SIR_imp_model_pilot

In [None]:
hybrid_system = HybridSystem_ILD17(model_path="./models/ild17_model.pt")
SIR_imp_model_pilot, SIR_model_pilot = get_improvement_for_hybrid_system(hybrid_system, number_of_samples)

SIR_df['SIR_ild17_model_pilot'] = SIR_model_pilot
SIR_df['imp_SIR_ild17_model_pilot'] = SIR_imp_model_pilot

In [None]:
hybrid_system = HybridSystem_ILD35(model_path="./models/ild35_model.pt")
SIR_imp_model_pilot, SIR_model_pilot = get_improvement_for_hybrid_system(hybrid_system, number_of_samples)

SIR_df['SIR_ild35_model_pilot'] = SIR_model_pilot
SIR_df['imp_SIR_ild35_model_pilot'] = SIR_imp_model_pilot

In [None]:
hybrid_system = HybridSystem_ILD14_ILD17(model_path="./models/ild14_ild17_model.pt")
SIR_imp_model_pilot, SIR_model_pilot = get_improvement_for_hybrid_system(hybrid_system, number_of_samples)

SIR_df['SIR_ild14_ild17_model_pilot'] = SIR_model_pilot
SIR_df['imp_SIR_ild14_ild17_model_pilot'] = SIR_imp_model_pilot

In [None]:
hybrid_system = HybridSystem_ILD14_ILD17_ILD35(model_path="./models/ild14_ild17_ild35_model.pt")
SIR_imp_model_pilot, SIR_model_pilot = get_improvement_for_hybrid_system(hybrid_system, number_of_samples)

SIR_df['SIR_ild14_ild17_ild35_model_pilot'] = SIR_model_pilot
SIR_df['imp_SIR_ild14_ild17_ild35_model_pilot'] = SIR_imp_model_pilot

### Uložení nových hodnot

In [13]:
SIR_df.to_csv('../results/SIR_test.csv', index=True)

## Vyhodnocení

In [7]:
print(f"Průměrné zlepšení s pilotem ze SOI {np.mean(SIR_df['imp_SIR_SOI_pilot'])}")
print(f"Průměrné zlepšení s labelovým pilotem {np.mean(SIR_df['imp_SIR_label_pilot'])}")
print("-----------------------------------------------------")
print(f"Průměrné zlepšení jedničkový pilot {np.mean(SIR_df['imp_SIR_ones_pilot'])}")
print(f"Průměrné zlepšení s pilotem z X {np.mean(SIR_df['imp_SIR_X_pilot'])}")
print("-----------------------------------------------------")
print("Průměrné zlepšení s modely:")
print(f"ipd17 {np.mean(SIR_df['imp_SIR_ipd17_model_pilot'])}")
print(f"ipd35 {np.mean(SIR_df['imp_SIR_ipd35_model_pilot'])}")
print(f"ipd14_ipd17 {np.mean(SIR_df['imp_SIR_ipd14_ipd17_model_pilot'])}")
print(f"ipd17_ild17 {np.mean(SIR_df['imp_SIR_ipd17_ild17_model_pilot'])}")
print(f"ipd14_ipd17_ipd35 {np.mean(SIR_df['imp_SIR_ipd14_ipd17_ipd35_model_pilot'])}")
print(f"ild17 {np.mean(SIR_df['imp_SIR_ild17_model_pilot'])}")
print(f"ild35 {np.mean(SIR_df['imp_SIR_ild35_model_pilot'])}")
print(f"ild14_ild17 {np.mean(SIR_df['imp_SIR_ild14_ild17_model_pilot'])}")
print(f"ild14_ild17_ild35 {np.mean(SIR_df['imp_SIR_ild14_ild17_ild35_model_pilot'])}")

Průměrné zlepšení s pilotem ze SOI 6.131486621507868
Průměrné zlepšení s labelovým pilotem 6.426319148799491
-----------------------------------------------------
Průměrné zlepšení jedničkový pilot 3.70768234826794
Průměrné zlepšení s pilotem z X 4.017802264588616
-----------------------------------------------------
Průměrné zlepšení s modely:
ipd17 5.354334191029949
ipd35 4.037910934117927
ipd14_ipd17 5.463142009455973
ipd17_ild17 5.446696070798598
ipd14_ipd17_ipd35 5.520953698495893
ild17 5.285119870577854
ild35 5.223958778697805
ild14_ild17 5.433549850162899
ild14_ild17_ild35 5.475712734263329


In [8]:
fig_box = go.Figure()

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ipd17_model_pilot"],
    name="Model IPD 17",
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ipd35_model_pilot"],
    name="Model IPD 35",
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ipd14_ipd17_model_pilot"],
    name="Model IPD 14+17",
))


fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ipd14_ipd17_ipd35_model_pilot"],
    name="Model IPD 14+17+35",
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ild17_model_pilot"],
    name="Model ILD 17",
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ild35_model_pilot"],
    name="Model ILD 35",
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ild14_ild17_model_pilot"],
    name="Model ILD 14+17",
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ild14_ild17_ild35_model_pilot"],
    name="Model ILD 14+17+35",
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ipd17_ild17_model_pilot"],
    name="Model IPD 17 + ILD 17",
))  

fig_box.update_layout(
        title=dict(
        text="Zlepšení SIR extrahovaného signálu s pilotem z natrénovaných modelů (testovací data)",
        font=dict(size=20)
    ),
    xaxis=dict(
        title=dict(text="Model pro tvrobu pilotu", font=dict(size=18)),
        tickfont=dict(size=16)
    ),
    yaxis=dict(
        title=dict(text="Zlepšení SIR [dB]", font=dict(size=18)),
        tickfont=dict(size=16)
    ),
    showlegend=False,
    height=800,
    width=1000
)

fig_box.show()

In [9]:
fig_box = go.Figure()

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_SOI_pilot"],
    name="Ideální pilot (energie SOI)"
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_label_pilot"],
    name="Ideální pilot (intenzity, labely)"
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ild14_ild17_ild35_model_pilot"],
    name="Pilot z modelu (ILD 14+17+35)"
))


fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_X_pilot"],
    name="Pilot z energie původního signálu"
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ones_pilot"],
    name="Pilot z 1 (slepá extrakce)"
))


fig_box.update_layout(
    title=dict(
        text="Zlepšení SIR extrahovaného signálu s různými piloty (testovací data)",
        font=dict(size=20)
    ),
    xaxis=dict(
        title=dict(text="Typ pilotu", font=dict(size=18)),
        tickfont=dict(size=16)
    ),
    yaxis=dict(
        title=dict(text="Zlepšení SIR [dB]", font=dict(size=18)),
        tickfont=dict(size=16)
    ),
    showlegend=False,
    height=800,
    width=800
)

fig_box.show()


In [4]:
fig_box = go.Figure()

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ipd17_model_pilot"],
    name="Model IPD 17",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ipd35_model_pilot"],
    name="Model IPD 35",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ipd14_ipd17_model_pilot"],
    name="Model IPD 14+17",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ipd14_ipd17_ipd35_model_pilot"],
    name="Model IPD 14+17+35",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ild17_model_pilot"],
    name="Model ILD 17",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ild35_model_pilot"],
    name="Model ILD 35",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ild14_ild17_model_pilot"],
    name="Model ILD 14+17",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ild14_ild17_ild35_model_pilot"],
    name="Model ILD 14+17+35",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ipd17_ild17_model_pilot"],
    name="Model IPD 17 + ILD 17",
    line=dict(width=2)
))  

fig_box.update_layout(
    title=dict(
        text="Zlepšení SIR extrahovaného signálu s pilotem z natrénovaných modelů (testovací data)",
        font=dict(size=20)
    ),
    xaxis=dict(
        title=dict(text="Model pro tvorbu pilotu", font=dict(size=18)),
        tickfont=dict(size=16),
        linecolor='black',
        mirror=True,
        showgrid=True,
        gridcolor='lightgrey'
    ),
    yaxis=dict(
        title=dict(text="Zlepšení SIR [dB]", font=dict(size=18)),
        tickfont=dict(size=16),
        linecolor='black',
        mirror=True,
        showgrid=True,
        zeroline=True,
        zerolinecolor='lightgrey',
        gridcolor='lightgrey'
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    showlegend=False,
    height=800,
    width=1000
)

fig_box.show()

In [5]:
fig_box = go.Figure()

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_SOI_pilot"],
    name="Ideální pilot (energie SOI)",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_label_pilot"],
    name="Ideální pilot (intenzity, labely)",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ild14_ild17_ild35_model_pilot"],
    name="Pilot z modelu (ILD 14+17+35)",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_X_pilot"],
    name="Pilot z energie původního signálu",
    line=dict(width=2)
))

fig_box.add_trace(go.Box(
    y=SIR_df["imp_SIR_ones_pilot"],
    name="Pilot z 1 (slepá extrakce)",
    line=dict(width=2)
))

fig_box.update_layout(
    title=dict(
        text="Zlepšení SIR extrahovaného signálu s různými piloty (testovací data)",
        font=dict(size=20)
    ),
    xaxis=dict(
        title=dict(text="Typ pilotu", font=dict(size=18)),
        tickfont=dict(size=16),
        linecolor='black',
        mirror=True,
        showgrid=True,
        gridcolor='lightgrey'
    ),
    yaxis=dict(
        title=dict(text="Zlepšení SIR [dB]", font=dict(size=18)),
        tickfont=dict(size=16),
        linecolor='black',
        mirror=True,
        showgrid=True,
        zeroline=True,
        zerolinecolor='lightgrey',
        gridcolor='lightgrey'
    ),
    plot_bgcolor='white',
    paper_bgcolor='white',
    showlegend=False,
    height=800,
    width=800
)


fig_box.show()