In [1]:
import sys
sys.path.append("../")
import dotenv
from src.utils.wandb_utils import get_config, config_to_omegaconf
from run import set_personal_dir_from_hostname
import os

%load_ext autoreload
%autoreload 2

In [2]:
set_personal_dir_from_hostname()
# load environment variables from `.env` file if it exists
dotenv.load_dotenv(override=True)

True

In [3]:
user = "kealexanderwang"
project = "importance-reweighing"

In [4]:
run_names = [
    "likely-mountain-495",
    "easy-field-494",
    "grateful-paper-493",
]

query = {"displayName": {"$in": run_names}}

In [5]:
from typing import Tuple


def init_model(name_config: Tuple[str, dict], ckpt_name="last.ckpt"):
    
    name, config = name_config
    config["data_dir"] = os.environ.get("DATA_DIR")
    config["datamodule/data_dir"] = os.environ.get("DATA_DIR")
    config["architecture/pretrained"] = False  # don't load resnet checkpoint for permission reasons
    
    config = config_to_omegaconf(config)
    run_dir = config["run_dir"]
    ckpt_path = f"{run_dir}/checkpoints/{ckpt_name}"
    
    config.trainer.gpus = 0  # don't use GPU for test time
    from src.train import hydra_init
    hydra_objs = hydra_init(config)
    
    model = hydra_objs.model
    import torch
    ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
    model.load_state_dict(ckpt["state_dict"])
    model.eval()
    model.cpu()
    
    datamodule = hydra_objs.datamodule
    datamodule.setup()
    return model, datamodule

In [6]:
name_configs = get_config(user=user, project=project, query=query)

In [7]:
model_datamodule_lst = [init_model(name_config) for name_config in name_configs]

<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable', 'class_weights']
Resetting final linear layer


GPU available: False, used: False
TPU available: False, using: 0 TPU cores


Train class counts: Counter({0: 3682, 1: 1113})
Train group counts: Counter({0: 3498, 3: 1057, 1: 184, 2: 56})
Val class counts: Counter({0: 933, 1: 266})
Val group counts: Counter({0: 467, 1: 466, 3: 133, 2: 133})
Test class counts: Counter({0: 4510, 1: 1284})
Test group counts: Counter({0: 2255, 1: 2255, 3: 642, 2: 642})
Dataset classes were undersampled to Counter({1: 118, 0: 105})
Dataset groups were undersampled to Counter({3: 62, 2: 56, 0: 54, 1: 51})


  weights = torch.tensor(weights)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable', 'class_weights']
Resetting final linear layer
Train class counts: Counter({0: 3682, 1: 1113})
Train group counts: Counter({0: 3498, 3: 1057, 1: 184, 2: 56})
Val class counts: Counter({0: 933, 1: 266})
Val group counts: Counter({0: 467, 1: 466, 3: 133, 2: 133})
Test class counts: Counter({0: 4510, 1: 1284})
Test group counts: Counter({0: 2255, 1: 2255, 3: 642, 2: 642})
Dataset classes were undersampled to Counter({0: 115, 1: 104})
Dataset groups were undersampled to Counter({0: 63, 2: 56, 1: 52, 3: 48})


In [8]:
val_dataset = model_datamodule_lst[0][1].val_dataset
models = [m_d[0] for m_d in model_datamodule_lst]

In [9]:
def ensemble_step(batch, models, method="majority"):
    """Majority vote ensembling"""
    import torch
    loss_preds_y = [model.step(batch) for model in models]
    _, logits_lst, preds_lst, _, _ = zip(*loss_preds_y)
    if method == "majority":
        ensemble_preds = torch.stack(preds_lst)
        preds = torch.mode(ensemble_preds, axis=0).values
    elif method == "average":
        ensemble_logits = torch.stack(logits_lst)
        ensemble_probs = torch.nn.functional.softmax(ensemble_logits, dim=-1)
        average_prob = ensemble_probs.mean(0)
        preds = torch.argmax(average_prob, dim=-1)
    return preds

In [10]:
import torch
from torch.utils.data import DataLoader

dataloader = DataLoader(val_dataset, batch_size=64)
y_trues = []
y_preds = []
weights = []
group_labels = []
for batch in dataloader:
    y_pred = ensemble_step(batch, models, "majority")
    
    y = batch.y
    g = batch.g
    w = batch.g
    
    y_trues.append(y)
    y_preds.append(y_pred)
    weights.append(w)
    group_labels.append(g)
    
y_trues = torch.cat(y_trues)
y_preds = torch.cat(y_preds)
group_labels = torch.cat(group_labels)
weights = torch.cat(weights)

correct = (y_trues == y_preds).float()

val_acc = correct.mean().item()

In [11]:
val_acc

0.9099249243736267

In [16]:
group_accs = {}                                                                   
correct = (y_trues == y_preds).float()
for g in torch.unique(group_labels).tolist():                                     
    in_g = group_labels == g                                                      
    acc = correct[in_g].mean().item()                                             
    group_accs[g] = acc                                                           

val_reweighted_acc = (correct * weights).sum(0) / weights.sum(0) 

In [17]:
print(group_accs)

{0: 0.9357602000236511, 1: 0.8991416096687317, 2: 0.8571428656578064, 3: 0.9097744226455688}


In [18]:
print(val_reweighted_acc)

tensor(0.8930)
