In [94]:
import pandas as pd
import os
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

In [65]:
res_files = [Path(".")/x for x in os.listdir(".") if "csv" in x]
res_files

[PosixPath("asl_asl_test_notransfer+28022023_20:46:42_ps-{'grid_search': [[10, 'p', 10], [20]]}result_log.csv"),
 PosixPath("asl_test_asl_starter+27022023_17:59:06_ps-{'grid_search': [[2, 'p', 5]]}result_log.csv"),
 PosixPath("asl_test_asl_starter+27022023_13:18:49_ps-{'grid_search': [[2, 'p', 2], [4]]}result_log.csv")]

In [226]:
def merge_results(all_files): return pd.concat([pd.read_csv(file) for file in all_files])

In [227]:
merged_df = merge_results(res_files)
merged_df.head(5)

Unnamed: 0.1,Unnamed: 0,loss,accuracy,time_this_iter_s,should_checkpoint,done,timesteps_total,episodes_total,training_iteration,trial_id,...,config/model,config/name_fn,config/num_cpu,config/num_gpu,config/pixel_replacement_method,config/proxy_steps,config/proxy_threshold,config/subset_images,config/transfer_imagenet,logdir
0,0,3.221793,"tensor(0.0878, device='cuda:0', dtype=torch.fl...",31.788656,True,False,,,8,a297c_00000,...,resnet18,<function asl_name_fn at 0x7f0f708ca050>,11,1,mean,"[10, 'p', 10]",0.921169,10000,False,/home/eragon/ray_results/train_proxy_steps_202...
1,1,3.368893,"tensor(0.0540, device='cuda:0', dtype=torch.fl...",42.158657,True,True,,,1,a297c_00001,...,resnet18,<function asl_name_fn at 0x7f0f708ca050>,11,1,mean,"[10, 'p', 10]",0.857935,10000,False,/home/eragon/ray_results/train_proxy_steps_202...
2,2,3.324268,"tensor(0.0576, device='cuda:0', dtype=torch.fl...",30.061499,True,True,,,2,a297c_00002,...,resnet18,<function asl_name_fn at 0x7f0f708ca050>,11,1,mean,"[10, 'p', 10]",0.899466,10000,False,/home/eragon/ray_results/train_proxy_steps_202...
3,3,3.361528,"tensor(0.0432, device='cuda:0', dtype=torch.fl...",31.432137,True,True,,,1,a297c_00003,...,resnet18,<function asl_name_fn at 0x7f0f708ca050>,11,1,mean,"[10, 'p', 10]",0.889297,10000,False,/home/eragon/ray_results/train_proxy_steps_202...
4,4,3.365735,"tensor(0.0364, device='cuda:0', dtype=torch.fl...",31.297887,True,True,,,1,a297c_00004,...,resnet18,<function asl_name_fn at 0x7f0f708ca050>,11,1,mean,"[10, 'p', 10]",0.825207,10000,False,/home/eragon/ray_results/train_proxy_steps_202...


In [228]:
merged_df.shape

(360, 44)

## Helper fns

In [229]:
def tensor_str_to_val(string): return np.float32(string.replace("tensor(", "").split(",")[0].strip())

In [230]:
def check_proxy(string): return "p" in string

In [231]:
def calc_stats(values):
    return f"min: {values.min()} \nmax: {values.max()} \navg: {values.mean()}"

In [241]:
def check_wrt_accuracy(has_proxy, condition, head = None):
    grouped_vals = has_proxy.groupby([condition, "config/transfer_imagenet"]).mean(numeric_only = True)["accuracy"].reset_index().sort_values(by = ["accuracy"], ascending=False)
    print("Transfer")
    if head != None:
        print(grouped_vals[grouped_vals["config/transfer_imagenet"] == True].head(head).to_markdown())
        print("No Transfer")
        print(grouped_vals[grouped_vals["config/transfer_imagenet"] == False].head(head).to_markdown())
    else:
        print(grouped_vals[grouped_vals["config/transfer_imagenet"] == True].to_markdown())
        print("No Transfer")
        print(grouped_vals[grouped_vals["config/transfer_imagenet"] == False].to_markdown())

