In [3]:
import pandas as pd

def get_top_policies(df, k):
        """Prints and returns top-k policies

        Policies are ordered by their expected accuracy increas
        Args:
            k (int) top-k
        Returns
            pandas.DataFrame: top-k policies as dataframe
        """
        trial_avg_val_acc_df = (
            df.drop_duplicates(["trial_no", "sample_no"])
            .groupby("trial_no")
            .mean()["mean_late_val_acc"]
            .reset_index()
        )[["trial_no", "mean_late_val_acc"]]

        x_df = pd.merge(
            df.drop(columns=["mean_late_val_acc"]),
            trial_avg_val_acc_df,
            on="trial_no",
            how="left",
        )

        x_df = x_df.sort_values("mean_late_val_acc", ascending=False)

        baseline_val_acc = x_df[x_df["trial_no"] == 0]["mean_late_val_acc"].values[0]

        x_df["expected_accuracy_increase(%)"] = (
            x_df["mean_late_val_acc"] - baseline_val_acc
        )*100

        top_df = x_df.drop_duplicates(["trial_no"]).sort_values(
            "mean_late_val_acc", ascending=False
        )[:k]

        SELECT = [
            "trial_no",
            'A_aug1_type', 'A_aug1_magnitude', 'A_aug2_type', 'A_aug2_magnitude',
            'B_aug1_type', 'B_aug1_magnitude', 'B_aug2_type', 'B_aug2_magnitude',
            'C_aug1_type', 'C_aug1_magnitude', 'C_aug2_type', 'C_aug2_magnitude',
            'D_aug1_type', 'D_aug1_magnitude', 'D_aug2_type', 'D_aug2_magnitude',
            'E_aug1_type', 'E_aug1_magnitude', 'E_aug2_type', 'E_aug2_magnitude',
            "mean_late_val_acc", "expected_accuracy_increase(%)"
        ]
        top_df = top_df[SELECT]

#         print(f"top-{k} policies:", k)
#         print(top_df)

        return top_df

get_top_policies(pd.read_csv('./autoaugment_results.csv'), 105)

Unnamed: 0,trial_no,A_aug1_type,A_aug1_magnitude,A_aug2_type,A_aug2_magnitude,B_aug1_type,B_aug1_magnitude,B_aug2_type,B_aug2_magnitude,C_aug1_type,...,D_aug1_type,D_aug1_magnitude,D_aug2_type,D_aug2_magnitude,E_aug1_type,E_aug1_magnitude,E_aug2_type,E_aug2_magnitude,mean_late_val_acc,expected_accuracy_increase(%)
79,3,coarse-salt-pepper,0.902,rotate,0.97,horizontal-flip,0.171,additive-gaussian-noise,0.751,coarse-salt-pepper,...,elastic-transform,0.653,coarse-salt-pepper,0.995,add-to-hue-and-saturation,0.414,gamma-contrast,0.624,0.73,5.9
438,21,histogram-equalize,0.934,coarse-salt-pepper,0.547,gaussian-blur,0.729,additive-gaussian-noise,0.643,fog,...,fog,0.093,add-to-hue-and-saturation,0.846,invert,0.493,invert,0.334,0.715,4.4
80,4,additive-gaussian-noise,0.675,emboss,0.778,elastic-transform,0.663,crop,0.623,horizontal-flip,...,gaussian-blur,0.451,crop,0.442,grayscale,0.359,brighten,0.689,0.702,3.1
121,6,add-to-hue-and-saturation,0.704,emboss,0.433,fog,0.396,perspective-transform,0.639,perspective-transform,...,rotate,0.767,coarse-dropout,0.675,sharpen,0.313,grayscale,0.588,0.699,2.8
325,16,gamma-contrast,0.91,gamma-contrast,0.845,translate-y,0.888,clouds,0.993,coarse-salt-pepper,...,gaussian-blur,0.498,fog,0.856,coarse-dropout,0.456,coarse-dropout,0.068,0.694,2.3
381,19,crop,0.96,crop,0.842,horizontal-flip,0.494,elastic-transform,0.793,gaussian-blur,...,perspective-transform,0.22,fog,0.864,coarse-dropout,0.496,horizontal-flip,0.559,0.694,2.3
522,26,additive-gaussian-noise,0.999,gamma-contrast,0.057,shear,0.877,gaussian-blur,0.68,coarse-dropout,...,rotate,0.598,histogram-equalize,0.79,elastic-transform,0.456,dropout,0.538,0.688,1.7
292,14,gaussian-blur,1.0,perspective-transform,0.75,shear,0.191,rotate,0.616,translate-x,...,dropout,0.494,translate-x,0.654,additive-gaussian-noise,0.41,dropout,0.595,0.687,1.6
159,7,horizontal-flip,0.533,translate-y,0.395,coarse-salt-pepper,0.475,gamma-contrast,0.716,emboss,...,rotate,0.092,additive-gaussian-noise,0.552,crop,0.969,additive-gaussian-noise,0.221,0.687,1.6
474,23,emboss,0.834,horizontal-flip,0.615,perspective-transform,0.715,translate-y,0.932,sharpen,...,translate-x,0.064,translate-y,0.987,sharpen,0.496,shear,0.301,0.684,1.3


In [12]:
get_top_policies(pd.read_csv('./autoaugment_results.csv'), 20).to_csv('./top_df',index=False)