In [17]:
import os
import numpy as np
import pandas as pd
import plotly
import plotly.graph_objs as go
import plotly.express as px
import plotly.io as pio
# import plotly.express as px
# import plotly.graph_objects as go
from IPython.display import display, HTML

# MWZ2.0/2.1 statistics

In [18]:
kage_mwz20_dialogue2len_path = './data/mwz20/train_dialogue2len.csv'
pptod_mwz20_dialogue2len_path = './data/mwz20/pptod_train_dialogue2len.csv'

In [19]:
kage_mwz20_dialogue2len_df = pd.read_csv(kage_mwz20_dialogue2len_path)
display(kage_mwz20_dialogue2len_df)

pptod_mwz20_dialogue2len_df = pd.read_csv(pptod_mwz20_dialogue2len_path)
display(pptod_mwz20_dialogue2len_df)

Unnamed: 0,dialogue_id,total_turns
0,MUL0001,10
1,MUL0002,7
2,MUL0005,9
3,MUL0006,8
4,MUL0007,7
...,...,...
7883,WOZ20671,3
7884,WOZ20672,5
7885,WOZ20673,5
7886,WOZ20674,4


Unnamed: 0,dialogue_id,total_turns
0,SNG01856,5
1,MUL2168,8
2,MUL2105,9
3,PMUL1690,11
4,MUL2395,7
...,...,...
7896,PMUL4251,9
7897,MUL1383,10
7898,SNG0827,5
7899,PMUL2395,7


Min, Max, and Avergae number of turns per dialogue in MWZ2.0

In [20]:
kage_mwz20_dialogue2len_df['total_turns'].describe()

count    7888.000000
mean        6.968940
std         2.588889
min         1.000000
25%         5.000000
50%         7.000000
75%         9.000000
max        22.000000
Name: total_turns, dtype: float64

## Path of MWZ20/21 KAGE

In [21]:
acc_path_list = [
    # Max Entropy
    './data/mwz20/KAGE/max_entropy/k100/k100_test_acc.csv',
    
    # Least Confidence
    './data/mwz20/KAGE/least_confidence/k100/k100_test_acc.csv',
    
    # Random
    './data/mwz20/KAGE/random/k100/k100_test_acc.csv',
    
    ################################################
    
    # Max Entropy
    './data/mwz20/PPTOD/max_entropy/k100/k100_test_acc.csv',
    
    # Least Confidence
    './data/mwz20/PPTOD/least_confidence/k100/k100_test_acc.csv',
    
    # Random
    './data/mwz20/PPTOD/random/k100/k100_test_acc.csv',

]

In [22]:
def add_data_size_col(df, test_acc_csv_path):
    # add number of instances
    if '2000' in test_acc_csv_path:
        df['# of labelled dialogue turns'] = df['round'] * 2000 + 2000
    elif '1500' in test_acc_csv_path:
        df['# of labelled dialogue turns'] = df['round'] * 1500 + 1500
    elif '1000' in test_acc_csv_path:
        df['# of labelled dialogue turns'] = df['round'] * 1000 + 1000
    elif '500' in test_acc_csv_path:
        df['# of labelled dialogue turns'] = df['round'] * 500 + 500
    elif '100' in test_acc_csv_path:
        df['# of labelled dialogue turns'] = df['round'] * 100 + 100
    
    df['# of labelled dialogue turns'] = df['# of labelled dialogue turns'].apply(lambda x: 7888 if x > 7888 else x)
    
    return df

In [23]:
def get_mean_std_acc_by_path(test_acc_csv_path):
    '''
    Get mean and std of acc
    '''
    
    df = pd.read_csv(test_acc_csv_path)
    df = add_data_size_col(df, test_acc_csv_path)
    

    
    if 'PPTOD' in test_acc_csv_path:
        # for PPTOD, divide by 100
        df['test_joint_acc'] = df['test_joint_acc'] / 100
        display(df.groupby('round')[['test_joint_acc']].mean())
        display(df.groupby('round')[['test_joint_acc']].std(ddof=0))
    else:
        display(df.groupby('round')[['test_joint_acc', 'test_slot_acc']].mean())
        display(df.groupby('round')[['test_joint_acc', 'test_slot_acc']].std(ddof=0))

## Plot

In [24]:
# Make sure the colors run in cycles if there are more lines than colors
def next_col(cols):
    while True:
        for col in cols:
            yield col

In [25]:
def get_colCycle():
    # define colors as a list 
    colors = px.colors.qualitative.Plotly

    # convert plotly hex colors to rgba to enable transparency adjustments
    def hex_rgba(hex, transparency):
        col_hex = hex.lstrip('#')
        col_rgb = list(int(col_hex[i:i+2], 16) for i in (0, 2, 4))
        col_rgb.extend([transparency])
        areacol = tuple(col_rgb)
        return areacol

    # rgba = [hex_rgba(c, transparency=0.2) for c in colors]
    rgba = [hex_rgba(c, transparency=0.2) for c in colors]
    colCycle = ['rgba'+str(elem) for elem in rgba]

    # for line color, which transparency is 1
    rgba = [hex_rgba(c, transparency=1) for c in colors]
    line_colCycle = ['rgba'+str(elem) for elem in rgba]

    # print(colCycle)
    
    return colCycle, line_colCycle

In [26]:
def plot_joint_acc_by_round(df_list, strategy_list):
    
#     full_baseline, lt_baseline, rand_baseline = baselines
    
#     if full_baseline == 0:
#         not_full_line = True
#     else:
#         not_full_line = False
    
    fig = go.Figure()
    
    colCycle = get_colCycle()
    line_color=next_col(cols=colCycle)
    
    
    
