In [1]:
from easydict import EasyDict  as edict
import torch
def create_args(exp_params):
    args = edict()
    task_args = edict()
    args.exp_id="X"
    args.seed = exp_params['seed']
    # --------------- MODEL ---------------------
    args.model = 'scnn'
    args.hidden_dim = 100
    args.output_dims = 1  # Equals number of classes
    
    # ---------- TRAINING PARAMS ----------------
    args.load_pretrained = True                                    # Do we start with a pretrained model or not 
    args.pretrained_path = f"../models/scnn_{exp_params['method']}_{exp_params['dataset']}_{exp_params['corr']}_True_{exp_params['seed']}_best_.pth" # path to pretrained model
    task_args.n_interventions = 0                                  # Amount of times we stop training before applying intervention (forgetting, playing)
    task_args.total_iterations = exp_params['max_iters']                         # 9452 = 1 epoch
    
    # ---------- MODEL PERSISTENCE --------------
    args.save_model = False                                        # Save last model
    args.save_best = False                                          # Save best performing model
    args.save_stats = True                                         # Save final performance metrics
    args.save_model_folder = 'models'                              # Folder where models are stored 
    args.save_grads_folder = 'grads'                               # Folder where gradients are saved to
    args.save_model_path = ".pth"                                  # suffix for saved model (name depends on model settings)
    args.use_comet = False
    
    # ------------------- TRAINING METHOD ------------------------------
    args.max_cur_iter = 0
    args.task_iter = 0
    args.mode = ["task"                                             # task = train on dataset defined in task_args 
                   #, 'play'                                        #  play = train on dataset defined in play_args 
                   #,'forget'                                       # forget = after training on task, forget using method defined in args.forget_method
                       ]
    args.base_method = "erm"                                      # gdro = group distributionally robust optimization
                                                                    # rw = reweight losses
                                                                    # erm = Empirical Risk Minimization 
    
    # --------- DATASET -----------------------------------------------------------------------------
    args.eval_datasets = dict()                                    # Which datasets to evaluate
    args.task_datasets = dict()     
    args.dataset_paths = {'synmnist': "../../datasets/SynMNIST",      # Path for each dataset
                          'mnistcifar': "../../datasets/MNISTCIFAR"}
    args.task_datasets['env1'] = {'name': exp_params['dataset'], 'corr': float(exp_params['corr'])
                                  , 'splits': ['train', 'test'], 'bs': 10000, "binarize": True}
    
    # All datasets listed on eval_datasets will be evaluated. One dataset per key, however, each dataset may evaluate multiple splits.
    for ds_id, ds in args.task_datasets.items():
        args.eval_datasets[f'task_{ds_id}'] = ds 
    args.eval_datasets['eval'] = {'name': exp_params['dataset'], 'corr': 0.0, 'splits': ['val'], 'bs': 50000, "binarize": True}
    # -------- METRICS -----------------------------------------------------------------------------
    args.metrics = ['acc', 'loss','worst_group_loss', 'worst_group_acc', "best_group_loss", "best_group_acc"]
    # --------------- Consolidate all settings on args --------------------
    args.task_args = task_args
    return args

def load_model(model, weights_path):

    w = torch.load(weights_path)
    s_dict = w['model']
    s_dict2 = dict()
    for k, v in s_dict.items():
        if k in ['fc.0.weight','fc.0.bias']:
            s_dict2[k.replace("0.","")] = v
        else:
            s_dict2[k] = v
        
    model.load_state_dict(s_dict2, strict=False)
    return model

In [2]:
import sys
from torch.optim import Adam
from tqdm.notebook import tqdm
sys.path.append('/media/alain/Data/Tesis/spur/')
from dataset import make_dataloaders
from train import train,evaluate_splits
from models import create_model
from numpy.random import choice
from torch.utils.data import Subset, DataLoader
from utils import update_metrics, save_stats
# load a model
# load balanced dataset
# define dataset size (hyperparameter)
# finetune model on balanced dataset
# report metrics (worst_group_*, best_group_*, acc, loss)
# create table with data at the method/dataset/spur/seed level then aggregate metho/dataset/spur

def run_experiment(exp_params):
    print(exp_params)
    all_metrics = {'task_env1': dict(), 'eval': dict()}
    args = create_args(exp_params)
    for k in all_metrics.keys():
        for split in ["train", "val", "test"]:
            for m in args.metrics:
                all_metrics[k][f"{split}_{m}"] = []
    
    # define args for dataloader
    model=create_model(args).cuda()
    model = load_model(model, args.pretrained_path).cuda()
    opt = Adam(model.parameters(), lr=0.001)#),momentum=0.9,weight_decay=0.01)
    # reload datasets
    dls = make_dataloaders(args)
    dl = dls['task']['env1']['test']              #train on balanced version of dataset
    n_samples = len(dl.dataset)
    print(n_samples)
    random_indices = choice(n_samples, exp_params["ft_size"])
    dl = DataLoader(Subset(dl.dataset, indices=random_indices),batch_size=10000,shuffle=True) # Get subset of dataset
    for i in tqdm(range(args.task_args.total_iterations),total=args.task_args.total_iterations):
        model,_,_ = train(model,dl,opt,args)
        metrics = evaluate_splits(model,dls['eval'],args,"task")
        # accumulate metrics
        for ds_name, m in metrics.items():
            all_metrics[ds_name] = update_metrics(all_metrics[ds_name], m)
    
    return args, all_metrics # {"worst_group}

