In [1]:
# !pip install python-poppler

In [2]:
import os
import numpy as np
import pandas as pd
import plotly
import plotly.graph_objs as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
pio.kaleido.scope.default_format = "svg"
from IPython.display import display, HTML

https://community.plotly.com/t/export-plotly-figure-such-that-when-i-include-it-in-latex-file-it-compiles-the-text-with-the-same-font-as-the-document/49274/2

# MWZ2.0/2.1 statistics

In [3]:
kage_dialogue2len_path = './data/mwz20/train_dialogue2len.csv'
pptod_dialogue2len_path = './data/mwz20/pptod_train_dialogue2len.csv'

In [4]:
kage_dialogue2len_df = pd.read_csv(kage_dialogue2len_path)
# display(kage_dialogue2len_df)
kage_dialogue2len_df['total_turns'].describe()

count    7888.000000
mean        6.968940
std         2.588889
min         1.000000
25%         5.000000
50%         7.000000
75%         9.000000
max        22.000000
Name: total_turns, dtype: float64

Min, Max, and Avergae number of turns per dialogue in MWZ2.0

In [5]:
pptod_dialogue2len_df = pd.read_csv(pptod_dialogue2len_path)
# display(pptod_dialogue2len_df)
pptod_dialogue2len_df['total_turns'].describe()

count    7901.000000
mean        6.966587
std         2.588367
min         1.000000
25%         5.000000
50%         7.000000
75%         9.000000
max        22.000000
Name: total_turns, dtype: float64

## Path of MWZ20/21 KAGE

In [6]:
acc_path_list = [
    # Max Entropy
    './data/mwz20/KAGE/max_entropy/k2000/k2000_test_acc.csv',
#     './data/mwz20/KAGE/max_entropy/k1500/k1500_test_acc.csv',
#     './data/mwz20/KAGE/max_entropy/k1000/k1000_test_acc.csv',
#     './data/mwz20/KAGE/max_entropy/k500/k500_test_acc.csv',
    
    # Least Confidence
    './data/mwz20/KAGE/least_confidence/k2000/k2000_test_acc.csv',
#     './data/mwz20/KAGE/least_confidence/k1500/k1500_test_acc.csv',
#     './data/mwz20/KAGE/least_confidence/k1000/k1000_test_acc.csv',
#     './data/mwz20/KAGE/least_confidence/k500/k500_test_acc.csv',
    
    # Random
    './data/mwz20/KAGE/random/k2000/k2000_test_acc.csv',
#     './data/mwz20/KAGE/random/k1500/k1500_test_acc.csv',
#     './data/mwz20/KAGE/random/k1000/k1000_test_acc.csv',
#     './data/mwz20/KAGE/random/k500/k500_test_acc.csv',
    
    ######################################################
    
    # Max Entropy
    './data/mwz21/KAGE/max_entropy/k2000/k2000_test_acc.csv',
    
    # Least Confidence
    './data/mwz21/KAGE/least_confidence/k2000/k2000_test_acc.csv',
    
    # Random
    './data/mwz21/KAGE/random/k2000/k2000_test_acc.csv',
    
    ######################################################
    
    # Max Entropy
    './data/mwz20/PPTOD/max_entropy/k2000/k2000_test_acc.csv',
    
    # Least Confidence
    './data/mwz20/PPTOD/least_confidence/k2000/k2000_test_acc.csv',
    
    # Random
    './data/mwz20/PPTOD/random/k2000/k2000_test_acc.csv',
    
    ######################################################
    
#     # Max Entropy
#     './data/mwz21/PPTOD/max_entropy/k2000/k2000_test_acc.csv',
    
#     # Least Confidence
#     './data/mwz21/PPTOD/least_confidence/k2000/k2000_test_acc.csv',
    
#     # Random
    './data/mwz21/PPTOD/random/k2000/k2000_test_acc.csv',

]