#     legend = ['ME', 'LC', 'RS']
    for i, df in enumerate(df_list):
    
        y_upper = list(df["test_joint_acc"] + df["joint_acc_std"])
        y_lower = list(df["test_joint_acc"] - df["joint_acc_std"])
        y_lower = y_lower[::-1]

        x_std = list(df["# of labelled dialogue turns"])
        x_std = x_std + x_std[::-1]

        new_col = next(line_color)
        
        fig.add_trace(
            go.Scatter(
                x=x_std, 
                y=y_upper + y_lower, 
                fill='tozerox',
                fillcolor=new_col,
                line=dict(color='rgba(255,255,255,0)'),
                name=strategy_list[i],
                mode="lines",
                showlegend=False,
            )
        )

        fig.add_trace(
            go.Scatter(
                x=df["# of labelled dialogue turns"], 
                y=df["test_joint_acc"], 
                line=dict(color=new_col, width=3),
    #             text=copyed_df['text'], 
                mode="lines+markers",
                name=strategy_list[i]
            )
        )
        
#         # Last Turn
#         fig.add_hline(
#             y=lt_baseline, 
#             line_width=1, 
#             line_dash="dot", # dash
#             line_color="green",
#             annotation_text="Last Turn(14.4%)", 
#             annotation_position="bottom left"
#         )
        
#         if not not_full_line:
#             # Full Data
#             fig.add_hline(
#                 y=full_baseline, 
#                 line_width=1, 
#                 line_dash="dot", # dash
#                 line_color="red",
#                 annotation_text="Full Data(100%)", 
#                 annotation_position="bottom left"
#             )
        
    
    
    fig.update_xaxes(type='category')
    fig.update_layout(
        yaxis_range=[0, 0.6],
#         title="Plot Title",
#         legend_title="Legend Title",
        xaxis_title="# of labelled dialogue turns",
        yaxis_title="Joint Goal Accuracy",
#         font=dict(
#             family="Courier New, monospace",
#             size=18,
#             color="RebeccaPurple"
#         )
    )
    
#     fig.update_traces(textposition="top center")
    
#     fig.show()
    return fig

In [27]:
def make_merged_df_for_plotting(test_acc_csv_path):
    acc_df = pd.read_csv(test_acc_csv_path)
    if 'PPTOD' in test_acc_csv_path:
        # for PPTOD, divide by 100
        acc_df['test_joint_acc'] = acc_df['test_joint_acc'] / 100

        mean = acc_df.groupby('round')[['test_joint_acc']].mean().reset_index()
        std = acc_df.groupby('round')[['test_joint_acc']].std(ddof=0).reset_index()
        std = std.rename(columns={
            'test_joint_acc': 'joint_acc_std'
        })
    else:
        mean = acc_df.groupby('round')[['test_joint_acc', 'test_slot_acc']].mean().reset_index()
        std = acc_df.groupby('round')[['test_joint_acc', 'test_slot_acc']].std(ddof=0).reset_index()
        std = std.rename(columns={
            'test_joint_acc': 'joint_acc_std',
            'test_slot_acc': 'slot_acc_std'
        })
    
#     display(mean)
#     display(std)
    
    merged = pd.merge(mean, std, on='round')
    merged = add_data_size_col(merged, test_acc_csv_path)
    
    return merged

In [28]:
def plot_k100_joint_acc(kage_mwz20_k100, pptod_mwz20_k100):
    
    kage_k100_df_list, kage_k100_strategy_list = kage_mwz20_k100
    pptod_k100_df_list, pptod_k100_strategy_list = pptod_mwz20_k100
    
    df_list = kage_k100_df_list + pptod_k100_df_list
    strategy_list = kage_k100_strategy_list + pptod_k100_strategy_list
    
    fig = go.Figure()
    
    colCycle, line_colCycle = get_colCycle()
    line_color=next_col(cols=line_colCycle)
    area_color=next_col(cols=colCycle)
    
    
#     legend = ['ME', 'LC', 'RS']
    for i, df in enumerate(df_list):
    
        y_upper = list(df["test_joint_acc"] + df["joint_acc_std"])
        y_lower = list(df["test_joint_acc"] - df["joint_acc_std"])
        y_lower = y_lower[::-1]

        x_std = list(df["# of labelled dialogue turns"])
        x_std = x_std + x_std[::-1]

        new_line_col = next(line_color)
        new_area_col = next(area_color)
        
#         fig.add_trace(
#             go.Scatter(
#                 x=x_std, 
#                 y=y_upper + y_lower, 
#                 fill='tozerox',
#                 fillcolor=new_col,
#                 line=dict(color='rgba(255,255,255,0)'),
#                 name=strategy_list[i],
#                 mode="lines",
#                 showlegend=False,
#             )
#         )
        
        if 'PPTOD' in strategy_list[i]:
            line_style = dict(color=new_line_col, width=4.5, dash='dash')
        else:
            line_style = dict(color=new_line_col, width=4.5)

        fig.add_trace(
            go.Scatter(
                x=df["# of labelled dialogue turns"], 
                y=df["test_joint_acc"], 
                line=line_style,
    #             text=copyed_df['text'], 
                mode="lines+markers",
                name=strategy_list[i]
            )
        )
        
        if 'PPTOD' in strategy_list[i]:
            # Last Turn
            fig.add_hline(
                y=0.2656, 
                line_width=1, 
                line_dash="dashdot", # dash
                line_color="green",
                annotation_text="PPTOD-LastTurn", 
                annotation_position="top left",
                annotation_font_size=10,
                annotation_font_color="rgba(98, 85, 92, 0.8)",
            )
        else:
            # Last Turn
            fig.add_hline(
                y=0.2259, 
                line_width=1, 
                line_dash="dashdot", # dash
                line_color="red",
                annotation_text="KAGE-LastTurn", 
                annotation_position="top left",
                annotation_font_size=10,
                annotation_font_color="rgba(98, 85, 92, 0.8)",
            )
        
#         # Last Turn
#         fig.add_hline(
#             y=lt_baseline, 
#             line_width=1, 
#             line_dash="dashdot", # dash
#             line_color="green",
#             annotation_text="Last Turn(14.4%)", 
#             annotation_position="bottom left"
#         )
        
