In [115]:
import torch 
from torch.utils.data import DataLoader 
from tqdm import tqdm 
import numpy as np
from typing import List, Optional, Tuple
import sys 
sys.path.insert(0, "../../")
from DEFAULTS import BASE_PATH 
from loaders import get_dataset 
from model_builder import get_pretrained_model_v2 

DATASET = "dl-sim"
MODEL = "mae-lightning-small"
WEIGHTS = "MAE_SMALL_IMAGENET1K_V1"
GLOBAL_POOL = "avg"


def get_classes(dataset: str):
    if dataset == "optim":
        return ["Actin", "Tubulin", "CaMKII", "PSD95"]  
    elif dataset == "neural-activity-states":
        return ["Block", "0Mg", "GluGly", "48hTTX"]
    elif dataset == "peroxisome":
        return ["6hGluc", "6hMeOH"]
    elif dataset == "polymer-rings":
        return ["CdvB1", "CdvB2"]
    elif dataset == "dl-sim":
        return ["adhesion", "factin", "microtubule", "mitosis"]
    else:
        raise ValueError(f"Dataset {dataset} not supported")
    
CLASSES = get_classes(DATASET)
N_CLASSES = len(CLASSES)


In [116]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Running on {DEVICE} ---")

model, cfg = get_pretrained_model_v2(
    name=MODEL,
    weights=WEIGHTS,
    path=None,
    mask_ratio=0.0,
    pretrained=True if "imagenet" in WEIGHTS.lower() else False,
    in_channels=3 if "imagenet" in WEIGHTS.lower() else 1,
    as_classifier=True,
    blocks="all",
    num_classes=4
)
model.to(DEVICE)
model.eval()

--- Running on cuda ---
mask_ratio 0.0
pretrained True
in_channels 3
blocks all
num_classes 4
--- mae-lightning-small | Pretrained Image-Net ---

--- Loaded model mae-lightning-small with ImageNet weights ---
--- Freezing every parameter in mae-lightning-small ---
--- Added linear probe to all frozen blocks ---


In [117]:
_, _, test_loader = get_dataset(
    name=DATASET,
    transform=None,
    training=True,
    path=None,
    batch_size=cfg.batch_size,
    n_channels=3 if "imagenet" in WEIGHTS.lower() else 1,
)

Samples:  DL-SIM-training.txt
adhesion 440
factin 623
microtubule 725
mitochondrial 669
----------
Class adhesion samples: 440
Class factin samples: 623
Class microtubule samples: 725
Class mitochondrial samples: 669
----------
Mean: 0.2438482108613763, Std: 0.1221117670657044
Samples:  DL-SIM-validation.txt
adhesion 229
factin 259
microtubule 282
mitochondrial 284
----------
Class adhesion samples: 229
Class factin samples: 259
Class microtubule samples: 282
Class mitochondrial samples: 284
----------
Mean: 0.23256373279542203, Std: 0.12156702883520837
Samples:  DL-SIM-testing.txt
adhesion 99
factin 126
microtubule 167
mitochondrial 135
----------
Class adhesion samples: 99
Class factin samples: 126
Class microtubule samples: 167
Class mitochondrial samples: 135
----------
Mean: 0.24179844530927158, Std: 0.12182894677490595
Training size: 5863
Validation size: 2535
Test size: 1242


In [118]:
embeddings, labels, dataset_idx = [], [], []
N = len(test_loader.dataset)
with torch.no_grad():
    for i in range(N):
        img = test_loader.dataset[i][0].unsqueeze(0)
        metadata = test_loader.dataset[i][1]
        img = img.to(DEVICE)
        label = metadata["label"]
        d_id = metadata["dataset-idx"]
        output = model.forward_features(img)
        embeddings.append(output)
        labels.append(label)
        dataset_idx.append(d_id)

embeddings = torch.cat(embeddings, dim=0)
labels = np.array(labels)
dataset_idx = np.array(dataset_idx)
print(embeddings.shape, labels.shape)
assert embeddings.shape[0] == labels.shape[0]

torch.Size([1242, 384]) (1242,)


In [119]:
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import trange


num_repetitions = 50
num_samples = 30
rep_accuracies = []
for n in trange(embeddings.shape[0]):
    random_embedding = embeddings[n]
    target_label = labels[n]
    img = test_loader.dataset[n][0].squeeze().cpu().numpy()
    img = img[0] if "imagenet" in WEIGHTS.lower() else img
    similarities = F.cosine_similarity(embeddings, random_embedding.unsqueeze(0), dim=1).cpu().numpy()
    sorted_indices = np.argsort(similarities)[::-1]
    query_labels = []

    for i in sorted_indices[1:num_samples+1]:
        sim = similarities[i]
        data_index = dataset_idx[i]
        similar_img = test_loader.dataset[data_index][0].squeeze().cpu().numpy()
        query_labels.append(labels[i])
        similar_img = similar_img[0] if "imagenet" in WEIGHTS.lower() else similar_img
        # if labels[i] != target_label:
        #     fig, axs = plt.subplots(1, 2)
        #     axs[0].imshow(img, cmap='hot')
        #     axs[1].imshow(similar_img, cmap='hot')
        #     axs[1].set_title(f"Similarity: {sim:.2f}")
        #     for ax in axs:
        #         ax.axis('off')
        #     plt.show()

    retrieval_accuracy = np.sum(np.array(query_labels) == target_label) / len(query_labels)
    rep_accuracies.append(retrieval_accuracy)

print(f"Average retrieval accuracy: {np.mean(rep_accuracies):.2f} ± {np.std(rep_accuracies):.2f}")



100%|██████████| 1242/1242 [00:30<00:00, 40.60it/s]

Average retrieval accuracy: 0.88 ± 0.22



