In [5]:
import sys
sys.path.append('..') # add bayesvlm to path

In [32]:
from typing import Tuple
import torch
import torch.distributions as dists
from torchmetrics.classification import MulticlassCalibrationError

from bayesvlm.utils import get_model_type_and_size, get_image_size, get_transform, load_model
from bayesvlm.data.factory import DataModuleFactory
from bayesvlm.hessians import load_hessians, optimize_prior_precision, compute_covariances
from bayesvlm.precompute import precompute_text_features, precompute_image_features, make_predictions

In [33]:
def evaluate_prediction(prediction: torch.Tensor, label: torch.Tensor, num_classes: int) -> Tuple[float, float, float]:
    ece_metric = MulticlassCalibrationError(num_classes=num_classes, n_bins=20, norm='l1')
    one_hot_pred = prediction.argmax(1)
    acc = (one_hot_pred == label).float().cpu().numpy()
    nlpd = -dists.Categorical(prediction).log_prob(label).cpu().numpy()
    ece = ece_metric(prediction, label).item()
    return acc, nlpd, ece

In [10]:
# define the model and dataset
model_str = 'clip-base'
dataset = 'food101'
hessian_dir = '../hessians/hessian_CLIP-ViT-B-32-laion2B-s34B-b79K'
pseudo_data_count = 10
batch_size = 32
num_workers = 4
device = 'mps'

In [7]:
# load model and transforms based on `model_str`
model_type, model_size = get_model_type_and_size(model_str)
transform_image_size = get_image_size(model_str)
transform = get_transform(model_type, transform_image_size)
image_encoder, text_encoder, vlm = load_model(model_str, device)



In [None]:
# load hessians
info = {'n_img': pseudo_data_count, 'n_txt': pseudo_data_count}
A_img, B_img = load_hessians(hessian_dir, tag='img', return_info=False)
A_txt, B_txt = load_hessians(hessian_dir, tag='txt', return_info=False)

# optimize prior precision based on marginal log-likelihood
info['lambda_img'] = optimize_prior_precision(
    image_encoder.vision_projection,
    A=A_img,
    B=B_img,
    lmbda_init=1500,
    n=info['n_img'],
    lr=1e-2,
    num_steps=300,
    device=device,
    verbose=True,
).item()

info['lambda_txt'] = optimize_prior_precision(
    text_encoder.text_projection,
    A=A_txt,
    B=B_txt,
    lmbda_init=1500,
    n=info['n_txt'],
    lr=1e-2,
    num_steps=300,
    device=device,
    verbose=True,
).item()

print("n_img:", info['n_img'])
print("n_txt:", info['n_txt'])
print("lambda_img:", info['lambda_img'])
print("lambda_txt:", info['lambda_txt'])

# pass the covatiances to the model
cov_img, cov_txt = compute_covariances(A_img, B_img, A_txt, B_txt, info)
vlm.set_covariances(cov_img, cov_txt)

Epoch 1/300, loss: -3929096.5, lmbda: 1500.0001220703125
Epoch 2/300, loss: -3928474.5, lmbda: 1515.07568359375
Epoch 3/300, loss: -3927859.5, lmbda: 1530.2978515625
Epoch 4/300, loss: -3927250.5, lmbda: 1545.6646728515625
Epoch 5/300, loss: -3926650.0, lmbda: 1561.174560546875
Epoch 6/300, loss: -3926057.5, lmbda: 1576.8251953125
Epoch 7/300, loss: -3925472.25, lmbda: 1592.6143798828125
Epoch 8/300, loss: -3924894.5, lmbda: 1608.5386962890625
Epoch 9/300, loss: -3924325.75, lmbda: 1624.595947265625
Epoch 10/300, loss: -3923765.25, lmbda: 1640.782958984375
Epoch 11/300, loss: -3923212.0, lmbda: 1657.096435546875
Epoch 12/300, loss: -3922668.0, lmbda: 1673.5330810546875
Epoch 13/300, loss: -3922132.5, lmbda: 1690.0892333984375
Epoch 14/300, loss: -3921606.0, lmbda: 1706.7611083984375
Epoch 15/300, loss: -3921088.5, lmbda: 1723.5447998046875
Epoch 16/300, loss: -3920579.5, lmbda: 1740.435546875
Epoch 17/300, loss: -3920079.5, lmbda: 1757.4298095703125
Epoch 18/300, loss: -3919589.25, lmb

In [15]:
# create the data module
f = DataModuleFactory(
    batch_size=batch_size,
    num_workers=num_workers,
    train_transform=transform,
    test_transform=transform,
    shuffle_train=True,
)
dm = f.create(dataset)
dm.setup()

In [24]:
# precompute embeddings
with torch.no_grad():
    image_outputs_test, image_class_ids_test, image_ids_test = precompute_image_features(
        image_encoder=image_encoder,
        loader=dm.test_dataloader(),
    )

    label_outputs = precompute_text_features(
        text_encoder=text_encoder,
        class_prompts=dm.class_prompts,
        batch_size=batch_size,
    )

100%|██████████| 790/790 [03:53<00:00,  3.39it/s]
100%|██████████| 4/4 [00:01<00:00,  2.17it/s]


In [29]:
# make predictions for vanilla BayesVLM and vanilla CLIP (MAP estimate)
logits_bayesvlm = make_predictions(
    clip=vlm,
    image_outputs=image_outputs_test,
    text_outputs=label_outputs,
    batch_size=batch_size,
    device=device,
    map_estimate=False,
)

logits_map = make_predictions(
    clip=vlm,
    image_outputs=image_outputs_test,
    text_outputs=label_outputs,
    batch_size=batch_size,
    device=device,
    map_estimate=True,
)

  0%|          | 0/790 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 790/790 [00:04<00:00, 162.43it/s]
  0%|          | 0/790 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 790/790 [00:03<00:00, 250.67it/s]


In [31]:
# convert probabilistic logits to probabilities
kappa = 1 / torch.sqrt(1. + torch.pi / 8 * logits_bayesvlm.var)
probas_bayesvlm = torch.softmax(kappa * logits_bayesvlm.mean, dim=-1)

# convert MAP logits to probabilities
probas_map = torch.softmax(logits_map.mean, dim=-1)

In [41]:
# evaluate the predictions
acc_bayesvlm, nlpd_bayesvlm, ece_bayesvlm = evaluate_prediction(
    prediction=probas_bayesvlm, 
    label=image_class_ids_test, 
    num_classes=len(dm.class_prompts),
)

acc_map, nlpd_map, ece_map = evaluate_prediction(
    prediction=probas_map,
    label=image_class_ids_test,
    num_classes=len(dm.class_prompts),
)

In [40]:
# table with zero shot results 
print(f"{'':<10}{'BayesVLM':<10}{'MAP':<10}")
print(f"{'acc':<10}{acc_bayesvlm.mean():<10.4f}{acc_map.mean():<10.4f}")
print(f"{'nlpd':<10}{nlpd_bayesvlm.mean():<10.4f}{nlpd_map.mean():<10.4f}")
print(f"{'ece':<10}{ece_bayesvlm:<10.4f}{ece_map:<10.4f}")

          BayesVLM  MAP       
acc       0.8032    0.8008    
nlpd      0.6808    0.7053    
ece       0.0083    0.0387    