#         if not not_full_line:
#             # Full Data
#             fig.add_hline(
#                 y=full_baseline, 
#                 line_width=1, 
#                 line_dash="dot", # dash
#                 line_color="red",
#                 annotation_text="Full Data(100%)", 
#                 annotation_position="bottom left"
#             )
        
    
    
    fig.update_xaxes(type='category')
    fig.update_layout(
        margin_l=5, margin_t=5, margin_b=5, margin_r=5,
        yaxis_range=[0, 0.3],
#         title="Plot Title",
#         legend_title="Legend Title",
        xaxis_title="# of labelled dialogue turns",
        yaxis_title="Joint Goal Accuracy",
        legend=dict(
            title=None, orientation = 'h', y=1, yanchor="bottom", x=0.5, xanchor="center",
            font=dict(size=10)
        )
    )
    
#     fig.update_traces(textposition="top center")
    
#     fig.show()
    return fig

In [29]:
def prepare_subplot_inputs(k, mwz, model):
    
    strategies = ['random', 'least_confidence', 'max_entropy']
#     strategy_list = ['RS', 'LC', 'ME']
    
    baselines = []
    # [full, last_turn, random]
    if model == 'KAGE':
        if mwz == 20:
            baselines = [0.2259]
            strategy_list = ['KAGE+RS', 'KAGE+LC', 'KAGE+ME']
    elif model == 'PPTOD':
        if mwz == 20:
            baselines = [0.5337, 0.4383, 0.4461]
            strategy_list = ['PPTOD+RS', 'PPTOD+LC', 'PPTOD+ME']

    
    df_list = []
    for strategy in strategies:
        acc_csv_path = f'./data/mwz{mwz}/{model}/{strategy}/k{k}/k{k}_test_acc.csv'
        print(f'============== {acc_csv_path} ==============')

        df = make_merged_df_for_plotting(acc_csv_path)
#         display(df)
        
        # only use round <= 4
        if k == 100:
            df = df[df['round'] <= 4]
            
        display(df)
        df_list.append(df)


#     print(f'baselines: {baselines}')
    
    return df_list, strategy_list

In [30]:
kage_mwz20_k100 = prepare_subplot_inputs(100, 20, 'KAGE')
pptod_mwz20_k100 = prepare_subplot_inputs(100, 20, 'PPTOD')



Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.01682,0.827523,0.0,0.0,100
1,1,0.018516,0.839017,0.000475,0.001976,200
2,2,0.146907,0.901042,0.009631,0.005202,300
3,3,0.154504,0.902469,0.041238,0.014732,400
4,4,0.244575,0.934258,0.00407,0.00248,500




Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.017092,0.831158,0.000272,0.001768,100
1,1,0.018041,0.833257,0.001221,0.005711,200
2,2,0.156201,0.874141,0.022859,0.037168,300
3,3,0.178511,0.874035,0.010312,0.046498,400
4,4,0.275366,0.941694,0.001764,0.00071,500




Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.01682,0.828043,0.0,0.001215,100
1,1,0.028305,0.842245,0.016052,0.009581,200
2,2,0.157488,0.91327,0.024106,0.007291,300
3,3,0.22703,0.931416,0.01689,0.002654,400
4,4,0.276678,0.941997,0.004461,0.001628,500




Unnamed: 0,round,test_joint_acc,joint_acc_std,# of labelled dialogue turns
0,0,0.154232,0.0,100
1,1,0.169425,0.0,200
2,2,0.187466,0.0,300
3,3,0.241454,0.0,400
4,4,0.235757,0.0,500




Unnamed: 0,round,test_joint_acc,joint_acc_std,# of labelled dialogue turns
0,0,0.143923,0.0,100
1,1,0.201709,0.0,200
2,2,0.25,0.0,300
3,3,0.256028,0.0,400
4,4,0.283234,0.0,500




Unnamed: 0,round,test_joint_acc,joint_acc_std,# of labelled dialogue turns
0,0,0.195876,0.0,100
1,1,0.240776,0.0,200
2,2,0.275672,0.0,300
3,3,0.277581,0.0,400
4,4,0.279616,0.0,500


In [33]:
# plot_k100_joint_acc(kage_mwz20_k100, pptod_mwz20_k100)

ablation_k100 = plot_k100_joint_acc(kage_mwz20_k100, pptod_mwz20_k100)
pio.write_image(ablation_k100, "./data/plot/ablation_base_DST_k100_v2.pdf", width=450, height=350)

In [16]:
kage_k500_s998_lt_df = pd.DataFrame({
    'joint_acc': [0.243082, 0.208627]
})
kage_k500_s998_lt_df.mean()

joint_acc    0.225855
dtype: float64

In [17]:
k = 100
mwz = 20
# [full, last_turn, random]
if mwz == 20:
    baselines = [0, 0.2259, 0]

strategies = ['random', 'least_confidence', 'max_entropy']
strategy_list = ['RS', 'LC', 'ME']

# acc_csv_paths = [f'./data/mwz20/KAGE/{strategy}/k2000/k2000_test_acc.csv' for strategy in strategies]

df_list = []
for strategy in strategies:
    acc_csv_path = f'./data/mwz{mwz}/KAGE/{strategy}/k{k}/k{k}_test_acc.csv'
    print(f'============== {acc_csv_path} ==============')
    
    df = make_merged_df_for_plotting(acc_csv_path)
    display(df)
    
    df_list.append(df)
    
#     print()
    
