In [1]:
import sys
sys.path.append("../")
from src.utils.wandb_utils import get_config, config_to_omegaconf

%load_ext autoreload
%autoreload 2

In [2]:
user = "kealexanderwang"
project = "importance-reweighing"

In [3]:
run_names = [
    "hopeful-jazz-53",
    "summer-bee-52",
    "peachy-universe-51",
    "worthy-star-50",
    "peachy-hill-49",
]

query = {"displayName": {"$in": run_names}}

In [5]:
from typing import Tuple


def init_model(name_config: Tuple[str, dict], epoch):
    
    name, config = name_config
    config = config_to_omegaconf(config)
    run_dir = config["run_dir"]
    ckpt_path = f"{run_dir}/checkpoints/epoch={epoch}.ckpt"
    
    config.trainer.gpus = 0  # don't use GPU for test time
    from src.train import hydra_init
    hydra_objs = hydra_init(config)
    
    model = hydra_objs.model
    import torch
    ckpt = torch.load(ckpt_path)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()
    model.cpu()
    
    datamodule = hydra_objs.datamodule
    datamodule.setup()
    return model, datamodule

In [6]:
name_configs = get_config(user=user, project=project, query=query)

In [7]:
model_datamodule_lst = [init_model(name_config, epoch=499) for name_config in name_configs]

<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable']


GPU available: True, used: False
TPU available: False, using: 0 TPU cores


Files already downloaded and verified
Files already downloaded and verified


GPU available: True, used: False
TPU available: False, using: 0 TPU cores


<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable']
Files already downloaded and verified
Files already downloaded and verified


GPU available: True, used: False
TPU available: False, using: 0 TPU cores


<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable']
Files already downloaded and verified
Files already downloaded and verified


GPU available: True, used: False
TPU available: False, using: 0 TPU cores


<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable']
Files already downloaded and verified
Files already downloaded and verified


GPU available: True, used: False
TPU available: False, using: 0 TPU cores


<class 'src.pl_models.imbalanced_classifier_model.ImbalancedClassifierModel'> initialized with unused kwargs: ['params_total', 'params_trainable', 'params_not_trainable']
Files already downloaded and verified
Files already downloaded and verified


In [8]:
val_dataset = model_datamodule_lst[0][1].val_dataset
models = [m_d[0] for m_d in model_datamodule_lst]

In [15]:
def ensemble_step(batch, models, method="majority"):
    """Majority vote ensembling"""
    import torch
    loss_preds_y = [model.step(batch) for model in models]
    _, logits_lst, preds_lst, _ = zip(*loss_preds_y)
    if method == "majority":
        ensemble_preds = torch.stack(preds_lst)
        preds = torch.mode(ensemble_preds, axis=0).values
    elif method == "average":
        ensemble_logits = torch.stack(logits_lst)
        ensemble_probs = torch.nn.functional.softmax(ensemble_logits, dim=-1)
        average_prob = ensemble_probs.mean(0)
        preds = torch.argmax(average_prob, dim=-1)
    return preds

In [19]:
import torch
from torch.utils.data import DataLoader

dataloader = DataLoader(val_dataset, batch_size=256)
y_trues = []
y_preds = []
for x, y in dataloader:
    batch = x, y
    y_pred = ensemble_step(batch, models, "average")
    
    y_trues.append(y)
    y_preds.append(y_pred)
    
y_trues = torch.cat(y_trues)
y_preds = torch.cat(y_preds)

val_acc = (y_trues == y_preds).float().mean().item()

In [20]:
val_acc

0.7225000262260437

In [21]:
import torch
from torch.utils.data import DataLoader

dataloader = DataLoader(val_dataset, batch_size=256)
y_trues = []
y_preds = []
for x, y in dataloader:
    batch = x, y
    y_pred = ensemble_step(batch, models, "majority")
    
    y_trues.append(y)
    y_preds.append(y_pred)
    
y_trues = torch.cat(y_trues)
y_preds = torch.cat(y_preds)

val_acc = (y_trues == y_preds).float().mean().item()

In [22]:
val_acc

0.7210000157356262