In [3]:
# Notebook for ner results table

In [4]:
import pandas as pd
import numpy as np
import json

In [5]:
raw_path = '/notebook/ue/uncertainty-estimation/workdir/run_calc_ues_metrics/electra-metric/'

#reg_path = '/data/gkuzmin/uncertainty-estimation/workdir/run_calc_ues_metrics/conll2003_electra_reg_01_fix/'
ues = ['last', 'all', 'dpp', 'dpp_with_ood']
ues_names = ['MC', 'MC', 'DPP_on_masks', 'DPP_with_ood']
ues_layers = ['last', 'all', 'last', 'last']
metrics = ['rejection-curve-auc', 'rcc-auc', 'rpp']
metric_names = ['rejection-curve-auc', 'rcc-auc', 'rpp']
types = ['mrpc', 'cola', 'sst2']
types_names = ['MRPC', 'CoLA', 'SST2 (10%)']
ue_methods = ['max_prob', 'bald', 'sampled_max_prob', 'variance']
perc_metrics = ['rejection-curve-auc', 'rpp']
diff_metrics = ['rejection-curve-auc', 'roc-auc']

In [6]:
def get_df(raw_path, reg_type, baselines_dict={}, baselines=None):
    raw_dict = {}
    df_dict = {}
    for ue, ue_name in zip(ues, ues):
        #ue_path = raw_path + ue + '/'
        # enter row level
        raw_dict[ue_name] = {}
        df_dict[ue_name] = {}
        for ue_type in types:
            raw_dict[ue_name][ue_type] = {}
            for metric in metrics:
                ue_path = raw_path + ue_type + '/' + ue + '/'
                fname = ue_path + f'metrics_{metric}.json'
                with open(fname, 'r') as f:
                    curr_metrics = json.loads(f.read())
                metric_results = {}
                for ue_method in ue_methods:
                    mean, std = np.mean(list(curr_metrics[ue_method].values())), np.std(list(curr_metrics[ue_method].values()))
                    if metric in perc_metrics:
                        mean, std = mean * 100, std * 100
                    if ue_method == 'max_prob':
                        baseline = mean
                        if baselines is None:
                            baselines_dict[ue_type + metric + ue_method] = baseline
                        else:
                            baseline = baselines_dict[ue_type + metric + ue_method]
                    if metric in diff_metrics and ue_method != 'max_prob':
                        mean = mean - baseline
                    value = '{:.{prec}f}'.format(mean, prec=2) + '$\\pm$' + '{:.{prec}f}'.format(std, prec=2)
                    metric_results[ue_method] = value

                # so we obtained two dict for one metric
                raw_dict[ue_name][ue_type][metric] = metric_results
                # make buf dataframe
            type_df = pd.DataFrame.from_dict(raw_dict[ue_name][ue_type])
            df_dict[ue_name][ue_type] = type_df

    dataset_dfs = [pd.concat([df_dict[ue][ue_type] for ue in ues]) for ue_type in types]
    # make multiindex
    for idx, df in enumerate(dataset_dfs):
        df.columns = pd.MultiIndex.from_tuples([(types_names[idx], metric) for metric in metrics])
        dataset_dfs[idx] = df
    #token_df.columns = pd.MultiIndex.from_tuples([('CoNNL-2003 (10%, token level)', metric) for metric in metrics])
    #seq_df.columns = pd.MultiIndex.from_tuples([('CoNNL-2003 (10%, sequence level)', metric) for metric in metrics])
    raw_df = pd.concat(dataset_dfs, axis=1)

    # after rename max_prob column to baseline and drop all max_prob columns
    max_prob_rows = raw_df.loc['max_prob']
    if len(max_prob_rows) != len(metrics) * len(types_names) or len(types_names) == 1:
        buf_max_prob = raw_df.loc['max_prob'].drop_duplicates().loc['max_prob']
    else:
        buf_max_prob = raw_df.loc['max_prob']
    raw_df.drop('max_prob', inplace=True)
    raw_df.loc['max_prob'] = buf_max_prob
    names_df = pd.DataFrame()
    methods = []
    for ue in ues_names:
        methods += [ue] * (len(ue_methods) - 1)
    methods += ['Baseline']
    layers = []
    for ue in ues_layers:
        layers += [ue] * (len(ue_methods) - 1)
    layers += ['-']
    reg_type = [reg_type] * len(raw_df)
    names_df['Method'] = methods
    names_df['Reg. Type'] = reg_type
    # names_df['Dropout Layers'] = layers
    names_df['UE Score'] = raw_df.index
    names_df.index = raw_df.index
    raw_df = pd.concat([names_df, raw_df], axis=1)
    return raw_df, baselines_dict

