In [1]:
import sys
sys.path.append("../")
from src.utils.wandb_utils import get_config, config_to_omegaconf

%load_ext autoreload
%autoreload 2

In [6]:
user = "kealexanderwang"
project = "importance-reweighing"

run_names = [
    "lemon-donkey-92",
    "fanciful-sound-91",
    "hopeful-music-90",
    "vibrant-durian-89",
    "solar-grass-88",
]

query = {"displayName": {"$in": run_names}}

In [10]:
from typing import Tuple


def init_model(name_config: Tuple[str, dict], ckpt_name="last.ckpt"):
    
    name, config = name_config
    config = config_to_omegaconf(config)
    run_dir = config["run_dir"]
    ckpt_path = f"{run_dir}/checkpoints/{ckpt_name}"
    
    config.trainer.gpus = 0  # don't use GPU for test time
    from src.train import hydra_init
    hydra_objs = hydra_init(config)
    
    model = hydra_objs.model
    import torch
    ckpt = torch.load(ckpt_path)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()
    model.cpu()
    
    datamodule = hydra_objs.datamodule
    datamodule.setup()
    return model, datamodule

In [11]:
name_configs = get_config(user=user, project=project, query=query)

In [13]:
model_datamodule_lst = [init_model(name_config) for name_config in name_configs]

GPU available: True, used: False
TPU available: False, using: 0 TPU cores


<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable']
Files already downloaded and verified
Files already downloaded and verified


GPU available: True, used: False
TPU available: False, using: 0 TPU cores


<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable']
Files already downloaded and verified
Files already downloaded and verified


GPU available: True, used: False
TPU available: False, using: 0 TPU cores


<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable']
Files already downloaded and verified
Files already downloaded and verified


GPU available: True, used: False
TPU available: False, using: 0 TPU cores


<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable']
Files already downloaded and verified
Files already downloaded and verified


GPU available: True, used: False
TPU available: False, using: 0 TPU cores


<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable']
Files already downloaded and verified
Files already downloaded and verified


In [14]:
val_dataset = model_datamodule_lst[0][1].val_dataset
models = [m_d[0] for m_d in model_datamodule_lst]

In [16]:
import copy
averaged_model = copy.deepcopy(models[0])

In [21]:
averaged_model.state_dict().keys()

odict_keys(['class_weights', 'architecture.0.weight', 'architecture.0.bias', 'architecture.2.weight', 'architecture.2.bias', 'architecture.5.weight', 'architecture.5.bias', 'architecture.7.weight', 'architecture.7.bias', 'architecture.9.weight', 'architecture.9.bias', 'architecture.13.weight', 'architecture.13.bias', 'architecture.15.weight', 'architecture.15.bias', 'architecture.17.weight', 'architecture.17.bias', 'criterion.weight'])

In [26]:
from typing import List, Dict
import torch

def average_params(target_state_dict: dict, other_state_dicts: List[dict], state_dict_keys: List[str]):
    for state_dict_key in state_dict_keys:
        other_params = [other[state_dict_key] for other in other_state_dicts]
        averaged_params = torch.stack(other_params, axis=0).mean(0)
        target_state_dict[state_dict_key] = averaged_params
    return target_state_dict

In [27]:
averaged_state_dict = average_params(averaged_model.state_dict(), [m.state_dict() for m in models], ["architecture.17.weight", "architecture.17.bias"])
averaged_model.load_state_dict(averaged_state_dict)

<All keys matched successfully>

In [39]:
import torch
from torch.utils.data import DataLoader

dataloader = DataLoader(val_dataset, batch_size=256)
y_trues = []
y_preds = []
for x, y in dataloader:
    batch = x, y
    _, _, y_pred, _ = averaged_model.step(batch)
    
    y_trues.append(y)
    y_preds.append(y_pred)
    
y_trues = torch.cat(y_trues)
y_preds = torch.cat(y_preds)

val_acc = (y_trues == y_preds).float().mean().item()

val_frac_predicted_pos = (y_preds == 1).sum().item() / len(y_preds)

In [40]:
print(val_acc)

0.6884999871253967


In [41]:
print(val_frac_predicted_pos)

0.7395