In [7]:
def add_data_size_col(df, test_acc_csv_path):
    # add number of instances
    if '2000' in test_acc_csv_path:
        df['# of labelled dialogue turns'] = df['round'] * 2000 + 2000
    elif '1500' in test_acc_csv_path:
        df['# of labelled dialogue turns'] = df['round'] * 1500 + 1500
    elif '1000' in test_acc_csv_path:
        df['# of labelled dialogue turns'] = df['round'] * 1000 + 1000
    elif '500' in test_acc_csv_path:
        df['# of labelled dialogue turns'] = df['round'] * 500 + 500
    elif '100' in test_acc_csv_path:
        df['# of labelled dialogue turns'] = df['round'] * 100 + 100
    
    df['# of labelled dialogue turns'] = df['# of labelled dialogue turns'].apply(lambda x: 7888 if x > 7888 else x)
    
    return df

In [8]:
def get_mean_std_acc_by_path(test_acc_csv_path):
    '''
    Get mean and std of acc
    '''
    
    df = pd.read_csv(test_acc_csv_path)
    df = add_data_size_col(df, test_acc_csv_path)
    
    if 'PPTOD' in test_acc_csv_path:
        # for PPTOD, divide by 100
        df['test_joint_acc'] = df['test_joint_acc'] / 100
        display(df.groupby('round')[['test_joint_acc']].mean())
        display(df.groupby('round')[['test_joint_acc']].std(ddof=0))
    else:
        display(df.groupby('round')[['test_joint_acc', 'test_slot_acc']].mean())
        display(df.groupby('round')[['test_joint_acc', 'test_slot_acc']].std(ddof=0))
    #     display(df)

## Plot

In [9]:
# Make sure the colors run in cycles if there are more lines than colors
def next_col(cols):
    while True:
        for col in cols:
            yield col

In [10]:
def get_colCycle():
    # define colors as a list 
    colors = px.colors.qualitative.Plotly

    # convert plotly hex colors to rgba to enable transparency adjustments
    def hex_rgba(hex, transparency):
        col_hex = hex.lstrip('#')
        col_rgb = list(int(col_hex[i:i+2], 16) for i in (0, 2, 4))
        col_rgb.extend([transparency])
        areacol = tuple(col_rgb)
        return areacol

    # rgba = [hex_rgba(c, transparency=0.2) for c in colors]
    rgba = [hex_rgba(c, transparency=0.3) for c in colors]
    colCycle = ['rgba'+str(elem) for elem in rgba]

#     print(colCycle)
    
    return colCycle

In [11]:
def plot_joint_acc_by_round(df_list, strategy_list, baselines):
    
    full_baseline, lt_baseline, rand_baseline = baselines
    
    if full_baseline == 0:
        not_full_line = True
    else:
        not_full_line = False
    
    fig = go.Figure()
    
    colCycle = get_colCycle()
    line_color=next_col(cols=colCycle)
    
    
    for i, df in enumerate(df_list):
    
        y_upper = list(df["test_joint_acc"] + df["joint_acc_std"])
        y_lower = list(df["test_joint_acc"] - df["joint_acc_std"])
        y_lower = y_lower[::-1]

        x_std = list(df["# of labelled dialogue turns"])
        x_std = x_std + x_std[::-1]

        new_col = next(line_color)
        
        fig.add_trace(
            go.Scatter(
                x=x_std, 
                y=y_upper + y_lower, 
                fill='tozerox',
                fillcolor=new_col,
                line=dict(color='rgba(255,255,255,0)'),
                name=strategy_list[i],
                mode="lines",
                showlegend=False,
            )
        )

        fig.add_trace(
            go.Scatter(
                x=df["# of labelled dialogue turns"], 
                y=df["test_joint_acc"], 
                line=dict(color=new_col, width=3),
                mode="lines+markers",
                name=strategy_list[i]
            )
        )
        
        # Last Turn
        fig.add_hline(
            y=lt_baseline, 
            line_width=1, 
            line_dash="dot", # dash
            line_color="green",
            annotation_text="Last Turn(14.4%)", 
            annotation_position="top left",
        )
        
        if not not_full_line:
            # Full Data
            fig.add_hline(
                y=full_baseline, 
                line_width=1, 
                line_dash="dot", # dash
                line_color="red",
                annotation_text="Full Data(100%)", 
                annotation_position="top left"
            )
        
    
    
    fig.update_xaxes(type='category')
    fig.update_layout(
        yaxis_range=[0.2, 0.6],
#         title="Plot Title",
#         legend_title="Legend Title",
        xaxis_title="# of labelled dialogue turns",
        yaxis_title="Joint Goal Accuracy",
#         font=dict(
#             family="Courier New, monospace",
#             size=18,
#             color="RebeccaPurple"
#         )
    )
    