# Final tables

In [14]:
raw_path = '/home/jovyan/uncertainty-estimation/workdir/run_calc_ues_metrics/mixup_electra/'

#reg_path = '/data/gkuzmin/uncertainty-estimation/workdir/run_calc_ues_metrics/conll2003_electra_reg_01_fix/'
ues = ['msd/all', 'msd/last']
ues_names = ['MSD|all', 'MSD|last']
ues_layers = ['all', 'last', 'last']
metrics = ['rcc-auc', 'rpp']
metric_names = ['rcc-auc', 'rpp']
types = ['mrpc', 'cola', 'sst2']
types_names = ['MRPC', 'CoLA', 'SST2 (10%)']
ue_methods = ['max_prob', 'mixup']
perc_metrics = ['rejection-curve-auc', 'rpp']
diff_metrics = ['rejection-curve-auc', 'roc-auc']


# copied from table
baselines_dict = {'mrpcrejection-curve-aucmax_prob': 0.9208435457516339 * 100,
                  'mrpcrcc-aucmax_prob': 23.279293481630972,
                  'mrpcrppmax_prob': 0.026788574907087016 * 100,
                  'colarejection-curve-aucmax_prob': 0.9203619367209971 * 100,
                  'colarcc-aucmax_prob': 59.03726591032054,
                  'colarppmax_prob': 0.02631936969193335 * 100,
                  'sst2rejection-curve-aucmax_prob': 0.9379778287461774 * 100,
                  'sst2rcc-aucmax_prob': 18.067838464295736,
                  'sst2rppmax_prob': 0.012349462026204303 * 100}
raw_df, baselines_dict = get_df(raw_path, 'MSD', baselines_dict, True)

miscl_df = raw_df
miscl_df.reset_index(inplace=True, drop=True)

latex_table = miscl_df.to_latex(bold_rows=False, index=False)
latex_table = latex_table.replace('\\$\\textbackslash pm\\$', '$\pm$')
latex_table = latex_table.replace('variance', 'PV')
latex_table = latex_table.replace('var\_ratio', 'VR')
latex_table = latex_table.replace('sampled\_entropy', 'SE')
latex_table = latex_table.replace('sampled\_max\_prob', 'SMP')
latex_table = latex_table.replace('mahalanobis\_distance', 'MD')
latex_table = latex_table.replace('max\_prob', 'MP')
latex_table = latex_table.replace('bald', 'BALD')
print(latex_table)

\begin{tabular}{lllllllll}
\toprule
  Method & Reg. Type & UE Score & (MRPC, rcc-auc) &   (MRPC, rpp) & (CoLA, rcc-auc) &   (CoLA, rpp) & (SST2 (10\%), rcc-auc) & (SST2 (10\%), rpp) \\
\midrule
 MSD|all &       MSD &    mixup &  12.54$\pm$1.03 & 1.66$\pm$0.14 &  41.25$\pm$2.00 & 2.06$\pm$0.06 &        13.80$\pm$0.82 &     0.96$\pm$0.06 \\
MSD|last &       MSD &    mixup &  12.79$\pm$1.07 & 1.70$\pm$0.15 &  42.12$\pm$2.12 & 2.12$\pm$0.06 &        13.93$\pm$0.80 &     0.97$\pm$0.06 \\
Baseline &       MSD & MP &  12.68$\pm$2.37 & 1.67$\pm$0.29 &  53.57$\pm$5.33 & 2.49$\pm$0.09 &        15.53$\pm$1.87 &     1.00$\pm$0.12 \\
\bottomrule
\end{tabular}



In [51]:
raw_path = '/notebook/ue/uncertainty-estimation/workdir/run_calc_ues_metrics/metric_opt_electra_3hyp/'

#reg_path = '/data/gkuzmin/uncertainty-estimation/workdir/run_calc_ues_metrics/conll2003_electra_reg_01_fix/'
ues = ['maha_mc']
ues_names = ['MD']
ues_layers = ['-']

