In [62]:
import glob
import os
from dataclasses import dataclass

import numpy as np
import torch
from tifffile import tifffile
from torch import Tensor
from torch.nn import functional as F

from create_dataloader import Augmentation, Dataset, DatasetScale, Label, MNInSecTVariant, SplitType
from model_picker import ModelType, get_model_name

MODELS_ROOT = "./models"
DATA_PATH = "./datasets/MNInSecT/"
ATTENTION_MAPS_ROOT = "D:/attention_maps"

In [63]:
model_type: ModelType = ModelType.SEResNet50
scale: DatasetScale = DatasetScale.Scale50
dataset_augmentation: Augmentation = Augmentation.Original

In [64]:
dataset_variant = MNInSecTVariant(dataset_augmentation, scale)
dataset: Dataset = Dataset(MNInSecT_root=DATA_PATH, type=SplitType.Test, seed=69420, as_rgb=False, variant=dataset_variant)
model_name = get_model_name(model_type, dataset_variant)

@dataclass
class AttentionMap:
    model: ModelType
    dataset: MNInSecTVariant
    layer: int
    image_name: str
    label: Label

    def fetch(self, root: str) -> Tensor:
        attention_maps_path = os.path.join(root, get_model_name(self.model, self.dataset), f"layer{self.layer}", self.image_name, f"{self.label.abbreviation}*.tif")
        attention_map_filename = glob.glob(attention_maps_path)[0]
        attention_map = torch.from_numpy(tifffile.imread(attention_map_filename))
        return attention_map


def scale_image(image: Tensor, new_size) -> Tensor:
    return F.interpolate(true_map.unsqueeze(dim=0).unsqueeze(dim=0), new_size)[0][0]

def combine_image(image: Tensor, true_map: Tensor, predicted_map: Tensor):
    true_map = scale_image(true_map, image.squeeze().shape)
    true_map[true_map < 0.2] = 0

    predicted_map = scale_image(predicted_map, image.squeeze().shape)
    predicted_map[predicted_map < 0.2] = 0

    image = image / image.max()
    image = image[0]
    image[image < 0.1] = 0

    background = (predicted_map == 0) & (image == 0) & (true_map == 0)
    opacity = torch.ones(image.shape)
    opacity[background] = 1
    opacity[image != 0] -= 0.2
    opacity[true_map != 0] -= 0.1
    opacity[predicted_map != 0] -= 0.1

    combined = torch.stack([image, predicted_map, true_map, opacity], dim=-1).numpy()
    combined_as_uint8 = (combined * 255).astype(np.uint8)
    return combined_as_uint8

def empty_folder(path: str) -> None:
    contents = glob.glob(f"{path}/*")
    for file in contents:
        os.remove(file)

In [65]:
image_id = 0

In [66]:
empty_folder("./combined")

image_name = dataset.get_name_of_image(image_id)
true_label = Label.from_abbreviation(image_name[:2].upper())
insect_image = dataset[image_id][0]

for layer in range(1, 5):

    true_label_path = os.path.join(ATTENTION_MAPS_ROOT, model_name, f"layer{layer}", image_name, f"{true_label.abbreviation}*.tif")
    true_map_filename = glob.glob(true_label_path)[0]
    true_map = torch.from_numpy(tifffile.imread(true_map_filename))

    prediction_label_path = os.path.join(ATTENTION_MAPS_ROOT, model_name, f"layer{layer}", image_name, f"*prediction*.tif")
    prediction_map_filename = glob.glob(prediction_label_path)[0]
    if prediction_map_filename != true_map_filename:
        prediction_map = torch.from_numpy(tifffile.imread(prediction_map_filename))
    else:
        prediction_map = torch.zeros(true_map.shape)

    tifffile.imwrite(f"./combined/{image_name[:6]}, {model_name} layer {layer}.tif", combine_image(insect_image, true_map, prediction_map))
image_id += 1