#     fig.update_traces(textposition="top center")
    
#     fig.show()
    return fig

In [12]:
def make_merged_df_for_plotting(test_acc_csv_path):
    acc_df = pd.read_csv(test_acc_csv_path)
    
    if 'PPTOD' in test_acc_csv_path:
        # for PPTOD, divide by 100
        acc_df['test_joint_acc'] = acc_df['test_joint_acc'] / 100

        mean = acc_df.groupby('round')[['test_joint_acc']].mean().reset_index()
        std = acc_df.groupby('round')[['test_joint_acc']].std(ddof=0).reset_index()
        std = std.rename(columns={
            'test_joint_acc': 'joint_acc_std'
        })
    else:
        mean = acc_df.groupby('round')[['test_joint_acc', 'test_slot_acc']].mean().reset_index()
        std = acc_df.groupby('round')[['test_joint_acc', 'test_slot_acc']].std(ddof=0).reset_index()
        std = std.rename(columns={
            'test_joint_acc': 'joint_acc_std',
            'test_slot_acc': 'slot_acc_std'
        })
    
#     display(mean)
#     display(std)
    
    merged = pd.merge(mean, std, on='round')
    merged = add_data_size_col(merged, test_acc_csv_path)
    
    return merged

In [13]:
def prepare_subplot_inputs(k, mwz, model):
    
    strategies = ['random', 'least_confidence', 'max_entropy']
    strategy_list = ['RS', 'LC', 'ME']
    
    baselines = []
    # [full, last_turn, random]
    if model == 'KAGE':
        if mwz == 20:
            baselines = [0.5486, 0.5043, 0.4937]
        elif mwz == 21:
            baselines = [0, 0.4912, 0.4898]
    elif model == 'PPTOD':
        if mwz == 20:
            baselines = [0.5337, 0.4383, 0.4461]
        elif mwz == 21:
            baselines = [0.5710, 0.4594, 0.4721]
    
    df_list = []
    for strategy in strategies:
        acc_csv_path = f'./data/mwz{mwz}/{model}/{strategy}/k{k}/k{k}_test_acc.csv'
        print(f'============== {acc_csv_path} ==============')

        df = make_merged_df_for_plotting(acc_csv_path)
        display(df)

        df_list.append(df)


    print(f'baselines: {baselines}')
    
    return df_list, strategy_list, baselines

In [14]:
kage_mwz20_k2000 = prepare_subplot_inputs(2000, 20, 'KAGE')
kage_mwz21_k2000 = prepare_subplot_inputs(2000, 21, 'KAGE')
pptod_mwz20_k2000 = prepare_subplot_inputs(2000, 20, 'PPTOD')
pptod_mwz21_k2000 = prepare_subplot_inputs(2000, 21, 'PPTOD')



Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.425258,0.961175,0.007732,0.000843,2000
1,1,0.459237,0.966656,0.001696,7e-06,4000
2,2,0.488606,0.969459,0.008681,0.000228,6000
3,3,0.503663,0.971113,0.005154,0.000644,7888




Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.376356,0.947755,0.072233,0.016339,2000
1,1,0.471175,0.967464,0.010377,0.001581,4000
2,2,0.487521,0.969556,0.006104,0.00085,6000
3,3,0.505629,0.971027,0.000746,8.3e-05,7888




Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.40577,0.959173,0.022105,0.004563,2000
1,1,0.469479,0.967684,0.005184,0.000304,4000
2,2,0.50841,0.971095,0.003771,0.000258,6000
3,3,0.513384,0.971557,0.00504,0.000552,7888


baselines: [0.5486, 0.5043, 0.4937]


Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.403434,0.959932,0.006718,0.000866,2000
1,1,0.452022,0.966031,0.002375,0.000346,4000
2,2,0.459894,0.967585,0.004954,1.8e-05,6000
3,3,0.469802,0.968149,0.006447,0.000744,7888




Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.40384,0.960799,0.005632,0.001746,2000
1,1,0.467359,0.967581,0.004004,0.000475,4000
2,2,0.478081,0.969099,0.006447,0.0003,6000
3,3,0.481338,0.969406,0.001968,2e-05,7888