plot_joint_acc_by_round(df_list, strategy_list)    



Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.01682,0.827523,0.0,0.0,100
1,1,0.018516,0.839017,0.000475,0.001976,200
2,2,0.146907,0.901042,0.009631,0.005202,300
3,3,0.154504,0.902469,0.041238,0.014732,400
4,4,0.244575,0.934258,0.00407,0.00248,500
5,5,0.290423,0.941628,0.013836,0.003841,600
6,6,0.30392,0.945989,0.002374,0.000719,700
7,7,0.341224,0.952155,0.01621,0.001811,800
8,8,0.35472,0.952631,0.01058,0.000864,900
9,9,0.35879,0.954865,0.002442,0.000172,1000




Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.017092,0.831158,0.000272,0.001768,100
1,1,0.018041,0.833257,0.001221,0.005711,200
2,2,0.156201,0.874141,0.022859,0.037168,300
3,3,0.178511,0.874035,0.010312,0.046498,400
4,4,0.275366,0.941694,0.001764,0.00071,500
5,5,0.31111,0.948818,0.0078,0.001911,600
6,6,0.328269,0.951841,0.015192,0.001903,700
7,7,0.338714,0.953323,0.0,0.0,800
8,8,0.383071,0.958501,0.0,0.0,900
9,9,0.383342,0.959138,0.0,0.0,1000




Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.01682,0.828043,0.0,0.001215,100
1,1,0.028305,0.842245,0.016052,0.009581,200
2,2,0.157488,0.91327,0.024106,0.007291,300
3,3,0.22703,0.931416,0.01689,0.002654,400
4,4,0.276678,0.941997,0.004461,0.001628,500
5,5,0.301682,0.946325,0.009391,0.001603,600
6,6,0.313709,0.947486,0.004798,0.000762,700
7,7,0.345451,0.952891,0.000942,0.000507,800
8,8,0.363673,0.954645,0.020522,0.002373,900
9,9,0.377193,0.956849,0.008512,0.001438,1000


In [18]:
# # ME all avg:
# x = pd.DataFrame({
#     'test_joint_acc': [0.513384, 0.509812, 0.518448],
#     'test_slot_acc': [0.971557, 0.971176, 0.972468]
# })
# x.describe()

In [19]:
def print_test_acc_by_path(acc_path_list):
    for acc_path in acc_path_list:
        print(f'------------ {acc_path} ------------')
        get_mean_std_acc_by_path(acc_path)

In [20]:
print_test_acc_by_path(acc_path_list)

------------ ./data/mwz20/KAGE/max_entropy/k100/k100_test_acc.csv ------------


Unnamed: 0_level_0,test_joint_acc,test_slot_acc
round,Unnamed: 1_level_1,Unnamed: 2_level_1
0,0.01682,0.828043
1,0.028305,0.842245
2,0.157488,0.91327
3,0.22703,0.931416
4,0.276678,0.941997
5,0.301682,0.946325
6,0.313709,0.947486
7,0.345451,0.952891
8,0.363673,0.954645
9,0.377193,0.956849


Unnamed: 0_level_0,test_joint_acc,test_slot_acc
round,Unnamed: 1_level_1,Unnamed: 2_level_1
0,0.0,0.001215
1,0.016052,0.009581
2,0.024106,0.007291
3,0.01689,0.002654
4,0.004461,0.001628
5,0.009391,0.001603
6,0.004798,0.000762
7,0.000942,0.000507
8,0.020522,0.002373
9,0.008512,0.001438


------------ ./data/mwz20/KAGE/least_confidence/k100/k100_test_acc.csv ------------


Unnamed: 0_level_0,test_joint_acc,test_slot_acc
round,Unnamed: 1_level_1,Unnamed: 2_level_1
0,0.017092,0.831158
1,0.018041,0.833257
2,0.156201,0.874141
3,0.178511,0.874035
4,0.275366,0.941694
5,0.31111,0.948818
6,0.328269,0.951841
7,0.338714,0.953323
8,0.383071,0.958501
9,0.383342,0.959138


Unnamed: 0_level_0,test_joint_acc,test_slot_acc
round,Unnamed: 1_level_1,Unnamed: 2_level_1
0,0.000272,0.001768
1,0.001221,0.005711
2,0.022859,0.037168
3,0.010312,0.046498
4,0.001764,0.00071
5,0.0078,0.001911
6,0.015192,0.001903
7,0.0,0.0
8,0.0,0.0
9,0.0,0.0


------------ ./data/mwz20/KAGE/random/k100/k100_test_acc.csv ------------


Unnamed: 0_level_0,test_joint_acc,test_slot_acc
round,Unnamed: 1_level_1,Unnamed: 2_level_1
0,0.01682,0.827523
1,0.018516,0.839017
2,0.146907,0.901042
3,0.154504,0.902469
4,0.244575,0.934258
5,0.290423,0.941628
6,0.30392,0.945989
7,0.341224,0.952155
8,0.35472,0.952631
9,0.35879,0.954865


Unnamed: 0_level_0,test_joint_acc,test_slot_acc
round,Unnamed: 1_level_1,Unnamed: 2_level_1
0,0.0,0.0
1,0.000475,0.001976
2,0.009631,0.005202
3,0.041238,0.014732
4,0.00407,0.00248
5,0.013836,0.003841
6,0.002374,0.000719
7,0.01621,0.001811
8,0.01058,0.000864
9,0.002442,0.000172


------------ ./data/mwz20/PPTOD/max_entropy/k100/k100_test_acc.csv ------------


Unnamed: 0_level_0,test_joint_acc
round,Unnamed: 1_level_1
0,0.195876
1,0.240776
2,0.275672
3,0.277581
4,0.279616


Unnamed: 0_level_0,test_joint_acc
round,Unnamed: 1_level_1
0,0.0
1,0.0
2,0.0
3,0.0
4,0.0


------------ ./data/mwz20/PPTOD/least_confidence/k100/k100_test_acc.csv ------------


Unnamed: 0_level_0,test_joint_acc
round,Unnamed: 1_level_1
0,0.143923
1,0.201709
2,0.25
3,0.256028
4,0.283234


Unnamed: 0_level_0,test_joint_acc
round,Unnamed: 1_level_1
0,0.0
1,0.0
2,0.0
3,0.0
4,0.0


