In [18]:
import pandas as pd

def get_top_policies(df, k):
        """Prints and returns top-k policies

        Policies are ordered by their expected accuracy increas
        Args:
            k (int) top-k
        Returns
            pandas.DataFrame: top-k policies as dataframe
        """
        trial_avg_val_acc_df = (
            df.drop_duplicates(["trial_no", "sample_no"])
            .groupby("trial_no")
            .mean()["mean_late_val_acc"]
            .reset_index()
        )[["trial_no", "mean_late_val_acc"]]

        x_df = pd.merge(
            df.drop(columns=["mean_late_val_acc"]),
            trial_avg_val_acc_df,
            on="trial_no",
            how="left",
        )

        x_df = x_df.sort_values("mean_late_val_acc", ascending=False)

        baseline_val_acc = x_df[x_df["trial_no"] == 0]["mean_late_val_acc"].values[0]

        x_df["expected_accuracy_increase(%)"] = (
            x_df["mean_late_val_acc"] - baseline_val_acc
        )*100

        top_df = x_df.drop_duplicates(["trial_no"]).sort_values(
            "mean_late_val_acc", ascending=False
        )[:k]

        SELECT = [
            "trial_no",
            'A_aug1_type', 'A_aug1_magnitude', 'A_aug2_type', 'A_aug2_magnitude',
            'B_aug1_type', 'B_aug1_magnitude', 'B_aug2_type', 'B_aug2_magnitude',
            'C_aug1_type', 'C_aug1_magnitude', 'C_aug2_type', 'C_aug2_magnitude',
            'D_aug1_type', 'D_aug1_magnitude', 'D_aug2_type', 'D_aug2_magnitude',
            'E_aug1_type', 'E_aug1_magnitude', 'E_aug2_type', 'E_aug2_magnitude',
            "mean_late_val_acc", "expected_accuracy_increase(%)"
        ]
        top_df = top_df[SELECT]

#         print(f"top-{k} policies:", k)
#         print(top_df)

        return top_df

get_top_policies(pd.read_csv('./autoaugment_results.csv'), 105)

Unnamed: 0,trial_no,A_aug1_type,A_aug1_magnitude,A_aug2_type,A_aug2_magnitude,B_aug1_type,B_aug1_magnitude,B_aug2_type,B_aug2_magnitude,C_aug1_type,...,D_aug1_type,D_aug1_magnitude,D_aug2_type,D_aug2_magnitude,E_aug1_type,E_aug1_magnitude,E_aug2_type,E_aug2_magnitude,mean_late_val_acc,expected_accuracy_increase(%)
310,15,horizontal-flip,0.357,vertical-flip,0.296,dropout,0.447,vertical-flip,0.248,dropout,...,brighten,0.259,emboss,0.671,coarse-salt-pepper,0.131,invert,0.968,0.783,20.8
581,29,gamma-contrast,0.087,coarse-salt-pepper,0.734,rotate,0.031,coarse-salt-pepper,0.244,translate-x,...,brighten,0.462,brighten,0.229,coarse-salt-pepper,0.137,sharpen,0.99,0.767,19.2
170,8,rotate,0.097,vertical-flip,0.26,coarse-dropout,0.448,gaussian-blur,0.352,dropout,...,coarse-dropout,0.167,clouds,0.865,additive-gaussian-noise,0.14,crop,0.983,0.75,17.5
540,27,shear,0.809,gamma-contrast,0.856,clouds,0.183,translate-y,0.206,crop,...,additive-gaussian-noise,0.037,coarse-salt-pepper,0.358,horizontal-flip,0.132,translate-y,0.973,0.75,17.5
610,30,dropout,0.756,clouds,0.181,dropout,0.056,horizontal-flip,0.254,shear,...,coarse-dropout,0.553,horizontal-flip,0.806,coarse-salt-pepper,0.115,additive-gaussian-noise,0.161,0.75,17.5
37,1,gamma-contrast,0.844,coarse-salt-pepper,0.847,brighten,0.384,translate-y,0.057,translate-y,...,emboss,0.836,sharpen,0.648,emboss,0.957,rotate,0.87,0.725,15.0
516,25,invert,0.136,sharpen,0.24,vertical-flip,0.444,sharpen,0.2,vertical-flip,...,vertical-flip,0.098,vertical-flip,0.588,add-to-hue-and-saturation,0.149,rotate,0.158,0.717,14.2
539,26,sharpen,0.17,sharpen,0.723,add-to-hue-and-saturation,0.293,clouds,0.833,crop,...,fog,0.02,brighten,0.279,translate-x,0.135,coarse-salt-pepper,0.968,0.717,14.2
142,7,invert,0.533,translate-x,0.395,brighten,0.475,dropout,0.716,translate-y,...,rotate,0.092,emboss,0.552,crop,0.969,sharpen,0.221,0.7,12.5
80,4,sharpen,0.675,sharpen,0.778,horizontal-flip,0.663,crop,0.623,invert,...,gaussian-blur,0.451,crop,0.442,vertical-flip,0.359,dropout,0.689,0.7,12.5


In [12]:
get_top_policies(pd.read_csv('./autoaugment_results.csv'), 20).to_csv('./top_df',index=False)