In [5]:
import numpy as np
from datasets import get_dataset
import sys 
import matplotlib.pyplot as plt
from tiffwrapper import make_composite
from decoders import get_decoder
import torch
import os
from stedfm.model_builder import get_base_model, get_pretrained_model_v2
from stedfm.configuration import Configuration 
from stedfm.DEFAULTS import BASE_PATH


DATASET = "lioness"
MODEL = "mae-lightning-small"
WEIGHTS = "MAE_SMALL_STED"
CHECKPOINT = "/home-local/Frederic/segmentation-baselines/mae-lightning-small/lioness/pretrained-MAE_SMALL_STED-46" 
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
backbone, cfg = get_pretrained_model_v2(MODEL, WEIGHTS)
_, _, test_dataset = get_dataset(name=DATASET, cfg=cfg)

in_channels 1
--- mae-lightning-small | /home-local/Frederic/baselines/mae-small_STED/pl_checkpoint-999.pth ---

--- Loaded model mae-lightning-small with weights MAE_SMALL_STED ---


In [7]:
class SegmentationConfiguration(Configuration):
    
    freeze_backbone: bool = True
    num_epochs: int = 300
    learning_rate: float = 1e-4
segmentation_cfg = SegmentationConfiguration()
for key, value in segmentation_cfg.__dict__.items():
        setattr(cfg, key, value)

model = get_decoder(backbone, cfg)
ckpt = torch.load(os.path.join(CHECKPOINT, "result.pt"))["model"]
model.load_state_dict(ckpt)
model = model.to(DEVICE)
model.eval()


===== Loading ViTSegmentationClassifier =====

--- Freezing backbone ---


ViTSegmentationClassifier(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(1, 384, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(app

In [8]:
from tqdm import tqdm

def merge_masks(masks):
    colors = ["green", "magenta"] #, "cyan", "yellow"]
    min_value = np.min(masks)
    max_value = np.max(masks)
    image_rgb = make_composite(masks, luts=colors, ranges=[(0, np.max(masks)) for _ in range(masks.shape[0])])
    return image_rgb

N = len(test_dataset)
indices = np.random.randint(0, N, size=50)

for i in tqdm(indices, total=len(indices)):
    image, mask = test_dataset[i]
    pred = model(image.unsqueeze(0).to(DEVICE)).squeeze().detach().cpu().numpy()
    image = image.squeeze().cpu().numpy()
    mask = mask.cpu().numpy()
    mask = merge_masks(mask)
    pred = merge_masks(pred)
    fig, axs = plt.subplots(1, 3, figsize=(10, 5))
    axs[0].imshow(image, cmap="hot")
    axs[1].imshow(mask, vmin=0, vmax=1)
    axs[2].imshow(pred, vmin=0, vmax=1)
    # axs[0].set_title(i)
    for ax in axs:
        ax.axis("off")
    fig.savefig(f"./dummy_images/{DATASET}_example_{i}.pdf", bbox_inches="tight", dpi=1200)
    plt.close(fig)
    # plt.show()


100%|██████████| 50/50 [01:09<00:00,  1.39s/it]