------------ ./data/mwz20/PPTOD/random/k100/k100_test_acc.csv ------------


Unnamed: 0_level_0,test_joint_acc
round,Unnamed: 1_level_1
0,0.154232
1,0.169425
2,0.187466
3,0.241454
4,0.235757


Unnamed: 0_level_0,test_joint_acc
round,Unnamed: 1_level_1
0,0.0
1,0.0
2,0.0
3,0.0
4,0.0


## Calulate Annotation Cost

In [21]:
def cal_turn_percentage(x):
    
    if int(x['total_turns']) == 0:
        return 1
    
    return int(x['turn_idx'])/int(x['total_turns'])

In [22]:
def merge_df(dialogue2len_df, selected_turn_df):
    
    selected_turn_df['dialogue_id'] = selected_turn_df['selected_turn_id'].apply(lambda x: x.split('-')[0])
    selected_turn_df['turn_idx'] = selected_turn_df['selected_turn_id'].apply(lambda x: int(x.split('-')[1]))
#     display(selected_turn_df)
    
    # merge 
    merged_df = pd.merge(selected_turn_df, dialogue2len_df, on='dialogue_id', how='left')
    merged_df['dialogue'] = merged_df.index
#     merged_df['total_turns'] = merged_df['total_turns'] - 1
    # turn_idx+1 because the index starts from 0
    merged_df['turn_idx'] = merged_df['turn_idx'] + 1
    
    ### metric1: calculate turn percentage
    # to see the model tends to select which turn of each dialogue in each round
    # should be in [0, 1], 0 means select the first turn, 1 means select the last turn
    merged_df['turn_percentage'] = merged_df.apply(lambda x: cal_turn_percentage(x), axis=1)
    
    ### metric2: # of turns that are read by annotators
    # if total_turns is 10, select turn_idx is 3, then annotator needs to read 3/10 turns to 
    # label the turn_idx=3 turn
    annotate_turns_percent = round(merged_df['turn_percentage'].mean(), 4)
    std_annotate_turns_percent = round(merged_df['turn_percentage'].std(ddof=0), 4)
    print(f'# of turns read by annotators: mean - {annotate_turns_percent} std - {std_annotate_turns_percent}')
    
    merged_df_wo_budget = merged_df[merged_df['round'] != -1].reset_index(drop=True)
    annotate_turns_percent_wo_budget = round(merged_df_wo_budget['turn_percentage'].mean(), 4)
    print(f'# of turns read by annotators without budget: {annotate_turns_percent_wo_budget}')
    print('--------------------------------------------------')
    
    merged_df['turn_percentage_by_round'] = merged_df.groupby('round')['turn_percentage'].transform('mean')
    annotate_turns_percent_by_round = merged_df.groupby('round')['turn_percentage'].mean()
    std_annotate_turns_percent_by_round = merged_df.groupby('round')['turn_percentage'].std(ddof=0)
    for idx in annotate_turns_percent_by_round.index:
        print(f'# of turns read by annotators by round {idx}: mean - '
              f'{round(annotate_turns_percent_by_round[idx], 4)} '
              f'std - {round(std_annotate_turns_percent_by_round[idx], 4)}'
             
             )
    
    
    return merged_df

In [23]:
def read_all_by_folder_name(folder_name):
    selected_turn_path_list = []
    for filename in os.listdir(folder_name):
        if not filename.endswith('selected_turn_id.csv'):
            continue
#         print(filename)
        selected_turn_path_list.append(f'{folder_name}/{filename}')
#     print(selected_turn_path_list)
    
    df_list = []
    for path in selected_turn_path_list:
        df = pd.read_csv(path, usecols=[0,1])
        
        df = df[df['round'] <= 4]
        
        df_list.append(df)
        
    merged = pd.concat(df_list)
    
    if 'PPTOD' in folder_name:
        merged_statis = merge_df(pptod_mwz20_dialogue2len_df, merged)
    else:
    
        merged_statis = merge_df(kage_mwz20_dialogue2len_df, merged)
        
    return merged_statis

In [24]:
selected_turn_folder_list = [
    # Max Entropy
    './data/mwz20/KAGE/max_entropy/k100',
    
    # Least Confidence
    './data/mwz20/KAGE/least_confidence/k100',
    
    # Random
    './data/mwz20/KAGE/random/k100',
    
    #####################################
    # Max Entropy
    './data/mwz20/PPTOD/max_entropy/k100',
    
    # Least Confidence
    './data/mwz20/PPTOD/least_confidence/k100',
    
    # Random
    './data/mwz20/PPTOD/random/k100',

]

In [25]:
for selected_turn_folder in selected_turn_folder_list:
    print(f'============= {selected_turn_folder} =============')
    statis = read_all_by_folder_name(selected_turn_folder)
    
    display(statis)
    print()

# of turns read by annotators: mean - 0.7518 std - 0.2905
# of turns read by annotators without budget: 0.7518
--------------------------------------------------
# of turns read by annotators by round 0: mean - 0.4561 std - 0.2694
# of turns read by annotators by round 1: mean - 0.9452 std - 0.117
# of turns read by annotators by round 2: mean - 0.9226 std - 0.141
# of turns read by annotators by round 3: mean - 0.7326 std - 0.2727
# of turns read by annotators by round 4: mean - 0.7024 std - 0.2906


Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
0,0,PMUL0274-4,PMUL0274,5,10,0,0.500000,0.456058
1,0,PMUL0620-0,PMUL0620,1,7,1,0.142857,0.456058
2,0,WOZ20501-2,WOZ20501,3,5,2,0.600000,0.456058
3,0,MUL1277-0,MUL1277,1,9,3,0.111111,0.456058
4,0,PMUL1240-0,PMUL1240,1,6,4,0.166667,0.456058
...,...,...,...,...,...,...,...,...
1495,4,SNG1357-0,SNG1357,1,5,1495,0.200000,0.702376
1496,4,MUL1704-2,MUL1704,3,6,1496,0.500000,0.702376
1497,4,MUL1992-1,MUL1992,2,7,1497,0.285714,0.702376
1498,4,MUL2631-11,MUL2631,12,12,1498,1.000000,0.702376



