In [1]:
import sys
sys.path.append("../")
import dotenv
from src.utils.wandb_utils import get_config, config_to_omegaconf
from run import set_personal_dir_from_hostname
import os

%load_ext autoreload
%autoreload 2

In [2]:
set_personal_dir_from_hostname()
# load environment variables from `.env` file if it exists
dotenv.load_dotenv(override=True)

True

In [3]:
user = "kealexanderwang"
project = "importance-reweighing"

In [4]:
run_names = [
    "likely-mountain-495",
    "easy-field-494",
    "grateful-paper-493",
]

query = {"displayName": {"$in": run_names}}

In [5]:
from typing import Tuple


def init_model(name_config: Tuple[str, dict], ckpt_name="last.ckpt"):
    
    name, config = name_config
    config["data_dir"] = os.environ.get("DATA_DIR")
    config["datamodule/data_dir"] = os.environ.get("DATA_DIR")
    config["architecture/pretrained"] = False  # don't load resnet checkpoint for permission reasons
    
    config = config_to_omegaconf(config)
    run_dir = config["run_dir"]
    ckpt_path = f"{run_dir}/checkpoints/{ckpt_name}"
    
    config.trainer.gpus = 0  # don't use GPU for test time
    from src.train import hydra_init
    hydra_objs = hydra_init(config)
    
    model = hydra_objs.model
    import torch
    ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
    model.load_state_dict(ckpt["state_dict"])
    model.eval()
    model.cpu()
    
    datamodule = hydra_objs.datamodule
    datamodule.setup()
    return model, datamodule

In [6]:
name_configs = get_config(user=user, project=project, query=query)

In [7]:
model_datamodule_lst = [init_model(name_config) for name_config in name_configs]

<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable', 'class_weights']
Resetting final linear layer


GPU available: False, used: False
TPU available: False, using: 0 TPU cores


Train class counts: Counter({0: 3682, 1: 1113})
Train group counts: Counter({0: 3498, 3: 1057, 1: 184, 2: 56})
Val class counts: Counter({0: 933, 1: 266})
Val group counts: Counter({0: 467, 1: 466, 3: 133, 2: 133})
Test class counts: Counter({0: 4510, 1: 1284})
Test group counts: Counter({0: 2255, 1: 2255, 3: 642, 2: 642})
Dataset classes were undersampled to Counter({1: 108, 0: 98})
Dataset groups were undersampled to Counter({1: 57, 2: 56, 3: 52, 0: 41})


  weights = torch.tensor(weights)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable', 'class_weights']
Resetting final linear layer
Train class counts: Counter({0: 3682, 1: 1113})
Train group counts: Counter({0: 3498, 3: 1057, 1: 184, 2: 56})
Val class counts: Counter({0: 933, 1: 266})
Val group counts: Counter({0: 467, 1: 466, 3: 133, 2: 133})
Test class counts: Counter({0: 4510, 1: 1284})
Test group counts: Counter({0: 2255, 1: 2255, 3: 642, 2: 642})
Dataset classes were undersampled to Counter({0: 130, 1: 115})
Dataset groups were undersampled to Counter({0: 74, 3: 59, 2: 56, 1: 56})


In [8]:
val_dataset = model_datamodule_lst[0][1].val_dataset
models = [m_d[0] for m_d in model_datamodule_lst]

In [9]:
import copy
averaged_model = copy.deepcopy(models[0])

In [10]:
models[0].state_dict().keys()