In [3]:
from os.path import join
from os import listdir
def choose_experiments(method, model_dir = "models"):
    def make_file_dict(f):
        f = f.split("_")
        return {
                'model': f[0],
                'method': f[1],
                'dataset': f[2],
                'corr': f[3],
                'seed': f[5]
               }
    files = []
    for f in listdir(model_dir):
        if method in f:
            files.append(make_file_dict(f))
    return files

exps = choose_experiments("rw", model_dir="../models")
max_iters = 3
size_of_ft = 1000
for size_of_ft in [10, 50, 100, 200, 500, 1000]:
    for e in tqdm(exps,total=len(exps)):
        e['max_iters'] = max_iters
        e['ft_size'] = size_of_ft
        args, results = run_experiment(e)
        args.base_method +=f"_ft_{size_of_ft}"
        save_stats(args,results,root="../stats")

  0%|          | 0/60 [00:00<?, ?it/s]

{'model': 'scnn', 'method': 'rw', 'dataset': 'mnistcifar', 'corr': '0.5', 'seed': '777', 'max_iters': 3, 'ft_size': 1000}
1000


  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    train    test
------------------------  -------  ------
loss                      1.105    1.359
best_group_loss           0.004    0.049
worst_group_loss          2.461    2.651
acc                       86.43%   85.70%
best_group_acc            99.72%   99.29%
worst_group_acc           72.27%   71.54%
#   0-eval-metric    val
-------------------  ------
loss                 1.549
best_group_loss      0.060
worst_group_loss     3.166
acc                  84.30%
best_group_acc       98.75%
worst_group_acc      70.52%

#   1-task_env1-metric    train    test
------------------------  -------  -------
loss                      2.808    3.271
best_group_loss           0.000    0.001
worst_group_loss          6.351    6.870
acc                       76.23%   75.60%
best_group_acc            100.00%  100.00%
worst_group_acc           49.30%   47.52%
#   1-eval-metric    val
-------------------  -------
loss                 3.283
best_group_loss      0.002
worst_gro

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    train    test
------------------------  -------  -------
loss                      0.241    1.752
best_group_loss           0.001    0.000
worst_group_loss          0.351    1.906
acc                       95.41%   81.70%
best_group_acc            100.00%  100.00%
worst_group_acc           93.85%   79.29%
#   0-eval-metric    val
-------------------  ------
loss                 1.025
best_group_loss      0.141
worst_group_loss     1.828
acc                  88.20%
best_group_acc       96.67%
worst_group_acc      80.86%

#   1-task_env1-metric    train    test
------------------------  -------  -------
loss                      1.156    1.587
best_group_loss           0.000    0.000
worst_group_loss          1.987    2.793
acc                       87.51%   84.70%
best_group_acc            100.00%  100.00%
worst_group_acc           79.73%   74.69%
#   1-eval-metric    val
-------------------  ------
loss                 1.441
best_group_loss      0.370
worst_gr

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    train    test
------------------------  -------  -------
loss                      1.971    2.410
best_group_loss           0.000    0.002
worst_group_loss          4.543    6.027
acc                       77.40%   75.90%
best_group_acc            100.00%  100.00%
worst_group_acc           51.68%   45.73%
#   0-eval-metric    val
-------------------  -------
loss                 2.685
best_group_loss      0.000
worst_group_loss     6.151
acc                  76.60%
best_group_acc       100.00%
worst_group_acc      47.50%

#   1-task_env1-metric    train    test
------------------------  -------  ------
loss                      0.379    0.671
best_group_loss           0.000    0.054
worst_group_loss          1.080    1.834
acc                       91.84%   90.70%
best_group_acc            100.00%  98.36%
worst_group_acc           78.10%   78.38%
#   1-eval-metric    val
-------------------  ------
loss                 0.867
best_group_loss      0.139
worst_gr

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    train    test
------------------------  -------  -------
loss                      3.632    7.473
best_group_loss           0.000    0.000
worst_group_loss          8.045    15.152
acc                       73.71%   61.50%
best_group_acc            100.00%  100.00%
worst_group_acc           30.19%   24.27%
#   0-eval-metric    val
-------------------  -------
loss                 6.020
best_group_loss      0.000
worst_group_loss     15.875
acc                  64.90%
best_group_acc       100.00%
worst_group_acc      17.58%

#   1-task_env1-metric    train    test
------------------------  -------  -------
loss                      0.296    1.320
best_group_loss           0.001    0.000
worst_group_loss          0.423    1.387
acc                       94.79%   85.10%
best_group_acc            100.00%  100.00%
worst_group_acc           92.73%   84.10%
#   1-eval-metric    val
-------------------  ------
loss                 0.985
best_group_loss      0.187
wors

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    train    test
------------------------  -------  ------
loss                      0.002    27.955
best_group_loss           0.000    13.301
worst_group_loss          0.004    42.377
acc                       99.92%   1.50%
best_group_acc            100.00%  2.62%
worst_group_acc           99.84%   0.40%
#   0-eval-metric    val
-------------------  -------
loss                 14.104
best_group_loss      0.000
worst_group_loss     41.737
acc                  49.80%
best_group_acc       100.00%
worst_group_acc      0.00%

