In [11]:
import pandas as pd

def get_top_policies(df, k):
        """Prints and returns top-k policies

        Policies are ordered by their expected accuracy increas
        Args:
            k (int) top-k
        Returns
            pandas.DataFrame: top-k policies as dataframe
        """
        trial_avg_val_acc_df = (
            df.drop_duplicates(["trial_no", "sample_no"])
            .groupby("trial_no")
            .mean()["mean_late_val_acc"]
            .reset_index()
        )[["trial_no", "mean_late_val_acc"]]

        x_df = pd.merge(
            df.drop(columns=["mean_late_val_acc"]),
            trial_avg_val_acc_df,
            on="trial_no",
            how="left",
        )

        x_df = x_df.sort_values("mean_late_val_acc", ascending=False)

        baseline_val_acc = x_df[x_df["trial_no"] == 0]["mean_late_val_acc"].values[0]

        x_df["expected_accuracy_increase(%)"] = (
            x_df["mean_late_val_acc"] - baseline_val_acc
        )*100

        top_df = x_df.drop_duplicates(["trial_no"]).sort_values(
            "mean_late_val_acc", ascending=False
        )[:k]

        SELECT = [
            "trial_no",
            'A_aug1_type', 'A_aug1_magnitude', 'A_aug2_type', 'A_aug2_magnitude',
            'B_aug1_type', 'B_aug1_magnitude', 'B_aug2_type', 'B_aug2_magnitude',
            'C_aug1_type', 'C_aug1_magnitude', 'C_aug2_type', 'C_aug2_magnitude',
            'D_aug1_type', 'D_aug1_magnitude', 'D_aug2_type', 'D_aug2_magnitude',
            'E_aug1_type', 'E_aug1_magnitude', 'E_aug2_type', 'E_aug2_magnitude',
            "mean_late_val_acc", "expected_accuracy_increase(%)"
        ]
        top_df = top_df[SELECT]

#         print(f"top-{k} policies:", k)
#         print(top_df)

        return top_df

get_top_policies(pd.read_csv('./autoaugment_results.csv'), 105)

Unnamed: 0,trial_no,A_aug1_type,A_aug1_magnitude,A_aug2_type,A_aug2_magnitude,B_aug1_type,B_aug1_magnitude,B_aug2_type,B_aug2_magnitude,C_aug1_type,...,D_aug1_type,D_aug1_magnitude,D_aug2_type,D_aug2_magnitude,E_aug1_type,E_aug1_magnitude,E_aug2_type,E_aug2_magnitude,mean_late_val_acc,expected_accuracy_increase(%)
602,30,super-pixels,0.813,perspective-transform,0.903,emboss,0.359,emboss,0.996,grayscale,...,translate-y,0.838,invert,0.169,sharpen,0.163,add-to-hue-and-saturation,0.423,0.671,10.7
843,42,translate-y,0.714,clouds,0.073,translate-x,0.972,gaussian-blur,0.008,sharpen,...,gamma-contrast,0.049,clouds,0.888,sharpen,0.15,shear,0.258,0.667,10.3
122,6,add-to-hue-and-saturation,0.704,emboss,0.433,fog,0.396,perspective-transform,0.639,perspective-transform,...,rotate,0.767,coarse-dropout,0.675,sharpen,0.313,grayscale,0.588,0.663,9.9
421,21,perspective-transform,0.802,dropout,0.27,brighten,0.991,emboss,0.713,add-to-hue-and-saturation,...,rotate,0.896,clouds,0.815,sharpen,0.184,translate-x,0.673,0.649,8.5
837,41,brighten,0.018,elastic-transform,0.096,elastic-transform,0.329,clouds,0.105,perspective-transform,...,rotate,0.626,emboss,0.131,sharpen,0.126,translate-y,0.311,0.647,8.3
920,46,coarse-dropout,0.626,translate-y,0.265,vertical-flip,0.975,rotate,0.595,emboss,...,crop,0.745,dropout,0.411,sharpen,0.498,rotate,0.48,0.629,6.5
283,14,translate-x,0.99,clouds,0.169,emboss,0.083,brighten,0.97,sharpen,...,grayscale,0.474,coarse-salt-pepper,0.987,shear,0.118,sharpen,0.598,0.627,6.3
785,39,rotate,0.81,coarse-dropout,0.002,vertical-flip,0.832,add-to-hue-and-saturation,0.704,shear,...,fog,0.418,elastic-transform,0.073,sharpen,0.131,invert,0.905,0.626,6.2
979,48,shear,0.798,brighten,0.127,rotate,0.992,gaussian-blur,0.452,brighten,...,grayscale,0.303,translate-x,0.682,additive-gaussian-noise,0.11,elastic-transform,0.555,0.624,6.0
447,22,emboss,0.826,add-to-hue-and-saturation,0.146,elastic-transform,0.496,sharpen,0.984,elastic-transform,...,sharpen,0.823,translate-x,0.864,sharpen,0.006,coarse-salt-pepper,0.461,0.62,5.6


In [12]:
get_top_policies(pd.read_csv('./autoaugment_results.csv'), 20).to_csv('./top_df',index=False)