# of turns read by annotators: mean - 0.8551 std - 0.2465
# of turns read by annotators without budget: 0.8551
--------------------------------------------------
# of turns read by annotators by round 0: mean - 0.5954 std - 0.293
# of turns read by annotators by round 1: mean - 0.9748 std - 0.0836
# of turns read by annotators by round 2: mean - 0.9299 std - 0.2027
# of turns read by annotators by round 3: mean - 0.8774 std - 0.211
# of turns read by annotators by round 4: mean - 0.8979 std - 0.1891


Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
0,0,WOZ20203-2,WOZ20203,3,3,0,1.000000,0.595392
1,0,PMUL2157-3,PMUL2157,4,7,1,0.571429,0.595392
2,0,PMUL2843-8,PMUL2843,9,11,2,0.818182,0.595392
3,0,PMUL2545-4,PMUL2545,5,5,3,1.000000,0.595392
4,0,SSNG0203-4,SSNG0203,5,10,4,0.500000,0.595392
...,...,...,...,...,...,...,...,...
995,4,PMUL1469-4,PMUL1469,5,6,995,0.833333,0.897856
996,4,PMUL4135-7,PMUL4135,8,9,996,0.888889,0.897856
997,4,PMUL3653-11,PMUL3653,12,12,997,1.000000,0.897856
998,4,SSNG0141-3,SSNG0141,4,4,998,1.000000,0.897856



# of turns read by annotators: mean - 0.5783 std - 0.2812
# of turns read by annotators without budget: 0.5783
--------------------------------------------------
# of turns read by annotators by round 0: mean - 0.577 std - 0.2782
# of turns read by annotators by round 1: mean - 0.6176 std - 0.2727
# of turns read by annotators by round 2: mean - 0.5596 std - 0.2888
# of turns read by annotators by round 3: mean - 0.5547 std - 0.2739
# of turns read by annotators by round 4: mean - 0.5826 std - 0.2876


Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
0,0,SNG0561-4,SNG0561,5,5,0,1.000000,0.576952
1,0,MUL1121-1,MUL1121,2,7,1,0.285714,0.576952
2,0,PMUL3144-8,PMUL3144,9,9,2,1.000000,0.576952
3,0,PMUL4298-5,PMUL4298,6,7,3,0.857143,0.576952
4,0,PMUL4942-2,PMUL4942,3,8,4,0.375000,0.576952
...,...,...,...,...,...,...,...,...
995,4,PMUL4669-1,PMUL4669,2,7,995,0.285714,0.582591
996,4,MUL2668-5,MUL2668,6,7,996,0.857143,0.582591
997,4,PMUL4677-4,PMUL4677,5,5,997,1.000000,0.582591
998,4,PMUL3655-0,PMUL3655,1,8,998,0.125000,0.582591



# of turns read by annotators: mean - 0.5868 std - 0.3153
# of turns read by annotators without budget: 0.5868
--------------------------------------------------
# of turns read by annotators by round 0: mean - 0.5522 std - 0.3231
# of turns read by annotators by round 1: mean - 0.5913 std - 0.2887
# of turns read by annotators by round 2: mean - 0.5708 std - 0.3414
# of turns read by annotators by round 3: mean - 0.74 std - 0.2551
# of turns read by annotators by round 4: mean - 0.4795 std - 0.3022


Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
0,0,WOZ20000-0,WOZ20000,1,5,0,0.200000,0.552174
1,0,PMUL3877-0,PMUL3877,1,4,1,0.250000,0.552174
2,0,MUL0015-0,MUL0015,1,9,2,0.111111,0.552174
3,0,PMUL4105-3,PMUL4105,4,10,3,0.400000,0.552174
4,0,MUL1524-0,MUL1524,1,8,4,0.125000,0.552174
...,...,...,...,...,...,...,...,...
495,4,WOZ20153-0,WOZ20153,1,3,495,0.333333,0.479531
496,4,PMUL1989-7,PMUL1989,8,9,496,0.888889,0.479531
497,4,MUL1962-3,MUL1962,4,6,497,0.666667,0.479531
498,4,SNG1317-0,SNG1317,1,4,498,0.250000,0.479531



# of turns read by annotators: mean - 0.8113 std - 0.2228
# of turns read by annotators without budget: 0.8113
--------------------------------------------------
# of turns read by annotators by round 0: mean - 0.7952 std - 0.2219
# of turns read by annotators by round 1: mean - 0.8343 std - 0.2282
# of turns read by annotators by round 2: mean - 0.8177 std - 0.2138
# of turns read by annotators by round 3: mean - 0.7279 std - 0.2345
# of turns read by annotators by round 4: mean - 0.8813 std - 0.1831


Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
0,0,PMUL2369-2,PMUL2369,3,8,0,0.375000,0.795186
1,0,PMUL3009-7,PMUL3009,8,10,1,0.800000,0.795186
2,0,MUL1236-8,MUL1236,9,11,2,0.818182,0.795186
3,0,MUL1907-5,MUL1907,6,6,3,1.000000,0.795186
4,0,MUL1990-4,MUL1990,5,8,4,0.625000,0.795186
...,...,...,...,...,...,...,...,...
495,4,SNG02298-3,SNG02298,4,5,495,0.800000,0.881302
496,4,SSNG0374-3,SSNG0374,4,4,496,1.000000,0.881302
497,4,MUL0758-4,MUL0758,5,11,497,0.454545,0.881302
498,4,MUL1170-3,MUL1170,4,6,498,0.666667,0.881302



