In [1]:
from easydict import EasyDict  as edict
import torch
def create_args(exp_params):
    args = edict()
    task_args = edict()
    args.exp_id="X"
    args.seed = exp_params['seed']
    # --------------- MODEL ---------------------
    args.model = 'scnn'
    args.hidden_dim = 100
    args.output_dims = 1  # Equals number of classes
    
    # ---------- TRAINING PARAMS ----------------
    args.load_pretrained = True                                    # Do we start with a pretrained model or not 
    args.pretrained_path = f"../models/scnn_{exp_params['method']}_{exp_params['dataset']}_{exp_params['corr']}_True_{exp_params['seed']}_best_.pth" # path to pretrained model
    task_args.n_interventions = 0                                  # Amount of times we stop training before applying intervention (forgetting, playing)
    task_args.total_iterations = exp_params['max_iters']                         # 9452 = 1 epoch
    
    # ---------- MODEL PERSISTENCE --------------
    args.save_model = False                                        # Save last model
    args.save_best = False                                          # Save best performing model
    args.save_stats = True                                         # Save final performance metrics
    args.save_model_folder = 'models'                              # Folder where models are stored 
    args.save_grads_folder = 'grads'                               # Folder where gradients are saved to
    args.save_model_path = ".pth"                                  # suffix for saved model (name depends on model settings)
    args.use_comet = False
    
    # ------------------- TRAINING METHOD ------------------------------
    args.max_cur_iter = 0
    args.task_iter = 0
    args.mode = ["task"                                             # task = train on dataset defined in task_args 
                   #, 'play'                                        #  play = train on dataset defined in play_args 
                   #,'forget'                                       # forget = after training on task, forget using method defined in args.forget_method
                       ]
    args.base_method = "erm"                                      # gdro = group distributionally robust optimization
                                                                    # rw = reweight losses
                                                                    # erm = Empirical Risk Minimization 
    
    # --------- DATASET -----------------------------------------------------------------------------
    args.eval_datasets = dict()                                    # Which datasets to evaluate
    args.task_datasets = dict()     
    args.dataset_paths = {'synmnist': "../../datasets/SynMNIST",      # Path for each dataset
                          'mnistcifar': "../../datasets/MNISTCIFAR"}
    args.task_datasets['env1'] = {'name': exp_params['dataset'], 'corr': float(exp_params['corr'])
                                  , 'splits': ['train', 'test'], 'bs': 10000, "binarize": True}
    
    # All datasets listed on eval_datasets will be evaluated. One dataset per key, however, each dataset may evaluate multiple splits.
    for ds_id, ds in args.task_datasets.items():
        args.eval_datasets[f'task_{ds_id}'] = ds 
    args.eval_datasets['eval'] = {'name': exp_params['dataset'], 'corr': 0.0, 'splits': ['val'], 'bs': 50000, "binarize": True}
    # -------- METRICS -----------------------------------------------------------------------------
    args.metrics = ['acc', 'loss','worst_group_loss', 'worst_group_acc', "best_group_loss", "best_group_acc"]
    # --------------- Consolidate all settings on args --------------------
    args.task_args = task_args
    args.svdropout_p = 0.0
    return args

def load_model(model, weights_path):

    w = torch.load(weights_path)
    s_dict = w['model']
    s_dict2 = dict()
    for k, v in s_dict.items():
        if k in ['fc.0.weight','fc.0.bias']:
            s_dict2[k.replace("0.","")] = v
        else:
            s_dict2[k] = v
        
    model.load_state_dict(s_dict2, strict=False)
    return model

In [2]:
import sys
from torch.optim import Adam
from tqdm.notebook import tqdm
sys.path.append('/media/alain/Data/Tesis/spur/')
from dataset import make_dataloaders
from train import train,evaluate_splits, run_eval_iteration
from models import create_model
from numpy.random import choice
from torch.utils.data import Subset, DataLoader
from utils import update_metrics, save_stats
# load a model
# load balanced dataset
# define dataset size (hyperparameter)
# finetune model on balanced dataset
# report metrics (worst_group_*, best_group_*, acc, loss)
# create table with data at the method/dataset/spur/seed level then aggregate metho/dataset/spur