metrics = ['rejection-curve-auc', 'rcc-auc', 'rpp']
metric_names = ['rejection-curve-auc', 'rcc-auc', 'rpp']
types = ['cola']#['mrpc', 'sst2']#['mrpc', 'cola', 'sst2']
types_names = ['cola']#['MRPC', 'SST2 (10%)']#['MRPC', 'CoLA', 'SST2 (10%)']
ue_methods = ['max_prob', 'mahalanobis_distance', 'sampled_mahalanobis_distance']

perc_metrics = ['rejection-curve-auc', 'rpp']
diff_metrics = ['rejection-curve-auc', 'roc-auc']

# copied from table
baselines_dict = {'mrpcrejection-curve-aucmax_prob': 0.9208435457516339 * 100,
                  'mrpcrcc-aucmax_prob': 23.279293481630972,
                  'mrpcrppmax_prob': 0.026788574907087016 * 100,
                  'colarejection-curve-aucmax_prob': 0.9203619367209971 * 100,
                  'colarcc-aucmax_prob': 59.03726591032054,
                  'colarppmax_prob': 0.02631936969193335 * 100,
                  'sst2rejection-curve-aucmax_prob': 0.9379778287461774 * 100,
                  'sst2rcc-aucmax_prob': 18.067838464295736,
                  'sst2rppmax_prob': 0.012349462026204303 * 100}
raw_df, baselines_dict = get_df(raw_path, 'metric', baselines_dict, True)

miscl_df = raw_df
miscl_df.reset_index(inplace=True, drop=True)

latex_table = miscl_df.to_latex(bold_rows=False, index=False)
latex_table = latex_table.replace('\\$\\textbackslash pm\\$', '$\pm$')
latex_table = latex_table.replace('variance', 'PV')
latex_table = latex_table.replace('var\_ratio', 'VR')
latex_table = latex_table.replace('sampled\_entropy', 'SE')
latex_table = latex_table.replace('sampled\_max\_prob', 'SMP')
latex_table = latex_table.replace('mahalanobis\_distance', 'MD')
latex_table = latex_table.replace('sampled\_MD', 'SMD')
latex_table = latex_table.replace('max\_prob', 'MP')
latex_table = latex_table.replace('bald', 'BALD')
print(latex_table)

\begin{tabular}{llllll}
\toprule
  Method & Reg. Type &                     UE Score & (cola, rejection-curve-auc) & (cola, rcc-auc) &   (cola, rpp) \\
\midrule
      MD &    metric &         MD &               0.37$\pm$0.12 &  46.30$\pm$2.97 & 2.23$\pm$0.12 \\
      MD &    metric & SMD &              -0.13$\pm$0.12 &  55.56$\pm$2.39 & 2.76$\pm$0.11 \\
Baseline &    metric &                     MP &              91.81$\pm$0.13 &  56.01$\pm$3.06 & 2.82$\pm$0.11 \\
\bottomrule
\end{tabular}



In [52]:
raw_path = '/notebook/ue/uncertainty-estimation/workdir/run_calc_ues_metrics/metric_opt_electra_3hyp/'

#reg_path = '/data/gkuzmin/uncertainty-estimation/workdir/run_calc_ues_metrics/conll2003_electra_reg_01_fix/'
ues = ['maha_sn_mc']
ues_names = ['MD SN (Ours)']
ues_layers = ['-']

metrics = ['rejection-curve-auc', 'rcc-auc', 'rpp']
metric_names = ['rejection-curve-auc', 'rcc-auc', 'rpp']
types = ['cola']#['mrpc', 'sst2']#['mrpc', 'cola', 'sst2']
types_names = ['CoLA']#['MRPC', 'SST2 (10%)']#['MRPC', 'CoLA', 'SST2 (10%)']
ue_methods = ['max_prob', 'mahalanobis_distance', 'sampled_mahalanobis_distance']

perc_metrics = ['rejection-curve-auc', 'rpp']
diff_metrics = ['rejection-curve-auc', 'roc-auc']

