# Imports

## built-in

In [None]:
import os
import json

## standard

In [None]:
import numpy as np
import torch

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

## local

In [None]:
from source.constants import ALL_CANCER_TYPES
from source.constants import ALL_IMG_NORMS, ALL_EXTRACTOR_MODELS
from source.constants import DATASET_SPECIFIC_NORMALIZATION_CONSTANTS_PATH
from source.constants import DATA_DIR, FEATURE_VECTORS_SAVE_DIR

print(f"DATA_DIR: {DATA_DIR}")
print(f"FEATURE_VECTORS_SAVE_DIR: {FEATURE_VECTORS_SAVE_DIR}")

print(f"ALL_CANCER_TYPES: {ALL_CANCER_TYPES}")
print(f"ALL_IMG_NORMS: {ALL_IMG_NORMS}")
print(f"DATASET_SPECIFIC_NORMALIZATION_CONSTANTS_PATH: {DATASET_SPECIFIC_NORMALIZATION_CONSTANTS_PATH}")
print(f"ALL_EXTRACTOR_MODELS: {ALL_EXTRACTOR_MODELS}")

In [None]:
from source.feature_extraction.data import get_data_transform
# help(get_data_transform)

In [None]:
from extract_features import (
    prepare_directories,
    calculate_dataset_mean_std,
    update_dataset_specific_mean_std,
    make_pytorch_dataset,
    make_pytorch_dataloader,
    prepare_feature_extractor,
    extract_features,
    save_features,
)

## autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Notebook-level Constants

In [None]:
CANCER_TYPE = 'lung_aca'
IMG_NORM = 'lc25k-lung_aca-resized'
if IMG_NORM.startswith('lc25k'):
    assert CANCER_TYPE in IMG_NORM

EXTRACTOR_NAME = 'dinov2_vitb14'
BATCH_SIZE = 256

assert CANCER_TYPE in ALL_CANCER_TYPES
assert IMG_NORM in ALL_IMG_NORMS
assert EXTRACTOR_NAME in ALL_EXTRACTOR_MODELS

os.listdir(DATA_DIR)

# Prepare location to save features, ids, and ids_2_img_paths mapping

In [None]:
img_dir, features_save_dir = prepare_directories(
    all_img_dir_path=DATA_DIR,
    all_features_save_dir=FEATURE_VECTORS_SAVE_DIR,
    cancer_type=CANCER_TYPE,
    img_norm=IMG_NORM,
    extractor_name=EXTRACTOR_NAME,
)
print(f"img_dir:\n {img_dir}")
print(f"features_save_dir:\n {features_save_dir}")

features_save_paths = {
    'ids': f'{features_save_dir}/ids.npy',
    'ids_2_img_paths': f'{features_save_dir}/ids_2_img_paths.json',
    'features': f'{features_save_dir}/features.npy'
}

# Get Data Transform

In [None]:
try:
    data_transform = get_data_transform(img_norm=IMG_NORM)
except KeyError as e:
    print(f"Key {e} not found in either constansts_zoo of `data.get_norm_constants()` or data-specific transforms in {DATASET_SPECIFIC_NORMALIZATION_CONSTANTS_PATH}")
    print("Calculating mean and std for the dataset...")
    mean, std = calculate_dataset_mean_std(img_dir=img_dir, batch_size=BATCH_SIZE)
    data_transform = get_data_transform(img_norm='manual', mean=mean, std=std)
    update_dataset_specific_mean_std(json_path=DATASET_SPECIFIC_NORMALIZATION_CONSTANTS_PATH, img_norm=IMG_NORM, mean=mean, std=std)

print(CANCER_TYPE)
print(data_transform)

# Get Feature Extractor

In [None]:
feature_extractor = prepare_feature_extractor(extractor_name=EXTRACTOR_NAME, device=device)

# Initialise a Dataset Instance

In [None]:
dataset = make_pytorch_dataset(img_dir=img_dir, data_transform=data_transform)
dataset.__getitem__(0)

In [None]:
dataset.img_dir

In [None]:
feature_extractor(dataset.__getitem__(0)['image'].to(device).unsqueeze(0)).shape

# Initialise a Dataloader Instance

In [None]:
dataloader = make_pytorch_dataloader(
    dataset=dataset, batch_size=BATCH_SIZE)
print("Total instances: ", len(dataloader.dataset))
print("Total batches: ", len(dataloader))
print()

first_batch = next(iter(dataloader))
for key, val in first_batch.items():
    if isinstance(val, torch.Tensor):
        print(key, type(val), ":", val.shape)
    elif isinstance(val, list):
        print(key, type(val), ":", len(val))
    else:
        print(key, ":", type(val))

# Run inference on the whole dataset

In [None]:
features_and_info = extract_features(
    feature_extractor=feature_extractor,
    dataloader=dataloader,
    device=device,
)

# Save features, ids, and ids_2_img_paths mapping

In [None]:
save_features(contents=features_and_info, paths=features_save_paths)

# Load saved features, ids, and ids_2_img_paths mapping

In [None]:
print(features_save_paths["ids"])
print(np.load(features_save_paths["ids"]))

print(features_save_paths["features"])
print(np.load(features_save_paths["features"])[0])

print(features_save_paths["ids_2_img_paths"])
with open(features_save_paths["ids_2_img_paths"], "r") as f:
    ids_2_img_paths = json.load(f)
print(ids_2_img_paths)