Unnamed: 0,round,test_joint_acc,test_slot_acc,joint_acc_std,slot_acc_std,# of labelled dialogue turns
0,0,0.386537,0.95748,0.011943,0.002508,2000
1,1,0.457723,0.966246,0.004682,0.000167,4000
2,2,0.478963,0.968585,0.006786,0.000507,6000
3,3,0.499865,0.971259,0.010994,0.001018,7888


baselines: [0, 0.4912, 0.4898]


Unnamed: 0,round,test_joint_acc,joint_acc_std,# of labelled dialogue turns
0,0,0.382189,6.8e-05,2000
1,1,0.410201,0.013565,4000
2,2,0.430819,0.003527,6000
3,3,0.437127,0.008071,7888




Unnamed: 0,round,test_joint_acc,joint_acc_std,# of labelled dialogue turns
0,0,0.354585,0.048833,2000
1,1,0.447165,0.010648,4000
2,2,0.451506,0.013497,6000
3,3,0.457949,0.003527,7888




Unnamed: 0,round,test_joint_acc,joint_acc_std,# of labelled dialogue turns
0,0,0.375407,0.006172,2000
1,1,0.428039,0.008071,4000
2,2,0.45761,0.002374,6000
3,3,0.469208,0.007868,7888


baselines: [0.5337, 0.4383, 0.4461]


Unnamed: 0,round,test_joint_acc,joint_acc_std,# of labelled dialogue turns
0,0,0.388399,0.00848,2000
1,1,0.413094,6.8e-05,4000
2,2,0.463161,0.010109,6000
3,3,0.469607,0.001764,7888




Unnamed: 0,round,test_joint_acc,joint_acc_std,# of labelled dialogue turns
0,0,0.394301,0.004206,2000
1,1,0.440706,0.003528,4000
2,2,0.463569,0.003596,6000
3,3,0.473745,0.003189,7888




Unnamed: 0,round,test_joint_acc,joint_acc_std,# of labelled dialogue turns
0,0,0.340977,0.010312,2000
1,1,0.440638,0.007395,4000
2,2,0.47924,0.006784,6000
3,3,0.48209,0.010041,7888


baselines: [0.571, 0.4594, 0.4721]


In [21]:
def plot_4_joint_acc_by_round(kage_mwz20_k2000, kage_mwz21_k2000, pptod_mwz20_k2000, pptod_mwz21_k2000):
    
    ################ 2 row x 2 col
    
    kage_mwz20_df_list, strategy_list, kage_mwz20_baselines = kage_mwz20_k2000
    kage_mwz21_df_list, _, kage_mwz21_baselines = kage_mwz21_k2000
    pptod_mwz20_df_list, _, pptod_mwz20_baselines = pptod_mwz20_k2000
    pptod_mwz21_df_list, _, pptod_mwz21_baselines = pptod_mwz21_k2000
    
    
    kage_mwz20_full_baseline, kage_mwz20_lt_baseline, kage_mwz20_rand_baseline = kage_mwz20_baselines
    kage_mwz21_full_baseline, kage_mwz21_lt_baseline, kage_mwz21_rand_baseline = kage_mwz21_baselines
    pptod_mwz20_full_baseline, pptod_mwz20_lt_baseline, pptod_mwz20_rand_baseline = pptod_mwz20_baselines
    pptod_mwz21_full_baseline, pptod_mwz21_lt_baseline, pptod_mwz21_rand_baseline = pptod_mwz21_baselines
    
    