# copied from table
baselines_dict = {'mrpcrejection-curve-aucmax_prob': 0.9208435457516339 * 100,
                  'mrpcrcc-aucmax_prob': 23.279293481630972,
                  'mrpcrppmax_prob': 0.026788574907087016 * 100,
                  'colarejection-curve-aucmax_prob': 0.9203619367209971 * 100,
                  'colarcc-aucmax_prob': 59.03726591032054,
                  'colarppmax_prob': 0.02631936969193335 * 100,
                  'sst2rejection-curve-aucmax_prob': 0.9379778287461774 * 100,
                  'sst2rcc-aucmax_prob': 18.067838464295736,
                  'sst2rppmax_prob': 0.012349462026204303 * 100}
raw_df, baselines_dict = get_df(raw_path, 'metric', baselines_dict, True)

miscl_df = raw_df
miscl_df.reset_index(inplace=True, drop=True)

latex_table = miscl_df.to_latex(bold_rows=False, index=False)
latex_table = latex_table.replace('\\$\\textbackslash pm\\$', '$\pm$')
latex_table = latex_table.replace('variance', 'PV')
latex_table = latex_table.replace('var\_ratio', 'VR')
latex_table = latex_table.replace('sampled\_entropy', 'SE')
latex_table = latex_table.replace('sampled\_max\_prob', 'SMP')
latex_table = latex_table.replace('mahalanobis\_distance', 'MD')
latex_table = latex_table.replace('sampled\_MD', 'SMD')
latex_table = latex_table.replace('max\_prob', 'MP')
latex_table = latex_table.replace('bald', 'BALD')
print(latex_table)

\begin{tabular}{llllll}
\toprule
      Method & Reg. Type &                     UE Score & (CoLA, rejection-curve-auc) &  (CoLA, rcc-auc) &   (CoLA, rpp) \\
\midrule
MD SN (Ours) &    metric &         MD &              -1.06$\pm$0.21 &   70.29$\pm$3.13 & 3.44$\pm$0.18 \\
MD SN (Ours) &    metric & SMD &              -3.71$\pm$0.38 & 175.69$\pm$11.57 & 6.18$\pm$0.34 \\
    Baseline &    metric &                     MP &              89.46$\pm$0.41 & 148.20$\pm$12.91 & 5.18$\pm$0.47 \\
\bottomrule
\end{tabular}



In [32]:
# MRPC with new pars, maha
raw_path = '/notebook/ue/uncertainty-estimation/workdir/run_calc_ues_metrics/metric_opt_electra_fix/'

#reg_path = '/data/gkuzmin/uncertainty-estimation/workdir/run_calc_ues_metrics/conll2003_electra_reg_01_fix/'
ues = ['maha_mc']
ues_names = ['MD']
ues_layers = ['-']

metrics = ['rejection-curve-auc', 'rcc-auc', 'rpp']
metric_names = ['rejection-curve-auc', 'rcc-auc', 'rpp']
types = ['mrpc']
types_names = ['MRPC']
ue_methods = ['max_prob', 'mahalanobis_distance', 'sampled_mahalanobis_distance']

perc_metrics = ['rejection-curve-auc', 'rpp']
diff_metrics = ['rejection-curve-auc', 'roc-auc']

# copied from table
baselines_dict = {'mrpcrejection-curve-aucmax_prob': 0.9208435457516339 * 100,
                  'mrpcrcc-aucmax_prob': 23.279293481630972,
                  'mrpcrppmax_prob': 0.026788574907087016 * 100,
                  'colarejection-curve-aucmax_prob': 0.9203619367209971 * 100,
                  'colarcc-aucmax_prob': 59.03726591032054,
                  'colarppmax_prob': 0.02631936969193335 * 100,
                  'sst2rejection-curve-aucmax_prob': 0.9379778287461774 * 100,
                  'sst2rcc-aucmax_prob': 18.067838464295736,
                  'sst2rppmax_prob': 0.012349462026204303 * 100}
raw_df, baselines_dict = get_df(raw_path, 'metric', baselines_dict, True)

miscl_df = raw_df
miscl_df.reset_index(inplace=True, drop=True)

latex_table = miscl_df.to_latex(bold_rows=False, index=False)
latex_table = latex_table.replace('\\$\\textbackslash pm\\$', '$\pm$')
latex_table = latex_table.replace('variance', 'PV')
latex_table = latex_table.replace('var\_ratio', 'VR')
latex_table = latex_table.replace('sampled\_entropy', 'SE')
latex_table = latex_table.replace('sampled\_max\_prob', 'SMP')
latex_table = latex_table.replace('mahalanobis\_distance', 'MD')
latex_table = latex_table.replace('sampled\_MD', 'SMD')
latex_table = latex_table.replace('max\_prob', 'MP')
latex_table = latex_table.replace('bald', 'BALD')
print(latex_table)