## Aggregate results for proxy

In [233]:
merged_df["has_proxy"] = merged_df["config/proxy_steps"].apply(check_proxy)

In [234]:
merged_df["accuracy"] = merged_df["accuracy"].apply(tensor_str_to_val)

In [235]:
merged_df["has_proxy"].value_counts()

True     216
False    144
Name: has_proxy, dtype: int64

In [236]:
has_proxy = merged_df[merged_df["has_proxy"] == True]
nothas_proxy = merged_df[merged_df["has_proxy"] == False]

In [237]:
calc_stats(has_proxy["accuracy"])

'min: 0.029600000008940697 \nmax: 0.9905999898910522 \navg: 0.662534236907959'

In [238]:
calc_stats(nothas_proxy["accuracy"])


'min: 0.030400000512599945 \nmax: 0.9918000102043152 \navg: 0.5157617926597595'

## Parameter elimination

In [239]:
merged_df.columns

Index(['Unnamed: 0', 'loss', 'accuracy', 'time_this_iter_s',
       'should_checkpoint', 'done', 'timesteps_total', 'episodes_total',
       'training_iteration', 'trial_id', 'experiment_id', 'date', 'timestamp',
       'time_total_s', 'pid', 'hostname', 'node_ip', 'time_since_restore',
       'timesteps_since_restore', 'iterations_since_restore', 'warmup_time',
       'config/batch_size', 'config/change_subset_attention',
       'config/clear_every_step', 'config/device', 'config/ds_name',
       'config/ds_path', 'config/enable_proxy_attention', 'config/epoch_steps',
       'config/experiment_name', 'config/fname_start',
       'config/gradient_method', 'config/image_size', 'config/load_proxy_data',
       'config/model', 'config/name_fn', 'config/num_cpu', 'config/num_gpu',
       'config/pixel_replacement_method', 'config/proxy_steps',
       'config/proxy_threshold', 'config/subset_images',
       'config/transfer_imagenet', 'logdir', 'has_proxy'],
      dtype='object')

In [242]:
check_wrt_accuracy(has_proxy, "config/clear_every_step")

Transfer
|    | config/clear_every_step   | config/transfer_imagenet   |   accuracy |
|---:|:--------------------------|:---------------------------|-----------:|
|  3 | True                      | True                       |   0.963012 |
|  1 | False                     | True                       |   0.962954 |
No Transfer
|    | config/clear_every_step   | config/transfer_imagenet   |   accuracy |
|---:|:--------------------------|:---------------------------|-----------:|
|  0 | False                     | False                      |  0.0648167 |
|  2 | True                      | False                      |  0.0584556 |


In [215]:
check_wrt_accuracy(has_proxy, "config/change_subset_attention")

Transfer
|    |   config/change_subset_attention | config/transfer_imagenet   |   accuracy |
|---:|---------------------------------:|:---------------------------|-----------:|
|  1 |                              0.3 | True                       |   0.965348 |
|  5 |                              0.8 | True                       |   0.962621 |
|  3 |                              0.5 | True                       |   0.960981 |
No Transfer
|    |   config/change_subset_attention | config/transfer_imagenet   |   accuracy |
|---:|---------------------------------:|:---------------------------|-----------:|
|  0 |                              0.3 | False                      |  0.0648167 |
|  2 |                              0.5 | False                      |  0.0641917 |
|  4 |                              0.8 | False                      |  0.0559    |


In [216]:

check_wrt_accuracy(has_proxy, "config/gradient_method")