#     fig = go.Figure()

    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=(
            "KAGE-GPT2 on MultiWOZ 2.0", 
            "PPTOD_base on MultiWOZ 2.0",
            "KAGE-GPT2 on MultiWOZ 2.1", 
            "PPTOD_base on MultiWOZ 2.1",
        ),
        vertical_spacing=0.1,
        horizontal_spacing=0.05,
    )
    
    colCycle = get_colCycle()
    line_color=next_col(cols=colCycle)
    
    ### KAGE MWZ20 ###
    for i, df in enumerate(kage_mwz20_df_list):
    
        y_upper = list(df["test_joint_acc"] + df["joint_acc_std"])
        y_lower = list(df["test_joint_acc"] - df["joint_acc_std"])
        y_lower = y_lower[::-1]

        x_std = list(df["# of labelled dialogue turns"])
        x_std = x_std + x_std[::-1]

        new_col = next(line_color)
        
        fig.add_trace(
            go.Scatter(
                x=x_std, 
                y=y_upper + y_lower, 
                fill='tozerox',
                fillcolor=new_col,
                line=dict(color='rgba(255,255,255,0)'),
                name=strategy_list[i],
                mode="lines",
                showlegend=False,
            ),
            row=1, col=1
        )

        fig.add_trace(
            go.Scatter(
                x=df["# of labelled dialogue turns"], 
                y=df["test_joint_acc"], 
                line=dict(color=new_col, width=3),
                mode="lines+markers",
                name=strategy_list[i],
            ),
            row=1, col=1
        )
        
        # Last Turn
        fig.add_hline(
            y=kage_mwz20_lt_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="green",
            annotation_text="Last Turn(14.4%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=1, col=1
        )
        
        # Full Data
        fig.add_hline(
            y=kage_mwz20_full_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="red",
            annotation_text="Full Data(100%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=1, col=1
        )
    
    # reset
    line_color=next_col(cols=colCycle)
    ### KAGE MWZ21 ###
    for i, df in enumerate(kage_mwz21_df_list):
    
        y_upper = list(df["test_joint_acc"] + df["joint_acc_std"])
        y_lower = list(df["test_joint_acc"] - df["joint_acc_std"])
        y_lower = y_lower[::-1]

        x_std = list(df["# of labelled dialogue turns"])
        x_std = x_std + x_std[::-1]

        new_col = next(line_color)
        
        fig.add_trace(
            go.Scatter(
                x=x_std, 
                y=y_upper + y_lower, 
                fill='tozerox',
                fillcolor=new_col,
                line=dict(color='rgba(255,255,255,0)'),
#                 name=strategy_list[i],
                mode="lines",
                showlegend=False,
            ),
            row=2, col=1
        )

        fig.add_trace(
            go.Scatter(
                x=df["# of labelled dialogue turns"], 
                y=df["test_joint_acc"], 
                line=dict(color=new_col, width=3),
                mode="lines+markers",
                showlegend=False,
#                 name=strategy_list[i],
            ),
            row=2, col=1
        )
        
        # Last Turn
        fig.add_hline(
            y=kage_mwz21_lt_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="green",
            annotation_text="Last Turn(14.4%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=2, col=1
        )
        