def run_dfr_experiment(exp_params):
    print(exp_params)
    all_metrics = {'task_env1': dict(), 'eval': dict()}
    args = create_args(exp_params)
    for k in all_metrics.keys():
        for split in ["train", "val", "test"]:
            for m in args.metrics:
                all_metrics[k][f"{split}_{m}"] = []
    
    # define args for dataloader
    model=create_model(args).cuda()
    model = load_model(model, args.pretrained_path).cuda()
    opt = Adam(model.parameters(), lr=0.001)#),momentum=0.9,weight_decay=0.01)
    # reload datasets
    dls = make_dataloaders(args)
    dl = dls['task']['env1']['test']              #train on balanced version of dataset
    n_samples = len(dl.dataset)
    print(n_samples)
    random_indices = choice(n_samples, exp_params["ft_size"])
    dl = DataLoader(Subset(dl.dataset, indices=random_indices),batch_size=10000,shuffle=True) # Get subset of dataset
    for i in tqdm(range(args.task_args.total_iterations),total=args.task_args.total_iterations):
        model,_,_ = train(model,dl,opt,args)
        metrics = evaluate_splits(model,dls['eval'],args,"task")
        # accumulate metrics
        for ds_name, m in metrics.items():
            all_metrics[ds_name] = update_metrics(all_metrics[ds_name], m)
    
    return args, all_metrics # {"worst_group}

def run_jtt_experiment(exp_params):
    print(exp_params)
    all_metrics = {'task_env1': dict(), 'eval': dict()}
    args = create_args(exp_params)
    for k in all_metrics.keys():
        for split in ["train", "val", "test"]:
            for m in args.metrics:
                all_metrics[k][f"{split}_{m}"] = []
    
    # define args for dataloader
    model=create_model(args).cuda()
    model = load_model(model, args.pretrained_path).cuda()
    opt = Adam(model.parameters(), lr=0.001)#),momentum=0.9,weight_decay=0.01)
    # reload datasets
    dls = make_dataloaders(args)
    dl = dls['task']['env1']['train']              #train on balanced version of dataset
    n_samples = len(dl.dataset)
    # run evaluation iteration on train set
    for n_batch, (x, y, g) in enumerate(dl):
        data = {'input': x, 'labels': y, 'groups': g}
    logits = run_eval_iteration(data,model,args)
    incorrect = (logits['logits'].squeeze() >0.5) != y
    weights = torch.ones_like(incorrect).float().cuda()
    weights[incorrect] = exp_params['lambda']
    print(f"N# of Incorrect samples is: {incorrect.sum():0d}/9000")
    for i in tqdm(range(args.task_args.total_iterations),total=args.task_args.total_iterations):
        model,_,_ = train(model,dl,opt,args,weights=weights)
        metrics = evaluate_splits(model,dls['eval'],args,"task")
        # accumulate metrics
        for ds_name, m in metrics.items():
            all_metrics[ds_name] = update_metrics(all_metrics[ds_name], m)
    
    return args, all_metrics # {"worst_group}

In [4]:
from os.path import join
from os import listdir
def choose_experiments(method, model_dir = "models"):
    def make_file_dict(f):
        f = f.split("_")
        return {
                'model': f[0],
                'method': f[1],
                'dataset': f[2],
                'corr': f[3],
                'seed': f[5]
               }
    files = []
    for f in listdir(model_dir):
        if method in f:
            files.append(make_file_dict(f))
    return files

exps = choose_experiments("jtt", model_dir="../models")
max_iters = 3
size_of_ft = 1000

method = "jtt"

if method == "dfr":
    for size_of_ft in [10, 50, 100, 200, 500, 1000]:
        for e in tqdm(exps,total=len(exps)):
            e['max_iters'] = max_iters
            e['ft_size'] = size_of_ft
            args, results = run_dfr_experiment(e)
            args.base_method +=f"_ft_{size_of_ft}"
            save_stats(args,results,root="../stats")
elif method == "jtt":
    for lmbda in [5,10,15,20]:
        for e in tqdm(exps,total=len(exps)): 
            e['max_iters'] = max_iters
            e['lambda'] = lmbda
            args, results = run_jtt_experiment(e)
            args.base_method = f"{e['method']}_{lmbda}"
            print(args.base_method)
            save_stats(args,results,root="../stats")

  0%|          | 0/75 [00:00<?, ?it/s]

{'model': 'scnn', 'method': 'jtt200', 'dataset': 'mnistcifar', 'corr': '0.25', 'seed': '222', 'max_iters': 3, 'lambda': 5}
N# of Incorrect samples is: 720/9000


  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    test     train