Transfer
|    | config/gradient_method   | config/transfer_imagenet   |   accuracy |
|---:|:-------------------------|:---------------------------|-----------:|
|  5 | gradcamplusplus          | True                       |   0.964796 |
|  3 | gradcam                  | True                       |   0.962375 |
|  1 | eigencam                 | True                       |   0.961779 |
No Transfer
|    | config/gradient_method   | config/transfer_imagenet   |   accuracy |
|---:|:-------------------------|:---------------------------|-----------:|
|  4 | gradcamplusplus          | False                      |  0.0675958 |
|  0 | eigencam                 | False                      |  0.0597    |
|  2 | gradcam                  | False                      |  0.0576125 |


In [217]:

check_wrt_accuracy(has_proxy, "config/pixel_replacement_method")

Transfer
|    | config/pixel_replacement_method   | config/transfer_imagenet   |   accuracy |
|---:|:----------------------------------|:---------------------------|-----------:|
|  3 | max                               | True                       |   0.965078 |
|  5 | mean                              | True                       |   0.963639 |
|  1 | halfmax                           | True                       |   0.962253 |
|  7 | min                               | True                       |   0.960964 |
No Transfer
|    | config/pixel_replacement_method   | config/transfer_imagenet   |   accuracy |
|---:|:----------------------------------|:---------------------------|-----------:|
|  2 | max                               | False                      |  0.0663611 |
|  6 | min                               | False                      |  0.0648556 |
|  0 | halfmax                           | False                      |  0.0601667 |
|  4 | mean                              | F

In [219]:

check_wrt_accuracy(has_proxy, "config/proxy_threshold", head= 4)

Transfer
|     |   config/proxy_threshold | config/transfer_imagenet   |   accuracy |
|----:|-------------------------:|:---------------------------|-----------:|
|   0 |                 0.800505 | True                       |     0.9906 |
|  51 |                 0.826222 | True                       |     0.9906 |
| 131 |                 0.875854 | True                       |     0.9902 |
| 172 |                 0.914068 | True                       |     0.9878 |
No Transfer
|     |   config/proxy_threshold | config/transfer_imagenet   |   accuracy |
|----:|-------------------------:|:---------------------------|-----------:|
|  29 |                 0.81375  | False                      |     0.1319 |
| 135 |                 0.883611 | False                      |     0.1299 |
| 170 |                 0.913836 | False                      |     0.1269 |
| 155 |                 0.901073 | False                      |     0.121  |


In [220]:

check_wrt_accuracy(has_proxy, "config/subset_images")

Transfer
|    |   config/subset_images | config/transfer_imagenet   |   accuracy |
|---:|-----------------------:|:---------------------------|-----------:|
|  1 |                  10000 | True                       |   0.962983 |
No Transfer
|    |   config/subset_images | config/transfer_imagenet   |   accuracy |
|---:|-----------------------:|:---------------------------|-----------:|
|  0 |                  10000 | False                      |  0.0616361 |


In [221]:

check_wrt_accuracy(has_proxy, "config/proxy_steps")

Transfer
|    | config/proxy_steps   | config/transfer_imagenet   |   accuracy |
|---:|:---------------------|:---------------------------|-----------:|
|  2 | [2, 'p', 5]          | True                       |   0.964621 |
|  1 | [2, 'p', 2]          | True                       |   0.961346 |
No Transfer
|    | config/proxy_steps   | config/transfer_imagenet   |   accuracy |
|---:|:---------------------|:---------------------------|-----------:|
|  0 | [10, 'p', 10]        | False                      |  0.0616361 |


In [244]:

check_wrt_accuracy(merged_df, "config/proxy_steps")

Transfer
|    | config/proxy_steps   | config/transfer_imagenet   |   accuracy |
|---:|:---------------------|:---------------------------|-----------:|
|  4 | [4]                  | True                       |   0.966972 |
|  2 | [2, 'p', 5]          | True                       |   0.964621 |
|  1 | [2, 'p', 2]          | True                       |   0.961346 |
No Transfer
|    | config/proxy_steps   | config/transfer_imagenet   |   accuracy |
|---:|:---------------------|:---------------------------|-----------:|
|  3 | [20]                 | False                      |  0.0645514 |
|  0 | [10, 'p', 10]        | False                      |  0.0616361 |