# of turns read by annotators: mean - 0.5792 std - 0.3043
# of turns read by annotators without budget: 0.5792
--------------------------------------------------
# of turns read by annotators by round 0: mean - 0.592 std - 0.3049
# of turns read by annotators by round 1: mean - 0.6426 std - 0.2919
# of turns read by annotators by round 2: mean - 0.5382 std - 0.2935
# of turns read by annotators by round 3: mean - 0.4911 std - 0.3107
# of turns read by annotators by round 4: mean - 0.6319 std - 0.2927


Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
0,0,PMUL2290-4,PMUL2290,5,5,0,1.000000,0.592026
1,0,MUL1717-5,MUL1717,6,7,1,0.857143,0.592026
2,0,PMUL2501-2,PMUL2501,3,7,2,0.428571,0.592026
3,0,MUL0611-6,MUL0611,7,10,3,0.700000,0.592026
4,0,WOZ20000-2,WOZ20000,3,5,4,0.600000,0.592026
...,...,...,...,...,...,...,...,...
495,4,SSNG0197-3,SSNG0197,4,4,495,1.000000,0.631876
496,4,PMUL1190-0,PMUL1190,1,4,496,0.250000,0.631876
497,4,MUL0451-0,MUL0451,1,7,497,0.142857,0.631876
498,4,SNG0315-5,SNG0315,6,6,498,1.000000,0.631876





In [45]:
# only read first two cols because columns number does not match for the final round
k2000S202 = pd.read_csv('./data/mwz20/KAGE/max_entropy/k2000/k2000S202_selected_turn_id.csv',usecols=[0,1])
k2000S588 = pd.read_csv('./data/mwz20/KAGE/max_entropy/k2000/k2000S588_selected_turn_id.csv',usecols=[0,1])
k2000S813 = pd.read_csv('./data/mwz20/KAGE/max_entropy/k2000/k2000S813_selected_turn_id.csv',usecols=[0,1])
k2000_merged = pd.concat([k2000S202, k2000S588, k2000S813])
k2000_merged

Unnamed: 0,round,selected_turn_id
0,0,MUL1950-7
1,0,PMUL3752-3
2,0,PMUL3020-0
3,0,PMUL1571-0
4,0,MUL0429-1
...,...,...
7883,3,WOZ20634-1
7884,3,WOZ20649-2
7885,3,WOZ20667-1
7886,3,WOZ20672-4


In [46]:
k2000_merged_statis = merge_df(kage_mwz20_dialogue2len_df, k2000_merged)
k2000_merged_statis

# of turns read by annotators: mean - 0.6258 std - 0.2852
# of turns read by annotators without budget: 0.6258
--------------------------------------------------
# of turns read by annotators by round 0: mean - 0.5782 std - 0.2997
# of turns read by annotators by round 1: mean - 0.6536 std - 0.2946
# of turns read by annotators by round 2: mean - 0.6409 std - 0.2671
# of turns read by annotators by round 3: mean - 0.6308 std - 0.2715


Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
0,0,MUL1950-7,MUL1950,8,8,0,1.000000,0.578239
1,0,PMUL3752-3,PMUL3752,4,8,1,0.500000,0.578239
2,0,PMUL3020-0,PMUL3020,1,10,2,0.100000,0.578239
3,0,PMUL1571-0,PMUL1571,1,8,3,0.125000,0.578239
4,0,MUL0429-1,MUL0429,2,8,4,0.250000,0.578239
...,...,...,...,...,...,...,...,...
23659,3,WOZ20634-1,WOZ20634,2,7,23659,0.285714,0.630771
23660,3,WOZ20649-2,WOZ20649,3,3,23660,1.000000,0.630771
23661,3,WOZ20667-1,WOZ20667,2,4,23661,0.500000,0.630771
23662,3,WOZ20672-4,WOZ20672,5,5,23662,1.000000,0.630771


In [47]:
k2000_merged_statis[k2000_merged_statis['dialogue_id'] == 'MUL1950']

Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
0,0,MUL1950-7,MUL1950,8,8,0,1.0,0.578239
12295,2,MUL1950-3,MUL1950,4,8,12295,0.5,0.640939
16179,0,MUL1950-5,MUL1950,6,8,16179,0.75,0.578239


In [48]:
k2000_merged_statis[k2000_merged_statis['dialogue_id'] == 'PMUL3752']

Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
1,0,PMUL3752-3,PMUL3752,4,8,1,0.5,0.578239
10550,1,PMUL3752-2,PMUL3752,3,8,10550,0.375,0.653631
21623,2,PMUL3752-1,PMUL3752,2,8,21623,0.25,0.640939


In [49]:
k2000_merged_statis[k2000_merged_statis['dialogue_id'] == 'WOZ20634']

Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
5136,2,WOZ20634-0,WOZ20634,1,7,5136,0.142857,0.640939
11515,1,WOZ20634-2,WOZ20634,3,7,11515,0.428571,0.653631
23659,3,WOZ20634-1,WOZ20634,2,7,23659,0.285714,0.630771


In [50]:
k2000_merged_statis[k2000_merged_statis['dialogue_id'] == 'WOZ20672']

Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
641,0,WOZ20672-1,WOZ20672,2,5,641,0.4,0.578239
15775,3,WOZ20672-2,WOZ20672,3,5,15775,0.6,0.630771
23662,3,WOZ20672-4,WOZ20672,5,5,23662,1.0,0.630771


PMUL3752 这个例子从turn_idx=4开始就不会增加新的state，因此我们的模型可以很好的找到这个turn的临界值去标记。 同时，round=0标记的是turn=3，因为turn=3是state最多，最难预测的，随着round的进行，模型已经学习了一定的能力，因此标记会越来越靠前。

同样，MUL1950在turn_idx=4开始，state就增加的很少，因此turn_idx=3的时候，state达到峰值，此时去标记最划算。

可以统计一下每个turn的total state，看看截止到哪个turn之后，总state最多，标记哪个turn最划算。