\begin{tabular}{llllll}
\toprule
  Method & Reg. Type &                     UE Score & (MRPC, rejection-curve-auc) & (MRPC, rcc-auc) &   (MRPC, rpp) \\
\midrule
      MD &    metric &         MD &               0.23$\pm$0.36 &  18.38$\pm$3.13 & 2.29$\pm$0.32 \\
      MD &    metric & SMD &              -0.61$\pm$1.02 & 31.14$\pm$11.04 & 3.14$\pm$0.98 \\
Baseline &    metric &                     MP &              91.46$\pm$0.35 &  27.26$\pm$7.23 & 3.14$\pm$0.34 \\
\bottomrule
\end{tabular}



In [33]:
# MRPC with new pars, maha
raw_path = '/notebook/ue/uncertainty-estimation/workdir/run_calc_ues_metrics/metric_opt_electra_fix6/'

#reg_path = '/data/gkuzmin/uncertainty-estimation/workdir/run_calc_ues_metrics/conll2003_electra_reg_01_fix/'
ues = ['maha_mc']
ues_names = ['MD']
ues_layers = ['-']

metrics = ['rejection-curve-auc', 'rcc-auc', 'rpp']
metric_names = ['rejection-curve-auc', 'rcc-auc', 'rpp']
types = ['mrpc']
types_names = ['MRPC']
ue_methods = ['max_prob', 'mahalanobis_distance', 'sampled_mahalanobis_distance']

perc_metrics = ['rejection-curve-auc', 'rpp']
diff_metrics = ['rejection-curve-auc', 'roc-auc']

# copied from table
baselines_dict = {'mrpcrejection-curve-aucmax_prob': 0.9208435457516339 * 100,
                  'mrpcrcc-aucmax_prob': 23.279293481630972,
                  'mrpcrppmax_prob': 0.026788574907087016 * 100,
                  'colarejection-curve-aucmax_prob': 0.9203619367209971 * 100,
                  'colarcc-aucmax_prob': 59.03726591032054,
                  'colarppmax_prob': 0.02631936969193335 * 100,
                  'sst2rejection-curve-aucmax_prob': 0.9379778287461774 * 100,
                  'sst2rcc-aucmax_prob': 18.067838464295736,
                  'sst2rppmax_prob': 0.012349462026204303 * 100}
raw_df, baselines_dict = get_df(raw_path, 'metric', baselines_dict, True)

miscl_df = raw_df
miscl_df.reset_index(inplace=True, drop=True)

latex_table = miscl_df.to_latex(bold_rows=False, index=False)
latex_table = latex_table.replace('\\$\\textbackslash pm\\$', '$\pm$')
latex_table = latex_table.replace('variance', 'PV')
latex_table = latex_table.replace('var\_ratio', 'VR')
latex_table = latex_table.replace('sampled\_entropy', 'SE')
latex_table = latex_table.replace('sampled\_max\_prob', 'SMP')
latex_table = latex_table.replace('mahalanobis\_distance', 'MD')
latex_table = latex_table.replace('sampled\_MD', 'SMD')
latex_table = latex_table.replace('max\_prob', 'MP')
latex_table = latex_table.replace('bald', 'BALD')
print(latex_table)

\begin{tabular}{llllll}
\toprule
  Method & Reg. Type &                     UE Score & (MRPC, rejection-curve-auc) & (MRPC, rcc-auc) &   (MRPC, rpp) \\
\midrule
      MD &    metric &         MD &               0.58$\pm$0.40 &  15.95$\pm$2.84 & 2.03$\pm$0.40 \\
      MD &    metric & SMD &               0.22$\pm$0.37 &  23.27$\pm$4.29 & 2.43$\pm$0.36 \\
Baseline &    metric &                     MP &              92.16$\pm$0.52 &  21.00$\pm$4.21 & 2.57$\pm$0.46 \\
\bottomrule
\end{tabular}



In [43]:
# New table for MRPC
raw_path = '/notebook/ue/uncertainty-estimation/workdir/run_calc_ues_metrics/metric_opt_electra_fix6/'

