In [6]:
import sys, os
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
import nibabel as nib
import nrrd
import random
from omegaconf import OmegaConf
from typing import Iterable, Dict, Callable, Tuple, Union
from scipy.ndimage import binary_erosion, binary_dilation
import random


sys.path.append('..')
from dataset import EvalDataset
from utils import *
from model import get_model
from user_model import UserModel

In [3]:
# Load config
cfg = OmegaConf.load('../configs/eval.yaml')

In [4]:
# Load dataset
dataset = EvalDataset(
    subject_id=729254, 
    cfg=cfg, 
    modality='reconstruction',
    to_gpu=False
)

In [None]:
# Update Dataset with annotations
dataset.user = UserModel(dataset.label, cfg)
dataset.init = 'per_class'
# currently, there are no annotations. We can also enforce this with clear_annotations() at any point
dataset.clear_annotation()
# get initial annotations
annot = dataset.initial_annotation(seed=42)
# and update the dataset
dataset.update_annotation(annot)
print(f"number of annotations: {dataset.annotations.detach().cpu().sum()}")

In [5]:
# Load model
model, state_dict = get_model(
    cfg=cfg,
    return_state_dict=True
)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [8]:
# Extract features from model
f_layer = 'encoder'
# Init the feature extractor. Have a look at PyTorchs Hook functionality.
extractor = FeatureExtractor(model, layers=[f_layer])
# Cache all features for a dataset and reformat/move to numpy for random forest stuff
hooked_results  = extractor(dataset)
features = hooked_results[f_layer]
features = features.permute(0,2,3,1).numpy()