In [24]:
k2000_merged_statis[['dialogue_id', 'turn_idx', 'total_turns']].groupby(['dialogue_id','turn_idx']).count()

Unnamed: 0_level_0,Unnamed: 1_level_0,total_turns
dialogue_id,turn_idx,Unnamed: 2_level_1
MUL0001,7,1
MUL0001,9,1
MUL0001,10,1
MUL0002,6,1
MUL0002,7,2
...,...,...
WOZ20674,2,1
WOZ20674,4,2
WOZ20675,2,1
WOZ20675,3,1


For each round, the number of reading cost is decreasing, indicating that the annotation reduction by using AL.

The first round may be random, because the model has not seen any data

In [25]:
merge_df(kage_mwz20_dialogue2len_df, k2000S202)

# of turns read by annotators: mean - 0.6361 std - 0.2943
# of turns read by annotators without budget: 0.6361
--------------------------------------------------
# of turns read by annotators by round 0: mean - 0.41 std - 0.2391
# of turns read by annotators by round 1: mean - 0.7741 std - 0.2668
# of turns read by annotators by round 2: mean - 0.6718 std - 0.2712
# of turns read by annotators by round 3: mean - 0.6915 std - 0.2632


Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
0,0,MUL1950-7,MUL1950,8,8,0,1.000000,0.409967
1,0,PMUL3752-3,PMUL3752,4,8,1,0.500000,0.409967
2,0,PMUL3020-0,PMUL3020,1,10,2,0.100000,0.409967
3,0,PMUL1571-0,PMUL1571,1,8,3,0.125000,0.409967
4,0,MUL0429-1,MUL0429,2,8,4,0.250000,0.409967
...,...,...,...,...,...,...,...,...
7883,3,WOZ20641-0,WOZ20641,1,4,7883,0.250000,0.691537
7884,3,WOZ20661-2,WOZ20661,3,4,7884,0.750000,0.691537
7885,3,WOZ20662-3,WOZ20662,4,4,7885,1.000000,0.691537
7886,3,WOZ20664-0,WOZ20664,1,6,7886,0.166667,0.691537


In [26]:
merge_df(kage_mwz20_dialogue2len_df, k2000S588)

# of turns read by annotators: mean - 0.6035 std - 0.2811
# of turns read by annotators without budget: 0.6035
--------------------------------------------------
# of turns read by annotators by round 0: mean - 0.5449 std - 0.2966
# of turns read by annotators by round 1: mean - 0.6467 std - 0.2957
# of turns read by annotators by round 2: mean - 0.6509 std - 0.2505
# of turns read by annotators by round 3: mean - 0.5697 std - 0.2622


Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
0,0,PMUL1855-1,PMUL1855,2,9,0,0.222222,0.544864
1,0,MUL0059-1,MUL0059,2,9,1,0.222222,0.544864
2,0,MUL2516-4,MUL2516,5,6,2,0.833333,0.544864
3,0,MUL2019-2,MUL2019,3,8,3,0.375000,0.544864
4,0,MUL1582-3,MUL1582,4,7,4,0.571429,0.544864
...,...,...,...,...,...,...,...,...
7883,3,WOZ20649-0,WOZ20649,1,3,7883,0.333333,0.569677
7884,3,WOZ20656-1,WOZ20656,2,4,7884,0.500000,0.569677
7885,3,WOZ20662-3,WOZ20662,4,4,7885,1.000000,0.569677
7886,3,WOZ20666-0,WOZ20666,1,3,7886,0.333333,0.569677


In [27]:
merge_df(kage_mwz20_dialogue2len_df, k2000S813)

# of turns read by annotators: mean - 0.6379 std - 0.2787
# of turns read by annotators without budget: 0.6379
--------------------------------------------------
# of turns read by annotators by round 0: mean - 0.7799 std - 0.233
# of turns read by annotators by round 1: mean - 0.5401 std - 0.2726
# of turns read by annotators by round 2: mean - 0.6001 std - 0.2741
# of turns read by annotators by round 3: mean - 0.6311 std - 0.2751


Unnamed: 0,round,selected_turn_id,dialogue_id,turn_idx,total_turns,dialogue,turn_percentage,turn_percentage_by_round
0,0,PMUL3578-8,PMUL3578,9,9,0,1.000000,0.779885
1,0,MUL1534-6,MUL1534,7,7,1,1.000000,0.779885
2,0,MUL2038-2,MUL2038,3,9,2,0.333333,0.779885
3,0,MUL1665-6,MUL1665,7,9,3,0.777778,0.779885
4,0,SNG01160-5,SNG01160,6,6,4,1.000000,0.779885
...,...,...,...,...,...,...,...,...
7883,3,WOZ20634-1,WOZ20634,2,7,7883,0.285714,0.631100
7884,3,WOZ20649-2,WOZ20649,3,3,7884,1.000000,0.631100
7885,3,WOZ20667-1,WOZ20667,2,4,7885,0.500000,0.631100
7886,3,WOZ20672-4,WOZ20672,5,5,7886,1.000000,0.631100


In [62]:
k2000S202

Unnamed: 0,round,selected_turn_id
0,0,MUL1950-7
1,0,PMUL3752-3
2,0,PMUL3020-0
3,0,PMUL1571-0
4,0,MUL0429-1
...,...,...
7883,3,WOZ20641-0
7884,3,WOZ20661-2
7885,3,WOZ20662-3
7886,3,WOZ20664-0


In [53]:
# for filename in os.listdir('./data/mwz20/KAGE/max_entropy/k2000'):
#     print(filename)
# #     if filename.endswith(".asm") or filename.endswith(".py"): 
# #          # print(os.path.join(directory, filename))
# #         continue
# #     else:
# #         continue

D2000S588_selected_turn_id.csv
D2000S202_selected_turn_id.csv
.DS_Store
D2000_test_acc.csv
D2000S813_selected_turn_id.csv