#         # Full Data
#         fig.add_hline(
#             y=kage_mwz21_full_baseline, 
#             line_width=1, 
#             line_dash="dot", # dash
#             line_color="red",
#             annotation_text="Full Data(100%)", 
#             annotation_position="top left",
#             row=1, col=2
#         )
            
    # reset
    line_color=next_col(cols=colCycle)
    ### PPTOD MWZ20 ###
    for i, df in enumerate(pptod_mwz20_df_list):
    
        y_upper = list(df["test_joint_acc"] + df["joint_acc_std"])
        y_lower = list(df["test_joint_acc"] - df["joint_acc_std"])
        y_lower = y_lower[::-1]

        x_std = list(df["# of labelled dialogue turns"])
        x_std = x_std + x_std[::-1]

        new_col = next(line_color)
        
        fig.add_trace(
            go.Scatter(
                x=x_std, 
                y=y_upper + y_lower, 
                fill='tozerox',
                fillcolor=new_col,
                line=dict(color='rgba(255,255,255,0)'),
#                 name=strategy_list[i],
                mode="lines",
                showlegend=False,
            ),
            row=1, col=2
        )

        fig.add_trace(
            go.Scatter(
                x=df["# of labelled dialogue turns"], 
                y=df["test_joint_acc"], 
                line=dict(color=new_col, width=3),
                mode="lines+markers",
                showlegend=False,
#                 name=strategy_list[i],
            ),
            row=1, col=2
        )
        
        # Last Turn
        fig.add_hline(
            y=pptod_mwz20_lt_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="green",
            annotation_text="Last Turn(14.4%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=1, col=2
        )
        
        # Full Data
        fig.add_hline(
            y=pptod_mwz20_full_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="red",
            annotation_text="Full Data(100%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=1, col=2
        )
        
    # reset
    line_color=next_col(cols=colCycle)
    ### PPTOD MWZ21 ###
    for i, df in enumerate(pptod_mwz21_df_list):
    
        y_upper = list(df["test_joint_acc"] + df["joint_acc_std"])
        y_lower = list(df["test_joint_acc"] - df["joint_acc_std"])
        y_lower = y_lower[::-1]

        x_std = list(df["# of labelled dialogue turns"])
        x_std = x_std + x_std[::-1]

        new_col = next(line_color)
        
        fig.add_trace(
            go.Scatter(
                x=x_std, 
                y=y_upper + y_lower, 
                fill='tozerox',
                fillcolor=new_col,
                line=dict(color='rgba(255,255,255,0)'),
#                 name=strategy_list[i],
                mode="lines",
                showlegend=False,
            ),
            row=2, col=2
        )

        fig.add_trace(
            go.Scatter(
                x=df["# of labelled dialogue turns"], 
                y=df["test_joint_acc"], 
                line=dict(color=new_col, width=3),
                mode="lines+markers",
                showlegend=False,
#                 name=strategy_list[i],
            ),
            row=2, col=2
        )
        
        # Last Turn
        fig.add_hline(
            y=pptod_mwz21_lt_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="green",
            annotation_text="Last Turn(14.4%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=2, col=2
        )
        
        # Full Data
        fig.add_hline(
            y=pptod_mwz21_full_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="red",
            annotation_text="Full Data(100%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=2, col=2
        )
        
    
    fig.update_xaxes(type="category", row=1, col=1)
    fig.update_xaxes(type="category", row=1, col=2)
    fig.update_xaxes(title_text="# of labelled dialogue turns", type="category", row=2, col=1)
    fig.update_xaxes(title_text="# of labelled dialogue turns", type="category", row=2, col=2)
    
#     fig.update_yaxes(ticksuffix = " ")
    fig.update_yaxes(title_text="Joint Goal Accuracy", range=[0.25, 0.6], row=1, col=1)
    fig.update_yaxes(range=[0.25, 0.55], row=1, col=2)
    fig.update_yaxes(title_text="Joint Goal Accuracy", range=[0.35, 0.55], row=2, col=1)
    fig.update_yaxes(range=[0.3, 0.6], row=2, col=2)
    
    fig.update_layout(
        height=500,
        margin_l=5, margin_t=20, margin_b=5, margin_r=5
#         yaxis_range=[0.2, 0.6],
#         title="Plot Title",
#         legend_title="Legend Title",
#         xaxis_title="# of labelled dialogue turns",
#         yaxis_title="Joint Goal Accuracy",
#         font=dict(
#             family="Courier New, monospace",
#             size=18,
#             color="RebeccaPurple"
#         )
    )
    
#     fig.update_traces(textposition="top center")
    
#     fig.show()
    return fig

In [25]:
plot_4_joint_acc_by_round(kage_mwz20_k2000, kage_mwz21_k2000, pptod_mwz20_k2000, pptod_mwz21_k2000)

# k2000_4plot = plot_4_joint_acc_by_round(kage_mwz20_k2000, kage_mwz21_k2000, pptod_mwz20_k2000, pptod_mwz21_k2000)
# pio.write_image(k2000_4plot, "./data/plot/k2000_4plots.pdf", width=700, height=450)



In [42]:
def row1col4_plot_4_joint_acc_by_round(kage_mwz20_k2000, kage_mwz21_k2000, pptod_mwz20_k2000):
    ################ 1 row x 4 col
    
    kage_mwz20_df_list, strategy_list, kage_mwz20_baselines = kage_mwz20_k2000
    kage_mwz21_df_list, _, kage_mwz21_baselines = kage_mwz21_k2000
    pptod_mwz20_df_list, _, pptod_mwz20_baselines = pptod_mwz20_k2000
    
    
    kage_mwz20_full_baseline, kage_mwz20_lt_baseline, kage_mwz20_rand_baseline = kage_mwz20_baselines
    kage_mwz21_full_baseline, kage_mwz21_lt_baseline, kage_mwz21_rand_baseline = kage_mwz21_baselines
    pptod_mwz20_full_baseline, pptod_mwz20_lt_baseline, pptod_mwz20_rand_baseline = pptod_mwz20_baselines
    
    