#reg_path = '/data/gkuzmin/uncertainty-estimation/workdir/run_calc_ues_metrics/conll2003_electra_reg_01_fix/'
ues = ['all', 'dpp', 'dpp_with_ood']
ues_names = ['MC', 'DDPP (+DPP) (Ours)', 'DDPP (+OOD) (Ours)']
ues_layers = ['all', 'last', 'last']
metrics = ['rejection-curve-auc', 'rcc-auc', 'rpp']
metric_names = ['rejection-curve-auc', 'rcc-auc', 'rpp']
types = ['mrpc']
types_names = ['MRPC']
ue_methods = ['max_prob', 'bald', 'sampled_max_prob', 'variance']
perc_metrics = ['rejection-curve-auc', 'rpp']
diff_metrics = ['rejection-curve-auc', 'roc-auc']


# copied from table
baselines_dict = {'mrpcrejection-curve-aucmax_prob': 0.9208435457516339 * 100,
                  'mrpcrcc-aucmax_prob': 23.279293481630972,
                  'mrpcrppmax_prob': 0.026788574907087016 * 100,
                  'colarejection-curve-aucmax_prob': 0.9203619367209971 * 100,
                  'colarcc-aucmax_prob': 59.03726591032054,
                  'colarppmax_prob': 0.02631936969193335 * 100,
                  'sst2rejection-curve-aucmax_prob': 0.9379778287461774 * 100,
                  'sst2rcc-aucmax_prob': 18.067838464295736,
                  'sst2rppmax_prob': 0.012349462026204303 * 100}
raw_df, baselines_dict = get_df(raw_path, 'metric', baselines_dict, True)

miscl_df = raw_df
miscl_df.reset_index(inplace=True, drop=True)

latex_table = miscl_df.to_latex(bold_rows=False, index=False)
latex_table = latex_table.replace('\\$\\textbackslash pm\\$', '$\pm$')
latex_table = latex_table.replace('variance', 'PV')
latex_table = latex_table.replace('var\_ratio', 'VR')
latex_table = latex_table.replace('sampled\_entropy', 'SE')
latex_table = latex_table.replace('sampled\_max\_prob', 'SMP')
latex_table = latex_table.replace('mahalanobis\_distance', 'MD')
latex_table = latex_table.replace('max\_prob', 'MP')
latex_table = latex_table.replace('bald', 'BALD')
print(latex_table)

\begin{tabular}{llllll}
\toprule
            Method & Reg. Type &         UE Score & (MRPC, rejection-curve-auc) & (MRPC, rcc-auc) &   (MRPC, rpp) \\
\midrule
                MC &    metric &             BALD &               0.22$\pm$0.29 &  22.25$\pm$3.29 & 2.42$\pm$0.27 \\
                MC &    metric & SMP &               0.44$\pm$0.37 &  20.17$\pm$3.83 & 2.22$\pm$0.33 \\
                MC &    metric &         PV &               0.27$\pm$0.30 &  21.76$\pm$3.53 & 2.38$\pm$0.29 \\
DDPP (+DPP) (Ours) &    metric &             BALD &              -0.21$\pm$0.65 &  23.46$\pm$6.14 & 2.83$\pm$0.66 \\
DDPP (+DPP) (Ours) &    metric & SMP &              -0.01$\pm$0.50 &  21.88$\pm$3.63 & 2.67$\pm$0.45 \\
DDPP (+DPP) (Ours) &    metric &         PV &              -0.08$\pm$0.74 &  22.22$\pm$6.44 & 2.70$\pm$0.74 \\
DDPP (+OOD) (Ours) &    metric &             BALD &              -0.38$\pm$0.55 &  24.67$\pm$6.18 & 3.05$\pm$0.53 \\
DDPP (+OOD) (Ours) &    metric & SMP &               0.05$\p

In [15]:
import torch

In [30]:
labels = torch.tensor([-100, 0, 1, 2, -100])
num_labels = 3

In [31]:
padding_ids = labels == -100
labels[padding_ids] = 0
labels

tensor([0, 0, 1, 2, 0])

In [32]:
one_hot = torch.nn.functional.one_hot(labels, num_classes=num_labels)
one_hot

tensor([[1, 0, 0],
        [1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0]])

In [33]:
one_hot[padding_ids] = 0

In [34]:
one_hot

tensor([[0, 0, 0],
        [1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [0, 0, 0]])