In [40]:
import torch
from models.model_utils import get_models, models_size, models, ModelWrapper
from dataset.pcam import get_pcam_dataloaders
from tqdm import tqdm
import numpy as np
import pandas as pd
device = "cuda" if torch.cuda.is_available() else "cpu"

In [32]:
def get_logits(data, model):
    labels = []
    logits = []
    model.to(device)
    model.eval()
    for img, label in tqdm(data):
        with torch.no_grad():
            _, logit = model(img.to(device))
        labels.append(label)
        logits.append(logit)
    return torch.cat(logits).cpu().numpy(), torch.cat(labels).cpu().numpy()

In [45]:
train_loader, val_loader, test_loader = get_pcam_dataloaders("", batch_size=64, num_workers=4, shuffle_train=False)
log_train = {}
log_test = {}
log_val = {}
base = "./chkpnt/teachers/checkpoints/"
num_classes = 2
for name in models.keys(): 
    model = models[name]()
    checkpoint = torch.load(base+f"pcam_{name}.ckpt")
    checkpoint['state_dict'] = {k.replace("model.",""): v for k,v in checkpoint['state_dict'].items()}
    in_features = checkpoint['state_dict']['fc.weight'].shape[1]
    model.fc = torch.nn.Linear(in_features, num_classes)
    model.load_state_dict(checkpoint['state_dict'], strict = True)
    model = ModelWrapper(model)
    yp, y = get_logits(test_loader, model)
    log_test[name] = yp
    yp, y = get_logits(val_loader, model)
    log_val[name] = yp
    yp, y = get_logits(train_loader, model)
    log_train[name] = yp

Using cache found in /export/livia/home/vision/Bkarimian/.cache/torch/hub/kaiko-ai_towards_large_pathology_fms_main
  4%|▍         | 20/512 [00:13<05:25,  1.51it/s]

In [None]:
np.savez("./chkpnt/test_logits.npz", **log_test)


In [None]:
loaded = np.load("./chkpnt/test_logits.npz")
data = {"model":[], "logits_0":[], "logits_1": []}
for name in loaded.keys():
    for l in loaded[name]:
        data["model"].append(name)
        data["logits_0"].append(l[0])
        data["logits_1"].append(l[1])
df = pd.DataFrame(data)
df.head()

Unnamed: 0,model,logits_0,logits_1
0,DINOL14,1.672314,-1.205865
1,DINOL14,-4.42147,4.478689
2,DINOL14,2.958307,-1.599156
3,DINOL14,-8.745443,10.118289
4,DINOL14,2.635189,-0.246823