#     fig = go.Figure()

    fig = make_subplots(
        rows=1, cols=4,
        subplot_titles=(
            "KAGE-GPT2+ME on MultiWOZ 2.0", 
            "PPTOD+ME on MultiWOZ 2.0",
            "KAGE-GPT2+ME on MultiWOZ 2.1", 
            "PPTOD+ME on MultiWOZ 2.1",
        ),
        vertical_spacing=0.1,
        horizontal_spacing=0.05,
    )
    
    colCycle = get_colCycle()
    line_color=next_col(cols=colCycle)
    
    ### KAGE MWZ20 ###
    for i, df in enumerate(kage_mwz20_df_list):
    
        y_upper = list(df["test_joint_acc"] + df["joint_acc_std"])
        y_lower = list(df["test_joint_acc"] - df["joint_acc_std"])
        y_lower = y_lower[::-1]

        x_std = list(df["# of labelled dialogue turns"])
        x_std = x_std + x_std[::-1]

        new_col = next(line_color)
        
        fig.add_trace(
            go.Scatter(
                x=x_std, 
                y=y_upper + y_lower, 
                fill='tozerox',
                fillcolor=new_col,
                line=dict(color='rgba(255,255,255,0)'),
                name=strategy_list[i],
                mode="lines",
                showlegend=False,
            ),
            row=1, col=1
        )

        fig.add_trace(
            go.Scatter(
                x=df["# of labelled dialogue turns"], 
                y=df["test_joint_acc"], 
                line=dict(color=new_col, width=3),
                mode="lines+markers",
                name=strategy_list[i],
            ),
            row=1, col=1
        )
        
        # Last Turn
        fig.add_hline(
            y=kage_mwz20_lt_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="green",
            annotation_text="Last Turn(14.4%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=1, col=1
        )
        
        # Full Data
        fig.add_hline(
            y=kage_mwz20_full_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="red",
            annotation_text="Full Data(100%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=1, col=1
        )
    
    # reset
    line_color=next_col(cols=colCycle)
    ### KAGE MWZ21 ###
    for i, df in enumerate(kage_mwz21_df_list):
    
        y_upper = list(df["test_joint_acc"] + df["joint_acc_std"])
        y_lower = list(df["test_joint_acc"] - df["joint_acc_std"])
        y_lower = y_lower[::-1]

        x_std = list(df["# of labelled dialogue turns"])
        x_std = x_std + x_std[::-1]

        new_col = next(line_color)
        
        fig.add_trace(
            go.Scatter(
                x=x_std, 
                y=y_upper + y_lower, 
                fill='tozerox',
                fillcolor=new_col,
                line=dict(color='rgba(255,255,255,0)'),
#                 name=strategy_list[i],
                mode="lines",
                showlegend=False,
            ),
            row=1, col=3
        )

        fig.add_trace(
            go.Scatter(
                x=df["# of labelled dialogue turns"], 
                y=df["test_joint_acc"], 
                line=dict(color=new_col, width=3),
                mode="lines+markers",
                showlegend=False,
#                 name=strategy_list[i],
            ),
            row=1, col=3
        )
        
        # Last Turn
        fig.add_hline(
            y=kage_mwz21_lt_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="green",
            annotation_text="Last Turn(14.4%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=1, col=3
        )
        