------------------------  -------  -------
loss                      2.060    1.977
best_group_loss           0.001    0.002
worst_group_loss          4.236    4.292
acc                       53.30%   53.73%
best_group_acc            100.00%  99.96%
worst_group_acc           6.56%    6.42%
#   0-eval-metric    val
-------------------  -------
loss                 2.088
best_group_loss      0.001
worst_group_loss     4.390
acc                  53.30%
best_group_acc       100.00%
worst_group_acc      3.52%

#   1-task_env1-metric    test    train
------------------------  ------  -------
loss                      0.404   0.353
best_group_loss           0.028   0.042
worst_group_loss          0.838   0.837
acc                       83.10%  84.54%
best_group_acc            99.46%  99.07%
worst_group_acc           64.59%  62.31%
#   1-eval-metric    val
-------------------  -------
loss                 0.399
best_group_loss      0.032
worst_group_los

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    test     train
------------------------  -------  -------
loss                      3.196    2.557
best_group_loss           0.001    0.000
worst_group_loss          7.199    7.041
acc                       50.90%   51.32%
best_group_acc            100.00%  100.00%
worst_group_acc           0.00%    0.00%
#   0-eval-metric    val
-------------------  -------
loss                 3.023
best_group_loss      0.000
worst_group_loss     7.506
acc                  51.10%
best_group_acc       100.00%
worst_group_acc      0.00%

#   1-task_env1-metric    test    train
------------------------  ------  -------
loss                      0.507   0.288
best_group_loss           0.021   0.019
worst_group_loss          1.131   1.086
acc                       78.80%  87.38%
best_group_acc            99.29%  99.77%
worst_group_acc           52.46%  48.60%
#   1-eval-metric    val
-------------------  -------
loss                 0.444
best_group_loss      0.016
worst_group_lo

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    test     train
------------------------  -------  -------
loss                      3.222    2.737
best_group_loss           0.001    0.001
worst_group_loss          7.039    6.938
acc                       50.50%   50.20%
best_group_acc            100.00%  100.00%
worst_group_acc           0.00%    0.00%
#   0-eval-metric    val
-------------------  -------
loss                 3.064
best_group_loss      0.001
worst_group_loss     7.157
acc                  49.90%
best_group_acc       100.00%
worst_group_acc      0.00%

#   1-task_env1-metric    test     train
------------------------  -------  -------
loss                      0.660    0.416
best_group_loss           0.033    0.036
worst_group_loss          1.410    1.433
acc                       71.10%   81.43%
best_group_acc            100.00%  99.77%
worst_group_acc           36.89%   35.53%
#   1-eval-metric    val
-------------------  ------
loss                 0.571
best_group_loss      0.035
worst_g

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    test     train
------------------------  -------  -------
loss                      2.302    2.307
best_group_loss           0.006    0.006
worst_group_loss          4.817    4.678
acc                       50.40%   49.97%
best_group_acc            100.00%  100.00%
worst_group_acc           0.00%    0.04%
#   0-eval-metric    val
-------------------  -------
loss                 2.372
best_group_loss      0.006
worst_group_loss     4.720
acc                  49.60%
best_group_acc       100.00%
worst_group_acc      0.00%

#   1-task_env1-metric    test    train
------------------------  ------  -------
loss                      0.548   0.586
best_group_loss           0.067   0.073
worst_group_loss          1.039   1.108
acc                       73.30%  70.81%
best_group_acc            99.60%  99.56%
worst_group_acc           43.18%  41.70%
#   1-eval-metric    val
-------------------  ------
loss                 0.586
best_group_loss      0.072
worst_group_los

  0%|          | 0/3 [00:00<?, ?it/s]

#   0-task_env1-metric    test    train
------------------------  ------  -------
loss                      2.070   1.973
best_group_loss           0.013   0.013
worst_group_loss          4.281   4.359
acc                       54.30%  54.32%
best_group_acc            99.68%  99.23%
worst_group_acc           8.20%   6.36%
#   0-eval-metric    val
-------------------  -------
loss                 2.038
best_group_loss      0.004
worst_group_loss     4.421
acc                  54.20%
best_group_acc       100.00%
worst_group_acc      3.91%

#   1-task_env1-metric    test    train
------------------------  ------  -------
loss                      0.522   0.471
best_group_loss           0.079   0.082
worst_group_loss          1.066   1.080
acc                       76.20%  78.98%
best_group_acc            95.68%  97.01%
worst_group_acc           50.82%  51.83%
#   1-eval-metric    val
-------------------  ------
loss                 0.499
best_group_loss      0.074
worst_group_loss     1.0

KeyboardInterrupt: 