#   1-task_env1-metric    train    test
------------------------  -------  ------
loss                      0.023    22.195
best_group_loss           0.000    8.001
worst_group_loss          0.045    36.164
acc                       99.33%   7.30%
best_group_acc            100.00%  14.72%
worst_group_acc           98.66%   0.00%
#   1-eval-metric    val
-------------------  -------
loss                 11.257
best_group_loss      0.000
worst_g

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    train    test
------------------------  -------  -------
loss                      0.327    2.164
best_group_loss           0.000    0.000
worst_group_loss          0.668    3.910
acc                       92.70%   72.90%
best_group_acc            100.00%  100.00%
worst_group_acc           84.43%   53.97%
#   0-eval-metric    val
-------------------  ------
loss                 1.428
best_group_loss      0.003
worst_group_loss     4.192
acc                  80.60%
best_group_acc       99.58%
worst_group_acc      50.20%

#   1-task_env1-metric    train    test
------------------------  -------  -------
loss                      0.252    1.060
best_group_loss           0.000    0.001
worst_group_loss          0.339    1.252
acc                       94.01%   83.90%
best_group_acc            100.00%  100.00%
worst_group_acc           92.43%   82.16%
#   1-eval-metric    val
-------------------  ------
loss                 0.782
best_group_loss      0.158
worst_gr

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    train    test
------------------------  -------  -------
loss                      0.188    1.946
best_group_loss           0.000    0.000
worst_group_loss          0.326    2.164
acc                       96.03%   79.40%
best_group_acc            100.00%  100.00%
worst_group_acc           93.20%   77.41%
#   0-eval-metric    val
-------------------  ------
loss                 1.157
best_group_loss      0.140
worst_group_loss     2.601
acc                  87.60%
best_group_acc       98.33%
worst_group_acc      75.70%

#   1-task_env1-metric    train    test
------------------------  -------  -------
loss                      2.557    4.044
best_group_loss           0.000    0.000
worst_group_loss          5.197    7.931
acc                       79.32%   73.40%
best_group_acc            100.00%  100.00%
worst_group_acc           59.25%   49.38%
#   1-eval-metric    val
-------------------  ------
loss                 3.401
best_group_loss      0.026
worst_gr

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    train    test
------------------------  -------  ------
loss                      0.574    0.987
best_group_loss           0.010    0.096
worst_group_loss          1.157    1.997
acc                       89.07%   84.90%
best_group_acc            99.69%   97.41%
worst_group_acc           78.17%   72.29%
#   0-eval-metric    val
-------------------  ------
loss                 1.352
best_group_loss      0.148
worst_group_loss     2.760
acc                  84.00%
best_group_acc       97.63%
worst_group_acc      68.75%

#   1-task_env1-metric    train    test
------------------------  -------  ------
loss                      2.560    2.924
best_group_loss           0.000    0.009
worst_group_loss          5.185    6.185
acc                       68.56%   67.50%
best_group_acc            100.00%  99.61%
worst_group_acc           36.23%   32.20%
#   1-eval-metric    val
-------------------  ------
loss                 3.139
best_group_loss      0.020
worst_group_

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    train    test
------------------------  -------  -------
loss                      4.447    5.094
best_group_loss           0.000    0.000
worst_group_loss          8.955    10.594
acc                       60.99%   60.90%
best_group_acc            100.00%  100.00%
worst_group_acc           19.30%   18.85%
#   0-eval-metric    val
-------------------  -------
loss                 5.110
best_group_loss      0.000
worst_group_loss     10.686
acc                  59.60%
best_group_acc       100.00%
worst_group_acc      14.74%

#   1-task_env1-metric    train    test
------------------------  -------  ------
loss                      0.506    1.027
best_group_loss           0.000    0.117
worst_group_loss          1.219    2.006
acc                       90.16%   86.10%
best_group_acc            100.00%  98.09%
worst_group_acc           77.35%   74.66%
#   1-eval-metric    val
-------------------  ------
loss                 1.039
best_group_loss      0.102
worst_

KeyboardInterrupt: 

In [None]:
'../models/scnn_gdro_mnistcifar_0.75_True_222_best_.pth

In [36]:
e

{'model': 'scnn',
 'method': 'mnistcifar',
 'dataset': '0.75',
 'corr': 'True',
 'seed': 'nofrz',
 'max_iters': 10}

In [23]:
!ls ../models/scnn_gdro_mnistcifar_0.75_True_222_best_.pth

../models/scnn_gdro_mnistcifar_0.75_True_222_best_.pth


In [32]:
a = torch.load("../models/scnn_gdro_mnistcifar_0.75_True_222_best_.pth")

In [34]:
print(a)

None


In [13]:
choice(2, 5)

array([0, 0, 1, 0, 1])