#         # Full Data
#         fig.add_hline(
#             y=kage_mwz21_full_baseline, 
#             line_width=1, 
#             line_dash="dot", # dash
#             line_color="red",
#             annotation_text="Full Data(100%)", 
#             annotation_position="top left",
#             row=1, col=2
#         )
            
    # reset
    line_color=next_col(cols=colCycle)
    ### PPTOD MWZ20 ###
    for i, df in enumerate(pptod_mwz20_df_list):
    
        y_upper = list(df["test_joint_acc"] + df["joint_acc_std"])
        y_lower = list(df["test_joint_acc"] - df["joint_acc_std"])
        y_lower = y_lower[::-1]

        x_std = list(df["# of labelled dialogue turns"])
        x_std = x_std + x_std[::-1]

        new_col = next(line_color)
        
        fig.add_trace(
            go.Scatter(
                x=x_std, 
                y=y_upper + y_lower, 
                fill='tozerox',
                fillcolor=new_col,
                line=dict(color='rgba(255,255,255,0)'),
#                 name=strategy_list[i],
                mode="lines",
                showlegend=False,
            ),
            row=1, col=2
        )

        fig.add_trace(
            go.Scatter(
                x=df["# of labelled dialogue turns"], 
                y=df["test_joint_acc"], 
                line=dict(color=new_col, width=3),
                mode="lines+markers",
                showlegend=False,
#                 name=strategy_list[i],
            ),
            row=1, col=2
        )
        
        # Last Turn
        fig.add_hline(
            y=pptod_mwz20_lt_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="green",
            annotation_text="Last Turn(14.4%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=1, col=2
        )
        
        # Full Data
        fig.add_hline(
            y=pptod_mwz20_full_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="red",
            annotation_text="Full Data(100%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=1, col=2
        )
        
    # reset
    line_color=next_col(cols=colCycle)
    ### PPTOD MWZ21 ###
    for i, df in enumerate(pptod_mwz20_df_list):
    
        y_upper = list(df["test_joint_acc"] + df["joint_acc_std"])
        y_lower = list(df["test_joint_acc"] - df["joint_acc_std"])
        y_lower = y_lower[::-1]

        x_std = list(df["# of labelled dialogue turns"])
        x_std = x_std + x_std[::-1]

        new_col = next(line_color)
        
        fig.add_trace(
            go.Scatter(
                x=x_std, 
                y=y_upper + y_lower, 
                fill='tozerox',
                fillcolor=new_col,
                line=dict(color='rgba(255,255,255,0)'),
#                 name=strategy_list[i],
                mode="lines",
                showlegend=False,
            ),
            row=1, col=4
        )

        fig.add_trace(
            go.Scatter(
                x=df["# of labelled dialogue turns"], 
                y=df["test_joint_acc"], 
                line=dict(color=new_col, width=3),
                mode="lines+markers",
                showlegend=False,
#                 name=strategy_list[i],
            ),
            row=1, col=4
        )
        
        # Last Turn
        fig.add_hline(
            y=pptod_mwz20_lt_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="green",
            annotation_text="Last Turn(14.4%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=1, col=4
        )
        
        # Full Data
        fig.add_hline(
            y=pptod_mwz20_full_baseline, 
            line_width=1.5, 
            line_dash="dot", # dash
            line_color="red",
            annotation_text="Full Data(100%)", 
            annotation_position="top left",
            annotation_font_size=10,
            annotation_font_color="rgba(98, 85, 92, 0.8)",
            row=1, col=4
        )
        
    
    fig.update_xaxes(title_text="# of labelled dialogue turns", type="category", row=1, col=1)
    fig.update_xaxes(title_text="# of labelled dialogue turns", type="category", row=1, col=2)
    fig.update_xaxes(title_text="# of labelled dialogue turns", type="category", row=1, col=3)
    fig.update_xaxes(title_text="# of labelled dialogue turns", type="category", row=1, col=4)
    
#     fig.update_yaxes(ticksuffix = " ")
    fig.update_yaxes(title_text="Joint Goal Accuracy", range=[0.25, 0.6], row=1, col=1)
    fig.update_yaxes(range=[0.25, 0.55], row=1, col=2)
    fig.update_yaxes(range=[0.35, 0.55], row=1, col=3)
    fig.update_yaxes(range=[0.2, 0.6], row=1, col=4)
    
    fig.update_layout(
        height=300,
        width=1400,
        margin_l=5, margin_t=20, margin_b=5, margin_r=5
#         yaxis_range=[0.2, 0.6],
#         title="Plot Title",
#         legend_title="Legend Title",
#         xaxis_title="# of labelled dialogue turns",
#         yaxis_title="Joint Goal Accuracy",
#         font=dict(
#             family="Courier New, monospace",
#             size=18,
#             color="RebeccaPurple"
#         )
    )
    
#     fig.update_traces(textposition="top center")
    
#     fig.show()
    return fig

In [46]:
plot1 = row1col4_plot_4_joint_acc_by_round(kage_mwz20_k2000, kage_mwz21_k2000, pptod_mwz20_k2000)
plot1

In [48]:
pio.write_image(plot1, "test2.pdf", width=1400, height=300)