odict_keys(['architecture.conv1.weight', 'architecture.bn1.weight', 'architecture.bn1.bias', 'architecture.bn1.running_mean', 'architecture.bn1.running_var', 'architecture.bn1.num_batches_tracked', 'architecture.layer1.0.downsample.0.weight', 'architecture.layer1.0.downsample.1.weight', 'architecture.layer1.0.downsample.1.bias', 'architecture.layer1.0.downsample.1.running_mean', 'architecture.layer1.0.downsample.1.running_var', 'architecture.layer1.0.downsample.1.num_batches_tracked', 'architecture.layer1.0.conv1.weight', 'architecture.layer1.0.bn1.weight', 'architecture.layer1.0.bn1.bias', 'architecture.layer1.0.bn1.running_mean', 'architecture.layer1.0.bn1.running_var', 'architecture.layer1.0.bn1.num_batches_tracked', 'architecture.layer1.0.conv2.weight', 'architecture.layer1.0.bn2.weight', 'architecture.layer1.0.bn2.bias', 'architecture.layer1.0.bn2.running_mean', 'architecture.layer1.0.bn2.running_var', 'architecture.layer1.0.bn2.num_batches_tracked', 'architecture.layer1.0.conv3.w

In [11]:
from typing import List, Dict
import torch

def average_params(target_state_dict: dict, other_state_dicts: List[dict], state_dict_keys: List[str]=[]):
    if state_dict_keys == []:
        state_dict_keys = list(target_state_dict.keys())    
    for state_dict_key in state_dict_keys:
        other_params = [other[state_dict_key] for other in other_state_dicts]
        averaged_params = torch.stack(other_params, axis=0).float().mean(0) # WARNING: We're also averaging batch norm statistics here!
        target_state_dict[state_dict_key] = averaged_params
    return target_state_dict

In [12]:
averaged_state_dict = average_params(averaged_model.state_dict(), [m.state_dict() for m in models])
averaged_model.load_state_dict(averaged_state_dict)

<All keys matched successfully>

In [13]:
del models # save some memory, delete other models

In [16]:
y_pred

(tensor(0.2222, grad_fn=<DivBackward0>),
 tensor([[-1.7820,  1.7610],
         [-1.9164,  2.0339],
         [ 0.8052, -0.9037],
         [-1.4280,  1.2056],
         [-1.5421,  1.5030],
         [-1.5618,  1.2545],
         [-1.6002,  1.7873],
         [-3.2132,  2.8146],
         [-0.3205,  0.1856],
         [-1.9415,  1.9130],
         [-1.7621,  1.5743],
         [-2.7017,  2.3098],
         [-1.4035,  1.2672],
         [-0.3577,  0.2682],
         [-2.6781,  2.4736],
         [-1.2993,  1.3919],
         [-0.6123,  0.4669],
         [-1.9032,  1.8806],
         [-1.7044,  1.7879],
         [-0.8972,  0.5990],
         [-1.5478,  1.4132],
         [ 1.5000, -1.6037],
         [ 1.0748, -1.1990],
         [ 1.5667, -1.7044],
         [ 1.2408, -1.4194],
         [-0.1605,  0.2223],
         [ 0.2227, -0.3095],
         [ 0.2881, -0.4925],
         [-0.7544,  0.6778],
         [-0.8659,  0.8924],
         [-0.6766,  0.7123],
         [ 0.2415, -0.3054]], grad_fn=<AddmmBackward>),
 ten

In [18]:
import torch
from torch.utils.data import DataLoader

dataloader = DataLoader(val_dataset, batch_size=64)
y_trues = []
y_preds = []
weights = []
group_labels = []
for batch in dataloader:
    _, logits, y_pred, _, _ = averaged_model.step(batch)

    y = batch.y
    g = batch.g
    w = batch.g
    
    y_trues.append(y)
    y_preds.append(y_pred)
    weights.append(w)
    group_labels.append(g)
    
y_trues = torch.cat(y_trues)
y_preds = torch.cat(y_preds)
group_labels = torch.cat(group_labels)
weights = torch.cat(weights)

correct = (y_trues == y_preds).float()

val_acc = correct.mean().item()

In [19]:
val_acc

0.9274395108222961

In [20]:
group_accs = {}                                                                   
correct = (y_trues == y_preds).float()
for g in torch.unique(group_labels).tolist():                                     
    in_g = group_labels == g                                                      
    acc = correct[in_g].mean().item()                                             
    group_accs[g] = acc                                                           

val_reweighted_acc = (correct * weights).sum(0) / weights.sum(0) 

In [21]:
print(group_accs)

{0: 0.9464668035507202, 1: 0.9442059993743896, 2: 0.8345864415168762, 3: 0.8947368264198303}


In [22]:
print(val_reweighted_acc)

tensor(0.9010)
